1246 lines
45 KiB
Python
1246 lines
45 KiB
Python
|
|
"""
|
||
|
|
Landmark Analysis Evaluation Script for Longitudinal Health Prediction Models
|
||
|
|
|
||
|
|
Implements the comprehensive evaluation framework defined in evaluate_design.md:
|
||
|
|
- Landmark analysis at age cutoffs {50, 60, 70}
|
||
|
|
- Prediction horizons {0.25, 0.5, 1, 2, 5, 10} years
|
||
|
|
- Two tracks: Complete-Case (primary) and Clean Control (academic benchmark)
|
||
|
|
- Metrics: AUC, Brier Score, Disease-Capture@K, Lift, Yield, DCA
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List, Tuple, Optional
|
||
|
|
import warnings
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import pandas as pd
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from torch.utils.data import DataLoader, Subset
|
||
|
|
from tqdm import tqdm
|
||
|
|
from sklearn.metrics import roc_auc_score, brier_score_loss
|
||
|
|
|
||
|
|
# Import model components
|
||
|
|
from model import DelphiFork, SapDelphi, SimpleHead
|
||
|
|
from dataset import HealthDataset, health_collate_fn
|
||
|
|
from losses import (
|
||
|
|
ExponentialNLLLoss,
|
||
|
|
DiscreteTimeCIFNLLLoss,
|
||
|
|
PiecewiseExponentialCIFNLLLoss
|
||
|
|
)
|
||
|
|
|
||
|
|
warnings.filterwarnings('ignore')
|
||
|
|
|
||
|
|
|
||
|
|
def _ensure_dir(path: str) -> str:
|
||
|
|
os.makedirs(path, exist_ok=True)
|
||
|
|
return path
|
||
|
|
|
||
|
|
|
||
|
|
def _to_float(x):
|
||
|
|
try:
|
||
|
|
return float(x)
|
||
|
|
except Exception:
|
||
|
|
return np.nan
|
||
|
|
|
||
|
|
|
||
|
|
def save_summary_json(summary: Dict, output_path: str) -> None:
|
||
|
|
"""Save a single JSON summary file."""
|
||
|
|
|
||
|
|
def convert_to_serializable(obj):
|
||
|
|
if isinstance(obj, np.ndarray):
|
||
|
|
return obj.tolist()
|
||
|
|
if isinstance(obj, (np.integer,)):
|
||
|
|
return int(obj)
|
||
|
|
if isinstance(obj, (np.floating,)):
|
||
|
|
return float(obj)
|
||
|
|
if isinstance(obj, dict):
|
||
|
|
return {k: convert_to_serializable(v) for k, v in obj.items()}
|
||
|
|
if isinstance(obj, list):
|
||
|
|
return [convert_to_serializable(v) for v in obj]
|
||
|
|
return obj
|
||
|
|
|
||
|
|
summary_serializable = convert_to_serializable(summary)
|
||
|
|
with open(output_path, 'w') as f:
|
||
|
|
json.dump(summary_serializable, f, indent=2)
|
||
|
|
|
||
|
|
|
||
|
|
def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]:
|
||
|
|
"""Save evaluation results into multiple CSV files.
|
||
|
|
|
||
|
|
Produces long-form tables so they are easy to analyze/plot:
|
||
|
|
- landmarks_summary.csv
|
||
|
|
- auc_per_disease.csv
|
||
|
|
- capture_at_k.csv
|
||
|
|
- lift_yield.csv
|
||
|
|
- dca.csv
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Mapping from logical name to file path.
|
||
|
|
"""
|
||
|
|
|
||
|
|
out_dir = _ensure_dir(out_dir)
|
||
|
|
|
||
|
|
summary_rows: List[Dict] = []
|
||
|
|
auc_rows: List[Dict] = []
|
||
|
|
capture_rows: List[Dict] = []
|
||
|
|
lift_rows: List[Dict] = []
|
||
|
|
dca_rows: List[Dict] = []
|
||
|
|
|
||
|
|
landmarks = results.get('landmarks', [])
|
||
|
|
for lm in landmarks:
|
||
|
|
age = lm.get('age_cutoff')
|
||
|
|
horizon = lm.get('horizon')
|
||
|
|
for track in ['complete_case', 'clean_control']:
|
||
|
|
track_res = lm.get(track) or {}
|
||
|
|
if not track_res:
|
||
|
|
continue
|
||
|
|
|
||
|
|
summary_rows.append({
|
||
|
|
'age_cutoff': age,
|
||
|
|
'horizon': horizon,
|
||
|
|
'track': track,
|
||
|
|
'n_patients': track_res.get('n_patients', np.nan),
|
||
|
|
'n_valid': track_res.get('n_valid', np.nan),
|
||
|
|
'n_valid_patients': track_res.get('n_valid_patients', np.nan),
|
||
|
|
'mean_auc': track_res.get('mean_auc', np.nan),
|
||
|
|
'brier_score': track_res.get('brier_score', np.nan),
|
||
|
|
'brier_skill_score': track_res.get('brier_skill_score', np.nan),
|
||
|
|
})
|
||
|
|
|
||
|
|
auc_per_disease = track_res.get('auc_per_disease') or {}
|
||
|
|
for disease_idx, auc in auc_per_disease.items():
|
||
|
|
auc_rows.append({
|
||
|
|
'age_cutoff': age,
|
||
|
|
'horizon': horizon,
|
||
|
|
'track': track,
|
||
|
|
'disease_idx': int(disease_idx),
|
||
|
|
'auc': _to_float(auc),
|
||
|
|
})
|
||
|
|
|
||
|
|
if track == 'complete_case':
|
||
|
|
capture = track_res.get('disease_capture_at_k') or {}
|
||
|
|
for k_val, per_disease in capture.items():
|
||
|
|
k_int = int(k_val)
|
||
|
|
for disease_idx, rate in (per_disease or {}).items():
|
||
|
|
capture_rows.append({
|
||
|
|
'age_cutoff': age,
|
||
|
|
'horizon': horizon,
|
||
|
|
'track': track,
|
||
|
|
'k': k_int,
|
||
|
|
'disease_idx': int(disease_idx),
|
||
|
|
'capture_rate': _to_float(rate),
|
||
|
|
})
|
||
|
|
|
||
|
|
lift_yield = track_res.get('lift_and_yield') or {}
|
||
|
|
overall = (lift_yield.get('overall') or {}) if isinstance(
|
||
|
|
lift_yield, dict) else {}
|
||
|
|
for frac, metrics in overall.items():
|
||
|
|
lift_rows.append({
|
||
|
|
'age_cutoff': age,
|
||
|
|
'horizon': horizon,
|
||
|
|
'track': track,
|
||
|
|
'level': 'overall',
|
||
|
|
'disease_idx': '',
|
||
|
|
'workload_frac': _to_float(frac),
|
||
|
|
'lift': _to_float((metrics or {}).get('lift')),
|
||
|
|
'yield': _to_float((metrics or {}).get('yield')),
|
||
|
|
})
|
||
|
|
|
||
|
|
per_disease = (lift_yield.get('per_disease') or {}
|
||
|
|
) if isinstance(lift_yield, dict) else {}
|
||
|
|
for disease_idx, disease_metrics in per_disease.items():
|
||
|
|
for frac, metrics in (disease_metrics or {}).items():
|
||
|
|
lift_rows.append({
|
||
|
|
'age_cutoff': age,
|
||
|
|
'horizon': horizon,
|
||
|
|
'track': track,
|
||
|
|
'level': 'per_disease',
|
||
|
|
'disease_idx': int(disease_idx),
|
||
|
|
'workload_frac': _to_float(frac),
|
||
|
|
'lift': _to_float((metrics or {}).get('lift')),
|
||
|
|
'yield': _to_float((metrics or {}).get('yield')),
|
||
|
|
})
|
||
|
|
|
||
|
|
dca = track_res.get('dca') or {}
|
||
|
|
thresholds = dca.get('thresholds')
|
||
|
|
net_benefit = dca.get('net_benefit')
|
||
|
|
if thresholds is not None and net_benefit is not None:
|
||
|
|
thresholds_arr = np.asarray(thresholds, dtype=np.float64)
|
||
|
|
nb_arr = np.asarray(net_benefit, dtype=np.float64)
|
||
|
|
for thr, nb in zip(thresholds_arr, nb_arr):
|
||
|
|
dca_rows.append({
|
||
|
|
'age_cutoff': age,
|
||
|
|
'horizon': horizon,
|
||
|
|
'track': track,
|
||
|
|
'threshold': float(thr),
|
||
|
|
'net_benefit': float(nb),
|
||
|
|
})
|
||
|
|
|
||
|
|
paths: Dict[str, str] = {}
|
||
|
|
|
||
|
|
df_summary = pd.DataFrame(summary_rows)
|
||
|
|
summary_path = os.path.join(out_dir, 'landmarks_summary.csv')
|
||
|
|
df_summary.to_csv(summary_path, index=False)
|
||
|
|
paths['landmarks_summary'] = summary_path
|
||
|
|
|
||
|
|
df_auc = pd.DataFrame(auc_rows)
|
||
|
|
auc_path = os.path.join(out_dir, 'auc_per_disease.csv')
|
||
|
|
df_auc.to_csv(auc_path, index=False)
|
||
|
|
paths['auc_per_disease'] = auc_path
|
||
|
|
|
||
|
|
df_capture = pd.DataFrame(capture_rows)
|
||
|
|
capture_path = os.path.join(out_dir, 'capture_at_k.csv')
|
||
|
|
df_capture.to_csv(capture_path, index=False)
|
||
|
|
paths['capture_at_k'] = capture_path
|
||
|
|
|
||
|
|
df_lift = pd.DataFrame(lift_rows)
|
||
|
|
lift_path = os.path.join(out_dir, 'lift_yield.csv')
|
||
|
|
df_lift.to_csv(lift_path, index=False)
|
||
|
|
paths['lift_yield'] = lift_path
|
||
|
|
|
||
|
|
df_dca = pd.DataFrame(dca_rows)
|
||
|
|
dca_path = os.path.join(out_dir, 'dca.csv')
|
||
|
|
df_dca.to_csv(dca_path, index=False)
|
||
|
|
paths['dca'] = dca_path
|
||
|
|
|
||
|
|
return paths
|
||
|
|
|
||
|
|
|
||
|
|
class LandmarkEvaluator:
|
||
|
|
"""
|
||
|
|
Comprehensive landmark analysis evaluator for survival/competing risks models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
model: torch.nn.Module,
|
||
|
|
head: torch.nn.Module,
|
||
|
|
loss_fn: torch.nn.Module,
|
||
|
|
dataset: HealthDataset,
|
||
|
|
device: str = 'cuda',
|
||
|
|
batch_size: int = 256,
|
||
|
|
num_workers: int = 4,
|
||
|
|
):
|
||
|
|
self.model = model.to(device)
|
||
|
|
self.model.eval()
|
||
|
|
self.head = head.to(device)
|
||
|
|
self.head.eval()
|
||
|
|
self.loss_fn = loss_fn
|
||
|
|
self.dataset = dataset
|
||
|
|
self.device = device
|
||
|
|
self.batch_size = batch_size
|
||
|
|
self.num_workers = num_workers
|
||
|
|
|
||
|
|
# Evaluation parameters from design doc
|
||
|
|
self.age_cutoffs = [50, 60, 70]
|
||
|
|
self.horizons = [0.25, 0.5, 1, 2, 5, 10]
|
||
|
|
self.top_k_values = [5, 10, 20, 50]
|
||
|
|
self.workload_fracs = [0.01, 0.05, 0.10, 0.20, 0.50]
|
||
|
|
|
||
|
|
# Convert age to days for comparison
|
||
|
|
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
|
||
|
|
self.horizons_days = [h * 365.25 for h in self.horizons]
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def compute_risk_scores(
|
||
|
|
self,
|
||
|
|
indices: List[int],
|
||
|
|
age_cutoff_days: float,
|
||
|
|
horizon_days: float,
|
||
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||
|
|
"""
|
||
|
|
Compute risk scores for specified patient indices at given landmark.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
indices: Patient indices to evaluate
|
||
|
|
age_cutoff_days: Age cutoff in days (T_cut)
|
||
|
|
horizon_days: Prediction horizon in days (H)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
risk_scores: (N, K) array of risk scores per disease
|
||
|
|
t_anchors: (N,) array of anchor times (days from birth)
|
||
|
|
valid_mask: (N,) boolean array indicating valid predictions
|
||
|
|
"""
|
||
|
|
subset = Subset(self.dataset, indices)
|
||
|
|
loader = DataLoader(
|
||
|
|
subset,
|
||
|
|
batch_size=self.batch_size,
|
||
|
|
shuffle=False,
|
||
|
|
collate_fn=health_collate_fn,
|
||
|
|
num_workers=self.num_workers,
|
||
|
|
pin_memory=True if self.device == 'cuda' else False,
|
||
|
|
)
|
||
|
|
|
||
|
|
all_risk_scores = []
|
||
|
|
all_t_anchors = []
|
||
|
|
all_valid_mask = []
|
||
|
|
|
||
|
|
for batch in loader:
|
||
|
|
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
|
||
|
|
event_batch = event_batch.to(self.device)
|
||
|
|
time_batch = time_batch.to(self.device)
|
||
|
|
cont_batch = cont_batch.to(self.device)
|
||
|
|
cate_batch = cate_batch.to(self.device)
|
||
|
|
sex_batch = sex_batch.to(self.device)
|
||
|
|
|
||
|
|
# Find anchor point: last valid record before age_cutoff
|
||
|
|
# Valid records are non-padding events (event >= 1)
|
||
|
|
valid_mask = event_batch >= 1 # (B, L)
|
||
|
|
before_cutoff = time_batch < age_cutoff_days # (B, L)
|
||
|
|
valid_before = valid_mask & before_cutoff # (B, L)
|
||
|
|
|
||
|
|
# Find last valid position for each patient
|
||
|
|
batch_size = event_batch.size(0)
|
||
|
|
t_anchor = torch.zeros(batch_size, device=self.device)
|
||
|
|
anchor_idx = torch.zeros(
|
||
|
|
batch_size, dtype=torch.long, device=self.device)
|
||
|
|
has_anchor = torch.zeros(
|
||
|
|
batch_size, dtype=torch.bool, device=self.device)
|
||
|
|
|
||
|
|
for b in range(batch_size):
|
||
|
|
valid_positions = valid_before[b].nonzero(as_tuple=True)[0]
|
||
|
|
if len(valid_positions) > 0:
|
||
|
|
last_pos = valid_positions[-1]
|
||
|
|
anchor_idx[b] = last_pos
|
||
|
|
t_anchor[b] = time_batch[b, last_pos]
|
||
|
|
has_anchor[b] = True
|
||
|
|
|
||
|
|
# Get model predictions at anchor points
|
||
|
|
if has_anchor.any():
|
||
|
|
# Forward pass
|
||
|
|
hidden = self.model(event_batch, time_batch,
|
||
|
|
sex_batch, cont_batch, cate_batch)
|
||
|
|
|
||
|
|
# Get predictions at anchor positions
|
||
|
|
batch_indices = torch.arange(batch_size, device=self.device)
|
||
|
|
# (B, n_embd)
|
||
|
|
hidden_at_anchor = hidden[batch_indices, anchor_idx]
|
||
|
|
|
||
|
|
# Compute logits using the loaded head
|
||
|
|
# (B, n_disease, ...) or (B, K+1, n_bins+1) etc.
|
||
|
|
logits = self.head(hidden_at_anchor)
|
||
|
|
|
||
|
|
# Compute CIF scores
|
||
|
|
# Time gap from anchor to start of horizon
|
||
|
|
t_start = age_cutoff_days - t_anchor # (B,)
|
||
|
|
# Time gap from anchor to end of horizon
|
||
|
|
t_end = age_cutoff_days + horizon_days - t_anchor # (B,)
|
||
|
|
|
||
|
|
# Ensure non-negative time gaps
|
||
|
|
t_start = torch.clamp(t_start, min=0)
|
||
|
|
t_end = torch.clamp(t_end, min=0)
|
||
|
|
|
||
|
|
# Calculate CIF at both time points
|
||
|
|
cif_start = self._compute_cif(logits, t_start) # (B, K)
|
||
|
|
cif_end = self._compute_cif(logits, t_end) # (B, K)
|
||
|
|
|
||
|
|
# Risk score is the increment within the horizon
|
||
|
|
risk_scores = cif_end - cif_start # (B, K)
|
||
|
|
risk_scores = torch.clamp(
|
||
|
|
risk_scores, min=0) # Ensure non-negative
|
||
|
|
|
||
|
|
else:
|
||
|
|
# No valid anchor points in this batch
|
||
|
|
risk_scores = torch.zeros(
|
||
|
|
batch_size, self.dataset.n_disease, device=self.device)
|
||
|
|
|
||
|
|
all_risk_scores.append(risk_scores.cpu().numpy())
|
||
|
|
all_t_anchors.append(t_anchor.cpu().numpy())
|
||
|
|
all_valid_mask.append(has_anchor.cpu().numpy())
|
||
|
|
|
||
|
|
# Concatenate results
|
||
|
|
risk_scores = np.vstack(all_risk_scores) # (N, K)
|
||
|
|
t_anchors = np.concatenate(all_t_anchors) # (N,)
|
||
|
|
valid_mask = np.concatenate(all_valid_mask) # (N,)
|
||
|
|
|
||
|
|
return risk_scores, t_anchors, valid_mask
|
||
|
|
|
||
|
|
def _compute_cif(self, logits: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||
|
|
"""
|
||
|
|
Compute Cumulative Incidence Function at time t.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
logits: Model output logits (B, K, ...) depending on loss type
|
||
|
|
t: Time points (B,) in years from anchor
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
cif: (B, K) cumulative incidence probabilities
|
||
|
|
"""
|
||
|
|
t_years = t / 365.25 # Convert to years
|
||
|
|
|
||
|
|
if isinstance(self.loss_fn, ExponentialNLLLoss):
|
||
|
|
# Exponential: logits are (B, K)
|
||
|
|
lambdas = F.softplus(logits) + 1e-6 # (B, K)
|
||
|
|
total_lambda = lambdas.sum(dim=-1, keepdim=True) # (B, 1)
|
||
|
|
|
||
|
|
# CIF_k(t) = (λ_k / Σλ) * (1 - exp(-Σλ * t))
|
||
|
|
frac = lambdas / total_lambda # (B, K)
|
||
|
|
exp_term = 1.0 - \
|
||
|
|
torch.exp(-total_lambda.squeeze(-1).unsqueeze(-1)
|
||
|
|
* t_years.unsqueeze(-1))
|
||
|
|
cif = frac * exp_term # (B, K)
|
||
|
|
|
||
|
|
elif isinstance(self.loss_fn, DiscreteTimeCIFNLLLoss):
|
||
|
|
# Discrete-time CIF: use calculate_cifs method
|
||
|
|
cif = self.loss_fn.calculate_cifs(
|
||
|
|
logits, t_years, return_survival=False)
|
||
|
|
|
||
|
|
elif isinstance(self.loss_fn, PiecewiseExponentialCIFNLLLoss):
|
||
|
|
# PWE CIF: use calculate_cifs method
|
||
|
|
cif = self.loss_fn.calculate_cifs(
|
||
|
|
logits, t_years, return_survival=False)
|
||
|
|
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown loss type: {type(self.loss_fn)}")
|
||
|
|
|
||
|
|
return cif
|
||
|
|
|
||
|
|
def prepare_evaluation_cohort(
|
||
|
|
self,
|
||
|
|
age_cutoff_days: float,
|
||
|
|
horizon_days: float,
|
||
|
|
track: str = 'complete_case',
|
||
|
|
) -> Tuple[List[int], np.ndarray, np.ndarray]:
|
||
|
|
"""Prepare evaluation cohort per design protocol with per-disease validity.
|
||
|
|
|
||
|
|
Key fixes vs the earlier implementation:
|
||
|
|
- Multi-disease labeling: mark *all* diseases that occur within the horizon.
|
||
|
|
- Per-disease validity mask: mask[i, k]=1 if patient i is valid for disease k.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
indices: list of patient indices included in this cohort table
|
||
|
|
labels: (N, K) float array
|
||
|
|
valid_mask: (N, K) bool array
|
||
|
|
"""
|
||
|
|
|
||
|
|
n_tech_tokens = 2 # PAD=0, DOA=1
|
||
|
|
K = int(self.dataset.n_disease)
|
||
|
|
|
||
|
|
# Competing risk token: treat death as competing event.
|
||
|
|
# As requested: DEATH_CODE is the last disease index.
|
||
|
|
DEATH_CODE = int(K - 1)
|
||
|
|
|
||
|
|
horizon_end_days = age_cutoff_days + horizon_days
|
||
|
|
|
||
|
|
indices: List[int] = []
|
||
|
|
labels_rows: List[np.ndarray] = []
|
||
|
|
valid_rows: List[np.ndarray] = []
|
||
|
|
|
||
|
|
for idx in range(len(self.dataset)):
|
||
|
|
patient_id = self.dataset.patient_ids[idx]
|
||
|
|
records = self.dataset.patient_events.get(patient_id, [])
|
||
|
|
if not records:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Must have some information strictly prior to cutoff (anchor existence in the data).
|
||
|
|
has_pre_cutoff = any(t < age_cutoff_days for t, _ in records)
|
||
|
|
if not has_pre_cutoff:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Events after cutoff (already sorted in dataset init)
|
||
|
|
events_after = [(t, e) for t, e in records if t >= age_cutoff_days]
|
||
|
|
|
||
|
|
labels = np.zeros(K, dtype=np.float32)
|
||
|
|
valid = np.zeros(K, dtype=bool)
|
||
|
|
|
||
|
|
# Identify diseases within horizon (multi-label; no early break)
|
||
|
|
diseases_in_horizon: set[int] = set()
|
||
|
|
death_in_horizon = False
|
||
|
|
for t, e in events_after:
|
||
|
|
if t > horizon_end_days:
|
||
|
|
break # events are time-sorted
|
||
|
|
if e < n_tech_tokens:
|
||
|
|
continue
|
||
|
|
disease_idx = int(e - n_tech_tokens)
|
||
|
|
if 0 <= disease_idx < K:
|
||
|
|
diseases_in_horizon.add(disease_idx)
|
||
|
|
if disease_idx == DEATH_CODE:
|
||
|
|
death_in_horizon = True
|
||
|
|
|
||
|
|
for d in diseases_in_horizon:
|
||
|
|
labels[d] = 1.0
|
||
|
|
|
||
|
|
last_time = float(records[-1][0])
|
||
|
|
|
||
|
|
if track == 'complete_case':
|
||
|
|
# Track A: Complete-Case at Horizon
|
||
|
|
# - Hit: disease occurs within horizon => valid positive for that disease
|
||
|
|
# - Healthy: last record > horizon_end => valid negative for all diseases
|
||
|
|
# - Death: death within horizon => valid negative for diseases not occurred before death
|
||
|
|
# (we implement as valid for all diseases, with labels marking any hits incl death)
|
||
|
|
# - Loss: censored within horizon (no hit for disease) => invalid for that disease
|
||
|
|
|
||
|
|
if last_time > horizon_end_days:
|
||
|
|
# Observed past horizon: negative for all non-hit diseases
|
||
|
|
valid[:] = True
|
||
|
|
elif death_in_horizon:
|
||
|
|
# Competing risk: include as explicit negatives (label 0) for diseases not hit
|
||
|
|
# and positives for diseases hit before death (we don't model ordering here).
|
||
|
|
# This matches the requirement: death within window is label=0 for target diseases.
|
||
|
|
valid[:] = True
|
||
|
|
else:
|
||
|
|
# Lost within horizon: only diseases that actually occurred in the horizon are valid positives
|
||
|
|
# (no assumptions about negatives).
|
||
|
|
for d in diseases_in_horizon:
|
||
|
|
valid[d] = True
|
||
|
|
|
||
|
|
elif track == 'clean_control':
|
||
|
|
# Track B: Clean Control (per-disease)
|
||
|
|
# For each disease k:
|
||
|
|
# - Hit: disease k occurs within horizon => label=1, valid
|
||
|
|
# - Pure clean for k: disease k never occurs in entire record => label=0, valid
|
||
|
|
# - Death within window => drop (invalid) for all k
|
||
|
|
# - Loss within window => drop (invalid) for all k
|
||
|
|
# - Late onset of k (after horizon) => drop (invalid) for k
|
||
|
|
|
||
|
|
if death_in_horizon:
|
||
|
|
# Drop all diseases for this patient
|
||
|
|
valid[:] = False
|
||
|
|
else:
|
||
|
|
# Precompute per-disease first occurrence time after cutoff (including after horizon)
|
||
|
|
first_occ_after_cutoff = np.full(
|
||
|
|
K, np.inf, dtype=np.float64)
|
||
|
|
ever_has_disease = np.zeros(K, dtype=bool)
|
||
|
|
for t, e in records:
|
||
|
|
if e < n_tech_tokens:
|
||
|
|
continue
|
||
|
|
disease_idx = int(e - n_tech_tokens)
|
||
|
|
if 0 <= disease_idx < K:
|
||
|
|
ever_has_disease[disease_idx] = True
|
||
|
|
if t >= age_cutoff_days and t < first_occ_after_cutoff[disease_idx]:
|
||
|
|
first_occ_after_cutoff[disease_idx] = float(t)
|
||
|
|
|
||
|
|
# If lost within horizon and did not have disease k in horizon, it's invalid for k.
|
||
|
|
lost_within_horizon = last_time <= horizon_end_days
|
||
|
|
|
||
|
|
for k in range(K):
|
||
|
|
if labels[k] == 1.0:
|
||
|
|
valid[k] = True
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not ever_has_disease[k]:
|
||
|
|
# Lifetime clean for k
|
||
|
|
# Still need to have complete follow-up within horizon for clean-control track.
|
||
|
|
# If censored within horizon, we cannot be sure k didn't occur in the window.
|
||
|
|
valid[k] = not lost_within_horizon
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Has disease k at some point.
|
||
|
|
t_first = first_occ_after_cutoff[k]
|
||
|
|
if np.isfinite(t_first) and t_first > horizon_end_days:
|
||
|
|
# Late onset of k -> invalid for k
|
||
|
|
valid[k] = False
|
||
|
|
else:
|
||
|
|
# Has k before cutoff (prevalent) or has k after cutoff but not in horizon? -> invalid.
|
||
|
|
valid[k] = False
|
||
|
|
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown track: {track}")
|
||
|
|
|
||
|
|
# Keep patient row if they are valid for at least one disease.
|
||
|
|
if valid.any():
|
||
|
|
indices.append(idx)
|
||
|
|
labels_rows.append(labels)
|
||
|
|
valid_rows.append(valid)
|
||
|
|
|
||
|
|
if not indices:
|
||
|
|
return [], np.zeros((0, K), dtype=np.float32), np.zeros((0, K), dtype=bool)
|
||
|
|
|
||
|
|
return indices, np.stack(labels_rows, axis=0), np.stack(valid_rows, axis=0)
|
||
|
|
|
||
|
|
def compute_auc_per_disease(
|
||
|
|
self,
|
||
|
|
risk_scores: np.ndarray,
|
||
|
|
labels: np.ndarray,
|
||
|
|
valid_mask: np.ndarray,
|
||
|
|
) -> Dict[int, float]:
|
||
|
|
"""
|
||
|
|
Compute time-dependent AUC for each disease.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
risk_scores: (N, K) risk scores
|
||
|
|
labels: (N, K) binary labels
|
||
|
|
valid_mask: (N, K) boolean mask for valid evaluations
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
auc_scores: Dict mapping disease_idx -> AUC
|
||
|
|
"""
|
||
|
|
auc_scores = {}
|
||
|
|
n_diseases = risk_scores.shape[1]
|
||
|
|
|
||
|
|
for k in range(n_diseases):
|
||
|
|
mk = valid_mask[:, k]
|
||
|
|
if not np.any(mk):
|
||
|
|
auc_scores[k] = np.nan
|
||
|
|
continue
|
||
|
|
|
||
|
|
y_true = labels[mk, k]
|
||
|
|
y_score = risk_scores[mk, k]
|
||
|
|
|
||
|
|
# Check if we have both classes
|
||
|
|
if len(np.unique(y_true)) < 2:
|
||
|
|
auc_scores[k] = np.nan
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
auc = roc_auc_score(y_true, y_score)
|
||
|
|
auc_scores[k] = auc
|
||
|
|
except Exception:
|
||
|
|
auc_scores[k] = np.nan
|
||
|
|
|
||
|
|
return auc_scores
|
||
|
|
|
||
|
|
def compute_brier_score(
|
||
|
|
self,
|
||
|
|
risk_scores: np.ndarray,
|
||
|
|
labels: np.ndarray,
|
||
|
|
valid_mask: np.ndarray,
|
||
|
|
) -> Dict[str, float]:
|
||
|
|
"""
|
||
|
|
Compute Brier Score and Brier Skill Score.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
risk_scores: (N, K) risk scores
|
||
|
|
labels: (N, K) binary labels
|
||
|
|
valid_mask: (N, K) boolean mask
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
metrics: Dict with 'brier_score' and 'brier_skill_score'
|
||
|
|
"""
|
||
|
|
# Apply per-entry valid mask and flatten
|
||
|
|
m = valid_mask.astype(bool)
|
||
|
|
y_true_flat = labels[m]
|
||
|
|
y_pred_flat = risk_scores[m]
|
||
|
|
|
||
|
|
mask = ~(np.isnan(y_true_flat) | np.isnan(y_pred_flat))
|
||
|
|
y_true_flat = y_true_flat[mask]
|
||
|
|
y_pred_flat = y_pred_flat[mask]
|
||
|
|
|
||
|
|
if len(y_true_flat) == 0:
|
||
|
|
return {'brier_score': np.nan, 'brier_skill_score': np.nan}
|
||
|
|
|
||
|
|
bs = brier_score_loss(y_true_flat, y_pred_flat)
|
||
|
|
|
||
|
|
# Brier Skill Score: reference is predicting the mean
|
||
|
|
p_mean = y_true_flat.mean()
|
||
|
|
bs_ref = ((p_mean - y_true_flat) ** 2).mean()
|
||
|
|
bss = 1.0 - (bs / bs_ref) if bs_ref > 0 else 0.0
|
||
|
|
|
||
|
|
return {'brier_score': bs, 'brier_skill_score': bss}
|
||
|
|
|
||
|
|
def compute_disease_capture_at_k(
|
||
|
|
self,
|
||
|
|
risk_scores: np.ndarray,
|
||
|
|
labels: np.ndarray,
|
||
|
|
valid_mask: np.ndarray,
|
||
|
|
) -> Dict[int, Dict[int, float]]:
|
||
|
|
"""
|
||
|
|
Compute Disease-Capture@K: fraction of true positives where the true
|
||
|
|
disease appears in the patient's top-K predicted risks.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
risk_scores: (N, K) risk scores
|
||
|
|
labels: (N, K) binary labels
|
||
|
|
valid_mask: (N, K) boolean mask
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
capture_rates: Dict[K_value][disease_idx] -> capture rate
|
||
|
|
"""
|
||
|
|
capture_rates = {k: {} for k in self.top_k_values}
|
||
|
|
|
||
|
|
n_diseases = risk_scores.shape[1]
|
||
|
|
|
||
|
|
for disease_idx in range(n_diseases):
|
||
|
|
mk = valid_mask[:, disease_idx]
|
||
|
|
if not np.any(mk):
|
||
|
|
for k in self.top_k_values:
|
||
|
|
capture_rates[k][disease_idx] = np.nan
|
||
|
|
continue
|
||
|
|
|
||
|
|
y_true = labels[mk, disease_idx]
|
||
|
|
y_scores = risk_scores[mk] # (N_valid_k, K)
|
||
|
|
|
||
|
|
# Find patients with positive label for this disease
|
||
|
|
pos_mask = y_true == 1
|
||
|
|
if pos_mask.sum() == 0:
|
||
|
|
for k in self.top_k_values:
|
||
|
|
capture_rates[k][disease_idx] = np.nan
|
||
|
|
continue
|
||
|
|
|
||
|
|
# For each positive patient, check if true disease is in top-K
|
||
|
|
for k_val in self.top_k_values:
|
||
|
|
captures = []
|
||
|
|
for i in np.where(pos_mask)[0]:
|
||
|
|
# Get top-K disease indices for this patient
|
||
|
|
top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val]
|
||
|
|
# Check if true disease is in top-K
|
||
|
|
is_captured = disease_idx in top_k_diseases
|
||
|
|
captures.append(int(is_captured))
|
||
|
|
|
||
|
|
capture_rate = np.mean(captures) if captures else np.nan
|
||
|
|
capture_rates[k_val][disease_idx] = capture_rate
|
||
|
|
|
||
|
|
return capture_rates
|
||
|
|
|
||
|
|
def compute_lift_and_yield(
|
||
|
|
self,
|
||
|
|
risk_scores: np.ndarray,
|
||
|
|
labels: np.ndarray,
|
||
|
|
valid_mask: np.ndarray,
|
||
|
|
) -> Dict[str, Dict]:
|
||
|
|
"""
|
||
|
|
Compute Lift and Yield at various workload fractions.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
risk_scores: (N, K) risk scores
|
||
|
|
labels: (N, K) binary labels
|
||
|
|
valid_mask: (N,) boolean mask
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
metrics:
|
||
|
|
{
|
||
|
|
'overall': {workload_frac: {'lift': ..., 'yield': ...}, ...},
|
||
|
|
'per_disease': {disease_idx: {workload_frac: {'lift': ..., 'yield': ...}, ...}, ...}
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
# Overall metric uses a patient-level mask to avoid including purely censored negatives.
|
||
|
|
# We include patients who either (a) have known outcomes for all diseases, or (b) have at least one hit.
|
||
|
|
has_any_hit = labels.max(axis=1) > 0
|
||
|
|
has_all_known = valid_mask.all(axis=1)
|
||
|
|
overall_patient_mask = has_any_hit | has_all_known
|
||
|
|
|
||
|
|
risk_scores_overall = risk_scores[overall_patient_mask]
|
||
|
|
labels_overall = labels[overall_patient_mask]
|
||
|
|
valid_overall = valid_mask[overall_patient_mask]
|
||
|
|
|
||
|
|
# Flatten to patient-level: any disease event
|
||
|
|
# Max risk across all diseases for each patient
|
||
|
|
max_risk_per_patient = risk_scores_overall.max(axis=1)
|
||
|
|
any_disease_label = labels_overall.max(axis=1)
|
||
|
|
|
||
|
|
n_patients = len(max_risk_per_patient)
|
||
|
|
base_rate = any_disease_label.mean() if n_patients > 0 else 0.0
|
||
|
|
|
||
|
|
overall: Dict[float, Dict[str, float]] = {}
|
||
|
|
for workload_frac in self.workload_fracs:
|
||
|
|
n_screen = max(1, int(n_patients * workload_frac))
|
||
|
|
top_n_idx = np.argsort(max_risk_per_patient)[::-1][:n_screen]
|
||
|
|
top_n_labels = any_disease_label[top_n_idx]
|
||
|
|
yield_val = float(top_n_labels.mean()) if n_screen > 0 else np.nan
|
||
|
|
lift_val = (yield_val / float(base_rate)) if base_rate > 0 else 0.0
|
||
|
|
overall[workload_frac] = {'lift': lift_val, 'yield': yield_val}
|
||
|
|
|
||
|
|
per_disease: Dict[int, Dict[float, Dict[str, float]]] = {}
|
||
|
|
n_diseases = risk_scores.shape[1]
|
||
|
|
for disease_idx in range(n_diseases):
|
||
|
|
mk = valid_mask[:, disease_idx]
|
||
|
|
if not np.any(mk):
|
||
|
|
per_disease[disease_idx] = {
|
||
|
|
frac: {'lift': np.nan, 'yield': np.nan} for frac in self.workload_fracs}
|
||
|
|
continue
|
||
|
|
|
||
|
|
disease_scores = risk_scores[mk, disease_idx]
|
||
|
|
disease_labels = labels[mk, disease_idx]
|
||
|
|
disease_base_rate = disease_labels.mean() if disease_labels.size > 0 else 0.0
|
||
|
|
|
||
|
|
n_patients_k = disease_scores.shape[0]
|
||
|
|
|
||
|
|
disease_metrics: Dict[float, Dict[str, float]] = {}
|
||
|
|
for workload_frac in self.workload_fracs:
|
||
|
|
n_screen = max(1, int(n_patients_k * workload_frac))
|
||
|
|
top_n_idx = np.argsort(disease_scores)[::-1][:n_screen]
|
||
|
|
top_n_labels = disease_labels[top_n_idx]
|
||
|
|
yield_val = float(top_n_labels.mean()
|
||
|
|
) if n_screen > 0 else np.nan
|
||
|
|
lift_val = (yield_val / float(disease_base_rate)
|
||
|
|
) if disease_base_rate > 0 else 0.0
|
||
|
|
disease_metrics[workload_frac] = {
|
||
|
|
'lift': lift_val, 'yield': yield_val}
|
||
|
|
|
||
|
|
per_disease[disease_idx] = disease_metrics
|
||
|
|
|
||
|
|
return {
|
||
|
|
'overall': overall,
|
||
|
|
'per_disease': per_disease,
|
||
|
|
}
|
||
|
|
|
||
|
|
def compute_dca_net_benefit(
|
||
|
|
self,
|
||
|
|
risk_scores: np.ndarray,
|
||
|
|
labels: np.ndarray,
|
||
|
|
valid_mask: np.ndarray,
|
||
|
|
threshold_range: np.ndarray = np.linspace(0, 0.5, 51),
|
||
|
|
) -> Dict[str, np.ndarray]:
|
||
|
|
"""
|
||
|
|
Compute Decision Curve Analysis (DCA) net benefit.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
risk_scores: (N, K) risk scores
|
||
|
|
labels: (N, K) binary labels
|
||
|
|
valid_mask: (N,) boolean mask
|
||
|
|
threshold_range: Array of threshold probabilities
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dca_results: Dict with 'thresholds' and 'net_benefit' arrays
|
||
|
|
"""
|
||
|
|
# Use the same overall patient mask as lift/yield (complete-case style)
|
||
|
|
has_any_hit = labels.max(axis=1) > 0
|
||
|
|
has_all_known = valid_mask.all(axis=1)
|
||
|
|
patient_mask = has_any_hit | has_all_known
|
||
|
|
|
||
|
|
risk_scores = risk_scores[patient_mask]
|
||
|
|
labels = labels[patient_mask]
|
||
|
|
|
||
|
|
# Use max risk and any disease label
|
||
|
|
max_risk = risk_scores.max(axis=1)
|
||
|
|
any_disease = labels.max(axis=1)
|
||
|
|
|
||
|
|
n = len(max_risk)
|
||
|
|
net_benefits = []
|
||
|
|
|
||
|
|
for pt in threshold_range:
|
||
|
|
if pt == 0:
|
||
|
|
# Treat all
|
||
|
|
nb = any_disease.mean()
|
||
|
|
else:
|
||
|
|
# Treat if predicted risk > threshold
|
||
|
|
treat = max_risk >= pt
|
||
|
|
tp = (treat & (any_disease == 1)).sum()
|
||
|
|
fp = (treat & (any_disease == 0)).sum()
|
||
|
|
|
||
|
|
# Net benefit = (TP/N) - (FP/N) * (pt / (1-pt))
|
||
|
|
nb = (tp / n) - (fp / n) * (pt / (1 - pt))
|
||
|
|
|
||
|
|
net_benefits.append(nb)
|
||
|
|
|
||
|
|
return {
|
||
|
|
'thresholds': threshold_range,
|
||
|
|
'net_benefit': np.array(net_benefits),
|
||
|
|
}
|
||
|
|
|
||
|
|
def evaluate_landmark(
|
||
|
|
self,
|
||
|
|
age_cutoff: float,
|
||
|
|
horizon: float,
|
||
|
|
) -> Dict:
|
||
|
|
"""
|
||
|
|
Evaluate model at a specific landmark (age_cutoff, horizon).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
age_cutoff: Age cutoff in years
|
||
|
|
horizon: Prediction horizon in years
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
results: Dictionary with all metrics
|
||
|
|
"""
|
||
|
|
age_cutoff_days = age_cutoff * 365.25
|
||
|
|
horizon_days = horizon * 365.25
|
||
|
|
|
||
|
|
print(f"\nEvaluating Landmark: Age={age_cutoff}, Horizon={horizon}y")
|
||
|
|
|
||
|
|
results = {
|
||
|
|
'age_cutoff': age_cutoff,
|
||
|
|
'horizon': horizon,
|
||
|
|
'complete_case': {},
|
||
|
|
'clean_control': {},
|
||
|
|
}
|
||
|
|
|
||
|
|
for track in ['complete_case', 'clean_control']:
|
||
|
|
print(f" Track: {track}")
|
||
|
|
|
||
|
|
# Prepare cohort
|
||
|
|
indices, labels_array, valid_mask = self.prepare_evaluation_cohort(
|
||
|
|
age_cutoff_days, horizon_days, track
|
||
|
|
)
|
||
|
|
|
||
|
|
if len(indices) == 0:
|
||
|
|
print(f" No valid patients for track {track}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
print(f" Cohort size: {len(indices)}")
|
||
|
|
|
||
|
|
# Compute risk scores
|
||
|
|
risk_scores, t_anchors, anchor_mask = self.compute_risk_scores(
|
||
|
|
indices, age_cutoff_days, horizon_days
|
||
|
|
)
|
||
|
|
|
||
|
|
# Combine anchor availability with per-disease validity
|
||
|
|
valid_mask = valid_mask & anchor_mask.astype(bool)[:, None]
|
||
|
|
|
||
|
|
# Compute metrics
|
||
|
|
print(" Computing AUC...")
|
||
|
|
auc_scores = self.compute_auc_per_disease(
|
||
|
|
risk_scores, labels_array, valid_mask)
|
||
|
|
mean_auc = np.nanmean(list(auc_scores.values()))
|
||
|
|
|
||
|
|
print(" Computing Brier Score...")
|
||
|
|
brier_metrics = self.compute_brier_score(
|
||
|
|
risk_scores, labels_array, valid_mask)
|
||
|
|
|
||
|
|
# Only compute patient-level and population metrics for complete_case
|
||
|
|
if track == 'complete_case':
|
||
|
|
print(" Computing Disease-Capture@K...")
|
||
|
|
capture_metrics = self.compute_disease_capture_at_k(
|
||
|
|
risk_scores, labels_array, valid_mask
|
||
|
|
)
|
||
|
|
|
||
|
|
print(" Computing Lift & Yield...")
|
||
|
|
lift_yield_metrics = self.compute_lift_and_yield(
|
||
|
|
risk_scores, labels_array, valid_mask
|
||
|
|
)
|
||
|
|
|
||
|
|
print(" Computing DCA...")
|
||
|
|
dca_metrics = self.compute_dca_net_benefit(
|
||
|
|
risk_scores, labels_array, valid_mask
|
||
|
|
)
|
||
|
|
|
||
|
|
results[track] = {
|
||
|
|
'n_patients': len(indices),
|
||
|
|
'n_valid': int(valid_mask.sum()),
|
||
|
|
'n_valid_patients': int(valid_mask.any(axis=1).sum()),
|
||
|
|
'auc_per_disease': auc_scores,
|
||
|
|
'mean_auc': mean_auc,
|
||
|
|
'brier_score': brier_metrics['brier_score'],
|
||
|
|
'brier_skill_score': brier_metrics['brier_skill_score'],
|
||
|
|
'disease_capture_at_k': capture_metrics,
|
||
|
|
'lift_and_yield': lift_yield_metrics,
|
||
|
|
'dca': dca_metrics,
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
# Clean control track: only discrimination metrics
|
||
|
|
results[track] = {
|
||
|
|
'n_patients': len(indices),
|
||
|
|
'n_valid': int(valid_mask.sum()),
|
||
|
|
'n_valid_patients': int(valid_mask.any(axis=1).sum()),
|
||
|
|
'auc_per_disease': auc_scores,
|
||
|
|
'mean_auc': mean_auc,
|
||
|
|
}
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def run_full_evaluation(self) -> Dict:
|
||
|
|
"""
|
||
|
|
Run complete landmark analysis across all cutoffs and horizons.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
all_results: Nested dictionary with all evaluation results
|
||
|
|
"""
|
||
|
|
all_results = {
|
||
|
|
'age_cutoffs': self.age_cutoffs,
|
||
|
|
'horizons': self.horizons,
|
||
|
|
'landmarks': [],
|
||
|
|
}
|
||
|
|
|
||
|
|
# Evaluate each landmark
|
||
|
|
for age_cutoff in self.age_cutoffs:
|
||
|
|
for horizon in self.horizons:
|
||
|
|
landmark_results = self.evaluate_landmark(age_cutoff, horizon)
|
||
|
|
all_results['landmarks'].append(landmark_results)
|
||
|
|
|
||
|
|
return all_results
|
||
|
|
|
||
|
|
|
||
|
|
def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
|
||
|
|
"""
|
||
|
|
Load trained model and configuration from run directory.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
run_dir: Path to run directory containing train_config.json and best_model.pt
|
||
|
|
device: Device to load model on
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
model, head, loss_fn, dataset, config
|
||
|
|
"""
|
||
|
|
run_path = Path(run_dir)
|
||
|
|
|
||
|
|
# Load config
|
||
|
|
config_path = run_path / 'train_config.json'
|
||
|
|
with open(config_path, 'r') as f:
|
||
|
|
config = json.load(f)
|
||
|
|
|
||
|
|
print(f"Loading model from {run_dir}")
|
||
|
|
print(f"Model type: {config['model_type']}")
|
||
|
|
print(f"Loss type: {config['loss_type']}")
|
||
|
|
|
||
|
|
# Load dataset to get dimensions
|
||
|
|
data_prefix = config['data_prefix']
|
||
|
|
|
||
|
|
# Determine covariate list based on full_cov
|
||
|
|
if config['full_cov']:
|
||
|
|
covariate_list = None # Use all covariates
|
||
|
|
else:
|
||
|
|
# Use partial covariates (define your partial list here)
|
||
|
|
covariate_list = ['age_at_assessment',
|
||
|
|
'bmi', 'smoking_status'] # Example
|
||
|
|
|
||
|
|
dataset = HealthDataset(
|
||
|
|
data_prefix=f"{data_prefix}_test",
|
||
|
|
covariate_list=covariate_list,
|
||
|
|
cache_event_tensors=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Determine output dimensions based on loss type
|
||
|
|
import math
|
||
|
|
if config['loss_type'] == 'exponential':
|
||
|
|
out_dims = [dataset.n_disease]
|
||
|
|
elif config['loss_type'] == 'discrete_time_cif':
|
||
|
|
# logits shape (M, K+1, n_bins+1)
|
||
|
|
bin_edges = config.get(
|
||
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')])
|
||
|
|
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||
|
|
elif config['loss_type'] == 'pwe_cif':
|
||
|
|
# Piecewise-exponential requires finite edges
|
||
|
|
bin_edges = config.get(
|
||
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0])
|
||
|
|
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
||
|
|
n_bins = len(pwe_edges) - 1
|
||
|
|
# logits shape (M, K, n_bins)
|
||
|
|
out_dims = [dataset.n_disease, n_bins]
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
||
|
|
|
||
|
|
# Build model
|
||
|
|
if config['model_type'] == 'delphi_fork':
|
||
|
|
model = DelphiFork(
|
||
|
|
n_disease=dataset.n_disease,
|
||
|
|
n_tech_tokens=2, # PAD=0, DOA=1
|
||
|
|
n_embd=config['n_embd'],
|
||
|
|
n_head=config['n_head'],
|
||
|
|
n_layer=config['n_layer'],
|
||
|
|
n_cont=dataset.n_cont,
|
||
|
|
n_cate=dataset.n_cate,
|
||
|
|
cate_dims=dataset.cate_dims,
|
||
|
|
age_encoder_type=config['age_encoder'],
|
||
|
|
pdrop=config['pdrop'],
|
||
|
|
)
|
||
|
|
elif config['model_type'] == 'sap_delphi':
|
||
|
|
model = SapDelphi(
|
||
|
|
n_disease=dataset.n_disease,
|
||
|
|
n_tech_tokens=2,
|
||
|
|
n_embd=config['n_embd'],
|
||
|
|
n_head=config['n_head'],
|
||
|
|
n_layer=config['n_layer'],
|
||
|
|
n_cont=dataset.n_cont,
|
||
|
|
n_cate=dataset.n_cate,
|
||
|
|
cate_dims=dataset.cate_dims,
|
||
|
|
age_encoder_type=config['age_encoder'],
|
||
|
|
pdrop=config['pdrop'],
|
||
|
|
pretrained_weights_path=config.get('pretrained_emd_path'),
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown model type: {config['model_type']}")
|
||
|
|
|
||
|
|
# Build head
|
||
|
|
head = SimpleHead(
|
||
|
|
n_embd=config['n_embd'],
|
||
|
|
out_dims=out_dims,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Load model weights (checkpoint contains model and head state dicts)
|
||
|
|
model_path = run_path / 'best_model.pt'
|
||
|
|
checkpoint = torch.load(model_path, map_location=device)
|
||
|
|
|
||
|
|
# The checkpoint is a dict with 'model_state_dict' and 'head_state_dict'
|
||
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
||
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
head.load_state_dict(checkpoint['head_state_dict'])
|
||
|
|
print("Loaded model and head from checkpoint")
|
||
|
|
else:
|
||
|
|
raise ValueError(
|
||
|
|
"Checkpoint format not recognized. Expected 'model_state_dict' and 'head_state_dict' keys.")
|
||
|
|
|
||
|
|
model.to(device)
|
||
|
|
head.to(device)
|
||
|
|
|
||
|
|
# Build loss function
|
||
|
|
if config['loss_type'] == 'exponential':
|
||
|
|
loss_fn = ExponentialNLLLoss(
|
||
|
|
lambda_reg=config.get('lambda_reg', 0.0)
|
||
|
|
)
|
||
|
|
elif config['loss_type'] == 'discrete_time_cif':
|
||
|
|
loss_fn = DiscreteTimeCIFNLLLoss(
|
||
|
|
bin_edges=config.get(
|
||
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]),
|
||
|
|
lambda_reg=config.get('lambda_reg', 0.0),
|
||
|
|
)
|
||
|
|
elif config['loss_type'] == 'pwe_cif':
|
||
|
|
loss_fn = PiecewiseExponentialCIFNLLLoss(
|
||
|
|
bin_edges=config.get(
|
||
|
|
'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]),
|
||
|
|
lambda_reg=config.get('lambda_reg', 0.0),
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
||
|
|
|
||
|
|
return model, head, loss_fn, dataset, config
|
||
|
|
|
||
|
|
|
||
|
|
def print_summary(results: Dict):
|
||
|
|
"""
|
||
|
|
Print summary of evaluation results.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
results: Results dictionary
|
||
|
|
"""
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
print("EVALUATION SUMMARY")
|
||
|
|
print("=" * 80)
|
||
|
|
|
||
|
|
for landmark in results['landmarks']:
|
||
|
|
age = landmark['age_cutoff']
|
||
|
|
horizon = landmark['horizon']
|
||
|
|
|
||
|
|
print(f"\nLandmark: Age {age}, Horizon {horizon}y")
|
||
|
|
print("-" * 40)
|
||
|
|
|
||
|
|
# Complete-case results
|
||
|
|
if 'complete_case' in landmark and landmark['complete_case']:
|
||
|
|
cc = landmark['complete_case']
|
||
|
|
print(f" Complete-Case Track:")
|
||
|
|
print(f" Patients: {cc['n_patients']}")
|
||
|
|
print(f" Mean AUC: {cc['mean_auc']:.4f}")
|
||
|
|
print(f" Brier Score: {cc['brier_score']:.4f}")
|
||
|
|
print(f" Brier Skill Score: {cc['brier_skill_score']:.4f}")
|
||
|
|
|
||
|
|
# Show top-K capture rates (average across diseases)
|
||
|
|
if 'disease_capture_at_k' in cc:
|
||
|
|
print(f" Disease Capture:")
|
||
|
|
for k in [5, 10, 20, 50]:
|
||
|
|
if k in cc['disease_capture_at_k']:
|
||
|
|
rates = list(cc['disease_capture_at_k'][k].values())
|
||
|
|
mean_rate = np.nanmean(rates)
|
||
|
|
print(f" Top-{k}: {mean_rate:.3f}")
|
||
|
|
|
||
|
|
# Show lift and yield
|
||
|
|
if 'lift_and_yield' in cc:
|
||
|
|
print(f" Lift & Yield:")
|
||
|
|
overall = cc['lift_and_yield'].get('overall', {}) if isinstance(
|
||
|
|
cc['lift_and_yield'], dict) else {}
|
||
|
|
for frac in [0.01, 0.05, 0.10]:
|
||
|
|
if frac in overall:
|
||
|
|
lift = overall[frac].get('lift', np.nan)
|
||
|
|
yield_val = overall[frac].get('yield', np.nan)
|
||
|
|
print(
|
||
|
|
f" Overall Top {int(frac*100)}%: Lift={lift:.2f}, Yield={yield_val:.3f}")
|
||
|
|
|
||
|
|
# Clean control results
|
||
|
|
if 'clean_control' in landmark and landmark['clean_control']:
|
||
|
|
clean = landmark['clean_control']
|
||
|
|
print(f" Clean-Control Track:")
|
||
|
|
print(f" Patients: {clean['n_patients']}")
|
||
|
|
print(f" Mean AUC: {clean['mean_auc']:.4f}")
|
||
|
|
|
||
|
|
print("\n" + "=" * 80)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description='Evaluate longitudinal health prediction model using landmark analysis'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--run_dir',
|
||
|
|
type=str,
|
||
|
|
required=True,
|
||
|
|
help='Path to run directory containing train_config.json and best_model.pt'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--output',
|
||
|
|
type=str,
|
||
|
|
default=None,
|
||
|
|
help='Output path for results JSON (default: <run_dir>/evaluation_results.json)'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--out_dir',
|
||
|
|
type=str,
|
||
|
|
default=None,
|
||
|
|
help='Directory to write CSV outputs (default: <run_dir>/evaluation_outputs)'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--device',
|
||
|
|
type=str,
|
||
|
|
default='cuda' if torch.cuda.is_available() else 'cpu',
|
||
|
|
help='Device to run evaluation on'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--batch_size',
|
||
|
|
type=int,
|
||
|
|
default=256,
|
||
|
|
help='Batch size for evaluation'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--num_workers',
|
||
|
|
type=int,
|
||
|
|
default=4,
|
||
|
|
help='Number of data loader workers'
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Load model and dataset
|
||
|
|
model, head, loss_fn, dataset, config = load_model_and_config(
|
||
|
|
args.run_dir, args.device)
|
||
|
|
|
||
|
|
# Create evaluator
|
||
|
|
evaluator = LandmarkEvaluator(
|
||
|
|
model=model,
|
||
|
|
head=head,
|
||
|
|
loss_fn=loss_fn,
|
||
|
|
dataset=dataset,
|
||
|
|
device=args.device,
|
||
|
|
batch_size=args.batch_size,
|
||
|
|
num_workers=args.num_workers,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Run evaluation
|
||
|
|
print("\nStarting landmark analysis evaluation...")
|
||
|
|
print(f"Age cutoffs: {evaluator.age_cutoffs}")
|
||
|
|
print(f"Horizons: {evaluator.horizons}")
|
||
|
|
|
||
|
|
results = evaluator.run_full_evaluation()
|
||
|
|
|
||
|
|
# Add metadata
|
||
|
|
results['metadata'] = {
|
||
|
|
'run_dir': args.run_dir,
|
||
|
|
'config': config,
|
||
|
|
'n_diseases': dataset.n_disease,
|
||
|
|
'device': args.device,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Save results (CSV bundle + single JSON summary)
|
||
|
|
if args.out_dir is None:
|
||
|
|
args.out_dir = os.path.join(args.run_dir, 'evaluation_outputs')
|
||
|
|
csv_paths = save_results_csv_bundle(results, args.out_dir)
|
||
|
|
|
||
|
|
summary = {
|
||
|
|
'metadata': results.get('metadata', {}),
|
||
|
|
'age_cutoffs': results.get('age_cutoffs', []),
|
||
|
|
'horizons': results.get('horizons', []),
|
||
|
|
'csv_outputs': csv_paths,
|
||
|
|
'notes': {
|
||
|
|
'metrics': [
|
||
|
|
'AUC (per disease + mean)',
|
||
|
|
'Brier Score / Brier Skill Score (complete-case only)',
|
||
|
|
'Disease-Capture@K (complete-case only)',
|
||
|
|
'Lift/Yield (complete-case only)',
|
||
|
|
'Decision Curve Analysis (complete-case only)',
|
||
|
|
],
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
if args.output is None:
|
||
|
|
args.output = os.path.join(args.run_dir, 'evaluation_summary.json')
|
||
|
|
save_summary_json(summary, args.output)
|
||
|
|
|
||
|
|
print(f"\nWrote CSV outputs to: {args.out_dir}")
|
||
|
|
print(f"Wrote summary JSON to: {args.output}")
|
||
|
|
|
||
|
|
# Print summary
|
||
|
|
print_summary(results)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|