Enhance evaluation metrics and progress visualization in model evaluation

- Introduced parallel processing for per-disease metrics using ThreadPoolExecutor.
- Added command-line arguments for metric workers and progress visualization options.
- Refactored evaluation functions to compute metrics for all diseases and summarize results.
- Updated output CSV filenames for clarity and consistency.
This commit is contained in:
2026-01-10 23:49:37 +08:00
parent 87baef3ecf
commit d87752d1f8

View File

@@ -5,6 +5,9 @@ import math
import os
import random
import statistics
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
@@ -522,20 +525,32 @@ def check_cif_integrity(
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
def compute_midrank(x: np.ndarray) -> np.ndarray:
"""Vectorized midrank computation (ties -> average ranks)."""
x = np.asarray(x, dtype=float)
order = np.argsort(x)
n = int(x.shape[0])
if n == 0:
return np.asarray([], dtype=float)
order = np.argsort(x, kind="mergesort")
z = x[order]
n = x.shape[0]
t = np.zeros(n, dtype=float)
i = 0
while i < n:
j = i
while j < n and z[j] == z[i]:
j += 1
t[i:j] = 0.5 * (i + j - 1) + 1.0
i = j
# Find tie groups in sorted order.
diff = np.diff(z)
# boundaries includes 0 and n
boundaries = np.concatenate(
[np.array([0], dtype=int), np.nonzero(diff != 0)
[0] + 1, np.array([n], dtype=int)]
)
starts = boundaries[:-1]
ends = boundaries[1:]
lens = ends - starts
# Midrank for each group in 1-based rank space.
mids = 0.5 * (starts + ends - 1) + 1.0
t_sorted = np.repeat(mids, lens).astype(float, copy=False)
out = np.empty(n, dtype=float)
out[order] = t
out[order] = t_sorted
return out
@@ -678,34 +693,98 @@ def brier_score(p: np.ndarray, y: np.ndarray) -> float:
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
bins, ici = _calibration_bins_and_ici(
p, y, n_bins=int(n_bins), return_bins=True)
return {"bins": bins, "ici": float(ici)}
def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float:
"""Fast ICI only (no per-bin point export)."""
_, ici = _calibration_bins_and_ici(
p, y, n_bins=int(n_bins), return_bins=False)
return float(ici)
def _calibration_bins_and_ici(
p: np.ndarray,
y: np.ndarray,
*,
n_bins: int,
return_bins: bool,
) -> Tuple[List[Dict[str, Any]], float]:
"""Vectorized quantile binning for calibration + ICI."""
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
# guard
if p.size == 0:
return {"bins": [], "ici": float("nan")}
return ([], float("nan")) if return_bins else ([], float("nan"))
edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1))
# make strictly increasing where possible
q = np.linspace(0.0, 1.0, int(n_bins) + 1)
edges = np.quantile(p, q)
edges = np.asarray(edges, dtype=float)
edges[0] = -np.inf
edges[-1] = np.inf
bins = []
ici_accum = 0.0
n = p.shape[0]
# Bin assignment: i if edges[i] < p <= edges[i+1]
bin_idx = np.searchsorted(edges, p, side="right") - 1
bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1)
for i in range(n_bins):
mask = (p > edges[i]) & (p <= edges[i + 1])
if not np.any(mask):
continue
p_mean = float(np.mean(p[mask]))
y_mean = float(np.mean(y[mask]))
bins.append({"bin": i, "p_mean": p_mean,
"y_mean": y_mean, "n": int(mask.sum())})
ici_accum += abs(p_mean - y_mean)
counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float)
sum_p = np.bincount(bin_idx, weights=p,
minlength=int(n_bins)).astype(float)
sum_y = np.bincount(bin_idx, weights=y,
minlength=int(n_bins)).astype(float)
ici = ici_accum / max(len(bins), 1)
return {"bins": bins, "ici": float(ici)}
nonempty = counts > 0
if not np.any(nonempty):
return ([], float("nan")) if return_bins else ([], float("nan"))
p_mean = np.zeros(int(n_bins), dtype=float)
y_mean = np.zeros(int(n_bins), dtype=float)
p_mean[nonempty] = sum_p[nonempty] / counts[nonempty]
y_mean[nonempty] = sum_y[nonempty] / counts[nonempty]
diffs = np.abs(p_mean[nonempty] - y_mean[nonempty])
ici = float(np.mean(diffs)) if diffs.size else float("nan")
if not return_bins:
return [], ici
bins: List[Dict[str, Any]] = []
idxs = np.nonzero(nonempty)[0]
for i in idxs.tolist():
bins.append(
{
"bin": int(i),
"p_mean": float(p_mean[i]),
"y_mean": float(y_mean[i]),
"n": int(counts[i]),
}
)
return bins, ici
def _progress_line(done: int, total: int, prefix: str = "") -> str:
total_i = max(1, int(total))
done_i = max(0, min(int(done), total_i))
width = 28
frac = done_i / total_i
filled = int(round(width * frac))
bar = "#" * filled + "-" * (width - filled)
pct = 100.0 * frac
return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)"
def _should_show_progress(mode: str) -> bool:
m = str(mode).strip().lower()
if m in {"0", "false", "no", "none", "off"}:
return False
# Default: show if interactive.
if m in {"auto", "1", "true", "yes", "on", "bar"}:
try:
return bool(sys.stdout.isatty())
except Exception:
return True
return True
def _safe_float(x: Any, default: float = float("nan")) -> float:
@@ -1116,15 +1195,14 @@ def predict_cifs_for_model(
device: str,
offset_years: float,
eval_horizons: Sequence[float],
top_cause_ids: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
n_disease: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Run model and produce cause-specific, time-dependent CIF outputs.
Returns:
cause_cif: (N, topK, H)
cif_full: (N, K, H)
survival: (N, H)
y_cause_within_tau: (N, topK, H)
y_cause_within_tau: (N, K, H)
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
"""
@@ -1132,13 +1210,10 @@ def predict_cifs_for_model(
head.eval()
# We will accumulate in CPU lists, then concat.
cause_cif_list: List[np.ndarray] = []
cif_full_list: List[np.ndarray] = []
survival_list: List[np.ndarray] = []
y_cause_within_list: List[np.ndarray] = []
sex_list: List[np.ndarray] = []
top_cause_ids_t = torch.tensor(
top_cause_ids, dtype=torch.long, device=device)
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
@@ -1177,95 +1252,124 @@ def predict_cifs_for_model(
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
cause_cif = cif_full.index_select(
dim=1, index=top_cause_ids_t) # (B,topK,H)
# Within-horizon labels for cause-specific CIF quality + discrimination.
n_disease = int(cif_full.size(1))
y_within_top = torch.stack(
# Within-horizon labels for all causes: disease k occurs within tau after context.
y_within_full = torch.stack(
[
multi_hot_selected_causes_within_horizon(
multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau),
cause_ids=top_cause_ids_t,
n_disease=n_disease,
n_disease=int(n_disease),
).to(torch.float32)
for tau in eval_horizons
],
dim=2,
) # (B,topK,H)
) # (B,K,H)
cause_cif_list.append(cause_cif.detach().cpu().numpy())
cif_full_list.append(cif_full.detach().cpu().numpy())
survival_list.append(survival.detach().cpu().numpy())
y_cause_within_list.append(y_within_top.detach().cpu().numpy())
y_cause_within_list.append(y_within_full.detach().cpu().numpy())
sex_list.append(sexes_k.detach().cpu().numpy())
if not cause_cif_list:
if not cif_full_list:
raise RuntimeError(
"No valid samples for evaluation (all batches filtered out by offset).")
cause_cif = np.concatenate(cause_cif_list, axis=0)
cif_full = np.concatenate(cif_full_list, axis=0)
survival = np.concatenate(survival_list, axis=0)
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
sex = np.concatenate(
sex_list, axis=0) if sex_list else np.array([], dtype=int)
return cause_cif, cif_full, survival, y_cause_within, sex
def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray:
counts = y_ever.sum(axis=0)
order = np.argsort(-counts)
order = order[counts[order] > 0]
return order[:top_k]
return cif_full, survival, y_cause_within, sex
def evaluate_one_model(
model_name: str,
cause_cif: np.ndarray,
cif_full: np.ndarray,
y_cause_within_tau: np.ndarray,
eval_horizons: Sequence[float],
top_cause_ids: np.ndarray,
out_rows: List[Dict[str, Any]],
calib_rows: List[Dict[str, Any]],
calib_cause_ids: Optional[Sequence[int]],
auc_ci_method: str,
bootstrap_n: int,
n_calib_bins: int = 10,
metric_workers: int = 0,
progress: str = "auto",
) -> None:
# Cause-specific, time-dependent metrics per horizon.
for h_i, tau in enumerate(eval_horizons):
p_tau = cause_cif[:, :, h_i] # (N, topK)
y_tau = y_cause_within_tau[:, :, h_i] # (N, topK)
"""Compute per-cause metrics for ALL diseases.
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = p_tau[:, j]
y = y_tau[:, j]
Notes:
- Writes scalar metrics for all causes into out_rows.
- Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable).
"""
cif_full = np.asarray(cif_full, dtype=float)
y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float)
if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3:
raise ValueError(
"Expected cif_full and y_cause_within_tau with shape (N, K, H)")
if cif_full.shape != y_cause_within_tau.shape:
raise ValueError(
f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}"
)
# Primary: CIF-based Brier score + ICI (calibration).
out_rows.append(
N, K, H = cif_full.shape
if H != len(eval_horizons):
raise ValueError("H mismatch between cif_full and eval_horizons")
calib_set = set(int(x)
for x in calib_cause_ids) if calib_cause_ids is not None else set()
workers = int(metric_workers)
if workers <= 0:
workers = int(min(8, os.cpu_count() or 1))
workers = max(1, workers)
show_progress = _should_show_progress(progress)
def _eval_chunk(
*,
tau: float,
p_tau: np.ndarray,
y_tau: np.ndarray,
brier_by_cause: np.ndarray,
cause_ids: np.ndarray,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]:
local_rows: List[Dict[str, Any]] = []
local_calib: List[Dict[str, Any]] = []
for cid in cause_ids.tolist():
p = p_tau[:, cid]
y = y_tau[:, cid]
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_brier",
"horizon": float(tau),
"cause": int(cause_id),
"value": brier_score(p, y),
"cause": int(cid),
"value": float(brier_by_cause[cid]),
"ci_low": "",
"ci_high": "",
}
)
# ICI: compute bins only if we will export them.
need_bins = (not calib_set) or (int(cid) in calib_set)
if need_bins:
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
out_rows.append(
ici = float(cal["ici"])
else:
cal = None
ici = calibration_ici_only(p, y, n_bins=n_calib_bins)
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_ici",
"horizon": float(tau),
"cause": int(cause_id),
"value": cal["ici"],
"cause": int(cid),
"value": float(ici),
"ci_low": "",
"ci_high": "",
}
@@ -1276,35 +1380,150 @@ def evaluate_one_model(
auc, lo, hi = float("nan"), float("nan"), float("nan")
elif auc_ci_method == "bootstrap":
auc, lo, hi = bootstrap_auc_ci(
p, y, n_bootstrap=bootstrap_n, alpha=0.95)
p, y, n_bootstrap=bootstrap_n, alpha=0.95
)
else:
auc, lo, hi = delong_ci(y, p, alpha=0.95)
out_rows.append(
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_auc",
"horizon": float(tau),
"cause": int(cause_id),
"value": auc,
"cause": int(cid),
"value": float(auc),
"ci_low": lo,
"ci_high": hi,
}
)
# Calibration curve bins for this cause + horizon.
if need_bins and cal is not None:
for binfo in cal.get("bins", []):
calib_rows.append(
local_calib.append(
{
"model_name": model_name,
"task": "cause_k",
"horizon": float(tau),
"cause_id": int(cause_id),
"cause_id": int(cid),
"bin_index": int(binfo["bin"]),
"p_mean": float(binfo["p_mean"]),
"y_mean": float(binfo["y_mean"]),
"n_in_bin": int(binfo["n"]),
}
)
return local_rows, local_calib, int(cause_ids.size)
# Cause-specific, time-dependent metrics per horizon.
for h_i, tau in enumerate(eval_horizons):
p_tau = cif_full[:, :, h_i] # (N, K)
y_tau = y_cause_within_tau[:, :, h_i] # (N, K)
# Vectorized Brier for speed.
brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,)
# Parallelize disease-level metrics; chunk to avoid millions of futures.
all_ids = np.arange(int(K), dtype=int)
chunks = np.array_split(all_ids, workers)
done = 0
prefix = f"[{model_name}] tau={float(tau)}y "
t0 = time.time()
if workers <= 1:
for ch in chunks:
r_chunk, c_chunk, n_done = _eval_chunk(
tau=float(tau),
p_tau=p_tau,
y_tau=y_tau,
brier_by_cause=brier_by_cause,
cause_ids=ch,
)
out_rows.extend(r_chunk)
calib_rows.extend(c_chunk)
done += int(n_done)
if show_progress:
sys.stdout.write(
"\r" + _progress_line(done, int(K), prefix=prefix))
sys.stdout.flush()
else:
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = [
ex.submit(
_eval_chunk,
tau=float(tau),
p_tau=p_tau,
y_tau=y_tau,
brier_by_cause=brier_by_cause,
cause_ids=ch,
)
for ch in chunks
if int(ch.size) > 0
]
for fut in as_completed(futs):
r_chunk, c_chunk, n_done = fut.result()
out_rows.extend(r_chunk)
calib_rows.extend(c_chunk)
done += int(n_done)
if show_progress:
sys.stdout.write(
"\r" + _progress_line(done, int(K), prefix=prefix))
sys.stdout.flush()
if show_progress:
dt = time.time() - t0
sys.stdout.write("\r" + _progress_line(int(K),
int(K), prefix=prefix) + f" ({dt:.1f}s)\n")
sys.stdout.flush()
def summarize_over_diseases(
rows: List[Dict[str, Any]],
*,
model_name: str,
eval_horizons: Sequence[float],
metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"),
) -> List[Dict[str, Any]]:
"""Summarize mean/median of each metric over diseases (per horizon)."""
out: List[Dict[str, Any]] = []
# Build metric_name -> horizon -> list of values
bucket: Dict[Tuple[str, float], List[float]] = {}
for r in rows:
if r.get("model_name") != model_name:
continue
m = str(r.get("metric_name"))
if m not in set(metrics):
continue
h = _safe_float(r.get("horizon"))
v = _safe_float(r.get("value"))
if not np.isfinite(h):
continue
if not np.isfinite(v):
continue
bucket.setdefault((m, float(h)), []).append(float(v))
for tau in eval_horizons:
ht = float(tau)
for m in metrics:
vals = bucket.get((str(m), ht), [])
if vals:
arr = np.asarray(vals, dtype=float)
mean_v = float(np.mean(arr))
med_v = float(np.median(arr))
n_valid = int(arr.size)
else:
mean_v = float("nan")
med_v = float("nan")
n_valid = 0
out.append(
{
"model_name": str(model_name),
"metric_name": str(m),
"horizon": ht,
"mean": mean_v,
"median": med_v,
"n_valid": n_valid,
}
)
return out
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
@@ -1361,13 +1580,12 @@ def main() -> int:
help="Anti-leakage offset (years)")
ap.add_argument("--eval_horizons", type=float,
nargs="*", default=DEFAULT_EVAL_HORIZONS)
ap.add_argument("--top_k_causes", type=int, default=50)
ap.add_argument("--batch_size", type=int, default=128)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--out_csv", type=str, default="eval_results.csv")
ap.add_argument("--out_csv", type=str, default="eval_summary.csv")
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
# Integrity checks
@@ -1383,6 +1601,21 @@ def main() -> int:
)
ap.add_argument("--bootstrap_n", type=int, default=2000)
# Speed/UX
ap.add_argument(
"--metric_workers",
type=int,
default=0,
help="Threads for per-disease metrics (0=auto, 1=disable parallelism)",
)
ap.add_argument(
"--progress",
type=str,
default="auto",
choices=["auto", "bar", "none"],
help="Progress visualization during per-disease evaluation",
)
# Export settings for user-facing experiments
ap.add_argument("--export_dir", type=str, default="eval_exports")
ap.add_argument("--death_cause_id", type=int,
@@ -1463,7 +1696,8 @@ def main() -> int:
# Metadata for focus causes (within tau_max).
top_causes_meta: List[Dict[str, Any]] = []
for cid in focus_causes:
n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0
n_case = int(counts[int(cid)]) if int(
cid) < int(counts.shape[0]) else 0
top_causes_meta.append(
{
"cause_id": int(cid),
@@ -1483,7 +1717,7 @@ def main() -> int:
hg_rows,
)
rows: List[Dict[str, Any]] = []
summary_rows: List[Dict[str, Any]] = []
calib_rows: List[Dict[str, Any]] = []
# Experiment exports (accumulated across models)
@@ -1503,7 +1737,6 @@ def main() -> int:
tag = _make_eval_tag(args.split, float(args.offset_years))
# Remember list offsets so we can write per-model slices to the model's run_dir.
rows_start = len(rows)
calib_start = len(calib_rows)
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
@@ -1544,7 +1777,6 @@ def main() -> int:
head.load_state_dict(ckpt["head_state_dict"], strict=True)
(
cause_cif,
cif_full,
survival,
y_cause_within_tau,
@@ -1558,7 +1790,7 @@ def main() -> int:
args.device,
args.offset_years,
args.eval_horizons,
top_cause_ids,
n_disease=int(dataset.n_disease),
)
# CIF integrity checks before metrics.
@@ -1575,26 +1807,42 @@ def main() -> int:
"integrity_notes": integrity_notes,
}
# Per-disease metrics for ALL diseases (written into the model's run_dir).
model_rows: List[Dict[str, Any]] = []
evaluate_one_model(
model_name=spec.name,
cause_cif=cause_cif,
cif_full=cif_full,
y_cause_within_tau=y_cause_within_tau,
eval_horizons=args.eval_horizons,
top_cause_ids=top_cause_ids,
out_rows=rows,
out_rows=model_rows,
calib_rows=calib_rows,
calib_cause_ids=top_cause_ids.tolist(),
auc_ci_method=str(args.auc_ci_method),
bootstrap_n=int(args.bootstrap_n),
metric_workers=int(args.metric_workers),
progress=str(args.progress),
)
# Summary over diseases (mean/median per horizon).
model_summary_rows = summarize_over_diseases(
model_rows,
model_name=spec.name,
eval_horizons=args.eval_horizons,
)
summary_rows.extend(model_summary_rows)
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full[:, top_cause_ids, :]
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
# ============================================================
# Experiment 1: Risk stratification bins + summary
# ============================================================
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif[:, j, h_i]
y = y_cause_within_tau[:, j, h_i]
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
@@ -1648,8 +1896,8 @@ def main() -> int:
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif[:, j, h_i]
y = y_cause_within_tau[:, j, h_i]
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
@@ -1690,15 +1938,15 @@ def main() -> int:
# Per-horizon metrics for grouping
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
per_h: Dict[Tuple[int, float], Dict[str, float]] = {}
for rr in rows[rows_start:]:
if rr.get("model_name") != spec.name:
continue
for rr in model_rows:
if rr.get("metric_name") not in {"cause_brier", "cause_ici"}:
continue
try:
cid = int(rr.get("cause"))
except Exception:
continue
if cid not in set(int(x) for x in top_cause_ids.tolist()):
continue
h = _safe_float(rr.get("horizon"))
if not np.isfinite(h):
continue
@@ -1738,8 +1986,8 @@ def main() -> int:
group_vals[g]["ici"].append(ici_h)
# pooled reliability bins from raw p/y
p = cause_cif[:, j, h_i]
y = y_cause_within_tau[:, j, h_i]
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
@@ -1809,7 +2057,7 @@ def main() -> int:
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta:
rows.append(
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_case_within_tau",
@@ -1820,7 +2068,7 @@ def main() -> int:
"ci_high": "",
}
)
rows.append(
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_control_within_tau",
@@ -1831,7 +2079,7 @@ def main() -> int:
"ci_high": "",
}
)
rows.append(
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_total_eval",
@@ -1844,13 +2092,18 @@ def main() -> int:
)
# Write per-model results into the model's run directory.
model_rows = rows[rows_start:]
model_calib_rows = calib_rows[calib_start:]
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv")
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
write_results_csv(model_out_csv, model_rows)
write_simple_csv(
model_summary_csv,
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
model_summary_rows,
)
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
model_meta = {
@@ -1860,12 +2113,13 @@ def main() -> int:
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"top_k_causes": int(args.top_k_causes),
"n_disease": int(dataset.n_disease),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
"paths": {
"results_csv": model_out_csv,
"summary_csv": model_summary_csv,
"calibration_bins_csv": model_calib_csv,
},
}
@@ -1874,7 +2128,12 @@ def main() -> int:
print(f"Wrote per-model results to {model_out_csv}")
write_results_csv(args.out_csv, rows)
# Write global summary (across diseases) across all models.
write_simple_csv(
args.out_csv,
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
summary_rows,
)
# Write calibration curve points to a separate CSV.
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
@@ -2012,7 +2271,7 @@ def main() -> int:
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
"## Files\n\n"
"- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n"
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
"- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n"
"- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n"
"- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n"
@@ -2027,7 +2286,7 @@ def main() -> int:
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"tau_max": float(tau_max),
"top_k_causes": int(args.top_k_causes),
"n_disease": int(dataset_for_top.n_disease),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": integrity_meta,
@@ -2045,7 +2304,7 @@ def main() -> int:
with open(args.out_meta_json, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote {args.out_csv} with {len(rows)} rows")
print(f"Wrote {args.out_csv} with {len(summary_rows)} rows")
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
print(f"Wrote {args.out_meta_json}")
return 0