Files
DeepHealth/evaluate_age.py

486 lines
16 KiB
Python

from __future__ import annotations
import argparse
import json
import math
import os
import multiprocessing as mp
from typing import List, Sequence, Tuple
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from evaluation_age_time_dependent import (
EvalAgeConfig,
aggregate_age_bin_results,
evaluate_time_dependent_age_bins,
)
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
from model import DelphiFork, SapDelphi, SimpleHead
def _parse_floats(items: Sequence[str]) -> List[float]:
out: List[float] = []
for x in items:
x = x.strip()
if not x:
continue
out.append(float(x))
return out
def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]:
vals = _parse_floats(edges)
if len(vals) < 2:
raise ValueError("--age_bin_edges must have at least 2 values")
for i in range(1, len(vals)):
if not (vals[i] > vals[i - 1]):
raise ValueError("--age_bin_edges must be strictly increasing")
return vals
def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
def _parse_gpus(gpus: str | None) -> List[int]:
if gpus is None:
return []
s = gpus.strip()
if not s:
return []
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[int] = []
for p in parts:
out.append(int(p))
return out
def _worker_eval_mcs_on_gpu(
queue: "mp.Queue",
*,
run_dir: str,
split: str,
data_prefix_override: str | None,
horizons: List[float],
age_bins: List[Tuple[float, float]],
topk_percents: List[float],
n_mc: int,
seed: int,
batch_size: int,
num_workers: int,
gpu_id: int,
mc_indices: List[int],
out_path: str,
) -> None:
"""Worker process: evaluate a subset of MC indices on a single GPU."""
try:
ckpt_path = os.path.join(run_dir, "best_model.pt")
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
data_prefix = (
data_prefix_override
if data_prefix_override is not None
else cfg.get("data_prefix", "ukb")
)
full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix,
covariate_list=cov_list)
train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15))
seed_split = int(cfg.get("random_seed", 42))
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed_split),
)
if split == "train":
ds = train_ds
elif split == "val":
ds = val_ds
elif split == "test":
ds = test_ds
else:
ds = dataset
loader = DataLoader(
ds,
batch_size=int(batch_size),
shuffle=False,
collate_fn=health_collate_fn,
num_workers=int(num_workers),
pin_memory=True,
)
criterion, out_dims = build_criterion_and_out_dims(
loss_type=str(cfg["loss_type"]),
n_disease=int(dataset.n_disease),
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
)
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
device = torch.device(f"cuda:{int(gpu_id)}")
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
if "criterion_state_dict" in checkpoint:
try:
criterion.load_state_dict(
checkpoint["criterion_state_dict"], strict=False)
except Exception:
pass
model.to(device)
head.to(device)
criterion.to(device)
frames: List[pd.DataFrame] = []
for mc_idx in mc_indices:
eval_cfg = EvalAgeConfig(
horizons_years=horizons,
age_bins=age_bins,
topk_percents=topk_percents,
n_mc=1,
seed=int(seed),
cause_ids=None,
)
df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
mc_offset=int(mc_idx),
)
frames.append(df_by_bin)
df_all = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame()
df_all.to_csv(out_path, index=False)
queue.put({"ok": True, "out_path": out_path})
except Exception as e:
queue.put({"ok": False, "error": repr(e)})
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
out_dims = [n_disease]
return criterion, out_dims
if loss_type == "discrete_time_cif":
criterion = DiscreteTimeCIFNLLLoss(
bin_edges=bin_edges, lambda_reg=lambda_reg)
out_dims = [n_disease + 1, len(bin_edges)]
return criterion, out_dims
if loss_type == "pwe_cif":
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
"pwe_cif requires at least 2 finite bin edges (including 0)")
if float(pwe_edges[0]) != 0.0:
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
criterion = PiecewiseExponentialCIFNLLLoss(
bin_edges=pwe_edges, lambda_reg=lambda_reg)
n_bins = len(pwe_edges) - 1
out_dims = [n_disease, n_bins]
return criterion, out_dims
raise ValueError(f"Unsupported loss_type: {loss_type}")
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
if model_type == "delphi_fork":
return DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg["age_encoder"]),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
)
if model_type == "sap_delphi":
return SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg["age_encoder"]),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
pretrained_weights_path=str(
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
)
raise ValueError(f"Unsupported model_type: {model_type}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})")
parser.add_argument(
"--run_dir",
type=str,
required=True,
help="Training run directory (contains best_model.pt and train_config.json)",
)
parser.add_argument("--data_prefix", type=str, default=None)
parser.add_argument("--split", type=str,
choices=["train", "val", "test", "all"], default="val")
parser.add_argument("--horizons", type=str, nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"])
parser.add_argument(
"--age_bin_edges",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "80"],
help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).",
)
parser.add_argument(
"--topk_percent",
type=float,
nargs="+",
default=[1, 5, 10, 20, 50],
help="One or more K%% values for recall/precision@K%%",
)
parser.add_argument("--n_mc", type=int, default=5)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3",
)
parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--out_prefix", type=str,
default=None, help="Output prefix for CSVs")
args = parser.parse_args()
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
cfg_path = os.path.join(args.run_dir, "train_config.json")
if not os.path.exists(ckpt_path):
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
if not os.path.exists(cfg_path):
raise SystemExit(f"Missing config: {cfg_path}")
with open(cfg_path, "r") as f:
cfg = json.load(f)
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
"data_prefix", "ukb")
full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15))
seed_split = int(cfg.get("random_seed", 42))
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed_split),
)
if args.split == "train":
ds = train_ds
elif args.split == "val":
ds = val_ds
elif args.split == "test":
ds = test_ds
else:
ds = dataset
loader = DataLoader(
ds,
batch_size=int(args.batch_size),
shuffle=False,
collate_fn=health_collate_fn,
num_workers=int(args.num_workers),
pin_memory=str(args.device).startswith("cuda"),
)
criterion, out_dims = build_criterion_and_out_dims(
loss_type=str(cfg["loss_type"]),
n_disease=int(dataset.n_disease),
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
)
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
device = torch.device(args.device)
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
if "criterion_state_dict" in checkpoint:
try:
criterion.load_state_dict(
checkpoint["criterion_state_dict"], strict=False)
except Exception:
pass
model.to(device)
head.to(device)
criterion.to(device)
age_edges = _parse_age_bin_edges(args.age_bin_edges)
age_bins = _edges_to_bins(age_edges)
eval_cfg = EvalAgeConfig(
horizons_years=_parse_floats(args.horizons),
age_bins=age_bins,
topk_percents=[float(x) for x in args.topk_percent],
n_mc=int(args.n_mc),
seed=int(args.seed),
cause_ids=None,
)
if args.out_prefix is None:
out_prefix = os.path.join(
args.run_dir, f"age_bin_time_dependent_{args.split}")
else:
out_prefix = args.out_prefix
out_bin = out_prefix + "_by_bin.csv"
out_agg = out_prefix + "_agg.csv"
gpus = _parse_gpus(args.gpus)
if len(gpus) <= 1:
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
)
df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False)
print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}")
return
if not torch.cuda.is_available():
raise SystemExit("--gpus was provided but CUDA is not available")
# Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU).
mc_indices_all = list(range(int(args.n_mc)))
per_gpu: List[Tuple[int, List[int]]] = []
for pos, gpu_id in enumerate(gpus):
assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos]
if assigned:
per_gpu.append((int(gpu_id), assigned))
ctx = mp.get_context("spawn")
queue: "mp.Queue" = ctx.Queue()
procs: List[mp.Process] = []
tmp_paths: List[str] = []
for gpu_id, mc_idxs in per_gpu:
tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv"
tmp_paths.append(tmp_path)
p = ctx.Process(
target=_worker_eval_mcs_on_gpu,
kwargs=dict(
queue=queue,
run_dir=str(args.run_dir),
split=str(args.split),
data_prefix_override=(
str(args.data_prefix) if args.data_prefix is not None else None
),
horizons=_parse_floats(args.horizons),
age_bins=age_bins,
topk_percents=[float(x) for x in args.topk_percent],
n_mc=int(args.n_mc),
seed=int(args.seed),
batch_size=int(args.batch_size),
num_workers=int(args.num_workers),
gpu_id=int(gpu_id),
mc_indices=mc_idxs,
out_path=tmp_path,
),
)
p.start()
procs.append(p)
results = [queue.get() for _ in range(len(procs))]
for p in procs:
p.join()
for r in results:
if not r.get("ok", False):
raise SystemExit(f"Worker failed: {r.get('error')}")
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
df_by_bin = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame()
df_agg = aggregate_age_bin_results(df_by_bin)
df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False)
# Best-effort cleanup.
for p in tmp_paths:
try:
if os.path.exists(p):
os.remove(p)
except Exception:
pass
print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}")
if __name__ == "__main__":
main()