diff --git a/evaluate_age.py b/evaluate_age.py index f5aa209..0dea88b 100644 --- a/evaluate_age.py +++ b/evaluate_age.py @@ -4,13 +4,19 @@ import argparse import json import math import os +import multiprocessing as mp from typing import List, Sequence, Tuple +import pandas as pd import torch from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn -from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins +from evaluation_age_time_dependent import ( + EvalAgeConfig, + aggregate_age_bin_results, + evaluate_time_dependent_age_bins, +) from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead @@ -39,6 +45,144 @@ def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]: return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)] +def _parse_gpus(gpus: str | None) -> List[int]: + if gpus is None: + return [] + s = gpus.strip() + if not s: + return [] + parts = [p.strip() for p in s.split(",") if p.strip()] + out: List[int] = [] + for p in parts: + out.append(int(p)) + return out + + +def _worker_eval_mcs_on_gpu( + queue: "mp.Queue", + *, + run_dir: str, + split: str, + data_prefix_override: str | None, + horizons: List[float], + age_bins: List[Tuple[float, float]], + topk_percents: List[float], + n_mc: int, + seed: int, + batch_size: int, + num_workers: int, + gpu_id: int, + mc_indices: List[int], + out_path: str, +) -> None: + """Worker process: evaluate a subset of MC indices on a single GPU.""" + try: + ckpt_path = os.path.join(run_dir, "best_model.pt") + cfg_path = os.path.join(run_dir, "train_config.json") + with open(cfg_path, "r") as f: + cfg = json.load(f) + + data_prefix = ( + data_prefix_override + if data_prefix_override is not None + else cfg.get("data_prefix", "ukb") + ) + + full_cov = bool(cfg.get("full_cov", False)) + cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] + dataset = HealthDataset(data_prefix=data_prefix, + covariate_list=cov_list) + + train_ratio = float(cfg.get("train_ratio", 0.7)) + val_ratio = float(cfg.get("val_ratio", 0.15)) + seed_split = int(cfg.get("random_seed", 42)) + + n_total = len(dataset) + n_train = int(n_total * train_ratio) + n_val = int(n_total * val_ratio) + n_test = n_total - n_train - n_val + + train_ds, val_ds, test_ds = random_split( + dataset, + [n_train, n_val, n_test], + generator=torch.Generator().manual_seed(seed_split), + ) + + if split == "train": + ds = train_ds + elif split == "val": + ds = val_ds + elif split == "test": + ds = test_ds + else: + ds = dataset + + loader = DataLoader( + ds, + batch_size=int(batch_size), + shuffle=False, + collate_fn=health_collate_fn, + num_workers=int(num_workers), + pin_memory=True, + ) + + criterion, out_dims = build_criterion_and_out_dims( + loss_type=str(cfg["loss_type"]), + n_disease=int(dataset.n_disease), + bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), + lambda_reg=float(cfg.get("lambda_reg", 0.0)), + ) + + model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) + head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) + + device = torch.device(f"cuda:{int(gpu_id)}") + checkpoint = torch.load(ckpt_path, map_location=device) + + model.load_state_dict(checkpoint["model_state_dict"], strict=True) + head.load_state_dict(checkpoint["head_state_dict"], strict=True) + if "criterion_state_dict" in checkpoint: + try: + criterion.load_state_dict( + checkpoint["criterion_state_dict"], strict=False) + except Exception: + pass + + model.to(device) + head.to(device) + criterion.to(device) + + frames: List[pd.DataFrame] = [] + for mc_idx in mc_indices: + eval_cfg = EvalAgeConfig( + horizons_years=horizons, + age_bins=age_bins, + topk_percents=topk_percents, + n_mc=1, + seed=int(seed), + cause_ids=None, + ) + + df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins( + model=model, + head=head, + criterion=criterion, + dataloader=loader, + n_disease=int(dataset.n_disease), + cfg=eval_cfg, + device=device, + mc_offset=int(mc_idx), + ) + frames.append(df_by_bin) + + df_all = pd.concat(frames, ignore_index=True) if len( + frames) else pd.DataFrame() + df_all.to_csv(out_path, index=False) + queue.put({"ok": True, "out_path": out_path}) + except Exception as e: + queue.put({"ok": False, "error": repr(e)}) + + def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): if loss_type == "exponential": criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) @@ -134,6 +278,13 @@ def main() -> None: parser.add_argument("--n_mc", type=int, default=5) parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3", + ) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--batch_size", type=int, default=256) @@ -232,16 +383,6 @@ def main() -> None: cause_ids=None, ) - df_by_bin, df_agg = evaluate_time_dependent_age_bins( - model=model, - head=head, - criterion=criterion, - dataloader=loader, - n_disease=int(dataset.n_disease), - cfg=eval_cfg, - device=device, - ) - if args.out_prefix is None: out_prefix = os.path.join( args.run_dir, f"age_bin_time_dependent_{args.split}") @@ -251,9 +392,91 @@ def main() -> None: out_bin = out_prefix + "_by_bin.csv" out_agg = out_prefix + "_agg.csv" + gpus = _parse_gpus(args.gpus) + if len(gpus) <= 1: + df_by_bin, df_agg = evaluate_time_dependent_age_bins( + model=model, + head=head, + criterion=criterion, + dataloader=loader, + n_disease=int(dataset.n_disease), + cfg=eval_cfg, + device=device, + ) + + df_by_bin.to_csv(out_bin, index=False) + df_agg.to_csv(out_agg, index=False) + print(f"Wrote: {out_bin}") + print(f"Wrote: {out_agg}") + return + + if not torch.cuda.is_available(): + raise SystemExit("--gpus was provided but CUDA is not available") + + # Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU). + mc_indices_all = list(range(int(args.n_mc))) + per_gpu: List[Tuple[int, List[int]]] = [] + for pos, gpu_id in enumerate(gpus): + assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos] + if assigned: + per_gpu.append((int(gpu_id), assigned)) + + ctx = mp.get_context("spawn") + queue: "mp.Queue" = ctx.Queue() + procs: List[mp.Process] = [] + tmp_paths: List[str] = [] + + for gpu_id, mc_idxs in per_gpu: + tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv" + tmp_paths.append(tmp_path) + p = ctx.Process( + target=_worker_eval_mcs_on_gpu, + kwargs=dict( + queue=queue, + run_dir=str(args.run_dir), + split=str(args.split), + data_prefix_override=( + str(args.data_prefix) if args.data_prefix is not None else None + ), + horizons=_parse_floats(args.horizons), + age_bins=age_bins, + topk_percents=[float(x) for x in args.topk_percent], + n_mc=int(args.n_mc), + seed=int(args.seed), + batch_size=int(args.batch_size), + num_workers=int(args.num_workers), + gpu_id=int(gpu_id), + mc_indices=mc_idxs, + out_path=tmp_path, + ), + ) + p.start() + procs.append(p) + + results = [queue.get() for _ in range(len(procs))] + for p in procs: + p.join() + + for r in results: + if not r.get("ok", False): + raise SystemExit(f"Worker failed: {r.get('error')}") + + frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)] + df_by_bin = pd.concat(frames, ignore_index=True) if len( + frames) else pd.DataFrame() + df_agg = aggregate_age_bin_results(df_by_bin) + df_by_bin.to_csv(out_bin, index=False) df_agg.to_csv(out_agg, index=False) + # Best-effort cleanup. + for p in tmp_paths: + try: + if os.path.exists(p): + os.remove(p) + except Exception: + pass + print(f"Wrote: {out_bin}") print(f"Wrote: {out_agg}") diff --git a/evaluation_age_time_dependent.py b/evaluation_age_time_dependent.py index 4368497..d3f8c3e 100644 --- a/evaluation_age_time_dependent.py +++ b/evaluation_age_time_dependent.py @@ -18,6 +18,158 @@ from utils import ( ) +def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: + """Aggregate per-bin age evaluation results. + + Produces both: + - macro: unweighted mean over bins with n_samples>0 + - weighted: weighted mean over bins using weights=n_samples + + Then aggregates across MC repetitions (mean/std). + + Requires df_by_bin to include: + mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id, + n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score + + Returns: + DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id) + """ + if df_by_bin is None or len(df_by_bin) == 0: + return pd.DataFrame( + columns=[ + "agg_type", + "horizon_tau", + "topk_percent", + "cause_id", + "n_mc", + "n_bins_used_mean", + "n_samples_total_mean", + "n_positives_total_mean", + "auc_mean", + "auc_std", + "auprc_mean", + "auprc_std", + "recall_at_K_mean", + "recall_at_K_std", + "precision_at_K_mean", + "precision_at_K_std", + "brier_score_mean", + "brier_score_std", + ] + ) + + def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: + g = group[group["n_samples"] > 0] + if len(g) == 0: + return pd.Series( + dict( + n_bins_used=0, + n_samples_total=0, + n_positives_total=0, + auc=float("nan"), + auprc=float("nan"), + recall_at_K=float("nan"), + precision_at_K=float("nan"), + brier_score=float("nan"), + ) + ) + + n_bins_used = int(g["age_bin_id"].nunique()) + n_samples_total = int(g["n_samples"].sum()) + n_positives_total = int(g["n_positives"].sum()) + + if not weighted: + return pd.Series( + dict( + n_bins_used=n_bins_used, + n_samples_total=n_samples_total, + n_positives_total=n_positives_total, + auc=float(g["auc"].mean()), + auprc=float(g["auprc"].mean()), + recall_at_K=float(g["recall_at_K"].mean()), + precision_at_K=float(g["precision_at_K"].mean()), + brier_score=float(g["brier_score"].mean()), + ) + ) + + w = g["n_samples"].to_numpy(dtype=float) + w_sum = float(w.sum()) + if w_sum <= 0.0: + return pd.Series( + dict( + n_bins_used=n_bins_used, + n_samples_total=n_samples_total, + n_positives_total=n_positives_total, + auc=float("nan"), + auprc=float("nan"), + recall_at_K=float("nan"), + precision_at_K=float("nan"), + brier_score=float("nan"), + ) + ) + + def _wavg(col: str) -> float: + return float(np.average(g[col].to_numpy(dtype=float), weights=w)) + + return pd.Series( + dict( + n_bins_used=n_bins_used, + n_samples_total=n_samples_total, + n_positives_total=n_positives_total, + auc=_wavg("auc"), + auprc=_wavg("auprc"), + recall_at_K=_wavg("recall_at_K"), + precision_at_K=_wavg("precision_at_K"), + brier_score=_wavg("brier_score"), + ) + ) + + group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] + + df_mc_macro = ( + df_by_bin.groupby(group_keys) + .apply(lambda g: _bin_aggregate(g, weighted=False)) + .reset_index() + ) + df_mc_macro["agg_type"] = "macro" + + df_mc_weighted = ( + df_by_bin.groupby(group_keys) + .apply(lambda g: _bin_aggregate(g, weighted=True)) + .reset_index() + ) + df_mc_weighted["agg_type"] = "weighted" + + df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) + + df_agg = ( + df_mc_binagg.groupby( + ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False + ) + .agg( + n_mc=("mc_idx", "nunique"), + n_bins_used_mean=("n_bins_used", "mean"), + n_samples_total_mean=("n_samples_total", "mean"), + n_positives_total_mean=("n_positives_total", "mean"), + auc_mean=("auc", "mean"), + auc_std=("auc", "std"), + auprc_mean=("auprc", "mean"), + auprc_std=("auprc", "std"), + recall_at_K_mean=("recall_at_K", "mean"), + recall_at_K_std=("recall_at_K", "std"), + precision_at_K_mean=("precision_at_K", "mean"), + precision_at_K_std=("precision_at_K", "std"), + brier_score_mean=("brier_score", "mean"), + brier_score_std=("brier_score", "std"), + ) + .sort_values( + ["agg_type", "horizon_tau", "topk_percent", "cause_id"], + ignore_index=True, + ) + ) + return df_agg + + def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: """Compute ROC AUC for binary labels with tie-aware ranking. @@ -138,6 +290,7 @@ def evaluate_time_dependent_age_bins( n_disease: int, cfg: EvalAgeConfig, device: str | torch.device, + mc_offset: int = 0, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Delphi-2M-style age-bin evaluation with strict horizon alignment. @@ -196,6 +349,7 @@ def evaluate_time_dependent_age_bins( ] for mc_idx in range(int(cfg.n_mc)): + global_mc_idx = int(mc_offset) + int(mc_idx) # tqdm over batches; include MC idx in description. for batch_idx, batch in enumerate( tqdm(dataloader, @@ -216,7 +370,7 @@ def evaluate_time_dependent_age_bins( # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. seed = ( int(cfg.seed) - + (100_000 * int(mc_idx)) + + (100_000 * int(global_mc_idx)) + (1_000 * int(tau_idx)) + (10 * int(bin_idx)) + int(batch_idx) @@ -277,6 +431,7 @@ def evaluate_time_dependent_age_bins( rows_by_bin: List[Dict[str, float | int]] = [] for mc_idx in range(int(cfg.n_mc)): + global_mc_idx = int(mc_offset) + int(mc_idx) for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): if len(y_true[mc_idx][h_idx][bin_idx]) == 0: @@ -287,7 +442,7 @@ def evaluate_time_dependent_age_bins( for k_percent in topk_percents: rows_by_bin.append( dict( - mc_idx=mc_idx, + mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), @@ -332,7 +487,7 @@ def evaluate_time_dependent_age_bins( yk, pk, float(k_percent)) rows_by_bin.append( dict( - mc_idx=mc_idx, + mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), @@ -351,115 +506,6 @@ def evaluate_time_dependent_age_bins( df_by_bin = pd.DataFrame(rows_by_bin) - def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: - g = group[group["n_samples"] > 0] - if len(g) == 0: - return pd.Series( - dict( - n_bins_used=0, - n_samples_total=0, - n_positives_total=0, - auc=float("nan"), - auprc=float("nan"), - recall_at_K=float("nan"), - precision_at_K=float("nan"), - brier_score=float("nan"), - ) - ) - - n_bins_used = int(g["age_bin_id"].nunique()) - n_samples_total = int(g["n_samples"].sum()) - n_positives_total = int(g["n_positives"].sum()) - - if not weighted: - return pd.Series( - dict( - n_bins_used=n_bins_used, - n_samples_total=n_samples_total, - n_positives_total=n_positives_total, - auc=float(g["auc"].mean()), - auprc=float(g["auprc"].mean()), - recall_at_K=float(g["recall_at_K"].mean()), - precision_at_K=float(g["precision_at_K"].mean()), - brier_score=float(g["brier_score"].mean()), - ) - ) - - w = g["n_samples"].to_numpy(dtype=float) - w_sum = float(w.sum()) - if w_sum <= 0.0: - return pd.Series( - dict( - n_bins_used=n_bins_used, - n_samples_total=n_samples_total, - n_positives_total=n_positives_total, - auc=float("nan"), - auprc=float("nan"), - recall_at_K=float("nan"), - precision_at_K=float("nan"), - brier_score=float("nan"), - ) - ) - - def _wavg(col: str) -> float: - return float(np.average(g[col].to_numpy(dtype=float), weights=w)) - - return pd.Series( - dict( - n_bins_used=n_bins_used, - n_samples_total=n_samples_total, - n_positives_total=n_positives_total, - auc=_wavg("auc"), - auprc=_wavg("auprc"), - recall_at_K=_wavg("recall_at_K"), - precision_at_K=_wavg("precision_at_K"), - brier_score=_wavg("brier_score"), - ) - ) - - group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] - - df_mc_macro = ( - df_by_bin.groupby(group_keys) - .apply(lambda g: _bin_aggregate(g, weighted=False)) - .reset_index() - ) - df_mc_macro["agg_type"] = "macro" - - df_mc_weighted = ( - df_by_bin.groupby(group_keys) - .apply(lambda g: _bin_aggregate(g, weighted=True)) - .reset_index() - ) - df_mc_weighted["agg_type"] = "weighted" - - df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) - - # Then average over MC repetitions. - df_agg = ( - df_mc_binagg.groupby( - ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False - ) - .agg( - n_mc=("mc_idx", "nunique"), - n_bins_used_mean=("n_bins_used", "mean"), - n_samples_total_mean=("n_samples_total", "mean"), - n_positives_total_mean=("n_positives_total", "mean"), - auc_mean=("auc", "mean"), - auc_std=("auc", "std"), - auprc_mean=("auprc", "mean"), - auprc_std=("auprc", "std"), - recall_at_K_mean=("recall_at_K", "mean"), - recall_at_K_std=("recall_at_K", "std"), - precision_at_K_mean=("precision_at_K", "mean"), - precision_at_K_std=("precision_at_K", "std"), - brier_score_mean=("brier_score", "mean"), - brier_score_std=("brier_score", "std"), - ) - .sort_values( - ["agg_type", "horizon_tau", "topk_percent", "cause_id"], - ignore_index=True, - ) - ) + df_agg = aggregate_age_bin_results(df_by_bin) return df_by_bin, df_agg