Add evaluation scripts for age-bin time-dependent metrics and remove obsolete evaluation_time_dependent.py
This commit is contained in:
234
evaluate.py
234
evaluate.py
@@ -1,234 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from evaluation_time_dependent import EvalConfig, evaluate_time_dependent
|
||||
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
|
||||
|
||||
def _parse_floats(items: Sequence[str]) -> List[float]:
|
||||
out: List[float] = []
|
||||
for x in items:
|
||||
x = x.strip()
|
||||
if not x:
|
||||
continue
|
||||
out.append(float(x))
|
||||
return out
|
||||
|
||||
|
||||
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
||||
if loss_type == "exponential":
|
||||
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
||||
out_dims = [n_disease]
|
||||
return criterion, out_dims
|
||||
|
||||
if loss_type == "discrete_time_cif":
|
||||
criterion = DiscreteTimeCIFNLLLoss(
|
||||
bin_edges=bin_edges, lambda_reg=lambda_reg)
|
||||
out_dims = [n_disease + 1, len(bin_edges)]
|
||||
return criterion, out_dims
|
||||
|
||||
if loss_type == "pwe_cif":
|
||||
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
||||
if len(pwe_edges) < 2:
|
||||
raise ValueError(
|
||||
"pwe_cif requires at least 2 finite bin edges (including 0)")
|
||||
if float(pwe_edges[0]) != 0.0:
|
||||
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
|
||||
criterion = PiecewiseExponentialCIFNLLLoss(
|
||||
bin_edges=pwe_edges, lambda_reg=lambda_reg)
|
||||
n_bins = len(pwe_edges) - 1
|
||||
out_dims = [n_disease, n_bins]
|
||||
return criterion, out_dims
|
||||
|
||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||
|
||||
|
||||
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
||||
if model_type == "delphi_fork":
|
||||
return DelphiFork(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=int(cfg["n_embd"]),
|
||||
n_head=int(cfg["n_head"]),
|
||||
n_layer=int(cfg["n_layer"]),
|
||||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||||
age_encoder_type=str(cfg["age_encoder"]),
|
||||
n_cont=int(dataset.n_cont),
|
||||
n_cate=int(dataset.n_cate),
|
||||
cate_dims=list(dataset.cate_dims),
|
||||
)
|
||||
|
||||
if model_type == "sap_delphi":
|
||||
return SapDelphi(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=int(cfg["n_embd"]),
|
||||
n_head=int(cfg["n_head"]),
|
||||
n_layer=int(cfg["n_layer"]),
|
||||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||||
age_encoder_type=str(cfg["age_encoder"]),
|
||||
n_cont=int(dataset.n_cont),
|
||||
n_cate=int(dataset.n_cate),
|
||||
cate_dims=list(dataset.cate_dims),
|
||||
pretrained_weights_path=str(
|
||||
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
|
||||
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported model_type: {model_type}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Time-dependent evaluation for DeepHealth")
|
||||
parser.add_argument(
|
||||
"--run_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Training run directory (contains best_model.pt and train_config.json)",
|
||||
)
|
||||
parser.add_argument("--data_prefix", type=str, default=None,
|
||||
help="Dataset prefix (overrides config if provided)")
|
||||
parser.add_argument("--split", type=str,
|
||||
choices=["train", "val", "test", "all"], default="val")
|
||||
|
||||
parser.add_argument("--horizons", type=str, nargs="+",
|
||||
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)")
|
||||
parser.add_argument("--offset_years", type=float, default=0.0,
|
||||
help="Context selection offset (years before follow-up end)")
|
||||
parser.add_argument(
|
||||
"--topk_percent",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[1, 5, 10, 20, 50],
|
||||
help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)",
|
||||
)
|
||||
|
||||
parser.add_argument("--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--batch_size", type=int, default=256)
|
||||
parser.add_argument("--num_workers", type=int,
|
||||
default=0, help="Keep 0 on Windows")
|
||||
|
||||
parser.add_argument("--out_csv", type=str, default=None,
|
||||
help="Optional output CSV path")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
|
||||
cfg_path = os.path.join(args.run_dir, "train_config.json")
|
||||
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
|
||||
if not os.path.exists(cfg_path):
|
||||
raise SystemExit(f"Missing config: {cfg_path}")
|
||||
|
||||
with open(cfg_path, "r") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
|
||||
"data_prefix", "ukb")
|
||||
|
||||
# Match training covariate selection.
|
||||
full_cov = bool(cfg.get("full_cov", False))
|
||||
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
||||
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
|
||||
|
||||
# Recreate the same split scheme as train.py
|
||||
train_ratio = float(cfg.get("train_ratio", 0.7))
|
||||
val_ratio = float(cfg.get("val_ratio", 0.15))
|
||||
seed = int(cfg.get("random_seed", 42))
|
||||
|
||||
n_total = len(dataset)
|
||||
n_train = int(n_total * train_ratio)
|
||||
n_val = int(n_total * val_ratio)
|
||||
n_test = n_total - n_train - n_val
|
||||
train_ds, val_ds, test_ds = random_split(
|
||||
dataset,
|
||||
[n_train, n_val, n_test],
|
||||
generator=torch.Generator().manual_seed(seed),
|
||||
)
|
||||
|
||||
if args.split == "train":
|
||||
ds = train_ds
|
||||
elif args.split == "val":
|
||||
ds = val_ds
|
||||
elif args.split == "test":
|
||||
ds = test_ds
|
||||
else:
|
||||
ds = dataset
|
||||
|
||||
loader = DataLoader(
|
||||
ds,
|
||||
batch_size=int(args.batch_size),
|
||||
shuffle=False,
|
||||
collate_fn=health_collate_fn,
|
||||
num_workers=int(args.num_workers),
|
||||
pin_memory=str(args.device).startswith("cuda"),
|
||||
)
|
||||
|
||||
criterion, out_dims = build_criterion_and_out_dims(
|
||||
loss_type=str(cfg["loss_type"]),
|
||||
n_disease=int(dataset.n_disease),
|
||||
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
||||
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
||||
)
|
||||
|
||||
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
||||
|
||||
device = torch.device(args.device)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
||||
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
||||
if "criterion_state_dict" in checkpoint:
|
||||
try:
|
||||
criterion.load_state_dict(
|
||||
checkpoint["criterion_state_dict"], strict=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model.to(device)
|
||||
head.to(device)
|
||||
criterion.to(device)
|
||||
|
||||
eval_cfg = EvalConfig(
|
||||
horizons_years=_parse_floats(args.horizons),
|
||||
offset_years=float(args.offset_years),
|
||||
topk_percents=[float(x) for x in args.topk_percent],
|
||||
cause_ids=None,
|
||||
)
|
||||
|
||||
df = evaluate_time_dependent(
|
||||
model=model,
|
||||
head=head,
|
||||
criterion=criterion,
|
||||
dataloader=loader,
|
||||
n_disease=int(dataset.n_disease),
|
||||
cfg=eval_cfg,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if args.out_csv is None:
|
||||
out_csv = os.path.join(
|
||||
args.run_dir, f"time_dependent_metrics_{args.split}.csv")
|
||||
else:
|
||||
out_csv = args.out_csv
|
||||
|
||||
df.to_csv(out_csv, index=False)
|
||||
print(f"Wrote: {out_csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user