Add evaluation scripts for next-event prediction and horizon-capture evaluation with detailed metric disclaimers
This commit is contained in:
38
README.md
38
README.md
@@ -1,2 +1,40 @@
|
|||||||
# DeepHealth
|
# DeepHealth
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
This repo includes two event-driven evaluation entrypoints:
|
||||||
|
|
||||||
|
- `evaluate_next_event.py`: next-event prediction using short-window CIF
|
||||||
|
- `evaluate_horizon.py`: horizon-capture evaluation using CIF at multiple horizons
|
||||||
|
|
||||||
|
### IMPORTANT metric disclaimers
|
||||||
|
|
||||||
|
- **AUC** reported by `evaluate_horizon.py` is “time-dependent” only because the label depends on the chosen horizon $\tau$.
|
||||||
|
Without explicit follow-up end times / censoring, this is **not** a classical risk-set AUC with IPCW.
|
||||||
|
Use it for **model comparison and diagnostics**, not strict statistical interpretation.
|
||||||
|
|
||||||
|
- **Brier score** reported by `evaluate_horizon.py` is an unadjusted diagnostic/proxy metric (no censoring adjustment).
|
||||||
|
Use it to detect probability-mass compression / numerical stability issues; do not claim calibrated absolute risk.
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Next-event (no --horizons)
|
||||||
|
python evaluate_next_event.py \
|
||||||
|
--run_dir runs/your_run \
|
||||||
|
--tau_short 0.25 \
|
||||||
|
--age_bins 40 45 50 55 60 65 70 inf \
|
||||||
|
--device cuda \
|
||||||
|
--batch_size 256 \
|
||||||
|
--seed 0
|
||||||
|
|
||||||
|
# Horizon-capture
|
||||||
|
python evaluate_horizon.py \
|
||||||
|
--run_dir runs/your_run \
|
||||||
|
--horizons 0.25 0.5 1.0 2.0 5.0 10.0 \
|
||||||
|
--age_bins 40 45 50 55 60 65 70 inf \
|
||||||
|
--device cuda \
|
||||||
|
--batch_size 256 \
|
||||||
|
--seed 0
|
||||||
|
```
|
||||||
|
|
||||||
|
|||||||
277
evaluate_horizon.py
Normal file
277
evaluate_horizon.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Horizon-capture evaluation.
|
||||||
|
|
||||||
|
DISCLAIMERS (important):
|
||||||
|
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
|
||||||
|
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
|
||||||
|
Use it for model comparison and diagnostics, not strict statistical interpretation.
|
||||||
|
|
||||||
|
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment).
|
||||||
|
Use them to detect probability-mass compression / numerical stability issues; do not claim
|
||||||
|
calibrated absolute risk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
EvalRecordDataset,
|
||||||
|
build_dataset_from_config,
|
||||||
|
build_event_driven_records,
|
||||||
|
build_model_head_criterion,
|
||||||
|
eval_collate_fn,
|
||||||
|
flatten_future_events,
|
||||||
|
get_test_subset,
|
||||||
|
load_checkpoint_into,
|
||||||
|
load_train_config,
|
||||||
|
make_inference_dataloader_kwargs,
|
||||||
|
parse_float_list,
|
||||||
|
predict_cifs,
|
||||||
|
roc_auc_ovr,
|
||||||
|
seed_everything,
|
||||||
|
topk_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Evaluate horizon-capture using CIF at horizons")
|
||||||
|
p.add_argument("--run_dir", type=str, required=True)
|
||||||
|
p.add_argument(
|
||||||
|
"--horizons",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
|
||||||
|
help="Horizon grid in years",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--age_bins",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
|
||||||
|
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
|
||||||
|
)
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
|
)
|
||||||
|
p.add_argument("--batch_size", type=int, default=256)
|
||||||
|
p.add_argument("--num_workers", type=int, default=0)
|
||||||
|
p.add_argument("--seed", type=int, default=0)
|
||||||
|
p.add_argument("--min_pos", type=int, default=20)
|
||||||
|
p.add_argument(
|
||||||
|
"--topk_list",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[1, 5, 10, 20, 50],
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_labels_within_tau_flat(
|
||||||
|
n_records: int,
|
||||||
|
n_causes: int,
|
||||||
|
event_record_idx: np.ndarray,
|
||||||
|
event_cause: np.ndarray,
|
||||||
|
event_dt_years: np.ndarray,
|
||||||
|
tau_years: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Build y_within_tau using a flattened (record,cause,dt) representation.
|
||||||
|
|
||||||
|
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
|
||||||
|
occurs in (t0, t0+tau].
|
||||||
|
"""
|
||||||
|
y = np.zeros((n_records, n_causes), dtype=np.int8)
|
||||||
|
if event_dt_years.size == 0:
|
||||||
|
return y
|
||||||
|
m = event_dt_years <= float(tau_years)
|
||||||
|
if not np.any(m):
|
||||||
|
return y
|
||||||
|
y[event_record_idx[m], event_cause[m]] = 1
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
run_dir = args.run_dir
|
||||||
|
cfg = load_train_config(run_dir)
|
||||||
|
|
||||||
|
dataset = build_dataset_from_config(cfg)
|
||||||
|
test_subset = get_test_subset(dataset, cfg)
|
||||||
|
|
||||||
|
age_bins_years = parse_float_list(args.age_bins)
|
||||||
|
horizons = parse_float_list(args.horizons)
|
||||||
|
horizons = [float(h) for h in horizons]
|
||||||
|
|
||||||
|
records = build_event_driven_records(
|
||||||
|
dataset=dataset,
|
||||||
|
subset=test_subset,
|
||||||
|
age_bins_years=age_bins_years,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||||
|
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||||
|
|
||||||
|
rec_ds = EvalRecordDataset(dataset, records)
|
||||||
|
|
||||||
|
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||||
|
loader = DataLoader(
|
||||||
|
rec_ds,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
collate_fn=eval_collate_fn,
|
||||||
|
**dl_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print disclaimers every run (requested)
|
||||||
|
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
|
||||||
|
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
|
||||||
|
|
||||||
|
scores = predict_cifs(model, head, criterion, loader,
|
||||||
|
horizons, device=device)
|
||||||
|
# scores shape: (N, K, H)
|
||||||
|
if scores.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
|
||||||
|
|
||||||
|
N, K, H = scores.shape
|
||||||
|
if N != len(records):
|
||||||
|
raise ValueError("Record count mismatch")
|
||||||
|
|
||||||
|
# Pre-flatten all future events once to avoid repeated per-record scans.
|
||||||
|
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
|
||||||
|
|
||||||
|
per_tau_rows: List[Dict[str, object]] = []
|
||||||
|
per_cause_rows: List[Dict[str, object]] = []
|
||||||
|
workload_rows: List[Dict[str, object]] = []
|
||||||
|
|
||||||
|
for h_idx, tau in enumerate(horizons):
|
||||||
|
s_tau = scores[:, :, h_idx]
|
||||||
|
y_tau = build_labels_within_tau_flat(
|
||||||
|
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
|
||||||
|
|
||||||
|
# Per-cause counts + Brier (vectorized)
|
||||||
|
n_pos = y_tau.sum(axis=0).astype(np.int64)
|
||||||
|
n_neg = (int(N) - n_pos).astype(np.int64)
|
||||||
|
|
||||||
|
# Brier per cause: mean_i (y - s)^2
|
||||||
|
brier_per_cause = np.mean(
|
||||||
|
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
|
||||||
|
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
|
||||||
|
brier_weighted = float(np.sum(
|
||||||
|
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
|
||||||
|
|
||||||
|
# AUC: compute only for causes with enough positives and at least 1 negative
|
||||||
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||||
|
min_pos = int(args.min_pos)
|
||||||
|
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
||||||
|
for k in candidates:
|
||||||
|
auc[k] = roc_auc_ovr(y_tau[:, k].astype(
|
||||||
|
np.int32), s_tau[:, k].astype(np.float64))
|
||||||
|
|
||||||
|
finite_auc = auc[np.isfinite(auc)]
|
||||||
|
auc_macro = float(np.mean(finite_auc)
|
||||||
|
) if finite_auc.size > 0 else float("nan")
|
||||||
|
w_mask = np.isfinite(auc)
|
||||||
|
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
|
||||||
|
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
|
||||||
|
n_valid_auc = int(np.isfinite(auc).sum())
|
||||||
|
|
||||||
|
# Append per-cause rows (vectorized via DataFrame to avoid Python loops)
|
||||||
|
per_cause_rows.append(
|
||||||
|
pd.DataFrame(
|
||||||
|
{
|
||||||
|
"tau_years": float(tau),
|
||||||
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
|
"n_pos": n_pos,
|
||||||
|
"n_neg": n_neg,
|
||||||
|
"auc": auc,
|
||||||
|
"brier": brier_per_cause,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Business metrics for each topK
|
||||||
|
denom_true_pairs = int(y_tau.sum())
|
||||||
|
for topk in args.topk_list:
|
||||||
|
topk = int(topk)
|
||||||
|
idx = topk_indices(s_tau, topk)
|
||||||
|
captured = np.take_along_axis(y_tau, idx, axis=1)
|
||||||
|
hits = captured.sum(axis=1).astype(np.float64)
|
||||||
|
true_cnt = y_tau.sum(axis=1).astype(np.float64)
|
||||||
|
|
||||||
|
precision_like = hits / float(min(topk, K))
|
||||||
|
mean_precision = float(np.mean(precision_like)
|
||||||
|
) if N > 0 else float("nan")
|
||||||
|
|
||||||
|
mask_has_true = true_cnt > 0
|
||||||
|
recall_like = np.full((N,), np.nan, dtype=np.float64)
|
||||||
|
recall_like[mask_has_true] = hits[mask_has_true] / \
|
||||||
|
true_cnt[mask_has_true]
|
||||||
|
mean_recall = float(np.nanmean(recall_like)) if np.any(
|
||||||
|
mask_has_true) else float("nan")
|
||||||
|
median_recall = float(np.nanmedian(recall_like)) if np.any(
|
||||||
|
mask_has_true) else float("nan")
|
||||||
|
|
||||||
|
numer_captured_pairs = int(captured.sum())
|
||||||
|
pop_capture_rate = float(
|
||||||
|
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
|
||||||
|
|
||||||
|
workload_rows.append(
|
||||||
|
{
|
||||||
|
"tau_years": float(tau),
|
||||||
|
"topk": int(topk),
|
||||||
|
"population_capture_rate": pop_capture_rate,
|
||||||
|
"mean_precision_like": mean_precision,
|
||||||
|
"mean_recall_like": mean_recall,
|
||||||
|
"median_recall_like": median_recall,
|
||||||
|
"denom_true_pairs": denom_true_pairs,
|
||||||
|
"numer_captured_pairs": numer_captured_pairs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
per_tau_rows.append(
|
||||||
|
{
|
||||||
|
"tau_years": float(tau),
|
||||||
|
"n_records": int(N),
|
||||||
|
"n_causes": int(K),
|
||||||
|
"auc_macro": auc_macro,
|
||||||
|
"auc_weighted_by_npos": auc_weighted,
|
||||||
|
"n_causes_valid_auc": int(n_valid_auc),
|
||||||
|
"brier_macro": brier_macro,
|
||||||
|
"brier_weighted_by_npos": brier_weighted,
|
||||||
|
"total_true_pairs": denom_true_pairs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
out_metrics = os.path.join(run_dir, "horizon_metrics.csv")
|
||||||
|
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
|
||||||
|
out_wy = os.path.join(run_dir, "workload_yield.csv")
|
||||||
|
|
||||||
|
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False)
|
||||||
|
if per_cause_rows:
|
||||||
|
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
|
||||||
|
else:
|
||||||
|
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
|
||||||
|
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
|
||||||
|
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
|
||||||
|
|
||||||
|
print(f"Wrote {out_metrics}")
|
||||||
|
print(f"Wrote {out_pc}")
|
||||||
|
print(f"Wrote {out_wy}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
188
evaluate_next_event.py
Normal file
188
evaluate_next_event.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
EvalRecordDataset,
|
||||||
|
build_dataset_from_config,
|
||||||
|
build_event_driven_records,
|
||||||
|
build_model_head_criterion,
|
||||||
|
eval_collate_fn,
|
||||||
|
get_test_subset,
|
||||||
|
make_inference_dataloader_kwargs,
|
||||||
|
load_checkpoint_into,
|
||||||
|
load_train_config,
|
||||||
|
parse_float_list,
|
||||||
|
predict_cifs,
|
||||||
|
roc_auc_ovr,
|
||||||
|
seed_everything,
|
||||||
|
topk_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Evaluate next-event prediction using short-window CIF"
|
||||||
|
)
|
||||||
|
p.add_argument("--run_dir", type=str, required=True)
|
||||||
|
p.add_argument("--tau_short", type=float, required=True, help="years")
|
||||||
|
p.add_argument(
|
||||||
|
"--age_bins",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
|
||||||
|
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
|
||||||
|
)
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
|
)
|
||||||
|
p.add_argument("--batch_size", type=int, default=256)
|
||||||
|
p.add_argument("--num_workers", type=int, default=0)
|
||||||
|
p.add_argument("--seed", type=int, default=0)
|
||||||
|
p.add_argument(
|
||||||
|
"--min_pos",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Minimum positives for per-cause AUC",
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
run_dir = args.run_dir
|
||||||
|
cfg = load_train_config(run_dir)
|
||||||
|
|
||||||
|
dataset = build_dataset_from_config(cfg)
|
||||||
|
test_subset = get_test_subset(dataset, cfg)
|
||||||
|
|
||||||
|
age_bins_years = parse_float_list(args.age_bins)
|
||||||
|
records = build_event_driven_records(
|
||||||
|
dataset=dataset,
|
||||||
|
subset=test_subset,
|
||||||
|
age_bins_years=age_bins_years,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||||
|
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||||
|
|
||||||
|
rec_ds = EvalRecordDataset(dataset, records)
|
||||||
|
|
||||||
|
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||||
|
loader = DataLoader(
|
||||||
|
rec_ds,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
collate_fn=eval_collate_fn,
|
||||||
|
**dl_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
tau = float(args.tau_short)
|
||||||
|
scores = predict_cifs(model, head, criterion, loader, [tau], device=device)
|
||||||
|
# scores shape: (N,K,1) for multi-taus; squeeze last
|
||||||
|
if scores.ndim == 3:
|
||||||
|
scores = scores[:, :, 0]
|
||||||
|
|
||||||
|
n_records_total = len(records)
|
||||||
|
y_next = np.array(
|
||||||
|
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
||||||
|
for r in records],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
eligible = y_next >= 0
|
||||||
|
n_eligible = int(eligible.sum())
|
||||||
|
coverage = float(
|
||||||
|
n_eligible / n_records_total) if n_records_total > 0 else 0.0
|
||||||
|
|
||||||
|
metrics_rows: List[dict] = []
|
||||||
|
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": "n_next_event_eligible", "value": n_eligible})
|
||||||
|
metrics_rows.append({"metric": "coverage", "value": coverage})
|
||||||
|
metrics_rows.append({"metric": "tau_short_years", "value": tau})
|
||||||
|
|
||||||
|
if n_eligible == 0:
|
||||||
|
out_path = os.path.join(run_dir, "next_event_metrics.csv")
|
||||||
|
pd.DataFrame(metrics_rows).to_csv(out_path, index=False)
|
||||||
|
print(f"No eligible records; wrote {out_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
scores_e = scores[eligible]
|
||||||
|
y_e = y_next[eligible]
|
||||||
|
|
||||||
|
pred = scores_e.argmax(axis=1)
|
||||||
|
acc = float((pred == y_e).mean())
|
||||||
|
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
||||||
|
|
||||||
|
# MRR
|
||||||
|
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
||||||
|
ranks = np.empty(y_e.shape[0], dtype=np.int32)
|
||||||
|
for i in range(y_e.shape[0]):
|
||||||
|
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
|
||||||
|
mrr = float((1.0 / ranks).mean())
|
||||||
|
metrics_rows.append({"metric": "mrr", "value": mrr})
|
||||||
|
|
||||||
|
# HitRate@K
|
||||||
|
for k in [1, 3, 5, 10, 20]:
|
||||||
|
topk = topk_indices(scores_e, k)
|
||||||
|
hit = (topk == y_e[:, None]).any(axis=1)
|
||||||
|
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
||||||
|
"value": float(hit.mean())})
|
||||||
|
|
||||||
|
# Macro OvR AUC per cause (optional)
|
||||||
|
K = scores.shape[1]
|
||||||
|
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
||||||
|
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
||||||
|
|
||||||
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||||
|
min_pos = int(args.min_pos)
|
||||||
|
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
||||||
|
for k in candidates:
|
||||||
|
auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
||||||
|
auc[k] = auc_k
|
||||||
|
|
||||||
|
included = (n_pos >= min_pos) & (n_neg > 0)
|
||||||
|
per_cause_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
|
"n_pos": n_pos,
|
||||||
|
"n_neg": n_neg,
|
||||||
|
"auc": auc,
|
||||||
|
"included": included,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
aucs = auc[np.isfinite(auc)]
|
||||||
|
|
||||||
|
if aucs:
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||||
|
else:
|
||||||
|
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||||
|
|
||||||
|
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
|
||||||
|
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
|
||||||
|
|
||||||
|
# optional per-cause
|
||||||
|
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
|
||||||
|
per_cause_df.to_csv(out_pc, index=False)
|
||||||
|
|
||||||
|
print(f"Wrote {out_metrics}")
|
||||||
|
print(f"Wrote {out_pc}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
566
utils.py
Normal file
566
utils.py
Normal file
@@ -0,0 +1,566 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
||||||
|
|
||||||
|
from dataset import HealthDataset
|
||||||
|
from losses import (
|
||||||
|
DiscreteTimeCIFNLLLoss,
|
||||||
|
ExponentialNLLLoss,
|
||||||
|
PiecewiseExponentialCIFNLLLoss,
|
||||||
|
)
|
||||||
|
from model import DelphiFork, SapDelphi, SimpleHead
|
||||||
|
|
||||||
|
|
||||||
|
DAYS_PER_YEAR = 365.25
|
||||||
|
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
|
||||||
|
|
||||||
|
|
||||||
|
def make_inference_dataloader_kwargs(
|
||||||
|
device: torch.device,
|
||||||
|
num_workers: int,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""DataLoader kwargs tuned for inference throughput.
|
||||||
|
|
||||||
|
Behavior/metrics are unchanged; this only impacts speed.
|
||||||
|
"""
|
||||||
|
use_cuda = device.type == "cuda" and torch.cuda.is_available()
|
||||||
|
kwargs: Dict[str, Any] = {
|
||||||
|
"pin_memory": bool(use_cuda),
|
||||||
|
}
|
||||||
|
if num_workers > 0:
|
||||||
|
kwargs["persistent_workers"] = True
|
||||||
|
# default prefetch is 2; set explicitly for clarity.
|
||||||
|
kwargs["prefetch_factor"] = 2
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Config + determinism
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
def _replace_nonstandard_json_numbers(text: str) -> str:
|
||||||
|
# Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats.
|
||||||
|
# Replace bare tokens (not within quotes) with string placeholders.
|
||||||
|
def repl(match: re.Match[str]) -> str:
|
||||||
|
token = match.group(0)
|
||||||
|
if token == "-Infinity":
|
||||||
|
return '"__NINF__"'
|
||||||
|
if token == "Infinity":
|
||||||
|
return '"__INF__"'
|
||||||
|
if token == "NaN":
|
||||||
|
return '"__NAN__"'
|
||||||
|
return token
|
||||||
|
|
||||||
|
return re.sub(r'(?<![\w\."])(-Infinity|Infinity|NaN)(?![\w\."])', repl, text)
|
||||||
|
|
||||||
|
|
||||||
|
def _restore_placeholders(obj: Any) -> Any:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _restore_placeholders(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_restore_placeholders(v) for v in obj]
|
||||||
|
if obj == "__INF__":
|
||||||
|
return float("inf")
|
||||||
|
if obj == "__NINF__":
|
||||||
|
return float("-inf")
|
||||||
|
if obj == "__NAN__":
|
||||||
|
return float("nan")
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_config(run_dir: str) -> Dict[str, Any]:
|
||||||
|
cfg_path = os.path.join(run_dir, "train_config.json")
|
||||||
|
with open(cfg_path, "r", encoding="utf-8") as f:
|
||||||
|
raw = f.read()
|
||||||
|
raw = _replace_nonstandard_json_numbers(raw)
|
||||||
|
cfg = json.loads(raw)
|
||||||
|
cfg = _restore_placeholders(cfg)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_float_list(values: Sequence[str]) -> List[float]:
|
||||||
|
out: List[float] = []
|
||||||
|
for v in values:
|
||||||
|
s = str(v).strip().lower()
|
||||||
|
if s in {"inf", "+inf", "infty", "infinity", "+infinity"}:
|
||||||
|
out.append(float("inf"))
|
||||||
|
elif s in {"-inf", "-infty", "-infinity"}:
|
||||||
|
out.append(float("-inf"))
|
||||||
|
else:
|
||||||
|
out.append(float(v))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Dataset + split (match train.py)
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset:
|
||||||
|
data_prefix = cfg["data_prefix"]
|
||||||
|
full_cov = bool(cfg.get("full_cov", False))
|
||||||
|
|
||||||
|
if full_cov:
|
||||||
|
cov_list = None
|
||||||
|
else:
|
||||||
|
cov_list = ["bmi", "smoking", "alcohol"]
|
||||||
|
|
||||||
|
dataset = HealthDataset(
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
covariate_list=cov_list,
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset:
|
||||||
|
n_total = len(dataset)
|
||||||
|
train_ratio = float(cfg["train_ratio"])
|
||||||
|
val_ratio = float(cfg["val_ratio"])
|
||||||
|
seed = int(cfg["random_seed"])
|
||||||
|
|
||||||
|
n_train = int(n_total * train_ratio)
|
||||||
|
n_val = int(n_total * val_ratio)
|
||||||
|
n_test = n_total - n_train - n_val
|
||||||
|
|
||||||
|
_, _, test_subset = random_split(
|
||||||
|
dataset,
|
||||||
|
[n_train, n_val, n_test],
|
||||||
|
generator=torch.Generator().manual_seed(seed),
|
||||||
|
)
|
||||||
|
return test_subset
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Model + head + criterion (match train.py)
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
def build_model_head_criterion(
|
||||||
|
cfg: Dict[str, Any],
|
||||||
|
dataset: HealthDataset,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
||||||
|
loss_type = cfg["loss_type"]
|
||||||
|
|
||||||
|
if loss_type == "exponential":
|
||||||
|
criterion = ExponentialNLLLoss(lambda_reg=float(
|
||||||
|
cfg.get("lambda_reg", 0.0))).to(device)
|
||||||
|
out_dims = [dataset.n_disease]
|
||||||
|
elif loss_type == "discrete_time_cif":
|
||||||
|
bin_edges = [float(x) for x in cfg["bin_edges"]]
|
||||||
|
criterion = DiscreteTimeCIFNLLLoss(
|
||||||
|
bin_edges=bin_edges,
|
||||||
|
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
||||||
|
).to(device)
|
||||||
|
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||||||
|
elif loss_type == "pwe_cif":
|
||||||
|
# training drops +inf for PWE
|
||||||
|
raw_edges = [float(x) for x in cfg["bin_edges"]]
|
||||||
|
pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))]
|
||||||
|
if len(pwe_edges) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"pwe_cif requires at least 2 finite bin edges (including 0). "
|
||||||
|
f"Got bin_edges={raw_edges}"
|
||||||
|
)
|
||||||
|
if float(pwe_edges[0]) != 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}")
|
||||||
|
|
||||||
|
criterion = PiecewiseExponentialCIFNLLLoss(
|
||||||
|
bin_edges=pwe_edges,
|
||||||
|
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
||||||
|
).to(device)
|
||||||
|
n_bins = len(pwe_edges) - 1
|
||||||
|
out_dims = [dataset.n_disease, n_bins]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||||
|
|
||||||
|
model_type = cfg["model_type"]
|
||||||
|
if model_type == "delphi_fork":
|
||||||
|
model = DelphiFork(
|
||||||
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=N_TECH_TOKENS,
|
||||||
|
n_embd=int(cfg["n_embd"]),
|
||||||
|
n_head=int(cfg["n_head"]),
|
||||||
|
n_layer=int(cfg["n_layer"]),
|
||||||
|
pdrop=float(cfg.get("pdrop", 0.0)),
|
||||||
|
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||||||
|
n_cont=int(dataset.n_cont),
|
||||||
|
n_cate=int(dataset.n_cate),
|
||||||
|
cate_dims=list(dataset.cate_dims),
|
||||||
|
).to(device)
|
||||||
|
elif model_type == "sap_delphi":
|
||||||
|
model = SapDelphi(
|
||||||
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=N_TECH_TOKENS,
|
||||||
|
n_embd=int(cfg["n_embd"]),
|
||||||
|
n_head=int(cfg["n_head"]),
|
||||||
|
n_layer=int(cfg["n_layer"]),
|
||||||
|
pdrop=float(cfg.get("pdrop", 0.0)),
|
||||||
|
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||||||
|
n_cont=int(dataset.n_cont),
|
||||||
|
n_cate=int(dataset.n_cate),
|
||||||
|
cate_dims=list(dataset.cate_dims),
|
||||||
|
pretrained_weights_path=str(
|
||||||
|
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
|
||||||
|
freeze_embeddings=True,
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model_type: {model_type}")
|
||||||
|
|
||||||
|
head = SimpleHead(
|
||||||
|
n_embd=int(cfg["n_embd"]),
|
||||||
|
out_dims=list(out_dims),
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
return model, head, criterion
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_into(
|
||||||
|
run_dir: str,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
head: torch.nn.Module,
|
||||||
|
criterion: Optional[torch.nn.Module],
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
ckpt_path = os.path.join(run_dir, "best_model.pt")
|
||||||
|
ckpt = torch.load(ckpt_path, map_location=device)
|
||||||
|
model.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||||||
|
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||||||
|
if criterion is not None and "criterion_state_dict" in ckpt:
|
||||||
|
try:
|
||||||
|
criterion.load_state_dict(
|
||||||
|
ckpt["criterion_state_dict"], strict=False)
|
||||||
|
except Exception:
|
||||||
|
# Criterion state is not essential for inference.
|
||||||
|
pass
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Evaluation record construction (event-driven)
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EvalRecord:
|
||||||
|
patient_idx: int
|
||||||
|
patient_id: Any
|
||||||
|
doa_days: float
|
||||||
|
t0_days: float
|
||||||
|
cutoff_pos: int # baseline position (inclusive)
|
||||||
|
next_event_cause: Optional[int]
|
||||||
|
next_event_dt_years: Optional[float]
|
||||||
|
future_causes: np.ndarray # (E,) in [0..K-1]
|
||||||
|
future_dt_years: np.ndarray # (E,) strictly > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _to_days(x_years: float) -> float:
|
||||||
|
if math.isinf(float(x_years)):
|
||||||
|
return float("inf")
|
||||||
|
return float(x_years) * DAYS_PER_YEAR
|
||||||
|
|
||||||
|
|
||||||
|
def build_event_driven_records(
|
||||||
|
dataset: HealthDataset,
|
||||||
|
subset: Subset,
|
||||||
|
age_bins_years: Sequence[float],
|
||||||
|
seed: int,
|
||||||
|
) -> List[EvalRecord]:
|
||||||
|
if len(age_bins_years) < 2:
|
||||||
|
raise ValueError("age_bins must have at least 2 boundaries")
|
||||||
|
|
||||||
|
age_bins_days = [_to_days(b) for b in age_bins_years]
|
||||||
|
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
|
||||||
|
raise ValueError("age_bins must be strictly increasing")
|
||||||
|
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
records: List[EvalRecord] = []
|
||||||
|
|
||||||
|
# Subset.indices is deterministic from random_split
|
||||||
|
indices = list(getattr(subset, "indices", range(len(subset))))
|
||||||
|
|
||||||
|
# Speed: avoid calling dataset.__getitem__ for every patient here.
|
||||||
|
# We only need DOA + event times/codes to create evaluation records.
|
||||||
|
eps = 1e-6
|
||||||
|
for patient_idx in indices:
|
||||||
|
patient_id = dataset.patient_ids[patient_idx]
|
||||||
|
|
||||||
|
doa_days = float(dataset._doa[patient_idx])
|
||||||
|
|
||||||
|
raw_records = dataset.patient_events.get(patient_id, [])
|
||||||
|
if raw_records:
|
||||||
|
times = np.asarray([t for t, _ in raw_records], dtype=np.float64)
|
||||||
|
codes = np.asarray([c for _, c in raw_records], dtype=np.int64)
|
||||||
|
else:
|
||||||
|
times = np.zeros((0,), dtype=np.float64)
|
||||||
|
codes = np.zeros((0,), dtype=np.int64)
|
||||||
|
|
||||||
|
# Mirror HealthDataset insertion logic exactly.
|
||||||
|
insert_pos = int(np.searchsorted(times, doa_days, side="left"))
|
||||||
|
times_ins = np.insert(times, insert_pos, doa_days)
|
||||||
|
codes_ins = np.insert(codes, insert_pos, 1)
|
||||||
|
|
||||||
|
is_disease = codes_ins >= N_TECH_TOKENS
|
||||||
|
disease_times = times_ins[is_disease]
|
||||||
|
|
||||||
|
for b in range(len(age_bins_days) - 1):
|
||||||
|
lo = age_bins_days[b]
|
||||||
|
hi = age_bins_days[b + 1]
|
||||||
|
|
||||||
|
# Inclusion rule:
|
||||||
|
# 1) DOA <= bin_upper
|
||||||
|
if not (doa_days <= hi):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA
|
||||||
|
in_bin = (disease_times >= lo) & (
|
||||||
|
disease_times < hi) & (disease_times >= doa_days)
|
||||||
|
cand_times = disease_times[in_bin]
|
||||||
|
if cand_times.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
t0_days = float(rng.choice(cand_times))
|
||||||
|
|
||||||
|
# Baseline position (inclusive) in the *post-DOA-inserted* sequence.
|
||||||
|
pos = np.flatnonzero(is_disease & np.isclose(
|
||||||
|
times_ins, t0_days, rtol=0.0, atol=eps))
|
||||||
|
if pos.size == 0:
|
||||||
|
disease_pos = np.flatnonzero(is_disease)
|
||||||
|
if disease_pos.size == 0:
|
||||||
|
continue
|
||||||
|
disease_times_full = times_ins[disease_pos]
|
||||||
|
closest_idx = int(
|
||||||
|
np.argmin(np.abs(disease_times_full - t0_days)))
|
||||||
|
cutoff_pos = int(disease_pos[closest_idx])
|
||||||
|
t0_days = float(disease_times_full[closest_idx])
|
||||||
|
else:
|
||||||
|
cutoff_pos = int(pos[0])
|
||||||
|
|
||||||
|
# Future disease events strictly after t0
|
||||||
|
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
||||||
|
future_pos = np.flatnonzero(future_mask)
|
||||||
|
if future_pos.size == 0:
|
||||||
|
next_cause = None
|
||||||
|
next_dt_years = None
|
||||||
|
future_causes = np.zeros((0,), dtype=np.int64)
|
||||||
|
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
future_times_days = times_ins[future_pos]
|
||||||
|
future_tokens = codes_ins[future_pos]
|
||||||
|
future_causes = (future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||||
|
future_dt_years_arr = (
|
||||||
|
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
|
||||||
|
|
||||||
|
# next-event = minimal time > t0 (tie broken by earliest position)
|
||||||
|
next_idx = int(np.argmin(future_times_days))
|
||||||
|
next_cause = int(future_causes[next_idx])
|
||||||
|
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||||
|
|
||||||
|
records.append(
|
||||||
|
EvalRecord(
|
||||||
|
patient_idx=int(patient_idx),
|
||||||
|
patient_id=patient_id,
|
||||||
|
doa_days=float(doa_days),
|
||||||
|
t0_days=float(t0_days),
|
||||||
|
cutoff_pos=int(cutoff_pos),
|
||||||
|
next_event_cause=next_cause,
|
||||||
|
next_event_dt_years=next_dt_years,
|
||||||
|
future_causes=future_causes,
|
||||||
|
future_dt_years=future_dt_years_arr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRecordDataset(Dataset):
|
||||||
|
def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]):
|
||||||
|
self.base = base_dataset
|
||||||
|
self.records = list(records)
|
||||||
|
self._cache: Dict[int, Tuple[torch.Tensor,
|
||||||
|
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
|
||||||
|
self._cache_order: List[int] = []
|
||||||
|
self._cache_max = 2048
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.records)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
rec = self.records[idx]
|
||||||
|
cached = self._cache.get(rec.patient_idx)
|
||||||
|
if cached is None:
|
||||||
|
event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx]
|
||||||
|
cached = (event_seq, time_seq, cont, cate, int(sex))
|
||||||
|
self._cache[rec.patient_idx] = cached
|
||||||
|
self._cache_order.append(rec.patient_idx)
|
||||||
|
if len(self._cache_order) > self._cache_max:
|
||||||
|
drop = self._cache_order.pop(0)
|
||||||
|
self._cache.pop(drop, None)
|
||||||
|
else:
|
||||||
|
event_seq, time_seq, cont, cate, sex = cached
|
||||||
|
cutoff = rec.cutoff_pos + 1
|
||||||
|
event_seq = event_seq[:cutoff]
|
||||||
|
time_seq = time_seq[:cutoff]
|
||||||
|
baseline_pos = rec.cutoff_pos # same index in truncated sequence
|
||||||
|
return event_seq, time_seq, cont, cate, sex, baseline_pos
|
||||||
|
|
||||||
|
|
||||||
|
def eval_collate_fn(batch):
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip(
|
||||||
|
*batch)
|
||||||
|
event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0)
|
||||||
|
time_batch = pad_sequence(
|
||||||
|
time_seqs, batch_first=True, padding_value=36525.0)
|
||||||
|
cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1)
|
||||||
|
cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1)
|
||||||
|
sex_batch = torch.tensor(sexes, dtype=torch.long)
|
||||||
|
baseline_pos = torch.tensor(baseline_pos, dtype=torch.long)
|
||||||
|
return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Inference utilities
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
def predict_cifs(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
head: torch.nn.Module,
|
||||||
|
criterion: torch.nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
taus_years: Sequence[float],
|
||||||
|
device: torch.device,
|
||||||
|
) -> np.ndarray:
|
||||||
|
model.eval()
|
||||||
|
head.eval()
|
||||||
|
|
||||||
|
taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
all_out: List[np.ndarray] = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in loader:
|
||||||
|
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
|
||||||
|
event_seq = event_seq.to(device, non_blocking=True)
|
||||||
|
time_seq = time_seq.to(device, non_blocking=True)
|
||||||
|
cont = cont.to(device, non_blocking=True)
|
||||||
|
cate = cate.to(device, non_blocking=True)
|
||||||
|
sex = sex.to(device, non_blocking=True)
|
||||||
|
baseline_pos = baseline_pos.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
h = model(event_seq, time_seq, sex, cont, cate)
|
||||||
|
b_idx = torch.arange(h.size(0), device=device)
|
||||||
|
c = h[b_idx, baseline_pos]
|
||||||
|
logits = head(c)
|
||||||
|
|
||||||
|
cifs = criterion.calculate_cifs(logits, taus_t)
|
||||||
|
out = cifs.detach().cpu().numpy()
|
||||||
|
all_out.append(out)
|
||||||
|
|
||||||
|
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_future_events(
|
||||||
|
records: Sequence[EvalRecord],
|
||||||
|
n_causes: int,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Flatten (record_idx, cause, dt_years) across all future events.
|
||||||
|
|
||||||
|
Used to build horizon labels via vectorized masking + scatter.
|
||||||
|
"""
|
||||||
|
rec_idx_parts: List[np.ndarray] = []
|
||||||
|
cause_parts: List[np.ndarray] = []
|
||||||
|
dt_parts: List[np.ndarray] = []
|
||||||
|
|
||||||
|
for i, r in enumerate(records):
|
||||||
|
if r.future_causes.size == 0:
|
||||||
|
continue
|
||||||
|
causes = r.future_causes
|
||||||
|
dts = r.future_dt_years
|
||||||
|
# Keep only valid cause ids.
|
||||||
|
m = (causes >= 0) & (causes < n_causes)
|
||||||
|
if not np.any(m):
|
||||||
|
continue
|
||||||
|
causes = causes[m].astype(np.int64, copy=False)
|
||||||
|
dts = dts[m].astype(np.float32, copy=False)
|
||||||
|
rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
||||||
|
cause_parts.append(causes)
|
||||||
|
dt_parts.append(dts)
|
||||||
|
|
||||||
|
if not rec_idx_parts:
|
||||||
|
return (
|
||||||
|
np.zeros((0,), dtype=np.int32),
|
||||||
|
np.zeros((0,), dtype=np.int64),
|
||||||
|
np.zeros((0,), dtype=np.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
np.concatenate(rec_idx_parts, axis=0),
|
||||||
|
np.concatenate(cause_parts, axis=0),
|
||||||
|
np.concatenate(dt_parts, axis=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Metrics helpers
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||||
|
"""Binary ROC AUC with tie-aware average ranks.
|
||||||
|
|
||||||
|
Returns NaN if y_true has no positives or no negatives.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(np.int32)
|
||||||
|
y_score = np.asarray(y_score).astype(np.float64)
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
n = int(y_true.size)
|
||||||
|
n_neg = n - n_pos
|
||||||
|
if n_pos == 0 or n_neg == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
order = np.argsort(y_score, kind="mergesort")
|
||||||
|
scores_sorted = y_score[order]
|
||||||
|
y_sorted = y_true[order]
|
||||||
|
|
||||||
|
ranks = np.empty(n, dtype=np.float64)
|
||||||
|
i = 0
|
||||||
|
while i < n:
|
||||||
|
j = i + 1
|
||||||
|
while j < n and scores_sorted[j] == scores_sorted[i]:
|
||||||
|
j += 1
|
||||||
|
# average rank for ties, ranks are 1..n
|
||||||
|
avg_rank = 0.5 * (i + 1 + j)
|
||||||
|
ranks[i:j] = avg_rank
|
||||||
|
i = j
|
||||||
|
|
||||||
|
sum_ranks_pos = float((ranks * y_sorted).sum())
|
||||||
|
auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
|
||||||
|
return float(auc)
|
||||||
|
|
||||||
|
|
||||||
|
def topk_indices(scores: np.ndarray, k: int) -> np.ndarray:
|
||||||
|
"""Return indices of top-k scores per row (descending)."""
|
||||||
|
if k <= 0:
|
||||||
|
raise ValueError("k must be positive")
|
||||||
|
n, K = scores.shape
|
||||||
|
k = min(k, K)
|
||||||
|
# argpartition gives arbitrary order within topk; sort those by score
|
||||||
|
part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
|
||||||
|
part_scores = np.take_along_axis(scores, part, axis=1)
|
||||||
|
order = np.argsort(-part_scores, axis=1, kind="mergesort")
|
||||||
|
return np.take_along_axis(part, order, axis=1)
|
||||||
Reference in New Issue
Block a user