""" Landmark Analysis Evaluation Script for Longitudinal Health Prediction Models Implements the comprehensive evaluation framework defined in evaluate_design.md: - Landmark analysis at age cutoffs {50, 60, 70} - Prediction horizons {0.25, 0.5, 1, 2, 5, 10} years - Two tracks: Complete-Case (primary) and Clean Control (academic benchmark) - Metrics: AUC, Brier Score, Disease-Capture@K, Lift, Yield, DCA """ import argparse import json import os from pathlib import Path from typing import Dict, List, Tuple, Optional import warnings import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Subset from tqdm import tqdm from sklearn.metrics import roc_auc_score, brier_score_loss # Import model components from model import DelphiFork, SapDelphi, SimpleHead from dataset import HealthDataset, health_collate_fn from losses import ( ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, PiecewiseExponentialCIFNLLLoss ) warnings.filterwarnings('ignore') def _ensure_dir(path: str) -> str: os.makedirs(path, exist_ok=True) return path def _to_float(x): try: return float(x) except Exception: return np.nan def save_summary_json(summary: Dict, output_path: str) -> None: """Save a single JSON summary file.""" def convert_to_serializable(obj): if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, dict): return {k: convert_to_serializable(v) for k, v in obj.items()} if isinstance(obj, list): return [convert_to_serializable(v) for v in obj] return obj summary_serializable = convert_to_serializable(summary) with open(output_path, 'w') as f: json.dump(summary_serializable, f, indent=2) def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]: """Save evaluation results into multiple CSV files. Produces long-form tables so they are easy to analyze/plot: - landmarks_summary.csv - auc_per_disease.csv - capture_at_k.csv - lift_yield.csv - dca.csv Returns: Mapping from logical name to file path. """ out_dir = _ensure_dir(out_dir) summary_rows: List[Dict] = [] auc_rows: List[Dict] = [] capture_rows: List[Dict] = [] lift_rows: List[Dict] = [] dca_rows: List[Dict] = [] landmarks = results.get('landmarks', []) for lm in landmarks: age = lm.get('age_cutoff') horizon = lm.get('horizon') for track in ['complete_case', 'clean_control']: track_res = lm.get(track) or {} if not track_res: continue summary_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'n_patients': track_res.get('n_patients', np.nan), 'n_valid': track_res.get('n_valid', np.nan), 'n_valid_patients': track_res.get('n_valid_patients', np.nan), 'mean_auc': track_res.get('mean_auc', np.nan), 'brier_score': track_res.get('brier_score', np.nan), 'brier_skill_score': track_res.get('brier_skill_score', np.nan), }) auc_per_disease = track_res.get('auc_per_disease') or {} for disease_idx, auc in auc_per_disease.items(): auc_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'disease_idx': int(disease_idx), 'auc': _to_float(auc), }) if track == 'complete_case': capture = track_res.get('disease_capture_at_k') or {} for k_val, per_disease in capture.items(): k_int = int(k_val) for disease_idx, rate in (per_disease or {}).items(): capture_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'k': k_int, 'disease_idx': int(disease_idx), 'capture_rate': _to_float(rate), }) lift_yield = track_res.get('lift_and_yield') or {} overall = (lift_yield.get('overall') or {}) if isinstance( lift_yield, dict) else {} for frac, metrics in overall.items(): lift_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'level': 'overall', 'disease_idx': '', 'workload_frac': _to_float(frac), 'lift': _to_float((metrics or {}).get('lift')), 'yield': _to_float((metrics or {}).get('yield')), }) per_disease = (lift_yield.get('per_disease') or {} ) if isinstance(lift_yield, dict) else {} for disease_idx, disease_metrics in per_disease.items(): for frac, metrics in (disease_metrics or {}).items(): lift_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'level': 'per_disease', 'disease_idx': int(disease_idx), 'workload_frac': _to_float(frac), 'lift': _to_float((metrics or {}).get('lift')), 'yield': _to_float((metrics or {}).get('yield')), }) dca = track_res.get('dca') or {} thresholds = dca.get('thresholds') net_benefit = dca.get('net_benefit') if thresholds is not None and net_benefit is not None: thresholds_arr = np.asarray(thresholds, dtype=np.float64) nb_arr = np.asarray(net_benefit, dtype=np.float64) for thr, nb in zip(thresholds_arr, nb_arr): dca_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'threshold': float(thr), 'net_benefit': float(nb), }) paths: Dict[str, str] = {} df_summary = pd.DataFrame(summary_rows) summary_path = os.path.join(out_dir, 'landmarks_summary.csv') df_summary.to_csv(summary_path, index=False) paths['landmarks_summary'] = summary_path df_auc = pd.DataFrame(auc_rows) auc_path = os.path.join(out_dir, 'auc_per_disease.csv') df_auc.to_csv(auc_path, index=False) paths['auc_per_disease'] = auc_path df_capture = pd.DataFrame(capture_rows) capture_path = os.path.join(out_dir, 'capture_at_k.csv') df_capture.to_csv(capture_path, index=False) paths['capture_at_k'] = capture_path df_lift = pd.DataFrame(lift_rows) lift_path = os.path.join(out_dir, 'lift_yield.csv') df_lift.to_csv(lift_path, index=False) paths['lift_yield'] = lift_path df_dca = pd.DataFrame(dca_rows) dca_path = os.path.join(out_dir, 'dca.csv') df_dca.to_csv(dca_path, index=False) paths['dca'] = dca_path return paths class LandmarkEvaluator: """ Comprehensive landmark analysis evaluator for survival/competing risks models. """ def __init__( self, model: torch.nn.Module, head: torch.nn.Module, loss_fn: torch.nn.Module, dataset: HealthDataset, device: str = 'cuda', batch_size: int = 256, num_workers: int = 4, ): self.model = model.to(device) self.model.eval() self.head = head.to(device) self.head.eval() self.loss_fn = loss_fn self.dataset = dataset self.device = device self.batch_size = batch_size self.num_workers = num_workers # Evaluation parameters from design doc self.age_cutoffs = [50, 60, 70] self.horizons = [0.25, 0.5, 1, 2, 5, 10] self.top_k_values = [5, 10, 20, 50] self.workload_fracs = [0.01, 0.05, 0.10, 0.20, 0.50] # Convert age to days for comparison self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs] self.horizons_days = [h * 365.25 for h in self.horizons] @torch.no_grad() def compute_risk_scores( self, indices: List[int], age_cutoff_days: float, horizon_days: float, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute risk scores for specified patient indices at given landmark. Args: indices: Patient indices to evaluate age_cutoff_days: Age cutoff in days (T_cut) horizon_days: Prediction horizon in days (H) Returns: risk_scores: (N, K) array of risk scores per disease t_anchors: (N,) array of anchor times (days from birth) valid_mask: (N,) boolean array indicating valid predictions """ subset = Subset(self.dataset, indices) loader = DataLoader( subset, batch_size=self.batch_size, shuffle=False, collate_fn=health_collate_fn, num_workers=self.num_workers, pin_memory=True if self.device == 'cuda' else False, ) all_risk_scores = [] all_t_anchors = [] all_valid_mask = [] for batch in loader: event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch event_batch = event_batch.to(self.device) time_batch = time_batch.to(self.device) cont_batch = cont_batch.to(self.device) cate_batch = cate_batch.to(self.device) sex_batch = sex_batch.to(self.device) # Find anchor point: last valid record before age_cutoff # Valid records are non-padding events (event >= 1) valid_mask = event_batch >= 1 # (B, L) before_cutoff = time_batch < age_cutoff_days # (B, L) valid_before = valid_mask & before_cutoff # (B, L) # Find last valid position for each patient batch_size = event_batch.size(0) t_anchor = torch.zeros(batch_size, device=self.device) anchor_idx = torch.zeros( batch_size, dtype=torch.long, device=self.device) has_anchor = torch.zeros( batch_size, dtype=torch.bool, device=self.device) for b in range(batch_size): valid_positions = valid_before[b].nonzero(as_tuple=True)[0] if len(valid_positions) > 0: last_pos = valid_positions[-1] anchor_idx[b] = last_pos t_anchor[b] = time_batch[b, last_pos] has_anchor[b] = True # Get model predictions at anchor points if has_anchor.any(): # Forward pass hidden = self.model(event_batch, time_batch, sex_batch, cont_batch, cate_batch) # Get predictions at anchor positions batch_indices = torch.arange(batch_size, device=self.device) # (B, n_embd) hidden_at_anchor = hidden[batch_indices, anchor_idx] # Compute logits using the loaded head # (B, n_disease, ...) or (B, K+1, n_bins+1) etc. logits = self.head(hidden_at_anchor) # Compute CIF scores # Time gap from anchor to start of horizon t_start = age_cutoff_days - t_anchor # (B,) # Time gap from anchor to end of horizon t_end = age_cutoff_days + horizon_days - t_anchor # (B,) # Ensure non-negative time gaps t_start = torch.clamp(t_start, min=0) t_end = torch.clamp(t_end, min=0) # Calculate CIF at both time points cif_start = self._compute_cif(logits, t_start) # (B, K) cif_end = self._compute_cif(logits, t_end) # (B, K) # Risk score is the increment within the horizon risk_scores = cif_end - cif_start # (B, K) risk_scores = torch.clamp( risk_scores, min=0) # Ensure non-negative else: # No valid anchor points in this batch risk_scores = torch.zeros( batch_size, self.dataset.n_disease, device=self.device) all_risk_scores.append(risk_scores.cpu().numpy()) all_t_anchors.append(t_anchor.cpu().numpy()) all_valid_mask.append(has_anchor.cpu().numpy()) # Concatenate results risk_scores = np.vstack(all_risk_scores) # (N, K) t_anchors = np.concatenate(all_t_anchors) # (N,) valid_mask = np.concatenate(all_valid_mask) # (N,) return risk_scores, t_anchors, valid_mask def _compute_cif(self, logits: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Compute Cumulative Incidence Function at time t. Args: logits: Model output logits (B, K, ...) depending on loss type t: Time points (B,) in years from anchor Returns: cif: (B, K) cumulative incidence probabilities """ t_years = t / 365.25 # Convert to years if isinstance(self.loss_fn, ExponentialNLLLoss): # Exponential: logits are (B, K) lambdas = F.softplus(logits) + 1e-6 # (B, K) total_lambda = lambdas.sum(dim=-1, keepdim=True) # (B, 1) # CIF_k(t) = (λ_k / Σλ) * (1 - exp(-Σλ * t)) frac = lambdas / total_lambda # (B, K) exp_term = 1.0 - \ torch.exp(-total_lambda.squeeze(-1).unsqueeze(-1) * t_years.unsqueeze(-1)) cif = frac * exp_term # (B, K) elif isinstance(self.loss_fn, DiscreteTimeCIFNLLLoss): # Discrete-time CIF: use calculate_cifs method cif = self.loss_fn.calculate_cifs( logits, t_years, return_survival=False) elif isinstance(self.loss_fn, PiecewiseExponentialCIFNLLLoss): # PWE CIF: use calculate_cifs method cif = self.loss_fn.calculate_cifs( logits, t_years, return_survival=False) else: raise ValueError(f"Unknown loss type: {type(self.loss_fn)}") return cif def prepare_evaluation_cohort( self, age_cutoff_days: float, horizon_days: float, track: str = 'complete_case', ) -> Tuple[List[int], np.ndarray, np.ndarray]: """Prepare evaluation cohort per design protocol with per-disease validity. Key fixes vs the earlier implementation: - Multi-disease labeling: mark *all* diseases that occur within the horizon. - Per-disease validity mask: mask[i, k]=1 if patient i is valid for disease k. Returns: indices: list of patient indices included in this cohort table labels: (N, K) float array valid_mask: (N, K) bool array """ n_tech_tokens = 2 # PAD=0, DOA=1 K = int(self.dataset.n_disease) # Competing risk token: treat death as competing event. # As requested: DEATH_CODE is the last disease index. DEATH_CODE = int(K - 1) horizon_end_days = age_cutoff_days + horizon_days indices: List[int] = [] labels_rows: List[np.ndarray] = [] valid_rows: List[np.ndarray] = [] for idx in range(len(self.dataset)): patient_id = self.dataset.patient_ids[idx] records = self.dataset.patient_events.get(patient_id, []) if not records: continue # Must have some information strictly prior to cutoff (anchor existence in the data). has_pre_cutoff = any(t < age_cutoff_days for t, _ in records) if not has_pre_cutoff: continue # Events after cutoff (already sorted in dataset init) events_after = [(t, e) for t, e in records if t >= age_cutoff_days] labels = np.zeros(K, dtype=np.float32) valid = np.zeros(K, dtype=bool) # Identify diseases within horizon (multi-label; no early break) diseases_in_horizon: set[int] = set() death_in_horizon = False for t, e in events_after: if t > horizon_end_days: break # events are time-sorted if e < n_tech_tokens: continue disease_idx = int(e - n_tech_tokens) if 0 <= disease_idx < K: diseases_in_horizon.add(disease_idx) if disease_idx == DEATH_CODE: death_in_horizon = True for d in diseases_in_horizon: labels[d] = 1.0 last_time = float(records[-1][0]) if track == 'complete_case': # Track A: Complete-Case at Horizon # - Hit: disease occurs within horizon => valid positive for that disease # - Healthy: last record > horizon_end => valid negative for all diseases # - Death: death within horizon => valid negative for diseases not occurred before death # (we implement as valid for all diseases, with labels marking any hits incl death) # - Loss: censored within horizon (no hit for disease) => invalid for that disease if last_time > horizon_end_days: # Observed past horizon: negative for all non-hit diseases valid[:] = True elif death_in_horizon: # Competing risk: include as explicit negatives (label 0) for diseases not hit # and positives for diseases hit before death (we don't model ordering here). # This matches the requirement: death within window is label=0 for target diseases. valid[:] = True else: # Lost within horizon: only diseases that actually occurred in the horizon are valid positives # (no assumptions about negatives). for d in diseases_in_horizon: valid[d] = True elif track == 'clean_control': # Track B: Clean Control (per-disease) # For each disease k: # - Hit: disease k occurs within horizon => label=1, valid # - Pure clean for k: disease k never occurs in entire record => label=0, valid # - Death within window => drop (invalid) for all k # - Loss within window => drop (invalid) for all k # - Late onset of k (after horizon) => drop (invalid) for k if death_in_horizon: # Drop all diseases for this patient valid[:] = False else: # Precompute per-disease first occurrence time after cutoff (including after horizon) first_occ_after_cutoff = np.full( K, np.inf, dtype=np.float64) ever_has_disease = np.zeros(K, dtype=bool) for t, e in records: if e < n_tech_tokens: continue disease_idx = int(e - n_tech_tokens) if 0 <= disease_idx < K: ever_has_disease[disease_idx] = True if t >= age_cutoff_days and t < first_occ_after_cutoff[disease_idx]: first_occ_after_cutoff[disease_idx] = float(t) # If lost within horizon and did not have disease k in horizon, it's invalid for k. lost_within_horizon = last_time <= horizon_end_days for k in range(K): if labels[k] == 1.0: valid[k] = True continue if not ever_has_disease[k]: # Lifetime clean for k # Still need to have complete follow-up within horizon for clean-control track. # If censored within horizon, we cannot be sure k didn't occur in the window. valid[k] = not lost_within_horizon continue # Has disease k at some point. t_first = first_occ_after_cutoff[k] if np.isfinite(t_first) and t_first > horizon_end_days: # Late onset of k -> invalid for k valid[k] = False else: # Has k before cutoff (prevalent) or has k after cutoff but not in horizon? -> invalid. valid[k] = False else: raise ValueError(f"Unknown track: {track}") # Keep patient row if they are valid for at least one disease. if valid.any(): indices.append(idx) labels_rows.append(labels) valid_rows.append(valid) if not indices: return [], np.zeros((0, K), dtype=np.float32), np.zeros((0, K), dtype=bool) return indices, np.stack(labels_rows, axis=0), np.stack(valid_rows, axis=0) def compute_auc_per_disease( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[int, float]: """ Compute time-dependent AUC for each disease. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N, K) boolean mask for valid evaluations Returns: auc_scores: Dict mapping disease_idx -> AUC """ auc_scores = {} n_diseases = risk_scores.shape[1] for k in range(n_diseases): mk = valid_mask[:, k] if not np.any(mk): auc_scores[k] = np.nan continue y_true = labels[mk, k] y_score = risk_scores[mk, k] # Check if we have both classes if len(np.unique(y_true)) < 2: auc_scores[k] = np.nan else: try: auc = roc_auc_score(y_true, y_score) auc_scores[k] = auc except Exception: auc_scores[k] = np.nan return auc_scores def compute_brier_score( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[str, float]: """ Compute Brier Score and Brier Skill Score. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N, K) boolean mask Returns: metrics: Dict with 'brier_score' and 'brier_skill_score' """ # Apply per-entry valid mask and flatten m = valid_mask.astype(bool) y_true_flat = labels[m] y_pred_flat = risk_scores[m] mask = ~(np.isnan(y_true_flat) | np.isnan(y_pred_flat)) y_true_flat = y_true_flat[mask] y_pred_flat = y_pred_flat[mask] if len(y_true_flat) == 0: return {'brier_score': np.nan, 'brier_skill_score': np.nan} bs = brier_score_loss(y_true_flat, y_pred_flat) # Brier Skill Score: reference is predicting the mean p_mean = y_true_flat.mean() bs_ref = ((p_mean - y_true_flat) ** 2).mean() bss = 1.0 - (bs / bs_ref) if bs_ref > 0 else 0.0 return {'brier_score': bs, 'brier_skill_score': bss} def compute_disease_capture_at_k( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[int, Dict[int, float]]: """ Compute Disease-Capture@K: fraction of true positives where the true disease appears in the patient's top-K predicted risks. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N, K) boolean mask Returns: capture_rates: Dict[K_value][disease_idx] -> capture rate """ capture_rates = {k: {} for k in self.top_k_values} n_diseases = risk_scores.shape[1] for disease_idx in range(n_diseases): mk = valid_mask[:, disease_idx] if not np.any(mk): for k in self.top_k_values: capture_rates[k][disease_idx] = np.nan continue y_true = labels[mk, disease_idx] y_scores = risk_scores[mk] # (N_valid_k, K) # Find patients with positive label for this disease pos_mask = y_true == 1 if pos_mask.sum() == 0: for k in self.top_k_values: capture_rates[k][disease_idx] = np.nan continue # For each positive patient, check if true disease is in top-K for k_val in self.top_k_values: captures = [] for i in np.where(pos_mask)[0]: # Get top-K disease indices for this patient top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val] # Check if true disease is in top-K is_captured = disease_idx in top_k_diseases captures.append(int(is_captured)) capture_rate = np.mean(captures) if captures else np.nan capture_rates[k_val][disease_idx] = capture_rate return capture_rates def compute_lift_and_yield( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, ) -> Dict[str, Dict]: """ Compute Lift and Yield at various workload fractions. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N,) boolean mask Returns: metrics: { 'overall': {workload_frac: {'lift': ..., 'yield': ...}, ...}, 'per_disease': {disease_idx: {workload_frac: {'lift': ..., 'yield': ...}, ...}, ...} } """ # Overall metric uses a patient-level mask to avoid including purely censored negatives. # We include patients who either (a) have known outcomes for all diseases, or (b) have at least one hit. has_any_hit = labels.max(axis=1) > 0 has_all_known = valid_mask.all(axis=1) overall_patient_mask = has_any_hit | has_all_known risk_scores_overall = risk_scores[overall_patient_mask] labels_overall = labels[overall_patient_mask] valid_overall = valid_mask[overall_patient_mask] # Flatten to patient-level: any disease event # Max risk across all diseases for each patient max_risk_per_patient = risk_scores_overall.max(axis=1) any_disease_label = labels_overall.max(axis=1) n_patients = len(max_risk_per_patient) base_rate = any_disease_label.mean() if n_patients > 0 else 0.0 overall: Dict[float, Dict[str, float]] = {} for workload_frac in self.workload_fracs: n_screen = max(1, int(n_patients * workload_frac)) top_n_idx = np.argsort(max_risk_per_patient)[::-1][:n_screen] top_n_labels = any_disease_label[top_n_idx] yield_val = float(top_n_labels.mean()) if n_screen > 0 else np.nan lift_val = (yield_val / float(base_rate)) if base_rate > 0 else 0.0 overall[workload_frac] = {'lift': lift_val, 'yield': yield_val} per_disease: Dict[int, Dict[float, Dict[str, float]]] = {} n_diseases = risk_scores.shape[1] for disease_idx in range(n_diseases): mk = valid_mask[:, disease_idx] if not np.any(mk): per_disease[disease_idx] = { frac: {'lift': np.nan, 'yield': np.nan} for frac in self.workload_fracs} continue disease_scores = risk_scores[mk, disease_idx] disease_labels = labels[mk, disease_idx] disease_base_rate = disease_labels.mean() if disease_labels.size > 0 else 0.0 n_patients_k = disease_scores.shape[0] disease_metrics: Dict[float, Dict[str, float]] = {} for workload_frac in self.workload_fracs: n_screen = max(1, int(n_patients_k * workload_frac)) top_n_idx = np.argsort(disease_scores)[::-1][:n_screen] top_n_labels = disease_labels[top_n_idx] yield_val = float(top_n_labels.mean() ) if n_screen > 0 else np.nan lift_val = (yield_val / float(disease_base_rate) ) if disease_base_rate > 0 else 0.0 disease_metrics[workload_frac] = { 'lift': lift_val, 'yield': yield_val} per_disease[disease_idx] = disease_metrics return { 'overall': overall, 'per_disease': per_disease, } def compute_dca_net_benefit( self, risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, threshold_range: np.ndarray = np.linspace(0, 0.5, 51), ) -> Dict[str, np.ndarray]: """ Compute Decision Curve Analysis (DCA) net benefit. Args: risk_scores: (N, K) risk scores labels: (N, K) binary labels valid_mask: (N,) boolean mask threshold_range: Array of threshold probabilities Returns: dca_results: Dict with 'thresholds' and 'net_benefit' arrays """ # Use the same overall patient mask as lift/yield (complete-case style) has_any_hit = labels.max(axis=1) > 0 has_all_known = valid_mask.all(axis=1) patient_mask = has_any_hit | has_all_known risk_scores = risk_scores[patient_mask] labels = labels[patient_mask] # Use max risk and any disease label max_risk = risk_scores.max(axis=1) any_disease = labels.max(axis=1) n = len(max_risk) net_benefits = [] for pt in threshold_range: if pt == 0: # Treat all nb = any_disease.mean() else: # Treat if predicted risk > threshold treat = max_risk >= pt tp = (treat & (any_disease == 1)).sum() fp = (treat & (any_disease == 0)).sum() # Net benefit = (TP/N) - (FP/N) * (pt / (1-pt)) nb = (tp / n) - (fp / n) * (pt / (1 - pt)) net_benefits.append(nb) return { 'thresholds': threshold_range, 'net_benefit': np.array(net_benefits), } def evaluate_landmark( self, age_cutoff: float, horizon: float, ) -> Dict: """ Evaluate model at a specific landmark (age_cutoff, horizon). Args: age_cutoff: Age cutoff in years horizon: Prediction horizon in years Returns: results: Dictionary with all metrics """ age_cutoff_days = age_cutoff * 365.25 horizon_days = horizon * 365.25 print(f"\nEvaluating Landmark: Age={age_cutoff}, Horizon={horizon}y") results = { 'age_cutoff': age_cutoff, 'horizon': horizon, 'complete_case': {}, 'clean_control': {}, } for track in ['complete_case', 'clean_control']: print(f" Track: {track}") # Prepare cohort indices, labels_array, valid_mask = self.prepare_evaluation_cohort( age_cutoff_days, horizon_days, track ) if len(indices) == 0: print(f" No valid patients for track {track}") continue print(f" Cohort size: {len(indices)}") # Compute risk scores risk_scores, t_anchors, anchor_mask = self.compute_risk_scores( indices, age_cutoff_days, horizon_days ) # Combine anchor availability with per-disease validity valid_mask = valid_mask & anchor_mask.astype(bool)[:, None] # Compute metrics print(" Computing AUC...") auc_scores = self.compute_auc_per_disease( risk_scores, labels_array, valid_mask) mean_auc = np.nanmean(list(auc_scores.values())) print(" Computing Brier Score...") brier_metrics = self.compute_brier_score( risk_scores, labels_array, valid_mask) # Only compute patient-level and population metrics for complete_case if track == 'complete_case': print(" Computing Disease-Capture@K...") capture_metrics = self.compute_disease_capture_at_k( risk_scores, labels_array, valid_mask ) print(" Computing Lift & Yield...") lift_yield_metrics = self.compute_lift_and_yield( risk_scores, labels_array, valid_mask ) print(" Computing DCA...") dca_metrics = self.compute_dca_net_benefit( risk_scores, labels_array, valid_mask ) results[track] = { 'n_patients': len(indices), 'n_valid': int(valid_mask.sum()), 'n_valid_patients': int(valid_mask.any(axis=1).sum()), 'auc_per_disease': auc_scores, 'mean_auc': mean_auc, 'brier_score': brier_metrics['brier_score'], 'brier_skill_score': brier_metrics['brier_skill_score'], 'disease_capture_at_k': capture_metrics, 'lift_and_yield': lift_yield_metrics, 'dca': dca_metrics, } else: # Clean control track: only discrimination metrics results[track] = { 'n_patients': len(indices), 'n_valid': int(valid_mask.sum()), 'n_valid_patients': int(valid_mask.any(axis=1).sum()), 'auc_per_disease': auc_scores, 'mean_auc': mean_auc, } return results def run_full_evaluation(self) -> Dict: """ Run complete landmark analysis across all cutoffs and horizons. Returns: all_results: Nested dictionary with all evaluation results """ all_results = { 'age_cutoffs': self.age_cutoffs, 'horizons': self.horizons, 'landmarks': [], } # Evaluate each landmark for age_cutoff in self.age_cutoffs: for horizon in self.horizons: landmark_results = self.evaluate_landmark(age_cutoff, horizon) all_results['landmarks'].append(landmark_results) return all_results def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple: """ Load trained model and configuration from run directory. Args: run_dir: Path to run directory containing train_config.json and best_model.pt device: Device to load model on Returns: model, head, loss_fn, dataset, config """ run_path = Path(run_dir) # Load config config_path = run_path / 'train_config.json' with open(config_path, 'r') as f: config = json.load(f) print(f"Loading model from {run_dir}") print(f"Model type: {config['model_type']}") print(f"Loss type: {config['loss_type']}") # Load dataset to get dimensions data_prefix = config['data_prefix'] # Determine covariate list based on full_cov if config['full_cov']: covariate_list = None # Use all covariates else: # Use partial covariates (define your partial list here) covariate_list = ['age_at_assessment', 'bmi', 'smoking_status'] # Example dataset = HealthDataset( data_prefix=f"{data_prefix}_test", covariate_list=covariate_list, cache_event_tensors=True, ) # Determine output dimensions based on loss type import math if config['loss_type'] == 'exponential': out_dims = [dataset.n_disease] elif config['loss_type'] == 'discrete_time_cif': # logits shape (M, K+1, n_bins+1) bin_edges = config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]) out_dims = [dataset.n_disease + 1, len(bin_edges)] elif config['loss_type'] == 'pwe_cif': # Piecewise-exponential requires finite edges bin_edges = config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]) pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] n_bins = len(pwe_edges) - 1 # logits shape (M, K, n_bins) out_dims = [dataset.n_disease, n_bins] else: raise ValueError(f"Unknown loss type: {config['loss_type']}") # Build model if config['model_type'] == 'delphi_fork': model = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, # PAD=0, DOA=1 n_embd=config['n_embd'], n_head=config['n_head'], n_layer=config['n_layer'], n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, age_encoder_type=config['age_encoder'], pdrop=config['pdrop'], ) elif config['model_type'] == 'sap_delphi': model = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=config['n_embd'], n_head=config['n_head'], n_layer=config['n_layer'], n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, age_encoder_type=config['age_encoder'], pdrop=config['pdrop'], pretrained_weights_path=config.get('pretrained_emd_path'), ) else: raise ValueError(f"Unknown model type: {config['model_type']}") # Build head head = SimpleHead( n_embd=config['n_embd'], out_dims=out_dims, ) # Load model weights (checkpoint contains model and head state dicts) model_path = run_path / 'best_model.pt' checkpoint = torch.load(model_path, map_location=device) # The checkpoint is a dict with 'model_state_dict' and 'head_state_dict' if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) head.load_state_dict(checkpoint['head_state_dict']) print("Loaded model and head from checkpoint") else: raise ValueError( "Checkpoint format not recognized. Expected 'model_state_dict' and 'head_state_dict' keys.") model.to(device) head.to(device) # Build loss function if config['loss_type'] == 'exponential': loss_fn = ExponentialNLLLoss( lambda_reg=config.get('lambda_reg', 0.0) ) elif config['loss_type'] == 'discrete_time_cif': loss_fn = DiscreteTimeCIFNLLLoss( bin_edges=config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]), lambda_reg=config.get('lambda_reg', 0.0), ) elif config['loss_type'] == 'pwe_cif': loss_fn = PiecewiseExponentialCIFNLLLoss( bin_edges=config.get( 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]), lambda_reg=config.get('lambda_reg', 0.0), ) else: raise ValueError(f"Unknown loss type: {config['loss_type']}") return model, head, loss_fn, dataset, config def print_summary(results: Dict): """ Print summary of evaluation results. Args: results: Results dictionary """ print("\n" + "=" * 80) print("EVALUATION SUMMARY") print("=" * 80) for landmark in results['landmarks']: age = landmark['age_cutoff'] horizon = landmark['horizon'] print(f"\nLandmark: Age {age}, Horizon {horizon}y") print("-" * 40) # Complete-case results if 'complete_case' in landmark and landmark['complete_case']: cc = landmark['complete_case'] print(f" Complete-Case Track:") print(f" Patients: {cc['n_patients']}") print(f" Mean AUC: {cc['mean_auc']:.4f}") print(f" Brier Score: {cc['brier_score']:.4f}") print(f" Brier Skill Score: {cc['brier_skill_score']:.4f}") # Show top-K capture rates (average across diseases) if 'disease_capture_at_k' in cc: print(f" Disease Capture:") for k in [5, 10, 20, 50]: if k in cc['disease_capture_at_k']: rates = list(cc['disease_capture_at_k'][k].values()) mean_rate = np.nanmean(rates) print(f" Top-{k}: {mean_rate:.3f}") # Show lift and yield if 'lift_and_yield' in cc: print(f" Lift & Yield:") overall = cc['lift_and_yield'].get('overall', {}) if isinstance( cc['lift_and_yield'], dict) else {} for frac in [0.01, 0.05, 0.10]: if frac in overall: lift = overall[frac].get('lift', np.nan) yield_val = overall[frac].get('yield', np.nan) print( f" Overall Top {int(frac*100)}%: Lift={lift:.2f}, Yield={yield_val:.3f}") # Clean control results if 'clean_control' in landmark and landmark['clean_control']: clean = landmark['clean_control'] print(f" Clean-Control Track:") print(f" Patients: {clean['n_patients']}") print(f" Mean AUC: {clean['mean_auc']:.4f}") print("\n" + "=" * 80) def main(): parser = argparse.ArgumentParser( description='Evaluate longitudinal health prediction model using landmark analysis' ) parser.add_argument( '--run_dir', type=str, required=True, help='Path to run directory containing train_config.json and best_model.pt' ) parser.add_argument( '--output', type=str, default=None, help='Output path for results JSON (default: /evaluation_results.json)' ) parser.add_argument( '--out_dir', type=str, default=None, help='Directory to write CSV outputs (default: /evaluation_outputs)' ) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run evaluation on' ) parser.add_argument( '--batch_size', type=int, default=256, help='Batch size for evaluation' ) parser.add_argument( '--num_workers', type=int, default=4, help='Number of data loader workers' ) args = parser.parse_args() # Load model and dataset model, head, loss_fn, dataset, config = load_model_and_config( args.run_dir, args.device) # Create evaluator evaluator = LandmarkEvaluator( model=model, head=head, loss_fn=loss_fn, dataset=dataset, device=args.device, batch_size=args.batch_size, num_workers=args.num_workers, ) # Run evaluation print("\nStarting landmark analysis evaluation...") print(f"Age cutoffs: {evaluator.age_cutoffs}") print(f"Horizons: {evaluator.horizons}") results = evaluator.run_full_evaluation() # Add metadata results['metadata'] = { 'run_dir': args.run_dir, 'config': config, 'n_diseases': dataset.n_disease, 'device': args.device, } # Save results (CSV bundle + single JSON summary) if args.out_dir is None: args.out_dir = os.path.join(args.run_dir, 'evaluation_outputs') csv_paths = save_results_csv_bundle(results, args.out_dir) summary = { 'metadata': results.get('metadata', {}), 'age_cutoffs': results.get('age_cutoffs', []), 'horizons': results.get('horizons', []), 'csv_outputs': csv_paths, 'notes': { 'metrics': [ 'AUC (per disease + mean)', 'Brier Score / Brier Skill Score (complete-case only)', 'Disease-Capture@K (complete-case only)', 'Lift/Yield (complete-case only)', 'Decision Curve Analysis (complete-case only)', ], }, } if args.output is None: args.output = os.path.join(args.run_dir, 'evaluation_summary.json') save_summary_json(summary, args.output) print(f"\nWrote CSV outputs to: {args.out_dir}") print(f"Wrote summary JSON to: {args.output}") # Print summary print_summary(results) if __name__ == '__main__': main()