diff --git a/evaluate_age.py b/evaluate_age.py deleted file mode 100644 index 2de473a..0000000 --- a/evaluate_age.py +++ /dev/null @@ -1,507 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import math -import os -import multiprocessing as mp -from typing import List, Sequence, Tuple - -import pandas as pd -import torch -from torch.utils.data import DataLoader, random_split - -from dataset import HealthDataset, health_collate_fn -from evaluation_age_time_dependent import ( - EvalAgeConfig, - aggregate_age_bin_results, - evaluate_time_dependent_age_bins, -) -from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss -from model import DelphiFork, SapDelphi, SimpleHead - - -def _parse_floats(items: Sequence[str]) -> List[float]: - out: List[float] = [] - for x in items: - x = x.strip() - if not x: - continue - out.append(float(x)) - return out - - -def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]: - vals = _parse_floats(edges) - if len(vals) < 2: - raise ValueError("--age_bin_edges must have at least 2 values") - for i in range(1, len(vals)): - if not (vals[i] > vals[i - 1]): - raise ValueError("--age_bin_edges must be strictly increasing") - return vals - - -def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]: - return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)] - - -def _parse_gpus(gpus: str | None) -> List[int]: - if gpus is None: - return [] - s = gpus.strip() - if not s: - return [] - parts = [p.strip() for p in s.split(",") if p.strip()] - out: List[int] = [] - for p in parts: - out.append(int(p)) - return out - - -def _worker_eval_mcs_on_gpu( - queue: "mp.Queue", - *, - run_dir: str, - split: str, - data_prefix_override: str | None, - horizons: List[float], - age_bins: List[Tuple[float, float]], - topk_percents: List[float], - n_mc: int, - seed: int, - batch_size: int, - num_workers: int, - gpu_id: int, - mc_indices: List[int], - out_path: str, -) -> None: - """Worker process: evaluate a subset of MC indices on a single GPU.""" - try: - ckpt_path = os.path.join(run_dir, "best_model.pt") - cfg_path = os.path.join(run_dir, "train_config.json") - with open(cfg_path, "r") as f: - cfg = json.load(f) - - data_prefix = ( - data_prefix_override - if data_prefix_override is not None - else cfg.get("data_prefix", "ukb") - ) - - full_cov = bool(cfg.get("full_cov", False)) - cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] - dataset = HealthDataset(data_prefix=data_prefix, - covariate_list=cov_list) - - train_ratio = float(cfg.get("train_ratio", 0.7)) - val_ratio = float(cfg.get("val_ratio", 0.15)) - seed_split = int(cfg.get("random_seed", 42)) - - n_total = len(dataset) - n_train = int(n_total * train_ratio) - n_val = int(n_total * val_ratio) - n_test = n_total - n_train - n_val - - train_ds, val_ds, test_ds = random_split( - dataset, - [n_train, n_val, n_test], - generator=torch.Generator().manual_seed(seed_split), - ) - - if split == "train": - ds = train_ds - elif split == "val": - ds = val_ds - elif split == "test": - ds = test_ds - else: - ds = dataset - - loader = DataLoader( - ds, - batch_size=int(batch_size), - shuffle=False, - collate_fn=health_collate_fn, - num_workers=int(num_workers), - pin_memory=True, - ) - - criterion, out_dims = build_criterion_and_out_dims( - loss_type=str(cfg["loss_type"]), - n_disease=int(dataset.n_disease), - bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), - lambda_reg=float(cfg.get("lambda_reg", 0.0)), - ) - - model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) - head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) - - device = torch.device(f"cuda:{int(gpu_id)}") - checkpoint = torch.load(ckpt_path, map_location=device) - - model.load_state_dict(checkpoint["model_state_dict"], strict=True) - head.load_state_dict(checkpoint["head_state_dict"], strict=True) - if "criterion_state_dict" in checkpoint: - try: - criterion.load_state_dict( - checkpoint["criterion_state_dict"], strict=False) - except Exception: - pass - - model.to(device) - head.to(device) - criterion.to(device) - - frames: List[pd.DataFrame] = [] - for mc_idx in mc_indices: - eval_cfg = EvalAgeConfig( - horizons_years=horizons, - age_bins=age_bins, - topk_percents=topk_percents, - n_mc=1, - seed=int(seed), - cause_ids=None, - ) - - df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins( - model=model, - head=head, - criterion=criterion, - dataloader=loader, - n_disease=int(dataset.n_disease), - cfg=eval_cfg, - device=device, - mc_offset=int(mc_idx), - ) - frames.append(df_by_bin) - - df_all = pd.concat(frames, ignore_index=True) if len( - frames) else pd.DataFrame() - df_all = _drop_zero_positives_rows(df_all, "n_positives") - df_all.to_csv(out_path, index=False) - queue.put({"ok": True, "out_path": out_path}) - except Exception as e: - queue.put({"ok": False, "error": repr(e)}) - - -def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): - if loss_type == "exponential": - criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) - out_dims = [n_disease] - return criterion, out_dims - - if loss_type == "discrete_time_cif": - criterion = DiscreteTimeCIFNLLLoss( - bin_edges=bin_edges, lambda_reg=lambda_reg) - out_dims = [n_disease + 1, len(bin_edges)] - return criterion, out_dims - - if loss_type == "pwe_cif": - pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] - if len(pwe_edges) < 2: - raise ValueError( - "pwe_cif requires at least 2 finite bin edges (including 0)") - if float(pwe_edges[0]) != 0.0: - raise ValueError("pwe_cif requires bin_edges[0]==0.0") - criterion = PiecewiseExponentialCIFNLLLoss( - bin_edges=pwe_edges, lambda_reg=lambda_reg) - n_bins = len(pwe_edges) - 1 - out_dims = [n_disease, n_bins] - return criterion, out_dims - - raise ValueError(f"Unsupported loss_type: {loss_type}") - - -def _drop_zero_positives_rows(df: pd.DataFrame, positive_col: str) -> pd.DataFrame: - """Drop rows where the provided positives column is <= 0. - - Intended to reduce CSV size by omitting (cause, horizon, bin) rows that have - no positives, which otherwise yield undefined/NaN metrics. - """ - if df is None or len(df) == 0: - return df - if positive_col not in df.columns: - return df - pos = pd.to_numeric(df[positive_col], errors="coerce") - return df[pos > 0].copy() - - -def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): - if model_type == "delphi_fork": - return DelphiFork( - n_disease=dataset.n_disease, - n_tech_tokens=2, - n_embd=int(cfg["n_embd"]), - n_head=int(cfg["n_head"]), - n_layer=int(cfg["n_layer"]), - pdrop=float(cfg.get("pdrop", 0.0)), - age_encoder_type=str(cfg["age_encoder"]), - n_cont=int(dataset.n_cont), - n_cate=int(dataset.n_cate), - cate_dims=list(dataset.cate_dims), - ) - - if model_type == "sap_delphi": - return SapDelphi( - n_disease=dataset.n_disease, - n_tech_tokens=2, - n_embd=int(cfg["n_embd"]), - n_head=int(cfg["n_head"]), - n_layer=int(cfg["n_layer"]), - pdrop=float(cfg.get("pdrop", 0.0)), - age_encoder_type=str(cfg["age_encoder"]), - n_cont=int(dataset.n_cont), - n_cate=int(dataset.n_cate), - cate_dims=list(dataset.cate_dims), - pretrained_weights_path=str( - cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), - freeze_embeddings=bool(cfg.get("freeze_embeddings", True)), - ) - - raise ValueError(f"Unsupported model_type: {model_type}") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})") - parser.add_argument( - "--run_dir", - type=str, - required=True, - help="Training run directory (contains best_model.pt and train_config.json)", - ) - parser.add_argument("--data_prefix", type=str, default=None) - parser.add_argument("--split", type=str, - choices=["train", "val", "test", "all"], default="val") - - parser.add_argument("--horizons", type=str, nargs="+", - default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"]) - parser.add_argument( - "--age_bin_edges", - type=str, - nargs="+", - default=["40", "45", "50", "55", "60", "65", "70", "75", "80"], - help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).", - ) - parser.add_argument( - "--topk_percent", - type=float, - nargs="+", - default=[1, 5, 10, 20, 50], - help="One or more K%% values for recall/precision@K%%", - ) - parser.add_argument("--n_mc", type=int, default=5) - parser.add_argument("--seed", type=int, default=0) - - parser.add_argument( - "--gpus", - type=str, - default=None, - help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3", - ) - - parser.add_argument("--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") - parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--num_workers", type=int, default=0) - - parser.add_argument("--out_prefix", type=str, - default=None, help="Output prefix for CSVs") - - args = parser.parse_args() - - ckpt_path = os.path.join(args.run_dir, "best_model.pt") - cfg_path = os.path.join(args.run_dir, "train_config.json") - if not os.path.exists(ckpt_path): - raise SystemExit(f"Missing checkpoint: {ckpt_path}") - if not os.path.exists(cfg_path): - raise SystemExit(f"Missing config: {cfg_path}") - - with open(cfg_path, "r") as f: - cfg = json.load(f) - - data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( - "data_prefix", "ukb") - - full_cov = bool(cfg.get("full_cov", False)) - cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] - dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) - - train_ratio = float(cfg.get("train_ratio", 0.7)) - val_ratio = float(cfg.get("val_ratio", 0.15)) - seed_split = int(cfg.get("random_seed", 42)) - - n_total = len(dataset) - n_train = int(n_total * train_ratio) - n_val = int(n_total * val_ratio) - n_test = n_total - n_train - n_val - - train_ds, val_ds, test_ds = random_split( - dataset, - [n_train, n_val, n_test], - generator=torch.Generator().manual_seed(seed_split), - ) - - if args.split == "train": - ds = train_ds - elif args.split == "val": - ds = val_ds - elif args.split == "test": - ds = test_ds - else: - ds = dataset - - loader = DataLoader( - ds, - batch_size=int(args.batch_size), - shuffle=False, - collate_fn=health_collate_fn, - num_workers=int(args.num_workers), - pin_memory=str(args.device).startswith("cuda"), - ) - - criterion, out_dims = build_criterion_and_out_dims( - loss_type=str(cfg["loss_type"]), - n_disease=int(dataset.n_disease), - bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), - lambda_reg=float(cfg.get("lambda_reg", 0.0)), - ) - - model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) - head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) - - device = torch.device(args.device) - checkpoint = torch.load(ckpt_path, map_location=device) - - model.load_state_dict(checkpoint["model_state_dict"], strict=True) - head.load_state_dict(checkpoint["head_state_dict"], strict=True) - if "criterion_state_dict" in checkpoint: - try: - criterion.load_state_dict( - checkpoint["criterion_state_dict"], strict=False) - except Exception: - pass - - model.to(device) - head.to(device) - criterion.to(device) - - age_edges = _parse_age_bin_edges(args.age_bin_edges) - age_bins = _edges_to_bins(age_edges) - - eval_cfg = EvalAgeConfig( - horizons_years=_parse_floats(args.horizons), - age_bins=age_bins, - topk_percents=[float(x) for x in args.topk_percent], - n_mc=int(args.n_mc), - seed=int(args.seed), - cause_ids=None, - ) - - if args.out_prefix is None: - out_prefix = os.path.join( - args.run_dir, f"age_bin_time_dependent_{args.split}") - else: - out_prefix = args.out_prefix - - out_bin = out_prefix + "_by_bin.csv" - out_agg = out_prefix + "_agg.csv" - - gpus = _parse_gpus(args.gpus) - if len(gpus) <= 1: - df_by_bin, df_agg = evaluate_time_dependent_age_bins( - model=model, - head=head, - criterion=criterion, - dataloader=loader, - n_disease=int(dataset.n_disease), - cfg=eval_cfg, - device=device, - ) - - df_by_bin_csv = _drop_zero_positives_rows(df_by_bin, "n_positives") - df_agg_csv = _drop_zero_positives_rows(df_agg, "n_positives_total_mean") - df_by_bin_csv.to_csv(out_bin, index=False) - df_agg_csv.to_csv(out_agg, index=False) - print(f"Wrote: {out_bin}") - print(f"Wrote: {out_agg}") - return - - if not torch.cuda.is_available(): - raise SystemExit("--gpus was provided but CUDA is not available") - - # Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU). - mc_indices_all = list(range(int(args.n_mc))) - per_gpu: List[Tuple[int, List[int]]] = [] - for pos, gpu_id in enumerate(gpus): - assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos] - if assigned: - per_gpu.append((int(gpu_id), assigned)) - - ctx = mp.get_context("spawn") - queue: "mp.Queue" = ctx.Queue() - procs: List[mp.Process] = [] - tmp_paths: List[str] = [] - - for gpu_id, mc_idxs in per_gpu: - tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv" - tmp_paths.append(tmp_path) - p = ctx.Process( - target=_worker_eval_mcs_on_gpu, - kwargs=dict( - queue=queue, - run_dir=str(args.run_dir), - split=str(args.split), - data_prefix_override=( - str(args.data_prefix) if args.data_prefix is not None else None - ), - horizons=_parse_floats(args.horizons), - age_bins=age_bins, - topk_percents=[float(x) for x in args.topk_percent], - n_mc=int(args.n_mc), - seed=int(args.seed), - batch_size=int(args.batch_size), - num_workers=int(args.num_workers), - gpu_id=int(gpu_id), - mc_indices=mc_idxs, - out_path=tmp_path, - ), - ) - p.start() - procs.append(p) - - results = [queue.get() for _ in range(len(procs))] - for p in procs: - p.join() - - for r in results: - if not r.get("ok", False): - raise SystemExit(f"Worker failed: {r.get('error')}") - - frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)] - df_by_bin = pd.concat(frames, ignore_index=True) if len( - frames) else pd.DataFrame() - - # Ensure we don't keep zero-positive rows even if a temp file was produced - # by an older version of the worker. - df_by_bin = _drop_zero_positives_rows(df_by_bin, "n_positives") - df_agg = aggregate_age_bin_results(df_by_bin) - - df_agg = _drop_zero_positives_rows(df_agg, "n_positives_total_mean") - df_by_bin.to_csv(out_bin, index=False) - df_agg.to_csv(out_agg, index=False) - - # Best-effort cleanup. - for p in tmp_paths: - try: - if os.path.exists(p): - os.remove(p) - except Exception: - pass - - print(f"Wrote: {out_bin}") - print(f"Wrote: {out_agg}") - - -if __name__ == "__main__": - main() diff --git a/evaluation_age_time_dependent.py b/evaluation_age_time_dependent.py deleted file mode 100644 index fbc26c2..0000000 --- a/evaluation_age_time_dependent.py +++ /dev/null @@ -1,852 +0,0 @@ -from __future__ import annotations - -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import numpy as np -import pandas as pd -import torch - -try: - from tqdm import tqdm -except Exception: # pragma: no cover - - def tqdm(x, **kwargs): - return x -from utils import ( - multi_hot_ever_within_horizon, - multi_hot_selected_causes_within_horizon, - sample_context_in_fixed_age_bin, -) - -from torch_metrics import compute_binary_metrics_torch - - -def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray: - with np.errstate(invalid="ignore"): - return np.nanmean(x, axis=axis) - - -def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray: - """NaN-aware sample std (ddof=1), matching pandas std() semantics.""" - x = np.asarray(x, dtype=float) - mask = np.isfinite(x) - cnt = mask.sum(axis=axis) - # mean over finite entries - x0 = np.where(mask, x, 0.0) - mean = x0.sum(axis=axis) / np.maximum(cnt, 1) - # sum of squared deviations over finite entries - dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0) - ss = dev2.sum(axis=axis) - denom = cnt - 1 - out = np.sqrt(ss / np.maximum(denom, 1)) - out = np.where(denom > 0, out, np.nan) - return out - - -def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray: - """NaN-aware weighted mean. - - Only bins with finite x contribute to both numerator and denominator. - If denom==0 -> NaN. - """ - x = np.asarray(x, dtype=float) - w = np.asarray(w, dtype=float) - - if axis != 0: - raise ValueError("_weighted_mean_np currently supports axis=0 only") - - # Broadcast weights along trailing dims of x. - while w.ndim < x.ndim: - w = w[..., None] - w = np.broadcast_to(w, x.shape) - - mask = np.isfinite(x) - num = np.where(mask, x * w, 0.0).sum(axis=0) - denom = np.where(mask, w, 0.0).sum(axis=0) - return np.where(denom > 0.0, num / denom, np.nan) - - -def _blocks_to_df_by_bin( - blocks: List[Dict[str, Any]], - *, - topk_percents: np.ndarray, -) -> pd.DataFrame: - """Convert per-block column vectors into the long-format per-bin DataFrame. - - This does a single vectorized reshape per block (cause-major ordering), and - concatenates columns once at the end. - """ - if len(blocks) == 0: - return pd.DataFrame( - columns=[ - "mc_idx", - "age_bin_id", - "age_bin_low", - "age_bin_high", - "horizon_tau", - "topk_percent", - "cause_id", - "n_samples", - "n_positives", - "auc", - "auprc", - "recall_at_K", - "precision_at_K", - "brier_score", - ] - ) - - P = int(topk_percents.size) - - cols: Dict[str, List[np.ndarray]] = { - "mc_idx": [], - "age_bin_id": [], - "age_bin_low": [], - "age_bin_high": [], - "horizon_tau": [], - "topk_percent": [], - "cause_id": [], - "n_samples": [], - "n_positives": [], - "auc": [], - "auprc": [], - "recall_at_K": [], - "precision_at_K": [], - "brier_score": [], - } - - for blk in blocks: - cause_id = np.asarray(blk["cause_id"], dtype=int) - K = int(cause_id.size) - n_rows = K * P - - cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int)) - cols["age_bin_id"].append( - np.full(n_rows, int(blk["age_bin_id"]), dtype=int)) - cols["age_bin_low"].append( - np.full(n_rows, float(blk["age_bin_low"]), dtype=float)) - cols["age_bin_high"].append( - np.full(n_rows, float(blk["age_bin_high"]), dtype=float)) - cols["horizon_tau"].append( - np.full(n_rows, float(blk["horizon_tau"]), dtype=float)) - - cols["cause_id"].append(np.repeat(cause_id, P)) - cols["topk_percent"].append(np.tile(topk_percents.astype(float), K)) - cols["n_samples"].append( - np.full(n_rows, int(blk["n_samples"]), dtype=int)) - - n_pos = np.asarray(blk["n_positives"], dtype=int) - cols["n_positives"].append(np.repeat(n_pos, P)) - - auc = np.asarray(blk["auc"], dtype=float) - auprc = np.asarray(blk["auprc"], dtype=float) - brier = np.asarray(blk["brier_score"], dtype=float) - cols["auc"].append(np.repeat(auc, P)) - cols["auprc"].append(np.repeat(auprc, P)) - cols["brier_score"].append(np.repeat(brier, P)) - - # precision/recall are stored as (P,K); we want cause-major rows, i.e. - # (K,P) then flatten. - prec = np.asarray(blk["precision_at_K"], dtype=float) - rec = np.asarray(blk["recall_at_K"], dtype=float) - if prec.shape != (P, K) or rec.shape != (P, K): - raise ValueError( - f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}" - ) - cols["precision_at_K"].append(prec.T.reshape(-1)) - cols["recall_at_K"].append(rec.T.reshape(-1)) - - out = {k: np.concatenate(v, axis=0) for k, v in cols.items()} - return pd.DataFrame(out) - - -def aggregate_metrics_columnar( - blocks: List[Dict[str, Any]], - *, - topk_percents: np.ndarray, - cause_id: np.ndarray, -) -> pd.DataFrame: - """Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std). - - This is a vectorized, columnar replacement for the old pandas groupby/apply. - Semantics match the previous implementation: - - bins with n_samples==0 are excluded from bin-aggregation - - macro: unweighted mean over bins (NaN-aware) - - weighted: weighted mean over bins using weights=n_samples (NaN-aware) - - across MC: mean/std (ddof=1), NaN-aware - """ - if len(blocks) == 0: - return pd.DataFrame( - columns=[ - "agg_type", - "horizon_tau", - "topk_percent", - "cause_id", - "n_mc", - "n_bins_used_mean", - "n_samples_total_mean", - "n_positives_total_mean", - "auc_mean", - "auc_std", - "auprc_mean", - "auprc_std", - "recall_at_K_mean", - "recall_at_K_std", - "precision_at_K_mean", - "precision_at_K_std", - "brier_score_mean", - "brier_score_std", - ] - ) - - P = int(topk_percents.size) - cause_id = np.asarray(cause_id, dtype=int) - K = int(cause_id.size) - - # Group blocks by (mc_idx, horizon_tau) - keys: List[Tuple[int, float]] = [] - grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {} - for blk in blocks: - key = (int(blk["mc_idx"]), float(blk["horizon_tau"])) - if key not in grouped: - grouped[key] = [] - keys.append(key) - grouped[key].append(blk) - - mc_vals = sorted({k[0] for k in keys}) - tau_vals = sorted({k[1] for k in keys}) - M = len(mc_vals) - T = len(tau_vals) - - mc_index = {mc: i for i, mc in enumerate(mc_vals)} - tau_index = {tau: i for i, tau in enumerate(tau_vals)} - - # Per (agg_type, mc, tau): store arrays - # metrics: (M,T,K) and (M,T,P,K) - auc_macro = np.full((M, T, K), np.nan, dtype=float) - auc_weighted = np.full((M, T, K), np.nan, dtype=float) - ap_macro = np.full((M, T, K), np.nan, dtype=float) - ap_weighted = np.full((M, T, K), np.nan, dtype=float) - brier_macro = np.full((M, T, K), np.nan, dtype=float) - brier_weighted = np.full((M, T, K), np.nan, dtype=float) - - prec_macro = np.full((M, T, P, K), np.nan, dtype=float) - prec_weighted = np.full((M, T, P, K), np.nan, dtype=float) - rec_macro = np.full((M, T, P, K), np.nan, dtype=float) - rec_weighted = np.full((M, T, P, K), np.nan, dtype=float) - - n_bins_used = np.zeros((M, T), dtype=float) - n_samples_total = np.zeros((M, T), dtype=float) - n_pos_total = np.zeros((M, T, K), dtype=float) - - for (mc, tau), blks in grouped.items(): - mi = mc_index[mc] - ti = tau_index[tau] - - # keep only bins with n_samples>0 - blks_nz = [b for b in blks if int(b["n_samples"]) > 0] - if len(blks_nz) == 0: - n_bins_used[mi, ti] = 0.0 - n_samples_total[mi, ti] = 0.0 - n_pos_total[mi, ti, :] = 0.0 - continue - - w = np.asarray([int(b["n_samples"]) - for b in blks_nz], dtype=float) # (B,) - n_bins_used[mi, ti] = float(len(w)) - n_samples_total[mi, ti] = float(w.sum()) - - npos = np.stack([np.asarray(b["n_positives"], dtype=float) - for b in blks_nz], axis=0) # (B,K) - n_pos_total[mi, ti, :] = npos.sum(axis=0) - - auc_b = np.stack([np.asarray(b["auc"], dtype=float) - for b in blks_nz], axis=0) # (B,K) - ap_b = np.stack([np.asarray(b["auprc"], dtype=float) - for b in blks_nz], axis=0) - brier_b = np.stack([np.asarray(b["brier_score"], dtype=float) - for b in blks_nz], axis=0) - - auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0) - ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0) - brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0) - - auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0) - ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0) - brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0) - - prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float) - for b in blks_nz], axis=0) # (B,P,K) - rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float) - for b in blks_nz], axis=0) - - # macro mean over bins - prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0) - rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0) - - # weighted mean over bins (weights along bin axis) - w3 = w.reshape(-1, 1, 1) - prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0) - rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0) - - # Across-MC aggregation (mean/std), then emit long-format df keyed by - # (agg_type, horizon_tau, topk_percent, cause_id) - rows: Dict[str, List[np.ndarray]] = { - "agg_type": [], - "horizon_tau": [], - "topk_percent": [], - "cause_id": [], - "n_mc": [], - "n_bins_used_mean": [], - "n_samples_total_mean": [], - "n_positives_total_mean": [], - "auc_mean": [], - "auc_std": [], - "auprc_mean": [], - "auprc_std": [], - "recall_at_K_mean": [], - "recall_at_K_std": [], - "precision_at_K_mean": [], - "precision_at_K_std": [], - "brier_score_mean": [], - "brier_score_std": [], - } - - cause_long = np.repeat(cause_id, P) - topk_long = np.tile(topk_percents.astype(float), K) - n_mc_val = float(M) - - for ti, tau in enumerate(tau_vals): - # scalar totals (repeat across causes/topk) - n_bins_mean = float( - np.mean(n_bins_used[:, ti])) if M > 0 else float("nan") - n_samp_mean = float( - np.mean(n_samples_total[:, ti])) if M > 0 else float("nan") - n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,) - - for agg_type in ("macro", "weighted"): - if agg_type == "macro": - auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0) - auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0) - ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0) - ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0) - brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0) - brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0) - prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K) - prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0) - rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0) - rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0) - else: - auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0) - auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0) - ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0) - ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0) - brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0) - brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0) - prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0) - prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0) - rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0) - rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0) - - n_rows = K * P - rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object)) - rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float)) - rows["topk_percent"].append(topk_long) - rows["cause_id"].append(cause_long) - rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float)) - rows["n_bins_used_mean"].append( - np.full(n_rows, n_bins_mean, dtype=float)) - rows["n_samples_total_mean"].append( - np.full(n_rows, n_samp_mean, dtype=float)) - rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P)) - - rows["auc_mean"].append(np.repeat(auc_m, P)) - rows["auc_std"].append(np.repeat(auc_s, P)) - rows["auprc_mean"].append(np.repeat(ap_m, P)) - rows["auprc_std"].append(np.repeat(ap_s, P)) - rows["brier_score_mean"].append(np.repeat(brier_m, P)) - rows["brier_score_std"].append(np.repeat(brier_s, P)) - - rows["precision_at_K_mean"].append(prec_m.T.reshape(-1)) - rows["precision_at_K_std"].append(prec_s.T.reshape(-1)) - rows["recall_at_K_mean"].append(rec_m.T.reshape(-1)) - rows["recall_at_K_std"].append(rec_s.T.reshape(-1)) - - out = {k: np.concatenate(v, axis=0) for k, v in rows.items()} - df = pd.DataFrame(out) - return df.sort_values( - ["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True - ) - - -def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: - """Aggregate per-bin age evaluation results. - - Produces both: - - macro: unweighted mean over bins with n_samples>0 - - weighted: weighted mean over bins using weights=n_samples - - Then aggregates across MC repetitions (mean/std). - - Requires df_by_bin to include: - mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id, - n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score - - Returns: - DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id) - """ - if df_by_bin is None or len(df_by_bin) == 0: - return pd.DataFrame( - columns=[ - "agg_type", - "horizon_tau", - "topk_percent", - "cause_id", - "n_mc", - "n_bins_used_mean", - "n_samples_total_mean", - "n_positives_total_mean", - "auc_mean", - "auc_std", - "auprc_mean", - "auprc_std", - "recall_at_K_mean", - "recall_at_K_std", - "precision_at_K_mean", - "precision_at_K_std", - "brier_score_mean", - "brier_score_std", - ] - ) - - def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: - g = group[group["n_samples"] > 0] - if len(g) == 0: - return pd.Series( - dict( - n_bins_used=0, - n_samples_total=0, - n_positives_total=0, - auc=float("nan"), - auprc=float("nan"), - recall_at_K=float("nan"), - precision_at_K=float("nan"), - brier_score=float("nan"), - ) - ) - - n_bins_used = int(g["age_bin_id"].nunique()) - n_samples_total = int(g["n_samples"].sum()) - n_positives_total = int(g["n_positives"].sum()) - - if not weighted: - return pd.Series( - dict( - n_bins_used=n_bins_used, - n_samples_total=n_samples_total, - n_positives_total=n_positives_total, - auc=float(g["auc"].mean()), - auprc=float(g["auprc"].mean()), - recall_at_K=float(g["recall_at_K"].mean()), - precision_at_K=float(g["precision_at_K"].mean()), - brier_score=float(g["brier_score"].mean()), - ) - ) - - w = g["n_samples"].to_numpy(dtype=float) - w_sum = float(w.sum()) - if w_sum <= 0.0: - return pd.Series( - dict( - n_bins_used=n_bins_used, - n_samples_total=n_samples_total, - n_positives_total=n_positives_total, - auc=float("nan"), - auprc=float("nan"), - recall_at_K=float("nan"), - precision_at_K=float("nan"), - brier_score=float("nan"), - ) - ) - - def _wavg(col: str) -> float: - return float(np.average(g[col].to_numpy(dtype=float), weights=w)) - - return pd.Series( - dict( - n_bins_used=n_bins_used, - n_samples_total=n_samples_total, - n_positives_total=n_positives_total, - auc=_wavg("auc"), - auprc=_wavg("auprc"), - recall_at_K=_wavg("recall_at_K"), - precision_at_K=_wavg("precision_at_K"), - brier_score=_wavg("brier_score"), - ) - ) - - # Kept for backward compatibility (e.g., if callers load a CSV and need to - # aggregate). Prefer `aggregate_metrics_columnar` during evaluation. - group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] - - df = df_by_bin[df_by_bin["n_samples"] > 0].copy() - if len(df) == 0: - return pd.DataFrame( - columns=[ - "agg_type", - "horizon_tau", - "topk_percent", - "cause_id", - "n_mc", - "n_bins_used_mean", - "n_samples_total_mean", - "n_positives_total_mean", - "auc_mean", - "auc_std", - "auprc_mean", - "auprc_std", - "recall_at_K_mean", - "recall_at_K_std", - "precision_at_K_mean", - "precision_at_K_std", - "brier_score_mean", - "brier_score_std", - ] - ) - - # Macro: mean over bins - df_mc_macro = ( - df.groupby(group_keys, as_index=False) - .agg( - n_bins_used=("age_bin_id", "nunique"), - n_samples_total=("n_samples", "sum"), - n_positives_total=("n_positives", "sum"), - auc=("auc", "mean"), - auprc=("auprc", "mean"), - recall_at_K=("recall_at_K", "mean"), - precision_at_K=("precision_at_K", "mean"), - brier_score=("brier_score", "mean"), - ) - ) - df_mc_macro["agg_type"] = "macro" - - # Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric - w = df["n_samples"].astype(float) - df_w = df.copy() - for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]: - m = df_w[col].astype(float) - ww = w.where(m.notna(), other=0.0) - df_w[f"__num_{col}"] = (m.fillna(0.0) * w) - df_w[f"__den_{col}"] = ww - - df_mc_w = df_w.groupby(group_keys, as_index=False).agg( - n_bins_used=("age_bin_id", "nunique"), - n_samples_total=("n_samples", "sum"), - n_positives_total=("n_positives", "sum"), - **{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]}, - **{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]}, - ) - for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]: - num = df_mc_w[f"__num_{col}"].astype(float) - den = df_mc_w[f"__den_{col}"].astype(float) - df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan")) - df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True) - df_mc_w["agg_type"] = "weighted" - - df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True) - - df_agg = ( - df_mc_binagg.groupby( - ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False) - .agg( - n_mc=("mc_idx", "nunique"), - n_bins_used_mean=("n_bins_used", "mean"), - n_samples_total_mean=("n_samples_total", "mean"), - n_positives_total_mean=("n_positives_total", "mean"), - auc_mean=("auc", "mean"), - auc_std=("auc", "std"), - auprc_mean=("auprc", "mean"), - auprc_std=("auprc", "std"), - recall_at_K_mean=("recall_at_K", "mean"), - recall_at_K_std=("recall_at_K", "std"), - precision_at_K_mean=("precision_at_K", "mean"), - precision_at_K_std=("precision_at_K", "std"), - brier_score_mean=("brier_score", "mean"), - brier_score_std=("brier_score", "std"), - ) - .sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True) - ) - return df_agg - - -# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`. -# NumPy/Pandas are only used for final CSV formatting/aggregation. - - -@dataclass -class EvalAgeConfig: - horizons_years: Sequence[float] - age_bins: Sequence[Tuple[float, float]] - topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) - n_mc: int = 5 - seed: int = 0 - cause_ids: Optional[Sequence[int]] = None - store_per_cause: bool = True - - -@torch.inference_mode() -def evaluate_time_dependent_age_bins( - model: torch.nn.Module, - head: torch.nn.Module, - criterion, - dataloader: torch.utils.data.DataLoader, - n_disease: int, - cfg: EvalAgeConfig, - device: str | torch.device, - mc_offset: int = 0, -) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Delphi-2M-style age-bin evaluation with strict horizon alignment. - - Semantics (strict): for each (MC, horizon tau, age bin) we independently: - - build the eligible token set within that bin - - enforce follow-up coverage: t_ctx + tau <= t_end - - randomly sample exactly one token per individual within the bin (de-dup) - - recompute context representations and predictions for that (tau, bin) - - Returns: - df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id) - df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted} - """ - device = torch.device(device) - model.eval() - head.eval() - - horizons_years = [float(x) for x in cfg.horizons_years] - if len(horizons_years) == 0: - raise ValueError("cfg.horizons_years must be non-empty") - - age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins] - if len(age_bins) == 0: - raise ValueError("cfg.age_bins must be non-empty") - for (a, b) in age_bins: - if not (b > a): - raise ValueError( - f"age_bins must be (low, high) with high>low; got {(a, b)}") - - topk_percents = [float(x) for x in cfg.topk_percents] - if len(topk_percents) == 0: - raise ValueError("cfg.topk_percents must be non-empty") - if any((p <= 0.0 or p > 100.0) for p in topk_percents): - raise ValueError( - f"All topk_percents must be in (0,100]; got {topk_percents}") - - if int(cfg.n_mc) <= 0: - raise ValueError("cfg.n_mc must be >= 1") - - if cfg.cause_ids is None: - cause_ids = None - n_causes_eval = int(n_disease) - cause_id_vec = np.arange(n_causes_eval, dtype=int) - else: - cause_ids = torch.tensor( - list(cfg.cause_ids), dtype=torch.long, device=device) - n_causes_eval = int(cause_ids.numel()) - cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int) - - topk_percents_np = np.asarray(topk_percents, dtype=float) - - # Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends. - blocks: List[Dict[str, Any]] = [] - - for mc_idx in range(int(cfg.n_mc)): - global_mc_idx = int(mc_offset) + int(mc_idx) - - # Storage for this MC only: (tau, bin) -> list of GPU tensors. - # This keeps computations GPU-first while preventing a factor-n_mc - # blow-up in GPU memory. - y_true_mc: List[List[List[torch.Tensor]]] = [ - [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) - ] - y_pred_mc: List[List[List[torch.Tensor]]] = [ - [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) - ] - - # tqdm over batches; include MC idx in description. - for batch_idx, batch in enumerate( - tqdm(dataloader, - desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch") - ): - event_seq, time_seq, cont_feats, cate_feats, sexes = batch - event_seq = event_seq.to(device) - time_seq = time_seq.to(device) - cont_feats = cont_feats.to(device) - cate_feats = cate_feats.to(device) - sexes = sexes.to(device) - - B = int(event_seq.size(0)) - b = torch.arange(B, device=device) - - # Hoist backbone forward pass: inputs are identical across (tau, age_bin) - # within this batch, so this is safe and numerically identical. - h = model(event_seq, time_seq, sexes, - cont_feats, cate_feats) # (B,L,D) - - for tau_idx, tau_y in enumerate(horizons_years): - tau_tensor = torch.tensor(float(tau_y), device=device) - for bin_idx, (a_lo, a_hi) in enumerate(age_bins): - # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. - seed = ( - int(cfg.seed) - + (100_000 * int(global_mc_idx)) - + (1_000 * int(tau_idx)) - + (10 * int(bin_idx)) - + int(batch_idx) - ) - - keep, t_ctx = sample_context_in_fixed_age_bin( - event_seq=event_seq, - time_seq=time_seq, - tau_years=float(tau_y), - age_bin=(float(a_lo), float(a_hi)), - seed=seed, - ) - if not keep.any(): - continue - - # Bin-specific prediction: context indices differ per (tau, bin) - # but the backbone features do not. - c = h[b, t_ctx] - logits = head(c) - - cifs = criterion.calculate_cifs( - logits, taus=tau_tensor - ) - if cifs.ndim != 2: - raise ValueError( - "criterion.calculate_cifs must return (B,K) for scalar tau; " - f"got shape={tuple(cifs.shape)}" - ) - - if cause_ids is None: - y = multi_hot_ever_within_horizon( - event_seq=event_seq, - time_seq=time_seq, - t_ctx=t_ctx, - tau_years=float(tau_y), - n_disease=n_disease, - ) - preds = cifs - else: - y = multi_hot_selected_causes_within_horizon( - event_seq=event_seq, - time_seq=time_seq, - t_ctx=t_ctx, - tau_years=float(tau_y), - cause_ids=cause_ids, - n_disease=n_disease, - ) - preds = cifs.index_select(dim=1, index=cause_ids) - - y_true_mc[tau_idx][bin_idx].append( - y[keep].detach().to(dtype=torch.bool) - ) - y_pred_mc[tau_idx][bin_idx].append( - preds[keep].detach().to(dtype=torch.float32) - ) - - # Aggregate this MC immediately (frees GPU memory before next MC). - for h_idx, tau_y in enumerate(horizons_years): - for bin_idx, (a_lo, a_hi) in enumerate(age_bins): - if len(y_true_mc[h_idx][bin_idx]) == 0: - # No samples in this bin for this (mc, tau): store a single - # block with NaN metric vectors. - K = int(n_causes_eval) - P = int(topk_percents_np.size) - blocks.append( - dict( - mc_idx=global_mc_idx, - age_bin_id=bin_idx, - age_bin_low=float(a_lo), - age_bin_high=float(a_hi), - horizon_tau=float(tau_y), - n_samples=0, - cause_id=cause_id_vec, - n_positives=np.zeros((K,), dtype=int), - auc=np.full((K,), np.nan, dtype=float), - auprc=np.full((K,), np.nan, dtype=float), - brier_score=np.full((K,), np.nan, dtype=float), - precision_at_K=np.full((P, K), np.nan, dtype=float), - recall_at_K=np.full((P, K), np.nan, dtype=float), - ) - ) - continue - - yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0) - pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0) - if tuple(yb_t.shape) != tuple(pb_t.shape): - raise ValueError( - f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}" - ) - - n_samples = int(yb_t.size(0)) - - metrics = compute_binary_metrics_torch( - y_true=yb_t, - y_pred=pb_t, - k_percents=topk_percents, - tie_mode="exact", - chunk_size=128, - compute_ici=False, - ) - - # Collect a single columnar block (vectors, not per-row dicts). - blocks.append( - dict( - mc_idx=global_mc_idx, - age_bin_id=bin_idx, - age_bin_low=float(a_lo), - age_bin_high=float(a_hi), - horizon_tau=float(tau_y), - n_samples=int(n_samples), - cause_id=cause_id_vec, - n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int), - auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float), - auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float), - brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float), - precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float), - recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float), - ) - ) - - # Aggregation is computed from columnar blocks (fast, no pandas apply). - df_agg = aggregate_metrics_columnar( - blocks, - topk_percents=topk_percents_np, - cause_id=cause_id_vec, - ) - - if bool(cfg.store_per_cause): - df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np) - else: - df_by_bin = pd.DataFrame( - columns=[ - "mc_idx", - "age_bin_id", - "age_bin_low", - "age_bin_high", - "horizon_tau", - "topk_percent", - "cause_id", - "n_samples", - "n_positives", - "auc", - "auprc", - "recall_at_K", - "precision_at_K", - "brier_score", - ] - ) - - return df_by_bin, df_agg diff --git a/utils.py b/utils.py deleted file mode 100644 index 3631574..0000000 --- a/utils.py +++ /dev/null @@ -1,207 +0,0 @@ -import torch -from typing import Tuple - -DAYS_PER_YEAR = 365.25 - - -def sample_context_in_fixed_age_bin( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - tau_years: float, - age_bin: Tuple[float, float], - seed: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Sample one context token per individual within a fixed age bin. - - Delphi-2M semantics for a specific (tau, age_bin): - - Token times are interpreted as age in *days* (converted to years). - - Follow-up end time is the last valid token time per individual. - - A token index j is eligible iff: - (token is valid) - AND (age_years in [age_low, age_high)) - AND (time_seq[i, j] + tau_days <= followup_end_time[i]) - - For each individual, randomly select exactly one eligible token in this bin. - - Args: - event_seq: (B, L) token ids, 0 is padding. - time_seq: (B, L) token times in days. - tau_years: horizon length in years. - age_bin: (low, high) bounds in years, interpreted as [low, high). - seed: RNG seed for deterministic sampling. - - Returns: - keep: (B,) bool, True if a context was sampled for this bin. - t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0). - """ - low, high = float(age_bin[0]), float(age_bin[1]) - if not (high > low): - raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}") - - device = event_seq.device - B, _ = event_seq.shape - - valid = event_seq != 0 - lengths = valid.sum(dim=1) - last_idx = torch.clamp(lengths - 1, min=0) - b = torch.arange(B, device=device) - followup_end_time = time_seq[b, last_idx] # (B,) - - tau_days = float(tau_years) * DAYS_PER_YEAR - age_years = time_seq / DAYS_PER_YEAR - - in_bin = (age_years >= low) & (age_years < high) - eligible = valid & in_bin & ( - (time_seq + tau_days) <= followup_end_time.unsqueeze(1)) - - # Vectorized, uniform sampling over eligible indices per sample. - # Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform - # choice among eligible indices by symmetry (ties have probability ~0). - keep = eligible.any(dim=1) - - # Prefer a per-call generator on the target device for reproducibility without - # touching global RNG state. If unavailable, fall back to seeding the global - # CUDA RNG for this call. - gen = None - if device.type == "cuda": - try: - gen = torch.Generator(device=device) - gen.manual_seed(int(seed)) - except Exception: - gen = None - torch.cuda.manual_seed(int(seed)) - else: - gen = torch.Generator() - gen.manual_seed(int(seed)) - - r = torch.rand((B, eligible.size(1)), device=device, generator=gen) - r = r.masked_fill(~eligible, -1.0) - t_ctx = r.argmax(dim=1).to(torch.long) - - # When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0). - return keep, t_ctx - - -def select_context_indices( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - offset_years: float, - tau_years: float = 0.0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Select per-sample prediction context index. - - IMPORTANT SEMANTICS: - - The last observed token time is treated as the FOLLOW-UP END time. - - We pick the last valid token with time <= (followup_end_time - offset). - - We do NOT interpret followup_end_time as an event time. - - Returns: - keep_mask: (B,) bool, which samples have a valid context - t_ctx: (B,) long, index into sequence - t_ctx_time: (B,) float, time (days) at context - """ - # valid tokens are event != 0 (padding is 0) - valid = event_seq != 0 - lengths = valid.sum(dim=1) - last_idx = torch.clamp(lengths - 1, min=0) - - b = torch.arange(event_seq.size(0), device=event_seq.device) - followup_end_time = time_seq[b, last_idx] - t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) - - eligible = valid & (time_seq <= t_cut.unsqueeze(1)) - eligible_counts = eligible.sum(dim=1) - keep = eligible_counts > 0 - - t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) - t_ctx_time = time_seq[b, t_ctx] - - # Horizon-aligned eligibility: require enough follow-up time after the selected context. - # All times are in days. - keep = keep & (followup_end_time >= ( - t_ctx_time + (tau_years * DAYS_PER_YEAR))) - - return keep, t_ctx, t_ctx_time - - -def multi_hot_ever_within_horizon( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - t_ctx: torch.Tensor, - tau_years: float, - n_disease: int, -) -> torch.Tensor: - """Binary labels: disease k occurs within tau after context (any occurrence).""" - B, L = event_seq.shape - b = torch.arange(B, device=event_seq.device) - t0 = time_seq[b, t_ctx] - t1 = t0 + (tau_years * DAYS_PER_YEAR) - - idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) - # Include same-day events after context, exclude any token at/before context index. - in_window = ( - (idxs > t_ctx.unsqueeze(1)) - & (time_seq >= t0.unsqueeze(1)) - & (time_seq <= t1.unsqueeze(1)) - & (event_seq >= 2) - & (event_seq != 0) - ) - - if not in_window.any(): - return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) - - b_idx, t_idx = in_window.nonzero(as_tuple=True) - disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - - y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) - y[b_idx, disease_ids] = True - return y - - -def multi_hot_selected_causes_within_horizon( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - t_ctx: torch.Tensor, - tau_years: float, - cause_ids: torch.Tensor, - n_disease: int, -) -> torch.Tensor: - """Labels for selected causes only: does cause k occur within tau after context?""" - B, L = event_seq.shape - device = event_seq.device - b = torch.arange(B, device=device) - t0 = time_seq[b, t_ctx] - t1 = t0 + (tau_years * DAYS_PER_YEAR) - - idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) - in_window = ( - (idxs > t_ctx.unsqueeze(1)) - & (time_seq >= t0.unsqueeze(1)) - & (time_seq <= t1.unsqueeze(1)) - & (event_seq >= 2) - & (event_seq != 0) - ) - - out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) - if not in_window.any(): - return out - - b_idx, t_idx = in_window.nonzero(as_tuple=True) - disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - - # Filter to selected causes via a boolean membership mask over the global disease space. - selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) - selected[cause_ids] = True - keep = selected[disease_ids] - if not keep.any(): - return out - - b_idx = b_idx[keep] - disease_ids = disease_ids[keep] - - # Map disease_id -> local index in cause_ids - # Build a lookup table (global disease space) where lookup[disease_id] = local_index - lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) - lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) - local = lookup[disease_ids] - out[b_idx, local] = True - return out