From 90dffc321154852ddd35e58a341d5eb9ca3ed92a Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 16 Jan 2026 16:13:31 +0800 Subject: [PATCH] Add evaluation scripts for age-bin time-dependent metrics and remove obsolete evaluation_time_dependent.py --- evaluate.py => evaluate_age.py | 82 ++++-- evaluation_age_time_dependent.py | 469 +++++++++++++++++++++++++++++++ evaluation_time_dependent.py | 322 --------------------- utils.py | 73 +++++ 4 files changed, 597 insertions(+), 349 deletions(-) rename evaluate.py => evaluate_age.py (74%) create mode 100644 evaluation_age_time_dependent.py delete mode 100644 evaluation_time_dependent.py diff --git a/evaluate.py b/evaluate_age.py similarity index 74% rename from evaluate.py rename to evaluate_age.py index 0b5002e..f5aa209 100644 --- a/evaluate.py +++ b/evaluate_age.py @@ -4,13 +4,13 @@ import argparse import json import math import os -from typing import List, Sequence +from typing import List, Sequence, Tuple import torch from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn -from evaluation_time_dependent import EvalConfig, evaluate_time_dependent +from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead @@ -25,6 +25,20 @@ def _parse_floats(items: Sequence[str]) -> List[float]: return out +def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]: + vals = _parse_floats(edges) + if len(vals) < 2: + raise ValueError("--age_bin_edges must have at least 2 values") + for i in range(1, len(vals)): + if not (vals[i] > vals[i - 1]): + raise ValueError("--age_bin_edges must be strictly increasing") + return vals + + +def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]: + return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)] + + def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): if loss_type == "exponential": criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) @@ -90,44 +104,48 @@ def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): def main() -> None: parser = argparse.ArgumentParser( - description="Time-dependent evaluation for DeepHealth") + description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})") parser.add_argument( "--run_dir", type=str, required=True, help="Training run directory (contains best_model.pt and train_config.json)", ) - parser.add_argument("--data_prefix", type=str, default=None, - help="Dataset prefix (overrides config if provided)") + parser.add_argument("--data_prefix", type=str, default=None) parser.add_argument("--split", type=str, choices=["train", "val", "test", "all"], default="val") parser.add_argument("--horizons", type=str, nargs="+", - default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)") - parser.add_argument("--offset_years", type=float, default=0.0, - help="Context selection offset (years before follow-up end)") + default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"]) + parser.add_argument( + "--age_bin_edges", + type=str, + nargs="+", + default=["40", "45", "50", "55", "60", "65", "70", "75", "80"], + help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).", + ) parser.add_argument( "--topk_percent", type=float, nargs="+", default=[1, 5, 10, 20, 50], - help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)", + help="One or more K%% values for recall/precision@K%%", ) + parser.add_argument("--n_mc", type=int, default=5) + parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--num_workers", type=int, - default=0, help="Keep 0 on Windows") + parser.add_argument("--num_workers", type=int, default=0) - parser.add_argument("--out_csv", type=str, default=None, - help="Optional output CSV path") + parser.add_argument("--out_prefix", type=str, + default=None, help="Output prefix for CSVs") args = parser.parse_args() ckpt_path = os.path.join(args.run_dir, "best_model.pt") cfg_path = os.path.join(args.run_dir, "train_config.json") - if not os.path.exists(ckpt_path): raise SystemExit(f"Missing checkpoint: {ckpt_path}") if not os.path.exists(cfg_path): @@ -139,24 +157,23 @@ def main() -> None: data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( "data_prefix", "ukb") - # Match training covariate selection. full_cov = bool(cfg.get("full_cov", False)) cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) - # Recreate the same split scheme as train.py train_ratio = float(cfg.get("train_ratio", 0.7)) val_ratio = float(cfg.get("val_ratio", 0.15)) - seed = int(cfg.get("random_seed", 42)) + seed_split = int(cfg.get("random_seed", 42)) n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val + train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], - generator=torch.Generator().manual_seed(seed), + generator=torch.Generator().manual_seed(seed_split), ) if args.split == "train": @@ -203,14 +220,19 @@ def main() -> None: head.to(device) criterion.to(device) - eval_cfg = EvalConfig( + age_edges = _parse_age_bin_edges(args.age_bin_edges) + age_bins = _edges_to_bins(age_edges) + + eval_cfg = EvalAgeConfig( horizons_years=_parse_floats(args.horizons), - offset_years=float(args.offset_years), + age_bins=age_bins, topk_percents=[float(x) for x in args.topk_percent], + n_mc=int(args.n_mc), + seed=int(args.seed), cause_ids=None, ) - df = evaluate_time_dependent( + df_by_bin, df_agg = evaluate_time_dependent_age_bins( model=model, head=head, criterion=criterion, @@ -220,14 +242,20 @@ def main() -> None: device=device, ) - if args.out_csv is None: - out_csv = os.path.join( - args.run_dir, f"time_dependent_metrics_{args.split}.csv") + if args.out_prefix is None: + out_prefix = os.path.join( + args.run_dir, f"age_bin_time_dependent_{args.split}") else: - out_csv = args.out_csv + out_prefix = args.out_prefix - df.to_csv(out_csv, index=False) - print(f"Wrote: {out_csv}") + out_bin = out_prefix + "_by_bin.csv" + out_agg = out_prefix + "_agg.csv" + + df_by_bin.to_csv(out_bin, index=False) + df_agg.to_csv(out_agg, index=False) + + print(f"Wrote: {out_bin}") + print(f"Wrote: {out_agg}") if __name__ == "__main__": diff --git a/evaluation_age_time_dependent.py b/evaluation_age_time_dependent.py new file mode 100644 index 0000000..f309750 --- /dev/null +++ b/evaluation_age_time_dependent.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +import torch + +try: + from tqdm import tqdm +except Exception: # pragma: no cover + + def tqdm(x, **kwargs): + return x + +from utils import ( + multi_hot_ever_within_horizon, + multi_hot_selected_causes_within_horizon, + sample_context_in_fixed_age_bin, +) + + +def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Compute ROC AUC for binary labels with tie-aware ranking. + + Returns NaN if y_true has no positives or no negatives. + + Uses the Mann–Whitney U statistic with average ranks for ties. + """ + y_true = np.asarray(y_true).astype(bool) + y_score = np.asarray(y_score).astype(float) + + n = y_true.size + if n == 0: + return float("nan") + + n_pos = int(y_true.sum()) + n_neg = n - n_pos + if n_pos == 0 or n_neg == 0: + return float("nan") + + # Rank scores ascending, average ranks for ties. + order = np.argsort(y_score, kind="mergesort") + sorted_scores = y_score[order] + + ranks = np.empty(n, dtype=float) + i = 0 + # 1-based ranks + while i < n: + j = i + 1 + while j < n and sorted_scores[j] == sorted_scores[i]: + j += 1 + avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j + ranks[order[i:j]] = avg_rank + i = j + + sum_ranks_pos = float(ranks[y_true].sum()) + u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) + return float(u / (n_pos * n_neg)) + + +def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Average precision (area under PR curve using step-wise interpolation). + + Returns NaN if no positives. + """ + y_true = np.asarray(y_true).astype(bool) + y_score = np.asarray(y_score).astype(float) + + n = y_true.size + if n == 0: + return float("nan") + + n_pos = int(y_true.sum()) + if n_pos == 0: + return float("nan") + + order = np.argsort(-y_score, kind="mergesort") + y = y_true[order] + + tp = np.cumsum(y).astype(float) + fp = np.cumsum(~y).astype(float) + precision = tp / np.maximum(tp + fp, 1.0) + + # AP = sum over each positive of precision at that point / n_pos + # (equivalent to ∑ Δrecall * precision) + ap = float(np.sum(precision[y]) / n_pos) + # handle potential tiny numerical overshoots + return float(max(0.0, min(1.0, ap))) + + +def _precision_recall_at_k_percent( + y_true: np.ndarray, + y_score: np.ndarray, + k_percent: float, +) -> Tuple[float, float]: + """Precision@K% and Recall@K% for binary labels. + + Returns (precision, recall). Returns NaN for recall if no positives. + Returns NaN for precision if k leads to 0 selected. + """ + y_true = np.asarray(y_true).astype(bool) + y_score = np.asarray(y_score).astype(float) + + n = y_true.size + if n == 0: + return float("nan"), float("nan") + + n_pos = int(y_true.sum()) + + k = int(math.ceil((float(k_percent) / 100.0) * n)) + if k <= 0: + return float("nan"), float("nan") + + order = np.argsort(-y_score, kind="mergesort") + top = order[:k] + tp_top = int(y_true[top].sum()) + + precision = tp_top / k + recall = float("nan") if n_pos == 0 else (tp_top / n_pos) + return float(precision), float(recall) + + +@dataclass +class EvalAgeConfig: + horizons_years: Sequence[float] + age_bins: Sequence[Tuple[float, float]] + topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) + n_mc: int = 5 + seed: int = 0 + cause_ids: Optional[Sequence[int]] = None + + +@torch.no_grad() +def evaluate_time_dependent_age_bins( + model: torch.nn.Module, + head: torch.nn.Module, + criterion, + dataloader: torch.utils.data.DataLoader, + n_disease: int, + cfg: EvalAgeConfig, + device: str | torch.device, +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Delphi-2M-style age-bin evaluation with strict horizon alignment. + + Semantics (strict): for each (MC, horizon tau, age bin) we independently: + - build the eligible token set within that bin + - enforce follow-up coverage: t_ctx + tau <= t_end + - randomly sample exactly one token per individual within the bin (de-dup) + - recompute context representations and predictions for that (tau, bin) + + Returns: + df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id) + df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted} + """ + device = torch.device(device) + model.eval() + head.eval() + + horizons_years = [float(x) for x in cfg.horizons_years] + if len(horizons_years) == 0: + raise ValueError("cfg.horizons_years must be non-empty") + + age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins] + if len(age_bins) == 0: + raise ValueError("cfg.age_bins must be non-empty") + for (a, b) in age_bins: + if not (b > a): + raise ValueError( + f"age_bins must be (low, high) with high>low; got {(a, b)}") + + topk_percents = [float(x) for x in cfg.topk_percents] + if len(topk_percents) == 0: + raise ValueError("cfg.topk_percents must be non-empty") + if any((p <= 0.0 or p > 100.0) for p in topk_percents): + raise ValueError( + f"All topk_percents must be in (0,100]; got {topk_percents}") + + if int(cfg.n_mc) <= 0: + raise ValueError("cfg.n_mc must be >= 1") + + if cfg.cause_ids is None: + cause_ids = None + n_causes_eval = int(n_disease) + else: + cause_ids = torch.tensor( + list(cfg.cause_ids), dtype=torch.long, device=device) + n_causes_eval = int(cause_ids.numel()) + + # Storage: (mc, h, bin) -> list of arrays + y_true: List[List[List[List[np.ndarray]]]] = [ + [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] + for _ in range(int(cfg.n_mc)) + ] + y_pred: List[List[List[List[np.ndarray]]]] = [ + [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] + for _ in range(int(cfg.n_mc)) + ] + + for mc_idx in range(int(cfg.n_mc)): + # tqdm over batches; include MC idx in description. + for batch_idx, batch in enumerate( + tqdm(dataloader, + desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch") + ): + event_seq, time_seq, cont_feats, cate_feats, sexes = batch + event_seq = event_seq.to(device) + time_seq = time_seq.to(device) + cont_feats = cont_feats.to(device) + cate_feats = cate_feats.to(device) + sexes = sexes.to(device) + + B = int(event_seq.size(0)) + b = torch.arange(B, device=device) + + for tau_idx, tau_y in enumerate(horizons_years): + for bin_idx, (a_lo, a_hi) in enumerate(age_bins): + # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. + seed = ( + int(cfg.seed) + + (100_000 * int(mc_idx)) + + (1_000 * int(tau_idx)) + + (10 * int(bin_idx)) + + int(batch_idx) + ) + + keep, t_ctx = sample_context_in_fixed_age_bin( + event_seq=event_seq, + time_seq=time_seq, + tau_years=float(tau_y), + age_bin=(float(a_lo), float(a_hi)), + seed=seed, + ) + if not keep.any(): + continue + + # Strict bin-specific prediction: recompute representations and logits per (tau, bin). + h = model(event_seq, time_seq, sexes, + cont_feats, cate_feats) # (B,L,D) + c = h[b, t_ctx] + logits = head(c) + + cifs = criterion.calculate_cifs( + logits, taus=torch.tensor(float(tau_y), device=device) + ) + if cifs.ndim != 2: + raise ValueError( + "criterion.calculate_cifs must return (B,K) for scalar tau; " + f"got shape={tuple(cifs.shape)}" + ) + + if cause_ids is None: + y = multi_hot_ever_within_horizon( + event_seq=event_seq, + time_seq=time_seq, + t_ctx=t_ctx, + tau_years=float(tau_y), + n_disease=n_disease, + ) + preds = cifs + else: + y = multi_hot_selected_causes_within_horizon( + event_seq=event_seq, + time_seq=time_seq, + t_ctx=t_ctx, + tau_years=float(tau_y), + cause_ids=cause_ids, + n_disease=n_disease, + ) + preds = cifs.index_select(dim=1, index=cause_ids) + + y_true[mc_idx][tau_idx][bin_idx].append( + y[keep].detach().to(torch.bool).cpu().numpy() + ) + y_pred[mc_idx][tau_idx][bin_idx].append( + preds[keep].detach().to(torch.float32).cpu().numpy() + ) + + rows_by_bin: List[Dict[str, float | int]] = [] + + for mc_idx in range(int(cfg.n_mc)): + for h_idx, tau_y in enumerate(horizons_years): + for bin_idx, (a_lo, a_hi) in enumerate(age_bins): + if len(y_true[mc_idx][h_idx][bin_idx]) == 0: + # No samples in this bin for this (mc, tau) + for cause_k in range(n_causes_eval): + cause_id = int(cause_k) if cause_ids is None else int( + cfg.cause_ids[cause_k]) + for k_percent in topk_percents: + rows_by_bin.append( + dict( + mc_idx=mc_idx, + age_bin_id=bin_idx, + age_bin_low=float(a_lo), + age_bin_high=float(a_hi), + horizon_tau=float(tau_y), + topk_percent=float(k_percent), + cause_id=cause_id, + n_samples=0, + n_positives=0, + auc=float("nan"), + auprc=float("nan"), + recall_at_K=float("nan"), + precision_at_K=float("nan"), + brier_score=float("nan"), + ) + ) + continue + + yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0) + pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0) + if yb.shape != pb.shape: + raise ValueError( + f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}" + ) + + n_samples = int(yb.shape[0]) + + for cause_k in range(n_causes_eval): + yk = yb[:, cause_k] + pk = pb[:, cause_k] + n_pos = int(yk.sum()) + + auc = _binary_roc_auc(yk, pk) + auprc = _average_precision(yk, pk) + brier = float(np.mean( + (yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan") + + cause_id = int(cause_k) if cause_ids is None else int( + cfg.cause_ids[cause_k]) + + for k_percent in topk_percents: + precision_k, recall_k = _precision_recall_at_k_percent( + yk, pk, float(k_percent)) + rows_by_bin.append( + dict( + mc_idx=mc_idx, + age_bin_id=bin_idx, + age_bin_low=float(a_lo), + age_bin_high=float(a_hi), + horizon_tau=float(tau_y), + topk_percent=float(k_percent), + cause_id=cause_id, + n_samples=n_samples, + n_positives=n_pos, + auc=float(auc), + auprc=float(auprc), + recall_at_K=float(recall_k), + precision_at_K=float(precision_k), + brier_score=float(brier), + ) + ) + + df_by_bin = pd.DataFrame(rows_by_bin) + + def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: + g = group[group["n_samples"] > 0] + if len(g) == 0: + return pd.Series( + dict( + n_bins_used=0, + n_samples_total=0, + n_positives_total=0, + auc=float("nan"), + auprc=float("nan"), + recall_at_K=float("nan"), + precision_at_K=float("nan"), + brier_score=float("nan"), + ) + ) + + n_bins_used = int(g["age_bin_id"].nunique()) + n_samples_total = int(g["n_samples"].sum()) + n_positives_total = int(g["n_positives"].sum()) + + if not weighted: + return pd.Series( + dict( + n_bins_used=n_bins_used, + n_samples_total=n_samples_total, + n_positives_total=n_positives_total, + auc=float(g["auc"].mean()), + auprc=float(g["auprc"].mean()), + recall_at_K=float(g["recall_at_K"].mean()), + precision_at_K=float(g["precision_at_K"].mean()), + brier_score=float(g["brier_score"].mean()), + ) + ) + + w = g["n_samples"].to_numpy(dtype=float) + w_sum = float(w.sum()) + if w_sum <= 0.0: + return pd.Series( + dict( + n_bins_used=n_bins_used, + n_samples_total=n_samples_total, + n_positives_total=n_positives_total, + auc=float("nan"), + auprc=float("nan"), + recall_at_K=float("nan"), + precision_at_K=float("nan"), + brier_score=float("nan"), + ) + ) + + def _wavg(col: str) -> float: + return float(np.average(g[col].to_numpy(dtype=float), weights=w)) + + return pd.Series( + dict( + n_bins_used=n_bins_used, + n_samples_total=n_samples_total, + n_positives_total=n_positives_total, + auc=_wavg("auc"), + auprc=_wavg("auprc"), + recall_at_K=_wavg("recall_at_K"), + precision_at_K=_wavg("precision_at_K"), + brier_score=_wavg("brier_score"), + ) + ) + + group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] + + df_mc_macro = ( + df_by_bin.groupby(group_keys) + .apply(lambda g: _bin_aggregate(g, weighted=False)) + .reset_index() + ) + df_mc_macro["agg_type"] = "macro" + + df_mc_weighted = ( + df_by_bin.groupby(group_keys) + .apply(lambda g: _bin_aggregate(g, weighted=True)) + .reset_index() + ) + df_mc_weighted["agg_type"] = "weighted" + + df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) + + # Then average over MC repetitions. + df_agg = ( + df_mc_binagg.groupby( + ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False + ) + .agg( + n_mc=("mc_idx", "nunique"), + n_bins_used_mean=("n_bins_used", "mean"), + n_samples_total_mean=("n_samples_total", "mean"), + n_positives_total_mean=("n_positives_total", "mean"), + auc_mean=("auc", "mean"), + auc_std=("auc", "std"), + auprc_mean=("auprc", "mean"), + auprc_std=("auprc", "std"), + recall_at_K_mean=("recall_at_K", "mean"), + recall_at_K_std=("recall_at_K", "std"), + precision_at_K_mean=("precision_at_K", "mean"), + precision_at_K_std=("precision_at_K", "std"), + brier_score_mean=("brier_score", "mean"), + brier_score_std=("brier_score", "std"), + ) + .sort_values( + ["agg_type", "horizon_tau", "topk_percent", "cause_id"], + ignore_index=True, + ) + ) + + return df_by_bin, df_agg diff --git a/evaluation_time_dependent.py b/evaluation_time_dependent.py deleted file mode 100644 index 955b02b..0000000 --- a/evaluation_time_dependent.py +++ /dev/null @@ -1,322 +0,0 @@ -from __future__ import annotations - -import math -from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple - -import numpy as np -import pandas as pd -import torch - -from utils import ( - DAYS_PER_YEAR, - multi_hot_ever_within_horizon, - multi_hot_selected_causes_within_horizon, - select_context_indices, -) - - -def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: - """Compute ROC AUC for binary labels with tie-aware ranking. - - Returns NaN if y_true has no positives or no negatives. - - Uses the Mann–Whitney U statistic with average ranks for ties. - """ - y_true = np.asarray(y_true).astype(bool) - y_score = np.asarray(y_score).astype(float) - - n = y_true.size - if n == 0: - return float("nan") - - n_pos = int(y_true.sum()) - n_neg = n - n_pos - if n_pos == 0 or n_neg == 0: - return float("nan") - - # Rank scores ascending, average ranks for ties. - order = np.argsort(y_score, kind="mergesort") - sorted_scores = y_score[order] - - ranks = np.empty(n, dtype=float) - i = 0 - # 1-based ranks - while i < n: - j = i + 1 - while j < n and sorted_scores[j] == sorted_scores[i]: - j += 1 - avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j - ranks[order[i:j]] = avg_rank - i = j - - sum_ranks_pos = float(ranks[y_true].sum()) - u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) - return float(u / (n_pos * n_neg)) - - -def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: - """Average precision (area under PR curve using step-wise interpolation). - - Returns NaN if no positives. - """ - y_true = np.asarray(y_true).astype(bool) - y_score = np.asarray(y_score).astype(float) - - n = y_true.size - if n == 0: - return float("nan") - - n_pos = int(y_true.sum()) - if n_pos == 0: - return float("nan") - - order = np.argsort(-y_score, kind="mergesort") - y = y_true[order] - - tp = np.cumsum(y).astype(float) - fp = np.cumsum(~y).astype(float) - precision = tp / np.maximum(tp + fp, 1.0) - recall = tp / n_pos - - # AP = sum over each positive of precision at that point / n_pos - # (equivalent to ∑ Δrecall * precision) - ap = float(np.sum(precision[y]) / n_pos) - # handle potential tiny numerical overshoots - return float(max(0.0, min(1.0, ap))) - - -def _precision_recall_at_k_percent( - y_true: np.ndarray, - y_score: np.ndarray, - k_percent: float, -) -> Tuple[float, float]: - """Precision@K% and Recall@K% for binary labels. - - Returns (precision, recall). Returns NaN for recall if no positives. - Returns NaN for precision if k leads to 0 selected. - """ - y_true = np.asarray(y_true).astype(bool) - y_score = np.asarray(y_score).astype(float) - - n = y_true.size - if n == 0: - return float("nan"), float("nan") - - n_pos = int(y_true.sum()) - - k = int(math.ceil((float(k_percent) / 100.0) * n)) - if k <= 0: - return float("nan"), float("nan") - - order = np.argsort(-y_score, kind="mergesort") - top = order[:k] - tp_top = int(y_true[top].sum()) - - precision = tp_top / k - recall = float("nan") if n_pos == 0 else (tp_top / n_pos) - return float(precision), float(recall) - - -@dataclass -class EvalConfig: - horizons_years: Sequence[float] - offset_years: float = 0.0 - topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) - cause_ids: Optional[Sequence[int]] = None - - -@torch.no_grad() -def evaluate_time_dependent( - model: torch.nn.Module, - head: torch.nn.Module, - criterion, - dataloader: torch.utils.data.DataLoader, - n_disease: int, - cfg: EvalConfig, - device: str | torch.device, -) -> pd.DataFrame: - """Evaluate time-dependent metrics per cause and per horizon. - - Assumptions: - - time_seq is in days - - horizons_years and the loss CIF times are in years - - disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2 - - Returns: - DataFrame with columns: - cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc, - recall_at_K, precision_at_K, brier_score - """ - device = torch.device(device) - model.eval() - head.eval() - - horizons_years = [float(x) for x in cfg.horizons_years] - if len(horizons_years) == 0: - raise ValueError("cfg.horizons_years must be non-empty") - - topk_percents = [float(x) for x in cfg.topk_percents] - if len(topk_percents) == 0: - raise ValueError("cfg.topk_percents must be non-empty") - if any((p <= 0.0 or p > 100.0) for p in topk_percents): - raise ValueError( - f"All topk_percents must be in (0,100]; got {topk_percents}") - - taus_tensor = torch.tensor( - horizons_years, device=device, dtype=torch.float32) - - if cfg.cause_ids is None: - cause_ids = None - n_causes_eval = int(n_disease) - else: - cause_ids = torch.tensor( - list(cfg.cause_ids), dtype=torch.long, device=device) - n_causes_eval = int(cause_ids.numel()) - - # Accumulate per horizon - y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years] - y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years] - - for batch in dataloader: - event_seq, time_seq, cont_feats, cate_feats, sexes = batch - event_seq = event_seq.to(device) - time_seq = time_seq.to(device) - cont_feats = cont_feats.to(device) - cate_feats = cate_feats.to(device) - sexes = sexes.to(device) - - h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) - - # Select a single fixed context per sample for this batch. - # Horizon-specific eligibility is derived from this context (do not re-select per horizon). - keep0, t_ctx, t_ctx_time = select_context_indices( - event_seq=event_seq, - time_seq=time_seq, - offset_years=float(cfg.offset_years), - tau_years=0.0, - ) - - if not keep0.any(): - continue - - b = torch.arange(event_seq.size(0), device=device) - c = h[b, t_ctx] # (B,D) - logits = head(c) - - # CIFs for all horizons at once - cifs_all = criterion.calculate_cifs( - logits, taus=taus_tensor) # (B,K,T) or (B,K) - if cifs_all.ndim != 3: - raise ValueError( - f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}" - ) - - # Follow-up end time per sample = time at last valid token. - valid = event_seq != 0 - lengths = valid.sum(dim=1) - last_idx = torch.clamp(lengths - 1, min=0) - followup_end_time = time_seq[b, last_idx] - - for h_idx, tau_y in enumerate(horizons_years): - # Horizon-specific eligibility without reselecting context: - # keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau) - keep_tau = keep0 & ( - followup_end_time >= ( - t_ctx_time + (float(tau_y) * DAYS_PER_YEAR)) - ) - if not keep_tau.any(): - continue - - if cause_ids is None: - y = multi_hot_ever_within_horizon( - event_seq=event_seq, - time_seq=time_seq, - t_ctx=t_ctx, - tau_years=float(tau_y), - n_disease=n_disease, - ) - y = y[keep_tau] - preds = cifs_all[keep_tau, :, h_idx] - else: - y = multi_hot_selected_causes_within_horizon( - event_seq=event_seq, - time_seq=time_seq, - t_ctx=t_ctx, - tau_years=float(tau_y), - cause_ids=cause_ids, - n_disease=n_disease, - ) - y = y[keep_tau] - preds = cifs_all[keep_tau, :, h_idx].index_select( - dim=1, index=cause_ids) - - y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy()) - y_pred_by_h[h_idx].append( - preds.detach().to(torch.float32).cpu().numpy()) - - rows: List[Dict[str, float | int]] = [] - - for h_idx, tau_y in enumerate(horizons_years): - if len(y_true_by_h[h_idx]) == 0: - # No eligible samples for this horizon. - for k in range(n_causes_eval): - cause_id = int(k) if cause_ids is None else int( - cfg.cause_ids[k]) - for k_percent in topk_percents: - rows.append( - dict( - cause_id=cause_id, - horizon_tau=float(tau_y), - topk_percent=float(k_percent), - n_samples=0, - n_positives=0, - auc=float("nan"), - auprc=float("nan"), - recall_at_K=float("nan"), - precision_at_K=float("nan"), - brier_score=float("nan"), - ) - ) - continue - - y_true = np.concatenate(y_true_by_h[h_idx], axis=0) - y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0) - - if y_true.shape != y_pred.shape: - raise ValueError( - f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}" - ) - - n_samples = int(y_true.shape[0]) - - for k in range(n_causes_eval): - yk = y_true[:, k] - pk = y_pred[:, k] - n_pos = int(yk.sum()) - - auc = _binary_roc_auc(yk, pk) - auprc = _average_precision(yk, pk) - brier = float(np.mean((yk.astype(float) - pk.astype(float)) - ** 2)) if n_samples > 0 else float("nan") - - cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k]) - for k_percent in topk_percents: - precision_k, recall_k = _precision_recall_at_k_percent( - yk, pk, float(k_percent)) - rows.append( - dict( - cause_id=cause_id, - horizon_tau=float(tau_y), - topk_percent=float(k_percent), - n_samples=n_samples, - n_positives=n_pos, - auc=float(auc), - auprc=float(auprc), - recall_at_K=float(recall_k), - precision_at_K=float(precision_k), - brier_score=float(brier), - ) - ) - - return pd.DataFrame(rows) diff --git a/utils.py b/utils.py index 25b6d3f..0aded3b 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,79 @@ from typing import Tuple DAYS_PER_YEAR = 365.25 +def sample_context_in_fixed_age_bin( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + tau_years: float, + age_bin: Tuple[float, float], + seed: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Sample one context token per individual within a fixed age bin. + + Delphi-2M semantics for a specific (tau, age_bin): + - Token times are interpreted as age in *days* (converted to years). + - Follow-up end time is the last valid token time per individual. + - A token index j is eligible iff: + (token is valid) + AND (age_years in [age_low, age_high)) + AND (time_seq[i, j] + tau_days <= followup_end_time[i]) + - For each individual, randomly select exactly one eligible token in this bin. + + Args: + event_seq: (B, L) token ids, 0 is padding. + time_seq: (B, L) token times in days. + tau_years: horizon length in years. + age_bin: (low, high) bounds in years, interpreted as [low, high). + seed: RNG seed for deterministic sampling. + + Returns: + keep: (B,) bool, True if a context was sampled for this bin. + t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0). + """ + low, high = float(age_bin[0]), float(age_bin[1]) + if not (high > low): + raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}") + + device = event_seq.device + B, _ = event_seq.shape + + valid = event_seq != 0 + lengths = valid.sum(dim=1) + last_idx = torch.clamp(lengths - 1, min=0) + b = torch.arange(B, device=device) + followup_end_time = time_seq[b, last_idx] # (B,) + + tau_days = float(tau_years) * DAYS_PER_YEAR + age_years = time_seq / DAYS_PER_YEAR + + in_bin = (age_years >= low) & (age_years < high) + eligible = valid & in_bin & ( + (time_seq + tau_days) <= followup_end_time.unsqueeze(1)) + + keep = torch.zeros((B,), dtype=torch.bool, device=device) + t_ctx = torch.zeros((B,), dtype=torch.long, device=device) + + gen = torch.Generator(device="cpu") + gen.manual_seed(int(seed)) + + for i in range(B): + m = eligible[i] + if not m.any(): + continue + + idxs = m.nonzero(as_tuple=False).view(-1).cpu() + chosen_idx_pos = int( + torch.randint(low=0, high=int(idxs.numel()), + size=(1,), generator=gen).item() + ) + chosen_t = int(idxs[chosen_idx_pos].item()) + + keep[i] = True + t_ctx[i] = chosen_t + + return keep, t_ctx + + def select_context_indices( event_seq: torch.Tensor, time_seq: torch.Tensor,