Add evaluation scripts for age-bin time-dependent metrics and remove obsolete evaluation_time_dependent.py

This commit is contained in:
2026-01-16 16:13:31 +08:00
parent 502ddd153b
commit 90dffc3211
4 changed files with 597 additions and 349 deletions

View File

@@ -4,13 +4,13 @@ import argparse
import json import json
import math import math
import os import os
from typing import List, Sequence from typing import List, Sequence, Tuple
import torch import torch
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn from dataset import HealthDataset, health_collate_fn
from evaluation_time_dependent import EvalConfig, evaluate_time_dependent from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
from model import DelphiFork, SapDelphi, SimpleHead from model import DelphiFork, SapDelphi, SimpleHead
@@ -25,6 +25,20 @@ def _parse_floats(items: Sequence[str]) -> List[float]:
return out return out
def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]:
vals = _parse_floats(edges)
if len(vals) < 2:
raise ValueError("--age_bin_edges must have at least 2 values")
for i in range(1, len(vals)):
if not (vals[i] > vals[i - 1]):
raise ValueError("--age_bin_edges must be strictly increasing")
return vals
def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
if loss_type == "exponential": if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
@@ -90,44 +104,48 @@ def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Time-dependent evaluation for DeepHealth") description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})")
parser.add_argument( parser.add_argument(
"--run_dir", "--run_dir",
type=str, type=str,
required=True, required=True,
help="Training run directory (contains best_model.pt and train_config.json)", help="Training run directory (contains best_model.pt and train_config.json)",
) )
parser.add_argument("--data_prefix", type=str, default=None, parser.add_argument("--data_prefix", type=str, default=None)
help="Dataset prefix (overrides config if provided)")
parser.add_argument("--split", type=str, parser.add_argument("--split", type=str,
choices=["train", "val", "test", "all"], default="val") choices=["train", "val", "test", "all"], default="val")
parser.add_argument("--horizons", type=str, nargs="+", parser.add_argument("--horizons", type=str, nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)") default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"])
parser.add_argument("--offset_years", type=float, default=0.0, parser.add_argument(
help="Context selection offset (years before follow-up end)") "--age_bin_edges",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "80"],
help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).",
)
parser.add_argument( parser.add_argument(
"--topk_percent", "--topk_percent",
type=float, type=float,
nargs="+", nargs="+",
default=[1, 5, 10, 20, 50], default=[1, 5, 10, 20, 50],
help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)", help="One or more K%% values for recall/precision@K%%",
) )
parser.add_argument("--n_mc", type=int, default=5)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", type=str, parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu") default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--num_workers", type=int, parser.add_argument("--num_workers", type=int, default=0)
default=0, help="Keep 0 on Windows")
parser.add_argument("--out_csv", type=str, default=None, parser.add_argument("--out_prefix", type=str,
help="Optional output CSV path") default=None, help="Output prefix for CSVs")
args = parser.parse_args() args = parser.parse_args()
ckpt_path = os.path.join(args.run_dir, "best_model.pt") ckpt_path = os.path.join(args.run_dir, "best_model.pt")
cfg_path = os.path.join(args.run_dir, "train_config.json") cfg_path = os.path.join(args.run_dir, "train_config.json")
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
raise SystemExit(f"Missing checkpoint: {ckpt_path}") raise SystemExit(f"Missing checkpoint: {ckpt_path}")
if not os.path.exists(cfg_path): if not os.path.exists(cfg_path):
@@ -139,24 +157,23 @@ def main() -> None:
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
"data_prefix", "ukb") "data_prefix", "ukb")
# Match training covariate selection.
full_cov = bool(cfg.get("full_cov", False)) full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
# Recreate the same split scheme as train.py
train_ratio = float(cfg.get("train_ratio", 0.7)) train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15)) val_ratio = float(cfg.get("val_ratio", 0.15))
seed = int(cfg.get("random_seed", 42)) seed_split = int(cfg.get("random_seed", 42))
n_total = len(dataset) n_total = len(dataset)
n_train = int(n_total * train_ratio) n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio) n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split( train_ds, val_ds, test_ds = random_split(
dataset, dataset,
[n_train, n_val, n_test], [n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed), generator=torch.Generator().manual_seed(seed_split),
) )
if args.split == "train": if args.split == "train":
@@ -203,14 +220,19 @@ def main() -> None:
head.to(device) head.to(device)
criterion.to(device) criterion.to(device)
eval_cfg = EvalConfig( age_edges = _parse_age_bin_edges(args.age_bin_edges)
age_bins = _edges_to_bins(age_edges)
eval_cfg = EvalAgeConfig(
horizons_years=_parse_floats(args.horizons), horizons_years=_parse_floats(args.horizons),
offset_years=float(args.offset_years), age_bins=age_bins,
topk_percents=[float(x) for x in args.topk_percent], topk_percents=[float(x) for x in args.topk_percent],
n_mc=int(args.n_mc),
seed=int(args.seed),
cause_ids=None, cause_ids=None,
) )
df = evaluate_time_dependent( df_by_bin, df_agg = evaluate_time_dependent_age_bins(
model=model, model=model,
head=head, head=head,
criterion=criterion, criterion=criterion,
@@ -220,14 +242,20 @@ def main() -> None:
device=device, device=device,
) )
if args.out_csv is None: if args.out_prefix is None:
out_csv = os.path.join( out_prefix = os.path.join(
args.run_dir, f"time_dependent_metrics_{args.split}.csv") args.run_dir, f"age_bin_time_dependent_{args.split}")
else: else:
out_csv = args.out_csv out_prefix = args.out_prefix
df.to_csv(out_csv, index=False) out_bin = out_prefix + "_by_bin.csv"
print(f"Wrote: {out_csv}") out_agg = out_prefix + "_agg.csv"
df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False)
print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,469 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
try:
from tqdm import tqdm
except Exception: # pragma: no cover
def tqdm(x, **kwargs):
return x
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
sample_context_in_fixed_age_bin,
)
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
@dataclass
class EvalAgeConfig:
horizons_years: Sequence[float]
age_bins: Sequence[Tuple[float, float]]
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
n_mc: int = 5
seed: int = 0
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalAgeConfig,
device: str | torch.device,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
- build the eligible token set within that bin
- enforce follow-up coverage: t_ctx + tau <= t_end
- randomly sample exactly one token per individual within the bin (de-dup)
- recompute context representations and predictions for that (tau, bin)
Returns:
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
if len(age_bins) == 0:
raise ValueError("cfg.age_bins must be non-empty")
for (a, b) in age_bins:
if not (b > a):
raise ValueError(
f"age_bins must be (low, high) with high>low; got {(a, b)}")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
if int(cfg.n_mc) <= 0:
raise ValueError("cfg.n_mc must be >= 1")
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of arrays
y_true: List[List[List[List[np.ndarray]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[np.ndarray]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
for mc_idx in range(int(cfg.n_mc)):
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
):
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
for tau_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
int(cfg.seed)
+ (100_000 * int(mc_idx))
+ (1_000 * int(tau_idx))
+ (10 * int(bin_idx))
+ int(batch_idx)
)
keep, t_ctx = sample_context_in_fixed_age_bin(
event_seq=event_seq,
time_seq=time_seq,
tau_years=float(tau_y),
age_bin=(float(a_lo), float(a_hi)),
seed=seed,
)
if not keep.any():
continue
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=torch.tensor(float(tau_y), device=device)
)
if cifs.ndim != 2:
raise ValueError(
"criterion.calculate_cifs must return (B,K) for scalar tau; "
f"got shape={tuple(cifs.shape)}"
)
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
preds = cifs
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
preds = cifs.index_select(dim=1, index=cause_ids)
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(torch.bool).cpu().numpy()
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(torch.float32).cpu().numpy()
)
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
rows_by_bin.append(
dict(
mc_idx=mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
continue
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
if yb.shape != pb.shape:
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
)
n_samples = int(yb.shape[0])
for cause_k in range(n_causes_eval):
yk = yb[:, cause_k]
pk = pb[:, cause_k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean(
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
rows_by_bin.append(
dict(
mc_idx=mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
)
)
df_by_bin = pd.DataFrame(rows_by_bin)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df_mc_macro = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=False))
.reset_index()
)
df_mc_macro["agg_type"] = "macro"
df_mc_weighted = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=True))
.reset_index()
)
df_mc_weighted["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
# Then average over MC repetitions.
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
)
return df_by_bin, df_agg

View File

@@ -1,322 +0,0 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
from utils import (
DAYS_PER_YEAR,
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
select_context_indices,
)
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
recall = tp / n_pos
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
@dataclass
class EvalConfig:
horizons_years: Sequence[float]
offset_years: float = 0.0
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
def evaluate_time_dependent(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalConfig,
device: str | torch.device,
) -> pd.DataFrame:
"""Evaluate time-dependent metrics per cause and per horizon.
Assumptions:
- time_seq is in days
- horizons_years and the loss CIF times are in years
- disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2
Returns:
DataFrame with columns:
cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc,
recall_at_K, precision_at_K, brier_score
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
taus_tensor = torch.tensor(
horizons_years, device=device, dtype=torch.float32)
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Accumulate per horizon
y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
for batch in dataloader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
# Select a single fixed context per sample for this batch.
# Horizon-specific eligibility is derived from this context (do not re-select per horizon).
keep0, t_ctx, t_ctx_time = select_context_indices(
event_seq=event_seq,
time_seq=time_seq,
offset_years=float(cfg.offset_years),
tau_years=0.0,
)
if not keep0.any():
continue
b = torch.arange(event_seq.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
# CIFs for all horizons at once
cifs_all = criterion.calculate_cifs(
logits, taus=taus_tensor) # (B,K,T) or (B,K)
if cifs_all.ndim != 3:
raise ValueError(
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
)
# Follow-up end time per sample = time at last valid token.
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
followup_end_time = time_seq[b, last_idx]
for h_idx, tau_y in enumerate(horizons_years):
# Horizon-specific eligibility without reselecting context:
# keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
keep_tau = keep0 & (
followup_end_time >= (
t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
)
if not keep_tau.any():
continue
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
y = y[keep_tau]
preds = cifs_all[keep_tau, :, h_idx]
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
y = y[keep_tau]
preds = cifs_all[keep_tau, :, h_idx].index_select(
dim=1, index=cause_ids)
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
y_pred_by_h[h_idx].append(
preds.detach().to(torch.float32).cpu().numpy())
rows: List[Dict[str, float | int]] = []
for h_idx, tau_y in enumerate(horizons_years):
if len(y_true_by_h[h_idx]) == 0:
# No eligible samples for this horizon.
for k in range(n_causes_eval):
cause_id = int(k) if cause_ids is None else int(
cfg.cause_ids[k])
for k_percent in topk_percents:
rows.append(
dict(
cause_id=cause_id,
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
continue
y_true = np.concatenate(y_true_by_h[h_idx], axis=0)
y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0)
if y_true.shape != y_pred.shape:
raise ValueError(
f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}"
)
n_samples = int(y_true.shape[0])
for k in range(n_causes_eval):
yk = y_true[:, k]
pk = y_pred[:, k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean((yk.astype(float) - pk.astype(float))
** 2)) if n_samples > 0 else float("nan")
cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
rows.append(
dict(
cause_id=cause_id,
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
)
)
return pd.DataFrame(rows)

View File

@@ -4,6 +4,79 @@ from typing import Tuple
DAYS_PER_YEAR = 365.25 DAYS_PER_YEAR = 365.25
def sample_context_in_fixed_age_bin(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
tau_years: float,
age_bin: Tuple[float, float],
seed: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample one context token per individual within a fixed age bin.
Delphi-2M semantics for a specific (tau, age_bin):
- Token times are interpreted as age in *days* (converted to years).
- Follow-up end time is the last valid token time per individual.
- A token index j is eligible iff:
(token is valid)
AND (age_years in [age_low, age_high))
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
- For each individual, randomly select exactly one eligible token in this bin.
Args:
event_seq: (B, L) token ids, 0 is padding.
time_seq: (B, L) token times in days.
tau_years: horizon length in years.
age_bin: (low, high) bounds in years, interpreted as [low, high).
seed: RNG seed for deterministic sampling.
Returns:
keep: (B,) bool, True if a context was sampled for this bin.
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
"""
low, high = float(age_bin[0]), float(age_bin[1])
if not (high > low):
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
device = event_seq.device
B, _ = event_seq.shape
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(B, device=device)
followup_end_time = time_seq[b, last_idx] # (B,)
tau_days = float(tau_years) * DAYS_PER_YEAR
age_years = time_seq / DAYS_PER_YEAR
in_bin = (age_years >= low) & (age_years < high)
eligible = valid & in_bin & (
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
keep = torch.zeros((B,), dtype=torch.bool, device=device)
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
gen = torch.Generator(device="cpu")
gen.manual_seed(int(seed))
for i in range(B):
m = eligible[i]
if not m.any():
continue
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
chosen_idx_pos = int(
torch.randint(low=0, high=int(idxs.numel()),
size=(1,), generator=gen).item()
)
chosen_t = int(idxs[chosen_idx_pos].item())
keep[i] = True
t_ctx[i] = chosen_t
return keep, t_ctx
def select_context_indices( def select_context_indices(
event_seq: torch.Tensor, event_seq: torch.Tensor,
time_seq: torch.Tensor, time_seq: torch.Tensor,