Add multi-GPU support for age-bin evaluation and refactor aggregation logic

This commit is contained in:
2026-01-16 16:27:02 +08:00
parent 7a1210b5b0
commit e47a7ce4d6
2 changed files with 393 additions and 124 deletions

View File

@@ -4,13 +4,19 @@ import argparse
import json import json
import math import math
import os import os
import multiprocessing as mp
from typing import List, Sequence, Tuple from typing import List, Sequence, Tuple
import pandas as pd
import torch import torch
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn from dataset import HealthDataset, health_collate_fn
from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins from evaluation_age_time_dependent import (
EvalAgeConfig,
aggregate_age_bin_results,
evaluate_time_dependent_age_bins,
)
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
from model import DelphiFork, SapDelphi, SimpleHead from model import DelphiFork, SapDelphi, SimpleHead
@@ -39,6 +45,144 @@ def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)] return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
def _parse_gpus(gpus: str | None) -> List[int]:
if gpus is None:
return []
s = gpus.strip()
if not s:
return []
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[int] = []
for p in parts:
out.append(int(p))
return out
def _worker_eval_mcs_on_gpu(
queue: "mp.Queue",
*,
run_dir: str,
split: str,
data_prefix_override: str | None,
horizons: List[float],
age_bins: List[Tuple[float, float]],
topk_percents: List[float],
n_mc: int,
seed: int,
batch_size: int,
num_workers: int,
gpu_id: int,
mc_indices: List[int],
out_path: str,
) -> None:
"""Worker process: evaluate a subset of MC indices on a single GPU."""
try:
ckpt_path = os.path.join(run_dir, "best_model.pt")
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
data_prefix = (
data_prefix_override
if data_prefix_override is not None
else cfg.get("data_prefix", "ukb")
)
full_cov = bool(cfg.get("full_cov", False))
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(data_prefix=data_prefix,
covariate_list=cov_list)
train_ratio = float(cfg.get("train_ratio", 0.7))
val_ratio = float(cfg.get("val_ratio", 0.15))
seed_split = int(cfg.get("random_seed", 42))
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed_split),
)
if split == "train":
ds = train_ds
elif split == "val":
ds = val_ds
elif split == "test":
ds = test_ds
else:
ds = dataset
loader = DataLoader(
ds,
batch_size=int(batch_size),
shuffle=False,
collate_fn=health_collate_fn,
num_workers=int(num_workers),
pin_memory=True,
)
criterion, out_dims = build_criterion_and_out_dims(
loss_type=str(cfg["loss_type"]),
n_disease=int(dataset.n_disease),
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
)
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
device = torch.device(f"cuda:{int(gpu_id)}")
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
if "criterion_state_dict" in checkpoint:
try:
criterion.load_state_dict(
checkpoint["criterion_state_dict"], strict=False)
except Exception:
pass
model.to(device)
head.to(device)
criterion.to(device)
frames: List[pd.DataFrame] = []
for mc_idx in mc_indices:
eval_cfg = EvalAgeConfig(
horizons_years=horizons,
age_bins=age_bins,
topk_percents=topk_percents,
n_mc=1,
seed=int(seed),
cause_ids=None,
)
df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
mc_offset=int(mc_idx),
)
frames.append(df_by_bin)
df_all = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame()
df_all.to_csv(out_path, index=False)
queue.put({"ok": True, "out_path": out_path})
except Exception as e:
queue.put({"ok": False, "error": repr(e)})
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
if loss_type == "exponential": if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
@@ -134,6 +278,13 @@ def main() -> None:
parser.add_argument("--n_mc", type=int, default=5) parser.add_argument("--n_mc", type=int, default=5)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--gpus",
type=str,
default=None,
help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3",
)
parser.add_argument("--device", type=str, parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu") default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--batch_size", type=int, default=256)
@@ -232,16 +383,6 @@ def main() -> None:
cause_ids=None, cause_ids=None,
) )
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
)
if args.out_prefix is None: if args.out_prefix is None:
out_prefix = os.path.join( out_prefix = os.path.join(
args.run_dir, f"age_bin_time_dependent_{args.split}") args.run_dir, f"age_bin_time_dependent_{args.split}")
@@ -251,9 +392,91 @@ def main() -> None:
out_bin = out_prefix + "_by_bin.csv" out_bin = out_prefix + "_by_bin.csv"
out_agg = out_prefix + "_agg.csv" out_agg = out_prefix + "_agg.csv"
gpus = _parse_gpus(args.gpus)
if len(gpus) <= 1:
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
model=model,
head=head,
criterion=criterion,
dataloader=loader,
n_disease=int(dataset.n_disease),
cfg=eval_cfg,
device=device,
)
df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False)
print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}")
return
if not torch.cuda.is_available():
raise SystemExit("--gpus was provided but CUDA is not available")
# Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU).
mc_indices_all = list(range(int(args.n_mc)))
per_gpu: List[Tuple[int, List[int]]] = []
for pos, gpu_id in enumerate(gpus):
assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos]
if assigned:
per_gpu.append((int(gpu_id), assigned))
ctx = mp.get_context("spawn")
queue: "mp.Queue" = ctx.Queue()
procs: List[mp.Process] = []
tmp_paths: List[str] = []
for gpu_id, mc_idxs in per_gpu:
tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv"
tmp_paths.append(tmp_path)
p = ctx.Process(
target=_worker_eval_mcs_on_gpu,
kwargs=dict(
queue=queue,
run_dir=str(args.run_dir),
split=str(args.split),
data_prefix_override=(
str(args.data_prefix) if args.data_prefix is not None else None
),
horizons=_parse_floats(args.horizons),
age_bins=age_bins,
topk_percents=[float(x) for x in args.topk_percent],
n_mc=int(args.n_mc),
seed=int(args.seed),
batch_size=int(args.batch_size),
num_workers=int(args.num_workers),
gpu_id=int(gpu_id),
mc_indices=mc_idxs,
out_path=tmp_path,
),
)
p.start()
procs.append(p)
results = [queue.get() for _ in range(len(procs))]
for p in procs:
p.join()
for r in results:
if not r.get("ok", False):
raise SystemExit(f"Worker failed: {r.get('error')}")
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
df_by_bin = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame()
df_agg = aggregate_age_bin_results(df_by_bin)
df_by_bin.to_csv(out_bin, index=False) df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False) df_agg.to_csv(out_agg, index=False)
# Best-effort cleanup.
for p in tmp_paths:
try:
if os.path.exists(p):
os.remove(p)
except Exception:
pass
print(f"Wrote: {out_bin}") print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}") print(f"Wrote: {out_agg}")

View File

@@ -18,6 +18,158 @@ from utils import (
) )
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results.
Produces both:
- macro: unweighted mean over bins with n_samples>0
- weighted: weighted mean over bins using weights=n_samples
Then aggregates across MC repetitions (mean/std).
Requires df_by_bin to include:
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
Returns:
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
"""
if df_by_bin is None or len(df_by_bin) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df_mc_macro = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=False))
.reset_index()
)
df_mc_macro["agg_type"] = "macro"
df_mc_weighted = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=True))
.reset_index()
)
df_mc_weighted["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
)
return df_agg
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking. """Compute ROC AUC for binary labels with tie-aware ranking.
@@ -138,6 +290,7 @@ def evaluate_time_dependent_age_bins(
n_disease: int, n_disease: int,
cfg: EvalAgeConfig, cfg: EvalAgeConfig,
device: str | torch.device, device: str | torch.device,
mc_offset: int = 0,
) -> Tuple[pd.DataFrame, pd.DataFrame]: ) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Delphi-2M-style age-bin evaluation with strict horizon alignment. """Delphi-2M-style age-bin evaluation with strict horizon alignment.
@@ -196,6 +349,7 @@ def evaluate_time_dependent_age_bins(
] ]
for mc_idx in range(int(cfg.n_mc)): for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# tqdm over batches; include MC idx in description. # tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate( for batch_idx, batch in enumerate(
tqdm(dataloader, tqdm(dataloader,
@@ -216,7 +370,7 @@ def evaluate_time_dependent_age_bins(
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation. # Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = ( seed = (
int(cfg.seed) int(cfg.seed)
+ (100_000 * int(mc_idx)) + (100_000 * int(global_mc_idx))
+ (1_000 * int(tau_idx)) + (1_000 * int(tau_idx))
+ (10 * int(bin_idx)) + (10 * int(bin_idx))
+ int(batch_idx) + int(batch_idx)
@@ -277,6 +431,7 @@ def evaluate_time_dependent_age_bins(
rows_by_bin: List[Dict[str, float | int]] = [] rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)): for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
for h_idx, tau_y in enumerate(horizons_years): for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins): for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true[mc_idx][h_idx][bin_idx]) == 0: if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
@@ -287,7 +442,7 @@ def evaluate_time_dependent_age_bins(
for k_percent in topk_percents: for k_percent in topk_percents:
rows_by_bin.append( rows_by_bin.append(
dict( dict(
mc_idx=mc_idx, mc_idx=global_mc_idx,
age_bin_id=bin_idx, age_bin_id=bin_idx,
age_bin_low=float(a_lo), age_bin_low=float(a_lo),
age_bin_high=float(a_hi), age_bin_high=float(a_hi),
@@ -332,7 +487,7 @@ def evaluate_time_dependent_age_bins(
yk, pk, float(k_percent)) yk, pk, float(k_percent))
rows_by_bin.append( rows_by_bin.append(
dict( dict(
mc_idx=mc_idx, mc_idx=global_mc_idx,
age_bin_id=bin_idx, age_bin_id=bin_idx,
age_bin_low=float(a_lo), age_bin_low=float(a_lo),
age_bin_high=float(a_hi), age_bin_high=float(a_hi),
@@ -351,115 +506,6 @@ def evaluate_time_dependent_age_bins(
df_by_bin = pd.DataFrame(rows_by_bin) df_by_bin = pd.DataFrame(rows_by_bin)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: df_agg = aggregate_age_bin_results(df_by_bin)
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df_mc_macro = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=False))
.reset_index()
)
df_mc_macro["agg_type"] = "macro"
df_mc_weighted = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=True))
.reset_index()
)
df_mc_weighted["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
# Then average over MC repetitions.
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
)
return df_by_bin, df_agg return df_by_bin, df_agg