Add evaluation scripts for next-event prediction and horizon-capture evaluation with detailed metric disclaimers

This commit is contained in:
2026-01-17 13:49:39 +08:00
parent 07916ee529
commit bfab601a77
4 changed files with 1069 additions and 0 deletions

View File

@@ -1,2 +1,40 @@
# DeepHealth
## Evaluation
This repo includes two event-driven evaluation entrypoints:
- `evaluate_next_event.py`: next-event prediction using short-window CIF
- `evaluate_horizon.py`: horizon-capture evaluation using CIF at multiple horizons
### IMPORTANT metric disclaimers
- **AUC** reported by `evaluate_horizon.py` is “time-dependent” only because the label depends on the chosen horizon $\tau$.
Without explicit follow-up end times / censoring, this is **not** a classical risk-set AUC with IPCW.
Use it for **model comparison and diagnostics**, not strict statistical interpretation.
- **Brier score** reported by `evaluate_horizon.py` is an unadjusted diagnostic/proxy metric (no censoring adjustment).
Use it to detect probability-mass compression / numerical stability issues; do not claim calibrated absolute risk.
### Example
```bash
# Next-event (no --horizons)
python evaluate_next_event.py \
--run_dir runs/your_run \
--tau_short 0.25 \
--age_bins 40 45 50 55 60 65 70 inf \
--device cuda \
--batch_size 256 \
--seed 0
# Horizon-capture
python evaluate_horizon.py \
--run_dir runs/your_run \
--horizons 0.25 0.5 1.0 2.0 5.0 10.0 \
--age_bins 40 45 50 55 60 65 70 inf \
--device cuda \
--batch_size 256 \
--seed 0
```

277
evaluate_horizon.py Normal file
View File

@@ -0,0 +1,277 @@
"""Horizon-capture evaluation.
DISCLAIMERS (important):
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
Use it for model comparison and diagnostics, not strict statistical interpretation.
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment).
Use them to detect probability-mass compression / numerical stability issues; do not claim
calibrated absolute risk.
"""
import argparse
import os
from typing import Dict, List, Sequence
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
flatten_future_events,
get_test_subset,
load_checkpoint_into,
load_train_config,
make_inference_dataloader_kwargs,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate horizon-capture using CIF at horizons")
p.add_argument("--run_dir", type=str, required=True)
p.add_argument(
"--horizons",
type=str,
nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
help="Horizon grid in years",
)
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--min_pos", type=int, default=20)
p.add_argument(
"--topk_list",
type=int,
nargs="+",
default=[1, 5, 10, 20, 50],
)
return p.parse_args()
def build_labels_within_tau_flat(
n_records: int,
n_causes: int,
event_record_idx: np.ndarray,
event_cause: np.ndarray,
event_dt_years: np.ndarray,
tau_years: float,
) -> np.ndarray:
"""Build y_within_tau using a flattened (record,cause,dt) representation.
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
occurs in (t0, t0+tau].
"""
y = np.zeros((n_records, n_causes), dtype=np.int8)
if event_dt_years.size == 0:
return y
m = event_dt_years <= float(tau_years)
if not np.any(m):
return y
y[event_record_idx[m], event_cause[m]] = 1
return y
def main() -> None:
args = parse_args()
seed_everything(args.seed)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
horizons = parse_float_list(args.horizons)
horizons = [float(h) for h in horizons]
records = build_event_driven_records(
dataset=dataset,
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
# Print disclaimers every run (requested)
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
scores = predict_cifs(model, head, criterion, loader,
horizons, device=device)
# scores shape: (N, K, H)
if scores.ndim != 3:
raise ValueError(
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
N, K, H = scores.shape
if N != len(records):
raise ValueError("Record count mismatch")
# Pre-flatten all future events once to avoid repeated per-record scans.
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
per_tau_rows: List[Dict[str, object]] = []
per_cause_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = []
for h_idx, tau in enumerate(horizons):
s_tau = scores[:, :, h_idx]
y_tau = build_labels_within_tau_flat(
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
# Per-cause counts + Brier (vectorized)
n_pos = y_tau.sum(axis=0).astype(np.int64)
n_neg = (int(N) - n_pos).astype(np.int64)
# Brier per cause: mean_i (y - s)^2
brier_per_cause = np.mean(
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
# AUC: compute only for causes with enough positives and at least 1 negative
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y_tau[:, k].astype(
np.int32), s_tau[:, k].astype(np.float64))
finite_auc = auc[np.isfinite(auc)]
auc_macro = float(np.mean(finite_auc)
) if finite_auc.size > 0 else float("nan")
w_mask = np.isfinite(auc)
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
# Append per-cause rows (vectorized via DataFrame to avoid Python loops)
per_cause_rows.append(
pd.DataFrame(
{
"tau_years": float(tau),
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"brier": brier_per_cause,
}
)
)
# Business metrics for each topK
denom_true_pairs = int(y_tau.sum())
for topk in args.topk_list:
topk = int(topk)
idx = topk_indices(s_tau, topk)
captured = np.take_along_axis(y_tau, idx, axis=1)
hits = captured.sum(axis=1).astype(np.float64)
true_cnt = y_tau.sum(axis=1).astype(np.float64)
precision_like = hits / float(min(topk, K))
mean_precision = float(np.mean(precision_like)
) if N > 0 else float("nan")
mask_has_true = true_cnt > 0
recall_like = np.full((N,), np.nan, dtype=np.float64)
recall_like[mask_has_true] = hits[mask_has_true] / \
true_cnt[mask_has_true]
mean_recall = float(np.nanmean(recall_like)) if np.any(
mask_has_true) else float("nan")
median_recall = float(np.nanmedian(recall_like)) if np.any(
mask_has_true) else float("nan")
numer_captured_pairs = int(captured.sum())
pop_capture_rate = float(
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
workload_rows.append(
{
"tau_years": float(tau),
"topk": int(topk),
"population_capture_rate": pop_capture_rate,
"mean_precision_like": mean_precision,
"mean_recall_like": mean_recall,
"median_recall_like": median_recall,
"denom_true_pairs": denom_true_pairs,
"numer_captured_pairs": numer_captured_pairs,
}
)
per_tau_rows.append(
{
"tau_years": float(tau),
"n_records": int(N),
"n_causes": int(K),
"auc_macro": auc_macro,
"auc_weighted_by_npos": auc_weighted,
"n_causes_valid_auc": int(n_valid_auc),
"brier_macro": brier_macro,
"brier_weighted_by_npos": brier_weighted,
"total_true_pairs": denom_true_pairs,
}
)
out_metrics = os.path.join(run_dir, "horizon_metrics.csv")
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
out_wy = os.path.join(run_dir, "workload_yield.csv")
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False)
if per_cause_rows:
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
else:
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
print(f"Wrote {out_wy}")
if __name__ == "__main__":
main()

188
evaluate_next_event.py Normal file
View File

@@ -0,0 +1,188 @@
import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
return p.parse_args()
def main() -> None:
args = parse_args()
seed_everything(args.seed)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
dataset=dataset,
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
tau = float(args.tau_short)
scores = predict_cifs(model, head, criterion, loader, [tau], device=device)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
n_records_total = len(records)
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append({"metric": "tau_short_years", "value": tau})
if n_eligible == 0:
out_path = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_path, index=False)
print(f"No eligible records; wrote {out_path}")
return
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Macro OvR AUC per cause (optional)
K = scores.shape[1]
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
auc[k] = auc_k
included = (n_pos >= min_pos) & (n_neg > 0)
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"included": included,
}
)
aucs = auc[np.isfinite(auc)]
if aucs:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
# optional per-cause
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
if __name__ == "__main__":
main()

566
utils.py Normal file
View File

@@ -0,0 +1,566 @@
import json
import math
import os
import random
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from dataset import HealthDataset
from losses import (
DiscreteTimeCIFNLLLoss,
ExponentialNLLLoss,
PiecewiseExponentialCIFNLLLoss,
)
from model import DelphiFork, SapDelphi, SimpleHead
DAYS_PER_YEAR = 365.25
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
def make_inference_dataloader_kwargs(
device: torch.device,
num_workers: int,
) -> Dict[str, Any]:
"""DataLoader kwargs tuned for inference throughput.
Behavior/metrics are unchanged; this only impacts speed.
"""
use_cuda = device.type == "cuda" and torch.cuda.is_available()
kwargs: Dict[str, Any] = {
"pin_memory": bool(use_cuda),
}
if num_workers > 0:
kwargs["persistent_workers"] = True
# default prefetch is 2; set explicitly for clarity.
kwargs["prefetch_factor"] = 2
return kwargs
# -------------------------
# Config + determinism
# -------------------------
def _replace_nonstandard_json_numbers(text: str) -> str:
# Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats.
# Replace bare tokens (not within quotes) with string placeholders.
def repl(match: re.Match[str]) -> str:
token = match.group(0)
if token == "-Infinity":
return '"__NINF__"'
if token == "Infinity":
return '"__INF__"'
if token == "NaN":
return '"__NAN__"'
return token
return re.sub(r'(?<![\w\."])(-Infinity|Infinity|NaN)(?![\w\."])', repl, text)
def _restore_placeholders(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _restore_placeholders(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_restore_placeholders(v) for v in obj]
if obj == "__INF__":
return float("inf")
if obj == "__NINF__":
return float("-inf")
if obj == "__NAN__":
return float("nan")
return obj
def load_train_config(run_dir: str) -> Dict[str, Any]:
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r", encoding="utf-8") as f:
raw = f.read()
raw = _replace_nonstandard_json_numbers(raw)
cfg = json.loads(raw)
cfg = _restore_placeholders(cfg)
return cfg
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_float_list(values: Sequence[str]) -> List[float]:
out: List[float] = []
for v in values:
s = str(v).strip().lower()
if s in {"inf", "+inf", "infty", "infinity", "+infinity"}:
out.append(float("inf"))
elif s in {"-inf", "-infty", "-infinity"}:
out.append(float("-inf"))
else:
out.append(float(v))
return out
# -------------------------
# Dataset + split (match train.py)
# -------------------------
def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset:
data_prefix = cfg["data_prefix"]
full_cov = bool(cfg.get("full_cov", False))
if full_cov:
cov_list = None
else:
cov_list = ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=data_prefix,
covariate_list=cov_list,
)
return dataset
def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset:
n_total = len(dataset)
train_ratio = float(cfg["train_ratio"])
val_ratio = float(cfg["val_ratio"])
seed = int(cfg["random_seed"])
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
_, _, test_subset = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
return test_subset
# -------------------------
# Model + head + criterion (match train.py)
# -------------------------
def build_model_head_criterion(
cfg: Dict[str, Any],
dataset: HealthDataset,
device: torch.device,
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
loss_type = cfg["loss_type"]
if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=float(
cfg.get("lambda_reg", 0.0))).to(device)
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = [float(x) for x in cfg["bin_edges"]]
criterion = DiscreteTimeCIFNLLLoss(
bin_edges=bin_edges,
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
).to(device)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "pwe_cif":
# training drops +inf for PWE
raw_edges = [float(x) for x in cfg["bin_edges"]]
pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
"pwe_cif requires at least 2 finite bin edges (including 0). "
f"Got bin_edges={raw_edges}"
)
if float(pwe_edges[0]) != 0.0:
raise ValueError(
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}")
criterion = PiecewiseExponentialCIFNLLLoss(
bin_edges=pwe_edges,
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
).to(device)
n_bins = len(pwe_edges) - 1
out_dims = [dataset.n_disease, n_bins]
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
model_type = cfg["model_type"]
if model_type == "delphi_fork":
model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=N_TECH_TOKENS,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
).to(device)
elif model_type == "sap_delphi":
model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=N_TECH_TOKENS,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
pretrained_weights_path=str(
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
freeze_embeddings=True,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(
n_embd=int(cfg["n_embd"]),
out_dims=list(out_dims),
).to(device)
return model, head, criterion
def load_checkpoint_into(
run_dir: str,
model: torch.nn.Module,
head: torch.nn.Module,
criterion: Optional[torch.nn.Module],
device: torch.device,
) -> Dict[str, Any]:
ckpt_path = os.path.join(run_dir, "best_model.pt")
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
if criterion is not None and "criterion_state_dict" in ckpt:
try:
criterion.load_state_dict(
ckpt["criterion_state_dict"], strict=False)
except Exception:
# Criterion state is not essential for inference.
pass
return ckpt
# -------------------------
# Evaluation record construction (event-driven)
# -------------------------
@dataclass(frozen=True)
class EvalRecord:
patient_idx: int
patient_id: Any
doa_days: float
t0_days: float
cutoff_pos: int # baseline position (inclusive)
next_event_cause: Optional[int]
next_event_dt_years: Optional[float]
future_causes: np.ndarray # (E,) in [0..K-1]
future_dt_years: np.ndarray # (E,) strictly > 0
def _to_days(x_years: float) -> float:
if math.isinf(float(x_years)):
return float("inf")
return float(x_years) * DAYS_PER_YEAR
def build_event_driven_records(
dataset: HealthDataset,
subset: Subset,
age_bins_years: Sequence[float],
seed: int,
) -> List[EvalRecord]:
if len(age_bins_years) < 2:
raise ValueError("age_bins must have at least 2 boundaries")
age_bins_days = [_to_days(b) for b in age_bins_years]
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
raise ValueError("age_bins must be strictly increasing")
rng = np.random.default_rng(seed)
records: List[EvalRecord] = []
# Subset.indices is deterministic from random_split
indices = list(getattr(subset, "indices", range(len(subset))))
# Speed: avoid calling dataset.__getitem__ for every patient here.
# We only need DOA + event times/codes to create evaluation records.
eps = 1e-6
for patient_idx in indices:
patient_id = dataset.patient_ids[patient_idx]
doa_days = float(dataset._doa[patient_idx])
raw_records = dataset.patient_events.get(patient_id, [])
if raw_records:
times = np.asarray([t for t, _ in raw_records], dtype=np.float64)
codes = np.asarray([c for _, c in raw_records], dtype=np.int64)
else:
times = np.zeros((0,), dtype=np.float64)
codes = np.zeros((0,), dtype=np.int64)
# Mirror HealthDataset insertion logic exactly.
insert_pos = int(np.searchsorted(times, doa_days, side="left"))
times_ins = np.insert(times, insert_pos, doa_days)
codes_ins = np.insert(codes, insert_pos, 1)
is_disease = codes_ins >= N_TECH_TOKENS
disease_times = times_ins[is_disease]
for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b]
hi = age_bins_days[b + 1]
# Inclusion rule:
# 1) DOA <= bin_upper
if not (doa_days <= hi):
continue
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA
in_bin = (disease_times >= lo) & (
disease_times < hi) & (disease_times >= doa_days)
cand_times = disease_times[in_bin]
if cand_times.size == 0:
continue
t0_days = float(rng.choice(cand_times))
# Baseline position (inclusive) in the *post-DOA-inserted* sequence.
pos = np.flatnonzero(is_disease & np.isclose(
times_ins, t0_days, rtol=0.0, atol=eps))
if pos.size == 0:
disease_pos = np.flatnonzero(is_disease)
if disease_pos.size == 0:
continue
disease_times_full = times_ins[disease_pos]
closest_idx = int(
np.argmin(np.abs(disease_times_full - t0_days)))
cutoff_pos = int(disease_pos[closest_idx])
t0_days = float(disease_times_full[closest_idx])
else:
cutoff_pos = int(pos[0])
# Future disease events strictly after t0
future_mask = (times_ins > (t0_days + eps)) & is_disease
future_pos = np.flatnonzero(future_mask)
if future_pos.size == 0:
next_cause = None
next_dt_years = None
future_causes = np.zeros((0,), dtype=np.int64)
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
else:
future_times_days = times_ins[future_pos]
future_tokens = codes_ins[future_pos]
future_causes = (future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
# next-event = minimal time > t0 (tie broken by earliest position)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
records.append(
EvalRecord(
patient_idx=int(patient_idx),
patient_id=patient_id,
doa_days=float(doa_days),
t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos),
next_event_cause=next_cause,
next_event_dt_years=next_dt_years,
future_causes=future_causes,
future_dt_years=future_dt_years_arr,
)
)
return records
class EvalRecordDataset(Dataset):
def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]):
self.base = base_dataset
self.records = list(records)
self._cache: Dict[int, Tuple[torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
self._cache_order: List[int] = []
self._cache_max = 2048
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int):
rec = self.records[idx]
cached = self._cache.get(rec.patient_idx)
if cached is None:
event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx]
cached = (event_seq, time_seq, cont, cate, int(sex))
self._cache[rec.patient_idx] = cached
self._cache_order.append(rec.patient_idx)
if len(self._cache_order) > self._cache_max:
drop = self._cache_order.pop(0)
self._cache.pop(drop, None)
else:
event_seq, time_seq, cont, cate, sex = cached
cutoff = rec.cutoff_pos + 1
event_seq = event_seq[:cutoff]
time_seq = time_seq[:cutoff]
baseline_pos = rec.cutoff_pos # same index in truncated sequence
return event_seq, time_seq, cont, cate, sex, baseline_pos
def eval_collate_fn(batch):
from torch.nn.utils.rnn import pad_sequence
event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip(
*batch)
event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0)
time_batch = pad_sequence(
time_seqs, batch_first=True, padding_value=36525.0)
cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1)
cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1)
sex_batch = torch.tensor(sexes, dtype=torch.long)
baseline_pos = torch.tensor(baseline_pos, dtype=torch.long)
return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos
# -------------------------
# Inference utilities
# -------------------------
def predict_cifs(
model: torch.nn.Module,
head: torch.nn.Module,
criterion: torch.nn.Module,
loader: DataLoader,
taus_years: Sequence[float],
device: torch.device,
) -> np.ndarray:
model.eval()
head.eval()
taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device)
all_out: List[np.ndarray] = []
with torch.no_grad():
for batch in loader:
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
cont = cont.to(device, non_blocking=True)
cate = cate.to(device, non_blocking=True)
sex = sex.to(device, non_blocking=True)
baseline_pos = baseline_pos.to(device, non_blocking=True)
h = model(event_seq, time_seq, sex, cont, cate)
b_idx = torch.arange(h.size(0), device=device)
c = h[b_idx, baseline_pos]
logits = head(c)
cifs = criterion.calculate_cifs(logits, taus_t)
out = cifs.detach().cpu().numpy()
all_out.append(out)
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))
def flatten_future_events(
records: Sequence[EvalRecord],
n_causes: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Flatten (record_idx, cause, dt_years) across all future events.
Used to build horizon labels via vectorized masking + scatter.
"""
rec_idx_parts: List[np.ndarray] = []
cause_parts: List[np.ndarray] = []
dt_parts: List[np.ndarray] = []
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
causes = r.future_causes
dts = r.future_dt_years
# Keep only valid cause ids.
m = (causes >= 0) & (causes < n_causes)
if not np.any(m):
continue
causes = causes[m].astype(np.int64, copy=False)
dts = dts[m].astype(np.float32, copy=False)
rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32))
cause_parts.append(causes)
dt_parts.append(dts)
if not rec_idx_parts:
return (
np.zeros((0,), dtype=np.int32),
np.zeros((0,), dtype=np.int64),
np.zeros((0,), dtype=np.float32),
)
return (
np.concatenate(rec_idx_parts, axis=0),
np.concatenate(cause_parts, axis=0),
np.concatenate(dt_parts, axis=0),
)
# -------------------------
# Metrics helpers
# -------------------------
def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Binary ROC AUC with tie-aware average ranks.
Returns NaN if y_true has no positives or no negatives.
"""
y_true = np.asarray(y_true).astype(np.int32)
y_score = np.asarray(y_score).astype(np.float64)
n_pos = int(y_true.sum())
n = int(y_true.size)
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
order = np.argsort(y_score, kind="mergesort")
scores_sorted = y_score[order]
y_sorted = y_true[order]
ranks = np.empty(n, dtype=np.float64)
i = 0
while i < n:
j = i + 1
while j < n and scores_sorted[j] == scores_sorted[i]:
j += 1
# average rank for ties, ranks are 1..n
avg_rank = 0.5 * (i + 1 + j)
ranks[i:j] = avg_rank
i = j
sum_ranks_pos = float((ranks * y_sorted).sum())
auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
return float(auc)
def topk_indices(scores: np.ndarray, k: int) -> np.ndarray:
"""Return indices of top-k scores per row (descending)."""
if k <= 0:
raise ValueError("k must be positive")
n, K = scores.shape
k = min(k, K)
# argpartition gives arbitrary order within topk; sort those by score
part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
part_scores = np.take_along_axis(scores, part, axis=1)
order = np.argsort(-part_scores, axis=1, kind="mergesort")
return np.take_along_axis(part, order, axis=1)