Add evaluation and utility functions for time-dependent metrics

- Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference.
- Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds.
- Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models.
- Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons.
This commit is contained in:
2026-01-16 14:55:09 +08:00
parent 660dc969bc
commit 34d8d8ce9d
4 changed files with 1066 additions and 0 deletions

234
evaluate.py Normal file
View File

@@ -0,0 +1,234 @@
from __future__ import annotations
import argparse
import json
import math
import os
from typing import List, Sequence
import torch
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from evaluation_time_dependent import EvalConfig, evaluate_time_dependent
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
from model import DelphiFork, SapDelphi, SimpleHead
def _parse_floats(items: Sequence[str]) -> List[float]:
out: List[float] = []
for x in items:
x = x.strip()
if not x:
continue
out.append(float(x))
return out
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
out_dims = [n_disease]
return criterion, out_dims
if loss_type == "discrete_time_cif":
criterion = DiscreteTimeCIFNLLLoss(
bin_edges=bin_edges, lambda_reg=lambda_reg)
out_dims = [n_disease + 1, len(bin_edges)]
return criterion, out_dims
if loss_type == "pwe_cif":
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
"pwe_cif requires at least 2 finite bin edges (including 0)")
if float(pwe_edges[0]) != 0.0:
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
criterion = PiecewiseExponentialCIFNLLLoss(
bin_edges=pwe_edges, lambda_reg=lambda_reg)
n_bins = len(pwe_edges) - 1
out_dims = [n_disease, n_bins]
return criterion, out_dims
raise ValueError(f"Unsupported loss_type: {loss_type}")
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
if model_type == "delphi_fork":
return DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg["age_encoder"]),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
)
if model_type == "sap_delphi":
return SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg["age_encoder"]),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
pretrained_weights_path=str(
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
)
raise ValueError(f"Unsupported model_type: {model_type}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Time-dependent evaluation for DeepHealth")
parser.add_argument(
"--run_dir",
type=str,
required=True,
help="Training run directory (contains best_model.pt and train_config.json)",
)
parser.add_argument("--data_prefix", type=str, default=None,
help="Dataset prefix (overrides config if provided)")
parser.add_argument("--split", type=str,
choices=["train", "val", "test", "all"], default="val")
parser.add_argument("--horizons", type=str, nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)")
parser.add_argument("--offset_years", type=float, default=0.0,
help="Context selection offset (years before follow-up end)")
parser.add_argument(
"--topk_percent",
type=float,
nargs="+",
default=[1, 5, 10, 20, 50],
help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)",
)
parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--num_workers", type=int,
default=0, help="Keep 0 on Windows")
parser.add_argument("--out_csv", type=str, default=None,
help="Optional output CSV path")
args = parser.parse_args()
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
cfg_path = os.path.join(args.run_dir, "train_config.json")
if not os.path.exists(ckpt_path):
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
if not os.path.exists(cfg_path):
raise SystemExit(f"Missing config: {cfg_path}")
with open(cfg_path, "r") as f:
cfg = json.load(f)
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
"data_prefix", "ukb")
# Match training covariate selection.
full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
# Recreate the same split scheme as train.py
train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15))
seed = int(cfg.get("random_seed", 42))
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
if args.split == "train":
ds = train_ds
elif args.split == "val":
ds = val_ds
elif args.split == "test":
ds = test_ds
else:
ds = dataset
loader = DataLoader(
ds,
batch_size=int(args.batch_size),
shuffle=False,
collate_fn=health_collate_fn,
num_workers=int(args.num_workers),
pin_memory=str(args.device).startswith("cuda"),
)
criterion, out_dims = build_criterion_and_out_dims(
loss_type=str(cfg["loss_type"]),
n_disease=int(dataset.n_disease),
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
)
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
device = torch.device(args.device)
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
if "criterion_state_dict" in checkpoint:
try:
criterion.load_state_dict(
checkpoint["criterion_state_dict"], strict=False)
except Exception:
pass
model.to(device)
head.to(device)
criterion.to(device)
eval_cfg = EvalConfig(
horizons_years=_parse_floats(args.horizons),
offset_years=float(args.offset_years),
topk_percents=[float(x) for x in args.topk_percent],
cause_ids=None,
)
df = evaluate_time_dependent(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
)
if args.out_csv is None:
out_csv = os.path.join(
args.run_dir, f"time_dependent_metrics_{args.split}.csv")
else:
out_csv = args.out_csv
df.to_csv(out_csv, index=False)
print(f"Wrote: {out_csv}")
if __name__ == "__main__":
main()