Enhance evaluation metrics and progress visualization in model evaluation
- Introduced parallel processing for per-disease metrics using ThreadPoolExecutor. - Added command-line arguments for metric workers and progress visualization options. - Refactored evaluation functions to compute metrics for all diseases and summarize results. - Updated output CSV filenames for clarity and consistency.
This commit is contained in:
@@ -5,6 +5,9 @@ import math
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
@@ -522,20 +525,32 @@ def check_cif_integrity(
|
||||
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
|
||||
|
||||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized midrank computation (ties -> average ranks)."""
|
||||
x = np.asarray(x, dtype=float)
|
||||
order = np.argsort(x)
|
||||
n = int(x.shape[0])
|
||||
if n == 0:
|
||||
return np.asarray([], dtype=float)
|
||||
|
||||
order = np.argsort(x, kind="mergesort")
|
||||
z = x[order]
|
||||
n = x.shape[0]
|
||||
t = np.zeros(n, dtype=float)
|
||||
i = 0
|
||||
while i < n:
|
||||
j = i
|
||||
while j < n and z[j] == z[i]:
|
||||
j += 1
|
||||
t[i:j] = 0.5 * (i + j - 1) + 1.0
|
||||
i = j
|
||||
|
||||
# Find tie groups in sorted order.
|
||||
diff = np.diff(z)
|
||||
# boundaries includes 0 and n
|
||||
boundaries = np.concatenate(
|
||||
[np.array([0], dtype=int), np.nonzero(diff != 0)
|
||||
[0] + 1, np.array([n], dtype=int)]
|
||||
)
|
||||
starts = boundaries[:-1]
|
||||
ends = boundaries[1:]
|
||||
lens = ends - starts
|
||||
|
||||
# Midrank for each group in 1-based rank space.
|
||||
mids = 0.5 * (starts + ends - 1) + 1.0
|
||||
t_sorted = np.repeat(mids, lens).astype(float, copy=False)
|
||||
|
||||
out = np.empty(n, dtype=float)
|
||||
out[order] = t
|
||||
out[order] = t_sorted
|
||||
return out
|
||||
|
||||
|
||||
@@ -678,34 +693,98 @@ def brier_score(p: np.ndarray, y: np.ndarray) -> float:
|
||||
|
||||
|
||||
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
|
||||
bins, ici = _calibration_bins_and_ici(
|
||||
p, y, n_bins=int(n_bins), return_bins=True)
|
||||
return {"bins": bins, "ici": float(ici)}
|
||||
|
||||
|
||||
def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float:
|
||||
"""Fast ICI only (no per-bin point export)."""
|
||||
_, ici = _calibration_bins_and_ici(
|
||||
p, y, n_bins=int(n_bins), return_bins=False)
|
||||
return float(ici)
|
||||
|
||||
|
||||
def _calibration_bins_and_ici(
|
||||
p: np.ndarray,
|
||||
y: np.ndarray,
|
||||
*,
|
||||
n_bins: int,
|
||||
return_bins: bool,
|
||||
) -> Tuple[List[Dict[str, Any]], float]:
|
||||
"""Vectorized quantile binning for calibration + ICI."""
|
||||
p = np.asarray(p, dtype=float)
|
||||
y = np.asarray(y, dtype=float)
|
||||
|
||||
# guard
|
||||
if p.size == 0:
|
||||
return {"bins": [], "ici": float("nan")}
|
||||
return ([], float("nan")) if return_bins else ([], float("nan"))
|
||||
|
||||
edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1))
|
||||
# make strictly increasing where possible
|
||||
q = np.linspace(0.0, 1.0, int(n_bins) + 1)
|
||||
edges = np.quantile(p, q)
|
||||
edges = np.asarray(edges, dtype=float)
|
||||
edges[0] = -np.inf
|
||||
edges[-1] = np.inf
|
||||
|
||||
bins = []
|
||||
ici_accum = 0.0
|
||||
n = p.shape[0]
|
||||
# Bin assignment: i if edges[i] < p <= edges[i+1]
|
||||
bin_idx = np.searchsorted(edges, p, side="right") - 1
|
||||
bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1)
|
||||
|
||||
for i in range(n_bins):
|
||||
mask = (p > edges[i]) & (p <= edges[i + 1])
|
||||
if not np.any(mask):
|
||||
continue
|
||||
p_mean = float(np.mean(p[mask]))
|
||||
y_mean = float(np.mean(y[mask]))
|
||||
bins.append({"bin": i, "p_mean": p_mean,
|
||||
"y_mean": y_mean, "n": int(mask.sum())})
|
||||
ici_accum += abs(p_mean - y_mean)
|
||||
counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float)
|
||||
sum_p = np.bincount(bin_idx, weights=p,
|
||||
minlength=int(n_bins)).astype(float)
|
||||
sum_y = np.bincount(bin_idx, weights=y,
|
||||
minlength=int(n_bins)).astype(float)
|
||||
|
||||
ici = ici_accum / max(len(bins), 1)
|
||||
return {"bins": bins, "ici": float(ici)}
|
||||
nonempty = counts > 0
|
||||
if not np.any(nonempty):
|
||||
return ([], float("nan")) if return_bins else ([], float("nan"))
|
||||
|
||||
p_mean = np.zeros(int(n_bins), dtype=float)
|
||||
y_mean = np.zeros(int(n_bins), dtype=float)
|
||||
p_mean[nonempty] = sum_p[nonempty] / counts[nonempty]
|
||||
y_mean[nonempty] = sum_y[nonempty] / counts[nonempty]
|
||||
|
||||
diffs = np.abs(p_mean[nonempty] - y_mean[nonempty])
|
||||
ici = float(np.mean(diffs)) if diffs.size else float("nan")
|
||||
|
||||
if not return_bins:
|
||||
return [], ici
|
||||
|
||||
bins: List[Dict[str, Any]] = []
|
||||
idxs = np.nonzero(nonempty)[0]
|
||||
for i in idxs.tolist():
|
||||
bins.append(
|
||||
{
|
||||
"bin": int(i),
|
||||
"p_mean": float(p_mean[i]),
|
||||
"y_mean": float(y_mean[i]),
|
||||
"n": int(counts[i]),
|
||||
}
|
||||
)
|
||||
return bins, ici
|
||||
|
||||
|
||||
def _progress_line(done: int, total: int, prefix: str = "") -> str:
|
||||
total_i = max(1, int(total))
|
||||
done_i = max(0, min(int(done), total_i))
|
||||
width = 28
|
||||
frac = done_i / total_i
|
||||
filled = int(round(width * frac))
|
||||
bar = "#" * filled + "-" * (width - filled)
|
||||
pct = 100.0 * frac
|
||||
return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)"
|
||||
|
||||
|
||||
def _should_show_progress(mode: str) -> bool:
|
||||
m = str(mode).strip().lower()
|
||||
if m in {"0", "false", "no", "none", "off"}:
|
||||
return False
|
||||
# Default: show if interactive.
|
||||
if m in {"auto", "1", "true", "yes", "on", "bar"}:
|
||||
try:
|
||||
return bool(sys.stdout.isatty())
|
||||
except Exception:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _safe_float(x: Any, default: float = float("nan")) -> float:
|
||||
@@ -1116,15 +1195,14 @@ def predict_cifs_for_model(
|
||||
device: str,
|
||||
offset_years: float,
|
||||
eval_horizons: Sequence[float],
|
||||
top_cause_ids: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
n_disease: int,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run model and produce cause-specific, time-dependent CIF outputs.
|
||||
|
||||
Returns:
|
||||
cause_cif: (N, topK, H)
|
||||
cif_full: (N, K, H)
|
||||
survival: (N, H)
|
||||
y_cause_within_tau: (N, topK, H)
|
||||
y_cause_within_tau: (N, K, H)
|
||||
|
||||
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
|
||||
"""
|
||||
@@ -1132,13 +1210,10 @@ def predict_cifs_for_model(
|
||||
head.eval()
|
||||
|
||||
# We will accumulate in CPU lists, then concat.
|
||||
cause_cif_list: List[np.ndarray] = []
|
||||
cif_full_list: List[np.ndarray] = []
|
||||
survival_list: List[np.ndarray] = []
|
||||
y_cause_within_list: List[np.ndarray] = []
|
||||
sex_list: List[np.ndarray] = []
|
||||
top_cause_ids_t = torch.tensor(
|
||||
top_cause_ids, dtype=torch.long, device=device)
|
||||
|
||||
for batch in loader:
|
||||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||||
@@ -1177,95 +1252,124 @@ def predict_cifs_for_model(
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||
|
||||
cause_cif = cif_full.index_select(
|
||||
dim=1, index=top_cause_ids_t) # (B,topK,H)
|
||||
|
||||
# Within-horizon labels for cause-specific CIF quality + discrimination.
|
||||
n_disease = int(cif_full.size(1))
|
||||
y_within_top = torch.stack(
|
||||
# Within-horizon labels for all causes: disease k occurs within tau after context.
|
||||
y_within_full = torch.stack(
|
||||
[
|
||||
multi_hot_selected_causes_within_horizon(
|
||||
multi_hot_ever_within_horizon(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
t_ctx=t_ctx,
|
||||
tau_years=float(tau),
|
||||
cause_ids=top_cause_ids_t,
|
||||
n_disease=n_disease,
|
||||
n_disease=int(n_disease),
|
||||
).to(torch.float32)
|
||||
for tau in eval_horizons
|
||||
],
|
||||
dim=2,
|
||||
) # (B,topK,H)
|
||||
) # (B,K,H)
|
||||
|
||||
cause_cif_list.append(cause_cif.detach().cpu().numpy())
|
||||
cif_full_list.append(cif_full.detach().cpu().numpy())
|
||||
survival_list.append(survival.detach().cpu().numpy())
|
||||
y_cause_within_list.append(y_within_top.detach().cpu().numpy())
|
||||
y_cause_within_list.append(y_within_full.detach().cpu().numpy())
|
||||
sex_list.append(sexes_k.detach().cpu().numpy())
|
||||
|
||||
if not cause_cif_list:
|
||||
if not cif_full_list:
|
||||
raise RuntimeError(
|
||||
"No valid samples for evaluation (all batches filtered out by offset).")
|
||||
|
||||
cause_cif = np.concatenate(cause_cif_list, axis=0)
|
||||
cif_full = np.concatenate(cif_full_list, axis=0)
|
||||
survival = np.concatenate(survival_list, axis=0)
|
||||
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
|
||||
sex = np.concatenate(
|
||||
sex_list, axis=0) if sex_list else np.array([], dtype=int)
|
||||
|
||||
return cause_cif, cif_full, survival, y_cause_within, sex
|
||||
|
||||
|
||||
def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray:
|
||||
counts = y_ever.sum(axis=0)
|
||||
order = np.argsort(-counts)
|
||||
order = order[counts[order] > 0]
|
||||
return order[:top_k]
|
||||
return cif_full, survival, y_cause_within, sex
|
||||
|
||||
|
||||
def evaluate_one_model(
|
||||
model_name: str,
|
||||
cause_cif: np.ndarray,
|
||||
cif_full: np.ndarray,
|
||||
y_cause_within_tau: np.ndarray,
|
||||
eval_horizons: Sequence[float],
|
||||
top_cause_ids: np.ndarray,
|
||||
out_rows: List[Dict[str, Any]],
|
||||
calib_rows: List[Dict[str, Any]],
|
||||
calib_cause_ids: Optional[Sequence[int]],
|
||||
auc_ci_method: str,
|
||||
bootstrap_n: int,
|
||||
n_calib_bins: int = 10,
|
||||
metric_workers: int = 0,
|
||||
progress: str = "auto",
|
||||
) -> None:
|
||||
# Cause-specific, time-dependent metrics per horizon.
|
||||
for h_i, tau in enumerate(eval_horizons):
|
||||
p_tau = cause_cif[:, :, h_i] # (N, topK)
|
||||
y_tau = y_cause_within_tau[:, :, h_i] # (N, topK)
|
||||
"""Compute per-cause metrics for ALL diseases.
|
||||
|
||||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||||
p = p_tau[:, j]
|
||||
y = y_tau[:, j]
|
||||
Notes:
|
||||
- Writes scalar metrics for all causes into out_rows.
|
||||
- Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable).
|
||||
"""
|
||||
cif_full = np.asarray(cif_full, dtype=float)
|
||||
y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float)
|
||||
if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3:
|
||||
raise ValueError(
|
||||
"Expected cif_full and y_cause_within_tau with shape (N, K, H)")
|
||||
if cif_full.shape != y_cause_within_tau.shape:
|
||||
raise ValueError(
|
||||
f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}"
|
||||
)
|
||||
|
||||
# Primary: CIF-based Brier score + ICI (calibration).
|
||||
out_rows.append(
|
||||
N, K, H = cif_full.shape
|
||||
if H != len(eval_horizons):
|
||||
raise ValueError("H mismatch between cif_full and eval_horizons")
|
||||
|
||||
calib_set = set(int(x)
|
||||
for x in calib_cause_ids) if calib_cause_ids is not None else set()
|
||||
|
||||
workers = int(metric_workers)
|
||||
if workers <= 0:
|
||||
workers = int(min(8, os.cpu_count() or 1))
|
||||
workers = max(1, workers)
|
||||
show_progress = _should_show_progress(progress)
|
||||
|
||||
def _eval_chunk(
|
||||
*,
|
||||
tau: float,
|
||||
p_tau: np.ndarray,
|
||||
y_tau: np.ndarray,
|
||||
brier_by_cause: np.ndarray,
|
||||
cause_ids: np.ndarray,
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]:
|
||||
local_rows: List[Dict[str, Any]] = []
|
||||
local_calib: List[Dict[str, Any]] = []
|
||||
for cid in cause_ids.tolist():
|
||||
p = p_tau[:, cid]
|
||||
y = y_tau[:, cid]
|
||||
|
||||
local_rows.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"metric_name": "cause_brier",
|
||||
"horizon": float(tau),
|
||||
"cause": int(cause_id),
|
||||
"value": brier_score(p, y),
|
||||
"cause": int(cid),
|
||||
"value": float(brier_by_cause[cid]),
|
||||
"ci_low": "",
|
||||
"ci_high": "",
|
||||
}
|
||||
)
|
||||
|
||||
# ICI: compute bins only if we will export them.
|
||||
need_bins = (not calib_set) or (int(cid) in calib_set)
|
||||
if need_bins:
|
||||
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
|
||||
out_rows.append(
|
||||
ici = float(cal["ici"])
|
||||
else:
|
||||
cal = None
|
||||
ici = calibration_ici_only(p, y, n_bins=n_calib_bins)
|
||||
|
||||
local_rows.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"metric_name": "cause_ici",
|
||||
"horizon": float(tau),
|
||||
"cause": int(cause_id),
|
||||
"value": cal["ici"],
|
||||
"cause": int(cid),
|
||||
"value": float(ici),
|
||||
"ci_low": "",
|
||||
"ci_high": "",
|
||||
}
|
||||
@@ -1276,35 +1380,150 @@ def evaluate_one_model(
|
||||
auc, lo, hi = float("nan"), float("nan"), float("nan")
|
||||
elif auc_ci_method == "bootstrap":
|
||||
auc, lo, hi = bootstrap_auc_ci(
|
||||
p, y, n_bootstrap=bootstrap_n, alpha=0.95)
|
||||
p, y, n_bootstrap=bootstrap_n, alpha=0.95
|
||||
)
|
||||
else:
|
||||
auc, lo, hi = delong_ci(y, p, alpha=0.95)
|
||||
out_rows.append(
|
||||
|
||||
local_rows.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"metric_name": "cause_auc",
|
||||
"horizon": float(tau),
|
||||
"cause": int(cause_id),
|
||||
"value": auc,
|
||||
"cause": int(cid),
|
||||
"value": float(auc),
|
||||
"ci_low": lo,
|
||||
"ci_high": hi,
|
||||
}
|
||||
)
|
||||
|
||||
# Calibration curve bins for this cause + horizon.
|
||||
if need_bins and cal is not None:
|
||||
for binfo in cal.get("bins", []):
|
||||
calib_rows.append(
|
||||
local_calib.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"task": "cause_k",
|
||||
"horizon": float(tau),
|
||||
"cause_id": int(cause_id),
|
||||
"cause_id": int(cid),
|
||||
"bin_index": int(binfo["bin"]),
|
||||
"p_mean": float(binfo["p_mean"]),
|
||||
"y_mean": float(binfo["y_mean"]),
|
||||
"n_in_bin": int(binfo["n"]),
|
||||
}
|
||||
)
|
||||
return local_rows, local_calib, int(cause_ids.size)
|
||||
|
||||
# Cause-specific, time-dependent metrics per horizon.
|
||||
for h_i, tau in enumerate(eval_horizons):
|
||||
p_tau = cif_full[:, :, h_i] # (N, K)
|
||||
y_tau = y_cause_within_tau[:, :, h_i] # (N, K)
|
||||
|
||||
# Vectorized Brier for speed.
|
||||
brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,)
|
||||
|
||||
# Parallelize disease-level metrics; chunk to avoid millions of futures.
|
||||
all_ids = np.arange(int(K), dtype=int)
|
||||
chunks = np.array_split(all_ids, workers)
|
||||
done = 0
|
||||
prefix = f"[{model_name}] tau={float(tau)}y "
|
||||
t0 = time.time()
|
||||
|
||||
if workers <= 1:
|
||||
for ch in chunks:
|
||||
r_chunk, c_chunk, n_done = _eval_chunk(
|
||||
tau=float(tau),
|
||||
p_tau=p_tau,
|
||||
y_tau=y_tau,
|
||||
brier_by_cause=brier_by_cause,
|
||||
cause_ids=ch,
|
||||
)
|
||||
out_rows.extend(r_chunk)
|
||||
calib_rows.extend(c_chunk)
|
||||
done += int(n_done)
|
||||
if show_progress:
|
||||
sys.stdout.write(
|
||||
"\r" + _progress_line(done, int(K), prefix=prefix))
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
futs = [
|
||||
ex.submit(
|
||||
_eval_chunk,
|
||||
tau=float(tau),
|
||||
p_tau=p_tau,
|
||||
y_tau=y_tau,
|
||||
brier_by_cause=brier_by_cause,
|
||||
cause_ids=ch,
|
||||
)
|
||||
for ch in chunks
|
||||
if int(ch.size) > 0
|
||||
]
|
||||
for fut in as_completed(futs):
|
||||
r_chunk, c_chunk, n_done = fut.result()
|
||||
out_rows.extend(r_chunk)
|
||||
calib_rows.extend(c_chunk)
|
||||
done += int(n_done)
|
||||
if show_progress:
|
||||
sys.stdout.write(
|
||||
"\r" + _progress_line(done, int(K), prefix=prefix))
|
||||
sys.stdout.flush()
|
||||
|
||||
if show_progress:
|
||||
dt = time.time() - t0
|
||||
sys.stdout.write("\r" + _progress_line(int(K),
|
||||
int(K), prefix=prefix) + f" ({dt:.1f}s)\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def summarize_over_diseases(
|
||||
rows: List[Dict[str, Any]],
|
||||
*,
|
||||
model_name: str,
|
||||
eval_horizons: Sequence[float],
|
||||
metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"),
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Summarize mean/median of each metric over diseases (per horizon)."""
|
||||
out: List[Dict[str, Any]] = []
|
||||
# Build metric_name -> horizon -> list of values
|
||||
bucket: Dict[Tuple[str, float], List[float]] = {}
|
||||
for r in rows:
|
||||
if r.get("model_name") != model_name:
|
||||
continue
|
||||
m = str(r.get("metric_name"))
|
||||
if m not in set(metrics):
|
||||
continue
|
||||
h = _safe_float(r.get("horizon"))
|
||||
v = _safe_float(r.get("value"))
|
||||
if not np.isfinite(h):
|
||||
continue
|
||||
if not np.isfinite(v):
|
||||
continue
|
||||
bucket.setdefault((m, float(h)), []).append(float(v))
|
||||
|
||||
for tau in eval_horizons:
|
||||
ht = float(tau)
|
||||
for m in metrics:
|
||||
vals = bucket.get((str(m), ht), [])
|
||||
if vals:
|
||||
arr = np.asarray(vals, dtype=float)
|
||||
mean_v = float(np.mean(arr))
|
||||
med_v = float(np.median(arr))
|
||||
n_valid = int(arr.size)
|
||||
else:
|
||||
mean_v = float("nan")
|
||||
med_v = float("nan")
|
||||
n_valid = 0
|
||||
out.append(
|
||||
{
|
||||
"model_name": str(model_name),
|
||||
"metric_name": str(m),
|
||||
"horizon": ht,
|
||||
"mean": mean_v,
|
||||
"median": med_v,
|
||||
"n_valid": n_valid,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||||
@@ -1361,13 +1580,12 @@ def main() -> int:
|
||||
help="Anti-leakage offset (years)")
|
||||
ap.add_argument("--eval_horizons", type=float,
|
||||
nargs="*", default=DEFAULT_EVAL_HORIZONS)
|
||||
ap.add_argument("--top_k_causes", type=int, default=50)
|
||||
ap.add_argument("--batch_size", type=int, default=128)
|
||||
ap.add_argument("--num_workers", type=int, default=0)
|
||||
ap.add_argument("--seed", type=int, default=123)
|
||||
ap.add_argument("--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
ap.add_argument("--out_csv", type=str, default="eval_results.csv")
|
||||
ap.add_argument("--out_csv", type=str, default="eval_summary.csv")
|
||||
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
|
||||
|
||||
# Integrity checks
|
||||
@@ -1383,6 +1601,21 @@ def main() -> int:
|
||||
)
|
||||
ap.add_argument("--bootstrap_n", type=int, default=2000)
|
||||
|
||||
# Speed/UX
|
||||
ap.add_argument(
|
||||
"--metric_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Threads for per-disease metrics (0=auto, 1=disable parallelism)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--progress",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "bar", "none"],
|
||||
help="Progress visualization during per-disease evaluation",
|
||||
)
|
||||
|
||||
# Export settings for user-facing experiments
|
||||
ap.add_argument("--export_dir", type=str, default="eval_exports")
|
||||
ap.add_argument("--death_cause_id", type=int,
|
||||
@@ -1463,7 +1696,8 @@ def main() -> int:
|
||||
# Metadata for focus causes (within tau_max).
|
||||
top_causes_meta: List[Dict[str, Any]] = []
|
||||
for cid in focus_causes:
|
||||
n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0
|
||||
n_case = int(counts[int(cid)]) if int(
|
||||
cid) < int(counts.shape[0]) else 0
|
||||
top_causes_meta.append(
|
||||
{
|
||||
"cause_id": int(cid),
|
||||
@@ -1483,7 +1717,7 @@ def main() -> int:
|
||||
hg_rows,
|
||||
)
|
||||
|
||||
rows: List[Dict[str, Any]] = []
|
||||
summary_rows: List[Dict[str, Any]] = []
|
||||
calib_rows: List[Dict[str, Any]] = []
|
||||
|
||||
# Experiment exports (accumulated across models)
|
||||
@@ -1503,7 +1737,6 @@ def main() -> int:
|
||||
tag = _make_eval_tag(args.split, float(args.offset_years))
|
||||
|
||||
# Remember list offsets so we can write per-model slices to the model's run_dir.
|
||||
rows_start = len(rows)
|
||||
calib_start = len(calib_rows)
|
||||
|
||||
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
|
||||
@@ -1544,7 +1777,6 @@ def main() -> int:
|
||||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||||
|
||||
(
|
||||
cause_cif,
|
||||
cif_full,
|
||||
survival,
|
||||
y_cause_within_tau,
|
||||
@@ -1558,7 +1790,7 @@ def main() -> int:
|
||||
args.device,
|
||||
args.offset_years,
|
||||
args.eval_horizons,
|
||||
top_cause_ids,
|
||||
n_disease=int(dataset.n_disease),
|
||||
)
|
||||
|
||||
# CIF integrity checks before metrics.
|
||||
@@ -1575,26 +1807,42 @@ def main() -> int:
|
||||
"integrity_notes": integrity_notes,
|
||||
}
|
||||
|
||||
# Per-disease metrics for ALL diseases (written into the model's run_dir).
|
||||
model_rows: List[Dict[str, Any]] = []
|
||||
evaluate_one_model(
|
||||
model_name=spec.name,
|
||||
cause_cif=cause_cif,
|
||||
cif_full=cif_full,
|
||||
y_cause_within_tau=y_cause_within_tau,
|
||||
eval_horizons=args.eval_horizons,
|
||||
top_cause_ids=top_cause_ids,
|
||||
out_rows=rows,
|
||||
out_rows=model_rows,
|
||||
calib_rows=calib_rows,
|
||||
calib_cause_ids=top_cause_ids.tolist(),
|
||||
auc_ci_method=str(args.auc_ci_method),
|
||||
bootstrap_n=int(args.bootstrap_n),
|
||||
metric_workers=int(args.metric_workers),
|
||||
progress=str(args.progress),
|
||||
)
|
||||
|
||||
# Summary over diseases (mean/median per horizon).
|
||||
model_summary_rows = summarize_over_diseases(
|
||||
model_rows,
|
||||
model_name=spec.name,
|
||||
eval_horizons=args.eval_horizons,
|
||||
)
|
||||
summary_rows.extend(model_summary_rows)
|
||||
|
||||
# Convenience slices for user-facing experiments (focus causes only).
|
||||
cause_cif_focus = cif_full[:, top_cause_ids, :]
|
||||
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
|
||||
|
||||
# ============================================================
|
||||
# Experiment 1: Risk stratification bins + summary
|
||||
# ============================================================
|
||||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||||
for h_i, tau in enumerate(args.eval_horizons):
|
||||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||||
p = cause_cif[:, j, h_i]
|
||||
y = y_cause_within_tau[:, j, h_i]
|
||||
p = cause_cif_focus[:, j, h_i]
|
||||
y = y_within_focus[:, j, h_i]
|
||||
if sex_mask is not None:
|
||||
p = p[sex_mask]
|
||||
y = y[sex_mask]
|
||||
@@ -1648,8 +1896,8 @@ def main() -> int:
|
||||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||||
for h_i, tau in enumerate(args.eval_horizons):
|
||||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||||
p = cause_cif[:, j, h_i]
|
||||
y = y_cause_within_tau[:, j, h_i]
|
||||
p = cause_cif_focus[:, j, h_i]
|
||||
y = y_within_focus[:, j, h_i]
|
||||
if sex_mask is not None:
|
||||
p = p[sex_mask]
|
||||
y = y[sex_mask]
|
||||
@@ -1690,15 +1938,15 @@ def main() -> int:
|
||||
# Per-horizon metrics for grouping
|
||||
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
|
||||
per_h: Dict[Tuple[int, float], Dict[str, float]] = {}
|
||||
for rr in rows[rows_start:]:
|
||||
if rr.get("model_name") != spec.name:
|
||||
continue
|
||||
for rr in model_rows:
|
||||
if rr.get("metric_name") not in {"cause_brier", "cause_ici"}:
|
||||
continue
|
||||
try:
|
||||
cid = int(rr.get("cause"))
|
||||
except Exception:
|
||||
continue
|
||||
if cid not in set(int(x) for x in top_cause_ids.tolist()):
|
||||
continue
|
||||
h = _safe_float(rr.get("horizon"))
|
||||
if not np.isfinite(h):
|
||||
continue
|
||||
@@ -1738,8 +1986,8 @@ def main() -> int:
|
||||
group_vals[g]["ici"].append(ici_h)
|
||||
|
||||
# pooled reliability bins from raw p/y
|
||||
p = cause_cif[:, j, h_i]
|
||||
y = y_cause_within_tau[:, j, h_i]
|
||||
p = cause_cif_focus[:, j, h_i]
|
||||
y = y_within_focus[:, j, h_i]
|
||||
if sex_mask is not None:
|
||||
p = p[sex_mask]
|
||||
y = y[sex_mask]
|
||||
@@ -1809,7 +2057,7 @@ def main() -> int:
|
||||
|
||||
# Optionally write top-cause counts into the main results CSV as metric rows.
|
||||
for tc in top_causes_meta:
|
||||
rows.append(
|
||||
model_rows.append(
|
||||
{
|
||||
"model_name": spec.name,
|
||||
"metric_name": "topcause_n_case_within_tau",
|
||||
@@ -1820,7 +2068,7 @@ def main() -> int:
|
||||
"ci_high": "",
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
model_rows.append(
|
||||
{
|
||||
"model_name": spec.name,
|
||||
"metric_name": "topcause_n_control_within_tau",
|
||||
@@ -1831,7 +2079,7 @@ def main() -> int:
|
||||
"ci_high": "",
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
model_rows.append(
|
||||
{
|
||||
"model_name": spec.name,
|
||||
"metric_name": "topcause_n_total_eval",
|
||||
@@ -1844,13 +2092,18 @@ def main() -> int:
|
||||
)
|
||||
|
||||
# Write per-model results into the model's run directory.
|
||||
model_rows = rows[rows_start:]
|
||||
model_calib_rows = calib_rows[calib_start:]
|
||||
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
|
||||
model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv")
|
||||
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
|
||||
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
|
||||
|
||||
write_results_csv(model_out_csv, model_rows)
|
||||
write_simple_csv(
|
||||
model_summary_csv,
|
||||
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
|
||||
model_summary_rows,
|
||||
)
|
||||
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
|
||||
|
||||
model_meta = {
|
||||
@@ -1860,12 +2113,13 @@ def main() -> int:
|
||||
"split": args.split,
|
||||
"offset_years": args.offset_years,
|
||||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||||
"top_k_causes": int(args.top_k_causes),
|
||||
"n_disease": int(dataset.n_disease),
|
||||
"top_cause_ids": top_cause_ids.tolist(),
|
||||
"top_causes": top_causes_meta,
|
||||
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
|
||||
"paths": {
|
||||
"results_csv": model_out_csv,
|
||||
"summary_csv": model_summary_csv,
|
||||
"calibration_bins_csv": model_calib_csv,
|
||||
},
|
||||
}
|
||||
@@ -1874,7 +2128,12 @@ def main() -> int:
|
||||
|
||||
print(f"Wrote per-model results to {model_out_csv}")
|
||||
|
||||
write_results_csv(args.out_csv, rows)
|
||||
# Write global summary (across diseases) across all models.
|
||||
write_simple_csv(
|
||||
args.out_csv,
|
||||
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
|
||||
summary_rows,
|
||||
)
|
||||
|
||||
# Write calibration curve points to a separate CSV.
|
||||
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
|
||||
@@ -2012,7 +2271,7 @@ def main() -> int:
|
||||
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
|
||||
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
|
||||
"## Files\n\n"
|
||||
"- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n"
|
||||
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
|
||||
"- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n"
|
||||
"- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n"
|
||||
"- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n"
|
||||
@@ -2027,7 +2286,7 @@ def main() -> int:
|
||||
"offset_years": args.offset_years,
|
||||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||||
"tau_max": float(tau_max),
|
||||
"top_k_causes": int(args.top_k_causes),
|
||||
"n_disease": int(dataset_for_top.n_disease),
|
||||
"top_cause_ids": top_cause_ids.tolist(),
|
||||
"top_causes": top_causes_meta,
|
||||
"integrity": integrity_meta,
|
||||
@@ -2045,7 +2304,7 @@ def main() -> int:
|
||||
with open(args.out_meta_json, "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
print(f"Wrote {args.out_csv} with {len(rows)} rows")
|
||||
print(f"Wrote {args.out_csv} with {len(summary_rows)} rows")
|
||||
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
|
||||
print(f"Wrote {args.out_meta_json}")
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user