1777 lines
67 KiB
Python
1777 lines
67 KiB
Python
"""
|
|
Landmark Analysis Evaluation Script for Longitudinal Health Prediction Models
|
|
|
|
Implements the comprehensive evaluation framework defined in evaluate_design.md:
|
|
- Landmark analysis at age cutoffs {50, 60, 70}
|
|
- Prediction horizons {0.25, 0.5, 1, 2, 5, 10} years
|
|
- Two tracks: Complete-Case (primary) and Clean Control (academic benchmark)
|
|
- Metrics: AUC, Brier Score, Disease-Capture@K, Lift, Yield, DCA
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader, Subset
|
|
from tqdm import tqdm
|
|
from sklearn.metrics import roc_auc_score, brier_score_loss
|
|
|
|
# Import model components
|
|
from model import DelphiFork, SapDelphi, SimpleHead
|
|
from dataset import HealthDataset, health_collate_fn
|
|
from losses import (
|
|
ExponentialNLLLoss,
|
|
DiscreteTimeCIFNLLLoss,
|
|
PiecewiseExponentialCIFNLLLoss
|
|
)
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
|
|
"""Best-effort torch.compile() wrapper (PyTorch 2.x).
|
|
|
|
Notes:
|
|
- Some PyTorch builds run compiled graphs via CUDA Graphs in certain modes.
|
|
If you keep references to graph outputs across steps, PyTorch may raise:
|
|
"accessing tensor output of CUDAGraphs that has been overwritten".
|
|
- We default to settings that avoid cudagraph output-lifetime pitfalls.
|
|
"""
|
|
if not enabled:
|
|
return module
|
|
try:
|
|
torch_compile = getattr(torch, "compile", None)
|
|
if torch_compile is None:
|
|
return module
|
|
# Prefer a safer mode for evaluation code; best-effort disable cudagraphs.
|
|
kwargs = {"mode": "default"}
|
|
try:
|
|
kwargs["options"] = {"triton.cudagraphs": False}
|
|
except Exception:
|
|
pass
|
|
return torch_compile(module, **kwargs)
|
|
except Exception:
|
|
return module
|
|
|
|
|
|
def _maybe_cudagraph_mark_step_begin() -> None:
|
|
"""Best-effort step marker for CUDA Graphs compiled execution."""
|
|
try:
|
|
compiler_mod = getattr(torch, "compiler", None)
|
|
if compiler_mod is None:
|
|
return
|
|
mark = getattr(compiler_mod, "cudagraph_mark_step_begin", None)
|
|
if mark is None:
|
|
return
|
|
mark()
|
|
except Exception:
|
|
return
|
|
|
|
|
|
def _ensure_dir(path: str) -> str:
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
|
|
def _to_float(x):
|
|
try:
|
|
return float(x)
|
|
except Exception:
|
|
return np.nan
|
|
|
|
|
|
def compute_disease_capture_at_k_fast(
|
|
y_true: np.ndarray,
|
|
y_scores: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
top_k_list: List[int],
|
|
) -> Dict[int, Dict[int, float]]:
|
|
"""Vectorized Disease-Capture@K.
|
|
|
|
Definition: for each disease d, among valid positives (y_true==1 and valid_mask),
|
|
capture@K is the fraction whose predicted top-K diseases contain d.
|
|
|
|
This implementation avoids per-positive full argsorts by computing top-k_max once
|
|
per sample (using argpartition), sorting those k_max indices by score, and then
|
|
aggregating hits via bincount.
|
|
"""
|
|
|
|
if y_scores.ndim != 2:
|
|
raise ValueError(
|
|
f"Expected y_scores 2D (N,K), got shape={y_scores.shape}")
|
|
if y_true.shape != y_scores.shape or valid_mask.shape != y_scores.shape:
|
|
raise ValueError(
|
|
f"Shape mismatch: y_true={y_true.shape}, y_scores={y_scores.shape}, valid_mask={valid_mask.shape}")
|
|
|
|
N, K = y_scores.shape
|
|
top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0})
|
|
capture_rates: Dict[int, Dict[int, float]] = {
|
|
int(k): {} for k in top_k_list}
|
|
if N == 0 or K == 0 or len(top_k_list) == 0:
|
|
return capture_rates
|
|
|
|
topk_max = min(max(top_k_list), K)
|
|
|
|
# Valid positives per disease are the denominator.
|
|
pos_valid = (y_true == 1) & valid_mask.astype(bool)
|
|
denom = pos_valid.sum(axis=0).astype(np.int64) # (K,)
|
|
|
|
# Compute top-k_max indices per sample once (unordered), then sort those indices by score.
|
|
part = np.argpartition(y_scores, -topk_max,
|
|
axis=1)[:, -topk_max:] # (N, topk_max)
|
|
part_scores = np.take_along_axis(y_scores, part, axis=1)
|
|
order = np.argsort(part_scores, axis=1)[:, ::-1]
|
|
topk_idx = np.take_along_axis(
|
|
part, order, axis=1).astype(np.int32) # (N, topk_max)
|
|
|
|
rows = np.arange(N)[:, None]
|
|
|
|
for k_val in top_k_list:
|
|
k_eff = min(int(k_val), topk_max)
|
|
idx_k = topk_idx[:, :k_eff] # (N, k_eff)
|
|
|
|
# For each sample, we count a "hit" for disease d when:
|
|
# d is in top-K (true by construction for idx_k)
|
|
# AND sample is a valid positive for disease d.
|
|
hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool
|
|
hit_diseases = idx_k[hits_mask]
|
|
hits = np.bincount(hit_diseases, minlength=K).astype(np.int64)
|
|
|
|
# Convert to dict with NaNs for diseases with no valid positives.
|
|
out_k: Dict[int, float] = {}
|
|
with np.errstate(divide='ignore', invalid='ignore'):
|
|
frac = hits / denom
|
|
for d in range(K):
|
|
out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan')
|
|
capture_rates[int(k_val)] = out_k
|
|
|
|
return capture_rates
|
|
|
|
|
|
def save_summary_json(summary: Dict, output_path: str) -> None:
|
|
"""Save a single JSON summary file."""
|
|
|
|
def convert_to_serializable(obj):
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
if isinstance(obj, (np.integer,)):
|
|
return int(obj)
|
|
if isinstance(obj, (np.floating,)):
|
|
return float(obj)
|
|
if isinstance(obj, dict):
|
|
return {k: convert_to_serializable(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [convert_to_serializable(v) for v in obj]
|
|
return obj
|
|
|
|
summary_serializable = convert_to_serializable(summary)
|
|
with open(output_path, 'w') as f:
|
|
json.dump(summary_serializable, f, indent=2)
|
|
|
|
|
|
def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]:
|
|
"""Save evaluation results into multiple CSV files.
|
|
|
|
Produces long-form tables so they are easy to analyze/plot:
|
|
- landmarks_summary.csv
|
|
- auc_per_disease.csv
|
|
- capture_at_k.csv
|
|
- lift_yield.csv
|
|
- dca.csv
|
|
|
|
Returns:
|
|
Mapping from logical name to file path.
|
|
"""
|
|
|
|
out_dir = _ensure_dir(out_dir)
|
|
|
|
summary_rows: List[Dict] = []
|
|
auc_rows: List[Dict] = []
|
|
capture_rows: List[Dict] = []
|
|
lift_rows: List[Dict] = []
|
|
dca_rows: List[Dict] = []
|
|
|
|
landmarks = results.get('landmarks', [])
|
|
for lm in landmarks:
|
|
age = lm.get('age_cutoff')
|
|
horizon = lm.get('horizon')
|
|
for track in ['complete_case', 'clean_control']:
|
|
track_res = lm.get(track) or {}
|
|
if not track_res:
|
|
continue
|
|
|
|
summary_rows.append({
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'track': track,
|
|
'n_patients': track_res.get('n_patients', np.nan),
|
|
'n_valid': track_res.get('n_valid', np.nan),
|
|
'n_valid_patients': track_res.get('n_valid_patients', np.nan),
|
|
'mean_auc': track_res.get('mean_auc', np.nan),
|
|
'brier_score': track_res.get('brier_score', np.nan),
|
|
'brier_skill_score': track_res.get('brier_skill_score', np.nan),
|
|
})
|
|
|
|
auc_per_disease = track_res.get('auc_per_disease') or {}
|
|
for disease_idx, auc in auc_per_disease.items():
|
|
auc_rows.append({
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'track': track,
|
|
'disease_idx': int(disease_idx),
|
|
'auc': _to_float(auc),
|
|
})
|
|
|
|
if track == 'complete_case':
|
|
capture = track_res.get('disease_capture_at_k') or {}
|
|
for k_val, per_disease in capture.items():
|
|
k_int = int(k_val)
|
|
for disease_idx, rate in (per_disease or {}).items():
|
|
capture_rows.append({
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'track': track,
|
|
'k': k_int,
|
|
'disease_idx': int(disease_idx),
|
|
'capture_rate': _to_float(rate),
|
|
})
|
|
|
|
lift_yield = track_res.get('lift_and_yield') or {}
|
|
overall = (lift_yield.get('overall') or {}) if isinstance(
|
|
lift_yield, dict) else {}
|
|
for frac, metrics in overall.items():
|
|
lift_rows.append({
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'track': track,
|
|
'level': 'overall',
|
|
'disease_idx': '',
|
|
'workload_frac': _to_float(frac),
|
|
'lift': _to_float((metrics or {}).get('lift')),
|
|
'yield': _to_float((metrics or {}).get('yield')),
|
|
})
|
|
|
|
per_disease = (lift_yield.get('per_disease') or {}
|
|
) if isinstance(lift_yield, dict) else {}
|
|
for disease_idx, disease_metrics in per_disease.items():
|
|
for frac, metrics in (disease_metrics or {}).items():
|
|
lift_rows.append({
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'track': track,
|
|
'level': 'per_disease',
|
|
'disease_idx': int(disease_idx),
|
|
'workload_frac': _to_float(frac),
|
|
'lift': _to_float((metrics or {}).get('lift')),
|
|
'yield': _to_float((metrics or {}).get('yield')),
|
|
})
|
|
|
|
dca = track_res.get('dca') or {}
|
|
thresholds = dca.get('thresholds')
|
|
net_benefit = dca.get('net_benefit')
|
|
if thresholds is not None and net_benefit is not None:
|
|
thresholds_arr = np.asarray(thresholds, dtype=np.float64)
|
|
nb_arr = np.asarray(net_benefit, dtype=np.float64)
|
|
for thr, nb in zip(thresholds_arr, nb_arr):
|
|
dca_rows.append({
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'track': track,
|
|
'threshold': float(thr),
|
|
'net_benefit': float(nb),
|
|
})
|
|
|
|
paths: Dict[str, str] = {}
|
|
|
|
df_summary = pd.DataFrame(summary_rows)
|
|
summary_path = os.path.join(out_dir, 'landmarks_summary.csv')
|
|
df_summary.to_csv(summary_path, index=False)
|
|
paths['landmarks_summary'] = summary_path
|
|
|
|
df_auc = pd.DataFrame(auc_rows)
|
|
auc_path = os.path.join(out_dir, 'auc_per_disease.csv')
|
|
df_auc.to_csv(auc_path, index=False)
|
|
paths['auc_per_disease'] = auc_path
|
|
|
|
df_capture = pd.DataFrame(capture_rows)
|
|
capture_path = os.path.join(out_dir, 'capture_at_k.csv')
|
|
df_capture.to_csv(capture_path, index=False)
|
|
paths['capture_at_k'] = capture_path
|
|
|
|
df_lift = pd.DataFrame(lift_rows)
|
|
lift_path = os.path.join(out_dir, 'lift_yield.csv')
|
|
df_lift.to_csv(lift_path, index=False)
|
|
paths['lift_yield'] = lift_path
|
|
|
|
df_dca = pd.DataFrame(dca_rows)
|
|
dca_path = os.path.join(out_dir, 'dca.csv')
|
|
df_dca.to_csv(dca_path, index=False)
|
|
paths['dca'] = dca_path
|
|
|
|
return paths
|
|
|
|
|
|
class LandmarkEvaluator:
|
|
"""
|
|
Comprehensive landmark analysis evaluator for survival/competing risks models.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: torch.nn.Module,
|
|
head: torch.nn.Module,
|
|
loss_fn: torch.nn.Module,
|
|
dataset: HealthDataset,
|
|
eval_indices: Optional[List[int]] = None,
|
|
device: str = 'cuda',
|
|
batch_size: int = 256,
|
|
num_workers: int = 4,
|
|
compile_model: bool = True,
|
|
check_capture_at_k: bool = False,
|
|
profile_metrics: bool = False,
|
|
capture_check_n: int = 200,
|
|
):
|
|
self.model = model.to(device).eval()
|
|
self.head = head.to(device).eval()
|
|
self.loss_fn = loss_fn
|
|
self.dataset = dataset
|
|
self.eval_indices = eval_indices
|
|
self.device = device
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
|
|
self.check_capture_at_k = bool(check_capture_at_k)
|
|
self.profile_metrics = bool(profile_metrics)
|
|
self.capture_check_n = int(capture_check_n)
|
|
self._did_capture_check = False
|
|
|
|
use_cuda = str(self.device).startswith(
|
|
"cuda") and torch.cuda.is_available()
|
|
if use_cuda:
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
try:
|
|
torch.set_float32_matmul_precision("high")
|
|
except Exception:
|
|
pass
|
|
|
|
# JIT/compile optimization (best effort)
|
|
if compile_model and use_cuda:
|
|
self.model = _maybe_torch_compile(self.model, enabled=True)
|
|
self.head = _maybe_torch_compile(self.head, enabled=True)
|
|
|
|
# Evaluation parameters from design doc
|
|
self.age_cutoffs = [50, 60, 70]
|
|
self.horizons = [0.25, 0.5, 1, 2, 5, 10]
|
|
self.top_k_values = [5, 10, 20, 50]
|
|
self.workload_fracs = [0.01, 0.05, 0.10, 0.20, 0.50]
|
|
|
|
# Convert age to days for comparison
|
|
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
|
|
self.horizons_days = [h * 365.25 for h in self.horizons]
|
|
|
|
@staticmethod
|
|
def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor:
|
|
"""Compute last observed (non-padding) time per patient."""
|
|
real_mask = event_batch >= 1
|
|
masked = time_batch.masked_fill(~real_mask, float('-inf'))
|
|
return masked.max(dim=1).values
|
|
|
|
@staticmethod
|
|
def _anchor_indices(
|
|
time_batch: torch.Tensor,
|
|
event_batch: torch.Tensor,
|
|
cutoff_days: float,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Find anchor index/time: last valid record before cutoff."""
|
|
real_mask = event_batch >= 1
|
|
before = time_batch < cutoff_days
|
|
valid_before = real_mask & before
|
|
has_anchor = valid_before.any(dim=1)
|
|
|
|
# argmax of position under mask gives last True position
|
|
L = event_batch.size(1)
|
|
pos = torch.arange(L, device=event_batch.device).view(1, L)
|
|
anchor_idx = (valid_before.to(torch.long) *
|
|
pos).max(dim=1).values.to(torch.long)
|
|
t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1)
|
|
return has_anchor, anchor_idx, t_anchor
|
|
|
|
def _labels_and_validity_for_cutoff(
|
|
self,
|
|
time_batch: torch.Tensor,
|
|
event_batch: torch.Tensor,
|
|
cutoff_days: float,
|
|
horizons_days: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Vectorized label + validity computation for all horizons at a cutoff.
|
|
|
|
Returns:
|
|
labels: (B, H, K) float32 {0,1}
|
|
valid_cc: (B, H, K) bool
|
|
valid_clean: (B, H, K) bool
|
|
"""
|
|
|
|
n_tech_tokens = 2
|
|
K = int(self.dataset.n_disease)
|
|
death_code = int(K - 1)
|
|
|
|
B, L = event_batch.shape
|
|
H = int(horizons_days.numel())
|
|
|
|
# Disease token mask and indices
|
|
is_disease = event_batch >= n_tech_tokens
|
|
disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1)
|
|
|
|
# ever_has_disease: (B, K)
|
|
ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device)
|
|
if is_disease.any():
|
|
b_idx, t_idx = is_disease.nonzero(as_tuple=True)
|
|
d_idx = disease_idx[b_idx, t_idx]
|
|
ever[b_idx, d_idx] = True
|
|
|
|
# Events within horizon windows: (B, L, H)
|
|
offset = time_batch - float(cutoff_days)
|
|
within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & (
|
|
offset.unsqueeze(-1) <= horizons_days.view(1, 1, H)
|
|
)
|
|
|
|
labels_bool = torch.zeros(
|
|
(B, H, K), dtype=torch.bool, device=event_batch.device)
|
|
if within.any():
|
|
b2, t2, h2 = within.nonzero(as_tuple=True)
|
|
d2 = disease_idx[b2, t2]
|
|
labels_bool[b2, h2, d2] = True
|
|
|
|
labels = labels_bool.to(torch.float32)
|
|
|
|
last_time = self._last_time(time_batch, event_batch) # (B,)
|
|
horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H)
|
|
|
|
death_in_horizon = labels_bool[:, :, death_code] # (B, H)
|
|
observed_past_horizon = last_time.view(B, 1) > horizon_end
|
|
lost_within_horizon = last_time.view(B, 1) <= horizon_end
|
|
|
|
# Track A (Complete-Case):
|
|
# - if observed past horizon OR death in horizon => valid all diseases
|
|
# - else (censored within horizon) => valid only for diseases that occurred within horizon
|
|
valid_cc = labels_bool.clone()
|
|
full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1)
|
|
if full_mask.any():
|
|
valid_cc = torch.where(
|
|
full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc)
|
|
|
|
# Track B (Clean-Control) per disease:
|
|
# valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window)
|
|
never = ~ever # (B, K)
|
|
valid_clean = (~death_in_horizon).unsqueeze(-1) & (
|
|
labels_bool | (never.unsqueeze(1) & (
|
|
~lost_within_horizon).unsqueeze(-1))
|
|
)
|
|
|
|
return labels, valid_cc, valid_clean
|
|
|
|
def _compute_risk_scores_many_horizons(
|
|
self,
|
|
logits: torch.Tensor,
|
|
t_start_days: torch.Tensor,
|
|
horizons_days: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Compute risk increments for all horizons in one vectorized call.
|
|
|
|
Args:
|
|
logits: model head outputs for anchor points.
|
|
t_start_days: (B,) time from anchor to cutoff (days).
|
|
horizons_days: (H,) horizons in days.
|
|
|
|
Returns:
|
|
risk: (B, H, K) float32
|
|
"""
|
|
t_start_days = torch.clamp(t_start_days, min=0)
|
|
t_end_days = torch.clamp(t_start_days.unsqueeze(
|
|
1) + horizons_days.view(1, -1), min=0)
|
|
|
|
t_query_years = torch.cat([t_start_days.unsqueeze(
|
|
1), t_end_days], dim=1) / 365.25 # (B, H+1)
|
|
|
|
# calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T)
|
|
if hasattr(self.loss_fn, "calculate_cifs"):
|
|
cifs = self.loss_fn.calculate_cifs(
|
|
logits, t_query_years, return_survival=False)
|
|
else:
|
|
raise ValueError(
|
|
f"Loss function does not support calculate_cifs: {type(self.loss_fn)}")
|
|
|
|
if cifs.ndim == 2:
|
|
# (B, K) -> (B, 1, K)
|
|
cifs_bt_k = cifs.unsqueeze(1)
|
|
else:
|
|
# (B, K, T) -> (B, T, K)
|
|
cifs_bt_k = cifs.permute(0, 2, 1)
|
|
|
|
cif_start = cifs_bt_k[:, :1, :] # (B, 1, K)
|
|
cif_end = cifs_bt_k[:, 1:, :] # (B, H, K)
|
|
risk = torch.clamp(cif_end - cif_start, min=0)
|
|
return risk
|
|
|
|
@torch.no_grad()
|
|
def compute_risk_scores(
|
|
self,
|
|
indices: List[int],
|
|
age_cutoff_days: float,
|
|
horizon_days: float,
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
"""
|
|
Compute risk scores for specified patient indices at given landmark.
|
|
|
|
Args:
|
|
indices: Patient indices to evaluate
|
|
age_cutoff_days: Age cutoff in days (T_cut)
|
|
horizon_days: Prediction horizon in days (H)
|
|
|
|
Returns:
|
|
risk_scores: (N, K) array of risk scores per disease
|
|
t_anchors: (N,) array of anchor times (days from birth)
|
|
valid_mask: (N,) boolean array indicating valid predictions
|
|
"""
|
|
subset = Subset(self.dataset, indices)
|
|
loader = DataLoader(
|
|
subset,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
collate_fn=health_collate_fn,
|
|
num_workers=self.num_workers,
|
|
pin_memory=True if self.device == 'cuda' else False,
|
|
)
|
|
|
|
all_risk_scores = []
|
|
all_t_anchors = []
|
|
all_valid_mask = []
|
|
|
|
for batch in loader:
|
|
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
|
|
event_batch = event_batch.to(self.device)
|
|
time_batch = time_batch.to(self.device)
|
|
cont_batch = cont_batch.to(self.device)
|
|
cate_batch = cate_batch.to(self.device)
|
|
sex_batch = sex_batch.to(self.device)
|
|
|
|
# Find anchor point: last valid record before age_cutoff
|
|
# Valid records are non-padding events (event >= 1)
|
|
valid_mask = event_batch >= 1 # (B, L)
|
|
before_cutoff = time_batch < age_cutoff_days # (B, L)
|
|
valid_before = valid_mask & before_cutoff # (B, L)
|
|
|
|
# Find last valid position for each patient
|
|
batch_size = event_batch.size(0)
|
|
t_anchor = torch.zeros(batch_size, device=self.device)
|
|
anchor_idx = torch.zeros(
|
|
batch_size, dtype=torch.long, device=self.device)
|
|
has_anchor = torch.zeros(
|
|
batch_size, dtype=torch.bool, device=self.device)
|
|
|
|
for b in range(batch_size):
|
|
valid_positions = valid_before[b].nonzero(as_tuple=True)[0]
|
|
if len(valid_positions) > 0:
|
|
last_pos = valid_positions[-1]
|
|
anchor_idx[b] = last_pos
|
|
t_anchor[b] = time_batch[b, last_pos]
|
|
has_anchor[b] = True
|
|
|
|
# Get model predictions at anchor points
|
|
if has_anchor.any():
|
|
# If torch.compile uses CUDA Graphs under the hood, mark a new step
|
|
# before each compiled invocation to avoid output lifetime issues.
|
|
_maybe_cudagraph_mark_step_begin()
|
|
# Forward pass
|
|
hidden = self.model(event_batch, time_batch,
|
|
sex_batch, cont_batch, cate_batch)
|
|
|
|
# Get predictions at anchor positions
|
|
batch_indices = torch.arange(batch_size, device=self.device)
|
|
# (B, n_embd)
|
|
hidden_at_anchor = hidden[batch_indices, anchor_idx]
|
|
|
|
# Compute logits using the loaded head
|
|
# (B, n_disease, ...) or (B, K+1, n_bins+1) etc.
|
|
logits = self.head(hidden_at_anchor)
|
|
|
|
# Compute CIF scores
|
|
# Time gap from anchor to start of horizon
|
|
t_start = age_cutoff_days - t_anchor # (B,)
|
|
# Time gap from anchor to end of horizon
|
|
t_end = age_cutoff_days + horizon_days - t_anchor # (B,)
|
|
|
|
# Ensure non-negative time gaps
|
|
t_start = torch.clamp(t_start, min=0)
|
|
t_end = torch.clamp(t_end, min=0)
|
|
|
|
# Calculate CIF at both time points
|
|
cif_start = self._compute_cif(logits, t_start) # (B, K)
|
|
cif_end = self._compute_cif(logits, t_end) # (B, K)
|
|
|
|
# Risk score is the increment within the horizon
|
|
risk_scores = cif_end - cif_start # (B, K)
|
|
risk_scores = torch.clamp(
|
|
risk_scores, min=0) # Ensure non-negative
|
|
|
|
else:
|
|
# No valid anchor points in this batch
|
|
risk_scores = torch.zeros(
|
|
batch_size, self.dataset.n_disease, device=self.device)
|
|
|
|
all_risk_scores.append(risk_scores.cpu().numpy())
|
|
all_t_anchors.append(t_anchor.cpu().numpy())
|
|
all_valid_mask.append(has_anchor.cpu().numpy())
|
|
|
|
# Concatenate results
|
|
risk_scores = np.vstack(all_risk_scores) # (N, K)
|
|
t_anchors = np.concatenate(all_t_anchors) # (N,)
|
|
valid_mask = np.concatenate(all_valid_mask) # (N,)
|
|
|
|
return risk_scores, t_anchors, valid_mask
|
|
|
|
def _compute_cif(self, logits: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute Cumulative Incidence Function at time t.
|
|
|
|
Args:
|
|
logits: Model output logits (B, K, ...) depending on loss type
|
|
t: Time points (B,) in years from anchor
|
|
|
|
Returns:
|
|
cif: (B, K) cumulative incidence probabilities
|
|
"""
|
|
t_years = t / 365.25 # Convert to years
|
|
|
|
if isinstance(self.loss_fn, ExponentialNLLLoss):
|
|
# Exponential: logits are (B, K)
|
|
lambdas = F.softplus(logits) + 1e-6 # (B, K)
|
|
total_lambda = lambdas.sum(dim=-1, keepdim=True) # (B, 1)
|
|
|
|
# CIF_k(t) = (λ_k / Σλ) * (1 - exp(-Σλ * t))
|
|
frac = lambdas / total_lambda # (B, K)
|
|
exp_term = 1.0 - \
|
|
torch.exp(-total_lambda.squeeze(-1).unsqueeze(-1)
|
|
* t_years.unsqueeze(-1))
|
|
cif = frac * exp_term # (B, K)
|
|
|
|
elif isinstance(self.loss_fn, DiscreteTimeCIFNLLLoss):
|
|
# Discrete-time CIF: use calculate_cifs method
|
|
cif = self.loss_fn.calculate_cifs(
|
|
logits, t_years, return_survival=False)
|
|
|
|
elif isinstance(self.loss_fn, PiecewiseExponentialCIFNLLLoss):
|
|
# PWE CIF: use calculate_cifs method
|
|
cif = self.loss_fn.calculate_cifs(
|
|
logits, t_years, return_survival=False)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {type(self.loss_fn)}")
|
|
|
|
return cif
|
|
|
|
def prepare_evaluation_cohort(
|
|
self,
|
|
age_cutoff_days: float,
|
|
horizon_days: float,
|
|
track: str = 'complete_case',
|
|
) -> Tuple[List[int], np.ndarray, np.ndarray]:
|
|
"""Prepare evaluation cohort per design protocol with per-disease validity.
|
|
|
|
Key fixes vs the earlier implementation:
|
|
- Multi-disease labeling: mark *all* diseases that occur within the horizon.
|
|
- Per-disease validity mask: mask[i, k]=1 if patient i is valid for disease k.
|
|
|
|
Returns:
|
|
indices: list of patient indices included in this cohort table
|
|
labels: (N, K) float array
|
|
valid_mask: (N, K) bool array
|
|
"""
|
|
|
|
n_tech_tokens = 2 # PAD=0, DOA=1
|
|
K = int(self.dataset.n_disease)
|
|
|
|
# Competing risk token: treat death as competing event.
|
|
# As requested: DEATH_CODE is the last disease index.
|
|
DEATH_CODE = int(K - 1)
|
|
|
|
horizon_end_days = age_cutoff_days + horizon_days
|
|
|
|
indices: List[int] = []
|
|
labels_rows: List[np.ndarray] = []
|
|
valid_rows: List[np.ndarray] = []
|
|
|
|
candidate_indices = self.eval_indices if self.eval_indices is not None else list(
|
|
range(len(self.dataset)))
|
|
for idx in candidate_indices:
|
|
patient_id = self.dataset.patient_ids[idx]
|
|
records = self.dataset.patient_events.get(patient_id, [])
|
|
if not records:
|
|
continue
|
|
|
|
# Must have some information strictly prior to cutoff (anchor existence in the data).
|
|
has_pre_cutoff = any(t < age_cutoff_days for t, _ in records)
|
|
if not has_pre_cutoff:
|
|
continue
|
|
|
|
# Events after cutoff (already sorted in dataset init)
|
|
events_after = [(t, e) for t, e in records if t >= age_cutoff_days]
|
|
|
|
labels = np.zeros(K, dtype=np.float32)
|
|
valid = np.zeros(K, dtype=bool)
|
|
|
|
# Identify diseases within horizon (multi-label; no early break)
|
|
diseases_in_horizon: set[int] = set()
|
|
death_in_horizon = False
|
|
for t, e in events_after:
|
|
if t > horizon_end_days:
|
|
break # events are time-sorted
|
|
if e < n_tech_tokens:
|
|
continue
|
|
disease_idx = int(e - n_tech_tokens)
|
|
if 0 <= disease_idx < K:
|
|
diseases_in_horizon.add(disease_idx)
|
|
if disease_idx == DEATH_CODE:
|
|
death_in_horizon = True
|
|
|
|
for d in diseases_in_horizon:
|
|
labels[d] = 1.0
|
|
|
|
last_time = float(records[-1][0])
|
|
|
|
if track == 'complete_case':
|
|
# Track A: Complete-Case at Horizon
|
|
# - Hit: disease occurs within horizon => valid positive for that disease
|
|
# - Healthy: last record > horizon_end => valid negative for all diseases
|
|
# - Death: death within horizon => valid negative for diseases not occurred before death
|
|
# (we implement as valid for all diseases, with labels marking any hits incl death)
|
|
# - Loss: censored within horizon (no hit for disease) => invalid for that disease
|
|
|
|
if last_time > horizon_end_days:
|
|
# Observed past horizon: negative for all non-hit diseases
|
|
valid[:] = True
|
|
elif death_in_horizon:
|
|
# Competing risk: include as explicit negatives (label 0) for diseases not hit
|
|
# and positives for diseases hit before death (we don't model ordering here).
|
|
# This matches the requirement: death within window is label=0 for target diseases.
|
|
valid[:] = True
|
|
else:
|
|
# Lost within horizon: only diseases that actually occurred in the horizon are valid positives
|
|
# (no assumptions about negatives).
|
|
for d in diseases_in_horizon:
|
|
valid[d] = True
|
|
|
|
elif track == 'clean_control':
|
|
# Track B: Clean Control (per-disease)
|
|
# For each disease k:
|
|
# - Hit: disease k occurs within horizon => label=1, valid
|
|
# - Pure clean for k: disease k never occurs in entire record => label=0, valid
|
|
# - Death within window => drop (invalid) for all k
|
|
# - Loss within window => drop (invalid) for all k
|
|
# - Late onset of k (after horizon) => drop (invalid) for k
|
|
|
|
if death_in_horizon:
|
|
# Drop all diseases for this patient
|
|
valid[:] = False
|
|
else:
|
|
# Precompute per-disease first occurrence time after cutoff (including after horizon)
|
|
first_occ_after_cutoff = np.full(
|
|
K, np.inf, dtype=np.float64)
|
|
ever_has_disease = np.zeros(K, dtype=bool)
|
|
for t, e in records:
|
|
if e < n_tech_tokens:
|
|
continue
|
|
disease_idx = int(e - n_tech_tokens)
|
|
if 0 <= disease_idx < K:
|
|
ever_has_disease[disease_idx] = True
|
|
if t >= age_cutoff_days and t < first_occ_after_cutoff[disease_idx]:
|
|
first_occ_after_cutoff[disease_idx] = float(t)
|
|
|
|
# If lost within horizon and did not have disease k in horizon, it's invalid for k.
|
|
lost_within_horizon = last_time <= horizon_end_days
|
|
|
|
for k in range(K):
|
|
if labels[k] == 1.0:
|
|
valid[k] = True
|
|
continue
|
|
|
|
if not ever_has_disease[k]:
|
|
# Lifetime clean for k
|
|
# Still need to have complete follow-up within horizon for clean-control track.
|
|
# If censored within horizon, we cannot be sure k didn't occur in the window.
|
|
valid[k] = not lost_within_horizon
|
|
continue
|
|
|
|
# Has disease k at some point.
|
|
t_first = first_occ_after_cutoff[k]
|
|
if np.isfinite(t_first) and t_first > horizon_end_days:
|
|
# Late onset of k -> invalid for k
|
|
valid[k] = False
|
|
else:
|
|
# Has k before cutoff (prevalent) or has k after cutoff but not in horizon? -> invalid.
|
|
valid[k] = False
|
|
|
|
else:
|
|
raise ValueError(f"Unknown track: {track}")
|
|
|
|
# Keep patient row if they are valid for at least one disease.
|
|
if valid.any():
|
|
indices.append(idx)
|
|
labels_rows.append(labels)
|
|
valid_rows.append(valid)
|
|
|
|
if not indices:
|
|
return [], np.zeros((0, K), dtype=np.float32), np.zeros((0, K), dtype=bool)
|
|
|
|
return indices, np.stack(labels_rows, axis=0), np.stack(valid_rows, axis=0)
|
|
|
|
def compute_auc_per_disease(
|
|
self,
|
|
risk_scores: np.ndarray,
|
|
labels: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
) -> Dict[int, float]:
|
|
"""
|
|
Compute time-dependent AUC for each disease.
|
|
|
|
Args:
|
|
risk_scores: (N, K) risk scores
|
|
labels: (N, K) binary labels
|
|
valid_mask: (N, K) boolean mask for valid evaluations
|
|
|
|
Returns:
|
|
auc_scores: Dict mapping disease_idx -> AUC
|
|
"""
|
|
auc_scores: Dict[int, float] = {}
|
|
n_diseases = risk_scores.shape[1]
|
|
m = valid_mask.astype(bool)
|
|
|
|
# Fast pre-filter: skip diseases without both classes among valid entries.
|
|
pos = (labels == 1) & m
|
|
n_valid = m.sum(axis=0)
|
|
n_pos = pos.sum(axis=0)
|
|
n_neg = n_valid - n_pos
|
|
|
|
for k in range(n_diseases):
|
|
if n_valid[k] == 0 or n_pos[k] == 0 or n_neg[k] == 0:
|
|
auc_scores[k] = np.nan
|
|
continue
|
|
|
|
mk = m[:, k]
|
|
y_true_k = labels[mk, k]
|
|
y_score_k = risk_scores[mk, k]
|
|
try:
|
|
auc_scores[k] = float(roc_auc_score(y_true_k, y_score_k))
|
|
except Exception:
|
|
auc_scores[k] = np.nan
|
|
|
|
return auc_scores
|
|
|
|
def compute_brier_score(
|
|
self,
|
|
risk_scores: np.ndarray,
|
|
labels: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute Brier Score and Brier Skill Score.
|
|
|
|
Args:
|
|
risk_scores: (N, K) risk scores
|
|
labels: (N, K) binary labels
|
|
valid_mask: (N, K) boolean mask
|
|
|
|
Returns:
|
|
metrics: Dict with 'brier_score' and 'brier_skill_score'
|
|
"""
|
|
# Apply per-entry valid mask and flatten
|
|
m = valid_mask.astype(bool)
|
|
y_true_flat = labels[m]
|
|
y_pred_flat = risk_scores[m]
|
|
|
|
mask = ~(np.isnan(y_true_flat) | np.isnan(y_pred_flat))
|
|
y_true_flat = y_true_flat[mask]
|
|
y_pred_flat = y_pred_flat[mask]
|
|
|
|
if len(y_true_flat) == 0:
|
|
return {'brier_score': np.nan, 'brier_skill_score': np.nan}
|
|
|
|
bs = brier_score_loss(y_true_flat, y_pred_flat)
|
|
|
|
# Brier Skill Score: reference is predicting the mean
|
|
p_mean = y_true_flat.mean()
|
|
bs_ref = ((p_mean - y_true_flat) ** 2).mean()
|
|
bss = 1.0 - (bs / bs_ref) if bs_ref > 0 else 0.0
|
|
|
|
return {'brier_score': bs, 'brier_skill_score': bss}
|
|
|
|
def compute_disease_capture_at_k(
|
|
self,
|
|
risk_scores: np.ndarray,
|
|
labels: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
) -> Dict[int, Dict[int, float]]:
|
|
"""
|
|
Compute Disease-Capture@K: fraction of true positives where the true
|
|
disease appears in the patient's top-K predicted risks.
|
|
|
|
Args:
|
|
risk_scores: (N, K) risk scores
|
|
labels: (N, K) binary labels
|
|
valid_mask: (N, K) boolean mask
|
|
|
|
Returns:
|
|
capture_rates: Dict[K_value][disease_idx] -> capture rate
|
|
"""
|
|
# Fast path (vectorized): compute top-k_max once per sample.
|
|
t0 = time.perf_counter() if self.profile_metrics else None
|
|
capture_fast = compute_disease_capture_at_k_fast(
|
|
y_true=labels,
|
|
y_scores=risk_scores,
|
|
valid_mask=valid_mask,
|
|
top_k_list=self.top_k_values,
|
|
)
|
|
if self.profile_metrics and t0 is not None:
|
|
dt = time.perf_counter() - t0
|
|
print(
|
|
f" [profile] capture@K fast: {dt:.3f}s (N={risk_scores.shape[0]}, K={risk_scores.shape[1]})")
|
|
|
|
# Optional correctness check against the slow reference on a small subset.
|
|
if self.check_capture_at_k and (not self._did_capture_check):
|
|
self._did_capture_check = True
|
|
rng = np.random.default_rng(0)
|
|
n_sub = min(self.capture_check_n, risk_scores.shape[0])
|
|
sub_idx = rng.choice(
|
|
risk_scores.shape[0], size=n_sub, replace=False)
|
|
rs = risk_scores[sub_idx]
|
|
ys = labels[sub_idx]
|
|
vm = valid_mask[sub_idx]
|
|
|
|
t1 = time.perf_counter()
|
|
slow = self._compute_disease_capture_at_k_slow(rs, ys, vm)
|
|
t2 = time.perf_counter()
|
|
fast = compute_disease_capture_at_k_fast(
|
|
ys, rs, vm, self.top_k_values)
|
|
t3 = time.perf_counter()
|
|
|
|
def _eq(a: float, b: float) -> bool:
|
|
if np.isnan(a) and np.isnan(b):
|
|
return True
|
|
return float(a) == float(b)
|
|
|
|
for k_val in self.top_k_values:
|
|
for d in range(rs.shape[1]):
|
|
if not _eq(slow[int(k_val)][d], fast[int(k_val)][d]):
|
|
raise AssertionError(
|
|
f"Capture@{k_val} mismatch for disease {d}: slow={slow[int(k_val)][d]} fast={fast[int(k_val)][d]}"
|
|
)
|
|
|
|
print(
|
|
f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s"
|
|
)
|
|
|
|
return capture_fast
|
|
|
|
def _compute_disease_capture_at_k_slow(
|
|
self,
|
|
risk_scores: np.ndarray,
|
|
labels: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
) -> Dict[int, Dict[int, float]]:
|
|
"""Reference implementation (slow): kept for correctness checking."""
|
|
capture_rates = {int(k): {} for k in self.top_k_values}
|
|
n_diseases = risk_scores.shape[1]
|
|
|
|
for disease_idx in range(n_diseases):
|
|
mk = valid_mask[:, disease_idx]
|
|
if not np.any(mk):
|
|
for k in self.top_k_values:
|
|
capture_rates[int(k)][disease_idx] = np.nan
|
|
continue
|
|
|
|
y_true = labels[mk, disease_idx]
|
|
y_scores = risk_scores[mk] # (N_valid_k, K)
|
|
|
|
pos_mask = y_true == 1
|
|
if pos_mask.sum() == 0:
|
|
for k in self.top_k_values:
|
|
capture_rates[int(k)][disease_idx] = np.nan
|
|
continue
|
|
|
|
pos_idx = np.where(pos_mask)[0]
|
|
for k_val in self.top_k_values:
|
|
captures = []
|
|
for i in pos_idx:
|
|
top_k_diseases = np.argsort(y_scores[i])[::-1][:int(k_val)]
|
|
captures.append(int(disease_idx in top_k_diseases))
|
|
capture_rates[int(k_val)][disease_idx] = float(
|
|
np.mean(captures)) if captures else np.nan
|
|
|
|
return capture_rates
|
|
|
|
def compute_lift_and_yield(
|
|
self,
|
|
risk_scores: np.ndarray,
|
|
labels: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
) -> Dict[str, Dict]:
|
|
"""
|
|
Compute Lift and Yield at various workload fractions.
|
|
|
|
Args:
|
|
risk_scores: (N, K) risk scores
|
|
labels: (N, K) binary labels
|
|
valid_mask: (N,) boolean mask
|
|
|
|
Returns:
|
|
metrics:
|
|
{
|
|
'overall': {workload_frac: {'lift': ..., 'yield': ...}, ...},
|
|
'per_disease': {disease_idx: {workload_frac: {'lift': ..., 'yield': ...}, ...}, ...}
|
|
}
|
|
"""
|
|
# Overall metric uses a patient-level mask to avoid including purely censored negatives.
|
|
# We include patients who either (a) have known outcomes for all diseases, or (b) have at least one hit.
|
|
has_any_hit = labels.max(axis=1) > 0
|
|
has_all_known = valid_mask.all(axis=1)
|
|
overall_patient_mask = has_any_hit | has_all_known
|
|
|
|
risk_scores_overall = risk_scores[overall_patient_mask]
|
|
labels_overall = labels[overall_patient_mask]
|
|
valid_overall = valid_mask[overall_patient_mask]
|
|
|
|
# Flatten to patient-level: any disease event
|
|
# Max risk across all diseases for each patient
|
|
max_risk_per_patient = risk_scores_overall.max(axis=1)
|
|
any_disease_label = labels_overall.max(axis=1)
|
|
|
|
n_patients = len(max_risk_per_patient)
|
|
base_rate = any_disease_label.mean() if n_patients > 0 else 0.0
|
|
|
|
overall: Dict[float, Dict[str, float]] = {}
|
|
for workload_frac in self.workload_fracs:
|
|
n_screen = max(1, int(n_patients * workload_frac))
|
|
top_n_idx = np.argsort(max_risk_per_patient)[::-1][:n_screen]
|
|
top_n_labels = any_disease_label[top_n_idx]
|
|
yield_val = float(top_n_labels.mean()) if n_screen > 0 else np.nan
|
|
lift_val = (yield_val / float(base_rate)) if base_rate > 0 else 0.0
|
|
overall[workload_frac] = {'lift': lift_val, 'yield': yield_val}
|
|
|
|
per_disease: Dict[int, Dict[float, Dict[str, float]]] = {}
|
|
n_diseases = risk_scores.shape[1]
|
|
for disease_idx in range(n_diseases):
|
|
mk = valid_mask[:, disease_idx]
|
|
if not np.any(mk):
|
|
per_disease[disease_idx] = {
|
|
frac: {'lift': np.nan, 'yield': np.nan} for frac in self.workload_fracs}
|
|
continue
|
|
|
|
disease_scores = risk_scores[mk, disease_idx]
|
|
disease_labels = labels[mk, disease_idx]
|
|
disease_base_rate = disease_labels.mean() if disease_labels.size > 0 else 0.0
|
|
|
|
n_patients_k = disease_scores.shape[0]
|
|
|
|
disease_metrics: Dict[float, Dict[str, float]] = {}
|
|
for workload_frac in self.workload_fracs:
|
|
n_screen = max(1, int(n_patients_k * workload_frac))
|
|
top_n_idx = np.argsort(disease_scores)[::-1][:n_screen]
|
|
top_n_labels = disease_labels[top_n_idx]
|
|
yield_val = float(top_n_labels.mean()
|
|
) if n_screen > 0 else np.nan
|
|
lift_val = (yield_val / float(disease_base_rate)
|
|
) if disease_base_rate > 0 else 0.0
|
|
disease_metrics[workload_frac] = {
|
|
'lift': lift_val, 'yield': yield_val}
|
|
|
|
per_disease[disease_idx] = disease_metrics
|
|
|
|
return {
|
|
'overall': overall,
|
|
'per_disease': per_disease,
|
|
}
|
|
|
|
def compute_dca_net_benefit(
|
|
self,
|
|
risk_scores: np.ndarray,
|
|
labels: np.ndarray,
|
|
valid_mask: np.ndarray,
|
|
threshold_range: np.ndarray = np.linspace(0, 0.5, 51),
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Compute Decision Curve Analysis (DCA) net benefit.
|
|
|
|
Args:
|
|
risk_scores: (N, K) risk scores
|
|
labels: (N, K) binary labels
|
|
valid_mask: (N,) boolean mask
|
|
threshold_range: Array of threshold probabilities
|
|
|
|
Returns:
|
|
dca_results: Dict with 'thresholds' and 'net_benefit' arrays
|
|
"""
|
|
# Use the same overall patient mask as lift/yield (complete-case style)
|
|
has_any_hit = labels.max(axis=1) > 0
|
|
has_all_known = valid_mask.all(axis=1)
|
|
patient_mask = has_any_hit | has_all_known
|
|
|
|
risk_scores = risk_scores[patient_mask]
|
|
labels = labels[patient_mask]
|
|
|
|
# Use max risk and any disease label
|
|
max_risk = risk_scores.max(axis=1)
|
|
any_disease = labels.max(axis=1)
|
|
|
|
n = len(max_risk)
|
|
net_benefits = []
|
|
|
|
for pt in threshold_range:
|
|
if pt == 0:
|
|
# Treat all
|
|
nb = any_disease.mean()
|
|
else:
|
|
# Treat if predicted risk > threshold
|
|
treat = max_risk >= pt
|
|
tp = (treat & (any_disease == 1)).sum()
|
|
fp = (treat & (any_disease == 0)).sum()
|
|
|
|
# Net benefit = (TP/N) - (FP/N) * (pt / (1-pt))
|
|
nb = (tp / n) - (fp / n) * (pt / (1 - pt))
|
|
|
|
net_benefits.append(nb)
|
|
|
|
return {
|
|
'thresholds': threshold_range,
|
|
'net_benefit': np.array(net_benefits),
|
|
}
|
|
|
|
def evaluate_landmark(
|
|
self,
|
|
age_cutoff: float,
|
|
horizon: float,
|
|
) -> Dict:
|
|
"""
|
|
Evaluate model at a specific landmark (age_cutoff, horizon).
|
|
|
|
Args:
|
|
age_cutoff: Age cutoff in years
|
|
horizon: Prediction horizon in years
|
|
|
|
Returns:
|
|
results: Dictionary with all metrics
|
|
"""
|
|
age_cutoff_days = age_cutoff * 365.25
|
|
horizon_days = horizon * 365.25
|
|
|
|
print(f"\nEvaluating Landmark: Age={age_cutoff}, Horizon={horizon}y")
|
|
|
|
results = {
|
|
'age_cutoff': age_cutoff,
|
|
'horizon': horizon,
|
|
'complete_case': {},
|
|
'clean_control': {},
|
|
}
|
|
|
|
for track in ['complete_case', 'clean_control']:
|
|
print(f" Track: {track}")
|
|
|
|
# Prepare cohort
|
|
indices, labels_array, valid_mask = self.prepare_evaluation_cohort(
|
|
age_cutoff_days, horizon_days, track
|
|
)
|
|
|
|
if len(indices) == 0:
|
|
print(f" No valid patients for track {track}")
|
|
continue
|
|
|
|
print(f" Cohort size: {len(indices)}")
|
|
|
|
# Compute risk scores
|
|
risk_scores, t_anchors, anchor_mask = self.compute_risk_scores(
|
|
indices, age_cutoff_days, horizon_days
|
|
)
|
|
|
|
# Combine anchor availability with per-disease validity
|
|
valid_mask = valid_mask & anchor_mask.astype(bool)[:, None]
|
|
|
|
# Compute metrics
|
|
print(" Computing AUC...")
|
|
auc_scores = self.compute_auc_per_disease(
|
|
risk_scores, labels_array, valid_mask)
|
|
mean_auc = np.nanmean(list(auc_scores.values()))
|
|
|
|
print(" Computing Brier Score...")
|
|
brier_metrics = self.compute_brier_score(
|
|
risk_scores, labels_array, valid_mask)
|
|
|
|
# Only compute patient-level and population metrics for complete_case
|
|
if track == 'complete_case':
|
|
print(" Computing Disease-Capture@K...")
|
|
capture_metrics = self.compute_disease_capture_at_k(
|
|
risk_scores, labels_array, valid_mask
|
|
)
|
|
|
|
print(" Computing Lift & Yield...")
|
|
lift_yield_metrics = self.compute_lift_and_yield(
|
|
risk_scores, labels_array, valid_mask
|
|
)
|
|
|
|
print(" Computing DCA...")
|
|
dca_metrics = self.compute_dca_net_benefit(
|
|
risk_scores, labels_array, valid_mask
|
|
)
|
|
|
|
results[track] = {
|
|
'n_patients': len(indices),
|
|
'n_valid': int(valid_mask.sum()),
|
|
'n_valid_patients': int(valid_mask.any(axis=1).sum()),
|
|
'auc_per_disease': auc_scores,
|
|
'mean_auc': mean_auc,
|
|
'brier_score': brier_metrics['brier_score'],
|
|
'brier_skill_score': brier_metrics['brier_skill_score'],
|
|
'disease_capture_at_k': capture_metrics,
|
|
'lift_and_yield': lift_yield_metrics,
|
|
'dca': dca_metrics,
|
|
}
|
|
else:
|
|
# Clean control track: only discrimination metrics
|
|
results[track] = {
|
|
'n_patients': len(indices),
|
|
'n_valid': int(valid_mask.sum()),
|
|
'n_valid_patients': int(valid_mask.any(axis=1).sum()),
|
|
'auc_per_disease': auc_scores,
|
|
'mean_auc': mean_auc,
|
|
}
|
|
|
|
return results
|
|
|
|
def run_full_evaluation(self) -> Dict:
|
|
"""Run the full evaluation using a single-pass DataLoader.
|
|
|
|
Key optimizations:
|
|
- iterate DataLoader exactly once
|
|
- run transformer backbone once per batch
|
|
- reuse hidden states per cutoff (3x head only)
|
|
- vectorize CIF/risk over all horizons in one call
|
|
"""
|
|
|
|
# Build evaluation subset loader
|
|
indices = self.eval_indices if self.eval_indices is not None else list(
|
|
range(len(self.dataset)))
|
|
subset = Subset(self.dataset, indices)
|
|
loader = DataLoader(
|
|
subset,
|
|
batch_size=self.batch_size,
|
|
shuffle=False,
|
|
collate_fn=health_collate_fn,
|
|
num_workers=self.num_workers,
|
|
pin_memory=True if str(self.device).startswith('cuda') else False,
|
|
)
|
|
|
|
cutoffs_days = torch.tensor(
|
|
# (C,)
|
|
self.age_cutoffs_days, dtype=torch.float32, device=self.device)
|
|
horizons_days = torch.tensor(
|
|
# (H,)
|
|
self.horizons_days, dtype=torch.float32, device=self.device)
|
|
C = int(cutoffs_days.numel())
|
|
H = int(horizons_days.numel())
|
|
K = int(self.dataset.n_disease)
|
|
|
|
# Buffers: store per landmark/track arrays in chunks to avoid repeated I/O.
|
|
# Each key stores lists of numpy arrays to be concatenated at the end.
|
|
buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {}
|
|
for ci in range(C):
|
|
for hi in range(H):
|
|
for track in ("complete_case", "clean_control"):
|
|
buffers[(ci, hi, track)] = {
|
|
"risk": [], "labels": [], "valid": []}
|
|
|
|
with torch.inference_mode():
|
|
for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100):
|
|
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
|
|
event_batch = event_batch.to(self.device, non_blocking=True)
|
|
time_batch = time_batch.to(self.device, non_blocking=True)
|
|
cont_batch = cont_batch.to(self.device, non_blocking=True)
|
|
cate_batch = cate_batch.to(self.device, non_blocking=True)
|
|
sex_batch = sex_batch.to(self.device, non_blocking=True)
|
|
|
|
B, L = event_batch.shape
|
|
batch_idx = torch.arange(B, device=self.device)
|
|
|
|
# Backbone once per batch
|
|
_maybe_cudagraph_mark_step_begin()
|
|
hidden = self.model(
|
|
# (B, L, D)
|
|
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
|
|
|
|
for ci in range(C):
|
|
cutoff = float(cutoffs_days[ci].item())
|
|
|
|
has_anchor, anchor_idx, t_anchor = self._anchor_indices(
|
|
time_batch, event_batch, cutoff)
|
|
if not has_anchor.any():
|
|
continue
|
|
|
|
# Hidden states at anchor positions
|
|
hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D)
|
|
logits = self.head(hidden_anchor)
|
|
|
|
# Vectorized labels/validity for all horizons
|
|
labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff(
|
|
time_batch, event_batch, cutoff, horizons_days
|
|
)
|
|
|
|
# Risk scores for all horizons (B, H, K)
|
|
t_start = torch.clamp(torch.tensor(
|
|
cutoff, device=self.device) - t_anchor, min=0)
|
|
risk_bhk = self._compute_risk_scores_many_horizons(
|
|
logits, t_start, horizons_days)
|
|
|
|
# Apply anchor constraint to validity
|
|
anchor_mask = has_anchor.view(B, 1, 1)
|
|
valid_cc_bhk = valid_cc_bhk & anchor_mask
|
|
valid_clean_bhk = valid_clean_bhk & anchor_mask
|
|
|
|
# Push per-horizon chunks
|
|
for hi in range(H):
|
|
for track, valid_bk in (
|
|
("complete_case", valid_cc_bhk[:, hi, :]),
|
|
("clean_control", valid_clean_bhk[:, hi, :]),
|
|
):
|
|
row_mask = valid_bk.any(dim=1)
|
|
if not row_mask.any():
|
|
continue
|
|
|
|
r = risk_bhk[row_mask, hi, :].to(
|
|
torch.float32).cpu().numpy()
|
|
y = labels_bhk[row_mask, hi, :].to(
|
|
torch.float32).cpu().numpy()
|
|
m = valid_bk[row_mask, :].to(
|
|
torch.bool).cpu().numpy()
|
|
|
|
buffers[(ci, hi, track)]["risk"].append(r)
|
|
buffers[(ci, hi, track)]["labels"].append(y)
|
|
buffers[(ci, hi, track)]["valid"].append(m)
|
|
|
|
# Assemble results in the original output schema
|
|
all_results: Dict = {
|
|
'age_cutoffs': self.age_cutoffs,
|
|
'horizons': self.horizons,
|
|
'landmarks': [],
|
|
}
|
|
|
|
for ci, age in enumerate(self.age_cutoffs):
|
|
for hi, horizon in enumerate(self.horizons):
|
|
landmark_results = {
|
|
'age_cutoff': age,
|
|
'horizon': horizon,
|
|
'complete_case': {},
|
|
'clean_control': {},
|
|
}
|
|
|
|
for track in ("complete_case", "clean_control"):
|
|
chunks = buffers[(ci, hi, track)]
|
|
if len(chunks["risk"]) == 0:
|
|
continue
|
|
|
|
risk_scores = np.concatenate(chunks["risk"], axis=0)
|
|
labels = np.concatenate(chunks["labels"], axis=0)
|
|
valid_mask = np.concatenate(chunks["valid"], axis=0)
|
|
|
|
auc_scores = self.compute_auc_per_disease(
|
|
risk_scores, labels, valid_mask)
|
|
mean_auc = np.nanmean(list(auc_scores.values()))
|
|
|
|
track_out = {
|
|
'n_patients': int(valid_mask.shape[0]),
|
|
'n_valid': int(valid_mask.sum()),
|
|
'n_valid_patients': int((valid_mask.any(axis=1)).sum()),
|
|
'auc_per_disease': auc_scores,
|
|
'mean_auc': mean_auc,
|
|
}
|
|
|
|
if track == "complete_case":
|
|
brier_metrics = self.compute_brier_score(
|
|
risk_scores, labels, valid_mask)
|
|
capture_metrics = self.compute_disease_capture_at_k(
|
|
risk_scores, labels, valid_mask)
|
|
lift_yield_metrics = self.compute_lift_and_yield(
|
|
risk_scores, labels, valid_mask)
|
|
dca_metrics = self.compute_dca_net_benefit(
|
|
risk_scores, labels, valid_mask)
|
|
track_out.update({
|
|
'brier_score': brier_metrics['brier_score'],
|
|
'brier_skill_score': brier_metrics['brier_skill_score'],
|
|
'disease_capture_at_k': capture_metrics,
|
|
'lift_and_yield': lift_yield_metrics,
|
|
'dca': dca_metrics,
|
|
})
|
|
|
|
landmark_results[track] = track_out
|
|
|
|
all_results['landmarks'].append(landmark_results)
|
|
|
|
return all_results
|
|
|
|
|
|
def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
|
|
"""
|
|
Load trained model and configuration from run directory.
|
|
|
|
Args:
|
|
run_dir: Path to run directory containing train_config.json and best_model.pt
|
|
device: Device to load model on
|
|
|
|
Returns:
|
|
model, head, loss_fn, dataset, config
|
|
"""
|
|
run_path = Path(run_dir)
|
|
|
|
# Load config
|
|
config_path = run_path / 'train_config.json'
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
|
|
print(f"Loading model from {run_dir}")
|
|
print(f"Model type: {config['model_type']}")
|
|
print(f"Loss type: {config['loss_type']}")
|
|
|
|
# Load dataset (same as training) and reproduce the train/val/test split.
|
|
# IMPORTANT: do NOT change data_prefix; train.py reads files like
|
|
# <data_prefix>_basic_info.csv, <data_prefix>_table.csv, <data_prefix>_event_data.npy
|
|
data_prefix = config['data_prefix']
|
|
|
|
if config.get('full_cov', False):
|
|
covariate_list = None
|
|
else:
|
|
# Match train.py partial-cov settings
|
|
covariate_list = ["bmi", "smoking", "alcohol"]
|
|
|
|
dataset = HealthDataset(
|
|
data_prefix=data_prefix,
|
|
covariate_list=covariate_list,
|
|
cache_event_tensors=True,
|
|
)
|
|
|
|
# Reproduce the random_split used in train.py to obtain the held-out test subset.
|
|
n_total = len(dataset)
|
|
train_ratio = float(config.get('train_ratio', 0.7))
|
|
val_ratio = float(config.get('val_ratio', 0.15))
|
|
seed = int(config.get('random_seed', 42))
|
|
|
|
n_train = int(n_total * train_ratio)
|
|
n_val = int(n_total * val_ratio)
|
|
n_test = n_total - n_train - n_val
|
|
if n_test < 0:
|
|
raise ValueError(
|
|
f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}"
|
|
)
|
|
|
|
from torch.utils.data import random_split
|
|
_, _, test_subset = random_split(
|
|
dataset,
|
|
[n_train, n_val, n_test],
|
|
generator=torch.Generator().manual_seed(seed),
|
|
)
|
|
test_indices = list(getattr(test_subset, 'indices', []))
|
|
|
|
# Determine output dimensions based on loss type
|
|
import math
|
|
if config['loss_type'] == 'exponential':
|
|
out_dims = [dataset.n_disease]
|
|
elif config['loss_type'] == 'discrete_time_cif':
|
|
# logits shape (M, K+1, n_bins+1)
|
|
bin_edges = config.get(
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')])
|
|
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
|
elif config['loss_type'] == 'pwe_cif':
|
|
# Piecewise-exponential requires finite edges
|
|
bin_edges = config.get(
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0])
|
|
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
|
n_bins = len(pwe_edges) - 1
|
|
# logits shape (M, K, n_bins)
|
|
out_dims = [dataset.n_disease, n_bins]
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
|
|
|
# Build model
|
|
if config['model_type'] == 'delphi_fork':
|
|
model = DelphiFork(
|
|
n_disease=dataset.n_disease,
|
|
n_tech_tokens=2, # PAD=0, DOA=1
|
|
n_embd=config['n_embd'],
|
|
n_head=config['n_head'],
|
|
n_layer=config['n_layer'],
|
|
n_cont=dataset.n_cont,
|
|
n_cate=dataset.n_cate,
|
|
cate_dims=dataset.cate_dims,
|
|
age_encoder_type=config['age_encoder'],
|
|
pdrop=config['pdrop'],
|
|
)
|
|
elif config['model_type'] == 'sap_delphi':
|
|
model = SapDelphi(
|
|
n_disease=dataset.n_disease,
|
|
n_tech_tokens=2,
|
|
n_embd=config['n_embd'],
|
|
n_head=config['n_head'],
|
|
n_layer=config['n_layer'],
|
|
n_cont=dataset.n_cont,
|
|
n_cate=dataset.n_cate,
|
|
cate_dims=dataset.cate_dims,
|
|
age_encoder_type=config['age_encoder'],
|
|
pdrop=config['pdrop'],
|
|
pretrained_weights_path=config.get('pretrained_emd_path'),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown model type: {config['model_type']}")
|
|
|
|
# Build head
|
|
head = SimpleHead(
|
|
n_embd=config['n_embd'],
|
|
out_dims=out_dims,
|
|
)
|
|
|
|
# Load model weights (checkpoint contains model and head state dicts)
|
|
model_path = run_path / 'best_model.pt'
|
|
checkpoint = torch.load(model_path, map_location=device)
|
|
|
|
# The checkpoint is a dict with 'model_state_dict' and 'head_state_dict'
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
head.load_state_dict(checkpoint['head_state_dict'])
|
|
print("Loaded model and head from checkpoint")
|
|
else:
|
|
raise ValueError(
|
|
"Checkpoint format not recognized. Expected 'model_state_dict' and 'head_state_dict' keys.")
|
|
|
|
model.to(device)
|
|
head.to(device)
|
|
|
|
# Build loss function
|
|
if config['loss_type'] == 'exponential':
|
|
loss_fn = ExponentialNLLLoss(
|
|
lambda_reg=config.get('lambda_reg', 0.0)
|
|
)
|
|
elif config['loss_type'] == 'discrete_time_cif':
|
|
loss_fn = DiscreteTimeCIFNLLLoss(
|
|
bin_edges=config.get(
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]),
|
|
lambda_reg=config.get('lambda_reg', 0.0),
|
|
)
|
|
elif config['loss_type'] == 'pwe_cif':
|
|
loss_fn = PiecewiseExponentialCIFNLLLoss(
|
|
bin_edges=config.get(
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]),
|
|
lambda_reg=config.get('lambda_reg', 0.0),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
|
|
|
return model, head, loss_fn, dataset, config, test_indices
|
|
|
|
|
|
def print_summary(results: Dict):
|
|
"""
|
|
Print summary of evaluation results.
|
|
|
|
Args:
|
|
results: Results dictionary
|
|
"""
|
|
print("\n" + "=" * 80)
|
|
print("EVALUATION SUMMARY")
|
|
print("=" * 80)
|
|
|
|
for landmark in results['landmarks']:
|
|
age = landmark['age_cutoff']
|
|
horizon = landmark['horizon']
|
|
|
|
print(f"\nLandmark: Age {age}, Horizon {horizon}y")
|
|
print("-" * 40)
|
|
|
|
# Complete-case results
|
|
if 'complete_case' in landmark and landmark['complete_case']:
|
|
cc = landmark['complete_case']
|
|
print(f" Complete-Case Track:")
|
|
print(f" Patients: {cc['n_patients']}")
|
|
print(f" Mean AUC: {cc['mean_auc']:.4f}")
|
|
print(f" Brier Score: {cc['brier_score']:.4f}")
|
|
print(f" Brier Skill Score: {cc['brier_skill_score']:.4f}")
|
|
|
|
# Show top-K capture rates (average across diseases)
|
|
if 'disease_capture_at_k' in cc:
|
|
print(f" Disease Capture:")
|
|
for k in [5, 10, 20, 50]:
|
|
if k in cc['disease_capture_at_k']:
|
|
rates = list(cc['disease_capture_at_k'][k].values())
|
|
mean_rate = np.nanmean(rates)
|
|
print(f" Top-{k}: {mean_rate:.3f}")
|
|
|
|
# Show lift and yield
|
|
if 'lift_and_yield' in cc:
|
|
print(f" Lift & Yield:")
|
|
overall = cc['lift_and_yield'].get('overall', {}) if isinstance(
|
|
cc['lift_and_yield'], dict) else {}
|
|
for frac in [0.01, 0.05, 0.10]:
|
|
if frac in overall:
|
|
lift = overall[frac].get('lift', np.nan)
|
|
yield_val = overall[frac].get('yield', np.nan)
|
|
print(
|
|
f" Overall Top {int(frac*100)}%: Lift={lift:.2f}, Yield={yield_val:.3f}")
|
|
|
|
# Clean control results
|
|
if 'clean_control' in landmark and landmark['clean_control']:
|
|
clean = landmark['clean_control']
|
|
print(f" Clean-Control Track:")
|
|
print(f" Patients: {clean['n_patients']}")
|
|
print(f" Mean AUC: {clean['mean_auc']:.4f}")
|
|
|
|
print("\n" + "=" * 80)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Evaluate longitudinal health prediction model using landmark analysis'
|
|
)
|
|
parser.add_argument(
|
|
'--run_dir',
|
|
type=str,
|
|
required=True,
|
|
help='Path to run directory containing train_config.json and best_model.pt'
|
|
)
|
|
parser.add_argument(
|
|
'--output',
|
|
type=str,
|
|
default=None,
|
|
help='Output path for results JSON (default: <run_dir>/evaluation_results.json)'
|
|
)
|
|
parser.add_argument(
|
|
'--out_dir',
|
|
type=str,
|
|
default=None,
|
|
help='Directory to write CSV outputs (default: <run_dir>/evaluation_outputs)'
|
|
)
|
|
parser.add_argument(
|
|
'--device',
|
|
type=str,
|
|
default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
help='Device to run evaluation on'
|
|
)
|
|
parser.add_argument(
|
|
'--batch_size',
|
|
type=int,
|
|
default=256,
|
|
help='Batch size for evaluation'
|
|
)
|
|
parser.add_argument(
|
|
'--num_workers',
|
|
type=int,
|
|
default=4,
|
|
help='Number of data loader workers'
|
|
)
|
|
parser.add_argument(
|
|
'--no_compile',
|
|
action='store_true',
|
|
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--check_capture_at_k',
|
|
action='store_true',
|
|
help='Run a one-time correctness check: slow vs fast Disease-Capture@K on a small subset'
|
|
)
|
|
parser.add_argument(
|
|
'--profile_metrics',
|
|
action='store_true',
|
|
help='Print basic timings for CPU-side metric computations'
|
|
)
|
|
parser.add_argument(
|
|
'--capture_check_n',
|
|
type=int,
|
|
default=200,
|
|
help='Number of samples used for the capture@K slow-vs-fast check (default: 200)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load model and dataset
|
|
model, head, loss_fn, dataset, config, test_indices = load_model_and_config(
|
|
args.run_dir, args.device)
|
|
|
|
# Create evaluator
|
|
evaluator = LandmarkEvaluator(
|
|
model=model,
|
|
head=head,
|
|
loss_fn=loss_fn,
|
|
dataset=dataset,
|
|
eval_indices=test_indices,
|
|
device=args.device,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
compile_model=(not args.no_compile),
|
|
check_capture_at_k=args.check_capture_at_k,
|
|
profile_metrics=args.profile_metrics,
|
|
capture_check_n=args.capture_check_n,
|
|
)
|
|
|
|
# Run evaluation
|
|
print("\nStarting landmark analysis evaluation...")
|
|
print(f"Age cutoffs: {evaluator.age_cutoffs}")
|
|
print(f"Horizons: {evaluator.horizons}")
|
|
|
|
results = evaluator.run_full_evaluation()
|
|
|
|
# Add metadata
|
|
results['metadata'] = {
|
|
'run_dir': args.run_dir,
|
|
'config': config,
|
|
'n_diseases': dataset.n_disease,
|
|
'device': args.device,
|
|
}
|
|
|
|
# Save results (CSV bundle + single JSON summary)
|
|
if args.out_dir is None:
|
|
args.out_dir = os.path.join(args.run_dir, 'evaluation_outputs')
|
|
csv_paths = save_results_csv_bundle(results, args.out_dir)
|
|
|
|
summary = {
|
|
'metadata': results.get('metadata', {}),
|
|
'age_cutoffs': results.get('age_cutoffs', []),
|
|
'horizons': results.get('horizons', []),
|
|
'csv_outputs': csv_paths,
|
|
'notes': {
|
|
'metrics': [
|
|
'AUC (per disease + mean)',
|
|
'Brier Score / Brier Skill Score (complete-case only)',
|
|
'Disease-Capture@K (complete-case only)',
|
|
'Lift/Yield (complete-case only)',
|
|
'Decision Curve Analysis (complete-case only)',
|
|
],
|
|
},
|
|
}
|
|
|
|
if args.output is None:
|
|
args.output = os.path.join(args.run_dir, 'evaluation_summary.json')
|
|
save_summary_json(summary, args.output)
|
|
|
|
print(f"\nWrote CSV outputs to: {args.out_dir}")
|
|
print(f"Wrote summary JSON to: {args.output}")
|
|
|
|
# Print summary
|
|
print_summary(results)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|