Files
DeepHealth/evaluate.py

1629 lines
61 KiB
Python

"""
Landmark Analysis Evaluation Script for Longitudinal Health Prediction Models
Implements the comprehensive evaluation framework defined in evaluate_design.md:
- Landmark analysis at age cutoffs {50, 60, 70}
- Prediction horizons {0.25, 0.5, 1, 2, 5, 10} years
- Two tracks: Complete-Case (primary) and Clean Control (academic benchmark)
- Metrics: AUC, Brier Score, Disease-Capture@K, Lift, Yield, DCA
"""
import argparse
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import warnings
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, brier_score_loss
# Import model components
from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn
from losses import (
ExponentialNLLLoss,
DiscreteTimeCIFNLLLoss,
PiecewiseExponentialCIFNLLLoss
)
warnings.filterwarnings('ignore')
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
"""Best-effort torch.compile() wrapper (PyTorch 2.x).
Notes:
- Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes.
If you keep references to graph outputs across steps, PyTorch may raise:
"accessing tensor output of CUDAGraphs that has been overwritten".
- We default to settings that avoid cudagraph output-lifetime pitfalls.
"""
if not enabled:
return module
try:
torch_compile = getattr(torch, "compile", None)
if torch_compile is None:
return module
# Prefer a safer mode for evaluation code; best-effort disable cudagraphs.
kwargs = {"mode": "default"}
try:
kwargs["options"] = {"triton.cudagraphs": False}
except Exception:
pass
return torch_compile(module, **kwargs)
except Exception:
return module
def _maybe_cudagraph_mark_step_begin() -> None:
"""Best-effort step marker for CUDA Graphs compiled execution."""
try:
compiler_mod = getattr(torch, "compiler", None)
if compiler_mod is None:
return
mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None)
if mark is None:
return
mark()
except Exception:
return
def _ensure_dir(path: str) -> str:
os.makedirs(path, exist_ok=True)
return path
def _to_float(x):
try:
return float(x)
except Exception:
return np.nan
def save_summary_json(summary: Dict, output_path: str) -> None:
"""Save a single JSON summary file."""
def convert_to_serializable(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [convert_to_serializable(v) for v in obj]
return obj
summary_serializable = convert_to_serializable(summary)
with open(output_path, 'w') as f:
json.dump(summary_serializable, f, indent=2)
def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]:
"""Save evaluation results into multiple CSV files.
Produces long-form tables so they are easy to analyze/plot:
- landmarks_summary.csv
- auc_per_disease.csv
- capture_at_k.csv
- lift_yield.csv
- dca.csv
Returns:
Mapping from logical name to file path.
"""
out_dir = _ensure_dir(out_dir)
summary_rows: List[Dict] = []
auc_rows: List[Dict] = []
capture_rows: List[Dict] = []
lift_rows: List[Dict] = []
dca_rows: List[Dict] = []
landmarks = results.get('landmarks', [])
for lm in landmarks:
age = lm.get('age_cutoff')
horizon = lm.get('horizon')
for track in ['complete_case', 'clean_control']:
track_res = lm.get(track) or {}
if not track_res:
continue
summary_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'n_patients': track_res.get('n_patients', np.nan),
'n_valid': track_res.get('n_valid', np.nan),
'n_valid_patients': track_res.get('n_valid_patients', np.nan),
'mean_auc': track_res.get('mean_auc', np.nan),
'brier_score': track_res.get('brier_score', np.nan),
'brier_skill_score': track_res.get('brier_skill_score', np.nan),
})
auc_per_disease = track_res.get('auc_per_disease') or {}
for disease_idx, auc in auc_per_disease.items():
auc_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'disease_idx': int(disease_idx),
'auc': _to_float(auc),
})
if track == 'complete_case':
capture = track_res.get('disease_capture_at_k') or {}
for k_val, per_disease in capture.items():
k_int = int(k_val)
for disease_idx, rate in (per_disease or {}).items():
capture_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'k': k_int,
'disease_idx': int(disease_idx),
'capture_rate': _to_float(rate),
})
lift_yield = track_res.get('lift_and_yield') or {}
overall = (lift_yield.get('overall') or {}) if isinstance(
lift_yield, dict) else {}
for frac, metrics in overall.items():
lift_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'level': 'overall',
'disease_idx': '',
'workload_frac': _to_float(frac),
'lift': _to_float((metrics or {}).get('lift')),
'yield': _to_float((metrics or {}).get('yield')),
})
per_disease = (lift_yield.get('per_disease') or {}
) if isinstance(lift_yield, dict) else {}
for disease_idx, disease_metrics in per_disease.items():
for frac, metrics in (disease_metrics or {}).items():
lift_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'level': 'per_disease',
'disease_idx': int(disease_idx),
'workload_frac': _to_float(frac),
'lift': _to_float((metrics or {}).get('lift')),
'yield': _to_float((metrics or {}).get('yield')),
})
dca = track_res.get('dca') or {}
thresholds = dca.get('thresholds')
net_benefit = dca.get('net_benefit')
if thresholds is not None and net_benefit is not None:
thresholds_arr = np.asarray(thresholds, dtype=np.float64)
nb_arr = np.asarray(net_benefit, dtype=np.float64)
for thr, nb in zip(thresholds_arr, nb_arr):
dca_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'threshold': float(thr),
'net_benefit': float(nb),
})
paths: Dict[str, str] = {}
df_summary = pd.DataFrame(summary_rows)
summary_path = os.path.join(out_dir, 'landmarks_summary.csv')
df_summary.to_csv(summary_path, index=False)
paths['landmarks_summary'] = summary_path
df_auc = pd.DataFrame(auc_rows)
auc_path = os.path.join(out_dir, 'auc_per_disease.csv')
df_auc.to_csv(auc_path, index=False)
paths['auc_per_disease'] = auc_path
df_capture = pd.DataFrame(capture_rows)
capture_path = os.path.join(out_dir, 'capture_at_k.csv')
df_capture.to_csv(capture_path, index=False)
paths['capture_at_k'] = capture_path
df_lift = pd.DataFrame(lift_rows)
lift_path = os.path.join(out_dir, 'lift_yield.csv')
df_lift.to_csv(lift_path, index=False)
paths['lift_yield'] = lift_path
df_dca = pd.DataFrame(dca_rows)
dca_path = os.path.join(out_dir, 'dca.csv')
df_dca.to_csv(dca_path, index=False)
paths['dca'] = dca_path
return paths
class LandmarkEvaluator:
"""
Comprehensive landmark analysis evaluator for survival/competing risks models.
"""
def __init__(
self,
model: torch.nn.Module,
head: torch.nn.Module,
loss_fn: torch.nn.Module,
dataset: HealthDataset,
eval_indices: Optional[List[int]] = None,
device: str = 'cuda',
batch_size: int = 256,
num_workers: int = 4,
compile_model: bool = True,
):
self.model = model.to(device).eval()
self.head = head.to(device).eval()
self.loss_fn = loss_fn
self.dataset = dataset
self.eval_indices = eval_indices
self.device = device
self.batch_size = batch_size
self.num_workers = num_workers
use_cuda = str(self.device).startswith(
"cuda") and torch.cuda.is_available()
if use_cuda:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
# JIT/compile optimization (best effort)
if compile_model and use_cuda:
self.model = _maybe_torch_compile(self.model, enabled=True)
self.head = _maybe_torch_compile(self.head, enabled=True)
# Evaluation parameters from design doc
self.age_cutoffs = [50, 60, 70]
self.horizons = [0.25, 0.5, 1, 2, 5, 10]
self.top_k_values = [5, 10, 20, 50]
self.workload_fracs = [0.01, 0.05, 0.10, 0.20, 0.50]
# Convert age to days for comparison
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
self.horizons_days = [h * 365.25 for h in self.horizons]
@staticmethod
def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor:
"""Compute last observed (non-padding) time per patient."""
real_mask = event_batch >= 1
masked = time_batch.masked_fill(~real_mask, float('-inf'))
return masked.max(dim=1).values
@staticmethod
def _anchor_indices(
time_batch: torch.Tensor,
event_batch: torch.Tensor,
cutoff_days: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Find anchor index/time: last valid record before cutoff."""
real_mask = event_batch >= 1
before = time_batch < cutoff_days
valid_before = real_mask & before
has_anchor = valid_before.any(dim=1)
# argmax of position under mask gives last True position
L = event_batch.size(1)
pos = torch.arange(L, device=event_batch.device).view(1, L)
anchor_idx = (valid_before.to(torch.long) *
pos).max(dim=1).values.to(torch.long)
t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1)
return has_anchor, anchor_idx, t_anchor
def _labels_and_validity_for_cutoff(
self,
time_batch: torch.Tensor,
event_batch: torch.Tensor,
cutoff_days: float,
horizons_days: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Vectorized label + validity computation for all horizons at a cutoff.
Returns:
labels: (B, H, K) float32 {0,1}
valid_cc: (B, H, K) bool
valid_clean: (B, H, K) bool
"""
n_tech_tokens = 2
K = int(self.dataset.n_disease)
death_code = int(K - 1)
B, L = event_batch.shape
H = int(horizons_days.numel())
# Disease token mask and indices
is_disease = event_batch >= n_tech_tokens
disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1)
# ever_has_disease: (B, K)
ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device)
if is_disease.any():
b_idx, t_idx = is_disease.nonzero(as_tuple=True)
d_idx = disease_idx[b_idx, t_idx]
ever[b_idx, d_idx] = True
# Events within horizon windows: (B, L, H)
offset = time_batch - float(cutoff_days)
within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & (
offset.unsqueeze(-1) <= horizons_days.view(1, 1, H)
)
labels_bool = torch.zeros(
(B, H, K), dtype=torch.bool, device=event_batch.device)
if within.any():
b2, t2, h2 = within.nonzero(as_tuple=True)
d2 = disease_idx[b2, t2]
labels_bool[b2, h2, d2] = True
labels = labels_bool.to(torch.float32)
last_time = self._last_time(time_batch, event_batch) # (B,)
horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H)
death_in_horizon = labels_bool[:, :, death_code] # (B, H)
observed_past_horizon = last_time.view(B, 1) > horizon_end
lost_within_horizon = last_time.view(B, 1) <= horizon_end
# Track A (Complete-Case):
# - if observed past horizon OR death in horizon => valid all diseases
# - else (censored within horizon) => valid only for diseases that occurred within horizon
valid_cc = labels_bool.clone()
full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1)
if full_mask.any():
valid_cc = torch.where(
full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc)
# Track B (Clean-Control) per disease:
# valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window)
never = ~ever # (B, K)
valid_clean = (~death_in_horizon).unsqueeze(-1) & (
labels_bool | (never.unsqueeze(1) & (
~lost_within_horizon).unsqueeze(-1))
)
return labels, valid_cc, valid_clean
def _compute_risk_scores_many_horizons(
self,
logits: torch.Tensor,
t_start_days: torch.Tensor,
horizons_days: torch.Tensor,
) -> torch.Tensor:
"""Compute risk increments for all horizons in one vectorized call.
Args:
logits: model head outputs for anchor points.
t_start_days: (B,) time from anchor to cutoff (days).
horizons_days: (H,) horizons in days.
Returns:
risk: (B, H, K) float32
"""
t_start_days = torch.clamp(t_start_days, min=0)
t_end_days = torch.clamp(t_start_days.unsqueeze(
1) + horizons_days.view(1, -1), min=0)
t_query_years = torch.cat([t_start_days.unsqueeze(
1), t_end_days], dim=1) / 365.25 # (B, H+1)
# calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T)
if hasattr(self.loss_fn, "calculate_cifs"):
cifs = self.loss_fn.calculate_cifs(
logits, t_query_years, return_survival=False)
else:
raise ValueError(
f"Loss function does not support calculate_cifs: {type(self.loss_fn)}")
if cifs.ndim == 2:
# (B, K) -> (B, 1, K)
cifs_bt_k = cifs.unsqueeze(1)
else:
# (B, K, T) -> (B, T, K)
cifs_bt_k = cifs.permute(0, 2, 1)
cif_start = cifs_bt_k[:, :1, :] # (B, 1, K)
cif_end = cifs_bt_k[:, 1:, :] # (B, H, K)
risk = torch.clamp(cif_end - cif_start, min=0)
return risk
@torch.no_grad()
def compute_risk_scores(
self,
indices: List[int],
age_cutoff_days: float,
horizon_days: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute risk scores for specified patient indices at given landmark.
Args:
indices: Patient indices to evaluate
age_cutoff_days: Age cutoff in days (T_cut)
horizon_days: Prediction horizon in days (H)
Returns:
risk_scores: (N, K) array of risk scores per disease
t_anchors: (N,) array of anchor times (days from birth)
valid_mask: (N,) boolean array indicating valid predictions
"""
subset = Subset(self.dataset, indices)
loader = DataLoader(
subset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=health_collate_fn,
num_workers=self.num_workers,
pin_memory=True if self.device == 'cuda' else False,
)
all_risk_scores = []
all_t_anchors = []
all_valid_mask = []
for batch in loader:
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
event_batch = event_batch.to(self.device)
time_batch = time_batch.to(self.device)
cont_batch = cont_batch.to(self.device)
cate_batch = cate_batch.to(self.device)
sex_batch = sex_batch.to(self.device)
# Find anchor point: last valid record before age_cutoff
# Valid records are non-padding events (event >= 1)
valid_mask = event_batch >= 1 # (B, L)
before_cutoff = time_batch < age_cutoff_days # (B, L)
valid_before = valid_mask & before_cutoff # (B, L)
# Find last valid position for each patient
batch_size = event_batch.size(0)
t_anchor = torch.zeros(batch_size, device=self.device)
anchor_idx = torch.zeros(
batch_size, dtype=torch.long, device=self.device)
has_anchor = torch.zeros(
batch_size, dtype=torch.bool, device=self.device)
for b in range(batch_size):
valid_positions = valid_before[b].nonzero(as_tuple=True)[0]
if len(valid_positions) > 0:
last_pos = valid_positions[-1]
anchor_idx[b] = last_pos
t_anchor[b] = time_batch[b, last_pos]
has_anchor[b] = True
# Get model predictions at anchor points
if has_anchor.any():
# If torch.compile uses CUDA Graphs under the hood, mark a new step
# before each compiled invocation to avoid output lifetime issues.
_maybe_cudagraph_mark_step_begin()
# Forward pass
hidden = self.model(event_batch, time_batch,
sex_batch, cont_batch, cate_batch)
# Get predictions at anchor positions
batch_indices = torch.arange(batch_size, device=self.device)
# (B, n_embd)
hidden_at_anchor = hidden[batch_indices, anchor_idx]
# Compute logits using the loaded head
# (B, n_disease, ...) or (B, K+1, n_bins+1) etc.
logits = self.head(hidden_at_anchor)
# Compute CIF scores
# Time gap from anchor to start of horizon
t_start = age_cutoff_days - t_anchor # (B,)
# Time gap from anchor to end of horizon
t_end = age_cutoff_days + horizon_days - t_anchor # (B,)
# Ensure non-negative time gaps
t_start = torch.clamp(t_start, min=0)
t_end = torch.clamp(t_end, min=0)
# Calculate CIF at both time points
cif_start = self._compute_cif(logits, t_start) # (B, K)
cif_end = self._compute_cif(logits, t_end) # (B, K)
# Risk score is the increment within the horizon
risk_scores = cif_end - cif_start # (B, K)
risk_scores = torch.clamp(
risk_scores, min=0) # Ensure non-negative
else:
# No valid anchor points in this batch
risk_scores = torch.zeros(
batch_size, self.dataset.n_disease, device=self.device)
all_risk_scores.append(risk_scores.cpu().numpy())
all_t_anchors.append(t_anchor.cpu().numpy())
all_valid_mask.append(has_anchor.cpu().numpy())
# Concatenate results
risk_scores = np.vstack(all_risk_scores) # (N, K)
t_anchors = np.concatenate(all_t_anchors) # (N,)
valid_mask = np.concatenate(all_valid_mask) # (N,)
return risk_scores, t_anchors, valid_mask
def _compute_cif(self, logits: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Compute Cumulative Incidence Function at time t.
Args:
logits: Model output logits (B, K, ...) depending on loss type
t: Time points (B,) in years from anchor
Returns:
cif: (B, K) cumulative incidence probabilities
"""
t_years = t / 365.25 # Convert to years
if isinstance(self.loss_fn, ExponentialNLLLoss):
# Exponential: logits are (B, K)
lambdas = F.softplus(logits) + 1e-6 # (B, K)
total_lambda = lambdas.sum(dim=-1, keepdim=True) # (B, 1)
# CIF_k(t) = (λ_k / Σλ) * (1 - exp(-Σλ * t))
frac = lambdas / total_lambda # (B, K)
exp_term = 1.0 - \
torch.exp(-total_lambda.squeeze(-1).unsqueeze(-1)
* t_years.unsqueeze(-1))
cif = frac * exp_term # (B, K)
elif isinstance(self.loss_fn, DiscreteTimeCIFNLLLoss):
# Discrete-time CIF: use calculate_cifs method
cif = self.loss_fn.calculate_cifs(
logits, t_years, return_survival=False)
elif isinstance(self.loss_fn, PiecewiseExponentialCIFNLLLoss):
# PWE CIF: use calculate_cifs method
cif = self.loss_fn.calculate_cifs(
logits, t_years, return_survival=False)
else:
raise ValueError(f"Unknown loss type: {type(self.loss_fn)}")
return cif
def prepare_evaluation_cohort(
self,
age_cutoff_days: float,
horizon_days: float,
track: str = 'complete_case',
) -> Tuple[List[int], np.ndarray, np.ndarray]:
"""Prepare evaluation cohort per design protocol with per-disease validity.
Key fixes vs the earlier implementation:
- Multi-disease labeling: mark *all* diseases that occur within the horizon.
- Per-disease validity mask: mask[i, k]=1 if patient i is valid for disease k.
Returns:
indices: list of patient indices included in this cohort table
labels: (N, K) float array
valid_mask: (N, K) bool array
"""
n_tech_tokens = 2 # PAD=0, DOA=1
K = int(self.dataset.n_disease)
# Competing risk token: treat death as competing event.
# As requested: DEATH_CODE is the last disease index.
DEATH_CODE = int(K - 1)
horizon_end_days = age_cutoff_days + horizon_days
indices: List[int] = []
labels_rows: List[np.ndarray] = []
valid_rows: List[np.ndarray] = []
candidate_indices = self.eval_indices if self.eval_indices is not None else list(
range(len(self.dataset)))
for idx in candidate_indices:
patient_id = self.dataset.patient_ids[idx]
records = self.dataset.patient_events.get(patient_id, [])
if not records:
continue
# Must have some information strictly prior to cutoff (anchor existence in the data).
has_pre_cutoff = any(t < age_cutoff_days for t, _ in records)
if not has_pre_cutoff:
continue
# Events after cutoff (already sorted in dataset init)
events_after = [(t, e) for t, e in records if t >= age_cutoff_days]
labels = np.zeros(K, dtype=np.float32)
valid = np.zeros(K, dtype=bool)
# Identify diseases within horizon (multi-label; no early break)
diseases_in_horizon: set[int] = set()
death_in_horizon = False
for t, e in events_after:
if t > horizon_end_days:
break # events are time-sorted
if e < n_tech_tokens:
continue
disease_idx = int(e - n_tech_tokens)
if 0 <= disease_idx < K:
diseases_in_horizon.add(disease_idx)
if disease_idx == DEATH_CODE:
death_in_horizon = True
for d in diseases_in_horizon:
labels[d] = 1.0
last_time = float(records[-1][0])
if track == 'complete_case':
# Track A: Complete-Case at Horizon
# - Hit: disease occurs within horizon => valid positive for that disease
# - Healthy: last record > horizon_end => valid negative for all diseases
# - Death: death within horizon => valid negative for diseases not occurred before death
# (we implement as valid for all diseases, with labels marking any hits incl death)
# - Loss: censored within horizon (no hit for disease) => invalid for that disease
if last_time > horizon_end_days:
# Observed past horizon: negative for all non-hit diseases
valid[:] = True
elif death_in_horizon:
# Competing risk: include as explicit negatives (label 0) for diseases not hit
# and positives for diseases hit before death (we don't model ordering here).
# This matches the requirement: death within window is label=0 for target diseases.
valid[:] = True
else:
# Lost within horizon: only diseases that actually occurred in the horizon are valid positives
# (no assumptions about negatives).
for d in diseases_in_horizon:
valid[d] = True
elif track == 'clean_control':
# Track B: Clean Control (per-disease)
# For each disease k:
# - Hit: disease k occurs within horizon => label=1, valid
# - Pure clean for k: disease k never occurs in entire record => label=0, valid
# - Death within window => drop (invalid) for all k
# - Loss within window => drop (invalid) for all k
# - Late onset of k (after horizon) => drop (invalid) for k
if death_in_horizon:
# Drop all diseases for this patient
valid[:] = False
else:
# Precompute per-disease first occurrence time after cutoff (including after horizon)
first_occ_after_cutoff = np.full(
K, np.inf, dtype=np.float64)
ever_has_disease = np.zeros(K, dtype=bool)
for t, e in records:
if e < n_tech_tokens:
continue
disease_idx = int(e - n_tech_tokens)
if 0 <= disease_idx < K:
ever_has_disease[disease_idx] = True
if t >= age_cutoff_days and t < first_occ_after_cutoff[disease_idx]:
first_occ_after_cutoff[disease_idx] = float(t)
# If lost within horizon and did not have disease k in horizon, it's invalid for k.
lost_within_horizon = last_time <= horizon_end_days
for k in range(K):
if labels[k] == 1.0:
valid[k] = True
continue
if not ever_has_disease[k]:
# Lifetime clean for k
# Still need to have complete follow-up within horizon for clean-control track.
# If censored within horizon, we cannot be sure k didn't occur in the window.
valid[k] = not lost_within_horizon
continue
# Has disease k at some point.
t_first = first_occ_after_cutoff[k]
if np.isfinite(t_first) and t_first > horizon_end_days:
# Late onset of k -> invalid for k
valid[k] = False
else:
# Has k before cutoff (prevalent) or has k after cutoff but not in horizon? -> invalid.
valid[k] = False
else:
raise ValueError(f"Unknown track: {track}")
# Keep patient row if they are valid for at least one disease.
if valid.any():
indices.append(idx)
labels_rows.append(labels)
valid_rows.append(valid)
if not indices:
return [], np.zeros((0, K), dtype=np.float32), np.zeros((0, K), dtype=bool)
return indices, np.stack(labels_rows, axis=0), np.stack(valid_rows, axis=0)
def compute_auc_per_disease(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[int, float]:
"""
Compute time-dependent AUC for each disease.
Args:
risk_scores: (N, K) risk scores
labels: (N, K) binary labels
valid_mask: (N, K) boolean mask for valid evaluations
Returns:
auc_scores: Dict mapping disease_idx -> AUC
"""
auc_scores = {}
n_diseases = risk_scores.shape[1]
for k in range(n_diseases):
mk = valid_mask[:, k]
if not np.any(mk):
auc_scores[k] = np.nan
continue
y_true = labels[mk, k]
y_score = risk_scores[mk, k]
# Check if we have both classes
if len(np.unique(y_true)) < 2:
auc_scores[k] = np.nan
else:
try:
auc = roc_auc_score(y_true, y_score)
auc_scores[k] = auc
except Exception:
auc_scores[k] = np.nan
return auc_scores
def compute_brier_score(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[str, float]:
"""
Compute Brier Score and Brier Skill Score.
Args:
risk_scores: (N, K) risk scores
labels: (N, K) binary labels
valid_mask: (N, K) boolean mask
Returns:
metrics: Dict with 'brier_score' and 'brier_skill_score'
"""
# Apply per-entry valid mask and flatten
m = valid_mask.astype(bool)
y_true_flat = labels[m]
y_pred_flat = risk_scores[m]
mask = ~(np.isnan(y_true_flat) | np.isnan(y_pred_flat))
y_true_flat = y_true_flat[mask]
y_pred_flat = y_pred_flat[mask]
if len(y_true_flat) == 0:
return {'brier_score': np.nan, 'brier_skill_score': np.nan}
bs = brier_score_loss(y_true_flat, y_pred_flat)
# Brier Skill Score: reference is predicting the mean
p_mean = y_true_flat.mean()
bs_ref = ((p_mean - y_true_flat) ** 2).mean()
bss = 1.0 - (bs / bs_ref) if bs_ref > 0 else 0.0
return {'brier_score': bs, 'brier_skill_score': bss}
def compute_disease_capture_at_k(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[int, Dict[int, float]]:
"""
Compute Disease-Capture@K: fraction of true positives where the true
disease appears in the patient's top-K predicted risks.
Args:
risk_scores: (N, K) risk scores
labels: (N, K) binary labels
valid_mask: (N, K) boolean mask
Returns:
capture_rates: Dict[K_value][disease_idx] -> capture rate
"""
capture_rates = {k: {} for k in self.top_k_values}
n_diseases = risk_scores.shape[1]
for disease_idx in range(n_diseases):
mk = valid_mask[:, disease_idx]
if not np.any(mk):
for k in self.top_k_values:
capture_rates[k][disease_idx] = np.nan
continue
y_true = labels[mk, disease_idx]
y_scores = risk_scores[mk] # (N_valid_k, K)
# Find patients with positive label for this disease
pos_mask = y_true == 1
if pos_mask.sum() == 0:
for k in self.top_k_values:
capture_rates[k][disease_idx] = np.nan
continue
# For each positive patient, check if true disease is in top-K
for k_val in self.top_k_values:
captures = []
for i in np.where(pos_mask)[0]:
# Get top-K disease indices for this patient
top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val]
# Check if true disease is in top-K
is_captured = disease_idx in top_k_diseases
captures.append(int(is_captured))
capture_rate = np.mean(captures) if captures else np.nan
capture_rates[k_val][disease_idx] = capture_rate
return capture_rates
def compute_lift_and_yield(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[str, Dict]:
"""
Compute Lift and Yield at various workload fractions.
Args:
risk_scores: (N, K) risk scores
labels: (N, K) binary labels
valid_mask: (N,) boolean mask
Returns:
metrics:
{
'overall': {workload_frac: {'lift': ..., 'yield': ...}, ...},
'per_disease': {disease_idx: {workload_frac: {'lift': ..., 'yield': ...}, ...}, ...}
}
"""
# Overall metric uses a patient-level mask to avoid including purely censored negatives.
# We include patients who either (a) have known outcomes for all diseases, or (b) have at least one hit.
has_any_hit = labels.max(axis=1) > 0
has_all_known = valid_mask.all(axis=1)
overall_patient_mask = has_any_hit | has_all_known
risk_scores_overall = risk_scores[overall_patient_mask]
labels_overall = labels[overall_patient_mask]
valid_overall = valid_mask[overall_patient_mask]
# Flatten to patient-level: any disease event
# Max risk across all diseases for each patient
max_risk_per_patient = risk_scores_overall.max(axis=1)
any_disease_label = labels_overall.max(axis=1)
n_patients = len(max_risk_per_patient)
base_rate = any_disease_label.mean() if n_patients > 0 else 0.0
overall: Dict[float, Dict[str, float]] = {}
for workload_frac in self.workload_fracs:
n_screen = max(1, int(n_patients * workload_frac))
top_n_idx = np.argsort(max_risk_per_patient)[::-1][:n_screen]
top_n_labels = any_disease_label[top_n_idx]
yield_val = float(top_n_labels.mean()) if n_screen > 0 else np.nan
lift_val = (yield_val / float(base_rate)) if base_rate > 0 else 0.0
overall[workload_frac] = {'lift': lift_val, 'yield': yield_val}
per_disease: Dict[int, Dict[float, Dict[str, float]]] = {}
n_diseases = risk_scores.shape[1]
for disease_idx in range(n_diseases):
mk = valid_mask[:, disease_idx]
if not np.any(mk):
per_disease[disease_idx] = {
frac: {'lift': np.nan, 'yield': np.nan} for frac in self.workload_fracs}
continue
disease_scores = risk_scores[mk, disease_idx]
disease_labels = labels[mk, disease_idx]
disease_base_rate = disease_labels.mean() if disease_labels.size > 0 else 0.0
n_patients_k = disease_scores.shape[0]
disease_metrics: Dict[float, Dict[str, float]] = {}
for workload_frac in self.workload_fracs:
n_screen = max(1, int(n_patients_k * workload_frac))
top_n_idx = np.argsort(disease_scores)[::-1][:n_screen]
top_n_labels = disease_labels[top_n_idx]
yield_val = float(top_n_labels.mean()
) if n_screen > 0 else np.nan
lift_val = (yield_val / float(disease_base_rate)
) if disease_base_rate > 0 else 0.0
disease_metrics[workload_frac] = {
'lift': lift_val, 'yield': yield_val}
per_disease[disease_idx] = disease_metrics
return {
'overall': overall,
'per_disease': per_disease,
}
def compute_dca_net_benefit(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
threshold_range: np.ndarray = np.linspace(0, 0.5, 51),
) -> Dict[str, np.ndarray]:
"""
Compute Decision Curve Analysis (DCA) net benefit.
Args:
risk_scores: (N, K) risk scores
labels: (N, K) binary labels
valid_mask: (N,) boolean mask
threshold_range: Array of threshold probabilities
Returns:
dca_results: Dict with 'thresholds' and 'net_benefit' arrays
"""
# Use the same overall patient mask as lift/yield (complete-case style)
has_any_hit = labels.max(axis=1) > 0
has_all_known = valid_mask.all(axis=1)
patient_mask = has_any_hit | has_all_known
risk_scores = risk_scores[patient_mask]
labels = labels[patient_mask]
# Use max risk and any disease label
max_risk = risk_scores.max(axis=1)
any_disease = labels.max(axis=1)
n = len(max_risk)
net_benefits = []
for pt in threshold_range:
if pt == 0:
# Treat all
nb = any_disease.mean()
else:
# Treat if predicted risk > threshold
treat = max_risk >= pt
tp = (treat & (any_disease == 1)).sum()
fp = (treat & (any_disease == 0)).sum()
# Net benefit = (TP/N) - (FP/N) * (pt / (1-pt))
nb = (tp / n) - (fp / n) * (pt / (1 - pt))
net_benefits.append(nb)
return {
'thresholds': threshold_range,
'net_benefit': np.array(net_benefits),
}
def evaluate_landmark(
self,
age_cutoff: float,
horizon: float,
) -> Dict:
"""
Evaluate model at a specific landmark (age_cutoff, horizon).
Args:
age_cutoff: Age cutoff in years
horizon: Prediction horizon in years
Returns:
results: Dictionary with all metrics
"""
age_cutoff_days = age_cutoff * 365.25
horizon_days = horizon * 365.25
print(f"\nEvaluating Landmark: Age={age_cutoff}, Horizon={horizon}y")
results = {
'age_cutoff': age_cutoff,
'horizon': horizon,
'complete_case': {},
'clean_control': {},
}
for track in ['complete_case', 'clean_control']:
print(f" Track: {track}")
# Prepare cohort
indices, labels_array, valid_mask = self.prepare_evaluation_cohort(
age_cutoff_days, horizon_days, track
)
if len(indices) == 0:
print(f" No valid patients for track {track}")
continue
print(f" Cohort size: {len(indices)}")
# Compute risk scores
risk_scores, t_anchors, anchor_mask = self.compute_risk_scores(
indices, age_cutoff_days, horizon_days
)
# Combine anchor availability with per-disease validity
valid_mask = valid_mask & anchor_mask.astype(bool)[:, None]
# Compute metrics
print(" Computing AUC...")
auc_scores = self.compute_auc_per_disease(
risk_scores, labels_array, valid_mask)
mean_auc = np.nanmean(list(auc_scores.values()))
print(" Computing Brier Score...")
brier_metrics = self.compute_brier_score(
risk_scores, labels_array, valid_mask)
# Only compute patient-level and population metrics for complete_case
if track == 'complete_case':
print(" Computing Disease-Capture@K...")
capture_metrics = self.compute_disease_capture_at_k(
risk_scores, labels_array, valid_mask
)
print(" Computing Lift & Yield...")
lift_yield_metrics = self.compute_lift_and_yield(
risk_scores, labels_array, valid_mask
)
print(" Computing DCA...")
dca_metrics = self.compute_dca_net_benefit(
risk_scores, labels_array, valid_mask
)
results[track] = {
'n_patients': len(indices),
'n_valid': int(valid_mask.sum()),
'n_valid_patients': int(valid_mask.any(axis=1).sum()),
'auc_per_disease': auc_scores,
'mean_auc': mean_auc,
'brier_score': brier_metrics['brier_score'],
'brier_skill_score': brier_metrics['brier_skill_score'],
'disease_capture_at_k': capture_metrics,
'lift_and_yield': lift_yield_metrics,
'dca': dca_metrics,
}
else:
# Clean control track: only discrimination metrics
results[track] = {
'n_patients': len(indices),
'n_valid': int(valid_mask.sum()),
'n_valid_patients': int(valid_mask.any(axis=1).sum()),
'auc_per_disease': auc_scores,
'mean_auc': mean_auc,
}
return results
def run_full_evaluation(self) -> Dict:
"""Run the full evaluation using a single-pass DataLoader.
Key optimizations:
- iterate DataLoader exactly once
- run transformer backbone once per batch
- reuse hidden states per cutoff (3x head only)
- vectorize CIF/risk over all horizons in one call
"""
# Build evaluation subset loader
indices = self.eval_indices if self.eval_indices is not None else list(
range(len(self.dataset)))
subset = Subset(self.dataset, indices)
loader = DataLoader(
subset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=health_collate_fn,
num_workers=self.num_workers,
pin_memory=True if str(self.device).startswith('cuda') else False,
)
cutoffs_days = torch.tensor(
# (C,)
self.age_cutoffs_days, dtype=torch.float32, device=self.device)
horizons_days = torch.tensor(
# (H,)
self.horizons_days, dtype=torch.float32, device=self.device)
C = int(cutoffs_days.numel())
H = int(horizons_days.numel())
K = int(self.dataset.n_disease)
# Buffers: store per landmark/track arrays in chunks to avoid repeated I/O.
# Each key stores lists of numpy arrays to be concatenated at the end.
buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {}
for ci in range(C):
for hi in range(H):
for track in ("complete_case", "clean_control"):
buffers[(ci, hi, track)] = {
"risk": [], "labels": [], "valid": []}
with torch.inference_mode():
for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100):
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
event_batch = event_batch.to(self.device, non_blocking=True)
time_batch = time_batch.to(self.device, non_blocking=True)
cont_batch = cont_batch.to(self.device, non_blocking=True)
cate_batch = cate_batch.to(self.device, non_blocking=True)
sex_batch = sex_batch.to(self.device, non_blocking=True)
B, L = event_batch.shape
batch_idx = torch.arange(B, device=self.device)
# Backbone once per batch
_maybe_cudagraph_mark_step_begin()
hidden = self.model(
# (B, L, D)
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
for ci in range(C):
cutoff = float(cutoffs_days[ci].item())
has_anchor, anchor_idx, t_anchor = self._anchor_indices(
time_batch, event_batch, cutoff)
if not has_anchor.any():
continue
# Hidden states at anchor positions
hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D)
logits = self.head(hidden_anchor)
# Vectorized labels/validity for all horizons
labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff(
time_batch, event_batch, cutoff, horizons_days
)
# Risk scores for all horizons (B, H, K)
t_start = torch.clamp(torch.tensor(
cutoff, device=self.device) - t_anchor, min=0)
risk_bhk = self._compute_risk_scores_many_horizons(
logits, t_start, horizons_days)
# Apply anchor constraint to validity
anchor_mask = has_anchor.view(B, 1, 1)
valid_cc_bhk = valid_cc_bhk & anchor_mask
valid_clean_bhk = valid_clean_bhk & anchor_mask
# Push per-horizon chunks
for hi in range(H):
for track, valid_bk in (
("complete_case", valid_cc_bhk[:, hi, :]),
("clean_control", valid_clean_bhk[:, hi, :]),
):
row_mask = valid_bk.any(dim=1)
if not row_mask.any():
continue
r = risk_bhk[row_mask, hi, :].to(
torch.float32).cpu().numpy()
y = labels_bhk[row_mask, hi, :].to(
torch.float32).cpu().numpy()
m = valid_bk[row_mask, :].to(
torch.bool).cpu().numpy()
buffers[(ci, hi, track)]["risk"].append(r)
buffers[(ci, hi, track)]["labels"].append(y)
buffers[(ci, hi, track)]["valid"].append(m)
# Assemble results in the original output schema
all_results: Dict = {
'age_cutoffs': self.age_cutoffs,
'horizons': self.horizons,
'landmarks': [],
}
for ci, age in enumerate(self.age_cutoffs):
for hi, horizon in enumerate(self.horizons):
landmark_results = {
'age_cutoff': age,
'horizon': horizon,
'complete_case': {},
'clean_control': {},
}
for track in ("complete_case", "clean_control"):
chunks = buffers[(ci, hi, track)]
if len(chunks["risk"]) == 0:
continue
risk_scores = np.concatenate(chunks["risk"], axis=0)
labels = np.concatenate(chunks["labels"], axis=0)
valid_mask = np.concatenate(chunks["valid"], axis=0)
auc_scores = self.compute_auc_per_disease(
risk_scores, labels, valid_mask)
mean_auc = np.nanmean(list(auc_scores.values()))
track_out = {
'n_patients': int(valid_mask.shape[0]),
'n_valid': int(valid_mask.sum()),
'n_valid_patients': int((valid_mask.any(axis=1)).sum()),
'auc_per_disease': auc_scores,
'mean_auc': mean_auc,
}
if track == "complete_case":
brier_metrics = self.compute_brier_score(
risk_scores, labels, valid_mask)
capture_metrics = self.compute_disease_capture_at_k(
risk_scores, labels, valid_mask)
lift_yield_metrics = self.compute_lift_and_yield(
risk_scores, labels, valid_mask)
dca_metrics = self.compute_dca_net_benefit(
risk_scores, labels, valid_mask)
track_out.update({
'brier_score': brier_metrics['brier_score'],
'brier_skill_score': brier_metrics['brier_skill_score'],
'disease_capture_at_k': capture_metrics,
'lift_and_yield': lift_yield_metrics,
'dca': dca_metrics,
})
landmark_results[track] = track_out
all_results['landmarks'].append(landmark_results)
return all_results
def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
"""
Load trained model and configuration from run directory.
Args:
run_dir: Path to run directory containing train_config.json and best_model.pt
device: Device to load model on
Returns:
model, head, loss_fn, dataset, config
"""
run_path = Path(run_dir)
# Load config
config_path = run_path / 'train_config.json'
with open(config_path, 'r') as f:
config = json.load(f)
print(f"Loading model from {run_dir}")
print(f"Model type: {config['model_type']}")
print(f"Loss type: {config['loss_type']}")
# Load dataset (same as training) and reproduce the train/val/test split.
# IMPORTANT: do NOT change data_prefix; train.py reads files like
# <data_prefix>_basic_info.csv, <data_prefix>_table.csv, <data_prefix>_event_data.npy
data_prefix = config['data_prefix']
if config.get('full_cov', False):
covariate_list = None
else:
# Match train.py partial-cov settings
covariate_list = ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=data_prefix,
covariate_list=covariate_list,
cache_event_tensors=True,
)
# Reproduce the random_split used in train.py to obtain the held-out test subset.
n_total = len(dataset)
train_ratio = float(config.get('train_ratio', 0.7))
val_ratio = float(config.get('val_ratio', 0.15))
seed = int(config.get('random_seed', 42))
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
if n_test < 0:
raise ValueError(
f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}"
)
from torch.utils.data import random_split
_, _, test_subset = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
test_indices = list(getattr(test_subset, 'indices', []))
# Determine output dimensions based on loss type
import math
if config['loss_type'] == 'exponential':
out_dims = [dataset.n_disease]
elif config['loss_type'] == 'discrete_time_cif':
# logits shape (M, K+1, n_bins+1)
bin_edges = config.get(
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')])
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif config['loss_type'] == 'pwe_cif':
# Piecewise-exponential requires finite edges
bin_edges = config.get(
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0])
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
n_bins = len(pwe_edges) - 1
# logits shape (M, K, n_bins)
out_dims = [dataset.n_disease, n_bins]
else:
raise ValueError(f"Unknown loss type: {config['loss_type']}")
# Build model
if config['model_type'] == 'delphi_fork':
model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2, # PAD=0, DOA=1
n_embd=config['n_embd'],
n_head=config['n_head'],
n_layer=config['n_layer'],
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
age_encoder_type=config['age_encoder'],
pdrop=config['pdrop'],
)
elif config['model_type'] == 'sap_delphi':
model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=config['n_embd'],
n_head=config['n_head'],
n_layer=config['n_layer'],
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
age_encoder_type=config['age_encoder'],
pdrop=config['pdrop'],
pretrained_weights_path=config.get('pretrained_emd_path'),
)
else:
raise ValueError(f"Unknown model type: {config['model_type']}")
# Build head
head = SimpleHead(
n_embd=config['n_embd'],
out_dims=out_dims,
)
# Load model weights (checkpoint contains model and head state dicts)
model_path = run_path / 'best_model.pt'
checkpoint = torch.load(model_path, map_location=device)
# The checkpoint is a dict with 'model_state_dict' and 'head_state_dict'
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
head.load_state_dict(checkpoint['head_state_dict'])
print("Loaded model and head from checkpoint")
else:
raise ValueError(
"Checkpoint format not recognized. Expected 'model_state_dict' and 'head_state_dict' keys.")
model.to(device)
head.to(device)
# Build loss function
if config['loss_type'] == 'exponential':
loss_fn = ExponentialNLLLoss(
lambda_reg=config.get('lambda_reg', 0.0)
)
elif config['loss_type'] == 'discrete_time_cif':
loss_fn = DiscreteTimeCIFNLLLoss(
bin_edges=config.get(
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]),
lambda_reg=config.get('lambda_reg', 0.0),
)
elif config['loss_type'] == 'pwe_cif':
loss_fn = PiecewiseExponentialCIFNLLLoss(
bin_edges=config.get(
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]),
lambda_reg=config.get('lambda_reg', 0.0),
)
else:
raise ValueError(f"Unknown loss type: {config['loss_type']}")
return model, head, loss_fn, dataset, config, test_indices
def print_summary(results: Dict):
"""
Print summary of evaluation results.
Args:
results: Results dictionary
"""
print("\n" + "=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
for landmark in results['landmarks']:
age = landmark['age_cutoff']
horizon = landmark['horizon']
print(f"\nLandmark: Age {age}, Horizon {horizon}y")
print("-" * 40)
# Complete-case results
if 'complete_case' in landmark and landmark['complete_case']:
cc = landmark['complete_case']
print(f" Complete-Case Track:")
print(f" Patients: {cc['n_patients']}")
print(f" Mean AUC: {cc['mean_auc']:.4f}")
print(f" Brier Score: {cc['brier_score']:.4f}")
print(f" Brier Skill Score: {cc['brier_skill_score']:.4f}")
# Show top-K capture rates (average across diseases)
if 'disease_capture_at_k' in cc:
print(f" Disease Capture:")
for k in [5, 10, 20, 50]:
if k in cc['disease_capture_at_k']:
rates = list(cc['disease_capture_at_k'][k].values())
mean_rate = np.nanmean(rates)
print(f" Top-{k}: {mean_rate:.3f}")
# Show lift and yield
if 'lift_and_yield' in cc:
print(f" Lift & Yield:")
overall = cc['lift_and_yield'].get('overall', {}) if isinstance(
cc['lift_and_yield'], dict) else {}
for frac in [0.01, 0.05, 0.10]:
if frac in overall:
lift = overall[frac].get('lift', np.nan)
yield_val = overall[frac].get('yield', np.nan)
print(
f" Overall Top {int(frac*100)}%: Lift={lift:.2f}, Yield={yield_val:.3f}")
# Clean control results
if 'clean_control' in landmark and landmark['clean_control']:
clean = landmark['clean_control']
print(f" Clean-Control Track:")
print(f" Patients: {clean['n_patients']}")
print(f" Mean AUC: {clean['mean_auc']:.4f}")
print("\n" + "=" * 80)
def main():
parser = argparse.ArgumentParser(
description='Evaluate longitudinal health prediction model using landmark analysis'
)
parser.add_argument(
'--run_dir',
type=str,
required=True,
help='Path to run directory containing train_config.json and best_model.pt'
)
parser.add_argument(
'--output',
type=str,
default=None,
help='Output path for results JSON (default: <run_dir>/evaluation_results.json)'
)
parser.add_argument(
'--out_dir',
type=str,
default=None,
help='Directory to write CSV outputs (default: <run_dir>/evaluation_outputs)'
)
parser.add_argument(
'--device',
type=str,
default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to run evaluation on'
)
parser.add_argument(
'--batch_size',
type=int,
default=256,
help='Batch size for evaluation'
)
parser.add_argument(
'--num_workers',
type=int,
default=4,
help='Number of data loader workers'
)
parser.add_argument(
'--no_compile',
action='store_true',
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
)
args = parser.parse_args()
# Load model and dataset
model, head, loss_fn, dataset, config, test_indices = load_model_and_config(
args.run_dir, args.device)
# Create evaluator
evaluator = LandmarkEvaluator(
model=model,
head=head,
loss_fn=loss_fn,
dataset=dataset,
eval_indices=test_indices,
device=args.device,
batch_size=args.batch_size,
num_workers=args.num_workers,
compile_model=(not args.no_compile),
)
# Run evaluation
print("\nStarting landmark analysis evaluation...")
print(f"Age cutoffs: {evaluator.age_cutoffs}")
print(f"Horizons: {evaluator.horizons}")
results = evaluator.run_full_evaluation()
# Add metadata
results['metadata'] = {
'run_dir': args.run_dir,
'config': config,
'n_diseases': dataset.n_disease,
'device': args.device,
}
# Save results (CSV bundle + single JSON summary)
if args.out_dir is None:
args.out_dir = os.path.join(args.run_dir, 'evaluation_outputs')
csv_paths = save_results_csv_bundle(results, args.out_dir)
summary = {
'metadata': results.get('metadata', {}),
'age_cutoffs': results.get('age_cutoffs', []),
'horizons': results.get('horizons', []),
'csv_outputs': csv_paths,
'notes': {
'metrics': [
'AUC (per disease + mean)',
'Brier Score / Brier Skill Score (complete-case only)',
'Disease-Capture@K (complete-case only)',
'Lift/Yield (complete-case only)',
'Decision Curve Analysis (complete-case only)',
],
},
}
if args.output is None:
args.output = os.path.join(args.run_dir, 'evaluation_summary.json')
save_summary_json(summary, args.output)
print(f"\nWrote CSV outputs to: {args.out_dir}")
print(f"Wrote summary JSON to: {args.output}")
# Print summary
print_summary(results)
if __name__ == '__main__':
main()