2026-01-10 11:37:12 +08:00
import argparse
import csv
import json
import math
import os
import random
import statistics
from dataclasses import dataclass
from typing import Any , Dict , Iterable , List , Optional , Sequence , Tuple
import numpy as np
import torch
import torch . nn . functional as F
from torch . utils . data import DataLoader , random_split
from dataset import HealthDataset , health_collate_fn
from model import DelphiFork , SapDelphi , SimpleHead
# ============================================================
# Constants / defaults (aligned with evaluate_prompt.md)
# ============================================================
DEFAULT_BIN_EDGES = [ 0.0 , 0.24 , 0.72 , 1.61 , 3.84 , 10.0 , 31.0 , float ( " inf " ) ]
DEFAULT_EVAL_HORIZONS = [ 0.72 , 1.61 , 3.84 , 10.0 ]
DAYS_PER_YEAR = 365.25
2026-01-10 17:00:16 +08:00
DEFAULT_DEATH_CAUSE_ID = 1256
2026-01-10 11:37:12 +08:00
# ============================================================
# Model specs
# ============================================================
@dataclass ( frozen = True )
class ModelSpec :
name : str
model_type : str # delphi_fork | sap_delphi
loss_type : str # exponential | discrete_time_cif
full_cov : bool
checkpoint_path : str
# ============================================================
# Determinism
# ============================================================
def set_deterministic ( seed : int ) - > None :
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
torch . cuda . manual_seed_all ( seed )
torch . backends . cudnn . deterministic = True
torch . backends . cudnn . benchmark = False
# ============================================================
# Utilities
# ============================================================
def _parse_bool ( x : Any ) - > bool :
if isinstance ( x , bool ) :
return x
s = str ( x ) . strip ( ) . lower ( )
if s in { " true " , " 1 " , " yes " , " y " } :
return True
if s in { " false " , " 0 " , " no " , " n " } :
return False
raise ValueError ( f " Cannot parse boolean: { x !r} " )
def load_models_json ( path : str ) - > List [ ModelSpec ] :
with open ( path , " r " ) as f :
data = json . load ( f )
if not isinstance ( data , list ) :
raise ValueError ( " models_json must be a list of model entries " )
specs : List [ ModelSpec ] = [ ]
for row in data :
specs . append (
ModelSpec (
name = str ( row [ " name " ] ) ,
model_type = str ( row [ " model_type " ] ) ,
loss_type = str ( row [ " loss_type " ] ) ,
full_cov = _parse_bool ( row [ " full_cov " ] ) ,
checkpoint_path = str ( row [ " checkpoint_path " ] ) ,
)
)
return specs
def load_train_config_for_checkpoint ( checkpoint_path : str ) - > Dict [ str , Any ] :
run_dir = os . path . dirname ( os . path . abspath ( checkpoint_path ) )
cfg_path = os . path . join ( run_dir , " train_config.json " )
with open ( cfg_path , " r " ) as f :
cfg = json . load ( f )
return cfg
def build_eval_subset (
dataset : HealthDataset ,
train_ratio : float ,
val_ratio : float ,
seed : int ,
split : str ,
) :
n_total = len ( dataset )
n_train = int ( n_total * train_ratio )
n_val = int ( n_total * val_ratio )
n_test = n_total - n_train - n_val
train_ds , val_ds , test_ds = random_split (
dataset ,
[ n_train , n_val , n_test ] ,
generator = torch . Generator ( ) . manual_seed ( seed ) ,
)
if split == " train " :
return train_ds
if split == " val " :
return val_ds
if split == " test " :
return test_ds
if split == " all " :
return dataset
raise ValueError ( " split must be one of: train, val, test, all " )
# ============================================================
# Context selection (anti-leakage)
# ============================================================
def select_context_indices (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
offset_years : float ,
) - > Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
""" Select per-sample prediction context index.
IMPORTANT SEMANTICS :
- The last observed token time is treated as the FOLLOW - UP END time .
- We pick the last valid token with time < = ( followup_end_time - offset ) .
- We do NOT interpret followup_end_time as an event time .
Returns :
keep_mask : ( B , ) bool , which samples have a valid context
t_ctx : ( B , ) long , index into sequence
t_ctx_time : ( B , ) float , time ( days ) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid . sum ( dim = 1 )
last_idx = torch . clamp ( lengths - 1 , min = 0 )
b = torch . arange ( event_seq . size ( 0 ) , device = event_seq . device )
followup_end_time = time_seq [ b , last_idx ]
t_cut = followup_end_time - ( offset_years * DAYS_PER_YEAR )
eligible = valid & ( time_seq < = t_cut . unsqueeze ( 1 ) )
eligible_counts = eligible . sum ( dim = 1 )
keep = eligible_counts > 0
t_ctx = torch . clamp ( eligible_counts - 1 , min = 0 ) . to ( torch . long )
t_ctx_time = time_seq [ b , t_ctx ]
return keep , t_ctx , t_ctx_time
def next_event_after_context (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
) - > Tuple [ torch . Tensor , torch . Tensor ] :
""" Return next disease event after context.
Returns :
dt_years : ( B , ) float , time to next disease in years ; + inf if none
cause : ( B , ) long , disease id in [ 0 , K ) for next event ; - 1 if none
"""
B , L = event_seq . shape
b = torch . arange ( B , device = event_seq . device )
t0 = time_seq [ b , t_ctx ]
# Allow same-day events while excluding the context token itself.
# We rely on time-sorted sequences and select the FIRST valid future event by index.
idxs = torch . arange ( L , device = event_seq . device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
future = ( idxs > t_ctx . unsqueeze ( 1 ) ) & ( event_seq > = 2 ) & ( event_seq != 0 )
idx_min = torch . where (
future , idxs , torch . full_like ( idxs , L ) ) . min ( dim = 1 ) . values
has = idx_min < L
t_next = torch . where ( has , idx_min , torch . zeros_like ( idx_min ) )
t_next_time = time_seq [ b , t_next ]
dt_days = t_next_time - t0
dt_years = dt_days / DAYS_PER_YEAR
dt_years = torch . where (
has , dt_years , torch . full_like ( dt_years , float ( " inf " ) ) )
cause_token = event_seq [ b , t_next ]
cause = ( cause_token - 2 ) . to ( torch . long )
cause = torch . where ( has , cause , torch . full_like ( cause , - 1 ) )
return dt_years , cause
def multi_hot_ever_within_horizon (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
tau_years : float ,
n_disease : int ,
) - > torch . Tensor :
""" Binary labels: disease k occurs within tau after context (any occurrence). """
B , L = event_seq . shape
b = torch . arange ( B , device = event_seq . device )
t0 = time_seq [ b , t_ctx ]
t1 = t0 + ( tau_years * DAYS_PER_YEAR )
idxs = torch . arange ( L , device = event_seq . device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
# Include same-day events after context, exclude any token at/before context index.
in_window = (
( idxs > t_ctx . unsqueeze ( 1 ) )
& ( time_seq > = t0 . unsqueeze ( 1 ) )
& ( time_seq < = t1 . unsqueeze ( 1 ) )
& ( event_seq > = 2 )
& ( event_seq != 0 )
)
if not in_window . any ( ) :
return torch . zeros ( ( B , n_disease ) , dtype = torch . bool , device = event_seq . device )
b_idx , t_idx = in_window . nonzero ( as_tuple = True )
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
y = torch . zeros ( ( B , n_disease ) , dtype = torch . bool , device = event_seq . device )
y [ b_idx , disease_ids ] = True
return y
def multi_hot_ever_after_context_anytime (
event_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
n_disease : int ,
) - > torch . Tensor :
""" Binary labels: disease k occurs ANYTIME after the prediction context.
This is Delphi2M - compatible for Task A case / control definition .
Same - day events are included as long as they occur after the context token index .
"""
B , L = event_seq . shape
idxs = torch . arange ( L , device = event_seq . device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
future = ( idxs > t_ctx . unsqueeze ( 1 ) ) & ( event_seq > = 2 ) & ( event_seq != 0 )
y = torch . zeros ( ( B , n_disease ) , dtype = torch . bool , device = event_seq . device )
if not future . any ( ) :
return y
b_idx , t_idx = future . nonzero ( as_tuple = True )
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
y [ b_idx , disease_ids ] = True
return y
def multi_hot_selected_causes_within_horizon (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
tau_years : float ,
cause_ids : torch . Tensor ,
n_disease : int ,
) - > torch . Tensor :
""" Labels for selected causes only: does cause k occur within tau after context? """
B , L = event_seq . shape
device = event_seq . device
b = torch . arange ( B , device = device )
t0 = time_seq [ b , t_ctx ]
t1 = t0 + ( tau_years * DAYS_PER_YEAR )
idxs = torch . arange ( L , device = device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
in_window = (
( idxs > t_ctx . unsqueeze ( 1 ) )
& ( time_seq > = t0 . unsqueeze ( 1 ) )
& ( time_seq < = t1 . unsqueeze ( 1 ) )
& ( event_seq > = 2 )
& ( event_seq != 0 )
)
out = torch . zeros ( ( B , cause_ids . numel ( ) ) , dtype = torch . bool , device = device )
if not in_window . any ( ) :
return out
b_idx , t_idx = in_window . nonzero ( as_tuple = True )
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch . zeros ( ( int ( n_disease ) , ) , dtype = torch . bool , device = device )
selected [ cause_ids ] = True
keep = selected [ disease_ids ]
if not keep . any ( ) :
return out
b_idx = b_idx [ keep ]
disease_ids = disease_ids [ keep ]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch . full ( ( int ( n_disease ) , ) , - 1 , dtype = torch . long , device = device )
lookup [ cause_ids ] = torch . arange ( cause_ids . numel ( ) , device = device )
local = lookup [ disease_ids ]
out [ b_idx , local ] = True
return out
# ============================================================
# CIF conversion
# ============================================================
def cifs_from_exponential_logits (
logits : torch . Tensor ,
taus : Sequence [ float ] ,
eps : float = 1e-6 ,
return_survival : bool = False ,
) - > torch . Tensor :
""" Convert exponential cause-specific logits -> CIFs at taus.
logits : ( B , K )
returns : ( B , K , H ) or ( cif , survival ) if return_survival
"""
hazards = F . softplus ( logits ) + eps
total = hazards . sum ( dim = 1 , keepdim = True ) # (B,1)
taus_t = torch . tensor ( list ( taus ) , device = logits . device ,
dtype = hazards . dtype ) . view ( 1 , 1 , - 1 )
total_h = total . unsqueeze ( - 1 ) # (B,1,1)
# (1 - exp(-Lambda * tau))
one_minus_surv = 1.0 - torch . exp ( - total_h * taus_t )
frac = hazards / torch . clamp ( total , min = eps )
cif = frac . unsqueeze ( - 1 ) * one_minus_surv # (B,K,H)
# If total==0, set to 0
cif = torch . where ( total_h > 0 , cif , torch . zeros_like ( cif ) )
if not return_survival :
return cif
survival = torch . exp ( - total_h * taus_t ) . squeeze ( 1 ) # (B,H)
2026-01-10 11:42:03 +08:00
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
nonzero = ( total . squeeze ( 1 ) > 0 ) . unsqueeze ( 1 )
survival = torch . where ( nonzero , survival , torch . ones_like ( survival ) )
2026-01-10 11:37:12 +08:00
return cif , survival
def cifs_from_discrete_time_logits (
logits : torch . Tensor ,
bin_edges : Sequence [ float ] ,
taus : Sequence [ float ] ,
return_survival : bool = False ,
) - > torch . Tensor :
""" Convert discrete-time CIF logits -> CIFs at taus.
logits : ( B , K + 1 , n_bins + 1 )
bin_edges : len = n_bins + 1 ( including 0 and inf )
taus : subset of finite bin edges ( recommended )
returns : ( B , K , H ) or ( cif , survival ) if return_survival
"""
if logits . ndim != 3 :
raise ValueError ( " Expected logits shape (B, K+1, n_bins+1) " )
B , K_plus_1 , n_bins_plus_1 = logits . shape
K = K_plus_1 - 1
edges = [ float ( x ) for x in bin_edges ]
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
finite_edges = [ e for e in edges [ 1 : ] if math . isfinite ( e ) ]
n_bins = len ( finite_edges )
if n_bins_plus_1 != len ( edges ) :
raise ValueError ( " logits last dim must match len(bin_edges) " )
probs = torch . softmax ( logits , dim = 1 ) # (B, K+1, n_bins+1)
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
hazards = probs [ : , : K , 1 : 1 + n_bins ] # (B,K,n_bins)
p_comp = probs [ : , K , 1 : 1 + n_bins ] # (B,n_bins)
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
ones = torch . ones ( ( B , 1 ) , device = logits . device , dtype = probs . dtype )
cum = torch . cumprod ( p_comp , dim = 1 )
s_prev = torch . cat ( [ ones , cum [ : , : - 1 ] ] , dim = 1 ) # (B, n_bins)
cif_bins = torch . cumsum ( s_prev . unsqueeze (
1 ) * hazards , dim = 2 ) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np . asarray ( finite_edges , dtype = float )
tau_to_idx : List [ int ] = [ ]
for tau in taus :
tau_f = float ( tau )
if not math . isfinite ( tau_f ) :
raise ValueError ( " taus must be finite for discrete-time CIF " )
diffs = np . abs ( finite_edges_arr - tau_f )
j = int ( np . argmin ( diffs ) )
if diffs [ j ] > 1e-6 :
raise ValueError (
f " tau= { tau_f } not close to any finite bin edge (min |edge-tau|= { diffs [ j ] } ) "
)
tau_to_idx . append ( j )
idx = torch . tensor ( tau_to_idx , device = logits . device , dtype = torch . long )
cif = cif_bins . index_select ( dim = 2 , index = idx ) # (B,K,H)
if not return_survival :
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins . index_select ( dim = 1 , index = idx ) # (B,H)
return cif , survival
# ============================================================
# CIF integrity checks
# ============================================================
def check_cif_integrity (
cause_cif : np . ndarray ,
horizons : Sequence [ float ] ,
* ,
tol : float = 1e-6 ,
name : str = " " ,
strict : bool = False ,
survival : Optional [ np . ndarray ] = None ,
) - > Tuple [ bool , List [ str ] ] :
""" Run basic sanity checks on CIF arrays.
Args :
cause_cif : ( N , K , H )
horizons : length H
tol : tolerance for inequalities
name : model name for messages
strict : if True , raise ValueError on first failure
survival : optional ( N , H ) survival values at the same horizons
Returns :
( integrity_ok , notes )
"""
notes : List [ str ] = [ ]
model_tag = f " [ { name } ] " if name else " "
def _fail ( msg : str ) - > None :
full = model_tag + msg
if strict :
raise ValueError ( full )
print ( " WARNING: " , full )
notes . append ( msg )
cif = np . asarray ( cause_cif )
if cif . ndim != 3 :
_fail ( f " integrity: expected cause_cif ndim=3, got { cif . ndim } " )
return False , notes
N , K , H = cif . shape
if H != len ( horizons ) :
_fail (
f " integrity: horizon length mismatch (H= { H } , len(horizons)= { len ( horizons ) } ) " )
# (5) Finite
if not np . isfinite ( cif ) . all ( ) :
_fail ( " integrity: non-finite values (NaN/Inf) in cause_cif " )
# (1) Range
cmin = float ( np . nanmin ( cif ) )
cmax = float ( np . nanmax ( cif ) )
if cmin < - tol :
_fail ( f " integrity: range min= { cmin } < -tol= { - tol } " )
if cmax > 1.0 + tol :
_fail ( f " integrity: range max= { cmax } > 1+tol= { 1.0 + tol } " )
# (2) Monotonicity in horizons (per n,k)
diffs = np . diff ( cif , axis = 2 )
if diffs . size > 0 :
if np . nanmin ( diffs ) < - tol :
_fail ( " integrity: monotonicity violated (found negative diff along horizons) " )
# (3) Probability mass: sum_k CIF <= 1
mass = np . sum ( cif , axis = 1 ) # (N,H)
mass_max = float ( np . nanmax ( mass ) )
if mass_max > 1.0 + tol :
_fail ( f " integrity: probability mass exceeds 1 (max sum_k= { mass_max } ) " )
# (4) Conservation with survival, if provided
if survival is None :
warn = " integrity: survival not provided; skipping conservation check "
if strict :
# still skip (requested behavior), but keep message for context
notes . append ( warn )
else :
print ( " WARNING: " , model_tag + warn )
notes . append ( warn )
else :
s = np . asarray ( survival , dtype = float )
if s . shape != ( N , H ) :
_fail (
f " integrity: survival shape mismatch (got { s . shape } , expected { ( N , H ) } ) " )
else :
recon = 1.0 - s
err = np . abs ( recon - mass )
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
tol_cons = max ( float ( tol ) , 1e-4 )
if float ( np . nanmax ( err ) ) > tol_cons :
_fail (
f " integrity: conservation violated (max |(1-surv)-sum_cif|= { float ( np . nanmax ( err ) ) } , tol= { tol_cons } ) " )
ok = len ( [ n for n in notes if not n . endswith (
" skipping conservation check " ) ] ) == 0
return ok , notes
# ============================================================
# Metrics
# ============================================================
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
def compute_midrank ( x : np . ndarray ) - > np . ndarray :
x = np . asarray ( x , dtype = float )
order = np . argsort ( x )
z = x [ order ]
n = x . shape [ 0 ]
t = np . zeros ( n , dtype = float )
i = 0
while i < n :
j = i
while j < n and z [ j ] == z [ i ] :
j + = 1
t [ i : j ] = 0.5 * ( i + j - 1 ) + 1.0
i = j
out = np . empty ( n , dtype = float )
out [ order ] = t
return out
def fastDeLong ( predictions_sorted_transposed : np . ndarray , label_1_count : int ) - > Tuple [ np . ndarray , np . ndarray ] :
""" Fast DeLong method for computing AUC covariance.
predictions_sorted_transposed : shape ( n_classifiers , n_examples ) with positive examples first .
"""
preds = np . asarray ( predictions_sorted_transposed , dtype = float )
m = int ( label_1_count )
n = int ( preds . shape [ 1 ] - m )
if m < = 0 or n < = 0 :
return np . array ( [ float ( " nan " ) ] ) , np . array ( [ [ float ( " nan " ) ] ] )
pos = preds [ : , : m ]
neg = preds [ : , m : ]
tx = np . array ( [ compute_midrank ( x ) for x in pos ] )
ty = np . array ( [ compute_midrank ( x ) for x in neg ] )
tz = np . array ( [ compute_midrank ( x ) for x in preds ] )
aucs = ( tz [ : , : m ] . sum ( axis = 1 ) - m * ( m + 1 ) / 2.0 ) / ( m * n )
v01 = ( tz [ : , : m ] - tx ) / n
v10 = 1.0 - ( tz [ : , m : ] - ty ) / m
if v01 . shape [ 0 ] > 1 :
sx = np . cov ( v01 )
sy = np . cov ( v10 )
else :
# Single-classifier case: compute row-wise variance (do not flatten).
var_v01 = float ( np . var ( v01 , axis = 1 , ddof = 1 ) [ 0 ] )
var_v10 = float ( np . var ( v10 , axis = 1 , ddof = 1 ) [ 0 ] )
sx = np . array ( [ [ var_v01 ] ] )
sy = np . array ( [ [ var_v10 ] ] )
delong_cov = sx / m + sy / n
return aucs , delong_cov
def calc_auc_variance ( ground_truth : np . ndarray , predictions : np . ndarray ) - > Tuple [ float , float ] :
y = np . asarray ( ground_truth , dtype = int )
p = np . asarray ( predictions , dtype = float )
if y . ndim != 1 or p . ndim != 1 or y . shape [ 0 ] != p . shape [ 0 ] :
raise ValueError ( " calc_auc_variance expects 1D arrays of equal length " )
m = int ( np . sum ( y == 1 ) )
n = int ( np . sum ( y == 0 ) )
if m == 0 or n == 0 :
return float ( " nan " ) , float ( " nan " )
order = np . argsort ( - y ) # positives first
preds_sorted = p [ order ]
aucs , cov = fastDeLong ( preds_sorted [ np . newaxis , : ] , m )
auc = float ( aucs [ 0 ] )
var = float ( cov [ 0 , 0 ] )
return auc , var
def delong_ci ( ground_truth : np . ndarray , predictions : np . ndarray , alpha : float = 0.95 ) - > Tuple [ float , float , float ] :
""" Return (auc, ci_low, ci_high) using DeLong variance and normal CI. """
auc , var = calc_auc_variance ( ground_truth , predictions )
if not np . isfinite ( var ) or var < = 0 :
print ( " WARNING: DeLong variance is non-positive or NaN; CI set to NaN " )
return float ( auc ) , float ( " nan " ) , float ( " nan " )
sd = math . sqrt ( var )
z = statistics . NormalDist ( ) . inv_cdf ( 1.0 - ( 1.0 - float ( alpha ) ) / 2.0 )
lo = max ( 0.0 , auc - z * sd )
hi = min ( 1.0 , auc + z * sd )
return float ( auc ) , float ( lo ) , float ( hi )
def roc_auc_rank ( y_true : np . ndarray , y_score : np . ndarray ) - > float :
""" Rank-based ROC AUC via Mann– Whitney U statistic (ties handled by midranks).
Returns NaN for degenerate labels .
"""
y = np . asarray ( y_true , dtype = int )
s = np . asarray ( y_score , dtype = float )
if y . ndim != 1 or s . ndim != 1 or y . shape [ 0 ] != s . shape [ 0 ] :
raise ValueError ( " roc_auc_rank expects 1D arrays of equal length " )
m = int ( np . sum ( y == 1 ) )
n = int ( np . sum ( y == 0 ) )
if m == 0 or n == 0 :
return float ( " nan " )
ranks = compute_midrank ( s )
sum_pos = float ( np . sum ( ranks [ y == 1 ] ) )
auc = ( sum_pos - m * ( m + 1 ) / 2.0 ) / ( m * n )
return float ( auc )
def bootstrap_auc_ci (
scores : np . ndarray ,
labels : np . ndarray ,
n_bootstrap : int ,
alpha : float = 0.95 ,
seed : int = 0 ,
) - > Tuple [ float , float , float ] :
""" Bootstrap CI for ROC AUC (percentile). """
rng = np . random . default_rng ( int ( seed ) )
scores = np . asarray ( scores , dtype = float )
labels = np . asarray ( labels , dtype = int )
n = labels . shape [ 0 ]
if n == 0 or np . all ( labels == labels [ 0 ] ) :
print ( " WARNING: bootstrap AUC CI degenerate labels; CI set to NaN " )
return float ( " nan " ) , float ( " nan " ) , float ( " nan " )
auc_full = roc_auc_rank ( labels , scores )
if not np . isfinite ( auc_full ) :
print ( " WARNING: bootstrap AUC CI degenerate labels; CI set to NaN " )
return float ( " nan " ) , float ( " nan " ) , float ( " nan " )
aucs : List [ float ] = [ ]
for _ in range ( int ( n_bootstrap ) ) :
idx = rng . integers ( 0 , n , size = n )
yb = labels [ idx ]
if np . all ( yb == yb [ 0 ] ) :
continue
pb = scores [ idx ]
auc = roc_auc_rank ( yb , pb )
if np . isfinite ( auc ) :
aucs . append ( float ( auc ) )
if len ( aucs ) < 10 :
print ( " WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN " )
return float ( auc_full ) , float ( " nan " ) , float ( " nan " )
lo_q = ( 1.0 - float ( alpha ) ) / 2.0
hi_q = 1.0 - lo_q
lo = float ( np . quantile ( aucs , lo_q ) )
hi = float ( np . quantile ( aucs , hi_q ) )
return float ( auc_full ) , lo , hi
def brier_score ( p : np . ndarray , y : np . ndarray ) - > float :
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
return float ( np . mean ( ( p - y ) * * 2 ) )
def calibration_deciles ( p : np . ndarray , y : np . ndarray , n_bins : int = 10 ) - > Dict [ str , Any ] :
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
# guard
if p . size == 0 :
2026-01-10 17:00:16 +08:00
return { " bins " : [ ] , " ici " : float ( " nan " ) }
2026-01-10 11:37:12 +08:00
edges = np . quantile ( p , np . linspace ( 0.0 , 1.0 , n_bins + 1 ) )
# make strictly increasing where possible
edges [ 0 ] = - np . inf
edges [ - 1 ] = np . inf
bins = [ ]
ici_accum = 0.0
n = p . shape [ 0 ]
for i in range ( n_bins ) :
mask = ( p > edges [ i ] ) & ( p < = edges [ i + 1 ] )
if not np . any ( mask ) :
continue
p_mean = float ( np . mean ( p [ mask ] ) )
y_mean = float ( np . mean ( y [ mask ] ) )
bins . append ( { " bin " : i , " p_mean " : p_mean ,
" y_mean " : y_mean , " n " : int ( mask . sum ( ) ) } )
ici_accum + = abs ( p_mean - y_mean )
ici = ici_accum / max ( len ( bins ) , 1 )
2026-01-10 17:00:16 +08:00
return { " bins " : bins , " ici " : float ( ici ) }
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
def _safe_float ( x : Any , default : float = float ( " nan " ) ) - > float :
try :
return float ( x )
except Exception :
return float ( default )
def _ensure_dir ( path : str ) - > None :
os . makedirs ( path , exist_ok = True )
def load_cause_names ( path : str = " labels.csv " ) - > Dict [ int , str ] :
""" Load 0-based cause_id -> name mapping.
labels . csv is assumed to be one label per line , in disease - id order .
"""
if not os . path . exists ( path ) :
return { }
mapping : Dict [ int , str ] = { }
with open ( path , " r " , encoding = " utf-8 " ) as f :
for i , line in enumerate ( f ) :
name = line . strip ( )
if name :
mapping [ int ( i ) ] = name
return mapping
def pick_focus_causes (
* ,
counts_within_tau : Optional [ np . ndarray ] ,
n_disease : int ,
death_cause_id : int = DEFAULT_DEATH_CAUSE_ID ,
k : int = 5 ,
) - > List [ int ] :
""" Pick focus causes for user-facing evaluation.
Rule :
1 ) Always include death_cause_id first .
2 ) Then add K additional causes by descending event count if available .
If counts_within_tau is None , fall back to descending cause_id coverage proxy .
Notes :
- counts_within_tau is expected to be shape ( n_disease , ) .
- Deterministic : ties broken by smaller cause id .
"""
n_disease_i = int ( n_disease )
if death_cause_id < 0 or death_cause_id > = n_disease_i :
print (
f " WARNING: death_cause_id= { death_cause_id } out of range (n_disease= { n_disease_i } ); "
" it will be omitted from focus causes. "
)
focus : List [ int ] = [ ]
else :
focus = [ int ( death_cause_id ) ]
candidates = [ i for i in range ( n_disease_i ) if i != int ( death_cause_id ) ]
if counts_within_tau is not None :
c = np . asarray ( counts_within_tau ) . astype ( float )
if c . shape [ 0 ] != n_disease_i :
print (
" WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering. "
)
counts_within_tau = None
else :
# Sort by (-count, cause_id)
order = sorted ( candidates , key = lambda i : ( - float ( c [ i ] ) , int ( i ) ) )
order = [ i for i in order if float ( c [ i ] ) > 0 ]
focus . extend ( [ int ( i ) for i in order [ : int ( k ) ] ] )
if counts_within_tau is None :
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
# (Real coverage requires data; this path is mostly for robustness.)
order = sorted ( candidates , key = lambda i : ( - int ( i ) ) )
focus . extend ( [ int ( i ) for i in order [ : int ( k ) ] ] )
# De-dup while preserving order
seen = set ( )
out : List [ int ] = [ ]
for cid in focus :
if cid not in seen :
out . append ( cid )
seen . add ( cid )
return out
def write_simple_csv ( path : str , fieldnames : List [ str ] , rows : List [ Dict [ str , Any ] ] ) - > None :
_ensure_dir ( os . path . dirname ( os . path . abspath ( path ) ) or " . " )
with open ( path , " w " , newline = " " , encoding = " utf-8 " ) as f :
w = csv . DictWriter ( f , fieldnames = fieldnames )
w . writeheader ( )
for r in rows :
w . writerow ( r )
def _sex_slices ( sex : Optional [ np . ndarray ] ) - > List [ Tuple [ str , Optional [ np . ndarray ] ] ] :
""" Return list of (sex_label, mask) slices including an ' all ' slice.
If sex is missing , returns only ( ' all ' , None ) .
"""
out : List [ Tuple [ str , Optional [ np . ndarray ] ] ] = [ ( " all " , None ) ]
if sex is None :
return out
s = np . asarray ( sex )
if s . ndim != 1 :
return out
for val in [ 0 , 1 ] :
m = ( s == val )
if int ( np . sum ( m ) ) > 0 :
out . append ( ( str ( val ) , m ) )
return out
def _quantile_edges ( p : np . ndarray , q : int ) - > np . ndarray :
edges = np . quantile ( p , np . linspace ( 0.0 , 1.0 , int ( q ) + 1 ) )
edges = np . asarray ( edges , dtype = float )
edges [ 0 ] = - np . inf
edges [ - 1 ] = np . inf
return edges
def compute_risk_stratification_bins (
p : np . ndarray ,
y : np . ndarray ,
* ,
q_default : int = 10 ,
) - > Tuple [ int , List [ Dict [ str , Any ] ] , Dict [ str , Any ] ] :
""" Compute quantile-based risk strata and a compact summary. """
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
n = int ( p . shape [ 0 ] )
if n == 0 :
return 0 , [ ] , {
" y_overall " : float ( " nan " ) ,
" top_decile_y_rate " : float ( " nan " ) ,
" bottom_half_y_rate " : float ( " nan " ) ,
" lift_top10_vs_bottom50 " : float ( " nan " ) ,
" slope_pred_vs_obs " : float ( " nan " ) ,
}
# Choose quantiles robustly.
q = int ( q_default )
if n < 200 :
q = 5
edges = _quantile_edges ( p , q )
y_overall = float ( np . mean ( y ) )
bin_rows : List [ Dict [ str , Any ] ] = [ ]
p_means : List [ float ] = [ ]
y_rates : List [ float ] = [ ]
n_bins : List [ int ] = [ ]
for i in range ( q ) :
mask = ( p > edges [ i ] ) & ( p < = edges [ i + 1 ] )
nb = int ( np . sum ( mask ) )
if nb == 0 :
# Keep the row for consistent plotting; set NaNs.
bin_rows . append (
{
" q " : int ( i + 1 ) ,
" n_bin " : 0 ,
" p_mean " : float ( " nan " ) ,
" y_rate " : float ( " nan " ) ,
" y_overall " : y_overall ,
" lift_vs_overall " : float ( " nan " ) ,
}
)
continue
p_mean = float ( np . mean ( p [ mask ] ) )
y_rate = float ( np . mean ( y [ mask ] ) )
lift = float ( y_rate / y_overall ) if y_overall > 0 else float ( " nan " )
bin_rows . append (
{
" q " : int ( i + 1 ) ,
" n_bin " : nb ,
" p_mean " : p_mean ,
" y_rate " : y_rate ,
" y_overall " : y_overall ,
" lift_vs_overall " : lift ,
}
)
p_means . append ( p_mean )
y_rates . append ( y_rate )
n_bins . append ( nb )
# Summary
top_mask = ( p > edges [ q - 1 ] ) & ( p < = edges [ q ] )
bot_half_mask = ( p > edges [ 0 ] ) & ( p < = edges [ q / / 2 ] )
top_y = float ( np . mean ( y [ top_mask ] ) ) if int (
np . sum ( top_mask ) ) > 0 else float ( " nan " )
bot_y = float ( np . mean ( y [ bot_half_mask ] ) ) if int (
np . sum ( bot_half_mask ) ) > 0 else float ( " nan " )
lift_top_vs_bottom = float ( top_y / bot_y ) if ( np . isfinite ( top_y )
and np . isfinite ( bot_y ) and bot_y > 0 ) else float ( " nan " )
slope = float ( " nan " )
if len ( p_means ) > = 2 :
# Weighted least squares slope of y_rate ~ p_mean.
x = np . asarray ( p_means , dtype = float )
yy = np . asarray ( y_rates , dtype = float )
w = np . asarray ( n_bins , dtype = float )
xm = float ( np . average ( x , weights = w ) )
ym = float ( np . average ( yy , weights = w ) )
denom = float ( np . sum ( w * ( x - xm ) * * 2 ) )
if denom > 0 :
slope = float ( np . sum ( w * ( x - xm ) * ( yy - ym ) ) / denom )
summary = {
" y_overall " : y_overall ,
" top_decile_y_rate " : top_y ,
" bottom_half_y_rate " : bot_y ,
" lift_top10_vs_bottom50 " : lift_top_vs_bottom ,
" slope_pred_vs_obs " : slope ,
}
return q , bin_rows , summary
def compute_capture_points (
p : np . ndarray ,
y : np . ndarray ,
k_pcts : Sequence [ int ] ,
) - > List [ Dict [ str , Any ] ] :
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
n = int ( p . shape [ 0 ] )
if n == 0 :
return [ ]
order = np . argsort ( - p )
y_sorted = y [ order ]
events_total = float ( np . sum ( y_sorted ) )
rows : List [ Dict [ str , Any ] ] = [ ]
for k in k_pcts :
kf = float ( k )
n_targeted = int ( math . ceil ( n * kf / 100.0 ) )
n_targeted = max ( 1 , min ( n_targeted , n ) )
events_targeted = float ( np . sum ( y_sorted [ : n_targeted ] ) )
capture = float ( events_targeted /
events_total ) if events_total > 0 else float ( " nan " )
precision = float ( events_targeted / float ( n_targeted ) )
rows . append (
{
" k_pct " : int ( k ) ,
" n_targeted " : int ( n_targeted ) ,
" events_targeted " : float ( events_targeted ) ,
" events_total " : float ( events_total ) ,
" event_capture_rate " : capture ,
" precision_in_targeted " : precision ,
}
)
return rows
def make_horizon_groups ( horizons : Sequence [ float ] ) - > Tuple [ List [ Dict [ str , Any ] ] , Dict [ float , str ] , str ] :
""" Bucketize horizons into short/medium/long using the continuous-horizon rule. """
uniq = sorted ( { float ( h ) for h in horizons } )
mapping : Dict [ float , str ] = { }
rows : List [ Dict [ str , Any ] ] = [ ]
# First 4 short, next 4 medium, rest long.
for i , h in enumerate ( uniq ) :
if i < 4 :
g , gr = " short " , 1
elif i < 8 :
g , gr = " medium " , 2
else :
g , gr = " long " , 3
mapping [ float ( h ) ] = g
rows . append ( { " horizon " : float ( h ) , " group " : g , " group_rank " : int ( gr ) } )
method = " continuous_unique_horizons_first4_next4_rest "
return rows , mapping , method
def count_occurs_within_horizon (
2026-01-10 11:37:12 +08:00
loader : DataLoader ,
offset_years : float ,
2026-01-10 17:00:16 +08:00
tau_years : float ,
2026-01-10 11:37:12 +08:00
n_disease : int ,
device : str ,
) - > Tuple [ np . ndarray , int ] :
2026-01-10 17:00:16 +08:00
""" Count per-person occurrence within tau after the prediction context.
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
Returns counts [ k ] = number of individuals with disease k at least once in ( t_ctx , t_ctx + tau ] .
2026-01-10 11:37:12 +08:00
"""
counts = torch . zeros ( ( n_disease , ) , dtype = torch . long , device = device )
n_total_eval = 0
for batch in loader :
event_seq , time_seq , cont_feats , cate_feats , sexes = batch
event_seq = event_seq . to ( device )
time_seq = time_seq . to ( device )
keep , t_ctx , _ = select_context_indices (
event_seq , time_seq , offset_years )
if not keep . any ( ) :
continue
n_total_eval + = int ( keep . sum ( ) . item ( ) )
event_seq = event_seq [ keep ]
2026-01-10 17:00:16 +08:00
time_seq = time_seq [ keep ]
2026-01-10 11:37:12 +08:00
t_ctx = t_ctx [ keep ]
B , L = event_seq . shape
2026-01-10 17:00:16 +08:00
b = torch . arange ( B , device = device )
t0 = time_seq [ b , t_ctx ]
t1 = t0 + ( float ( tau_years ) * DAYS_PER_YEAR )
2026-01-10 11:37:12 +08:00
idxs = torch . arange ( L , device = device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
2026-01-10 17:00:16 +08:00
in_window = (
( idxs > t_ctx . unsqueeze ( 1 ) )
& ( time_seq > = t0 . unsqueeze ( 1 ) )
& ( time_seq < = t1 . unsqueeze ( 1 ) )
& ( event_seq > = 2 )
& ( event_seq != 0 )
)
if not in_window . any ( ) :
2026-01-10 11:37:12 +08:00
continue
2026-01-10 17:00:16 +08:00
b_idx , t_idx = in_window . nonzero ( as_tuple = True )
2026-01-10 11:37:12 +08:00
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
2026-01-10 17:00:16 +08:00
# unique per (person, disease) to count per-person within-window occurrence
2026-01-10 11:37:12 +08:00
key = b_idx . to ( torch . long ) * int ( n_disease ) + disease_ids
uniq = torch . unique ( key )
uniq_disease = uniq % int ( n_disease )
counts . scatter_add_ ( 0 , uniq_disease , torch . ones_like (
uniq_disease , dtype = torch . long ) )
return counts . detach ( ) . cpu ( ) . numpy ( ) , int ( n_total_eval )
# ============================================================
# Evaluation core
# ============================================================
def instantiate_model_and_head (
cfg : Dict [ str , Any ] ,
dataset : HealthDataset ,
device : str ,
checkpoint_path : str = " " ,
) - > Tuple [ torch . nn . Module , torch . nn . Module , str , Sequence [ float ] ] :
model_type = str ( cfg [ " model_type " ] )
loss_type = str ( cfg [ " loss_type " ] )
if loss_type == " exponential " :
out_dims = [ dataset . n_disease ]
elif loss_type == " discrete_time_cif " :
bin_edges = cfg . get ( " bin_edges " , DEFAULT_BIN_EDGES )
out_dims = [ dataset . n_disease + 1 , len ( bin_edges ) ]
else :
raise ValueError ( f " Unsupported loss_type for evaluation: { loss_type } " )
if model_type == " delphi_fork " :
backbone = DelphiFork (
n_disease = dataset . n_disease ,
n_tech_tokens = 2 ,
n_embd = int ( cfg [ " n_embd " ] ) ,
n_head = int ( cfg [ " n_head " ] ) ,
n_layer = int ( cfg [ " n_layer " ] ) ,
pdrop = float ( cfg . get ( " pdrop " , 0.0 ) ) ,
age_encoder_type = str ( cfg . get ( " age_encoder " , " sinusoidal " ) ) ,
n_cont = dataset . n_cont ,
n_cate = dataset . n_cate ,
cate_dims = dataset . cate_dims ,
) . to ( device )
elif model_type == " sap_delphi " :
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
emb_path = cfg . get ( " pretrained_emb_path " , None )
if emb_path in { " " , None } :
emb_path = cfg . get ( " pretrained_emd_path " , None )
if emb_path in { " " , None } :
run_dir = os . path . dirname ( os . path . abspath (
checkpoint_path ) ) if checkpoint_path else " "
print (
f " WARNING: SapDelphi pretrained embedding path missing in config "
f " (expected ' pretrained_emb_path ' or ' pretrained_emd_path ' ). "
f " checkpoint= { checkpoint_path } run_dir= { run_dir } "
)
backbone = SapDelphi (
n_disease = dataset . n_disease ,
n_tech_tokens = 2 ,
n_embd = int ( cfg [ " n_embd " ] ) ,
n_head = int ( cfg [ " n_head " ] ) ,
n_layer = int ( cfg [ " n_layer " ] ) ,
pdrop = float ( cfg . get ( " pdrop " , 0.0 ) ) ,
age_encoder_type = str ( cfg . get ( " age_encoder " , " sinusoidal " ) ) ,
n_cont = dataset . n_cont ,
n_cate = dataset . n_cate ,
cate_dims = dataset . cate_dims ,
pretrained_weights_path = emb_path ,
freeze_embeddings = True ,
) . to ( device )
else :
raise ValueError ( f " Unsupported model_type: { model_type } " )
head = SimpleHead ( n_embd = int ( cfg [ " n_embd " ] ) , out_dims = out_dims ) . to ( device )
bin_edges = cfg . get ( " bin_edges " , DEFAULT_BIN_EDGES )
return backbone , head , loss_type , bin_edges
@torch.no_grad ( )
def predict_cifs_for_model (
backbone : torch . nn . Module ,
head : torch . nn . Module ,
loss_type : str ,
bin_edges : Sequence [ float ] ,
loader : DataLoader ,
device : str ,
offset_years : float ,
eval_horizons : Sequence [ float ] ,
top_cause_ids : np . ndarray ,
2026-01-10 17:00:16 +08:00
) - > Tuple [ np . ndarray , np . ndarray , np . ndarray , np . ndarray , np . ndarray ] :
""" Run model and produce cause-specific, time-dependent CIF outputs.
2026-01-10 11:37:12 +08:00
Returns :
cause_cif : ( N , topK , H )
cif_full : ( N , K , H )
survival : ( N , H )
y_cause_within_tau : ( N , topK , H )
2026-01-10 17:00:16 +08:00
NOTE : Evaluation is cause - specific and horizon - specific ( multi - disease risk ) .
2026-01-10 11:37:12 +08:00
"""
backbone . eval ( )
head . eval ( )
# We will accumulate in CPU lists, then concat.
cause_cif_list : List [ np . ndarray ] = [ ]
cif_full_list : List [ np . ndarray ] = [ ]
survival_list : List [ np . ndarray ] = [ ]
y_cause_within_list : List [ np . ndarray ] = [ ]
2026-01-10 17:00:16 +08:00
sex_list : List [ np . ndarray ] = [ ]
2026-01-10 11:37:12 +08:00
top_cause_ids_t = torch . tensor (
top_cause_ids , dtype = torch . long , device = device )
for batch in loader :
event_seq , time_seq , cont_feats , cate_feats , sexes = batch
event_seq = event_seq . to ( device )
time_seq = time_seq . to ( device )
cont_feats = cont_feats . to ( device )
cate_feats = cate_feats . to ( device )
keep , t_ctx , _ = select_context_indices (
event_seq , time_seq , offset_years )
if not keep . any ( ) :
continue
# filter batch
event_seq = event_seq [ keep ]
time_seq = time_seq [ keep ]
cont_feats = cont_feats [ keep ]
cate_feats = cate_feats [ keep ]
2026-01-10 17:00:16 +08:00
sexes_k = sexes [ keep ] . to ( device )
2026-01-10 11:37:12 +08:00
t_ctx = t_ctx [ keep ]
h = backbone ( event_seq , time_seq , sexes_k ,
cont_feats , cate_feats ) # (B,L,D)
b = torch . arange ( h . size ( 0 ) , device = device )
c = h [ b , t_ctx ] # (B,D)
logits = head ( c )
if loss_type == " exponential " :
cif_full , survival = cifs_from_exponential_logits (
logits , eval_horizons , return_survival = True ) # (B,K,H), (B,H)
elif loss_type == " discrete_time_cif " :
cif_full , survival = cifs_from_discrete_time_logits (
# (B,K,H), (B,H)
logits , bin_edges , eval_horizons , return_survival = True )
else :
raise ValueError ( f " Unsupported loss_type: { loss_type } " )
cause_cif = cif_full . index_select (
dim = 1 , index = top_cause_ids_t ) # (B,topK,H)
2026-01-10 17:00:16 +08:00
# Within-horizon labels for cause-specific CIF quality + discrimination.
2026-01-10 11:37:12 +08:00
n_disease = int ( cif_full . size ( 1 ) )
y_within_top = torch . stack (
[
multi_hot_selected_causes_within_horizon (
event_seq = event_seq ,
time_seq = time_seq ,
t_ctx = t_ctx ,
tau_years = float ( tau ) ,
cause_ids = top_cause_ids_t ,
n_disease = n_disease ,
) . to ( torch . float32 )
for tau in eval_horizons
] ,
dim = 2 ,
) # (B,topK,H)
2026-01-10 17:00:16 +08:00
2026-01-10 11:37:12 +08:00
cause_cif_list . append ( cause_cif . detach ( ) . cpu ( ) . numpy ( ) )
cif_full_list . append ( cif_full . detach ( ) . cpu ( ) . numpy ( ) )
survival_list . append ( survival . detach ( ) . cpu ( ) . numpy ( ) )
y_cause_within_list . append ( y_within_top . detach ( ) . cpu ( ) . numpy ( ) )
2026-01-10 17:00:16 +08:00
sex_list . append ( sexes_k . detach ( ) . cpu ( ) . numpy ( ) )
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
if not cause_cif_list :
2026-01-10 11:37:12 +08:00
raise RuntimeError (
" No valid samples for evaluation (all batches filtered out by offset). " )
cause_cif = np . concatenate ( cause_cif_list , axis = 0 )
cif_full = np . concatenate ( cif_full_list , axis = 0 )
survival = np . concatenate ( survival_list , axis = 0 )
y_cause_within = np . concatenate ( y_cause_within_list , axis = 0 )
2026-01-10 17:00:16 +08:00
sex = np . concatenate (
sex_list , axis = 0 ) if sex_list else np . array ( [ ] , dtype = int )
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
return cause_cif , cif_full , survival , y_cause_within , sex
2026-01-10 11:37:12 +08:00
def pick_top_causes ( y_ever : np . ndarray , top_k : int ) - > np . ndarray :
counts = y_ever . sum ( axis = 0 )
order = np . argsort ( - counts )
order = order [ counts [ order ] > 0 ]
return order [ : top_k ]
def evaluate_one_model (
model_name : str ,
cause_cif : np . ndarray ,
y_cause_within_tau : np . ndarray ,
eval_horizons : Sequence [ float ] ,
top_cause_ids : np . ndarray ,
out_rows : List [ Dict [ str , Any ] ] ,
calib_rows : List [ Dict [ str , Any ] ] ,
auc_ci_method : str ,
bootstrap_n : int ,
n_calib_bins : int = 10 ,
) - > None :
2026-01-10 17:00:16 +08:00
# Cause-specific, time-dependent metrics per horizon.
2026-01-10 11:37:12 +08:00
for h_i , tau in enumerate ( eval_horizons ) :
2026-01-10 17:00:16 +08:00
p_tau = cause_cif [ : , : , h_i ] # (N, topK)
y_tau = y_cause_within_tau [ : , : , h_i ] # (N, topK)
2026-01-10 11:37:12 +08:00
for j , cause_id in enumerate ( top_cause_ids . tolist ( ) ) :
p = p_tau [ : , j ]
y = y_tau [ : , j ]
2026-01-10 17:00:16 +08:00
# Primary: CIF-based Brier score + ICI (calibration).
2026-01-10 11:37:12 +08:00
out_rows . append (
{
" model_name " : model_name ,
" metric_name " : " cause_brier " ,
" horizon " : float ( tau ) ,
" cause " : int ( cause_id ) ,
" value " : brier_score ( p , y ) ,
" ci_low " : " " ,
" ci_high " : " " ,
}
)
2026-01-10 17:00:16 +08:00
cal = calibration_deciles ( p , y , n_bins = n_calib_bins )
2026-01-10 11:37:12 +08:00
out_rows . append (
{
" model_name " : model_name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " cause_ici " ,
2026-01-10 11:37:12 +08:00
" horizon " : float ( tau ) ,
" cause " : int ( cause_id ) ,
2026-01-10 17:00:16 +08:00
" value " : cal [ " ici " ] ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
2026-01-10 17:00:16 +08:00
# Secondary: discrimination via AUC at the same horizon.
if auc_ci_method == " none " :
auc , lo , hi = float ( " nan " ) , float ( " nan " ) , float ( " nan " )
elif auc_ci_method == " bootstrap " :
auc , lo , hi = bootstrap_auc_ci (
p , y , n_bootstrap = bootstrap_n , alpha = 0.95 )
else :
auc , lo , hi = delong_ci ( y , p , alpha = 0.95 )
2026-01-10 11:37:12 +08:00
out_rows . append (
{
" model_name " : model_name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " cause_auc " ,
2026-01-10 11:37:12 +08:00
" horizon " : float ( tau ) ,
" cause " : int ( cause_id ) ,
2026-01-10 17:00:16 +08:00
" value " : auc ,
" ci_low " : lo ,
" ci_high " : hi ,
2026-01-10 11:37:12 +08:00
}
)
2026-01-10 17:00:16 +08:00
# Calibration curve bins for this cause + horizon.
2026-01-10 11:37:12 +08:00
for binfo in cal . get ( " bins " , [ ] ) :
calib_rows . append (
{
" model_name " : model_name ,
" task " : " cause_k " ,
" horizon " : float ( tau ) ,
" cause_id " : int ( cause_id ) ,
" bin_index " : int ( binfo [ " bin " ] ) ,
" p_mean " : float ( binfo [ " p_mean " ] ) ,
" y_mean " : float ( binfo [ " y_mean " ] ) ,
" n_in_bin " : int ( binfo [ " n " ] ) ,
}
)
def write_calibration_bins_csv ( path : str , rows : List [ Dict [ str , Any ] ] ) - > None :
fieldnames = [
" model_name " ,
" task " ,
" horizon " ,
" cause_id " ,
" bin_index " ,
" p_mean " ,
" y_mean " ,
" n_in_bin " ,
]
with open ( path , " w " , newline = " " ) as f :
w = csv . DictWriter ( f , fieldnames = fieldnames )
w . writeheader ( )
for r in rows :
w . writerow ( r )
def write_results_csv ( path : str , rows : List [ Dict [ str , Any ] ] ) - > None :
fieldnames = [
" model_name " ,
" metric_name " ,
" horizon " ,
" cause " ,
" value " ,
" ci_low " ,
" ci_high " ,
]
with open ( path , " w " , newline = " " ) as f :
w = csv . DictWriter ( f , fieldnames = fieldnames )
w . writeheader ( )
for r in rows :
w . writerow ( r )
def _make_eval_tag ( split : str , offset_years : float ) - > str :
""" Short tag for filenames written into run directories. """
off = f " { float ( offset_years ) : .4f } " . rstrip ( " 0 " ) . rstrip ( " . " )
return f " { split } _offset { off } y "
def main ( ) - > int :
ap = argparse . ArgumentParser (
description = " Unified downstream evaluation via CIFs " )
ap . add_argument ( " --models_json " , type = str , required = True ,
help = " Path to models list JSON " )
ap . add_argument ( " --data_prefix " , type = str ,
default = " ukb " , help = " Dataset prefix " )
ap . add_argument ( " --split " , type = str , default = " test " ,
choices = [ " train " , " val " , " test " , " all " ] , help = " Which split to evaluate " )
ap . add_argument ( " --offset_years " , type = float , default = 0.5 ,
help = " Anti-leakage offset (years) " )
ap . add_argument ( " --eval_horizons " , type = float ,
nargs = " * " , default = DEFAULT_EVAL_HORIZONS )
ap . add_argument ( " --top_k_causes " , type = int , default = 50 )
ap . add_argument ( " --batch_size " , type = int , default = 128 )
ap . add_argument ( " --num_workers " , type = int , default = 0 )
ap . add_argument ( " --seed " , type = int , default = 123 )
ap . add_argument ( " --device " , type = str ,
default = " cuda " if torch . cuda . is_available ( ) else " cpu " )
ap . add_argument ( " --out_csv " , type = str , default = " eval_results.csv " )
ap . add_argument ( " --out_meta_json " , type = str , default = " eval_meta.json " )
# Integrity checks
ap . add_argument ( " --integrity_strict " , action = " store_true " , default = False )
ap . add_argument ( " --integrity_tol " , type = float , default = 1e-6 )
# AUC CI methods
ap . add_argument (
" --auc_ci_method " ,
type = str ,
default = " delong " ,
choices = [ " delong " , " bootstrap " , " none " ] ,
)
ap . add_argument ( " --bootstrap_n " , type = int , default = 2000 )
2026-01-10 17:00:16 +08:00
# Export settings for user-facing experiments
ap . add_argument ( " --export_dir " , type = str , default = " eval_exports " )
ap . add_argument ( " --death_cause_id " , type = int ,
default = DEFAULT_DEATH_CAUSE_ID )
ap . add_argument ( " --focus_k " , type = int , default = 5 ,
help = " Additional non-death causes to include " )
ap . add_argument ( " --capture_k_pcts " , type = int ,
nargs = " * " , default = [ 1 , 5 , 10 , 20 ] )
ap . add_argument (
" --capture_curve_max_pct " ,
type = int ,
default = 50 ,
help = " If >0, also export a dense capture curve for k=1..max_pct " ,
)
2026-01-10 11:37:12 +08:00
args = ap . parse_args ( )
set_deterministic ( args . seed )
specs = load_models_json ( args . models_json )
if not specs :
raise ValueError ( " No models provided " )
2026-01-10 17:00:16 +08:00
export_dir = str ( args . export_dir )
_ensure_dir ( export_dir )
cause_names = load_cause_names ( " labels.csv " )
# Determine top-K causes from the evaluation split only (model-agnostic),
# aligned to time-dependent risk: occurrence within tau_max after context.
2026-01-10 11:37:12 +08:00
first_cfg = load_train_config_for_checkpoint ( specs [ 0 ] . checkpoint_path )
cov_list = None if _parse_bool ( first_cfg . get ( " full_cov " , False ) ) else [
" bmi " , " smoking " , " alcohol " ]
dataset_for_top = HealthDataset (
data_prefix = args . data_prefix , covariate_list = cov_list )
subset_for_top = build_eval_subset (
dataset_for_top ,
train_ratio = float ( first_cfg . get ( " train_ratio " , 0.7 ) ) ,
val_ratio = float ( first_cfg . get ( " val_ratio " , 0.15 ) ) ,
seed = int ( first_cfg . get ( " random_seed " , 42 ) ) ,
split = args . split ,
)
loader_top = DataLoader (
subset_for_top ,
batch_size = args . batch_size ,
shuffle = False ,
num_workers = args . num_workers ,
collate_fn = health_collate_fn ,
)
2026-01-10 17:00:16 +08:00
tau_max = float ( max ( args . eval_horizons ) )
counts , n_total_eval = count_occurs_within_horizon (
2026-01-10 11:37:12 +08:00
loader = loader_top ,
offset_years = args . offset_years ,
2026-01-10 17:00:16 +08:00
tau_years = tau_max ,
2026-01-10 11:37:12 +08:00
n_disease = dataset_for_top . n_disease ,
device = args . device ,
)
2026-01-10 17:00:16 +08:00
focus_causes = pick_focus_causes (
counts_within_tau = counts ,
n_disease = int ( dataset_for_top . n_disease ) ,
death_cause_id = int ( args . death_cause_id ) ,
k = int ( args . focus_k ) ,
)
top_cause_ids = np . asarray ( focus_causes , dtype = int )
# Export the chosen focus causes.
focus_rows : List [ Dict [ str , Any ] ] = [ ]
for r , cid in enumerate ( focus_causes , start = 1 ) :
row : Dict [ str , Any ] = { " cause " : int ( cid ) , " rank " : int ( r ) }
if cid in cause_names :
row [ " cause_name " ] = cause_names [ cid ]
focus_rows . append ( row )
focus_fieldnames = [ " cause " , " rank " ] + \
( [ " cause_name " ] if any ( " cause_name " in r for r in focus_rows ) else [ ] )
write_simple_csv ( os . path . join ( export_dir , " focus_causes.csv " ) ,
focus_fieldnames , focus_rows )
# Metadata for focus causes (within tau_max).
2026-01-10 11:37:12 +08:00
top_causes_meta : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 17:00:16 +08:00
for cid in focus_causes :
n_case = int ( counts [ int ( cid ) ] ) if int ( cid ) < int ( counts . shape [ 0 ] ) else 0
2026-01-10 11:37:12 +08:00
top_causes_meta . append (
{
2026-01-10 17:00:16 +08:00
" cause_id " : int ( cid ) ,
" tau_years " : float ( tau_max ) ,
" n_case_within_tau " : n_case ,
" n_control_within_tau " : int ( n_total_eval - n_case ) ,
2026-01-10 11:37:12 +08:00
" n_total_eval " : int ( n_total_eval ) ,
}
)
2026-01-10 17:00:16 +08:00
# Horizon groups for Experiment 3
hg_rows , horizon_to_group , hg_method = make_horizon_groups (
args . eval_horizons )
write_simple_csv (
os . path . join ( export_dir , " horizon_groups.csv " ) ,
[ " horizon " , " group " , " group_rank " ] ,
hg_rows ,
)
2026-01-10 11:37:12 +08:00
rows : List [ Dict [ str , Any ] ] = [ ]
calib_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 17:00:16 +08:00
# Experiment exports (accumulated across models)
rs_bins_rows : List [ Dict [ str , Any ] ] = [ ]
rs_sum_rows : List [ Dict [ str , Any ] ] = [ ]
cap_points_rows : List [ Dict [ str , Any ] ] = [ ]
cap_curve_rows : List [ Dict [ str , Any ] ] = [ ]
cal_group_sum_rows : List [ Dict [ str , Any ] ] = [ ]
cal_group_bins_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 11:37:12 +08:00
# Track per-model integrity status for meta JSON.
integrity_meta : Dict [ str , Any ] = { }
# Evaluate each model
for spec in specs :
run_dir = os . path . dirname ( os . path . abspath ( spec . checkpoint_path ) )
tag = _make_eval_tag ( args . split , float ( args . offset_years ) )
# Remember list offsets so we can write per-model slices to the model's run_dir.
rows_start = len ( rows )
calib_start = len ( calib_rows )
cfg = load_train_config_for_checkpoint ( spec . checkpoint_path )
2026-01-10 17:00:16 +08:00
# Identifiers for consistent exports
model_id = str ( spec . name )
model_type = str (
cfg . get ( " model_type " , spec . model_type if hasattr ( spec , " model_type " ) else " " ) )
loss_type_id = str (
cfg . get ( " loss_type " , spec . loss_type if hasattr ( spec , " loss_type " ) else " " ) )
age_encoder = str ( cfg . get ( " age_encoder " , " " ) )
cov_type = " full " if _parse_bool (
cfg . get ( " full_cov " , False ) ) else " partial "
2026-01-10 11:37:12 +08:00
cov_list = None if _parse_bool ( cfg . get ( " full_cov " , False ) ) else [
" bmi " , " smoking " , " alcohol " ]
dataset = HealthDataset (
data_prefix = args . data_prefix , covariate_list = cov_list )
subset = build_eval_subset (
dataset ,
train_ratio = float ( cfg . get ( " train_ratio " , 0.7 ) ) ,
val_ratio = float ( cfg . get ( " val_ratio " , 0.15 ) ) ,
seed = int ( cfg . get ( " random_seed " , 42 ) ) ,
split = args . split ,
)
loader = DataLoader (
subset ,
batch_size = args . batch_size ,
shuffle = False ,
num_workers = args . num_workers ,
collate_fn = health_collate_fn ,
)
backbone , head , loss_type , bin_edges = instantiate_model_and_head (
cfg , dataset , args . device , checkpoint_path = spec . checkpoint_path )
ckpt = torch . load ( spec . checkpoint_path , map_location = args . device )
backbone . load_state_dict ( ckpt [ " model_state_dict " ] , strict = True )
head . load_state_dict ( ckpt [ " head_state_dict " ] , strict = True )
(
cause_cif ,
cif_full ,
survival ,
y_cause_within_tau ,
2026-01-10 17:00:16 +08:00
sex ,
2026-01-10 11:37:12 +08:00
) = predict_cifs_for_model (
backbone ,
head ,
loss_type ,
bin_edges ,
loader ,
args . device ,
args . offset_years ,
args . eval_horizons ,
top_cause_ids ,
)
# CIF integrity checks before metrics.
integrity_ok , integrity_notes = check_cif_integrity (
cif_full ,
args . eval_horizons ,
tol = float ( args . integrity_tol ) ,
name = spec . name ,
strict = bool ( args . integrity_strict ) ,
survival = survival ,
)
integrity_meta [ spec . name ] = {
" integrity_ok " : bool ( integrity_ok ) ,
" integrity_notes " : integrity_notes ,
}
evaluate_one_model (
model_name = spec . name ,
cause_cif = cause_cif ,
y_cause_within_tau = y_cause_within_tau ,
eval_horizons = args . eval_horizons ,
top_cause_ids = top_cause_ids ,
out_rows = rows ,
calib_rows = calib_rows ,
auc_ci_method = str ( args . auc_ci_method ) ,
bootstrap_n = int ( args . bootstrap_n ) ,
)
2026-01-10 17:00:16 +08:00
# ============================================================
# Experiment 1: Risk stratification bins + summary
# ============================================================
for sex_label , sex_mask in _sex_slices ( sex if sex . size else None ) :
for h_i , tau in enumerate ( args . eval_horizons ) :
for j , cause_id in enumerate ( top_cause_ids . tolist ( ) ) :
p = cause_cif [ : , j , h_i ]
y = y_cause_within_tau [ : , j , h_i ]
if sex_mask is not None :
p = p [ sex_mask ]
y = y [ sex_mask ]
q_used , bin_rows , summary = compute_risk_stratification_bins (
p , y , q_default = 10 )
for br in bin_rows :
rs_bins_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
" q " : int ( br [ " q " ] ) ,
" n_bin " : int ( br [ " n_bin " ] ) ,
" p_mean " : _safe_float ( br [ " p_mean " ] ) ,
" y_rate " : _safe_float ( br [ " y_rate " ] ) ,
" y_overall " : _safe_float ( br [ " y_overall " ] ) ,
" lift_vs_overall " : _safe_float ( br [ " lift_vs_overall " ] ) ,
" q_total " : int ( q_used ) ,
}
)
rs_sum_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
" q_total " : int ( q_used ) ,
" top_decile_y_rate " : _safe_float ( summary [ " top_decile_y_rate " ] ) ,
" bottom_half_y_rate " : _safe_float ( summary [ " bottom_half_y_rate " ] ) ,
" lift_top10_vs_bottom50 " : _safe_float ( summary [ " lift_top10_vs_bottom50 " ] ) ,
" slope_pred_vs_obs " : _safe_float ( summary [ " slope_pred_vs_obs " ] ) ,
}
)
# ============================================================
# Experiment 2: High-risk capture points (+ optional curve)
# ============================================================
k_pcts = [ int ( x ) for x in args . capture_k_pcts ]
curve_max = int ( args . capture_curve_max_pct )
curve_grid = list ( range ( 1 , curve_max + 1 )
) if curve_max and curve_max > 0 else [ ]
for sex_label , sex_mask in _sex_slices ( sex if sex . size else None ) :
for h_i , tau in enumerate ( args . eval_horizons ) :
for j , cause_id in enumerate ( top_cause_ids . tolist ( ) ) :
p = cause_cif [ : , j , h_i ]
y = y_cause_within_tau [ : , j , h_i ]
if sex_mask is not None :
p = p [ sex_mask ]
y = y [ sex_mask ]
for r in compute_capture_points ( p , y , k_pcts ) :
cap_points_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
* * r ,
}
)
if curve_grid :
for r in compute_capture_points ( p , y , curve_grid ) :
cap_curve_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
* * r ,
}
)
# ============================================================
# Experiment 3: Short/Medium/Long horizon-group calibration
# ============================================================
# Per-horizon metrics for grouping
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
per_h : Dict [ Tuple [ int , float ] , Dict [ str , float ] ] = { }
for rr in rows [ rows_start : ] :
if rr . get ( " model_name " ) != spec . name :
continue
if rr . get ( " metric_name " ) not in { " cause_brier " , " cause_ici " } :
continue
try :
cid = int ( rr . get ( " cause " ) )
except Exception :
continue
h = _safe_float ( rr . get ( " horizon " ) )
if not np . isfinite ( h ) :
continue
key = ( cid , float ( h ) )
d = per_h . get ( key , { } )
d [ str ( rr . get ( " metric_name " ) ) ] = _safe_float ( rr . get ( " value " ) )
per_h [ key ] = d
# Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice).
for sex_label , sex_mask in _sex_slices ( sex if sex . size else None ) :
for j , cause_id in enumerate ( top_cause_ids . tolist ( ) ) :
# Decide Q per slice for pooled reliability curve
n_slice = int ( np . sum ( sex_mask ) ) if sex_mask is not None else int (
sex . shape [ 0 ] )
q_pool = 10 if n_slice > = 200 else 5
# Collect per-horizon brier/ici values
group_vals : Dict [ str , Dict [ str , List [ float ] ] ] = { " short " : { " brier " : [ ] , " ici " : [
] } , " medium " : { " brier " : [ ] , " ici " : [ ] } , " long " : { " brier " : [ ] , " ici " : [ ] } }
group_n_total : Dict [ str , int ] = {
" short " : 0 , " medium " : 0 , " long " : 0 }
# Pooled bins: group -> q -> accumulators
pooled : Dict [ str , Dict [ int , Dict [ str , float ] ] ] = {
" short " : { } , " medium " : { } , " long " : { } }
for h_i , tau in enumerate ( args . eval_horizons ) :
g = horizon_to_group . get ( float ( tau ) , " long " )
# brier/ici per horizon (already computed at full-sample level)
d = per_h . get ( ( int ( cause_id ) , float ( tau ) ) , { } )
brier_h = _safe_float ( d . get ( " cause_brier " ) )
ici_h = _safe_float ( d . get ( " cause_ici " ) )
if np . isfinite ( brier_h ) :
group_vals [ g ] [ " brier " ] . append ( brier_h )
if np . isfinite ( ici_h ) :
group_vals [ g ] [ " ici " ] . append ( ici_h )
# pooled reliability bins from raw p/y
p = cause_cif [ : , j , h_i ]
y = y_cause_within_tau [ : , j , h_i ]
if sex_mask is not None :
p = p [ sex_mask ]
y = y [ sex_mask ]
if p . size == 0 :
continue
edges = _quantile_edges ( p , q_pool )
for qi in range ( q_pool ) :
m = ( p > edges [ qi ] ) & ( p < = edges [ qi + 1 ] )
nb = int ( np . sum ( m ) )
if nb == 0 :
continue
pm = float ( np . mean ( p [ m ] ) )
yr = float ( np . mean ( y [ m ] ) )
acc = pooled [ g ] . get (
qi + 1 , { " n " : 0.0 , " p_sum " : 0.0 , " y_sum " : 0.0 } )
acc [ " n " ] + = float ( nb )
acc [ " p_sum " ] + = float ( nb ) * pm
acc [ " y_sum " ] + = float ( nb ) * yr
pooled [ g ] [ qi + 1 ] = acc
group_n_total [ g ] = max ( group_n_total [ g ] , int ( p . size ) )
for g in [ " short " , " medium " , " long " ] :
bvals = group_vals [ g ] [ " brier " ]
ivals = group_vals [ g ] [ " ici " ]
cal_group_sum_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" sex " : sex_label ,
" horizon_group " : g ,
" brier_mean " : float ( np . mean ( bvals ) ) if bvals else float ( " nan " ) ,
" brier_median " : float ( np . median ( bvals ) ) if bvals else float ( " nan " ) ,
" ici_mean " : float ( np . mean ( ivals ) ) if ivals else float ( " nan " ) ,
" ici_median " : float ( np . median ( ivals ) ) if ivals else float ( " nan " ) ,
" n_total " : int ( group_n_total [ g ] ) ,
" horizon_grouping_method " : hg_method ,
}
)
for qi in range ( 1 , q_pool + 1 ) :
acc = pooled [ g ] . get ( qi )
if not acc or float ( acc . get ( " n " , 0.0 ) ) < = 0 :
continue
n_bin = float ( acc [ " n " ] )
cal_group_bins_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" sex " : sex_label ,
" horizon_group " : g ,
" q " : int ( qi ) ,
" n_bin " : int ( n_bin ) ,
" p_mean " : float ( acc [ " p_sum " ] / n_bin ) ,
" y_rate " : float ( acc [ " y_sum " ] / n_bin ) ,
" q_total " : int ( q_pool ) ,
" horizon_grouping_method " : hg_method ,
}
)
2026-01-10 11:37:12 +08:00
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta :
rows . append (
{
" model_name " : spec . name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " topcause_n_case_within_tau " ,
" horizon " : float ( tc [ " tau_years " ] ) ,
2026-01-10 11:37:12 +08:00
" cause " : int ( tc [ " cause_id " ] ) ,
2026-01-10 17:00:16 +08:00
" value " : int ( tc [ " n_case_within_tau " ] ) ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
rows . append (
{
" model_name " : spec . name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " topcause_n_control_within_tau " ,
" horizon " : float ( tc [ " tau_years " ] ) ,
2026-01-10 11:37:12 +08:00
" cause " : int ( tc [ " cause_id " ] ) ,
2026-01-10 17:00:16 +08:00
" value " : int ( tc [ " n_control_within_tau " ] ) ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
rows . append (
{
" model_name " : spec . name ,
" metric_name " : " topcause_n_total_eval " ,
2026-01-10 17:00:16 +08:00
" horizon " : float ( tc [ " tau_years " ] ) ,
2026-01-10 11:37:12 +08:00
" cause " : int ( tc [ " cause_id " ] ) ,
" value " : int ( tc [ " n_total_eval " ] ) ,
" ci_low " : " " ,
" ci_high " : " " ,
}
)
# Write per-model results into the model's run directory.
model_rows = rows [ rows_start : ]
model_calib_rows = calib_rows [ calib_start : ]
model_out_csv = os . path . join ( run_dir , f " eval_results_ { tag } .csv " )
model_calib_csv = os . path . join ( run_dir , f " calibration_bins_ { tag } .csv " )
model_meta_json = os . path . join ( run_dir , f " eval_meta_ { tag } .json " )
write_results_csv ( model_out_csv , model_rows )
write_calibration_bins_csv ( model_calib_csv , model_calib_rows )
model_meta = {
" model_name " : spec . name ,
" checkpoint_path " : spec . checkpoint_path ,
" run_dir " : run_dir ,
" split " : args . split ,
" offset_years " : args . offset_years ,
" eval_horizons " : [ float ( x ) for x in args . eval_horizons ] ,
" top_k_causes " : int ( args . top_k_causes ) ,
" top_cause_ids " : top_cause_ids . tolist ( ) ,
" top_causes " : top_causes_meta ,
" integrity " : { spec . name : integrity_meta . get ( spec . name , { } ) } ,
" paths " : {
" results_csv " : model_out_csv ,
" calibration_bins_csv " : model_calib_csv ,
} ,
}
with open ( model_meta_json , " w " ) as f :
json . dump ( model_meta , f , indent = 2 )
print ( f " Wrote per-model results to { model_out_csv } " )
write_results_csv ( args . out_csv , rows )
# Write calibration curve points to a separate CSV.
out_dir = os . path . dirname ( os . path . abspath ( args . out_csv ) ) or " . "
calib_csv_path = os . path . join ( out_dir , " calibration_bins.csv " )
write_calibration_bins_csv ( calib_csv_path , calib_rows )
2026-01-10 17:00:16 +08:00
# Write experiment exports
write_simple_csv (
os . path . join ( export_dir , " risk_stratification_bins.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" horizon " ,
" sex " ,
" q " ,
" n_bin " ,
" p_mean " ,
" y_rate " ,
" y_overall " ,
" lift_vs_overall " ,
" q_total " ,
] ,
rs_bins_rows ,
)
write_simple_csv (
os . path . join ( export_dir , " risk_stratification_summary.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" horizon " ,
" sex " ,
" q_total " ,
" top_decile_y_rate " ,
" bottom_half_y_rate " ,
" lift_top10_vs_bottom50 " ,
" slope_pred_vs_obs " ,
] ,
rs_sum_rows ,
)
write_simple_csv (
os . path . join ( export_dir , " lift_capture_points.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" horizon " ,
" sex " ,
" k_pct " ,
" n_targeted " ,
" events_targeted " ,
" events_total " ,
" event_capture_rate " ,
" precision_in_targeted " ,
] ,
cap_points_rows ,
)
if cap_curve_rows :
write_simple_csv (
os . path . join ( export_dir , " lift_capture_curve.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" horizon " ,
" sex " ,
" k_pct " ,
" n_targeted " ,
" events_targeted " ,
" events_total " ,
" event_capture_rate " ,
" precision_in_targeted " ,
] ,
cap_curve_rows ,
)
write_simple_csv (
os . path . join ( export_dir , " calibration_groups_summary.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" sex " ,
" horizon_group " ,
" brier_mean " ,
" brier_median " ,
" ici_mean " ,
" ici_median " ,
" n_total " ,
" horizon_grouping_method " ,
] ,
cal_group_sum_rows ,
)
write_simple_csv (
os . path . join ( export_dir , " calibration_groups_bins.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" sex " ,
" horizon_group " ,
" q " ,
" n_bin " ,
" p_mean " ,
" y_rate " ,
" q_total " ,
" horizon_grouping_method " ,
] ,
cal_group_bins_rows ,
)
# Manifest markdown (stable, user-facing)
manifest_path = os . path . join ( export_dir , " eval_exports_manifest.md " )
with open ( manifest_path , " w " , encoding = " utf-8 " ) as f :
f . write (
" # Evaluation Exports Manifest \n \n "
" This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
" All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced. \n \n "
" ## Files \n \n "
" - focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table. \n "
" - horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons. \n "
" - risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines. \n "
" - risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table. \n "
" - lift_capture_points.csv: Capture/precision at top { 1,5,10,20} % r isk. Intended plot: bar/line showing event capture vs resources. \n "
" - lift_capture_curve.csv (optional): Dense capture curve for k=1..N % . Intended plot: gain curve overlay across models. \n "
" - calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket. \n "
" - calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model. \n "
)
2026-01-10 11:37:12 +08:00
meta = {
" split " : args . split ,
" offset_years " : args . offset_years ,
" eval_horizons " : [ float ( x ) for x in args . eval_horizons ] ,
2026-01-10 17:00:16 +08:00
" tau_max " : float ( tau_max ) ,
2026-01-10 11:37:12 +08:00
" top_k_causes " : int ( args . top_k_causes ) ,
" top_cause_ids " : top_cause_ids . tolist ( ) ,
" top_causes " : top_causes_meta ,
" integrity " : integrity_meta ,
" notes " : {
2026-01-10 17:00:16 +08:00
" label " : " Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau]) " ,
" primary_metrics " : " cause_brier (CIF-based) and cause_ici (calibration) " ,
" secondary_metrics " : " cause_auc (discrimination) with optional CI " ,
" exclusions " : " No all-cause aggregation; no next-event formulation; ECE not reported " ,
2026-01-10 11:37:12 +08:00
" warning " : " This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time. " ,
2026-01-10 17:00:16 +08:00
" exports_dir " : export_dir ,
" focus_causes " : focus_causes ,
" horizon_grouping_method " : hg_method ,
2026-01-10 11:37:12 +08:00
} ,
}
with open ( args . out_meta_json , " w " ) as f :
json . dump ( meta , f , indent = 2 )
print ( f " Wrote { args . out_csv } with { len ( rows ) } rows " )
print ( f " Wrote { calib_csv_path } with { len ( calib_rows ) } rows " )
print ( f " Wrote { args . out_meta_json } " )
return 0
if __name__ == " __main__ " :
raise SystemExit ( main ( ) )