Remove evaluation_age_time_dependent.py and utils.py files

- Deleted the entire evaluation_age_time_dependent.py file which contained functions for evaluating age-dependent metrics, including various statistical calculations and data aggregation methods.
- Removed utils.py file that provided utility functions for sampling context in fixed age bins and multi-hot encoding for disease occurrences.
This commit is contained in:
2026-01-17 09:53:15 +08:00
parent a637beb220
commit c1bba30de4
3 changed files with 0 additions and 1566 deletions

View File

@@ -1,507 +0,0 @@
from __future__ import annotations
import argparse
import json
import math
import os
import multiprocessing as mp
from typing import List, Sequence, Tuple
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from evaluation_age_time_dependent import (
EvalAgeConfig,
aggregate_age_bin_results,
evaluate_time_dependent_age_bins,
)
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
from model import DelphiFork, SapDelphi, SimpleHead
def _parse_floats(items: Sequence[str]) -> List[float]:
out: List[float] = []
for x in items:
x = x.strip()
if not x:
continue
out.append(float(x))
return out
def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]:
vals = _parse_floats(edges)
if len(vals) < 2:
raise ValueError("--age_bin_edges must have at least 2 values")
for i in range(1, len(vals)):
if not (vals[i] > vals[i - 1]):
raise ValueError("--age_bin_edges must be strictly increasing")
return vals
def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
def _parse_gpus(gpus: str | None) -> List[int]:
if gpus is None:
return []
s = gpus.strip()
if not s:
return []
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[int] = []
for p in parts:
out.append(int(p))
return out
def _worker_eval_mcs_on_gpu(
queue: "mp.Queue",
*,
run_dir: str,
split: str,
data_prefix_override: str | None,
horizons: List[float],
age_bins: List[Tuple[float, float]],
topk_percents: List[float],
n_mc: int,
seed: int,
batch_size: int,
num_workers: int,
gpu_id: int,
mc_indices: List[int],
out_path: str,
) -> None:
"""Worker process: evaluate a subset of MC indices on a single GPU."""
try:
ckpt_path = os.path.join(run_dir, "best_model.pt")
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
data_prefix = (
data_prefix_override
if data_prefix_override is not None
else cfg.get("data_prefix", "ukb")
)
full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix,
covariate_list=cov_list)
train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15))
seed_split = int(cfg.get("random_seed", 42))
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed_split),
)
if split == "train":
ds = train_ds
elif split == "val":
ds = val_ds
elif split == "test":
ds = test_ds
else:
ds = dataset
loader = DataLoader(
ds,
batch_size=int(batch_size),
shuffle=False,
collate_fn=health_collate_fn,
num_workers=int(num_workers),
pin_memory=True,
)
criterion, out_dims = build_criterion_and_out_dims(
loss_type=str(cfg["loss_type"]),
n_disease=int(dataset.n_disease),
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
)
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
device = torch.device(f"cuda:{int(gpu_id)}")
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
if "criterion_state_dict" in checkpoint:
try:
criterion.load_state_dict(
checkpoint["criterion_state_dict"], strict=False)
except Exception:
pass
model.to(device)
head.to(device)
criterion.to(device)
frames: List[pd.DataFrame] = []
for mc_idx in mc_indices:
eval_cfg = EvalAgeConfig(
horizons_years=horizons,
age_bins=age_bins,
topk_percents=topk_percents,
n_mc=1,
seed=int(seed),
cause_ids=None,
)
df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
mc_offset=int(mc_idx),
)
frames.append(df_by_bin)
df_all = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame()
df_all = _drop_zero_positives_rows(df_all, "n_positives")
df_all.to_csv(out_path, index=False)
queue.put({"ok": True, "out_path": out_path})
except Exception as e:
queue.put({"ok": False, "error": repr(e)})
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
out_dims = [n_disease]
return criterion, out_dims
if loss_type == "discrete_time_cif":
criterion = DiscreteTimeCIFNLLLoss(
bin_edges=bin_edges, lambda_reg=lambda_reg)
out_dims = [n_disease + 1, len(bin_edges)]
return criterion, out_dims
if loss_type == "pwe_cif":
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
"pwe_cif requires at least 2 finite bin edges (including 0)")
if float(pwe_edges[0]) != 0.0:
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
criterion = PiecewiseExponentialCIFNLLLoss(
bin_edges=pwe_edges, lambda_reg=lambda_reg)
n_bins = len(pwe_edges) - 1
out_dims = [n_disease, n_bins]
return criterion, out_dims
raise ValueError(f"Unsupported loss_type: {loss_type}")
def _drop_zero_positives_rows(df: pd.DataFrame, positive_col: str) -> pd.DataFrame:
"""Drop rows where the provided positives column is <= 0.
Intended to reduce CSV size by omitting (cause, horizon, bin) rows that have
no positives, which otherwise yield undefined/NaN metrics.
"""
if df is None or len(df) == 0:
return df
if positive_col not in df.columns:
return df
pos = pd.to_numeric(df[positive_col], errors="coerce")
return df[pos > 0].copy()
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
if model_type == "delphi_fork":
return DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg["age_encoder"]),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
)
if model_type == "sap_delphi":
return SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg["age_encoder"]),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
pretrained_weights_path=str(
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
)
raise ValueError(f"Unsupported model_type: {model_type}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})")
parser.add_argument(
"--run_dir",
type=str,
required=True,
help="Training run directory (contains best_model.pt and train_config.json)",
)
parser.add_argument("--data_prefix", type=str, default=None)
parser.add_argument("--split", type=str,
choices=["train", "val", "test", "all"], default="val")
parser.add_argument("--horizons", type=str, nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"])
parser.add_argument(
"--age_bin_edges",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "80"],
help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).",
)
parser.add_argument(
"--topk_percent",
type=float,
nargs="+",
default=[1, 5, 10, 20, 50],
help="One or more K%% values for recall/precision@K%%",
)
parser.add_argument("--n_mc", type=int, default=5)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3",
)
parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--out_prefix", type=str,
default=None, help="Output prefix for CSVs")
args = parser.parse_args()
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
cfg_path = os.path.join(args.run_dir, "train_config.json")
if not os.path.exists(ckpt_path):
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
if not os.path.exists(cfg_path):
raise SystemExit(f"Missing config: {cfg_path}")
with open(cfg_path, "r") as f:
cfg = json.load(f)
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
"data_prefix", "ukb")
full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15))
seed_split = int(cfg.get("random_seed", 42))
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed_split),
)
if args.split == "train":
ds = train_ds
elif args.split == "val":
ds = val_ds
elif args.split == "test":
ds = test_ds
else:
ds = dataset
loader = DataLoader(
ds,
batch_size=int(args.batch_size),
shuffle=False,
collate_fn=health_collate_fn,
num_workers=int(args.num_workers),
pin_memory=str(args.device).startswith("cuda"),
)
criterion, out_dims = build_criterion_and_out_dims(
loss_type=str(cfg["loss_type"]),
n_disease=int(dataset.n_disease),
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
)
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
device = torch.device(args.device)
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
if "criterion_state_dict" in checkpoint:
try:
criterion.load_state_dict(
checkpoint["criterion_state_dict"], strict=False)
except Exception:
pass
model.to(device)
head.to(device)
criterion.to(device)
age_edges = _parse_age_bin_edges(args.age_bin_edges)
age_bins = _edges_to_bins(age_edges)
eval_cfg = EvalAgeConfig(
horizons_years=_parse_floats(args.horizons),
age_bins=age_bins,
topk_percents=[float(x) for x in args.topk_percent],
n_mc=int(args.n_mc),
seed=int(args.seed),
cause_ids=None,
)
if args.out_prefix is None:
out_prefix = os.path.join(
args.run_dir, f"age_bin_time_dependent_{args.split}")
else:
out_prefix = args.out_prefix
out_bin = out_prefix + "_by_bin.csv"
out_agg = out_prefix + "_agg.csv"
gpus = _parse_gpus(args.gpus)
if len(gpus) <= 1:
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
)
df_by_bin_csv = _drop_zero_positives_rows(df_by_bin, "n_positives")
df_agg_csv = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
df_by_bin_csv.to_csv(out_bin, index=False)
df_agg_csv.to_csv(out_agg, index=False)
print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}")
return
if not torch.cuda.is_available():
raise SystemExit("--gpus was provided but CUDA is not available")
# Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU).
mc_indices_all = list(range(int(args.n_mc)))
per_gpu: List[Tuple[int, List[int]]] = []
for pos, gpu_id in enumerate(gpus):
assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos]
if assigned:
per_gpu.append((int(gpu_id), assigned))
ctx = mp.get_context("spawn")
queue: "mp.Queue" = ctx.Queue()
procs: List[mp.Process] = []
tmp_paths: List[str] = []
for gpu_id, mc_idxs in per_gpu:
tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv"
tmp_paths.append(tmp_path)
p = ctx.Process(
target=_worker_eval_mcs_on_gpu,
kwargs=dict(
queue=queue,
run_dir=str(args.run_dir),
split=str(args.split),
data_prefix_override=(
str(args.data_prefix) if args.data_prefix is not None else None
),
horizons=_parse_floats(args.horizons),
age_bins=age_bins,
topk_percents=[float(x) for x in args.topk_percent],
n_mc=int(args.n_mc),
seed=int(args.seed),
batch_size=int(args.batch_size),
num_workers=int(args.num_workers),
gpu_id=int(gpu_id),
mc_indices=mc_idxs,
out_path=tmp_path,
),
)
p.start()
procs.append(p)
results = [queue.get() for _ in range(len(procs))]
for p in procs:
p.join()
for r in results:
if not r.get("ok", False):
raise SystemExit(f"Worker failed: {r.get('error')}")
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
df_by_bin = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame()
# Ensure we don't keep zero-positive rows even if a temp file was produced
# by an older version of the worker.
df_by_bin = _drop_zero_positives_rows(df_by_bin, "n_positives")
df_agg = aggregate_age_bin_results(df_by_bin)
df_agg = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False)
# Best-effort cleanup.
for p in tmp_paths:
try:
if os.path.exists(p):
os.remove(p)
except Exception:
pass
print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}")
if __name__ == "__main__":
main()

View File

@@ -1,852 +0,0 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
try:
from tqdm import tqdm
except Exception: # pragma: no cover
def tqdm(x, **kwargs):
return x
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
sample_context_in_fixed_age_bin,
)
from torch_metrics import compute_binary_metrics_torch
def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray:
with np.errstate(invalid="ignore"):
return np.nanmean(x, axis=axis)
def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware sample std (ddof=1), matching pandas std() semantics."""
x = np.asarray(x, dtype=float)
mask = np.isfinite(x)
cnt = mask.sum(axis=axis)
# mean over finite entries
x0 = np.where(mask, x, 0.0)
mean = x0.sum(axis=axis) / np.maximum(cnt, 1)
# sum of squared deviations over finite entries
dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0)
ss = dev2.sum(axis=axis)
denom = cnt - 1
out = np.sqrt(ss / np.maximum(denom, 1))
out = np.where(denom > 0, out, np.nan)
return out
def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware weighted mean.
Only bins with finite x contribute to both numerator and denominator.
If denom==0 -> NaN.
"""
x = np.asarray(x, dtype=float)
w = np.asarray(w, dtype=float)
if axis != 0:
raise ValueError("_weighted_mean_np currently supports axis=0 only")
# Broadcast weights along trailing dims of x.
while w.ndim < x.ndim:
w = w[..., None]
w = np.broadcast_to(w, x.shape)
mask = np.isfinite(x)
num = np.where(mask, x * w, 0.0).sum(axis=0)
denom = np.where(mask, w, 0.0).sum(axis=0)
return np.where(denom > 0.0, num / denom, np.nan)
def _blocks_to_df_by_bin(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
) -> pd.DataFrame:
"""Convert per-block column vectors into the long-format per-bin DataFrame.
This does a single vectorized reshape per block (cause-major ordering), and
concatenates columns once at the end.
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
P = int(topk_percents.size)
cols: Dict[str, List[np.ndarray]] = {
"mc_idx": [],
"age_bin_id": [],
"age_bin_low": [],
"age_bin_high": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_samples": [],
"n_positives": [],
"auc": [],
"auprc": [],
"recall_at_K": [],
"precision_at_K": [],
"brier_score": [],
}
for blk in blocks:
cause_id = np.asarray(blk["cause_id"], dtype=int)
K = int(cause_id.size)
n_rows = K * P
cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int))
cols["age_bin_id"].append(
np.full(n_rows, int(blk["age_bin_id"]), dtype=int))
cols["age_bin_low"].append(
np.full(n_rows, float(blk["age_bin_low"]), dtype=float))
cols["age_bin_high"].append(
np.full(n_rows, float(blk["age_bin_high"]), dtype=float))
cols["horizon_tau"].append(
np.full(n_rows, float(blk["horizon_tau"]), dtype=float))
cols["cause_id"].append(np.repeat(cause_id, P))
cols["topk_percent"].append(np.tile(topk_percents.astype(float), K))
cols["n_samples"].append(
np.full(n_rows, int(blk["n_samples"]), dtype=int))
n_pos = np.asarray(blk["n_positives"], dtype=int)
cols["n_positives"].append(np.repeat(n_pos, P))
auc = np.asarray(blk["auc"], dtype=float)
auprc = np.asarray(blk["auprc"], dtype=float)
brier = np.asarray(blk["brier_score"], dtype=float)
cols["auc"].append(np.repeat(auc, P))
cols["auprc"].append(np.repeat(auprc, P))
cols["brier_score"].append(np.repeat(brier, P))
# precision/recall are stored as (P,K); we want cause-major rows, i.e.
# (K,P) then flatten.
prec = np.asarray(blk["precision_at_K"], dtype=float)
rec = np.asarray(blk["recall_at_K"], dtype=float)
if prec.shape != (P, K) or rec.shape != (P, K):
raise ValueError(
f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}"
)
cols["precision_at_K"].append(prec.T.reshape(-1))
cols["recall_at_K"].append(rec.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in cols.items()}
return pd.DataFrame(out)
def aggregate_metrics_columnar(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
cause_id: np.ndarray,
) -> pd.DataFrame:
"""Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std).
This is a vectorized, columnar replacement for the old pandas groupby/apply.
Semantics match the previous implementation:
- bins with n_samples==0 are excluded from bin-aggregation
- macro: unweighted mean over bins (NaN-aware)
- weighted: weighted mean over bins using weights=n_samples (NaN-aware)
- across MC: mean/std (ddof=1), NaN-aware
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
P = int(topk_percents.size)
cause_id = np.asarray(cause_id, dtype=int)
K = int(cause_id.size)
# Group blocks by (mc_idx, horizon_tau)
keys: List[Tuple[int, float]] = []
grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {}
for blk in blocks:
key = (int(blk["mc_idx"]), float(blk["horizon_tau"]))
if key not in grouped:
grouped[key] = []
keys.append(key)
grouped[key].append(blk)
mc_vals = sorted({k[0] for k in keys})
tau_vals = sorted({k[1] for k in keys})
M = len(mc_vals)
T = len(tau_vals)
mc_index = {mc: i for i, mc in enumerate(mc_vals)}
tau_index = {tau: i for i, tau in enumerate(tau_vals)}
# Per (agg_type, mc, tau): store arrays
# metrics: (M,T,K) and (M,T,P,K)
auc_macro = np.full((M, T, K), np.nan, dtype=float)
auc_weighted = np.full((M, T, K), np.nan, dtype=float)
ap_macro = np.full((M, T, K), np.nan, dtype=float)
ap_weighted = np.full((M, T, K), np.nan, dtype=float)
brier_macro = np.full((M, T, K), np.nan, dtype=float)
brier_weighted = np.full((M, T, K), np.nan, dtype=float)
prec_macro = np.full((M, T, P, K), np.nan, dtype=float)
prec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
rec_macro = np.full((M, T, P, K), np.nan, dtype=float)
rec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
n_bins_used = np.zeros((M, T), dtype=float)
n_samples_total = np.zeros((M, T), dtype=float)
n_pos_total = np.zeros((M, T, K), dtype=float)
for (mc, tau), blks in grouped.items():
mi = mc_index[mc]
ti = tau_index[tau]
# keep only bins with n_samples>0
blks_nz = [b for b in blks if int(b["n_samples"]) > 0]
if len(blks_nz) == 0:
n_bins_used[mi, ti] = 0.0
n_samples_total[mi, ti] = 0.0
n_pos_total[mi, ti, :] = 0.0
continue
w = np.asarray([int(b["n_samples"])
for b in blks_nz], dtype=float) # (B,)
n_bins_used[mi, ti] = float(len(w))
n_samples_total[mi, ti] = float(w.sum())
npos = np.stack([np.asarray(b["n_positives"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
n_pos_total[mi, ti, :] = npos.sum(axis=0)
auc_b = np.stack([np.asarray(b["auc"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
ap_b = np.stack([np.asarray(b["auprc"], dtype=float)
for b in blks_nz], axis=0)
brier_b = np.stack([np.asarray(b["brier_score"], dtype=float)
for b in blks_nz], axis=0)
auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0)
ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0)
brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0)
auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0)
ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0)
brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0)
prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float)
for b in blks_nz], axis=0) # (B,P,K)
rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float)
for b in blks_nz], axis=0)
# macro mean over bins
prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0)
rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0)
# weighted mean over bins (weights along bin axis)
w3 = w.reshape(-1, 1, 1)
prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0)
rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0)
# Across-MC aggregation (mean/std), then emit long-format df keyed by
# (agg_type, horizon_tau, topk_percent, cause_id)
rows: Dict[str, List[np.ndarray]] = {
"agg_type": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_mc": [],
"n_bins_used_mean": [],
"n_samples_total_mean": [],
"n_positives_total_mean": [],
"auc_mean": [],
"auc_std": [],
"auprc_mean": [],
"auprc_std": [],
"recall_at_K_mean": [],
"recall_at_K_std": [],
"precision_at_K_mean": [],
"precision_at_K_std": [],
"brier_score_mean": [],
"brier_score_std": [],
}
cause_long = np.repeat(cause_id, P)
topk_long = np.tile(topk_percents.astype(float), K)
n_mc_val = float(M)
for ti, tau in enumerate(tau_vals):
# scalar totals (repeat across causes/topk)
n_bins_mean = float(
np.mean(n_bins_used[:, ti])) if M > 0 else float("nan")
n_samp_mean = float(
np.mean(n_samples_total[:, ti])) if M > 0 else float("nan")
n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,)
for agg_type in ("macro", "weighted"):
if agg_type == "macro":
auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K)
prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0)
else:
auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0)
prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0)
n_rows = K * P
rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object))
rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float))
rows["topk_percent"].append(topk_long)
rows["cause_id"].append(cause_long)
rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float))
rows["n_bins_used_mean"].append(
np.full(n_rows, n_bins_mean, dtype=float))
rows["n_samples_total_mean"].append(
np.full(n_rows, n_samp_mean, dtype=float))
rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P))
rows["auc_mean"].append(np.repeat(auc_m, P))
rows["auc_std"].append(np.repeat(auc_s, P))
rows["auprc_mean"].append(np.repeat(ap_m, P))
rows["auprc_std"].append(np.repeat(ap_s, P))
rows["brier_score_mean"].append(np.repeat(brier_m, P))
rows["brier_score_std"].append(np.repeat(brier_s, P))
rows["precision_at_K_mean"].append(prec_m.T.reshape(-1))
rows["precision_at_K_std"].append(prec_s.T.reshape(-1))
rows["recall_at_K_mean"].append(rec_m.T.reshape(-1))
rows["recall_at_K_std"].append(rec_s.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in rows.items()}
df = pd.DataFrame(out)
return df.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True
)
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results.
Produces both:
- macro: unweighted mean over bins with n_samples>0
- weighted: weighted mean over bins using weights=n_samples
Then aggregates across MC repetitions (mean/std).
Requires df_by_bin to include:
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
Returns:
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
"""
if df_by_bin is None or len(df_by_bin) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
# Kept for backward compatibility (e.g., if callers load a CSV and need to
# aggregate). Prefer `aggregate_metrics_columnar` during evaluation.
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df = df_by_bin[df_by_bin["n_samples"] > 0].copy()
if len(df) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
# Macro: mean over bins
df_mc_macro = (
df.groupby(group_keys, as_index=False)
.agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
auc=("auc", "mean"),
auprc=("auprc", "mean"),
recall_at_K=("recall_at_K", "mean"),
precision_at_K=("precision_at_K", "mean"),
brier_score=("brier_score", "mean"),
)
)
df_mc_macro["agg_type"] = "macro"
# Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric
w = df["n_samples"].astype(float)
df_w = df.copy()
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
m = df_w[col].astype(float)
ww = w.where(m.notna(), other=0.0)
df_w[f"__num_{col}"] = (m.fillna(0.0) * w)
df_w[f"__den_{col}"] = ww
df_mc_w = df_w.groupby(group_keys, as_index=False).agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
**{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
**{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
)
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
num = df_mc_w[f"__num_{col}"].astype(float)
den = df_mc_w[f"__den_{col}"].astype(float)
df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan"))
df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True)
df_mc_w["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True)
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True)
)
return df_agg
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
# NumPy/Pandas are only used for final CSV formatting/aggregation.
@dataclass
class EvalAgeConfig:
horizons_years: Sequence[float]
age_bins: Sequence[Tuple[float, float]]
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
n_mc: int = 5
seed: int = 0
cause_ids: Optional[Sequence[int]] = None
store_per_cause: bool = True
@torch.inference_mode()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalAgeConfig,
device: str | torch.device,
mc_offset: int = 0,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
- build the eligible token set within that bin
- enforce follow-up coverage: t_ctx + tau <= t_end
- randomly sample exactly one token per individual within the bin (de-dup)
- recompute context representations and predictions for that (tau, bin)
Returns:
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
if len(age_bins) == 0:
raise ValueError("cfg.age_bins must be non-empty")
for (a, b) in age_bins:
if not (b > a):
raise ValueError(
f"age_bins must be (low, high) with high>low; got {(a, b)}")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
if int(cfg.n_mc) <= 0:
raise ValueError("cfg.n_mc must be >= 1")
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
cause_id_vec = np.arange(n_causes_eval, dtype=int)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int)
topk_percents_np = np.asarray(topk_percents, dtype=float)
# Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends.
blocks: List[Dict[str, Any]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
# This keeps computations GPU-first while preventing a factor-n_mc
# blow-up in GPU memory.
y_true_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
y_pred_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
):
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
# within this batch, so this is safe and numerically identical.
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
for tau_idx, tau_y in enumerate(horizons_years):
tau_tensor = torch.tensor(float(tau_y), device=device)
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
int(cfg.seed)
+ (100_000 * int(global_mc_idx))
+ (1_000 * int(tau_idx))
+ (10 * int(bin_idx))
+ int(batch_idx)
)
keep, t_ctx = sample_context_in_fixed_age_bin(
event_seq=event_seq,
time_seq=time_seq,
tau_years=float(tau_y),
age_bin=(float(a_lo), float(a_hi)),
seed=seed,
)
if not keep.any():
continue
# Bin-specific prediction: context indices differ per (tau, bin)
# but the backbone features do not.
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=tau_tensor
)
if cifs.ndim != 2:
raise ValueError(
"criterion.calculate_cifs must return (B,K) for scalar tau; "
f"got shape={tuple(cifs.shape)}"
)
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
preds = cifs
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
preds = cifs.index_select(dim=1, index=cause_ids)
y_true_mc[tau_idx][bin_idx].append(
y[keep].detach().to(dtype=torch.bool)
)
y_pred_mc[tau_idx][bin_idx].append(
preds[keep].detach().to(dtype=torch.float32)
)
# Aggregate this MC immediately (frees GPU memory before next MC).
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true_mc[h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau): store a single
# block with NaN metric vectors.
K = int(n_causes_eval)
P = int(topk_percents_np.size)
blocks.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
n_samples=0,
cause_id=cause_id_vec,
n_positives=np.zeros((K,), dtype=int),
auc=np.full((K,), np.nan, dtype=float),
auprc=np.full((K,), np.nan, dtype=float),
brier_score=np.full((K,), np.nan, dtype=float),
precision_at_K=np.full((P, K), np.nan, dtype=float),
recall_at_K=np.full((P, K), np.nan, dtype=float),
)
)
continue
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
if tuple(yb_t.shape) != tuple(pb_t.shape):
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
)
n_samples = int(yb_t.size(0))
metrics = compute_binary_metrics_torch(
y_true=yb_t,
y_pred=pb_t,
k_percents=topk_percents,
tie_mode="exact",
chunk_size=128,
compute_ici=False,
)
# Collect a single columnar block (vectors, not per-row dicts).
blocks.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
n_samples=int(n_samples),
cause_id=cause_id_vec,
n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int),
auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float),
auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float),
brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float),
precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float),
recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float),
)
)
# Aggregation is computed from columnar blocks (fast, no pandas apply).
df_agg = aggregate_metrics_columnar(
blocks,
topk_percents=topk_percents_np,
cause_id=cause_id_vec,
)
if bool(cfg.store_per_cause):
df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np)
else:
df_by_bin = pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
return df_by_bin, df_agg

207
utils.py
View File

@@ -1,207 +0,0 @@
import torch
from typing import Tuple
DAYS_PER_YEAR = 365.25
def sample_context_in_fixed_age_bin(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
tau_years: float,
age_bin: Tuple[float, float],
seed: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample one context token per individual within a fixed age bin.
Delphi-2M semantics for a specific (tau, age_bin):
- Token times are interpreted as age in *days* (converted to years).
- Follow-up end time is the last valid token time per individual.
- A token index j is eligible iff:
(token is valid)
AND (age_years in [age_low, age_high))
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
- For each individual, randomly select exactly one eligible token in this bin.
Args:
event_seq: (B, L) token ids, 0 is padding.
time_seq: (B, L) token times in days.
tau_years: horizon length in years.
age_bin: (low, high) bounds in years, interpreted as [low, high).
seed: RNG seed for deterministic sampling.
Returns:
keep: (B,) bool, True if a context was sampled for this bin.
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
"""
low, high = float(age_bin[0]), float(age_bin[1])
if not (high > low):
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
device = event_seq.device
B, _ = event_seq.shape
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(B, device=device)
followup_end_time = time_seq[b, last_idx] # (B,)
tau_days = float(tau_years) * DAYS_PER_YEAR
age_years = time_seq / DAYS_PER_YEAR
in_bin = (age_years >= low) & (age_years < high)
eligible = valid & in_bin & (
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
# Vectorized, uniform sampling over eligible indices per sample.
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
# choice among eligible indices by symmetry (ties have probability ~0).
keep = eligible.any(dim=1)
# Prefer a per-call generator on the target device for reproducibility without
# touching global RNG state. If unavailable, fall back to seeding the global
# CUDA RNG for this call.
gen = None
if device.type == "cuda":
try:
gen = torch.Generator(device=device)
gen.manual_seed(int(seed))
except Exception:
gen = None
torch.cuda.manual_seed(int(seed))
else:
gen = torch.Generator()
gen.manual_seed(int(seed))
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
r = r.masked_fill(~eligible, -1.0)
t_ctx = r.argmax(dim=1).to(torch.long)
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
return keep, t_ctx
def select_context_indices(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
offset_years: float,
tau_years: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Select per-sample prediction context index.
IMPORTANT SEMANTICS:
- The last observed token time is treated as the FOLLOW-UP END time.
- We pick the last valid token with time <= (followup_end_time - offset).
- We do NOT interpret followup_end_time as an event time.
Returns:
keep_mask: (B,) bool, which samples have a valid context
t_ctx: (B,) long, index into sequence
t_ctx_time: (B,) float, time (days) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(event_seq.size(0), device=event_seq.device)
followup_end_time = time_seq[b, last_idx]
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
eligible_counts = eligible.sum(dim=1)
keep = eligible_counts > 0
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
t_ctx_time = time_seq[b, t_ctx]
# Horizon-aligned eligibility: require enough follow-up time after the selected context.
# All times are in days.
keep = keep & (followup_end_time >= (
t_ctx_time + (tau_years * DAYS_PER_YEAR)))
return keep, t_ctx, t_ctx_time
def multi_hot_ever_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
# Include same-day events after context, exclude any token at/before context index.
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
y[b_idx, disease_ids] = True
return y
def multi_hot_selected_causes_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
cause_ids: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Labels for selected causes only: does cause k occur within tau after context?"""
B, L = event_seq.shape
device = event_seq.device
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
if not in_window.any():
return out
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
selected[cause_ids] = True
keep = selected[disease_ids]
if not keep.any():
return out
b_idx = b_idx[keep]
disease_ids = disease_ids[keep]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
local = lookup[disease_ids]
out[b_idx, local] = True
return out