""" Landmark Analysis Evaluation Script for Longitudinal Health Prediction Models Implements the comprehensive evaluation framework defined in evaluate_design.md: - Landmark analysis at age cutoffs {50, 60, 70} - Prediction horizons {0.25, 0.5, 1, 2, 5, 10} years - Two tracks: Complete-Case (primary) and Clean Control (academic benchmark) - Metrics: AUC, Brier Score, Disease-Capture@K, Lift, Yield, DCA """ import argparse import json import math import os import time from pathlib import Path from typing import Dict, List, Tuple, Optional import warnings import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Subset from tqdm import tqdm from sklearn.metrics import roc_auc_score, brier_score_loss # Import model components from model import DelphiFork, SapDelphi, SimpleHead from dataset import HealthDataset, health_collate_fn from losses import ( ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, PiecewiseExponentialCIFNLLLoss ) warnings.filterwarnings('ignore') def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module: """Best-effort torch.compile() wrapper (PyTorch 2.x). Notes: - Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes. If you keep references to graph outputs across steps, PyTorch may raise: "accessing tensor output of CUDAGraphs that has been overwritten". - We default to settings that avoid cudagraph output-lifetime pitfalls. """ if not enabled: return module try: torch_compile = getattr(torch, "compile", None) if torch_compile is None: return module # Prefer a safer mode for evaluation code; best-effort disable cudagraphs. kwargs = {"mode": "default"} try: kwargs["options"] = {"triton.cudagraphs": False} except Exception: pass return torch_compile(module, **kwargs) except Exception: return module def _maybe_cudagraph_mark_step_begin() -> None: """Best-effort step marker for CUDA Graphs compiled execution.""" try: compiler_mod = getattr(torch, "compiler", None) if compiler_mod is None: return mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None) if mark is None: return mark() except Exception: return def _ensure_dir(path: str) -> str: os.makedirs(path, exist_ok=True) return path def _to_float(x): try: return float(x) except Exception: return np.nan def compute_disease_capture_at_k_fast( y_true: np.ndarray, y_scores: np.ndarray, valid_mask: np.ndarray, top_k_list: List[int], return_counts: bool = False, ) -> Dict[int, Dict[int, float]] | Tuple[Dict[int, Dict[int, float]], np.ndarray, Dict[int, np.ndarray]]: """Vectorized Disease-Capture@K. Definition: for each disease d, among valid positives (y_true==1 and valid_mask), capture@K is the fraction whose predicted top-K diseases contain d. This implementation avoids per-positive full argsorts by computing top-k_max once per sample (using argpartition), sorting those k_max indices by score, and then aggregating hits via bincount. """ if y_scores.ndim != 2: raise ValueError( f"Expected y_scores 2D (N,K), got shape={y_scores.shape}") if y_true.shape != y_scores.shape or valid_mask.shape != y_scores.shape: raise ValueError( f"Shape mismatch: y_true={y_true.shape}, y_scores={y_scores.shape}, valid_mask={valid_mask.shape}") N, K = y_scores.shape top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0}) capture_rates: Dict[int, Dict[int, float]] = { int(k): {} for k in top_k_list} hits_by_k: Dict[int, np.ndarray] = {} if N == 0 or K == 0 or len(top_k_list) == 0: if return_counts: return capture_rates, np.zeros((K,), dtype=np.int64), hits_by_k return capture_rates topk_max = min(max(top_k_list), K) # Valid positives per disease are the denominator. pos_valid = (y_true == 1) & valid_mask.astype(bool) denom = pos_valid.sum(axis=0).astype(np.int64) # (K,) # Compute top-k_max indices per sample once (unordered), then sort those indices by score. part = np.argpartition(y_scores, -topk_max, axis=1)[:, -topk_max:] # (N, topk_max) part_scores = np.take_along_axis(y_scores, part, axis=1) order = np.argsort(part_scores, axis=1)[:, ::-1] topk_idx = np.take_along_axis( part, order, axis=1).astype(np.int32) # (N, topk_max) rows = np.arange(N)[:, None] for k_val in top_k_list: k_eff = min(int(k_val), topk_max) idx_k = topk_idx[:, :k_eff] # (N, k_eff) # For each sample, we count a "hit" for disease d when: # d is in top-K (true by construction for idx_k) # AND sample is a valid positive for disease d. hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool hit_diseases = idx_k[hits_mask] hits = np.bincount(hit_diseases, minlength=K).astype(np.int64) if return_counts: hits_by_k[int(k_val)] = hits # Convert to dict with NaNs for diseases with no valid positives. out_k: Dict[int, float] = {} with np.errstate(divide='ignore', invalid='ignore'): frac = hits / denom for d in range(K): out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan') capture_rates[int(k_val)] = out_k if return_counts: return capture_rates, denom, hits_by_k return capture_rates def save_summary_json(summary: Dict, output_path: str) -> None: """Save a single JSON summary file.""" def convert_to_serializable(obj): if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, dict): return {k: convert_to_serializable(v) for k, v in obj.items()} if isinstance(obj, list): return [convert_to_serializable(v) for v in obj] return obj summary_serializable = convert_to_serializable(summary) with open(output_path, 'w') as f: json.dump(summary_serializable, f, indent=2) def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]: """Save evaluation results into multiple CSV files. Produces long-form tables so they are easy to analyze/plot: - landmarks_summary.csv - auc_per_disease.csv - capture_at_k.csv - lift_yield.csv - dca.csv Returns: Mapping from logical name to file path. """ out_dir = _ensure_dir(out_dir) summary_rows: List[Dict] = [] auc_rows: List[Dict] = [] capture_rows: List[Dict] = [] capture_mean_rows: List[Dict] = [] lift_rows: List[Dict] = [] dca_rows: List[Dict] = [] landmarks = results.get('landmarks', []) for lm in landmarks: age = lm.get('age_cutoff') horizon = lm.get('horizon') for track in ['complete_case', 'clean_control']: track_res = lm.get(track) or {} if not track_res: continue summary_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'n_patients': track_res.get('n_patients', np.nan), 'n_valid': track_res.get('n_valid', np.nan), 'n_valid_patients': track_res.get('n_valid_patients', np.nan), 'mean_auc': track_res.get('mean_auc', np.nan), 'brier_score': track_res.get('brier_score', np.nan), 'brier_skill_score': track_res.get('brier_skill_score', np.nan), }) auc_per_disease = track_res.get('auc_per_disease') or {} for disease_idx, auc in auc_per_disease.items(): auc_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'disease_idx': int(disease_idx), 'auc': _to_float(auc), }) if track == 'complete_case': capture = track_res.get('disease_capture_at_k') or {} # Backward-compatible parsing: # - new format: {per_disease: {k: {d: rate}}, n_positive: {d: n}, macro_avg: {k: x}, micro_avg: {k: y}} # - old format: {k: {d: rate}} if isinstance(capture, dict) and 'per_disease' in capture: per_disease_by_k = capture.get('per_disease') or {} n_positive_by_disease = capture.get('n_positive') or {} macro_by_k = capture.get('macro_avg') or {} micro_by_k = capture.get('micro_avg') or {} else: per_disease_by_k = capture n_positive_by_disease = {} macro_by_k = {} micro_by_k = {} # Per-disease rows (+ n_positive) for k_val, per_disease in (per_disease_by_k or {}).items(): try: k_int = int(k_val) except Exception: continue for disease_idx, rate in (per_disease or {}).items(): try: d_int = int(disease_idx) except Exception: continue n_pos = n_positive_by_disease.get( d_int, n_positive_by_disease.get(str(d_int), np.nan)) capture_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'k': k_int, 'disease_idx': d_int, 'capture_rate': _to_float(rate), 'n_positive': _to_float(n_pos), }) # Macro/Micro summary rows # Prefer explicit macro/micro from the new format; otherwise compute macro from rates. for k_val, per_disease in (per_disease_by_k or {}).items(): try: k_int = int(k_val) except Exception: continue macro = macro_by_k.get( k_int, macro_by_k.get(str(k_int), np.nan)) micro = micro_by_k.get( k_int, micro_by_k.get(str(k_int), np.nan)) if macro is None or (isinstance(macro, float) and np.isnan(macro)): rates = [ _to_float(r) for r in (per_disease or {}).values() ] macro = float(np.nanmean(rates)) if len( rates) else np.nan capture_mean_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'k': k_int, 'macro_avg': _to_float(macro), 'micro_avg': _to_float(micro), }) lift_yield = track_res.get('lift_and_yield') or {} overall = (lift_yield.get('overall') or {}) if isinstance( lift_yield, dict) else {} for frac, metrics in overall.items(): lift_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'level': 'overall', 'disease_idx': '', 'workload_frac': _to_float(frac), 'lift': _to_float((metrics or {}).get('lift')), 'yield': _to_float((metrics or {}).get('yield')), }) per_disease = (lift_yield.get('per_disease') or {} ) if isinstance(lift_yield, dict) else {} for disease_idx, disease_metrics in per_disease.items(): for frac, metrics in (disease_metrics or {}).items(): lift_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'level': 'per_disease', 'disease_idx': int(disease_idx), 'workload_frac': _to_float(frac), 'lift': _to_float((metrics or {}).get('lift')), 'yield': _to_float((metrics or {}).get('yield')), }) dca = track_res.get('dca') or {} thresholds = dca.get('thresholds') net_benefit = dca.get('net_benefit') if thresholds is not None and net_benefit is not None: thresholds_arr = np.asarray(thresholds, dtype=np.float64) nb_arr = np.asarray(net_benefit, dtype=np.float64) for thr, nb in zip(thresholds_arr, nb_arr): dca_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'threshold': float(thr), 'net_benefit': float(nb), }) paths: Dict[str, str] = {} df_summary = pd.DataFrame(summary_rows) summary_path = os.path.join(out_dir, 'landmarks_summary.csv') df_summary.to_csv(summary_path, index=False) paths['landmarks_summary'] = summary_path df_auc = pd.DataFrame(auc_rows) auc_path = os.path.join(out_dir, 'auc_per_disease.csv') df_auc.to_csv(auc_path, index=False) paths['auc_per_disease'] = auc_path df_capture = pd.DataFrame(capture_rows) capture_path = os.path.join(out_dir, 'capture_at_k.csv') df_capture.to_csv(capture_path, index=False) paths['capture_at_k'] = capture_path df_capture_mean = pd.DataFrame(capture_mean_rows) capture_mean_path = os.path.join(out_dir, 'capture_at_k_mean.csv') df_capture_mean.to_csv(capture_mean_path, index=False) paths['capture_at_k_mean'] = capture_mean_path df_lift = pd.DataFrame(lift_rows) lift_path = os.path.join(out_dir, 'lift_yield.csv') df_lift.to_csv(lift_path, index=False) paths['lift_yield'] = lift_path df_dca = pd.DataFrame(dca_rows) dca_path = os.path.join(out_dir, 'dca.csv') df_dca.to_csv(dca_path, index=False) paths['dca'] = dca_path return paths class LandmarkEvaluator: """ Comprehensive landmark analysis evaluator for survival/competing risks models. """ def __init__( self, model: torch.nn.Module, head: torch.nn.Module, loss_fn: torch.nn.Module, dataset: HealthDataset, eval_indices: Optional[List[int]] = None, device: str = 'cuda', batch_size: int = 256, num_workers: int = 4, compile_model: bool = True, check_capture_at_k: bool = False, profile_metrics: bool = False, capture_check_n: int = 200, ): self.model = model.to(device).eval() self.head = head.to(device).eval() self.loss_fn = loss_fn self.dataset = dataset self.eval_indices = eval_indices self.device = device self.batch_size = batch_size self.num_workers = num_workers self.check_capture_at_k = bool(check_capture_at_k) self.profile_metrics = bool(profile_metrics) self.capture_check_n = int(capture_check_n) self._did_capture_check = False use_cuda = str(self.device).startswith( "cuda") and torch.cuda.is_available() if use_cuda: torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True try: torch.set_float32_matmul_precision("high") except Exception: pass # JIT/compile optimization (best effort) if compile_model and use_cuda: self.model = _maybe_torch_compile(self.model, enabled=True) self.head = _maybe_torch_compile(self.head, enabled=True) # Evaluation parameters from design doc self.age_cutoffs = [50, 60, 70] self.horizons = [0.25, 0.5, 1, 2, 5, 10] self.top_k_values = [5, 10, 20, 50] self.workload_fracs = [0.01, 0.05, 0.10, 0.20, 0.50] # Convert age to days for comparison self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs] self.horizons_days = [h * 365.25 for h in self.horizons] @staticmethod def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor: """Compute last observed (non-padding) time per patient.""" real_mask = event_batch >= 1 masked = time_batch.masked_fill(~real_mask, float('-inf')) return masked.max(dim=1).values @staticmethod def _anchor_indices( time_batch: torch.Tensor, event_batch: torch.Tensor, cutoff_days: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Find anchor index/time: last valid record before cutoff.""" real_mask = event_batch >= 1 before = time_batch < cutoff_days valid_before = real_mask & before has_anchor = valid_before.any(dim=1) # argmax of position under mask gives last True position L = event_batch.size(1) pos = torch.arange(L, device=event_batch.device).view(1, L) anchor_idx = (valid_before.to(torch.long) * pos).max(dim=1).values.to(torch.long) t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1) return has_anchor, anchor_idx, t_anchor def _labels_and_validity_for_cutoff( self, time_batch: torch.Tensor, event_batch: torch.Tensor, cutoff_days: float, horizons_days: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Vectorized label + validity computation for all horizons at a cutoff. Returns: labels: (B, H, K) float32 {0,1} valid_cc: (B, H, K) bool valid_clean: (B, H, K) bool """ n_tech_tokens = 2 K = int(self.dataset.n_disease) death_code = int(K - 1) B, L = event_batch.shape H = int(horizons_days.numel()) # Disease token mask and indices is_disease = event_batch >= n_tech_tokens disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1) # ever_has_disease: (B, K) ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device) if is_disease.any(): b_idx, t_idx = is_disease.nonzero(as_tuple=True) d_idx = disease_idx[b_idx, t_idx] ever[b_idx, d_idx] = True # Events within horizon windows: (B, L, H) offset = time_batch - float(cutoff_days) within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & ( offset.unsqueeze(-1) <= horizons_days.view(1, 1, H) ) labels_bool = torch.zeros( (B, H, K), dtype=torch.bool, device=event_batch.device) if within.any(): b2, t2, h2 = within.nonzero(as_tuple=True) d2 = disease_idx[b2, t2] labels_bool[b2, h2, d2] = True labels = labels_bool.to(torch.float32) last_time = self._last_time(time_batch, event_batch) # (B,) horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H) death_in_horizon = labels_bool[:, :, death_code] # (B, H) observed_past_horizon = last_time.view(B, 1) > horizon_end lost_within_horizon = last_time.view(B, 1) <= horizon_end # Track A (Complete-Case): # - if observed past horizon OR death in horizon => valid all diseases # - else (censored within horizon) => valid only for diseases that occurred within horizon valid_cc = labels_bool.clone() full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1) if full_mask.any(): valid_cc = torch.where( full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc) # Track B (Clean-Control) per disease: # valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window) never = ~ever # (B, K) valid_clean = (~death_in_horizon).unsqueeze(-1) & ( labels_bool | (never.unsqueeze(1) & ( ~lost_within_horizon).unsqueeze(-1)) ) return labels, valid_cc, valid_clean def _compute_risk_scores_many_horizons( self, logits: torch.Tensor, t_start_days: torch.Tensor, horizons_days: torch.Tensor, ) -> torch.Tensor: """Compute risk increments for all horizons in one vectorized call. Args: logits: model head outputs for anchor points. t_start_days: (B,) time from anchor to cutoff (days). horizons_days: (H,) horizons in days. Returns: risk: (B, H, K) float32 """ t_start_days = torch.clamp(t_start_days, min=0) t_end_days = torch.clamp(t_start_days.unsqueeze( 1) + horizons_days.view(1, -1), min=0) t_query_years = torch.cat([t_start_days.unsqueeze( 1), t_end_days], dim=1) / 365.25 # (B, H+1) # calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T) if hasattr(self.loss_fn, "calculate_cifs"): cifs = self.loss_fn.calculate_cifs( logits, t_query_years, return_survival=False) else: raise ValueError( f"Loss function does not support calculate_cifs: {type(self.loss_fn)}") if cifs.ndim == 2: # (B, K) -> (B, 1, K) cifs_bt_k = cifs.unsqueeze(1) else: # (B, K, T) -> (B, T, K) cifs_bt_k = cifs.permute(0, 2, 1) cif_start = cifs_bt_k[:, :1, :] # (B, 1, K) cif_end = cifs_bt_k[:, 1:, :] # (B, H, K) risk = torch.clamp(cif_end - cif_start, min=0) return risk @torch.no_grad() def compute_risk_scores( self, indices: List[int], age_cutoff_days: float, horizon_days: float, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute risk scores for specified patient indices at given landmark. Args: indices: Patient indices to evaluate age_cutoff_days: Age cutoff in days (T_cut) horizon_days: Prediction horizon in days (H) Returns: risk_scores: (N, K) array of risk scores per disease t_anchors: (N,) array of anchor times (days from birth) valid_mask: (N,) boolean array indicating valid predictions """ subset = Subset(self.dataset, indices) loader = DataLoader( subset, batch_size=self.batch_size, shuffle=False, collate_fn=health_collate_fn, num_workers=self.num_workers, pin_memory=True if self.device == 'cuda' else False, ) all_risk_scores = [] all_t_anchors = [] all_valid_mask = [] for batch in loader: event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch event_batch = event_batch.to(self.device) time_batch = time_batch.to(self.device) cont_batch = cont_batch.to(self.device) cate_batch = cate_batch.to(self.device) sex_batch = sex_batch.to(self.device) # Find anchor point: last valid record before age_cutoff # Valid records are non-padding events (event >= 1) valid_mask = event_batch >= 1 # (B, L) before_cutoff = time_batch < age_cutoff_days # (B, L) valid_before = valid_mask & before_cutoff # (B, L) # Find last valid position for each patient batch_size = event_batch.size(0) t_anchor = torch.zeros(batch_size, device=self.device) anchor_idx = torch.zeros( batch_size, dtype=torch.long, device=self.device) has_anchor = torch.zeros( batch_size, dtype=torch.bool, device=self.device) for b in range(batch_size): valid_positions = valid_before[b].nonzero(as_tuple=True)[0] if len(valid_positions) > 0: last_pos = valid_positions[-1] anchor_idx[b] = last_pos t_anchor[b] = time_batch[b, last_pos] has_anchor[b] = True # Get model predictions at anchor points if has_anchor.any(): # If torch.compile uses CUDA Graphs under the hood, mark a new step # before each compiled invocation to avoid output lifetime issues. _maybe_cudagraph_mark_step_begin() # Forward pass hidden = self.model(event_batch, time_batch, sex_batch, cont_batch, cate_batch) # Get predictions at anchor positions batch_indices = torch.arange(batch_size, device=self.device) # (B, n_embd) hidden_at_anchor = hidden[batch_indices, anchor_idx] # Compute logits using the loaded head # (B, n_disease, ...) or (B, K+1, n_bins+1) etc. logits = self.head(hidden_at_anchor) # Compute CIF scores # Time gap from anchor to start of horizon t_start = age_cutoff_days - t_anchor # (B,) # Time gap from anchor to end of horizon t_end = age_cutoff_days + horizon_days - t_anchor # (B,) # Ensure non-negative time gaps t_start = torch.clamp(t_start, min=0) t_end = torch.clamp(t_end, min=0) # Calculate CIF at both time points cif_start = self._compute_cif(logits, t_start) # (B, K) cif_end = self._compute_cif(logits, t_end) # (B, K) # Risk score is the increment within the horizon risk_scores = cif_end - cif_start # (B, K) risk_scores = torch.clamp( risk_scores, min=0) # Ensure non-negative else: # No valid anchor points in this batch risk_scores = torch.zeros( batch_size, self.dataset.n_disease, device=self.device) all_risk_scores.append(risk_scores.cpu().numpy()) all_t_anchors.append(t_anchor.cpu().numpy()) all_valid_mask.append(has_anchor.cpu().numpy()) # Concatenate results risk_scores = np.vstack(all_risk_scores) # (N, K) t_anchors = np.concatenate(all_t_anchors) # (N,) valid_mask = np.concatenate(all_valid_mask) # (N,) return risk_scores, t_anchors, valid_mask def _compute_cif(self, logits: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Compute Cumulative Incidence Function at time t. Args: logits: Model output logits (B, K, ...) depending on loss type t: Time points (B,) in years from anchor Returns: cif: (B, K) cumulative incidence probabilities """ t_years = t / 365.25 # Convert to years if isinstance(self.loss_fn, ExponentialNLLLoss): # Exponential: logits are (B, K) lambdas = F.softplus(logits) + 1e-6 # (B, K) total_lambda = lambdas.sum(dim=-1, keepdim=True) # (B, 1) # CIF_k(t) = (λ_k / Σλ) * (1 - exp(-Σλ * t)) frac = lambdas / total_lambda # (B, K) exp_term = 1.0 - \ torch.exp(-total_lambda.squeeze(-1).unsqueeze(-1) * t_years.unsqueeze(-1)) cif = frac * exp_term # (B, K) elif isinstance(self.loss_fn, DiscreteTimeCIFNLLLoss): # Discrete-time CIF: use calculate_cifs method cif = self.loss_fn.calculate_cifs( logits, t_years, return_survival=False) elif isinstance(self.loss_fn, PiecewiseExponentialCIFNLLLoss): # PWE CIF: use calculate_cifs method cif = self.loss_fn.calculate_cifs( logits, t_years, return_survival=False) else: raise ValueError(f"Unknown loss type: {type(self.loss_fn)}") return cif def prepare_evaluation_cohort( self, age_cutoff_days: float, horizon_days: float, track: str = 'complete_case', ) -> Tuple[List[int], np.ndarray, np.ndarray]: """Prepare evaluation cohort per design protocol with per-disease validity. Key fixes vs the earlier implementation: - Multi-disease labeling: mark *all* diseases that occur within the horizon. - Per-disease validity mask: mask[i, k]=1 if patient i is valid for disease k. Returns: indices: list of patient indices included in this cohort table labels: (N, K) float array valid_mask: (N, K) bool array """ n_tech_tokens = 2 # PAD=0, DOA=1 K = int(self.dataset.n_disease) # Competing risk token: treat death as competing event. # As requested: DEATH_CODE is the last disease index. DEATH_CODE = int(K - 1) horizon_end_days = age_cutoff_days + horizon_days indices: List[int] = [] labels_rows: List[np.ndarray] = [] valid_rows: List[np.ndarray] = [] candidate_indices = self.eval_indices if self.eval_indices is not None else list( range(len(self.dataset))) for idx in candidate_indices: patient_id = self.dataset.patient_ids[idx] records = self.dataset.patient_events.get(patient_id, []) if not records: continue # Must have some information strictly prior to cutoff (anchor existence in the data). has_pre_cutoff = any(t < age_cutoff_days for t, _ in records) if not has_pre_cutoff: continue # Events after cutoff (already sorted in dataset init) events_after = [(t, e) for t, e in records if t >= age_cutoff_days] labels = np.zeros(K, dtype=np.float32) valid = np.zeros(K, dtype=bool) # Identify diseases within horizon (multi-label; no early break) diseases_in_horizon: set[int] = set() death_in_horizon = False for t, e in events_after: if t > horizon_end_days: break # events are time-sorted if e < n_tech_tokens: continue disease_idx = int(e - n_tech_tokens) if 0 <= disease_idx < K: diseases_in_horizon.add(disease_idx) if disease_idx == DEATH_CODE: death_in_horizon = True for d in diseases_in_horizon: labels[d] = 1.0 last_time = float(records[-1][0]) if track == 'complete_case': # Track A: Complete-Case at Horizon # - Hit: disease occurs within horizon => valid positive for that disease # - Healthy: last record > horizon_end => valid negative for all diseases # - Death: death within horizon => valid negative for diseases not occurred before death # (we implement as valid for all diseases, with labels marking any hits incl death) # - Loss: censored within horizon (no hit for disease) => invalid for that disease if last_time > horizon_end_days: # Observed past horizon: negative for all non-hit diseases valid[:] = True elif death_in_horizon: # Competing risk: include as explicit negatives (label 0) for diseases not hit # and positives for diseases hit before death (we don't model ordering here). # This matches the requirement: death within window is label=0 for target diseases. valid[:] = True else: # Lost within horizon: only diseases that actually occurred in the horizon are valid positives # (no assumptions about negatives). for d in diseases_in_horizon: valid[d] = True elif track == 'clean_control': # Track B: Clean Control (per-disease) # For each disease k: # - Hit: disease k occurs within horizon => label=1, valid # - Pure clean for k: disease k never occurs in entire record => label=0, valid # - Death within window => drop (invalid) for all k # - Loss within window => drop (invalid) for all k # - Late onset of k (after horizon) => drop (invalid) for k if death_in_horizon: # Drop all diseases for this patient valid[:] = False else: # Precompute per-disease first occurrence time after cutoff (including after horizon) first_occ_after_cutoff = np.full( K, np.inf, dtype=np.float64) ever_has_disease = np.zeros(K, dtype=bool) for t, e in records: if e < n_tech_tokens: continue disease_idx = int(e - n_tech_tokens) if 0 <= disease_idx < K: ever_has_disease[disease_idx] = True if t >= age_cutoff_days and t < first_occ_after_cutoff[disease_idx]: first_occ_after_cutoff[disease_idx] = float(t) # If lost within horizon and did not have disease k in horizon, it's invalid for k. lost_within_horizon = last_time <= horizon_end_days for k in range(K): if labels[k] == 1.0: valid[k] = True continue if not ever_has_disease[k]: # Lifetime clean for k # Still need to have complete follow-up within horizon for clean-control track. # If censored within horizon, we cannot be sure k didn't occur in the window. valid[k] = not lost_within_horizon continue # Has disease k at some point. t_first = first_occ_after_cutoff[k] if np.isfinite(t_first) and t_first > horizon_end_days: # Late onset of k -> invalid for k valid[k] = False else: # Has k before cutoff (prevalent) or has k after cutoff but not in horizon? -> invalid. valid[k] = False else: raise ValueError(f"Unknown track: {track}") # Keep patient row if they are valid for at least one disease. if valid.any(): indices.append(idx) labels_rows.append(labels) valid_rows.append(valid) if not indices: return [], np.zeros((0, K), dtype=np.float32), np.zeros((0, K), dtype=bool) return indices, np.stack(labels_rows, axis=0), np.stack(valid_rows, axis=0) def compute_auc_per_disease( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[int, float]: """ Compute time-dependent AUC for each disease. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N, K) boolean mask for valid evaluations Returns: auc_scores: Dict mapping disease_idx -> AUC """ auc_scores: Dict[int, float] = {} n_diseases = risk_scores.shape[1] m = valid_mask.astype(bool) # Fast pre-filter: skip diseases without both classes among valid entries. pos = (labels == 1) & m n_valid = m.sum(axis=0) n_pos = pos.sum(axis=0) n_neg = n_valid - n_pos for k in range(n_diseases): if n_valid[k] == 0 or n_pos[k] == 0 or n_neg[k] == 0: auc_scores[k] = np.nan continue mk = m[:, k] y_true_k = labels[mk, k] y_score_k = risk_scores[mk, k] try: auc_scores[k] = float(roc_auc_score(y_true_k, y_score_k)) except Exception: auc_scores[k] = np.nan return auc_scores def compute_brier_score( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[str, float]: """ Compute Brier Score and Brier Skill Score. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N, K) boolean mask Returns: metrics: Dict with 'brier_score' and 'brier_skill_score' """ # Apply per-entry valid mask and flatten m = valid_mask.astype(bool) y_true_flat = labels[m] y_pred_flat = risk_scores[m] mask = ~(np.isnan(y_true_flat) | np.isnan(y_pred_flat)) y_true_flat = y_true_flat[mask] y_pred_flat = y_pred_flat[mask] if len(y_true_flat) == 0: return {'brier_score': np.nan, 'brier_skill_score': np.nan} bs = brier_score_loss(y_true_flat, y_pred_flat) # Brier Skill Score: reference is predicting the mean p_mean = y_true_flat.mean() bs_ref = ((p_mean - y_true_flat) ** 2).mean() bss = 1.0 - (bs / bs_ref) if bs_ref > 0 else 0.0 return {'brier_score': bs, 'brier_skill_score': bss} def compute_disease_capture_at_k( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[str, Dict]: """ Compute Disease-Capture@K: fraction of true positives where the true disease appears in the patient's top-K predicted risks. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N, K) boolean mask Returns: metrics: Dict with keys: - per_disease: Dict[k][disease_idx] -> capture rate - n_positive: Dict[disease_idx] -> number of valid positives (support) - n_captured: Dict[k][disease_idx] -> number of captured positives - macro_avg: Dict[k] -> macro-average capture rate - micro_avg: Dict[k] -> micro-average capture rate """ # Fast path (vectorized): compute top-k_max once per sample. t0 = time.perf_counter() if self.profile_metrics else None capture_rates, denom, hits_by_k = compute_disease_capture_at_k_fast( y_true=labels, y_scores=risk_scores, valid_mask=valid_mask, top_k_list=self.top_k_values, return_counts=True, ) if self.profile_metrics and t0 is not None: dt = time.perf_counter() - t0 print( f" [profile] capture@K fast: {dt:.3f}s (N={risk_scores.shape[0]}, K={risk_scores.shape[1]})") # Optional correctness check against the slow reference on a small subset. if self.check_capture_at_k and (not self._did_capture_check): self._did_capture_check = True rng = np.random.default_rng(0) n_sub = min(self.capture_check_n, risk_scores.shape[0]) sub_idx = rng.choice( risk_scores.shape[0], size=n_sub, replace=False) rs = risk_scores[sub_idx] ys = labels[sub_idx] vm = valid_mask[sub_idx] t1 = time.perf_counter() slow = self._compute_disease_capture_at_k_slow(rs, ys, vm) t2 = time.perf_counter() fast = compute_disease_capture_at_k_fast( ys, rs, vm, self.top_k_values) t3 = time.perf_counter() def _eq(a: float, b: float) -> bool: if np.isnan(a) and np.isnan(b): return True return float(a) == float(b) for k_val in self.top_k_values: for d in range(rs.shape[1]): if not _eq(slow[int(k_val)][d], fast[int(k_val)][d]): raise AssertionError( f"Capture@{k_val} mismatch for disease {d}: slow={slow[int(k_val)][d]} fast={fast[int(k_val)][d]}" ) print( f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s" ) K = int(risk_scores.shape[1]) n_positive: Dict[int, int] = {int(d): int(denom[d]) for d in range(K)} n_captured: Dict[int, Dict[int, int]] = {} macro_avg: Dict[int, float] = {} micro_avg: Dict[int, float] = {} total_pos = int(denom.sum()) for k_val in self.top_k_values: k_int = int(k_val) hits = hits_by_k.get(k_int, np.zeros((K,), dtype=np.int64)) n_captured[k_int] = {int(d): int(hits[d]) for d in range(K)} # Macro: mean across diseases with support (ignore NaNs) rates = capture_rates.get(k_int, {}) rate_values = np.array([rates.get(d, np.nan) for d in range(K)], dtype=np.float64) macro_avg[k_int] = float(np.nanmean( rate_values)) if rate_values.size else float('nan') # Micro: sum captured / sum positives micro_avg[k_int] = float( hits.sum() / total_pos) if total_pos > 0 else float('nan') return { 'per_disease': capture_rates, 'n_positive': n_positive, 'n_captured': n_captured, 'macro_avg': macro_avg, 'micro_avg': micro_avg, } def _compute_disease_capture_at_k_slow( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[int, Dict[int, float]]: """Reference implementation (slow): kept for correctness checking.""" capture_rates = {int(k): {} for k in self.top_k_values} n_diseases = risk_scores.shape[1] for disease_idx in range(n_diseases): mk = valid_mask[:, disease_idx] if not np.any(mk): for k in self.top_k_values: capture_rates[int(k)][disease_idx] = np.nan continue y_true = labels[mk, disease_idx] y_scores = risk_scores[mk] # (N_valid_k, K) pos_mask = y_true == 1 if pos_mask.sum() == 0: for k in self.top_k_values: capture_rates[int(k)][disease_idx] = np.nan continue pos_idx = np.where(pos_mask)[0] for k_val in self.top_k_values: captures = [] for i in pos_idx: top_k_diseases = np.argsort(y_scores[i])[::-1][:int(k_val)] captures.append(int(disease_idx in top_k_diseases)) capture_rates[int(k_val)][disease_idx] = float( np.mean(captures)) if captures else np.nan return capture_rates def compute_lift_and_yield( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[str, Dict]: """ Compute Lift and Yield at various workload fractions. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N,) boolean mask Returns: metrics: { 'overall': {workload_frac: {'lift': ..., 'yield': ...}, ...}, 'per_disease': {disease_idx: {workload_frac: {'lift': ..., 'yield': ...}, ...}, ...} } """ # Overall metric uses a patient-level mask to avoid including purely censored negatives. # We include patients who either (a) have known outcomes for all diseases, or (b) have at least one hit. has_any_hit = labels.max(axis=1) > 0 has_all_known = valid_mask.all(axis=1) overall_patient_mask = has_any_hit | has_all_known risk_scores_overall = risk_scores[overall_patient_mask] labels_overall = labels[overall_patient_mask] valid_overall = valid_mask[overall_patient_mask] # Flatten to patient-level: any disease event # Max risk across all diseases for each patient max_risk_per_patient = risk_scores_overall.max(axis=1) any_disease_label = labels_overall.max(axis=1) n_patients = len(max_risk_per_patient) base_rate = any_disease_label.mean() if n_patients > 0 else 0.0 overall: Dict[float, Dict[str, float]] = {} for workload_frac in self.workload_fracs: n_screen = max(1, int(n_patients * workload_frac)) top_n_idx = np.argsort(max_risk_per_patient)[::-1][:n_screen] top_n_labels = any_disease_label[top_n_idx] yield_val = float(top_n_labels.mean()) if n_screen > 0 else np.nan lift_val = (yield_val / float(base_rate)) if base_rate > 0 else 0.0 overall[workload_frac] = {'lift': lift_val, 'yield': yield_val} per_disease: Dict[int, Dict[float, Dict[str, float]]] = {} n_diseases = risk_scores.shape[1] for disease_idx in range(n_diseases): mk = valid_mask[:, disease_idx] if not np.any(mk): per_disease[disease_idx] = { frac: {'lift': np.nan, 'yield': np.nan} for frac in self.workload_fracs} continue disease_scores = risk_scores[mk, disease_idx] disease_labels = labels[mk, disease_idx] disease_base_rate = disease_labels.mean() if disease_labels.size > 0 else 0.0 n_patients_k = disease_scores.shape[0] disease_metrics: Dict[float, Dict[str, float]] = {} for workload_frac in self.workload_fracs: n_screen = max(1, int(n_patients_k * workload_frac)) top_n_idx = np.argsort(disease_scores)[::-1][:n_screen] top_n_labels = disease_labels[top_n_idx] yield_val = float(top_n_labels.mean() ) if n_screen > 0 else np.nan lift_val = (yield_val / float(disease_base_rate) ) if disease_base_rate > 0 else 0.0 disease_metrics[workload_frac] = { 'lift': lift_val, 'yield': yield_val} per_disease[disease_idx] = disease_metrics return { 'overall': overall, 'per_disease': per_disease, } def compute_dca_net_benefit( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, threshold_range: np.ndarray = np.linspace(0, 0.5, 51), ) -> Dict[str, np.ndarray]: """ Compute Decision Curve Analysis (DCA) net benefit. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N,) boolean mask threshold_range: Array of threshold probabilities Returns: dca_results: Dict with 'thresholds' and 'net_benefit' arrays """ # Use the same overall patient mask as lift/yield (complete-case style) has_any_hit = labels.max(axis=1) > 0 has_all_known = valid_mask.all(axis=1) patient_mask = has_any_hit | has_all_known risk_scores = risk_scores[patient_mask] labels = labels[patient_mask] # Use max risk and any disease label max_risk = risk_scores.max(axis=1) any_disease = labels.max(axis=1) n = len(max_risk) net_benefits = [] for pt in threshold_range: if pt == 0: # Treat all nb = any_disease.mean() else: # Treat if predicted risk > threshold treat = max_risk >= pt tp = (treat & (any_disease == 1)).sum() fp = (treat & (any_disease == 0)).sum() # Net benefit = (TP/N) - (FP/N) * (pt / (1-pt)) nb = (tp / n) - (fp / n) * (pt / (1 - pt)) net_benefits.append(nb) return { 'thresholds': threshold_range, 'net_benefit': np.array(net_benefits), } def evaluate_landmark( self, age_cutoff: float, horizon: float, ) -> Dict: """ Evaluate model at a specific landmark (age_cutoff, horizon). Args: age_cutoff: Age cutoff in years horizon: Prediction horizon in years Returns: results: Dictionary with all metrics """ age_cutoff_days = age_cutoff * 365.25 horizon_days = horizon * 365.25 print(f"\nEvaluating Landmark: Age={age_cutoff}, Horizon={horizon}y") results = { 'age_cutoff': age_cutoff, 'horizon': horizon, 'complete_case': {}, 'clean_control': {}, } for track in ['complete_case', 'clean_control']: print(f" Track: {track}") # Prepare cohort indices, labels_array, valid_mask = self.prepare_evaluation_cohort( age_cutoff_days, horizon_days, track ) if len(indices) == 0: print(f" No valid patients for track {track}") continue print(f" Cohort size: {len(indices)}") # Compute risk scores risk_scores, t_anchors, anchor_mask = self.compute_risk_scores( indices, age_cutoff_days, horizon_days ) # Combine anchor availability with per-disease validity valid_mask = valid_mask & anchor_mask.astype(bool)[:, None] # Compute metrics print(" Computing AUC...") auc_scores = self.compute_auc_per_disease( risk_scores, labels_array, valid_mask) mean_auc = np.nanmean(list(auc_scores.values())) print(" Computing Brier Score...") brier_metrics = self.compute_brier_score( risk_scores, labels_array, valid_mask) # Only compute patient-level and population metrics for complete_case if track == 'complete_case': print(" Computing Disease-Capture@K...") capture_metrics = self.compute_disease_capture_at_k( risk_scores, labels_array, valid_mask ) print(" Computing Lift & Yield...") lift_yield_metrics = self.compute_lift_and_yield( risk_scores, labels_array, valid_mask ) print(" Computing DCA...") dca_metrics = self.compute_dca_net_benefit( risk_scores, labels_array, valid_mask ) results[track] = { 'n_patients': len(indices), 'n_valid': int(valid_mask.sum()), 'n_valid_patients': int(valid_mask.any(axis=1).sum()), 'auc_per_disease': auc_scores, 'mean_auc': mean_auc, 'brier_score': brier_metrics['brier_score'], 'brier_skill_score': brier_metrics['brier_skill_score'], 'disease_capture_at_k': capture_metrics, 'lift_and_yield': lift_yield_metrics, 'dca': dca_metrics, } else: # Clean control track: only discrimination metrics results[track] = { 'n_patients': len(indices), 'n_valid': int(valid_mask.sum()), 'n_valid_patients': int(valid_mask.any(axis=1).sum()), 'auc_per_disease': auc_scores, 'mean_auc': mean_auc, } return results def run_full_evaluation(self) -> Dict: """Run the full evaluation using a single-pass DataLoader. Key optimizations: - iterate DataLoader exactly once - run transformer backbone once per batch - reuse hidden states per cutoff (3x head only) - vectorize CIF/risk over all horizons in one call """ # Build evaluation subset loader indices = self.eval_indices if self.eval_indices is not None else list( range(len(self.dataset))) subset = Subset(self.dataset, indices) loader = DataLoader( subset, batch_size=self.batch_size, shuffle=False, collate_fn=health_collate_fn, num_workers=self.num_workers, pin_memory=True if str(self.device).startswith('cuda') else False, ) cutoffs_days = torch.tensor( # (C,) self.age_cutoffs_days, dtype=torch.float32, device=self.device) horizons_days = torch.tensor( # (H,) self.horizons_days, dtype=torch.float32, device=self.device) C = int(cutoffs_days.numel()) H = int(horizons_days.numel()) K = int(self.dataset.n_disease) # Buffers: store per landmark/track arrays in chunks to avoid repeated I/O. # Each key stores lists of numpy arrays to be concatenated at the end. buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {} for ci in range(C): for hi in range(H): for track in ("complete_case", "clean_control"): buffers[(ci, hi, track)] = { "risk": [], "labels": [], "valid": []} with torch.inference_mode(): for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100): event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch event_batch = event_batch.to(self.device, non_blocking=True) time_batch = time_batch.to(self.device, non_blocking=True) cont_batch = cont_batch.to(self.device, non_blocking=True) cate_batch = cate_batch.to(self.device, non_blocking=True) sex_batch = sex_batch.to(self.device, non_blocking=True) B, L = event_batch.shape batch_idx = torch.arange(B, device=self.device) # Backbone once per batch _maybe_cudagraph_mark_step_begin() hidden = self.model( # (B, L, D) event_batch, time_batch, sex_batch, cont_batch, cate_batch) for ci in range(C): cutoff = float(cutoffs_days[ci].item()) has_anchor, anchor_idx, t_anchor = self._anchor_indices( time_batch, event_batch, cutoff) if not has_anchor.any(): continue # Hidden states at anchor positions hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D) logits = self.head(hidden_anchor) # Vectorized labels/validity for all horizons labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff( time_batch, event_batch, cutoff, horizons_days ) # Risk scores for all horizons (B, H, K) t_start = torch.clamp(torch.tensor( cutoff, device=self.device) - t_anchor, min=0) risk_bhk = self._compute_risk_scores_many_horizons( logits, t_start, horizons_days) # Apply anchor constraint to validity anchor_mask = has_anchor.view(B, 1, 1) valid_cc_bhk = valid_cc_bhk & anchor_mask valid_clean_bhk = valid_clean_bhk & anchor_mask # Push per-horizon chunks for hi in range(H): for track, valid_bk in ( ("complete_case", valid_cc_bhk[:, hi, :]), ("clean_control", valid_clean_bhk[:, hi, :]), ): row_mask = valid_bk.any(dim=1) if not row_mask.any(): continue r = risk_bhk[row_mask, hi, :].to( torch.float32).cpu().numpy() y = labels_bhk[row_mask, hi, :].to( torch.float32).cpu().numpy() m = valid_bk[row_mask, :].to( torch.bool).cpu().numpy() buffers[(ci, hi, track)]["risk"].append(r) buffers[(ci, hi, track)]["labels"].append(y) buffers[(ci, hi, track)]["valid"].append(m) # Assemble results in the original output schema all_results: Dict = { 'age_cutoffs': self.age_cutoffs, 'horizons': self.horizons, 'landmarks': [], } for ci, age in enumerate(self.age_cutoffs): for hi, horizon in enumerate(self.horizons): landmark_results = { 'age_cutoff': age, 'horizon': horizon, 'complete_case': {}, 'clean_control': {}, } for track in ("complete_case", "clean_control"): chunks = buffers[(ci, hi, track)] if len(chunks["risk"]) == 0: continue risk_scores = np.concatenate(chunks["risk"], axis=0) labels = np.concatenate(chunks["labels"], axis=0) valid_mask = np.concatenate(chunks["valid"], axis=0) auc_scores = self.compute_auc_per_disease( risk_scores, labels, valid_mask) mean_auc = np.nanmean(list(auc_scores.values())) track_out = { 'n_patients': int(valid_mask.shape[0]), 'n_valid': int(valid_mask.sum()), 'n_valid_patients': int((valid_mask.any(axis=1)).sum()), 'auc_per_disease': auc_scores, 'mean_auc': mean_auc, } if track == "complete_case": brier_metrics = self.compute_brier_score( risk_scores, labels, valid_mask) capture_metrics = self.compute_disease_capture_at_k( risk_scores, labels, valid_mask) lift_yield_metrics = self.compute_lift_and_yield( risk_scores, labels, valid_mask) dca_metrics = self.compute_dca_net_benefit( risk_scores, labels, valid_mask) track_out.update({ 'brier_score': brier_metrics['brier_score'], 'brier_skill_score': brier_metrics['brier_skill_score'], 'disease_capture_at_k': capture_metrics, 'lift_and_yield': lift_yield_metrics, 'dca': dca_metrics, }) landmark_results[track] = track_out all_results['landmarks'].append(landmark_results) return all_results def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple: """ Load trained model and configuration from run directory. Args: run_dir: Path to run directory containing train_config.json and best_model.pt device: Device to load model on Returns: model, head, loss_fn, dataset, config """ run_path = Path(run_dir) # Load config config_path = run_path / 'train_config.json' with open(config_path, 'r') as f: config = json.load(f) print(f"Loading model from {run_dir}") print(f"Model type: {config['model_type']}") print(f"Loss type: {config['loss_type']}") # Load dataset (same as training) and reproduce the train/val/test split. # IMPORTANT: do NOT change data_prefix; train.py reads files like # _basic_info.csv, _table.csv, _event_data.npy data_prefix = config['data_prefix'] if config.get('full_cov', False): covariate_list = None else: # Match train.py partial-cov settings covariate_list = ["bmi", "smoking", "alcohol"] dataset = HealthDataset( data_prefix=data_prefix, covariate_list=covariate_list, cache_event_tensors=True, ) # Reproduce the random_split used in train.py to obtain the held-out test subset. n_total = len(dataset) train_ratio = float(config.get('train_ratio', 0.7)) val_ratio = float(config.get('val_ratio', 0.15)) seed = int(config.get('random_seed', 42)) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val if n_test < 0: raise ValueError( f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}" ) from torch.utils.data import random_split _, _, test_subset = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed), ) test_indices = list(getattr(test_subset, 'indices', [])) # Determine output dimensions based on loss type import math if config['loss_type'] == 'exponential': out_dims = [dataset.n_disease] elif config['loss_type'] == 'discrete_time_cif': # logits shape (M, K+1, n_bins+1) bin_edges = config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]) out_dims = [dataset.n_disease + 1, len(bin_edges)] elif config['loss_type'] == 'pwe_cif': # Piecewise-exponential requires finite edges bin_edges = config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]) pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] n_bins = len(pwe_edges) - 1 # logits shape (M, K, n_bins) out_dims = [dataset.n_disease, n_bins] else: raise ValueError(f"Unknown loss type: {config['loss_type']}") # Build model if config['model_type'] == 'delphi_fork': model = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, # PAD=0, DOA=1 n_embd=config['n_embd'], n_head=config['n_head'], n_layer=config['n_layer'], n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, age_encoder_type=config['age_encoder'], pdrop=config['pdrop'], ) elif config['model_type'] == 'sap_delphi': model = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=config['n_embd'], n_head=config['n_head'], n_layer=config['n_layer'], n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, age_encoder_type=config['age_encoder'], pdrop=config['pdrop'], pretrained_weights_path=config.get('pretrained_emd_path'), ) else: raise ValueError(f"Unknown model type: {config['model_type']}") # Build head head = SimpleHead( n_embd=config['n_embd'], out_dims=out_dims, ) # Load model weights (checkpoint contains model and head state dicts) model_path = run_path / 'best_model.pt' checkpoint = torch.load(model_path, map_location=device) # The checkpoint is a dict with 'model_state_dict' and 'head_state_dict' if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) head.load_state_dict(checkpoint['head_state_dict']) print("Loaded model and head from checkpoint") else: raise ValueError( "Checkpoint format not recognized. Expected 'model_state_dict' and 'head_state_dict' keys.") model.to(device) head.to(device) # Build loss function if config['loss_type'] == 'exponential': loss_fn = ExponentialNLLLoss( lambda_reg=config.get('lambda_reg', 0.0) ) elif config['loss_type'] == 'discrete_time_cif': loss_fn = DiscreteTimeCIFNLLLoss( bin_edges=config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]), lambda_reg=config.get('lambda_reg', 0.0), ) elif config['loss_type'] == 'pwe_cif': # Piecewise-exponential (PWE) requires a FINITE last edge. # If bin_edges ends with +inf (default), drop it and train up to the last finite edge. raw_edges = config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]) pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))] if len(pwe_edges) < 2: raise ValueError( "pwe_cif requires at least 2 finite bin edges (including 0). " f"Got bin_edges={list(raw_edges)}" ) if pwe_edges[0] != 0.0: raise ValueError( f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}" ) loss_fn = PiecewiseExponentialCIFNLLLoss( bin_edges=pwe_edges, lambda_reg=config.get('lambda_reg', 0.0), ) else: raise ValueError(f"Unknown loss type: {config['loss_type']}") return model, head, loss_fn, dataset, config, test_indices def print_summary(results: Dict): """ Print summary of evaluation results. Args: results: Results dictionary """ print("\n" + "=" * 80) print("EVALUATION SUMMARY") print("=" * 80) for landmark in results['landmarks']: age = landmark['age_cutoff'] horizon = landmark['horizon'] print(f"\nLandmark: Age {age}, Horizon {horizon}y") print("-" * 40) # Complete-case results if 'complete_case' in landmark and landmark['complete_case']: cc = landmark['complete_case'] print(f" Complete-Case Track:") print(f" Patients: {cc['n_patients']}") print(f" Mean AUC: {cc['mean_auc']:.4f}") print(f" Brier Score: {cc['brier_score']:.4f}") print(f" Brier Skill Score: {cc['brier_skill_score']:.4f}") # Show top-K capture rates (average across diseases) if 'disease_capture_at_k' in cc: print(f" Disease Capture:") for k in [5, 10, 20, 50]: capture = cc.get('disease_capture_at_k') or {} if isinstance(capture, dict) and 'macro_avg' in capture: macro = (capture.get('macro_avg') or {}).get(k, np.nan) micro = (capture.get('micro_avg') or {}).get(k, np.nan) print( f" Top-{k}: macro={_to_float(macro):.3f}, micro={_to_float(micro):.3f}") elif isinstance(capture, dict) and k in capture: rates = list((capture.get(k) or {}).values()) mean_rate = np.nanmean([_to_float(r) for r in rates]) print(f" Top-{k}: {mean_rate:.3f}") # Show lift and yield if 'lift_and_yield' in cc: print(f" Lift & Yield:") overall = cc['lift_and_yield'].get('overall', {}) if isinstance( cc['lift_and_yield'], dict) else {} for frac in [0.01, 0.05, 0.10]: if frac in overall: lift = overall[frac].get('lift', np.nan) yield_val = overall[frac].get('yield', np.nan) print( f" Overall Top {int(frac*100)}%: Lift={lift:.2f}, Yield={yield_val:.3f}") # Clean control results if 'clean_control' in landmark and landmark['clean_control']: clean = landmark['clean_control'] print(f" Clean-Control Track:") print(f" Patients: {clean['n_patients']}") print(f" Mean AUC: {clean['mean_auc']:.4f}") print("\n" + "=" * 80) def main(): parser = argparse.ArgumentParser( description='Evaluate longitudinal health prediction model using landmark analysis' ) parser.add_argument( '--run_dir', type=str, required=True, help='Path to run directory containing train_config.json and best_model.pt' ) parser.add_argument( '--output', type=str, default=None, help='Output path for results JSON (default: /evaluation_results.json)' ) parser.add_argument( '--out_dir', type=str, default=None, help='Directory to write CSV outputs (default: /evaluation_outputs)' ) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run evaluation on' ) parser.add_argument( '--batch_size', type=int, default=256, help='Batch size for evaluation' ) parser.add_argument( '--num_workers', type=int, default=4, help='Number of data loader workers' ) parser.add_argument( '--no_compile', action='store_true', help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)' ) parser.add_argument( '--check_capture_at_k', action='store_true', help='Run a one-time correctness check: slow vs fast Disease-Capture@K on a small subset' ) parser.add_argument( '--profile_metrics', action='store_true', help='Print basic timings for CPU-side metric computations' ) parser.add_argument( '--capture_check_n', type=int, default=200, help='Number of samples used for the capture@K slow-vs-fast check (default: 200)' ) args = parser.parse_args() # Load model and dataset model, head, loss_fn, dataset, config, test_indices = load_model_and_config( args.run_dir, args.device) # Create evaluator evaluator = LandmarkEvaluator( model=model, head=head, loss_fn=loss_fn, dataset=dataset, eval_indices=test_indices, device=args.device, batch_size=args.batch_size, num_workers=args.num_workers, compile_model=(not args.no_compile), check_capture_at_k=args.check_capture_at_k, profile_metrics=args.profile_metrics, capture_check_n=args.capture_check_n, ) # Run evaluation print("\nStarting landmark analysis evaluation...") print(f"Age cutoffs: {evaluator.age_cutoffs}") print(f"Horizons: {evaluator.horizons}") results = evaluator.run_full_evaluation() # Add metadata results['metadata'] = { 'run_dir': args.run_dir, 'config': config, 'n_diseases': dataset.n_disease, 'device': args.device, } # Save results (CSV bundle + single JSON summary) if args.out_dir is None: args.out_dir = os.path.join(args.run_dir, 'evaluation_outputs') csv_paths = save_results_csv_bundle(results, args.out_dir) summary = { 'metadata': results.get('metadata', {}), 'age_cutoffs': results.get('age_cutoffs', []), 'horizons': results.get('horizons', []), 'csv_outputs': csv_paths, 'notes': { 'metrics': [ 'AUC (per disease + mean)', 'Brier Score / Brier Skill Score (complete-case only)', 'Disease-Capture@K (complete-case only)', 'Lift/Yield (complete-case only)', 'Decision Curve Analysis (complete-case only)', ], }, } if args.output is None: args.output = os.path.join(args.run_dir, 'evaluation_summary.json') save_summary_json(summary, args.output) print(f"\nWrote CSV outputs to: {args.out_dir}") print(f"Wrote summary JSON to: {args.output}") # Print summary print_summary(results) if __name__ == '__main__': main()