486 lines
16 KiB
Python
486 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import multiprocessing as mp
|
|
from typing import List, Sequence, Tuple
|
|
|
|
import pandas as pd
|
|
import torch
|
|
from torch.utils.data import DataLoader, random_split
|
|
|
|
from dataset import HealthDataset, health_collate_fn
|
|
from evaluation_age_time_dependent import (
|
|
EvalAgeConfig,
|
|
aggregate_age_bin_results,
|
|
evaluate_time_dependent_age_bins,
|
|
)
|
|
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
|
from model import DelphiFork, SapDelphi, SimpleHead
|
|
|
|
|
|
def _parse_floats(items: Sequence[str]) -> List[float]:
|
|
out: List[float] = []
|
|
for x in items:
|
|
x = x.strip()
|
|
if not x:
|
|
continue
|
|
out.append(float(x))
|
|
return out
|
|
|
|
|
|
def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]:
|
|
vals = _parse_floats(edges)
|
|
if len(vals) < 2:
|
|
raise ValueError("--age_bin_edges must have at least 2 values")
|
|
for i in range(1, len(vals)):
|
|
if not (vals[i] > vals[i - 1]):
|
|
raise ValueError("--age_bin_edges must be strictly increasing")
|
|
return vals
|
|
|
|
|
|
def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
|
|
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
|
|
|
|
|
|
def _parse_gpus(gpus: str | None) -> List[int]:
|
|
if gpus is None:
|
|
return []
|
|
s = gpus.strip()
|
|
if not s:
|
|
return []
|
|
parts = [p.strip() for p in s.split(",") if p.strip()]
|
|
out: List[int] = []
|
|
for p in parts:
|
|
out.append(int(p))
|
|
return out
|
|
|
|
|
|
def _worker_eval_mcs_on_gpu(
|
|
queue: "mp.Queue",
|
|
*,
|
|
run_dir: str,
|
|
split: str,
|
|
data_prefix_override: str | None,
|
|
horizons: List[float],
|
|
age_bins: List[Tuple[float, float]],
|
|
topk_percents: List[float],
|
|
n_mc: int,
|
|
seed: int,
|
|
batch_size: int,
|
|
num_workers: int,
|
|
gpu_id: int,
|
|
mc_indices: List[int],
|
|
out_path: str,
|
|
) -> None:
|
|
"""Worker process: evaluate a subset of MC indices on a single GPU."""
|
|
try:
|
|
ckpt_path = os.path.join(run_dir, "best_model.pt")
|
|
cfg_path = os.path.join(run_dir, "train_config.json")
|
|
with open(cfg_path, "r") as f:
|
|
cfg = json.load(f)
|
|
|
|
data_prefix = (
|
|
data_prefix_override
|
|
if data_prefix_override is not None
|
|
else cfg.get("data_prefix", "ukb")
|
|
)
|
|
|
|
full_cov = bool(cfg.get("full_cov", False))
|
|
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
|
dataset = HealthDataset(data_prefix=data_prefix,
|
|
covariate_list=cov_list)
|
|
|
|
train_ratio = float(cfg.get("train_ratio", 0.7))
|
|
val_ratio = float(cfg.get("val_ratio", 0.15))
|
|
seed_split = int(cfg.get("random_seed", 42))
|
|
|
|
n_total = len(dataset)
|
|
n_train = int(n_total * train_ratio)
|
|
n_val = int(n_total * val_ratio)
|
|
n_test = n_total - n_train - n_val
|
|
|
|
train_ds, val_ds, test_ds = random_split(
|
|
dataset,
|
|
[n_train, n_val, n_test],
|
|
generator=torch.Generator().manual_seed(seed_split),
|
|
)
|
|
|
|
if split == "train":
|
|
ds = train_ds
|
|
elif split == "val":
|
|
ds = val_ds
|
|
elif split == "test":
|
|
ds = test_ds
|
|
else:
|
|
ds = dataset
|
|
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=int(batch_size),
|
|
shuffle=False,
|
|
collate_fn=health_collate_fn,
|
|
num_workers=int(num_workers),
|
|
pin_memory=True,
|
|
)
|
|
|
|
criterion, out_dims = build_criterion_and_out_dims(
|
|
loss_type=str(cfg["loss_type"]),
|
|
n_disease=int(dataset.n_disease),
|
|
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
|
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
|
)
|
|
|
|
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
|
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
|
|
|
device = torch.device(f"cuda:{int(gpu_id)}")
|
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
|
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
|
if "criterion_state_dict" in checkpoint:
|
|
try:
|
|
criterion.load_state_dict(
|
|
checkpoint["criterion_state_dict"], strict=False)
|
|
except Exception:
|
|
pass
|
|
|
|
model.to(device)
|
|
head.to(device)
|
|
criterion.to(device)
|
|
|
|
frames: List[pd.DataFrame] = []
|
|
for mc_idx in mc_indices:
|
|
eval_cfg = EvalAgeConfig(
|
|
horizons_years=horizons,
|
|
age_bins=age_bins,
|
|
topk_percents=topk_percents,
|
|
n_mc=1,
|
|
seed=int(seed),
|
|
cause_ids=None,
|
|
)
|
|
|
|
df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins(
|
|
model=model,
|
|
head=head,
|
|
criterion=criterion,
|
|
dataloader=loader,
|
|
n_disease=int(dataset.n_disease),
|
|
cfg=eval_cfg,
|
|
device=device,
|
|
mc_offset=int(mc_idx),
|
|
)
|
|
frames.append(df_by_bin)
|
|
|
|
df_all = pd.concat(frames, ignore_index=True) if len(
|
|
frames) else pd.DataFrame()
|
|
df_all.to_csv(out_path, index=False)
|
|
queue.put({"ok": True, "out_path": out_path})
|
|
except Exception as e:
|
|
queue.put({"ok": False, "error": repr(e)})
|
|
|
|
|
|
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
|
if loss_type == "exponential":
|
|
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
|
out_dims = [n_disease]
|
|
return criterion, out_dims
|
|
|
|
if loss_type == "discrete_time_cif":
|
|
criterion = DiscreteTimeCIFNLLLoss(
|
|
bin_edges=bin_edges, lambda_reg=lambda_reg)
|
|
out_dims = [n_disease + 1, len(bin_edges)]
|
|
return criterion, out_dims
|
|
|
|
if loss_type == "pwe_cif":
|
|
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
|
if len(pwe_edges) < 2:
|
|
raise ValueError(
|
|
"pwe_cif requires at least 2 finite bin edges (including 0)")
|
|
if float(pwe_edges[0]) != 0.0:
|
|
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
|
|
criterion = PiecewiseExponentialCIFNLLLoss(
|
|
bin_edges=pwe_edges, lambda_reg=lambda_reg)
|
|
n_bins = len(pwe_edges) - 1
|
|
out_dims = [n_disease, n_bins]
|
|
return criterion, out_dims
|
|
|
|
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
|
|
|
|
|
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
|
if model_type == "delphi_fork":
|
|
return DelphiFork(
|
|
n_disease=dataset.n_disease,
|
|
n_tech_tokens=2,
|
|
n_embd=int(cfg["n_embd"]),
|
|
n_head=int(cfg["n_head"]),
|
|
n_layer=int(cfg["n_layer"]),
|
|
pdrop=float(cfg.get("pdrop", 0.0)),
|
|
age_encoder_type=str(cfg["age_encoder"]),
|
|
n_cont=int(dataset.n_cont),
|
|
n_cate=int(dataset.n_cate),
|
|
cate_dims=list(dataset.cate_dims),
|
|
)
|
|
|
|
if model_type == "sap_delphi":
|
|
return SapDelphi(
|
|
n_disease=dataset.n_disease,
|
|
n_tech_tokens=2,
|
|
n_embd=int(cfg["n_embd"]),
|
|
n_head=int(cfg["n_head"]),
|
|
n_layer=int(cfg["n_layer"]),
|
|
pdrop=float(cfg.get("pdrop", 0.0)),
|
|
age_encoder_type=str(cfg["age_encoder"]),
|
|
n_cont=int(dataset.n_cont),
|
|
n_cate=int(dataset.n_cate),
|
|
cate_dims=list(dataset.cate_dims),
|
|
pretrained_weights_path=str(
|
|
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
|
|
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
|
|
)
|
|
|
|
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})")
|
|
parser.add_argument(
|
|
"--run_dir",
|
|
type=str,
|
|
required=True,
|
|
help="Training run directory (contains best_model.pt and train_config.json)",
|
|
)
|
|
parser.add_argument("--data_prefix", type=str, default=None)
|
|
parser.add_argument("--split", type=str,
|
|
choices=["train", "val", "test", "all"], default="val")
|
|
|
|
parser.add_argument("--horizons", type=str, nargs="+",
|
|
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"])
|
|
parser.add_argument(
|
|
"--age_bin_edges",
|
|
type=str,
|
|
nargs="+",
|
|
default=["40", "45", "50", "55", "60", "65", "70", "75", "80"],
|
|
help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).",
|
|
)
|
|
parser.add_argument(
|
|
"--topk_percent",
|
|
type=float,
|
|
nargs="+",
|
|
default=[1, 5, 10, 20, 50],
|
|
help="One or more K%% values for recall/precision@K%%",
|
|
)
|
|
parser.add_argument("--n_mc", type=int, default=5)
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
|
|
parser.add_argument(
|
|
"--gpus",
|
|
type=str,
|
|
default=None,
|
|
help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3",
|
|
)
|
|
|
|
parser.add_argument("--device", type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu")
|
|
parser.add_argument("--batch_size", type=int, default=256)
|
|
parser.add_argument("--num_workers", type=int, default=0)
|
|
|
|
parser.add_argument("--out_prefix", type=str,
|
|
default=None, help="Output prefix for CSVs")
|
|
|
|
args = parser.parse_args()
|
|
|
|
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
|
|
cfg_path = os.path.join(args.run_dir, "train_config.json")
|
|
if not os.path.exists(ckpt_path):
|
|
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
|
|
if not os.path.exists(cfg_path):
|
|
raise SystemExit(f"Missing config: {cfg_path}")
|
|
|
|
with open(cfg_path, "r") as f:
|
|
cfg = json.load(f)
|
|
|
|
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
|
|
"data_prefix", "ukb")
|
|
|
|
full_cov = bool(cfg.get("full_cov", False))
|
|
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
|
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
|
|
|
|
train_ratio = float(cfg.get("train_ratio", 0.7))
|
|
val_ratio = float(cfg.get("val_ratio", 0.15))
|
|
seed_split = int(cfg.get("random_seed", 42))
|
|
|
|
n_total = len(dataset)
|
|
n_train = int(n_total * train_ratio)
|
|
n_val = int(n_total * val_ratio)
|
|
n_test = n_total - n_train - n_val
|
|
|
|
train_ds, val_ds, test_ds = random_split(
|
|
dataset,
|
|
[n_train, n_val, n_test],
|
|
generator=torch.Generator().manual_seed(seed_split),
|
|
)
|
|
|
|
if args.split == "train":
|
|
ds = train_ds
|
|
elif args.split == "val":
|
|
ds = val_ds
|
|
elif args.split == "test":
|
|
ds = test_ds
|
|
else:
|
|
ds = dataset
|
|
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=int(args.batch_size),
|
|
shuffle=False,
|
|
collate_fn=health_collate_fn,
|
|
num_workers=int(args.num_workers),
|
|
pin_memory=str(args.device).startswith("cuda"),
|
|
)
|
|
|
|
criterion, out_dims = build_criterion_and_out_dims(
|
|
loss_type=str(cfg["loss_type"]),
|
|
n_disease=int(dataset.n_disease),
|
|
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
|
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
|
)
|
|
|
|
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
|
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
|
|
|
device = torch.device(args.device)
|
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
|
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
|
if "criterion_state_dict" in checkpoint:
|
|
try:
|
|
criterion.load_state_dict(
|
|
checkpoint["criterion_state_dict"], strict=False)
|
|
except Exception:
|
|
pass
|
|
|
|
model.to(device)
|
|
head.to(device)
|
|
criterion.to(device)
|
|
|
|
age_edges = _parse_age_bin_edges(args.age_bin_edges)
|
|
age_bins = _edges_to_bins(age_edges)
|
|
|
|
eval_cfg = EvalAgeConfig(
|
|
horizons_years=_parse_floats(args.horizons),
|
|
age_bins=age_bins,
|
|
topk_percents=[float(x) for x in args.topk_percent],
|
|
n_mc=int(args.n_mc),
|
|
seed=int(args.seed),
|
|
cause_ids=None,
|
|
)
|
|
|
|
if args.out_prefix is None:
|
|
out_prefix = os.path.join(
|
|
args.run_dir, f"age_bin_time_dependent_{args.split}")
|
|
else:
|
|
out_prefix = args.out_prefix
|
|
|
|
out_bin = out_prefix + "_by_bin.csv"
|
|
out_agg = out_prefix + "_agg.csv"
|
|
|
|
gpus = _parse_gpus(args.gpus)
|
|
if len(gpus) <= 1:
|
|
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
|
|
model=model,
|
|
head=head,
|
|
criterion=criterion,
|
|
dataloader=loader,
|
|
n_disease=int(dataset.n_disease),
|
|
cfg=eval_cfg,
|
|
device=device,
|
|
)
|
|
|
|
df_by_bin.to_csv(out_bin, index=False)
|
|
df_agg.to_csv(out_agg, index=False)
|
|
print(f"Wrote: {out_bin}")
|
|
print(f"Wrote: {out_agg}")
|
|
return
|
|
|
|
if not torch.cuda.is_available():
|
|
raise SystemExit("--gpus was provided but CUDA is not available")
|
|
|
|
# Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU).
|
|
mc_indices_all = list(range(int(args.n_mc)))
|
|
per_gpu: List[Tuple[int, List[int]]] = []
|
|
for pos, gpu_id in enumerate(gpus):
|
|
assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos]
|
|
if assigned:
|
|
per_gpu.append((int(gpu_id), assigned))
|
|
|
|
ctx = mp.get_context("spawn")
|
|
queue: "mp.Queue" = ctx.Queue()
|
|
procs: List[mp.Process] = []
|
|
tmp_paths: List[str] = []
|
|
|
|
for gpu_id, mc_idxs in per_gpu:
|
|
tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv"
|
|
tmp_paths.append(tmp_path)
|
|
p = ctx.Process(
|
|
target=_worker_eval_mcs_on_gpu,
|
|
kwargs=dict(
|
|
queue=queue,
|
|
run_dir=str(args.run_dir),
|
|
split=str(args.split),
|
|
data_prefix_override=(
|
|
str(args.data_prefix) if args.data_prefix is not None else None
|
|
),
|
|
horizons=_parse_floats(args.horizons),
|
|
age_bins=age_bins,
|
|
topk_percents=[float(x) for x in args.topk_percent],
|
|
n_mc=int(args.n_mc),
|
|
seed=int(args.seed),
|
|
batch_size=int(args.batch_size),
|
|
num_workers=int(args.num_workers),
|
|
gpu_id=int(gpu_id),
|
|
mc_indices=mc_idxs,
|
|
out_path=tmp_path,
|
|
),
|
|
)
|
|
p.start()
|
|
procs.append(p)
|
|
|
|
results = [queue.get() for _ in range(len(procs))]
|
|
for p in procs:
|
|
p.join()
|
|
|
|
for r in results:
|
|
if not r.get("ok", False):
|
|
raise SystemExit(f"Worker failed: {r.get('error')}")
|
|
|
|
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
|
|
df_by_bin = pd.concat(frames, ignore_index=True) if len(
|
|
frames) else pd.DataFrame()
|
|
df_agg = aggregate_age_bin_results(df_by_bin)
|
|
|
|
df_by_bin.to_csv(out_bin, index=False)
|
|
df_agg.to_csv(out_agg, index=False)
|
|
|
|
# Best-effort cleanup.
|
|
for p in tmp_paths:
|
|
try:
|
|
if os.path.exists(p):
|
|
os.remove(p)
|
|
except Exception:
|
|
pass
|
|
|
|
print(f"Wrote: {out_bin}")
|
|
print(f"Wrote: {out_agg}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|