2026-01-10 11:37:12 +08:00
import argparse
import csv
import json
import math
import os
import random
2026-01-10 23:49:37 +08:00
import sys
import time
from concurrent . futures import ThreadPoolExecutor , as_completed
2026-01-10 11:37:12 +08:00
from dataclasses import dataclass
2026-01-13 15:59:20 +08:00
from typing import Any , Dict , List , Optional , Sequence , Tuple
2026-01-10 11:37:12 +08:00
import numpy as np
import torch
import torch . nn . functional as F
from torch . utils . data import DataLoader , random_split
from dataset import HealthDataset , health_collate_fn
from model import DelphiFork , SapDelphi , SimpleHead
# ============================================================
# Constants / defaults (aligned with evaluate_prompt.md)
# ============================================================
DEFAULT_BIN_EDGES = [ 0.0 , 0.24 , 0.72 , 1.61 , 3.84 , 10.0 , 31.0 , float ( " inf " ) ]
DEFAULT_EVAL_HORIZONS = [ 0.72 , 1.61 , 3.84 , 10.0 ]
DAYS_PER_YEAR = 365.25
2026-01-10 17:00:16 +08:00
DEFAULT_DEATH_CAUSE_ID = 1256
2026-01-10 11:37:12 +08:00
# ============================================================
# Model specs
# ============================================================
@dataclass ( frozen = True )
class ModelSpec :
name : str
model_type : str # delphi_fork | sap_delphi
2026-01-13 21:11:38 +08:00
loss_type : str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif
2026-01-10 11:37:12 +08:00
full_cov : bool
checkpoint_path : str
# ============================================================
# Determinism
# ============================================================
def set_deterministic ( seed : int ) - > None :
random . seed ( seed )
np . random . seed ( seed )
torch . manual_seed ( seed )
torch . cuda . manual_seed_all ( seed )
torch . backends . cudnn . deterministic = True
torch . backends . cudnn . benchmark = False
# ============================================================
# Utilities
# ============================================================
def _parse_bool ( x : Any ) - > bool :
if isinstance ( x , bool ) :
return x
s = str ( x ) . strip ( ) . lower ( )
if s in { " true " , " 1 " , " yes " , " y " } :
return True
if s in { " false " , " 0 " , " no " , " n " } :
return False
raise ValueError ( f " Cannot parse boolean: { x !r} " )
def load_models_json ( path : str ) - > List [ ModelSpec ] :
with open ( path , " r " ) as f :
data = json . load ( f )
if not isinstance ( data , list ) :
raise ValueError ( " models_json must be a list of model entries " )
specs : List [ ModelSpec ] = [ ]
for row in data :
specs . append (
ModelSpec (
name = str ( row [ " name " ] ) ,
model_type = str ( row [ " model_type " ] ) ,
loss_type = str ( row [ " loss_type " ] ) ,
full_cov = _parse_bool ( row [ " full_cov " ] ) ,
checkpoint_path = str ( row [ " checkpoint_path " ] ) ,
)
)
return specs
def load_train_config_for_checkpoint ( checkpoint_path : str ) - > Dict [ str , Any ] :
run_dir = os . path . dirname ( os . path . abspath ( checkpoint_path ) )
cfg_path = os . path . join ( run_dir , " train_config.json " )
with open ( cfg_path , " r " ) as f :
cfg = json . load ( f )
return cfg
def build_eval_subset (
dataset : HealthDataset ,
train_ratio : float ,
val_ratio : float ,
seed : int ,
split : str ,
) :
n_total = len ( dataset )
n_train = int ( n_total * train_ratio )
n_val = int ( n_total * val_ratio )
n_test = n_total - n_train - n_val
train_ds , val_ds , test_ds = random_split (
dataset ,
[ n_train , n_val , n_test ] ,
generator = torch . Generator ( ) . manual_seed ( seed ) ,
)
if split == " train " :
return train_ds
if split == " val " :
return val_ds
if split == " test " :
return test_ds
if split == " all " :
return dataset
raise ValueError ( " split must be one of: train, val, test, all " )
# ============================================================
# Context selection (anti-leakage)
# ============================================================
def select_context_indices (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
offset_years : float ,
) - > Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
""" Select per-sample prediction context index.
IMPORTANT SEMANTICS :
- The last observed token time is treated as the FOLLOW - UP END time .
- We pick the last valid token with time < = ( followup_end_time - offset ) .
- We do NOT interpret followup_end_time as an event time .
Returns :
keep_mask : ( B , ) bool , which samples have a valid context
t_ctx : ( B , ) long , index into sequence
t_ctx_time : ( B , ) float , time ( days ) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid . sum ( dim = 1 )
last_idx = torch . clamp ( lengths - 1 , min = 0 )
b = torch . arange ( event_seq . size ( 0 ) , device = event_seq . device )
followup_end_time = time_seq [ b , last_idx ]
t_cut = followup_end_time - ( offset_years * DAYS_PER_YEAR )
eligible = valid & ( time_seq < = t_cut . unsqueeze ( 1 ) )
eligible_counts = eligible . sum ( dim = 1 )
keep = eligible_counts > 0
t_ctx = torch . clamp ( eligible_counts - 1 , min = 0 ) . to ( torch . long )
t_ctx_time = time_seq [ b , t_ctx ]
return keep , t_ctx , t_ctx_time
def next_event_after_context (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
) - > Tuple [ torch . Tensor , torch . Tensor ] :
""" Return next disease event after context.
Returns :
dt_years : ( B , ) float , time to next disease in years ; + inf if none
cause : ( B , ) long , disease id in [ 0 , K ) for next event ; - 1 if none
"""
B , L = event_seq . shape
b = torch . arange ( B , device = event_seq . device )
t0 = time_seq [ b , t_ctx ]
# Allow same-day events while excluding the context token itself.
# We rely on time-sorted sequences and select the FIRST valid future event by index.
idxs = torch . arange ( L , device = event_seq . device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
future = ( idxs > t_ctx . unsqueeze ( 1 ) ) & ( event_seq > = 2 ) & ( event_seq != 0 )
idx_min = torch . where (
future , idxs , torch . full_like ( idxs , L ) ) . min ( dim = 1 ) . values
has = idx_min < L
t_next = torch . where ( has , idx_min , torch . zeros_like ( idx_min ) )
t_next_time = time_seq [ b , t_next ]
dt_days = t_next_time - t0
dt_years = dt_days / DAYS_PER_YEAR
dt_years = torch . where (
has , dt_years , torch . full_like ( dt_years , float ( " inf " ) ) )
cause_token = event_seq [ b , t_next ]
cause = ( cause_token - 2 ) . to ( torch . long )
cause = torch . where ( has , cause , torch . full_like ( cause , - 1 ) )
return dt_years , cause
def multi_hot_ever_within_horizon (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
tau_years : float ,
n_disease : int ,
) - > torch . Tensor :
""" Binary labels: disease k occurs within tau after context (any occurrence). """
B , L = event_seq . shape
b = torch . arange ( B , device = event_seq . device )
t0 = time_seq [ b , t_ctx ]
t1 = t0 + ( tau_years * DAYS_PER_YEAR )
idxs = torch . arange ( L , device = event_seq . device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
# Include same-day events after context, exclude any token at/before context index.
in_window = (
( idxs > t_ctx . unsqueeze ( 1 ) )
& ( time_seq > = t0 . unsqueeze ( 1 ) )
& ( time_seq < = t1 . unsqueeze ( 1 ) )
& ( event_seq > = 2 )
& ( event_seq != 0 )
)
if not in_window . any ( ) :
return torch . zeros ( ( B , n_disease ) , dtype = torch . bool , device = event_seq . device )
b_idx , t_idx = in_window . nonzero ( as_tuple = True )
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
y = torch . zeros ( ( B , n_disease ) , dtype = torch . bool , device = event_seq . device )
y [ b_idx , disease_ids ] = True
return y
def multi_hot_ever_after_context_anytime (
event_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
n_disease : int ,
) - > torch . Tensor :
""" Binary labels: disease k occurs ANYTIME after the prediction context.
This is Delphi2M - compatible for Task A case / control definition .
Same - day events are included as long as they occur after the context token index .
"""
B , L = event_seq . shape
idxs = torch . arange ( L , device = event_seq . device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
future = ( idxs > t_ctx . unsqueeze ( 1 ) ) & ( event_seq > = 2 ) & ( event_seq != 0 )
y = torch . zeros ( ( B , n_disease ) , dtype = torch . bool , device = event_seq . device )
if not future . any ( ) :
return y
b_idx , t_idx = future . nonzero ( as_tuple = True )
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
y [ b_idx , disease_ids ] = True
return y
def multi_hot_selected_causes_within_horizon (
event_seq : torch . Tensor ,
time_seq : torch . Tensor ,
t_ctx : torch . Tensor ,
tau_years : float ,
cause_ids : torch . Tensor ,
n_disease : int ,
) - > torch . Tensor :
""" Labels for selected causes only: does cause k occur within tau after context? """
B , L = event_seq . shape
device = event_seq . device
b = torch . arange ( B , device = device )
t0 = time_seq [ b , t_ctx ]
t1 = t0 + ( tau_years * DAYS_PER_YEAR )
idxs = torch . arange ( L , device = device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
in_window = (
( idxs > t_ctx . unsqueeze ( 1 ) )
& ( time_seq > = t0 . unsqueeze ( 1 ) )
& ( time_seq < = t1 . unsqueeze ( 1 ) )
& ( event_seq > = 2 )
& ( event_seq != 0 )
)
out = torch . zeros ( ( B , cause_ids . numel ( ) ) , dtype = torch . bool , device = device )
if not in_window . any ( ) :
return out
b_idx , t_idx = in_window . nonzero ( as_tuple = True )
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch . zeros ( ( int ( n_disease ) , ) , dtype = torch . bool , device = device )
selected [ cause_ids ] = True
keep = selected [ disease_ids ]
if not keep . any ( ) :
return out
b_idx = b_idx [ keep ]
disease_ids = disease_ids [ keep ]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch . full ( ( int ( n_disease ) , ) , - 1 , dtype = torch . long , device = device )
lookup [ cause_ids ] = torch . arange ( cause_ids . numel ( ) , device = device )
local = lookup [ disease_ids ]
out [ b_idx , local ] = True
return out
# ============================================================
# CIF conversion
# ============================================================
def cifs_from_exponential_logits (
logits : torch . Tensor ,
taus : Sequence [ float ] ,
eps : float = 1e-6 ,
return_survival : bool = False ,
) - > torch . Tensor :
""" Convert exponential cause-specific logits -> CIFs at taus.
logits : ( B , K )
returns : ( B , K , H ) or ( cif , survival ) if return_survival
"""
hazards = F . softplus ( logits ) + eps
total = hazards . sum ( dim = 1 , keepdim = True ) # (B,1)
taus_t = torch . tensor ( list ( taus ) , device = logits . device ,
dtype = hazards . dtype ) . view ( 1 , 1 , - 1 )
total_h = total . unsqueeze ( - 1 ) # (B,1,1)
# (1 - exp(-Lambda * tau))
one_minus_surv = 1.0 - torch . exp ( - total_h * taus_t )
frac = hazards / torch . clamp ( total , min = eps )
cif = frac . unsqueeze ( - 1 ) * one_minus_surv # (B,K,H)
# If total==0, set to 0
cif = torch . where ( total_h > 0 , cif , torch . zeros_like ( cif ) )
if not return_survival :
return cif
survival = torch . exp ( - total_h * taus_t ) . squeeze ( 1 ) # (B,H)
2026-01-10 11:42:03 +08:00
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
nonzero = ( total . squeeze ( 1 ) > 0 ) . unsqueeze ( 1 )
survival = torch . where ( nonzero , survival , torch . ones_like ( survival ) )
2026-01-10 11:37:12 +08:00
return cif , survival
def cifs_from_discrete_time_logits (
logits : torch . Tensor ,
bin_edges : Sequence [ float ] ,
taus : Sequence [ float ] ,
return_survival : bool = False ,
) - > torch . Tensor :
""" Convert discrete-time CIF logits -> CIFs at taus.
logits : ( B , K + 1 , n_bins + 1 )
bin_edges : len = n_bins + 1 ( including 0 and inf )
taus : subset of finite bin edges ( recommended )
returns : ( B , K , H ) or ( cif , survival ) if return_survival
"""
if logits . ndim != 3 :
raise ValueError ( " Expected logits shape (B, K+1, n_bins+1) " )
B , K_plus_1 , n_bins_plus_1 = logits . shape
K = K_plus_1 - 1
edges = [ float ( x ) for x in bin_edges ]
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
finite_edges = [ e for e in edges [ 1 : ] if math . isfinite ( e ) ]
n_bins = len ( finite_edges )
if n_bins_plus_1 != len ( edges ) :
raise ValueError ( " logits last dim must match len(bin_edges) " )
probs = torch . softmax ( logits , dim = 1 ) # (B, K+1, n_bins+1)
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
hazards = probs [ : , : K , 1 : 1 + n_bins ] # (B,K,n_bins)
p_comp = probs [ : , K , 1 : 1 + n_bins ] # (B,n_bins)
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
ones = torch . ones ( ( B , 1 ) , device = logits . device , dtype = probs . dtype )
cum = torch . cumprod ( p_comp , dim = 1 )
s_prev = torch . cat ( [ ones , cum [ : , : - 1 ] ] , dim = 1 ) # (B, n_bins)
cif_bins = torch . cumsum ( s_prev . unsqueeze (
1 ) * hazards , dim = 2 ) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np . asarray ( finite_edges , dtype = float )
tau_to_idx : List [ int ] = [ ]
for tau in taus :
tau_f = float ( tau )
if not math . isfinite ( tau_f ) :
raise ValueError ( " taus must be finite for discrete-time CIF " )
diffs = np . abs ( finite_edges_arr - tau_f )
j = int ( np . argmin ( diffs ) )
if diffs [ j ] > 1e-6 :
raise ValueError (
f " tau= { tau_f } not close to any finite bin edge (min |edge-tau|= { diffs [ j ] } ) "
)
tau_to_idx . append ( j )
idx = torch . tensor ( tau_to_idx , device = logits . device , dtype = torch . long )
cif = cif_bins . index_select ( dim = 2 , index = idx ) # (B,K,H)
if not return_survival :
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins . index_select ( dim = 1 , index = idx ) # (B,H)
return cif , survival
2026-01-13 15:59:20 +08:00
def _normal_cdf_stable ( z : torch . Tensor ) - > torch . Tensor :
z = torch . clamp ( z , - 12.0 , 12.0 )
return 0.5 * ( 1.0 + torch . erf ( z / math . sqrt ( 2.0 ) ) )
2026-01-13 21:11:38 +08:00
def cifs_from_lognormal_basis_binned_hazard_logits (
2026-01-13 15:59:20 +08:00
logits : torch . Tensor ,
2026-01-13 21:11:38 +08:00
* ,
2026-01-13 15:59:20 +08:00
centers : Sequence [ float ] ,
sigma : torch . Tensor ,
2026-01-13 21:11:38 +08:00
bin_edges : Sequence [ float ] ,
2026-01-13 15:59:20 +08:00
taus : Sequence [ float ] ,
2026-01-13 21:11:38 +08:00
eps : float = 1e-8 ,
alpha_floor : float = 0.0 ,
2026-01-13 15:59:20 +08:00
return_survival : bool = False ,
) - > torch . Tensor :
2026-01-13 21:11:38 +08:00
""" Convert Route-3 binned hazard logits -> CIFs at taus.
2026-01-13 15:59:20 +08:00
2026-01-13 21:11:38 +08:00
logits : ( B , J , R ) OR ( B , J * R ) OR ( B , 1 + J * R ) ( leading column ignored ) .
taus are expected to align with finite bin edges .
2026-01-13 15:59:20 +08:00
"""
2026-01-13 21:11:38 +08:00
if logits . ndim not in { 2 , 3 } :
raise ValueError ( " logits must be 2D or 3D " )
2026-01-13 15:59:20 +08:00
if sigma . ndim != 0 :
raise ValueError ( " sigma must be a scalar tensor " )
device = logits . device
dtype = logits . dtype
centers_t = torch . tensor ( [ float ( x )
for x in centers ] , device = device , dtype = dtype )
r = int ( centers_t . numel ( ) )
2026-01-13 21:11:38 +08:00
if r < = 0 :
raise ValueError ( " centers must be non-empty " )
offset = 0
if logits . ndim == 3 :
j = int ( logits . shape [ 1 ] )
if int ( logits . shape [ 2 ] ) != r :
raise ValueError (
f " logits.shape[2] must equal R= { r } ; got { int ( logits . shape [ 2 ] ) } "
)
else :
d = int ( logits . shape [ 1 ] )
if d % r == 0 :
jr = d
elif ( d - 1 ) % r == 0 :
offset = 1
jr = d - 1
else :
raise ValueError (
f " logits.shape[1] must be divisible by R= { r } (or 1+J*R); got { d } " )
j = jr / / r
if j < = 0 :
raise ValueError ( " Inferred J must be >= 1 " )
edges = [ float ( x ) for x in bin_edges ]
finite_edges = [ e for e in edges [ 1 : ] if math . isfinite ( e ) ]
n_bins = len ( finite_edges )
if n_bins < = 0 :
raise ValueError ( " bin_edges must contain at least one finite edge " )
# Build finite bins [edges[k-1], edges[k]) for k=1..n_bins
left = torch . tensor ( edges [ : n_bins ] , device = device , dtype = dtype )
right = torch . tensor ( edges [ 1 : 1 + n_bins ] , device = device , dtype = dtype )
2026-01-13 15:59:20 +08:00
# Stable t_min clamp (aligns with training loss rule).
t_min = 1e-12
2026-01-13 21:11:38 +08:00
if len ( edges ) > = 2 and math . isfinite ( edges [ 1 ] ) and edges [ 1 ] > 0 :
t_min = edges [ 1 ] * 1e-6
2026-01-13 15:59:20 +08:00
t_min_t = torch . tensor ( float ( t_min ) , device = device , dtype = dtype )
2026-01-13 21:11:38 +08:00
left_is_zero = left < = 0
left_clamped = torch . clamp ( left , min = t_min_t )
log_left = torch . log ( left_clamped )
right_clamped = torch . clamp ( right , min = t_min_t )
log_right = torch . log ( right_clamped )
sigma_c = sigma . to ( device = device , dtype = dtype )
z_left = ( log_left . unsqueeze ( - 1 ) - centers_t . unsqueeze ( 0 ) ) / sigma_c
z_right = ( log_right . unsqueeze ( - 1 ) - centers_t . unsqueeze ( 0 ) ) / sigma_c
cdf_left = _normal_cdf_stable ( z_left )
if left_is_zero . any ( ) :
cdf_left = torch . where ( left_is_zero . unsqueeze ( - 1 ) ,
torch . zeros_like ( cdf_left ) , cdf_left )
cdf_right = _normal_cdf_stable ( z_right )
delta_basis = torch . clamp ( cdf_right - cdf_left , min = 0.0 ) # (n_bins, R)
if logits . ndim == 3 :
alpha = F . softplus ( logits ) + float ( alpha_floor ) # (B,J,R)
else :
logits_used = logits [ : , offset : ]
alpha = ( F . softplus ( logits_used ) + float ( alpha_floor )
) . view ( logits . size ( 0 ) , j , r ) # (B,J,R)
h_jk = torch . einsum ( " bjr,kr->bjk " , alpha , delta_basis ) # (B,J,n_bins)
h_k = h_jk . sum ( dim = 1 ) # (B,n_bins)
2026-01-13 15:59:20 +08:00
2026-01-13 21:11:38 +08:00
h_k = torch . clamp ( h_k , min = eps )
h_jk = torch . clamp ( h_jk , min = eps )
2026-01-13 15:59:20 +08:00
2026-01-13 21:11:38 +08:00
p_comp = torch . exp ( - h_k ) # (B,n_bins)
one_minus = - torch . expm1 ( - h_k ) # (B,n_bins) = 1-exp(-H)
ratio = h_jk / torch . clamp ( h_k . unsqueeze ( 1 ) , min = eps )
p_event = one_minus . unsqueeze ( 1 ) * ratio # (B,J,n_bins)
2026-01-13 15:59:20 +08:00
2026-01-13 21:11:38 +08:00
ones = torch . ones ( ( alpha . size ( 0 ) , 1 ) , device = device , dtype = dtype )
cum = torch . cumprod ( p_comp , dim = 1 ) # survival after each bin
s_prev = torch . cat ( [ ones , cum [ : , : - 1 ] ] , dim = 1 ) # survival before each bin
cif_bins = torch . cumsum ( s_prev . unsqueeze (
1 ) * p_event , dim = 2 ) # (B,J,n_bins)
finite_edges_arr = np . asarray ( finite_edges , dtype = float )
tau_to_idx : List [ int ] = [ ]
for tau in taus :
tau_f = float ( tau )
if not math . isfinite ( tau_f ) :
raise ValueError ( " taus must be finite for discrete-time CIF " )
diffs = np . abs ( finite_edges_arr - tau_f )
idx0 = int ( np . argmin ( diffs ) )
if diffs [ idx0 ] > 1e-6 :
raise ValueError (
f " tau= { tau_f } not close to any finite bin edge (min |edge-tau|= { diffs [ idx0 ] } ) "
)
tau_to_idx . append ( idx0 )
idx = torch . tensor ( tau_to_idx , device = device , dtype = torch . long )
cif = cif_bins . index_select ( dim = 2 , index = idx ) # (B,J,H)
2026-01-13 15:59:20 +08:00
if not return_survival :
return cif
2026-01-13 21:11:38 +08:00
survival = cum . index_select ( dim = 1 , index = idx ) # (B,H)
2026-01-13 15:59:20 +08:00
return cif , survival
2026-01-10 11:37:12 +08:00
# ============================================================
# CIF integrity checks
# ============================================================
def check_cif_integrity (
cause_cif : np . ndarray ,
horizons : Sequence [ float ] ,
* ,
tol : float = 1e-6 ,
name : str = " " ,
strict : bool = False ,
survival : Optional [ np . ndarray ] = None ,
) - > Tuple [ bool , List [ str ] ] :
""" Run basic sanity checks on CIF arrays.
Args :
cause_cif : ( N , K , H )
horizons : length H
tol : tolerance for inequalities
name : model name for messages
strict : if True , raise ValueError on first failure
survival : optional ( N , H ) survival values at the same horizons
Returns :
( integrity_ok , notes )
"""
notes : List [ str ] = [ ]
model_tag = f " [ { name } ] " if name else " "
def _fail ( msg : str ) - > None :
full = model_tag + msg
if strict :
raise ValueError ( full )
print ( " WARNING: " , full )
notes . append ( msg )
cif = np . asarray ( cause_cif )
if cif . ndim != 3 :
_fail ( f " integrity: expected cause_cif ndim=3, got { cif . ndim } " )
return False , notes
N , K , H = cif . shape
if H != len ( horizons ) :
_fail (
f " integrity: horizon length mismatch (H= { H } , len(horizons)= { len ( horizons ) } ) " )
# (5) Finite
if not np . isfinite ( cif ) . all ( ) :
_fail ( " integrity: non-finite values (NaN/Inf) in cause_cif " )
# (1) Range
cmin = float ( np . nanmin ( cif ) )
cmax = float ( np . nanmax ( cif ) )
if cmin < - tol :
_fail ( f " integrity: range min= { cmin } < -tol= { - tol } " )
if cmax > 1.0 + tol :
_fail ( f " integrity: range max= { cmax } > 1+tol= { 1.0 + tol } " )
# (2) Monotonicity in horizons (per n,k)
diffs = np . diff ( cif , axis = 2 )
if diffs . size > 0 :
if np . nanmin ( diffs ) < - tol :
_fail ( " integrity: monotonicity violated (found negative diff along horizons) " )
# (3) Probability mass: sum_k CIF <= 1
mass = np . sum ( cif , axis = 1 ) # (N,H)
mass_max = float ( np . nanmax ( mass ) )
if mass_max > 1.0 + tol :
_fail ( f " integrity: probability mass exceeds 1 (max sum_k= { mass_max } ) " )
# (4) Conservation with survival, if provided
if survival is None :
warn = " integrity: survival not provided; skipping conservation check "
if strict :
# still skip (requested behavior), but keep message for context
notes . append ( warn )
else :
print ( " WARNING: " , model_tag + warn )
notes . append ( warn )
else :
s = np . asarray ( survival , dtype = float )
if s . shape != ( N , H ) :
_fail (
f " integrity: survival shape mismatch (got { s . shape } , expected { ( N , H ) } ) " )
else :
recon = 1.0 - s
err = np . abs ( recon - mass )
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
tol_cons = max ( float ( tol ) , 1e-4 )
if float ( np . nanmax ( err ) ) > tol_cons :
_fail (
f " integrity: conservation violated (max |(1-surv)-sum_cif|= { float ( np . nanmax ( err ) ) } , tol= { tol_cons } ) " )
ok = len ( [ n for n in notes if not n . endswith (
" skipping conservation check " ) ] ) == 0
return ok , notes
# ============================================================
# Metrics
# ============================================================
2026-01-11 00:47:56 +08:00
# --- Rank-based ROC AUC (ties handled via midranks) ---
2026-01-10 11:37:12 +08:00
def compute_midrank ( x : np . ndarray ) - > np . ndarray :
2026-01-10 23:49:37 +08:00
""" Vectorized midrank computation (ties -> average ranks). """
2026-01-10 11:37:12 +08:00
x = np . asarray ( x , dtype = float )
2026-01-10 23:49:37 +08:00
n = int ( x . shape [ 0 ] )
if n == 0 :
return np . asarray ( [ ] , dtype = float )
order = np . argsort ( x , kind = " mergesort " )
2026-01-10 11:37:12 +08:00
z = x [ order ]
2026-01-10 23:49:37 +08:00
# Find tie groups in sorted order.
diff = np . diff ( z )
# boundaries includes 0 and n
boundaries = np . concatenate (
[ np . array ( [ 0 ] , dtype = int ) , np . nonzero ( diff != 0 )
[ 0 ] + 1 , np . array ( [ n ] , dtype = int ) ]
)
starts = boundaries [ : - 1 ]
ends = boundaries [ 1 : ]
lens = ends - starts
# Midrank for each group in 1-based rank space.
mids = 0.5 * ( starts + ends - 1 ) + 1.0
t_sorted = np . repeat ( mids , lens ) . astype ( float , copy = False )
2026-01-10 11:37:12 +08:00
out = np . empty ( n , dtype = float )
2026-01-10 23:49:37 +08:00
out [ order ] = t_sorted
2026-01-10 11:37:12 +08:00
return out
def roc_auc_rank ( y_true : np . ndarray , y_score : np . ndarray ) - > float :
""" Rank-based ROC AUC via Mann– Whitney U statistic (ties handled by midranks).
Returns NaN for degenerate labels .
"""
y = np . asarray ( y_true , dtype = int )
s = np . asarray ( y_score , dtype = float )
if y . ndim != 1 or s . ndim != 1 or y . shape [ 0 ] != s . shape [ 0 ] :
raise ValueError ( " roc_auc_rank expects 1D arrays of equal length " )
m = int ( np . sum ( y == 1 ) )
n = int ( np . sum ( y == 0 ) )
if m == 0 or n == 0 :
return float ( " nan " )
ranks = compute_midrank ( s )
sum_pos = float ( np . sum ( ranks [ y == 1 ] ) )
auc = ( sum_pos - m * ( m + 1 ) / 2.0 ) / ( m * n )
return float ( auc )
def brier_score ( p : np . ndarray , y : np . ndarray ) - > float :
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
return float ( np . mean ( ( p - y ) * * 2 ) )
def calibration_deciles ( p : np . ndarray , y : np . ndarray , n_bins : int = 10 ) - > Dict [ str , Any ] :
2026-01-10 23:49:37 +08:00
bins , ici = _calibration_bins_and_ici (
p , y , n_bins = int ( n_bins ) , return_bins = True )
return { " bins " : bins , " ici " : float ( ici ) }
def calibration_ici_only ( p : np . ndarray , y : np . ndarray , n_bins : int = 10 ) - > float :
""" Fast ICI only (no per-bin point export). """
_ , ici = _calibration_bins_and_ici (
p , y , n_bins = int ( n_bins ) , return_bins = False )
return float ( ici )
def _calibration_bins_and_ici (
p : np . ndarray ,
y : np . ndarray ,
* ,
n_bins : int ,
return_bins : bool ,
) - > Tuple [ List [ Dict [ str , Any ] ] , float ] :
""" Vectorized quantile binning for calibration + ICI. """
2026-01-10 11:37:12 +08:00
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
if p . size == 0 :
2026-01-10 23:49:37 +08:00
return ( [ ] , float ( " nan " ) ) if return_bins else ( [ ] , float ( " nan " ) )
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
q = np . linspace ( 0.0 , 1.0 , int ( n_bins ) + 1 )
edges = np . quantile ( p , q )
edges = np . asarray ( edges , dtype = float )
2026-01-10 11:37:12 +08:00
edges [ 0 ] = - np . inf
edges [ - 1 ] = np . inf
2026-01-10 23:49:37 +08:00
# Bin assignment: i if edges[i] < p <= edges[i+1]
bin_idx = np . searchsorted ( edges , p , side = " right " ) - 1
bin_idx = np . clip ( bin_idx , 0 , int ( n_bins ) - 1 )
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
counts = np . bincount ( bin_idx , minlength = int ( n_bins ) ) . astype ( float )
sum_p = np . bincount ( bin_idx , weights = p ,
minlength = int ( n_bins ) ) . astype ( float )
sum_y = np . bincount ( bin_idx , weights = y ,
minlength = int ( n_bins ) ) . astype ( float )
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
nonempty = counts > 0
if not np . any ( nonempty ) :
return ( [ ] , float ( " nan " ) ) if return_bins else ( [ ] , float ( " nan " ) )
p_mean = np . zeros ( int ( n_bins ) , dtype = float )
y_mean = np . zeros ( int ( n_bins ) , dtype = float )
p_mean [ nonempty ] = sum_p [ nonempty ] / counts [ nonempty ]
y_mean [ nonempty ] = sum_y [ nonempty ] / counts [ nonempty ]
diffs = np . abs ( p_mean [ nonempty ] - y_mean [ nonempty ] )
ici = float ( np . mean ( diffs ) ) if diffs . size else float ( " nan " )
if not return_bins :
return [ ] , ici
bins : List [ Dict [ str , Any ] ] = [ ]
idxs = np . nonzero ( nonempty ) [ 0 ]
for i in idxs . tolist ( ) :
bins . append (
{
" bin " : int ( i ) ,
" p_mean " : float ( p_mean [ i ] ) ,
" y_mean " : float ( y_mean [ i ] ) ,
" n " : int ( counts [ i ] ) ,
}
)
return bins , ici
def _progress_line ( done : int , total : int , prefix : str = " " ) - > str :
total_i = max ( 1 , int ( total ) )
done_i = max ( 0 , min ( int ( done ) , total_i ) )
width = 28
frac = done_i / total_i
filled = int ( round ( width * frac ) )
bar = " # " * filled + " - " * ( width - filled )
pct = 100.0 * frac
return f " { prefix } [ { bar } ] { done_i } / { total_i } ( { pct : 5.1f } %) "
def _should_show_progress ( mode : str ) - > bool :
m = str ( mode ) . strip ( ) . lower ( )
if m in { " 0 " , " false " , " no " , " none " , " off " } :
return False
# Default: show if interactive.
if m in { " auto " , " 1 " , " true " , " yes " , " on " , " bar " } :
try :
return bool ( sys . stdout . isatty ( ) )
except Exception :
return True
return True
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
def _safe_float ( x : Any , default : float = float ( " nan " ) ) - > float :
try :
return float ( x )
except Exception :
return float ( default )
def _ensure_dir ( path : str ) - > None :
os . makedirs ( path , exist_ok = True )
def load_cause_names ( path : str = " labels.csv " ) - > Dict [ int , str ] :
""" Load 0-based cause_id -> name mapping.
labels . csv is assumed to be one label per line , in disease - id order .
"""
if not os . path . exists ( path ) :
return { }
mapping : Dict [ int , str ] = { }
with open ( path , " r " , encoding = " utf-8 " ) as f :
for i , line in enumerate ( f ) :
name = line . strip ( )
if name :
mapping [ int ( i ) ] = name
return mapping
def pick_focus_causes (
* ,
counts_within_tau : Optional [ np . ndarray ] ,
n_disease : int ,
death_cause_id : int = DEFAULT_DEATH_CAUSE_ID ,
k : int = 5 ,
) - > List [ int ] :
""" Pick focus causes for user-facing evaluation.
Rule :
1 ) Always include death_cause_id first .
2 ) Then add K additional causes by descending event count if available .
If counts_within_tau is None , fall back to descending cause_id coverage proxy .
Notes :
- counts_within_tau is expected to be shape ( n_disease , ) .
- Deterministic : ties broken by smaller cause id .
"""
n_disease_i = int ( n_disease )
if death_cause_id < 0 or death_cause_id > = n_disease_i :
print (
f " WARNING: death_cause_id= { death_cause_id } out of range (n_disease= { n_disease_i } ); "
" it will be omitted from focus causes. "
)
focus : List [ int ] = [ ]
else :
focus = [ int ( death_cause_id ) ]
candidates = [ i for i in range ( n_disease_i ) if i != int ( death_cause_id ) ]
if counts_within_tau is not None :
c = np . asarray ( counts_within_tau ) . astype ( float )
if c . shape [ 0 ] != n_disease_i :
print (
" WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering. "
)
counts_within_tau = None
else :
# Sort by (-count, cause_id)
order = sorted ( candidates , key = lambda i : ( - float ( c [ i ] ) , int ( i ) ) )
order = [ i for i in order if float ( c [ i ] ) > 0 ]
focus . extend ( [ int ( i ) for i in order [ : int ( k ) ] ] )
if counts_within_tau is None :
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
# (Real coverage requires data; this path is mostly for robustness.)
order = sorted ( candidates , key = lambda i : ( - int ( i ) ) )
focus . extend ( [ int ( i ) for i in order [ : int ( k ) ] ] )
# De-dup while preserving order
seen = set ( )
out : List [ int ] = [ ]
for cid in focus :
if cid not in seen :
out . append ( cid )
seen . add ( cid )
return out
def write_simple_csv ( path : str , fieldnames : List [ str ] , rows : List [ Dict [ str , Any ] ] ) - > None :
_ensure_dir ( os . path . dirname ( os . path . abspath ( path ) ) or " . " )
with open ( path , " w " , newline = " " , encoding = " utf-8 " ) as f :
w = csv . DictWriter ( f , fieldnames = fieldnames )
w . writeheader ( )
for r in rows :
w . writerow ( r )
def _sex_slices ( sex : Optional [ np . ndarray ] ) - > List [ Tuple [ str , Optional [ np . ndarray ] ] ] :
""" Return list of (sex_label, mask) slices including an ' all ' slice.
If sex is missing , returns only ( ' all ' , None ) .
"""
out : List [ Tuple [ str , Optional [ np . ndarray ] ] ] = [ ( " all " , None ) ]
if sex is None :
return out
s = np . asarray ( sex )
if s . ndim != 1 :
return out
for val in [ 0 , 1 ] :
m = ( s == val )
if int ( np . sum ( m ) ) > 0 :
out . append ( ( str ( val ) , m ) )
return out
def _quantile_edges ( p : np . ndarray , q : int ) - > np . ndarray :
edges = np . quantile ( p , np . linspace ( 0.0 , 1.0 , int ( q ) + 1 ) )
edges = np . asarray ( edges , dtype = float )
edges [ 0 ] = - np . inf
edges [ - 1 ] = np . inf
return edges
def compute_risk_stratification_bins (
p : np . ndarray ,
y : np . ndarray ,
* ,
q_default : int = 10 ,
) - > Tuple [ int , List [ Dict [ str , Any ] ] , Dict [ str , Any ] ] :
""" Compute quantile-based risk strata and a compact summary. """
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
n = int ( p . shape [ 0 ] )
if n == 0 :
return 0 , [ ] , {
" y_overall " : float ( " nan " ) ,
" top_decile_y_rate " : float ( " nan " ) ,
" bottom_half_y_rate " : float ( " nan " ) ,
" lift_top10_vs_bottom50 " : float ( " nan " ) ,
" slope_pred_vs_obs " : float ( " nan " ) ,
}
# Choose quantiles robustly.
q = int ( q_default )
if n < 200 :
q = 5
edges = _quantile_edges ( p , q )
y_overall = float ( np . mean ( y ) )
bin_rows : List [ Dict [ str , Any ] ] = [ ]
p_means : List [ float ] = [ ]
y_rates : List [ float ] = [ ]
n_bins : List [ int ] = [ ]
for i in range ( q ) :
mask = ( p > edges [ i ] ) & ( p < = edges [ i + 1 ] )
nb = int ( np . sum ( mask ) )
if nb == 0 :
# Keep the row for consistent plotting; set NaNs.
bin_rows . append (
{
" q " : int ( i + 1 ) ,
" n_bin " : 0 ,
" p_mean " : float ( " nan " ) ,
" y_rate " : float ( " nan " ) ,
" y_overall " : y_overall ,
" lift_vs_overall " : float ( " nan " ) ,
}
)
continue
p_mean = float ( np . mean ( p [ mask ] ) )
y_rate = float ( np . mean ( y [ mask ] ) )
lift = float ( y_rate / y_overall ) if y_overall > 0 else float ( " nan " )
bin_rows . append (
{
" q " : int ( i + 1 ) ,
" n_bin " : nb ,
" p_mean " : p_mean ,
" y_rate " : y_rate ,
" y_overall " : y_overall ,
" lift_vs_overall " : lift ,
}
)
p_means . append ( p_mean )
y_rates . append ( y_rate )
n_bins . append ( nb )
# Summary
top_mask = ( p > edges [ q - 1 ] ) & ( p < = edges [ q ] )
bot_half_mask = ( p > edges [ 0 ] ) & ( p < = edges [ q / / 2 ] )
top_y = float ( np . mean ( y [ top_mask ] ) ) if int (
np . sum ( top_mask ) ) > 0 else float ( " nan " )
bot_y = float ( np . mean ( y [ bot_half_mask ] ) ) if int (
np . sum ( bot_half_mask ) ) > 0 else float ( " nan " )
lift_top_vs_bottom = float ( top_y / bot_y ) if ( np . isfinite ( top_y )
and np . isfinite ( bot_y ) and bot_y > 0 ) else float ( " nan " )
slope = float ( " nan " )
if len ( p_means ) > = 2 :
# Weighted least squares slope of y_rate ~ p_mean.
x = np . asarray ( p_means , dtype = float )
yy = np . asarray ( y_rates , dtype = float )
w = np . asarray ( n_bins , dtype = float )
xm = float ( np . average ( x , weights = w ) )
ym = float ( np . average ( yy , weights = w ) )
denom = float ( np . sum ( w * ( x - xm ) * * 2 ) )
if denom > 0 :
slope = float ( np . sum ( w * ( x - xm ) * ( yy - ym ) ) / denom )
summary = {
" y_overall " : y_overall ,
" top_decile_y_rate " : top_y ,
" bottom_half_y_rate " : bot_y ,
" lift_top10_vs_bottom50 " : lift_top_vs_bottom ,
" slope_pred_vs_obs " : slope ,
}
return q , bin_rows , summary
def compute_capture_points (
p : np . ndarray ,
y : np . ndarray ,
k_pcts : Sequence [ int ] ,
) - > List [ Dict [ str , Any ] ] :
p = np . asarray ( p , dtype = float )
y = np . asarray ( y , dtype = float )
n = int ( p . shape [ 0 ] )
if n == 0 :
return [ ]
order = np . argsort ( - p )
y_sorted = y [ order ]
events_total = float ( np . sum ( y_sorted ) )
rows : List [ Dict [ str , Any ] ] = [ ]
for k in k_pcts :
kf = float ( k )
n_targeted = int ( math . ceil ( n * kf / 100.0 ) )
n_targeted = max ( 1 , min ( n_targeted , n ) )
events_targeted = float ( np . sum ( y_sorted [ : n_targeted ] ) )
capture = float ( events_targeted /
events_total ) if events_total > 0 else float ( " nan " )
precision = float ( events_targeted / float ( n_targeted ) )
rows . append (
{
" k_pct " : int ( k ) ,
" n_targeted " : int ( n_targeted ) ,
" events_targeted " : float ( events_targeted ) ,
" events_total " : float ( events_total ) ,
" event_capture_rate " : capture ,
" precision_in_targeted " : precision ,
}
)
return rows
2026-01-11 00:47:56 +08:00
def compute_event_rate_at_topk_causes (
p_tau : np . ndarray ,
y_tau : np . ndarray ,
topk_list : Sequence [ int ] ,
) - > List [ Dict [ str , Any ] ] :
""" Compute Event Rate@K for cross-cause prioritization.
For each individual , rank causes by predicted risk p_tau at a fixed horizon .
For each K , select top - K causes and compute the fraction that occur within the horizon .
Args :
p_tau : ( N , K ) predicted CIFs at a fixed horizon
y_tau : ( N , K ) binary labels ( 0 / 1 ) whether cause occurs within the horizon
topk_list : list of K values to evaluate
Returns :
2026-01-11 00:52:35 +08:00
List of rows with :
- topk
- event_rate_mean / event_rate_median
- recall_mean / recall_median ( averaged over individuals with > = 1 true cause )
- n_total / n_valid_recall
2026-01-11 00:47:56 +08:00
"""
p = np . asarray ( p_tau , dtype = float )
y = np . asarray ( y_tau , dtype = float )
if p . ndim != 2 or y . ndim != 2 or p . shape != y . shape :
raise ValueError (
" compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape " )
n , k_total = p . shape
if n == 0 or k_total == 0 :
out : List [ Dict [ str , Any ] ] = [ ]
for kk in topk_list :
out . append (
{
" topk " : int ( max ( 1 , int ( kk ) ) ) ,
" event_rate_mean " : float ( " nan " ) ,
" event_rate_median " : float ( " nan " ) ,
2026-01-11 00:52:35 +08:00
" recall_mean " : float ( " nan " ) ,
" recall_median " : float ( " nan " ) ,
2026-01-11 00:47:56 +08:00
" n_total " : int ( n ) ,
2026-01-11 00:52:35 +08:00
" n_valid_recall " : 0 ,
2026-01-11 00:47:56 +08:00
}
)
return out
# Sanitize K list.
topks = sorted ( { int ( x ) for x in topk_list if int ( x ) > 0 } )
if not topks :
return [ ]
max_k = min ( int ( max ( topks ) ) , int ( k_total ) )
if max_k < = 0 :
return [ ]
# Efficient: get top max_k causes per individual, then sort within those.
part = np . argpartition ( - p , kth = max_k - 1 , axis = 1 ) [ : , : max_k ] # (N, max_k)
p_part = np . take_along_axis ( p , part , axis = 1 )
order = np . argsort ( - p_part , axis = 1 )
top_sorted = np . take_along_axis ( part , order , axis = 1 ) # (N, max_k)
out_rows : List [ Dict [ str , Any ] ] = [ ]
for kk in topks :
kk_eff = min ( int ( kk ) , int ( k_total ) )
idx = top_sorted [ : , : kk_eff ]
y_sel = np . take_along_axis ( y , idx , axis = 1 )
2026-01-11 00:52:35 +08:00
# Selected true causes per person
hit = np . sum ( y_sel , axis = 1 )
# Precision-like: fraction of selected causes that occur
per_person = hit / \
float ( kk_eff ) if kk_eff > 0 else np . full ( ( n , ) , np . nan )
# Recall@K: fraction of true causes covered by top-K (undefined when no true cause)
g = np . sum ( y , axis = 1 )
valid = g > 0
recall = np . full ( ( n , ) , np . nan , dtype = float )
recall [ valid ] = hit [ valid ] / g [ valid ]
2026-01-11 00:47:56 +08:00
out_rows . append (
{
" topk " : int ( kk_eff ) ,
" event_rate_mean " : float ( np . mean ( per_person ) ) if per_person . size else float ( " nan " ) ,
" event_rate_median " : float ( np . median ( per_person ) ) if per_person . size else float ( " nan " ) ,
2026-01-11 00:52:35 +08:00
" recall_mean " : float ( np . nanmean ( recall ) ) if int ( np . sum ( valid ) ) > 0 else float ( " nan " ) ,
" recall_median " : float ( np . nanmedian ( recall ) ) if int ( np . sum ( valid ) ) > 0 else float ( " nan " ) ,
2026-01-11 00:47:56 +08:00
" n_total " : int ( n ) ,
2026-01-11 00:52:35 +08:00
" n_valid_recall " : int ( np . sum ( valid ) ) ,
2026-01-11 00:47:56 +08:00
}
)
return out_rows
2026-01-10 17:00:16 +08:00
2026-01-11 00:52:35 +08:00
def compute_random_ranking_baseline_topk (
y_tau : np . ndarray ,
topk_list : Sequence [ int ] ,
* ,
z : float = 1.645 ,
) - > List [ Dict [ str , Any ] ] :
""" Random ranking baseline for Event Rate@K and Recall@K.
Baseline definition :
- For each individual , pick K causes uniformly at random without replacement .
- EventRate @K = ( # selected causes that occur) / K.
- Recall @K = ( # selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause.
This function computes the expected baseline mean and an approximate 5 - 95 % range
for the population mean using a normal approximation of the hypergeometric variance .
Args :
y_tau : ( N , K_total ) binary labels
topk_list : K values
z : z - score for the central interval ; z = 1.645 corresponds to ~ 90 % ( 5 - 95 % )
Returns :
Rows with baseline means and p05 / p95 for both metrics .
"""
y = np . asarray ( y_tau , dtype = float )
if y . ndim != 2 :
raise ValueError (
" compute_random_ranking_baseline_topk expects y_tau with shape (N,K) " )
n , k_total = y . shape
topks = sorted ( { int ( x ) for x in topk_list if int ( x ) > 0 } )
if not topks :
return [ ]
g = np . sum ( y , axis = 1 ) # (N,)
valid = g > 0
n_valid = int ( np . sum ( valid ) )
out : List [ Dict [ str , Any ] ] = [ ]
for kk in topks :
kk_eff = min ( int ( kk ) , int ( k_total ) ) if k_total > 0 else int ( kk )
if n == 0 or k_total == 0 or kk_eff < = 0 :
out . append (
{
" topk " : int ( max ( 1 , kk_eff ) ) ,
" baseline_event_rate_mean " : float ( " nan " ) ,
" baseline_event_rate_p05 " : float ( " nan " ) ,
" baseline_event_rate_p95 " : float ( " nan " ) ,
" baseline_recall_mean " : float ( " nan " ) ,
" baseline_recall_p05 " : float ( " nan " ) ,
" baseline_recall_p95 " : float ( " nan " ) ,
" n_total " : int ( n ) ,
" n_valid_recall " : int ( n_valid ) ,
" k_total " : int ( k_total ) ,
" baseline_method " : " random_ranking_hypergeometric_normal_approx " ,
}
)
continue
# Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total.
er_mean = float ( np . mean ( g / float ( k_total ) ) )
# Variance of hypergeometric count X:
# Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total.
if k_total > 1 and kk_eff < k_total :
p = g / float ( k_total )
finite_corr = ( float ( k_total - kk_eff ) / float ( k_total - 1 ) )
var_x = float ( kk_eff ) * p * ( 1.0 - p ) * finite_corr
else :
var_x = np . zeros_like ( g , dtype = float )
var_er = var_x / ( float ( kk_eff ) * * 2 )
se_er_mean = float ( np . sqrt ( np . sum ( var_er ) ) ) / float ( max ( 1 , n ) )
er_p05 = float ( np . clip ( er_mean - z * se_er_mean , 0.0 , 1.0 ) )
er_p95 = float ( np . clip ( er_mean + z * se_er_mean , 0.0 , 1.0 ) )
# Expected Recall@K for individuals with g>0 is K/K_total (clipped).
rec_mean = float ( min ( float ( kk_eff ) / float ( k_total ) , 1.0 ) )
if n_valid > 0 :
var_rec = np . zeros_like ( g , dtype = float )
gv = g [ valid ]
var_xv = var_x [ valid ]
# Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual)
var_rec_v = var_xv / ( gv * * 2 )
se_rec_mean = float ( np . sqrt ( np . sum ( var_rec_v ) ) ) / float ( n_valid )
rec_p05 = float ( np . clip ( rec_mean - z * se_rec_mean , 0.0 , 1.0 ) )
rec_p95 = float ( np . clip ( rec_mean + z * se_rec_mean , 0.0 , 1.0 ) )
else :
rec_p05 = float ( " nan " )
rec_p95 = float ( " nan " )
out . append (
{
" topk " : int ( kk_eff ) ,
" baseline_event_rate_mean " : er_mean ,
" baseline_event_rate_p05 " : er_p05 ,
" baseline_event_rate_p95 " : er_p95 ,
" baseline_recall_mean " : rec_mean ,
" baseline_recall_p05 " : float ( rec_p05 ) ,
" baseline_recall_p95 " : float ( rec_p95 ) ,
" n_total " : int ( n ) ,
" n_valid_recall " : int ( n_valid ) ,
" k_total " : int ( k_total ) ,
" baseline_method " : " random_ranking_hypergeometric_normal_approx " ,
}
)
return out
2026-01-10 17:00:16 +08:00
def count_occurs_within_horizon (
2026-01-10 11:37:12 +08:00
loader : DataLoader ,
offset_years : float ,
2026-01-10 17:00:16 +08:00
tau_years : float ,
2026-01-10 11:37:12 +08:00
n_disease : int ,
device : str ,
) - > Tuple [ np . ndarray , int ] :
2026-01-10 17:00:16 +08:00
""" Count per-person occurrence within tau after the prediction context.
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
Returns counts [ k ] = number of individuals with disease k at least once in ( t_ctx , t_ctx + tau ] .
2026-01-10 11:37:12 +08:00
"""
counts = torch . zeros ( ( n_disease , ) , dtype = torch . long , device = device )
n_total_eval = 0
for batch in loader :
event_seq , time_seq , cont_feats , cate_feats , sexes = batch
event_seq = event_seq . to ( device )
time_seq = time_seq . to ( device )
keep , t_ctx , _ = select_context_indices (
event_seq , time_seq , offset_years )
if not keep . any ( ) :
continue
n_total_eval + = int ( keep . sum ( ) . item ( ) )
event_seq = event_seq [ keep ]
2026-01-10 17:00:16 +08:00
time_seq = time_seq [ keep ]
2026-01-10 11:37:12 +08:00
t_ctx = t_ctx [ keep ]
B , L = event_seq . shape
2026-01-10 17:00:16 +08:00
b = torch . arange ( B , device = device )
t0 = time_seq [ b , t_ctx ]
t1 = t0 + ( float ( tau_years ) * DAYS_PER_YEAR )
2026-01-10 11:37:12 +08:00
idxs = torch . arange ( L , device = device ) . unsqueeze ( 0 ) . expand ( B , - 1 )
2026-01-10 17:00:16 +08:00
in_window = (
( idxs > t_ctx . unsqueeze ( 1 ) )
& ( time_seq > = t0 . unsqueeze ( 1 ) )
& ( time_seq < = t1 . unsqueeze ( 1 ) )
& ( event_seq > = 2 )
& ( event_seq != 0 )
)
if not in_window . any ( ) :
2026-01-10 11:37:12 +08:00
continue
2026-01-10 17:00:16 +08:00
b_idx , t_idx = in_window . nonzero ( as_tuple = True )
2026-01-10 11:37:12 +08:00
disease_ids = ( event_seq [ b_idx , t_idx ] - 2 ) . to ( torch . long )
2026-01-10 17:00:16 +08:00
# unique per (person, disease) to count per-person within-window occurrence
2026-01-10 11:37:12 +08:00
key = b_idx . to ( torch . long ) * int ( n_disease ) + disease_ids
uniq = torch . unique ( key )
uniq_disease = uniq % int ( n_disease )
counts . scatter_add_ ( 0 , uniq_disease , torch . ones_like (
uniq_disease , dtype = torch . long ) )
return counts . detach ( ) . cpu ( ) . numpy ( ) , int ( n_total_eval )
# ============================================================
# Evaluation core
# ============================================================
def instantiate_model_and_head (
cfg : Dict [ str , Any ] ,
dataset : HealthDataset ,
device : str ,
checkpoint_path : str = " " ,
2026-01-13 15:59:20 +08:00
) - > Tuple [ torch . nn . Module , torch . nn . Module , str , Sequence [ float ] , Dict [ str , Any ] ] :
2026-01-10 11:37:12 +08:00
model_type = str ( cfg [ " model_type " ] )
loss_type = str ( cfg [ " loss_type " ] )
2026-01-13 15:59:20 +08:00
loss_params : Dict [ str , Any ] = { }
2026-01-10 11:37:12 +08:00
if loss_type == " exponential " :
out_dims = [ dataset . n_disease ]
elif loss_type == " discrete_time_cif " :
bin_edges = cfg . get ( " bin_edges " , DEFAULT_BIN_EDGES )
out_dims = [ dataset . n_disease + 1 , len ( bin_edges ) ]
2026-01-13 21:11:38 +08:00
elif loss_type == " lognormal_basis_binned_hazard_cif " :
2026-01-13 15:59:20 +08:00
centers = cfg . get ( " lognormal_centers " , None )
if centers is None :
centers = cfg . get ( " centers " , None )
if not isinstance ( centers , list ) or len ( centers ) == 0 :
raise ValueError (
2026-01-13 21:11:38 +08:00
" lognormal_basis_binned_hazard_cif requires ' lognormal_centers ' (list of mu_r in log-time) in train_config.json "
2026-01-13 15:59:20 +08:00
)
2026-01-13 21:11:38 +08:00
r = len ( centers )
desired_total = int ( dataset . n_disease ) * int ( r )
legacy_total = 1 + desired_total
# Prefer the new shape (K,R) but keep compatibility with older checkpoints
# that used a single flattened dimension (1 + K*R).
out_dims = [ int ( dataset . n_disease ) , int ( r ) ]
if checkpoint_path :
try :
ckpt = torch . load ( checkpoint_path , map_location = " cpu " )
head_sd = ckpt . get ( " head_state_dict " , { } )
w = head_sd . get ( " net.2.weight " , None )
if isinstance ( w , torch . Tensor ) and w . ndim == 2 :
out_features = int ( w . shape [ 0 ] )
if out_features == legacy_total :
out_dims = [ legacy_total ]
elif out_features == desired_total :
out_dims = [ int ( dataset . n_disease ) , int ( r ) ]
else :
raise ValueError (
f " Checkpoint head out_features= { out_features } does not match expected { desired_total } (K*R) or { legacy_total } (1+K*R) "
)
except Exception as e :
raise ValueError (
f " Failed to infer head output dims from checkpoint= { checkpoint_path } : { e } "
)
2026-01-13 15:59:20 +08:00
loss_params [ " centers " ] = centers
loss_params [ " bandwidth_min " ] = float ( cfg . get ( " bandwidth_min " , 1e-3 ) )
loss_params [ " bandwidth_max " ] = float ( cfg . get ( " bandwidth_max " , 10.0 ) )
loss_params [ " bandwidth_init " ] = float ( cfg . get ( " bandwidth_init " , 0.7 ) )
loss_params [ " loss_eps " ] = float ( cfg . get ( " loss_eps " , 1e-8 ) )
2026-01-13 21:11:38 +08:00
loss_params [ " alpha_floor " ] = float ( cfg . get ( " alpha_floor " , 0.0 ) )
2026-01-10 11:37:12 +08:00
else :
raise ValueError ( f " Unsupported loss_type for evaluation: { loss_type } " )
if model_type == " delphi_fork " :
backbone = DelphiFork (
n_disease = dataset . n_disease ,
n_tech_tokens = 2 ,
n_embd = int ( cfg [ " n_embd " ] ) ,
n_head = int ( cfg [ " n_head " ] ) ,
n_layer = int ( cfg [ " n_layer " ] ) ,
pdrop = float ( cfg . get ( " pdrop " , 0.0 ) ) ,
age_encoder_type = str ( cfg . get ( " age_encoder " , " sinusoidal " ) ) ,
n_cont = dataset . n_cont ,
n_cate = dataset . n_cate ,
cate_dims = dataset . cate_dims ,
) . to ( device )
elif model_type == " sap_delphi " :
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
emb_path = cfg . get ( " pretrained_emb_path " , None )
if emb_path in { " " , None } :
emb_path = cfg . get ( " pretrained_emd_path " , None )
if emb_path in { " " , None } :
run_dir = os . path . dirname ( os . path . abspath (
checkpoint_path ) ) if checkpoint_path else " "
print (
f " WARNING: SapDelphi pretrained embedding path missing in config "
f " (expected ' pretrained_emb_path ' or ' pretrained_emd_path ' ). "
f " checkpoint= { checkpoint_path } run_dir= { run_dir } "
)
backbone = SapDelphi (
n_disease = dataset . n_disease ,
n_tech_tokens = 2 ,
n_embd = int ( cfg [ " n_embd " ] ) ,
n_head = int ( cfg [ " n_head " ] ) ,
n_layer = int ( cfg [ " n_layer " ] ) ,
pdrop = float ( cfg . get ( " pdrop " , 0.0 ) ) ,
age_encoder_type = str ( cfg . get ( " age_encoder " , " sinusoidal " ) ) ,
n_cont = dataset . n_cont ,
n_cate = dataset . n_cate ,
cate_dims = dataset . cate_dims ,
pretrained_weights_path = emb_path ,
freeze_embeddings = True ,
) . to ( device )
else :
raise ValueError ( f " Unsupported model_type: { model_type } " )
head = SimpleHead ( n_embd = int ( cfg [ " n_embd " ] ) , out_dims = out_dims ) . to ( device )
bin_edges = cfg . get ( " bin_edges " , DEFAULT_BIN_EDGES )
2026-01-13 15:59:20 +08:00
return backbone , head , loss_type , bin_edges , loss_params
2026-01-10 11:37:12 +08:00
@torch.no_grad ( )
def predict_cifs_for_model (
backbone : torch . nn . Module ,
head : torch . nn . Module ,
loss_type : str ,
bin_edges : Sequence [ float ] ,
2026-01-13 15:59:20 +08:00
loss_params : Dict [ str , Any ] ,
2026-01-10 11:37:12 +08:00
loader : DataLoader ,
device : str ,
offset_years : float ,
eval_horizons : Sequence [ float ] ,
2026-01-10 23:49:37 +08:00
n_disease : int ,
) - > Tuple [ np . ndarray , np . ndarray , np . ndarray ] :
2026-01-10 17:00:16 +08:00
""" Run model and produce cause-specific, time-dependent CIF outputs.
2026-01-10 11:37:12 +08:00
Returns :
cif_full : ( N , K , H )
survival : ( N , H )
2026-01-10 23:49:37 +08:00
y_cause_within_tau : ( N , K , H )
2026-01-10 11:37:12 +08:00
2026-01-10 17:00:16 +08:00
NOTE : Evaluation is cause - specific and horizon - specific ( multi - disease risk ) .
2026-01-10 11:37:12 +08:00
"""
backbone . eval ( )
head . eval ( )
# We will accumulate in CPU lists, then concat.
cif_full_list : List [ np . ndarray ] = [ ]
survival_list : List [ np . ndarray ] = [ ]
y_cause_within_list : List [ np . ndarray ] = [ ]
2026-01-10 17:00:16 +08:00
sex_list : List [ np . ndarray ] = [ ]
2026-01-10 11:37:12 +08:00
for batch in loader :
event_seq , time_seq , cont_feats , cate_feats , sexes = batch
event_seq = event_seq . to ( device )
time_seq = time_seq . to ( device )
cont_feats = cont_feats . to ( device )
cate_feats = cate_feats . to ( device )
2026-01-10 17:02:28 +08:00
sexes = sexes . to ( device )
2026-01-10 11:37:12 +08:00
keep , t_ctx , _ = select_context_indices (
event_seq , time_seq , offset_years )
if not keep . any ( ) :
continue
# filter batch
event_seq = event_seq [ keep ]
time_seq = time_seq [ keep ]
cont_feats = cont_feats [ keep ]
cate_feats = cate_feats [ keep ]
2026-01-10 17:02:28 +08:00
sexes_k = sexes [ keep ]
2026-01-10 11:37:12 +08:00
t_ctx = t_ctx [ keep ]
h = backbone ( event_seq , time_seq , sexes_k ,
cont_feats , cate_feats ) # (B,L,D)
b = torch . arange ( h . size ( 0 ) , device = device )
c = h [ b , t_ctx ] # (B,D)
logits = head ( c )
if loss_type == " exponential " :
cif_full , survival = cifs_from_exponential_logits (
logits , eval_horizons , return_survival = True ) # (B,K,H), (B,H)
elif loss_type == " discrete_time_cif " :
cif_full , survival = cifs_from_discrete_time_logits (
# (B,K,H), (B,H)
logits , bin_edges , eval_horizons , return_survival = True )
2026-01-13 21:11:38 +08:00
elif loss_type == " lognormal_basis_binned_hazard_cif " :
2026-01-13 15:59:20 +08:00
centers = loss_params . get ( " centers " , None )
sigma = loss_params . get ( " sigma " , None )
if centers is None or sigma is None :
raise ValueError (
2026-01-13 21:11:38 +08:00
" lognormal_basis_binned_hazard_cif requires loss_params[ ' centers ' ] and loss_params[ ' sigma ' ] " )
cif_full , survival = cifs_from_lognormal_basis_binned_hazard_logits (
2026-01-13 15:59:20 +08:00
logits ,
centers = centers ,
sigma = sigma ,
bin_edges = bin_edges ,
2026-01-13 21:11:38 +08:00
taus = eval_horizons ,
eps = float ( loss_params . get ( " loss_eps " , 1e-8 ) ) ,
alpha_floor = float ( loss_params . get ( " alpha_floor " , 0.0 ) ) ,
2026-01-13 15:59:20 +08:00
return_survival = True ,
)
2026-01-10 11:37:12 +08:00
else :
raise ValueError ( f " Unsupported loss_type: { loss_type } " )
2026-01-10 23:49:37 +08:00
# Within-horizon labels for all causes: disease k occurs within tau after context.
y_within_full = torch . stack (
2026-01-10 11:37:12 +08:00
[
2026-01-10 23:49:37 +08:00
multi_hot_ever_within_horizon (
2026-01-10 11:37:12 +08:00
event_seq = event_seq ,
time_seq = time_seq ,
t_ctx = t_ctx ,
tau_years = float ( tau ) ,
2026-01-10 23:49:37 +08:00
n_disease = int ( n_disease ) ,
2026-01-10 11:37:12 +08:00
) . to ( torch . float32 )
for tau in eval_horizons
] ,
dim = 2 ,
2026-01-10 23:49:37 +08:00
) # (B,K,H)
2026-01-10 17:00:16 +08:00
2026-01-10 11:37:12 +08:00
cif_full_list . append ( cif_full . detach ( ) . cpu ( ) . numpy ( ) )
survival_list . append ( survival . detach ( ) . cpu ( ) . numpy ( ) )
2026-01-10 23:49:37 +08:00
y_cause_within_list . append ( y_within_full . detach ( ) . cpu ( ) . numpy ( ) )
2026-01-10 17:00:16 +08:00
sex_list . append ( sexes_k . detach ( ) . cpu ( ) . numpy ( ) )
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
if not cif_full_list :
2026-01-10 11:37:12 +08:00
raise RuntimeError (
" No valid samples for evaluation (all batches filtered out by offset). " )
cif_full = np . concatenate ( cif_full_list , axis = 0 )
survival = np . concatenate ( survival_list , axis = 0 )
y_cause_within = np . concatenate ( y_cause_within_list , axis = 0 )
2026-01-10 17:00:16 +08:00
sex = np . concatenate (
sex_list , axis = 0 ) if sex_list else np . array ( [ ] , dtype = int )
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
return cif_full , survival , y_cause_within , sex
2026-01-10 11:37:12 +08:00
def evaluate_one_model (
model_name : str ,
2026-01-10 23:49:37 +08:00
cif_full : np . ndarray ,
2026-01-10 11:37:12 +08:00
y_cause_within_tau : np . ndarray ,
eval_horizons : Sequence [ float ] ,
out_rows : List [ Dict [ str , Any ] ] ,
calib_rows : List [ Dict [ str , Any ] ] ,
2026-01-10 23:49:37 +08:00
calib_cause_ids : Optional [ Sequence [ int ] ] ,
2026-01-10 11:37:12 +08:00
n_calib_bins : int = 10 ,
2026-01-10 23:49:37 +08:00
metric_workers : int = 0 ,
progress : str = " auto " ,
2026-01-10 11:37:12 +08:00
) - > None :
2026-01-10 23:49:37 +08:00
""" Compute per-cause metrics for ALL diseases.
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
Notes :
- Writes scalar metrics for all causes into out_rows .
- Writes calibration - bin points only for calib_cause_ids ( to keep outputs tractable ) .
"""
cif_full = np . asarray ( cif_full , dtype = float )
y_cause_within_tau = np . asarray ( y_cause_within_tau , dtype = float )
if cif_full . ndim != 3 or y_cause_within_tau . ndim != 3 :
raise ValueError (
" Expected cif_full and y_cause_within_tau with shape (N, K, H) " )
if cif_full . shape != y_cause_within_tau . shape :
raise ValueError (
f " Shape mismatch: cif_full { cif_full . shape } vs y_cause_within_tau { y_cause_within_tau . shape } "
)
2026-01-10 11:37:12 +08:00
2026-01-10 23:49:37 +08:00
N , K , H = cif_full . shape
if H != len ( eval_horizons ) :
raise ValueError ( " H mismatch between cif_full and eval_horizons " )
calib_set = set ( int ( x )
for x in calib_cause_ids ) if calib_cause_ids is not None else set ( )
workers = int ( metric_workers )
if workers < = 0 :
workers = int ( min ( 8 , os . cpu_count ( ) or 1 ) )
workers = max ( 1 , workers )
show_progress = _should_show_progress ( progress )
def _eval_chunk (
* ,
tau : float ,
p_tau : np . ndarray ,
y_tau : np . ndarray ,
brier_by_cause : np . ndarray ,
cause_ids : np . ndarray ,
) - > Tuple [ List [ Dict [ str , Any ] ] , List [ Dict [ str , Any ] ] , int ] :
local_rows : List [ Dict [ str , Any ] ] = [ ]
local_calib : List [ Dict [ str , Any ] ] = [ ]
for cid in cause_ids . tolist ( ) :
p = p_tau [ : , cid ]
y = y_tau [ : , cid ]
local_rows . append (
2026-01-10 11:37:12 +08:00
{
" model_name " : model_name ,
" metric_name " : " cause_brier " ,
" horizon " : float ( tau ) ,
2026-01-10 23:49:37 +08:00
" cause " : int ( cid ) ,
" value " : float ( brier_by_cause [ cid ] ) ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
2026-01-10 23:49:37 +08:00
# ICI: compute bins only if we will export them.
need_bins = ( not calib_set ) or ( int ( cid ) in calib_set )
if need_bins :
cal = calibration_deciles ( p , y , n_bins = n_calib_bins )
ici = float ( cal [ " ici " ] )
else :
cal = None
ici = calibration_ici_only ( p , y , n_bins = n_calib_bins )
local_rows . append (
2026-01-10 11:37:12 +08:00
{
" model_name " : model_name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " cause_ici " ,
2026-01-10 11:37:12 +08:00
" horizon " : float ( tau ) ,
2026-01-10 23:49:37 +08:00
" cause " : int ( cid ) ,
" value " : float ( ici ) ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
2026-01-10 17:00:16 +08:00
2026-01-11 00:47:56 +08:00
# Secondary: discrimination via AUC at the same horizon (point estimate only).
auc = roc_auc_rank ( y , p )
2026-01-10 23:49:37 +08:00
local_rows . append (
2026-01-10 11:37:12 +08:00
{
" model_name " : model_name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " cause_auc " ,
2026-01-10 11:37:12 +08:00
" horizon " : float ( tau ) ,
2026-01-10 23:49:37 +08:00
" cause " : int ( cid ) ,
" value " : float ( auc ) ,
2026-01-11 00:47:56 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
2026-01-10 11:37:12 +08:00
}
)
2026-01-10 23:49:37 +08:00
if need_bins and cal is not None :
for binfo in cal . get ( " bins " , [ ] ) :
local_calib . append (
{
" model_name " : model_name ,
" task " : " cause_k " ,
" horizon " : float ( tau ) ,
" cause_id " : int ( cid ) ,
" bin_index " : int ( binfo [ " bin " ] ) ,
" p_mean " : float ( binfo [ " p_mean " ] ) ,
" y_mean " : float ( binfo [ " y_mean " ] ) ,
" n_in_bin " : int ( binfo [ " n " ] ) ,
}
)
return local_rows , local_calib , int ( cause_ids . size )
# Cause-specific, time-dependent metrics per horizon.
for h_i , tau in enumerate ( eval_horizons ) :
p_tau = cif_full [ : , : , h_i ] # (N, K)
y_tau = y_cause_within_tau [ : , : , h_i ] # (N, K)
# Vectorized Brier for speed.
brier_by_cause = np . mean ( ( p_tau - y_tau ) * * 2 , axis = 0 ) # (K,)
# Parallelize disease-level metrics; chunk to avoid millions of futures.
all_ids = np . arange ( int ( K ) , dtype = int )
chunks = np . array_split ( all_ids , workers )
done = 0
prefix = f " [ { model_name } ] tau= { float ( tau ) } y "
t0 = time . time ( )
if workers < = 1 :
for ch in chunks :
r_chunk , c_chunk , n_done = _eval_chunk (
tau = float ( tau ) ,
p_tau = p_tau ,
y_tau = y_tau ,
brier_by_cause = brier_by_cause ,
cause_ids = ch ,
2026-01-10 11:37:12 +08:00
)
2026-01-10 23:49:37 +08:00
out_rows . extend ( r_chunk )
calib_rows . extend ( c_chunk )
done + = int ( n_done )
if show_progress :
sys . stdout . write (
" \r " + _progress_line ( done , int ( K ) , prefix = prefix ) )
sys . stdout . flush ( )
else :
with ThreadPoolExecutor ( max_workers = workers ) as ex :
futs = [
ex . submit (
_eval_chunk ,
tau = float ( tau ) ,
p_tau = p_tau ,
y_tau = y_tau ,
brier_by_cause = brier_by_cause ,
cause_ids = ch ,
)
for ch in chunks
if int ( ch . size ) > 0
]
for fut in as_completed ( futs ) :
r_chunk , c_chunk , n_done = fut . result ( )
out_rows . extend ( r_chunk )
calib_rows . extend ( c_chunk )
done + = int ( n_done )
if show_progress :
sys . stdout . write (
" \r " + _progress_line ( done , int ( K ) , prefix = prefix ) )
sys . stdout . flush ( )
if show_progress :
dt = time . time ( ) - t0
sys . stdout . write ( " \r " + _progress_line ( int ( K ) ,
int ( K ) , prefix = prefix ) + f " ( { dt : .1f } s) \n " )
sys . stdout . flush ( )
def summarize_over_diseases (
rows : List [ Dict [ str , Any ] ] ,
* ,
model_name : str ,
eval_horizons : Sequence [ float ] ,
metrics : Sequence [ str ] = ( " cause_brier " , " cause_ici " , " cause_auc " ) ,
) - > List [ Dict [ str , Any ] ] :
""" Summarize mean/median of each metric over diseases (per horizon). """
out : List [ Dict [ str , Any ] ] = [ ]
# Build metric_name -> horizon -> list of values
bucket : Dict [ Tuple [ str , float ] , List [ float ] ] = { }
for r in rows :
if r . get ( " model_name " ) != model_name :
continue
m = str ( r . get ( " metric_name " ) )
if m not in set ( metrics ) :
continue
h = _safe_float ( r . get ( " horizon " ) )
v = _safe_float ( r . get ( " value " ) )
if not np . isfinite ( h ) :
continue
if not np . isfinite ( v ) :
continue
bucket . setdefault ( ( m , float ( h ) ) , [ ] ) . append ( float ( v ) )
for tau in eval_horizons :
ht = float ( tau )
for m in metrics :
vals = bucket . get ( ( str ( m ) , ht ) , [ ] )
if vals :
arr = np . asarray ( vals , dtype = float )
mean_v = float ( np . mean ( arr ) )
med_v = float ( np . median ( arr ) )
n_valid = int ( arr . size )
else :
mean_v = float ( " nan " )
med_v = float ( " nan " )
n_valid = 0
out . append (
{
" model_name " : str ( model_name ) ,
" metric_name " : str ( m ) ,
" horizon " : ht ,
" mean " : mean_v ,
" median " : med_v ,
" n_valid " : n_valid ,
}
)
return out
2026-01-10 11:37:12 +08:00
def write_calibration_bins_csv ( path : str , rows : List [ Dict [ str , Any ] ] ) - > None :
fieldnames = [
" model_name " ,
" task " ,
" horizon " ,
" cause_id " ,
" bin_index " ,
" p_mean " ,
" y_mean " ,
" n_in_bin " ,
]
with open ( path , " w " , newline = " " ) as f :
w = csv . DictWriter ( f , fieldnames = fieldnames )
w . writeheader ( )
for r in rows :
w . writerow ( r )
def write_results_csv ( path : str , rows : List [ Dict [ str , Any ] ] ) - > None :
fieldnames = [
" model_name " ,
" metric_name " ,
" horizon " ,
" cause " ,
" value " ,
" ci_low " ,
" ci_high " ,
]
with open ( path , " w " , newline = " " ) as f :
w = csv . DictWriter ( f , fieldnames = fieldnames )
w . writeheader ( )
for r in rows :
w . writerow ( r )
def _make_eval_tag ( split : str , offset_years : float ) - > str :
""" Short tag for filenames written into run directories. """
off = f " { float ( offset_years ) : .4f } " . rstrip ( " 0 " ) . rstrip ( " . " )
return f " { split } _offset { off } y "
def main ( ) - > int :
ap = argparse . ArgumentParser (
description = " Unified downstream evaluation via CIFs " )
ap . add_argument ( " --models_json " , type = str , required = True ,
help = " Path to models list JSON " )
ap . add_argument ( " --data_prefix " , type = str ,
default = " ukb " , help = " Dataset prefix " )
ap . add_argument ( " --split " , type = str , default = " test " ,
choices = [ " train " , " val " , " test " , " all " ] , help = " Which split to evaluate " )
ap . add_argument ( " --offset_years " , type = float , default = 0.5 ,
help = " Anti-leakage offset (years) " )
ap . add_argument ( " --eval_horizons " , type = float ,
nargs = " * " , default = DEFAULT_EVAL_HORIZONS )
ap . add_argument ( " --batch_size " , type = int , default = 128 )
ap . add_argument ( " --num_workers " , type = int , default = 0 )
ap . add_argument ( " --seed " , type = int , default = 123 )
ap . add_argument ( " --device " , type = str ,
default = " cuda " if torch . cuda . is_available ( ) else " cpu " )
2026-01-10 23:49:37 +08:00
ap . add_argument ( " --out_csv " , type = str , default = " eval_summary.csv " )
2026-01-10 11:37:12 +08:00
ap . add_argument ( " --out_meta_json " , type = str , default = " eval_meta.json " )
# Integrity checks
ap . add_argument ( " --integrity_strict " , action = " store_true " , default = False )
ap . add_argument ( " --integrity_tol " , type = float , default = 1e-6 )
2026-01-10 23:49:37 +08:00
# Speed/UX
ap . add_argument (
" --metric_workers " ,
type = int ,
default = 0 ,
help = " Threads for per-disease metrics (0=auto, 1=disable parallelism) " ,
)
ap . add_argument (
" --progress " ,
type = str ,
default = " auto " ,
choices = [ " auto " , " bar " , " none " ] ,
help = " Progress visualization during per-disease evaluation " ,
)
2026-01-10 17:00:16 +08:00
# Export settings for user-facing experiments
ap . add_argument ( " --export_dir " , type = str , default = " eval_exports " )
ap . add_argument ( " --death_cause_id " , type = int ,
default = DEFAULT_DEATH_CAUSE_ID )
ap . add_argument ( " --focus_k " , type = int , default = 5 ,
help = " Additional non-death causes to include " )
ap . add_argument ( " --capture_k_pcts " , type = int ,
nargs = " * " , default = [ 1 , 5 , 10 , 20 ] )
ap . add_argument (
" --capture_curve_max_pct " ,
type = int ,
default = 50 ,
help = " If >0, also export a dense capture curve for k=1..max_pct " ,
)
2026-01-11 00:47:56 +08:00
# High-risk cause concentration (cross-cause prioritization)
ap . add_argument (
" --cause_concentration_topk " ,
type = int ,
nargs = " * " ,
default = [ 5 , 10 , 20 , 50 ] ,
help = " Top-K causes per individual for Event Rate@K (cross-cause prioritization) " ,
)
2026-01-11 00:52:35 +08:00
ap . add_argument (
" --cause_concentration_write_random_baseline " ,
action = " store_true " ,
default = False ,
help = " If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range) " ,
)
2026-01-10 11:37:12 +08:00
args = ap . parse_args ( )
set_deterministic ( args . seed )
specs = load_models_json ( args . models_json )
if not specs :
raise ValueError ( " No models provided " )
2026-01-10 17:00:16 +08:00
export_dir = str ( args . export_dir )
_ensure_dir ( export_dir )
cause_names = load_cause_names ( " labels.csv " )
# Determine top-K causes from the evaluation split only (model-agnostic),
# aligned to time-dependent risk: occurrence within tau_max after context.
2026-01-10 11:37:12 +08:00
first_cfg = load_train_config_for_checkpoint ( specs [ 0 ] . checkpoint_path )
cov_list = None if _parse_bool ( first_cfg . get ( " full_cov " , False ) ) else [
" bmi " , " smoking " , " alcohol " ]
dataset_for_top = HealthDataset (
data_prefix = args . data_prefix , covariate_list = cov_list )
subset_for_top = build_eval_subset (
dataset_for_top ,
train_ratio = float ( first_cfg . get ( " train_ratio " , 0.7 ) ) ,
val_ratio = float ( first_cfg . get ( " val_ratio " , 0.15 ) ) ,
seed = int ( first_cfg . get ( " random_seed " , 42 ) ) ,
split = args . split ,
)
loader_top = DataLoader (
subset_for_top ,
batch_size = args . batch_size ,
shuffle = False ,
num_workers = args . num_workers ,
collate_fn = health_collate_fn ,
)
2026-01-10 17:00:16 +08:00
tau_max = float ( max ( args . eval_horizons ) )
counts , n_total_eval = count_occurs_within_horizon (
2026-01-10 11:37:12 +08:00
loader = loader_top ,
offset_years = args . offset_years ,
2026-01-10 17:00:16 +08:00
tau_years = tau_max ,
2026-01-10 11:37:12 +08:00
n_disease = dataset_for_top . n_disease ,
device = args . device ,
)
2026-01-10 17:00:16 +08:00
focus_causes = pick_focus_causes (
counts_within_tau = counts ,
n_disease = int ( dataset_for_top . n_disease ) ,
death_cause_id = int ( args . death_cause_id ) ,
k = int ( args . focus_k ) ,
)
top_cause_ids = np . asarray ( focus_causes , dtype = int )
# Export the chosen focus causes.
focus_rows : List [ Dict [ str , Any ] ] = [ ]
for r , cid in enumerate ( focus_causes , start = 1 ) :
row : Dict [ str , Any ] = { " cause " : int ( cid ) , " rank " : int ( r ) }
if cid in cause_names :
row [ " cause_name " ] = cause_names [ cid ]
focus_rows . append ( row )
focus_fieldnames = [ " cause " , " rank " ] + \
( [ " cause_name " ] if any ( " cause_name " in r for r in focus_rows ) else [ ] )
write_simple_csv ( os . path . join ( export_dir , " focus_causes.csv " ) ,
focus_fieldnames , focus_rows )
# Metadata for focus causes (within tau_max).
2026-01-10 11:37:12 +08:00
top_causes_meta : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 17:00:16 +08:00
for cid in focus_causes :
2026-01-10 23:49:37 +08:00
n_case = int ( counts [ int ( cid ) ] ) if int (
cid ) < int ( counts . shape [ 0 ] ) else 0
2026-01-10 11:37:12 +08:00
top_causes_meta . append (
{
2026-01-10 17:00:16 +08:00
" cause_id " : int ( cid ) ,
" tau_years " : float ( tau_max ) ,
" n_case_within_tau " : n_case ,
" n_control_within_tau " : int ( n_total_eval - n_case ) ,
2026-01-10 11:37:12 +08:00
" n_total_eval " : int ( n_total_eval ) ,
}
)
2026-01-10 23:49:37 +08:00
summary_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 11:37:12 +08:00
calib_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 17:00:16 +08:00
# Experiment exports (accumulated across models)
cap_points_rows : List [ Dict [ str , Any ] ] = [ ]
cap_curve_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-11 00:47:56 +08:00
conc_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-11 00:52:35 +08:00
conc_base_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 17:00:16 +08:00
2026-01-10 11:37:12 +08:00
# Track per-model integrity status for meta JSON.
integrity_meta : Dict [ str , Any ] = { }
# Evaluate each model
for spec in specs :
run_dir = os . path . dirname ( os . path . abspath ( spec . checkpoint_path ) )
tag = _make_eval_tag ( args . split , float ( args . offset_years ) )
# Remember list offsets so we can write per-model slices to the model's run_dir.
calib_start = len ( calib_rows )
cfg = load_train_config_for_checkpoint ( spec . checkpoint_path )
2026-01-10 17:00:16 +08:00
# Identifiers for consistent exports
model_id = str ( spec . name )
model_type = str (
cfg . get ( " model_type " , spec . model_type if hasattr ( spec , " model_type " ) else " " ) )
loss_type_id = str (
cfg . get ( " loss_type " , spec . loss_type if hasattr ( spec , " loss_type " ) else " " ) )
age_encoder = str ( cfg . get ( " age_encoder " , " " ) )
cov_type = " full " if _parse_bool (
cfg . get ( " full_cov " , False ) ) else " partial "
2026-01-10 11:37:12 +08:00
cov_list = None if _parse_bool ( cfg . get ( " full_cov " , False ) ) else [
" bmi " , " smoking " , " alcohol " ]
dataset = HealthDataset (
data_prefix = args . data_prefix , covariate_list = cov_list )
subset = build_eval_subset (
dataset ,
train_ratio = float ( cfg . get ( " train_ratio " , 0.7 ) ) ,
val_ratio = float ( cfg . get ( " val_ratio " , 0.15 ) ) ,
seed = int ( cfg . get ( " random_seed " , 42 ) ) ,
split = args . split ,
)
loader = DataLoader (
subset ,
batch_size = args . batch_size ,
shuffle = False ,
num_workers = args . num_workers ,
collate_fn = health_collate_fn ,
)
ckpt = torch . load ( spec . checkpoint_path , map_location = args . device )
2026-01-13 15:59:20 +08:00
backbone , head , loss_type , bin_edges , loss_params = instantiate_model_and_head (
cfg , dataset , args . device , checkpoint_path = spec . checkpoint_path )
2026-01-10 11:37:12 +08:00
backbone . load_state_dict ( ckpt [ " model_state_dict " ] , strict = True )
head . load_state_dict ( ckpt [ " head_state_dict " ] , strict = True )
2026-01-13 21:11:38 +08:00
if loss_type == " lognormal_basis_binned_hazard_cif " :
2026-01-13 15:59:20 +08:00
crit_state = ckpt . get ( " criterion_state_dict " , { } )
log_sigma = crit_state . get ( " log_sigma " , None )
if isinstance ( log_sigma , torch . Tensor ) :
log_sigma_t = log_sigma . to ( device = args . device )
sigma = torch . exp ( log_sigma_t )
else :
sigma = torch . tensor ( float ( loss_params . get (
" bandwidth_init " , 0.7 ) ) , device = args . device )
bmin = float ( loss_params . get ( " bandwidth_min " , 1e-3 ) )
bmax = float ( loss_params . get ( " bandwidth_max " , 10.0 ) )
sigma = torch . clamp ( sigma , min = bmin , max = bmax )
loss_params [ " sigma " ] = sigma
2026-01-10 11:37:12 +08:00
(
cif_full ,
survival ,
y_cause_within_tau ,
2026-01-10 17:00:16 +08:00
sex ,
2026-01-10 11:37:12 +08:00
) = predict_cifs_for_model (
backbone ,
head ,
loss_type ,
bin_edges ,
2026-01-13 15:59:20 +08:00
loss_params ,
2026-01-10 11:37:12 +08:00
loader ,
args . device ,
args . offset_years ,
args . eval_horizons ,
2026-01-10 23:49:37 +08:00
n_disease = int ( dataset . n_disease ) ,
2026-01-10 11:37:12 +08:00
)
# CIF integrity checks before metrics.
integrity_ok , integrity_notes = check_cif_integrity (
cif_full ,
args . eval_horizons ,
tol = float ( args . integrity_tol ) ,
name = spec . name ,
strict = bool ( args . integrity_strict ) ,
survival = survival ,
)
integrity_meta [ spec . name ] = {
" integrity_ok " : bool ( integrity_ok ) ,
" integrity_notes " : integrity_notes ,
}
2026-01-10 23:49:37 +08:00
# Per-disease metrics for ALL diseases (written into the model's run_dir).
model_rows : List [ Dict [ str , Any ] ] = [ ]
2026-01-10 11:37:12 +08:00
evaluate_one_model (
model_name = spec . name ,
2026-01-10 23:49:37 +08:00
cif_full = cif_full ,
2026-01-10 11:37:12 +08:00
y_cause_within_tau = y_cause_within_tau ,
eval_horizons = args . eval_horizons ,
2026-01-10 23:49:37 +08:00
out_rows = model_rows ,
2026-01-10 11:37:12 +08:00
calib_rows = calib_rows ,
2026-01-10 23:49:37 +08:00
calib_cause_ids = top_cause_ids . tolist ( ) ,
metric_workers = int ( args . metric_workers ) ,
progress = str ( args . progress ) ,
2026-01-10 11:37:12 +08:00
)
2026-01-10 23:49:37 +08:00
# Summary over diseases (mean/median per horizon).
model_summary_rows = summarize_over_diseases (
model_rows ,
model_name = spec . name ,
eval_horizons = args . eval_horizons ,
)
summary_rows . extend ( model_summary_rows )
2026-01-10 17:00:16 +08:00
# ============================================================
2026-01-11 00:47:56 +08:00
# Experiment: High-Risk Cause Concentration at fixed horizon
# (cross-cause prioritization accuracy)
2026-01-10 17:00:16 +08:00
# ============================================================
2026-01-11 00:47:56 +08:00
topk_causes = [ int ( x ) for x in args . cause_concentration_topk ]
2026-01-10 17:00:16 +08:00
for sex_label , sex_mask in _sex_slices ( sex if sex . size else None ) :
for h_i , tau in enumerate ( args . eval_horizons ) :
2026-01-11 00:47:56 +08:00
p_tau_all = np . asarray ( cif_full [ : , : , h_i ] , dtype = float )
y_tau_all = np . asarray (
y_cause_within_tau [ : , : , h_i ] , dtype = float )
if sex_mask is not None :
p_tau_all = p_tau_all [ sex_mask ]
y_tau_all = y_tau_all [ sex_mask ]
for rr in compute_event_rate_at_topk_causes ( p_tau_all , y_tau_all , topk_causes ) :
conc_rows . append (
2026-01-10 17:00:16 +08:00
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
2026-01-11 00:47:56 +08:00
* * rr ,
2026-01-10 17:00:16 +08:00
}
)
2026-01-11 00:52:35 +08:00
if bool ( args . cause_concentration_write_random_baseline ) :
for rr in compute_random_ranking_baseline_topk ( y_tau_all , topk_causes ) :
conc_base_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
* * rr ,
}
)
2026-01-11 00:47:56 +08:00
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full [ : , top_cause_ids , : ]
y_within_focus = y_cause_within_tau [ : , top_cause_ids , : ]
2026-01-10 17:00:16 +08:00
# ============================================================
# Experiment 2: High-risk capture points (+ optional curve)
# ============================================================
k_pcts = [ int ( x ) for x in args . capture_k_pcts ]
curve_max = int ( args . capture_curve_max_pct )
curve_grid = list ( range ( 1 , curve_max + 1 )
) if curve_max and curve_max > 0 else [ ]
for sex_label , sex_mask in _sex_slices ( sex if sex . size else None ) :
for h_i , tau in enumerate ( args . eval_horizons ) :
for j , cause_id in enumerate ( top_cause_ids . tolist ( ) ) :
2026-01-10 23:49:37 +08:00
p = cause_cif_focus [ : , j , h_i ]
y = y_within_focus [ : , j , h_i ]
2026-01-10 17:00:16 +08:00
if sex_mask is not None :
p = p [ sex_mask ]
y = y [ sex_mask ]
for r in compute_capture_points ( p , y , k_pcts ) :
cap_points_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
* * r ,
}
)
if curve_grid :
for r in compute_capture_points ( p , y , curve_grid ) :
cap_curve_rows . append (
{
" model_id " : model_id ,
" model_type " : model_type ,
" loss_type " : loss_type_id ,
" age_encoder " : age_encoder ,
" cov_type " : cov_type ,
" cause " : int ( cause_id ) ,
" horizon " : float ( tau ) ,
" sex " : sex_label ,
* * r ,
}
)
2026-01-10 11:37:12 +08:00
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta :
2026-01-10 23:49:37 +08:00
model_rows . append (
2026-01-10 11:37:12 +08:00
{
" model_name " : spec . name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " topcause_n_case_within_tau " ,
" horizon " : float ( tc [ " tau_years " ] ) ,
2026-01-10 11:37:12 +08:00
" cause " : int ( tc [ " cause_id " ] ) ,
2026-01-10 17:00:16 +08:00
" value " : int ( tc [ " n_case_within_tau " ] ) ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
2026-01-10 23:49:37 +08:00
model_rows . append (
2026-01-10 11:37:12 +08:00
{
" model_name " : spec . name ,
2026-01-10 17:00:16 +08:00
" metric_name " : " topcause_n_control_within_tau " ,
" horizon " : float ( tc [ " tau_years " ] ) ,
2026-01-10 11:37:12 +08:00
" cause " : int ( tc [ " cause_id " ] ) ,
2026-01-10 17:00:16 +08:00
" value " : int ( tc [ " n_control_within_tau " ] ) ,
2026-01-10 11:37:12 +08:00
" ci_low " : " " ,
" ci_high " : " " ,
}
)
2026-01-10 23:49:37 +08:00
model_rows . append (
2026-01-10 11:37:12 +08:00
{
" model_name " : spec . name ,
" metric_name " : " topcause_n_total_eval " ,
2026-01-10 17:00:16 +08:00
" horizon " : float ( tc [ " tau_years " ] ) ,
2026-01-10 11:37:12 +08:00
" cause " : int ( tc [ " cause_id " ] ) ,
" value " : int ( tc [ " n_total_eval " ] ) ,
" ci_low " : " " ,
" ci_high " : " " ,
}
)
# Write per-model results into the model's run directory.
model_calib_rows = calib_rows [ calib_start : ]
model_out_csv = os . path . join ( run_dir , f " eval_results_ { tag } .csv " )
2026-01-10 23:49:37 +08:00
model_summary_csv = os . path . join ( run_dir , f " eval_summary_ { tag } .csv " )
2026-01-10 11:37:12 +08:00
model_calib_csv = os . path . join ( run_dir , f " calibration_bins_ { tag } .csv " )
model_meta_json = os . path . join ( run_dir , f " eval_meta_ { tag } .json " )
write_results_csv ( model_out_csv , model_rows )
2026-01-10 23:49:37 +08:00
write_simple_csv (
model_summary_csv ,
[ " model_name " , " metric_name " , " horizon " , " mean " , " median " , " n_valid " ] ,
model_summary_rows ,
)
2026-01-10 11:37:12 +08:00
write_calibration_bins_csv ( model_calib_csv , model_calib_rows )
model_meta = {
" model_name " : spec . name ,
" checkpoint_path " : spec . checkpoint_path ,
" run_dir " : run_dir ,
" split " : args . split ,
" offset_years " : args . offset_years ,
" eval_horizons " : [ float ( x ) for x in args . eval_horizons ] ,
2026-01-10 23:49:37 +08:00
" n_disease " : int ( dataset . n_disease ) ,
2026-01-10 11:37:12 +08:00
" top_cause_ids " : top_cause_ids . tolist ( ) ,
" top_causes " : top_causes_meta ,
" integrity " : { spec . name : integrity_meta . get ( spec . name , { } ) } ,
" paths " : {
" results_csv " : model_out_csv ,
2026-01-10 23:49:37 +08:00
" summary_csv " : model_summary_csv ,
2026-01-10 11:37:12 +08:00
" calibration_bins_csv " : model_calib_csv ,
} ,
}
with open ( model_meta_json , " w " ) as f :
json . dump ( model_meta , f , indent = 2 )
print ( f " Wrote per-model results to { model_out_csv } " )
2026-01-10 23:49:37 +08:00
# Write global summary (across diseases) across all models.
write_simple_csv (
args . out_csv ,
[ " model_name " , " metric_name " , " horizon " , " mean " , " median " , " n_valid " ] ,
summary_rows ,
)
2026-01-10 11:37:12 +08:00
# Write calibration curve points to a separate CSV.
out_dir = os . path . dirname ( os . path . abspath ( args . out_csv ) ) or " . "
calib_csv_path = os . path . join ( out_dir , " calibration_bins.csv " )
write_calibration_bins_csv ( calib_csv_path , calib_rows )
2026-01-10 17:00:16 +08:00
# Write experiment exports
write_simple_csv (
os . path . join ( export_dir , " lift_capture_points.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" horizon " ,
" sex " ,
" k_pct " ,
" n_targeted " ,
" events_targeted " ,
" events_total " ,
" event_capture_rate " ,
" precision_in_targeted " ,
] ,
cap_points_rows ,
)
if cap_curve_rows :
write_simple_csv (
os . path . join ( export_dir , " lift_capture_curve.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" cause " ,
" horizon " ,
" sex " ,
" k_pct " ,
" n_targeted " ,
" events_targeted " ,
" events_total " ,
" event_capture_rate " ,
" precision_in_targeted " ,
] ,
cap_curve_rows ,
)
2026-01-11 00:47:56 +08:00
2026-01-10 17:00:16 +08:00
write_simple_csv (
2026-01-11 00:47:56 +08:00
os . path . join ( export_dir , " high_risk_cause_concentration.csv " ) ,
2026-01-10 17:00:16 +08:00
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
2026-01-11 00:47:56 +08:00
" horizon " ,
2026-01-10 17:00:16 +08:00
" sex " ,
2026-01-11 00:47:56 +08:00
" topk " ,
" event_rate_mean " ,
" event_rate_median " ,
2026-01-11 00:52:35 +08:00
" recall_mean " ,
" recall_median " ,
2026-01-10 17:00:16 +08:00
" n_total " ,
2026-01-11 00:52:35 +08:00
" n_valid_recall " ,
2026-01-10 17:00:16 +08:00
] ,
2026-01-11 00:47:56 +08:00
conc_rows ,
2026-01-10 17:00:16 +08:00
)
2026-01-11 00:52:35 +08:00
if conc_base_rows :
write_simple_csv (
os . path . join (
export_dir , " high_risk_cause_concentration_random_baseline.csv " ) ,
[
" model_id " ,
" model_type " ,
" loss_type " ,
" age_encoder " ,
" cov_type " ,
" horizon " ,
" sex " ,
" topk " ,
" baseline_event_rate_mean " ,
" baseline_event_rate_p05 " ,
" baseline_event_rate_p95 " ,
" baseline_recall_mean " ,
" baseline_recall_p05 " ,
" baseline_recall_p95 " ,
" n_total " ,
" n_valid_recall " ,
" k_total " ,
" baseline_method " ,
] ,
conc_base_rows ,
)
2026-01-10 17:00:16 +08:00
# Manifest markdown (stable, user-facing)
manifest_path = os . path . join ( export_dir , " eval_exports_manifest.md " )
with open ( manifest_path , " w " , encoding = " utf-8 " ) as f :
f . write (
" # Evaluation Exports Manifest \n \n "
" This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
" All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced. \n \n "
" ## Files \n \n "
2026-01-10 23:49:37 +08:00
" - focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table. \n "
2026-01-10 17:00:16 +08:00
" - lift_capture_points.csv: Capture/precision at top { 1,5,10,20} % r isk. Intended plot: bar/line showing event capture vs resources. \n "
" - lift_capture_curve.csv (optional): Dense capture curve for k=1..N % . Intended plot: gain curve overlay across models. \n "
2026-01-11 00:52:35 +08:00
" - high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K. \n "
" - high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline). \n "
2026-01-10 17:00:16 +08:00
)
2026-01-10 11:37:12 +08:00
meta = {
" split " : args . split ,
" offset_years " : args . offset_years ,
" eval_horizons " : [ float ( x ) for x in args . eval_horizons ] ,
2026-01-10 17:00:16 +08:00
" tau_max " : float ( tau_max ) ,
2026-01-10 23:49:37 +08:00
" n_disease " : int ( dataset_for_top . n_disease ) ,
2026-01-10 11:37:12 +08:00
" top_cause_ids " : top_cause_ids . tolist ( ) ,
" top_causes " : top_causes_meta ,
" integrity " : integrity_meta ,
" notes " : {
2026-01-10 17:00:16 +08:00
" label " : " Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau]) " ,
" primary_metrics " : " cause_brier (CIF-based) and cause_ici (calibration) " ,
2026-01-11 00:47:56 +08:00
" secondary_metrics " : " cause_auc (discrimination) " ,
2026-01-10 17:00:16 +08:00
" exclusions " : " No all-cause aggregation; no next-event formulation; ECE not reported " ,
2026-01-10 11:37:12 +08:00
" warning " : " This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time. " ,
2026-01-10 17:00:16 +08:00
" exports_dir " : export_dir ,
" focus_causes " : focus_causes ,
2026-01-10 11:37:12 +08:00
} ,
}
with open ( args . out_meta_json , " w " ) as f :
json . dump ( meta , f , indent = 2 )
2026-01-10 23:49:37 +08:00
print ( f " Wrote { args . out_csv } with { len ( summary_rows ) } rows " )
2026-01-10 11:37:12 +08:00
print ( f " Wrote { calib_csv_path } with { len ( calib_rows ) } rows " )
print ( f " Wrote { args . out_meta_json } " )
return 0
if __name__ == " __main__ " :
raise SystemExit ( main ( ) )