diff --git a/README.md b/README.md index 7207c25..8f89f0e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,40 @@ # DeepHealth +## Evaluation + +This repo includes two event-driven evaluation entrypoints: + +- `evaluate_next_event.py`: next-event prediction using short-window CIF +- `evaluate_horizon.py`: horizon-capture evaluation using CIF at multiple horizons + +### IMPORTANT metric disclaimers + +- **AUC** reported by `evaluate_horizon.py` is “time-dependent” only because the label depends on the chosen horizon $\tau$. + Without explicit follow-up end times / censoring, this is **not** a classical risk-set AUC with IPCW. + Use it for **model comparison and diagnostics**, not strict statistical interpretation. + +- **Brier score** reported by `evaluate_horizon.py` is an unadjusted diagnostic/proxy metric (no censoring adjustment). + Use it to detect probability-mass compression / numerical stability issues; do not claim calibrated absolute risk. + +### Example + +```bash +# Next-event (no --horizons) +python evaluate_next_event.py \ + --run_dir runs/your_run \ + --tau_short 0.25 \ + --age_bins 40 45 50 55 60 65 70 inf \ + --device cuda \ + --batch_size 256 \ + --seed 0 + +# Horizon-capture +python evaluate_horizon.py \ + --run_dir runs/your_run \ + --horizons 0.25 0.5 1.0 2.0 5.0 10.0 \ + --age_bins 40 45 50 55 60 65 70 inf \ + --device cuda \ + --batch_size 256 \ + --seed 0 +``` + diff --git a/evaluate_horizon.py b/evaluate_horizon.py new file mode 100644 index 0000000..6ce9fe7 --- /dev/null +++ b/evaluate_horizon.py @@ -0,0 +1,277 @@ +"""Horizon-capture evaluation. + +DISCLAIMERS (important): +- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$. + Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW. + Use it for model comparison and diagnostics, not strict statistical interpretation. + +- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment). + Use them to detect probability-mass compression / numerical stability issues; do not claim + calibrated absolute risk. +""" + +import argparse +import os +from typing import Dict, List, Sequence + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader + +from utils import ( + EvalRecordDataset, + build_dataset_from_config, + build_event_driven_records, + build_model_head_criterion, + eval_collate_fn, + flatten_future_events, + get_test_subset, + load_checkpoint_into, + load_train_config, + make_inference_dataloader_kwargs, + parse_float_list, + predict_cifs, + roc_auc_ovr, + seed_everything, + topk_indices, +) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Evaluate horizon-capture using CIF at horizons") + p.add_argument("--run_dir", type=str, required=True) + p.add_argument( + "--horizons", + type=str, + nargs="+", + default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], + help="Horizon grid in years", + ) + p.add_argument( + "--age_bins", + type=str, + nargs="+", + default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], + help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", + ) + + p.add_argument( + "--device", + type=str, + default=("cuda" if torch.cuda.is_available() else "cpu"), + ) + p.add_argument("--batch_size", type=int, default=256) + p.add_argument("--num_workers", type=int, default=0) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--min_pos", type=int, default=20) + p.add_argument( + "--topk_list", + type=int, + nargs="+", + default=[1, 5, 10, 20, 50], + ) + return p.parse_args() + + +def build_labels_within_tau_flat( + n_records: int, + n_causes: int, + event_record_idx: np.ndarray, + event_cause: np.ndarray, + event_dt_years: np.ndarray, + tau_years: float, +) -> np.ndarray: + """Build y_within_tau using a flattened (record,cause,dt) representation. + + This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k + occurs in (t0, t0+tau]. + """ + y = np.zeros((n_records, n_causes), dtype=np.int8) + if event_dt_years.size == 0: + return y + m = event_dt_years <= float(tau_years) + if not np.any(m): + return y + y[event_record_idx[m], event_cause[m]] = 1 + return y + + +def main() -> None: + args = parse_args() + seed_everything(args.seed) + + run_dir = args.run_dir + cfg = load_train_config(run_dir) + + dataset = build_dataset_from_config(cfg) + test_subset = get_test_subset(dataset, cfg) + + age_bins_years = parse_float_list(args.age_bins) + horizons = parse_float_list(args.horizons) + horizons = [float(h) for h in horizons] + + records = build_event_driven_records( + dataset=dataset, + subset=test_subset, + age_bins_years=age_bins_years, + seed=args.seed, + ) + + device = torch.device(args.device) + model, head, criterion = build_model_head_criterion(cfg, dataset, device) + load_checkpoint_into(run_dir, model, head, criterion, device) + + rec_ds = EvalRecordDataset(dataset, records) + + dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) + loader = DataLoader( + rec_ds, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=eval_collate_fn, + **dl_kwargs, + ) + + # Print disclaimers every run (requested) + print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).") + print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).") + + scores = predict_cifs(model, head, criterion, loader, + horizons, device=device) + # scores shape: (N, K, H) + if scores.ndim != 3: + raise ValueError( + f"Expected CIF scores with shape (N,K,H), got {scores.shape}") + + N, K, H = scores.shape + if N != len(records): + raise ValueError("Record count mismatch") + + # Pre-flatten all future events once to avoid repeated per-record scans. + evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K) + + per_tau_rows: List[Dict[str, object]] = [] + per_cause_rows: List[Dict[str, object]] = [] + workload_rows: List[Dict[str, object]] = [] + + for h_idx, tau in enumerate(horizons): + s_tau = scores[:, :, h_idx] + y_tau = build_labels_within_tau_flat( + N, K, evt_rec_idx, evt_cause, evt_dt, tau) + + # Per-cause counts + Brier (vectorized) + n_pos = y_tau.sum(axis=0).astype(np.int64) + n_neg = (int(N) - n_pos).astype(np.int64) + + # Brier per cause: mean_i (y - s)^2 + brier_per_cause = np.mean( + (y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0) + brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan") + brier_weighted = float(np.sum( + brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan") + + # AUC: compute only for causes with enough positives and at least 1 negative + auc = np.full((K,), np.nan, dtype=np.float64) + min_pos = int(args.min_pos) + candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) + for k in candidates: + auc[k] = roc_auc_ovr(y_tau[:, k].astype( + np.int32), s_tau[:, k].astype(np.float64)) + + finite_auc = auc[np.isfinite(auc)] + auc_macro = float(np.mean(finite_auc) + ) if finite_auc.size > 0 else float("nan") + w_mask = np.isfinite(auc) + auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum( + n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan") + n_valid_auc = int(np.isfinite(auc).sum()) + + # Append per-cause rows (vectorized via DataFrame to avoid Python loops) + per_cause_rows.append( + pd.DataFrame( + { + "tau_years": float(tau), + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": n_pos, + "n_neg": n_neg, + "auc": auc, + "brier": brier_per_cause, + } + ) + ) + + # Business metrics for each topK + denom_true_pairs = int(y_tau.sum()) + for topk in args.topk_list: + topk = int(topk) + idx = topk_indices(s_tau, topk) + captured = np.take_along_axis(y_tau, idx, axis=1) + hits = captured.sum(axis=1).astype(np.float64) + true_cnt = y_tau.sum(axis=1).astype(np.float64) + + precision_like = hits / float(min(topk, K)) + mean_precision = float(np.mean(precision_like) + ) if N > 0 else float("nan") + + mask_has_true = true_cnt > 0 + recall_like = np.full((N,), np.nan, dtype=np.float64) + recall_like[mask_has_true] = hits[mask_has_true] / \ + true_cnt[mask_has_true] + mean_recall = float(np.nanmean(recall_like)) if np.any( + mask_has_true) else float("nan") + median_recall = float(np.nanmedian(recall_like)) if np.any( + mask_has_true) else float("nan") + + numer_captured_pairs = int(captured.sum()) + pop_capture_rate = float( + numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan") + + workload_rows.append( + { + "tau_years": float(tau), + "topk": int(topk), + "population_capture_rate": pop_capture_rate, + "mean_precision_like": mean_precision, + "mean_recall_like": mean_recall, + "median_recall_like": median_recall, + "denom_true_pairs": denom_true_pairs, + "numer_captured_pairs": numer_captured_pairs, + } + ) + + per_tau_rows.append( + { + "tau_years": float(tau), + "n_records": int(N), + "n_causes": int(K), + "auc_macro": auc_macro, + "auc_weighted_by_npos": auc_weighted, + "n_causes_valid_auc": int(n_valid_auc), + "brier_macro": brier_macro, + "brier_weighted_by_npos": brier_weighted, + "total_true_pairs": denom_true_pairs, + } + ) + + out_metrics = os.path.join(run_dir, "horizon_metrics.csv") + out_pc = os.path.join(run_dir, "horizon_per_cause.csv") + out_wy = os.path.join(run_dir, "workload_yield.csv") + + pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False) + if per_cause_rows: + pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False) + else: + pd.DataFrame(columns=["tau_years", "cause_id", "n_pos", + "n_neg", "auc", "brier"]).to_csv(out_pc, index=False) + pd.DataFrame(workload_rows).to_csv(out_wy, index=False) + + print(f"Wrote {out_metrics}") + print(f"Wrote {out_pc}") + print(f"Wrote {out_wy}") + + +if __name__ == "__main__": + main() diff --git a/evaluate_next_event.py b/evaluate_next_event.py new file mode 100644 index 0000000..92124dd --- /dev/null +++ b/evaluate_next_event.py @@ -0,0 +1,188 @@ +import argparse +import os +from typing import List + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader + +from utils import ( + EvalRecordDataset, + build_dataset_from_config, + build_event_driven_records, + build_model_head_criterion, + eval_collate_fn, + get_test_subset, + make_inference_dataloader_kwargs, + load_checkpoint_into, + load_train_config, + parse_float_list, + predict_cifs, + roc_auc_ovr, + seed_everything, + topk_indices, +) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Evaluate next-event prediction using short-window CIF" + ) + p.add_argument("--run_dir", type=str, required=True) + p.add_argument("--tau_short", type=float, required=True, help="years") + p.add_argument( + "--age_bins", + type=str, + nargs="+", + default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], + help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", + ) + + p.add_argument( + "--device", + type=str, + default=("cuda" if torch.cuda.is_available() else "cpu"), + ) + p.add_argument("--batch_size", type=int, default=256) + p.add_argument("--num_workers", type=int, default=0) + p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--min_pos", + type=int, + default=20, + help="Minimum positives for per-cause AUC", + ) + return p.parse_args() + + +def main() -> None: + args = parse_args() + seed_everything(args.seed) + + run_dir = args.run_dir + cfg = load_train_config(run_dir) + + dataset = build_dataset_from_config(cfg) + test_subset = get_test_subset(dataset, cfg) + + age_bins_years = parse_float_list(args.age_bins) + records = build_event_driven_records( + dataset=dataset, + subset=test_subset, + age_bins_years=age_bins_years, + seed=args.seed, + ) + + device = torch.device(args.device) + model, head, criterion = build_model_head_criterion(cfg, dataset, device) + load_checkpoint_into(run_dir, model, head, criterion, device) + + rec_ds = EvalRecordDataset(dataset, records) + + dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) + loader = DataLoader( + rec_ds, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=eval_collate_fn, + **dl_kwargs, + ) + + tau = float(args.tau_short) + scores = predict_cifs(model, head, criterion, loader, [tau], device=device) + # scores shape: (N,K,1) for multi-taus; squeeze last + if scores.ndim == 3: + scores = scores[:, :, 0] + + n_records_total = len(records) + y_next = np.array( + [(-1 if r.next_event_cause is None else int(r.next_event_cause)) + for r in records], + dtype=np.int64, + ) + eligible = y_next >= 0 + n_eligible = int(eligible.sum()) + coverage = float( + n_eligible / n_records_total) if n_records_total > 0 else 0.0 + + metrics_rows: List[dict] = [] + metrics_rows.append({"metric": "n_records_total", "value": n_records_total}) + metrics_rows.append( + {"metric": "n_next_event_eligible", "value": n_eligible}) + metrics_rows.append({"metric": "coverage", "value": coverage}) + metrics_rows.append({"metric": "tau_short_years", "value": tau}) + + if n_eligible == 0: + out_path = os.path.join(run_dir, "next_event_metrics.csv") + pd.DataFrame(metrics_rows).to_csv(out_path, index=False) + print(f"No eligible records; wrote {out_path}") + return + + scores_e = scores[eligible] + y_e = y_next[eligible] + + pred = scores_e.argmax(axis=1) + acc = float((pred == y_e).mean()) + metrics_rows.append({"metric": "top1_accuracy", "value": acc}) + + # MRR + order = np.argsort(-scores_e, axis=1, kind="mergesort") + ranks = np.empty(y_e.shape[0], dtype=np.int32) + for i in range(y_e.shape[0]): + ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1 + mrr = float((1.0 / ranks).mean()) + metrics_rows.append({"metric": "mrr", "value": mrr}) + + # HitRate@K + for k in [1, 3, 5, 10, 20]: + topk = topk_indices(scores_e, k) + hit = (topk == y_e[:, None]).any(axis=1) + metrics_rows.append({"metric": f"hitrate_at_{k}", + "value": float(hit.mean())}) + + # Macro OvR AUC per cause (optional) + K = scores.shape[1] + n_pos = np.bincount(y_e, minlength=K).astype(np.int64) + n_neg = (int(y_e.size) - n_pos).astype(np.int64) + + auc = np.full((K,), np.nan, dtype=np.float64) + min_pos = int(args.min_pos) + candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) + for k in candidates: + auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) + auc[k] = auc_k + + included = (n_pos >= min_pos) & (n_neg > 0) + per_cause_df = pd.DataFrame( + { + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": n_pos, + "n_neg": n_neg, + "auc": auc, + "included": included, + } + ) + + aucs = auc[np.isfinite(auc)] + + if aucs: + metrics_rows.append( + {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) + else: + metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) + + out_metrics = os.path.join(run_dir, "next_event_metrics.csv") + pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False) + + # optional per-cause + out_pc = os.path.join(run_dir, "next_event_per_cause.csv") + per_cause_df.to_csv(out_pc, index=False) + + print(f"Wrote {out_metrics}") + print(f"Wrote {out_pc}") + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d05c4e9 --- /dev/null +++ b/utils.py @@ -0,0 +1,566 @@ +import json +import math +import os +import random +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset, Subset, random_split + +from dataset import HealthDataset +from losses import ( + DiscreteTimeCIFNLLLoss, + ExponentialNLLLoss, + PiecewiseExponentialCIFNLLLoss, +) +from model import DelphiFork, SapDelphi, SimpleHead + + +DAYS_PER_YEAR = 365.25 +N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2 + + +def make_inference_dataloader_kwargs( + device: torch.device, + num_workers: int, +) -> Dict[str, Any]: + """DataLoader kwargs tuned for inference throughput. + + Behavior/metrics are unchanged; this only impacts speed. + """ + use_cuda = device.type == "cuda" and torch.cuda.is_available() + kwargs: Dict[str, Any] = { + "pin_memory": bool(use_cuda), + } + if num_workers > 0: + kwargs["persistent_workers"] = True + # default prefetch is 2; set explicitly for clarity. + kwargs["prefetch_factor"] = 2 + return kwargs + + +# ------------------------- +# Config + determinism +# ------------------------- + +def _replace_nonstandard_json_numbers(text: str) -> str: + # Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats. + # Replace bare tokens (not within quotes) with string placeholders. + def repl(match: re.Match[str]) -> str: + token = match.group(0) + if token == "-Infinity": + return '"__NINF__"' + if token == "Infinity": + return '"__INF__"' + if token == "NaN": + return '"__NAN__"' + return token + + return re.sub(r'(? Any: + if isinstance(obj, dict): + return {k: _restore_placeholders(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_restore_placeholders(v) for v in obj] + if obj == "__INF__": + return float("inf") + if obj == "__NINF__": + return float("-inf") + if obj == "__NAN__": + return float("nan") + return obj + + +def load_train_config(run_dir: str) -> Dict[str, Any]: + cfg_path = os.path.join(run_dir, "train_config.json") + with open(cfg_path, "r", encoding="utf-8") as f: + raw = f.read() + raw = _replace_nonstandard_json_numbers(raw) + cfg = json.loads(raw) + cfg = _restore_placeholders(cfg) + return cfg + + +def seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def parse_float_list(values: Sequence[str]) -> List[float]: + out: List[float] = [] + for v in values: + s = str(v).strip().lower() + if s in {"inf", "+inf", "infty", "infinity", "+infinity"}: + out.append(float("inf")) + elif s in {"-inf", "-infty", "-infinity"}: + out.append(float("-inf")) + else: + out.append(float(v)) + return out + + +# ------------------------- +# Dataset + split (match train.py) +# ------------------------- + +def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset: + data_prefix = cfg["data_prefix"] + full_cov = bool(cfg.get("full_cov", False)) + + if full_cov: + cov_list = None + else: + cov_list = ["bmi", "smoking", "alcohol"] + + dataset = HealthDataset( + data_prefix=data_prefix, + covariate_list=cov_list, + ) + return dataset + + +def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset: + n_total = len(dataset) + train_ratio = float(cfg["train_ratio"]) + val_ratio = float(cfg["val_ratio"]) + seed = int(cfg["random_seed"]) + + n_train = int(n_total * train_ratio) + n_val = int(n_total * val_ratio) + n_test = n_total - n_train - n_val + + _, _, test_subset = random_split( + dataset, + [n_train, n_val, n_test], + generator=torch.Generator().manual_seed(seed), + ) + return test_subset + + +# ------------------------- +# Model + head + criterion (match train.py) +# ------------------------- + +def build_model_head_criterion( + cfg: Dict[str, Any], + dataset: HealthDataset, + device: torch.device, +) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: + loss_type = cfg["loss_type"] + + if loss_type == "exponential": + criterion = ExponentialNLLLoss(lambda_reg=float( + cfg.get("lambda_reg", 0.0))).to(device) + out_dims = [dataset.n_disease] + elif loss_type == "discrete_time_cif": + bin_edges = [float(x) for x in cfg["bin_edges"]] + criterion = DiscreteTimeCIFNLLLoss( + bin_edges=bin_edges, + lambda_reg=float(cfg.get("lambda_reg", 0.0)), + ).to(device) + out_dims = [dataset.n_disease + 1, len(bin_edges)] + elif loss_type == "pwe_cif": + # training drops +inf for PWE + raw_edges = [float(x) for x in cfg["bin_edges"]] + pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))] + if len(pwe_edges) < 2: + raise ValueError( + "pwe_cif requires at least 2 finite bin edges (including 0). " + f"Got bin_edges={raw_edges}" + ) + if float(pwe_edges[0]) != 0.0: + raise ValueError( + f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}") + + criterion = PiecewiseExponentialCIFNLLLoss( + bin_edges=pwe_edges, + lambda_reg=float(cfg.get("lambda_reg", 0.0)), + ).to(device) + n_bins = len(pwe_edges) - 1 + out_dims = [dataset.n_disease, n_bins] + else: + raise ValueError(f"Unsupported loss_type: {loss_type}") + + model_type = cfg["model_type"] + if model_type == "delphi_fork": + model = DelphiFork( + n_disease=dataset.n_disease, + n_tech_tokens=N_TECH_TOKENS, + n_embd=int(cfg["n_embd"]), + n_head=int(cfg["n_head"]), + n_layer=int(cfg["n_layer"]), + pdrop=float(cfg.get("pdrop", 0.0)), + age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), + n_cont=int(dataset.n_cont), + n_cate=int(dataset.n_cate), + cate_dims=list(dataset.cate_dims), + ).to(device) + elif model_type == "sap_delphi": + model = SapDelphi( + n_disease=dataset.n_disease, + n_tech_tokens=N_TECH_TOKENS, + n_embd=int(cfg["n_embd"]), + n_head=int(cfg["n_head"]), + n_layer=int(cfg["n_layer"]), + pdrop=float(cfg.get("pdrop", 0.0)), + age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), + n_cont=int(dataset.n_cont), + n_cate=int(dataset.n_cate), + cate_dims=list(dataset.cate_dims), + pretrained_weights_path=str( + cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), + freeze_embeddings=True, + ).to(device) + else: + raise ValueError(f"Unsupported model_type: {model_type}") + + head = SimpleHead( + n_embd=int(cfg["n_embd"]), + out_dims=list(out_dims), + ).to(device) + + return model, head, criterion + + +def load_checkpoint_into( + run_dir: str, + model: torch.nn.Module, + head: torch.nn.Module, + criterion: Optional[torch.nn.Module], + device: torch.device, +) -> Dict[str, Any]: + ckpt_path = os.path.join(run_dir, "best_model.pt") + ckpt = torch.load(ckpt_path, map_location=device) + model.load_state_dict(ckpt["model_state_dict"], strict=True) + head.load_state_dict(ckpt["head_state_dict"], strict=True) + if criterion is not None and "criterion_state_dict" in ckpt: + try: + criterion.load_state_dict( + ckpt["criterion_state_dict"], strict=False) + except Exception: + # Criterion state is not essential for inference. + pass + return ckpt + + +# ------------------------- +# Evaluation record construction (event-driven) +# ------------------------- + +@dataclass(frozen=True) +class EvalRecord: + patient_idx: int + patient_id: Any + doa_days: float + t0_days: float + cutoff_pos: int # baseline position (inclusive) + next_event_cause: Optional[int] + next_event_dt_years: Optional[float] + future_causes: np.ndarray # (E,) in [0..K-1] + future_dt_years: np.ndarray # (E,) strictly > 0 + + +def _to_days(x_years: float) -> float: + if math.isinf(float(x_years)): + return float("inf") + return float(x_years) * DAYS_PER_YEAR + + +def build_event_driven_records( + dataset: HealthDataset, + subset: Subset, + age_bins_years: Sequence[float], + seed: int, +) -> List[EvalRecord]: + if len(age_bins_years) < 2: + raise ValueError("age_bins must have at least 2 boundaries") + + age_bins_days = [_to_days(b) for b in age_bins_years] + if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)): + raise ValueError("age_bins must be strictly increasing") + + rng = np.random.default_rng(seed) + + records: List[EvalRecord] = [] + + # Subset.indices is deterministic from random_split + indices = list(getattr(subset, "indices", range(len(subset)))) + + # Speed: avoid calling dataset.__getitem__ for every patient here. + # We only need DOA + event times/codes to create evaluation records. + eps = 1e-6 + for patient_idx in indices: + patient_id = dataset.patient_ids[patient_idx] + + doa_days = float(dataset._doa[patient_idx]) + + raw_records = dataset.patient_events.get(patient_id, []) + if raw_records: + times = np.asarray([t for t, _ in raw_records], dtype=np.float64) + codes = np.asarray([c for _, c in raw_records], dtype=np.int64) + else: + times = np.zeros((0,), dtype=np.float64) + codes = np.zeros((0,), dtype=np.int64) + + # Mirror HealthDataset insertion logic exactly. + insert_pos = int(np.searchsorted(times, doa_days, side="left")) + times_ins = np.insert(times, insert_pos, doa_days) + codes_ins = np.insert(codes, insert_pos, 1) + + is_disease = codes_ins >= N_TECH_TOKENS + disease_times = times_ins[is_disease] + + for b in range(len(age_bins_days) - 1): + lo = age_bins_days[b] + hi = age_bins_days[b + 1] + + # Inclusion rule: + # 1) DOA <= bin_upper + if not (doa_days <= hi): + continue + + # 2) at least one disease event within bin, and baseline must satisfy t0>=DOA + in_bin = (disease_times >= lo) & ( + disease_times < hi) & (disease_times >= doa_days) + cand_times = disease_times[in_bin] + if cand_times.size == 0: + continue + + t0_days = float(rng.choice(cand_times)) + + # Baseline position (inclusive) in the *post-DOA-inserted* sequence. + pos = np.flatnonzero(is_disease & np.isclose( + times_ins, t0_days, rtol=0.0, atol=eps)) + if pos.size == 0: + disease_pos = np.flatnonzero(is_disease) + if disease_pos.size == 0: + continue + disease_times_full = times_ins[disease_pos] + closest_idx = int( + np.argmin(np.abs(disease_times_full - t0_days))) + cutoff_pos = int(disease_pos[closest_idx]) + t0_days = float(disease_times_full[closest_idx]) + else: + cutoff_pos = int(pos[0]) + + # Future disease events strictly after t0 + future_mask = (times_ins > (t0_days + eps)) & is_disease + future_pos = np.flatnonzero(future_mask) + if future_pos.size == 0: + next_cause = None + next_dt_years = None + future_causes = np.zeros((0,), dtype=np.int64) + future_dt_years_arr = np.zeros((0,), dtype=np.float32) + else: + future_times_days = times_ins[future_pos] + future_tokens = codes_ins[future_pos] + future_causes = (future_tokens - N_TECH_TOKENS).astype(np.int64) + future_dt_years_arr = ( + (future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32) + + # next-event = minimal time > t0 (tie broken by earliest position) + next_idx = int(np.argmin(future_times_days)) + next_cause = int(future_causes[next_idx]) + next_dt_years = float(future_dt_years_arr[next_idx]) + + records.append( + EvalRecord( + patient_idx=int(patient_idx), + patient_id=patient_id, + doa_days=float(doa_days), + t0_days=float(t0_days), + cutoff_pos=int(cutoff_pos), + next_event_cause=next_cause, + next_event_dt_years=next_dt_years, + future_causes=future_causes, + future_dt_years=future_dt_years_arr, + ) + ) + + return records + + +class EvalRecordDataset(Dataset): + def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]): + self.base = base_dataset + self.records = list(records) + self._cache: Dict[int, Tuple[torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, int]] = {} + self._cache_order: List[int] = [] + self._cache_max = 2048 + + def __len__(self) -> int: + return len(self.records) + + def __getitem__(self, idx: int): + rec = self.records[idx] + cached = self._cache.get(rec.patient_idx) + if cached is None: + event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx] + cached = (event_seq, time_seq, cont, cate, int(sex)) + self._cache[rec.patient_idx] = cached + self._cache_order.append(rec.patient_idx) + if len(self._cache_order) > self._cache_max: + drop = self._cache_order.pop(0) + self._cache.pop(drop, None) + else: + event_seq, time_seq, cont, cate, sex = cached + cutoff = rec.cutoff_pos + 1 + event_seq = event_seq[:cutoff] + time_seq = time_seq[:cutoff] + baseline_pos = rec.cutoff_pos # same index in truncated sequence + return event_seq, time_seq, cont, cate, sex, baseline_pos + + +def eval_collate_fn(batch): + from torch.nn.utils.rnn import pad_sequence + + event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip( + *batch) + event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0) + time_batch = pad_sequence( + time_seqs, batch_first=True, padding_value=36525.0) + cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1) + cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1) + sex_batch = torch.tensor(sexes, dtype=torch.long) + baseline_pos = torch.tensor(baseline_pos, dtype=torch.long) + return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos + + +# ------------------------- +# Inference utilities +# ------------------------- + +def predict_cifs( + model: torch.nn.Module, + head: torch.nn.Module, + criterion: torch.nn.Module, + loader: DataLoader, + taus_years: Sequence[float], + device: torch.device, +) -> np.ndarray: + model.eval() + head.eval() + + taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device) + + all_out: List[np.ndarray] = [] + with torch.no_grad(): + for batch in loader: + event_seq, time_seq, cont, cate, sex, baseline_pos = batch + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) + cont = cont.to(device, non_blocking=True) + cate = cate.to(device, non_blocking=True) + sex = sex.to(device, non_blocking=True) + baseline_pos = baseline_pos.to(device, non_blocking=True) + + h = model(event_seq, time_seq, sex, cont, cate) + b_idx = torch.arange(h.size(0), device=device) + c = h[b_idx, baseline_pos] + logits = head(c) + + cifs = criterion.calculate_cifs(logits, taus_t) + out = cifs.detach().cpu().numpy() + all_out.append(out) + + return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,)) + + +def flatten_future_events( + records: Sequence[EvalRecord], + n_causes: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Flatten (record_idx, cause, dt_years) across all future events. + + Used to build horizon labels via vectorized masking + scatter. + """ + rec_idx_parts: List[np.ndarray] = [] + cause_parts: List[np.ndarray] = [] + dt_parts: List[np.ndarray] = [] + + for i, r in enumerate(records): + if r.future_causes.size == 0: + continue + causes = r.future_causes + dts = r.future_dt_years + # Keep only valid cause ids. + m = (causes >= 0) & (causes < n_causes) + if not np.any(m): + continue + causes = causes[m].astype(np.int64, copy=False) + dts = dts[m].astype(np.float32, copy=False) + rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32)) + cause_parts.append(causes) + dt_parts.append(dts) + + if not rec_idx_parts: + return ( + np.zeros((0,), dtype=np.int32), + np.zeros((0,), dtype=np.int64), + np.zeros((0,), dtype=np.float32), + ) + + return ( + np.concatenate(rec_idx_parts, axis=0), + np.concatenate(cause_parts, axis=0), + np.concatenate(dt_parts, axis=0), + ) + + +# ------------------------- +# Metrics helpers +# ------------------------- + +def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Binary ROC AUC with tie-aware average ranks. + + Returns NaN if y_true has no positives or no negatives. + """ + y_true = np.asarray(y_true).astype(np.int32) + y_score = np.asarray(y_score).astype(np.float64) + + n_pos = int(y_true.sum()) + n = int(y_true.size) + n_neg = n - n_pos + if n_pos == 0 or n_neg == 0: + return float("nan") + + order = np.argsort(y_score, kind="mergesort") + scores_sorted = y_score[order] + y_sorted = y_true[order] + + ranks = np.empty(n, dtype=np.float64) + i = 0 + while i < n: + j = i + 1 + while j < n and scores_sorted[j] == scores_sorted[i]: + j += 1 + # average rank for ties, ranks are 1..n + avg_rank = 0.5 * (i + 1 + j) + ranks[i:j] = avg_rank + i = j + + sum_ranks_pos = float((ranks * y_sorted).sum()) + auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg) + return float(auc) + + +def topk_indices(scores: np.ndarray, k: int) -> np.ndarray: + """Return indices of top-k scores per row (descending).""" + if k <= 0: + raise ValueError("k must be positive") + n, K = scores.shape + k = min(k, K) + # argpartition gives arbitrary order within topk; sort those by score + part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k] + part_scores = np.take_along_axis(scores, part, axis=1) + order = np.argsort(-part_scores, axis=1, kind="mergesort") + return np.take_along_axis(part, order, axis=1)