Remove evaluation_age_time_dependent.py and utils.py files
- Deleted the entire evaluation_age_time_dependent.py file which contained functions for evaluating age-dependent metrics, including various statistical calculations and data aggregation methods. - Removed utils.py file that provided utility functions for sampling context in fixed age bins and multi-hot encoding for disease occurrences.
This commit is contained in:
507
evaluate_age.py
507
evaluate_age.py
@@ -1,507 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import multiprocessing as mp
|
|
||||||
from typing import List, Sequence, Tuple
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, random_split
|
|
||||||
|
|
||||||
from dataset import HealthDataset, health_collate_fn
|
|
||||||
from evaluation_age_time_dependent import (
|
|
||||||
EvalAgeConfig,
|
|
||||||
aggregate_age_bin_results,
|
|
||||||
evaluate_time_dependent_age_bins,
|
|
||||||
)
|
|
||||||
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
|
||||||
from model import DelphiFork, SapDelphi, SimpleHead
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_floats(items: Sequence[str]) -> List[float]:
|
|
||||||
out: List[float] = []
|
|
||||||
for x in items:
|
|
||||||
x = x.strip()
|
|
||||||
if not x:
|
|
||||||
continue
|
|
||||||
out.append(float(x))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]:
|
|
||||||
vals = _parse_floats(edges)
|
|
||||||
if len(vals) < 2:
|
|
||||||
raise ValueError("--age_bin_edges must have at least 2 values")
|
|
||||||
for i in range(1, len(vals)):
|
|
||||||
if not (vals[i] > vals[i - 1]):
|
|
||||||
raise ValueError("--age_bin_edges must be strictly increasing")
|
|
||||||
return vals
|
|
||||||
|
|
||||||
|
|
||||||
def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
|
|
||||||
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_gpus(gpus: str | None) -> List[int]:
|
|
||||||
if gpus is None:
|
|
||||||
return []
|
|
||||||
s = gpus.strip()
|
|
||||||
if not s:
|
|
||||||
return []
|
|
||||||
parts = [p.strip() for p in s.split(",") if p.strip()]
|
|
||||||
out: List[int] = []
|
|
||||||
for p in parts:
|
|
||||||
out.append(int(p))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _worker_eval_mcs_on_gpu(
|
|
||||||
queue: "mp.Queue",
|
|
||||||
*,
|
|
||||||
run_dir: str,
|
|
||||||
split: str,
|
|
||||||
data_prefix_override: str | None,
|
|
||||||
horizons: List[float],
|
|
||||||
age_bins: List[Tuple[float, float]],
|
|
||||||
topk_percents: List[float],
|
|
||||||
n_mc: int,
|
|
||||||
seed: int,
|
|
||||||
batch_size: int,
|
|
||||||
num_workers: int,
|
|
||||||
gpu_id: int,
|
|
||||||
mc_indices: List[int],
|
|
||||||
out_path: str,
|
|
||||||
) -> None:
|
|
||||||
"""Worker process: evaluate a subset of MC indices on a single GPU."""
|
|
||||||
try:
|
|
||||||
ckpt_path = os.path.join(run_dir, "best_model.pt")
|
|
||||||
cfg_path = os.path.join(run_dir, "train_config.json")
|
|
||||||
with open(cfg_path, "r") as f:
|
|
||||||
cfg = json.load(f)
|
|
||||||
|
|
||||||
data_prefix = (
|
|
||||||
data_prefix_override
|
|
||||||
if data_prefix_override is not None
|
|
||||||
else cfg.get("data_prefix", "ukb")
|
|
||||||
)
|
|
||||||
|
|
||||||
full_cov = bool(cfg.get("full_cov", False))
|
|
||||||
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
|
||||||
dataset = HealthDataset(data_prefix=data_prefix,
|
|
||||||
covariate_list=cov_list)
|
|
||||||
|
|
||||||
train_ratio = float(cfg.get("train_ratio", 0.7))
|
|
||||||
val_ratio = float(cfg.get("val_ratio", 0.15))
|
|
||||||
seed_split = int(cfg.get("random_seed", 42))
|
|
||||||
|
|
||||||
n_total = len(dataset)
|
|
||||||
n_train = int(n_total * train_ratio)
|
|
||||||
n_val = int(n_total * val_ratio)
|
|
||||||
n_test = n_total - n_train - n_val
|
|
||||||
|
|
||||||
train_ds, val_ds, test_ds = random_split(
|
|
||||||
dataset,
|
|
||||||
[n_train, n_val, n_test],
|
|
||||||
generator=torch.Generator().manual_seed(seed_split),
|
|
||||||
)
|
|
||||||
|
|
||||||
if split == "train":
|
|
||||||
ds = train_ds
|
|
||||||
elif split == "val":
|
|
||||||
ds = val_ds
|
|
||||||
elif split == "test":
|
|
||||||
ds = test_ds
|
|
||||||
else:
|
|
||||||
ds = dataset
|
|
||||||
|
|
||||||
loader = DataLoader(
|
|
||||||
ds,
|
|
||||||
batch_size=int(batch_size),
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=health_collate_fn,
|
|
||||||
num_workers=int(num_workers),
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
criterion, out_dims = build_criterion_and_out_dims(
|
|
||||||
loss_type=str(cfg["loss_type"]),
|
|
||||||
n_disease=int(dataset.n_disease),
|
|
||||||
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
|
||||||
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
|
||||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
|
||||||
|
|
||||||
device = torch.device(f"cuda:{int(gpu_id)}")
|
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
||||||
|
|
||||||
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
|
||||||
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
|
||||||
if "criterion_state_dict" in checkpoint:
|
|
||||||
try:
|
|
||||||
criterion.load_state_dict(
|
|
||||||
checkpoint["criterion_state_dict"], strict=False)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
head.to(device)
|
|
||||||
criterion.to(device)
|
|
||||||
|
|
||||||
frames: List[pd.DataFrame] = []
|
|
||||||
for mc_idx in mc_indices:
|
|
||||||
eval_cfg = EvalAgeConfig(
|
|
||||||
horizons_years=horizons,
|
|
||||||
age_bins=age_bins,
|
|
||||||
topk_percents=topk_percents,
|
|
||||||
n_mc=1,
|
|
||||||
seed=int(seed),
|
|
||||||
cause_ids=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins(
|
|
||||||
model=model,
|
|
||||||
head=head,
|
|
||||||
criterion=criterion,
|
|
||||||
dataloader=loader,
|
|
||||||
n_disease=int(dataset.n_disease),
|
|
||||||
cfg=eval_cfg,
|
|
||||||
device=device,
|
|
||||||
mc_offset=int(mc_idx),
|
|
||||||
)
|
|
||||||
frames.append(df_by_bin)
|
|
||||||
|
|
||||||
df_all = pd.concat(frames, ignore_index=True) if len(
|
|
||||||
frames) else pd.DataFrame()
|
|
||||||
df_all = _drop_zero_positives_rows(df_all, "n_positives")
|
|
||||||
df_all.to_csv(out_path, index=False)
|
|
||||||
queue.put({"ok": True, "out_path": out_path})
|
|
||||||
except Exception as e:
|
|
||||||
queue.put({"ok": False, "error": repr(e)})
|
|
||||||
|
|
||||||
|
|
||||||
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
|
||||||
if loss_type == "exponential":
|
|
||||||
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
|
||||||
out_dims = [n_disease]
|
|
||||||
return criterion, out_dims
|
|
||||||
|
|
||||||
if loss_type == "discrete_time_cif":
|
|
||||||
criterion = DiscreteTimeCIFNLLLoss(
|
|
||||||
bin_edges=bin_edges, lambda_reg=lambda_reg)
|
|
||||||
out_dims = [n_disease + 1, len(bin_edges)]
|
|
||||||
return criterion, out_dims
|
|
||||||
|
|
||||||
if loss_type == "pwe_cif":
|
|
||||||
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
|
||||||
if len(pwe_edges) < 2:
|
|
||||||
raise ValueError(
|
|
||||||
"pwe_cif requires at least 2 finite bin edges (including 0)")
|
|
||||||
if float(pwe_edges[0]) != 0.0:
|
|
||||||
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
|
|
||||||
criterion = PiecewiseExponentialCIFNLLLoss(
|
|
||||||
bin_edges=pwe_edges, lambda_reg=lambda_reg)
|
|
||||||
n_bins = len(pwe_edges) - 1
|
|
||||||
out_dims = [n_disease, n_bins]
|
|
||||||
return criterion, out_dims
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def _drop_zero_positives_rows(df: pd.DataFrame, positive_col: str) -> pd.DataFrame:
|
|
||||||
"""Drop rows where the provided positives column is <= 0.
|
|
||||||
|
|
||||||
Intended to reduce CSV size by omitting (cause, horizon, bin) rows that have
|
|
||||||
no positives, which otherwise yield undefined/NaN metrics.
|
|
||||||
"""
|
|
||||||
if df is None or len(df) == 0:
|
|
||||||
return df
|
|
||||||
if positive_col not in df.columns:
|
|
||||||
return df
|
|
||||||
pos = pd.to_numeric(df[positive_col], errors="coerce")
|
|
||||||
return df[pos > 0].copy()
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
|
||||||
if model_type == "delphi_fork":
|
|
||||||
return DelphiFork(
|
|
||||||
n_disease=dataset.n_disease,
|
|
||||||
n_tech_tokens=2,
|
|
||||||
n_embd=int(cfg["n_embd"]),
|
|
||||||
n_head=int(cfg["n_head"]),
|
|
||||||
n_layer=int(cfg["n_layer"]),
|
|
||||||
pdrop=float(cfg.get("pdrop", 0.0)),
|
|
||||||
age_encoder_type=str(cfg["age_encoder"]),
|
|
||||||
n_cont=int(dataset.n_cont),
|
|
||||||
n_cate=int(dataset.n_cate),
|
|
||||||
cate_dims=list(dataset.cate_dims),
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == "sap_delphi":
|
|
||||||
return SapDelphi(
|
|
||||||
n_disease=dataset.n_disease,
|
|
||||||
n_tech_tokens=2,
|
|
||||||
n_embd=int(cfg["n_embd"]),
|
|
||||||
n_head=int(cfg["n_head"]),
|
|
||||||
n_layer=int(cfg["n_layer"]),
|
|
||||||
pdrop=float(cfg.get("pdrop", 0.0)),
|
|
||||||
age_encoder_type=str(cfg["age_encoder"]),
|
|
||||||
n_cont=int(dataset.n_cont),
|
|
||||||
n_cate=int(dataset.n_cate),
|
|
||||||
cate_dims=list(dataset.cate_dims),
|
|
||||||
pretrained_weights_path=str(
|
|
||||||
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
|
|
||||||
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})")
|
|
||||||
parser.add_argument(
|
|
||||||
"--run_dir",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Training run directory (contains best_model.pt and train_config.json)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--data_prefix", type=str, default=None)
|
|
||||||
parser.add_argument("--split", type=str,
|
|
||||||
choices=["train", "val", "test", "all"], default="val")
|
|
||||||
|
|
||||||
parser.add_argument("--horizons", type=str, nargs="+",
|
|
||||||
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"])
|
|
||||||
parser.add_argument(
|
|
||||||
"--age_bin_edges",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["40", "45", "50", "55", "60", "65", "70", "75", "80"],
|
|
||||||
help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--topk_percent",
|
|
||||||
type=float,
|
|
||||||
nargs="+",
|
|
||||||
default=[1, 5, 10, 20, 50],
|
|
||||||
help="One or more K%% values for recall/precision@K%%",
|
|
||||||
)
|
|
||||||
parser.add_argument("--n_mc", type=int, default=5)
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--gpus",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--device", type=str,
|
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
parser.add_argument("--batch_size", type=int, default=256)
|
|
||||||
parser.add_argument("--num_workers", type=int, default=0)
|
|
||||||
|
|
||||||
parser.add_argument("--out_prefix", type=str,
|
|
||||||
default=None, help="Output prefix for CSVs")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
|
|
||||||
cfg_path = os.path.join(args.run_dir, "train_config.json")
|
|
||||||
if not os.path.exists(ckpt_path):
|
|
||||||
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
|
|
||||||
if not os.path.exists(cfg_path):
|
|
||||||
raise SystemExit(f"Missing config: {cfg_path}")
|
|
||||||
|
|
||||||
with open(cfg_path, "r") as f:
|
|
||||||
cfg = json.load(f)
|
|
||||||
|
|
||||||
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
|
|
||||||
"data_prefix", "ukb")
|
|
||||||
|
|
||||||
full_cov = bool(cfg.get("full_cov", False))
|
|
||||||
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
|
||||||
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
|
|
||||||
|
|
||||||
train_ratio = float(cfg.get("train_ratio", 0.7))
|
|
||||||
val_ratio = float(cfg.get("val_ratio", 0.15))
|
|
||||||
seed_split = int(cfg.get("random_seed", 42))
|
|
||||||
|
|
||||||
n_total = len(dataset)
|
|
||||||
n_train = int(n_total * train_ratio)
|
|
||||||
n_val = int(n_total * val_ratio)
|
|
||||||
n_test = n_total - n_train - n_val
|
|
||||||
|
|
||||||
train_ds, val_ds, test_ds = random_split(
|
|
||||||
dataset,
|
|
||||||
[n_train, n_val, n_test],
|
|
||||||
generator=torch.Generator().manual_seed(seed_split),
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.split == "train":
|
|
||||||
ds = train_ds
|
|
||||||
elif args.split == "val":
|
|
||||||
ds = val_ds
|
|
||||||
elif args.split == "test":
|
|
||||||
ds = test_ds
|
|
||||||
else:
|
|
||||||
ds = dataset
|
|
||||||
|
|
||||||
loader = DataLoader(
|
|
||||||
ds,
|
|
||||||
batch_size=int(args.batch_size),
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=health_collate_fn,
|
|
||||||
num_workers=int(args.num_workers),
|
|
||||||
pin_memory=str(args.device).startswith("cuda"),
|
|
||||||
)
|
|
||||||
|
|
||||||
criterion, out_dims = build_criterion_and_out_dims(
|
|
||||||
loss_type=str(cfg["loss_type"]),
|
|
||||||
n_disease=int(dataset.n_disease),
|
|
||||||
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
|
||||||
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
|
||||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
|
||||||
|
|
||||||
device = torch.device(args.device)
|
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
||||||
|
|
||||||
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
|
||||||
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
|
||||||
if "criterion_state_dict" in checkpoint:
|
|
||||||
try:
|
|
||||||
criterion.load_state_dict(
|
|
||||||
checkpoint["criterion_state_dict"], strict=False)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
head.to(device)
|
|
||||||
criterion.to(device)
|
|
||||||
|
|
||||||
age_edges = _parse_age_bin_edges(args.age_bin_edges)
|
|
||||||
age_bins = _edges_to_bins(age_edges)
|
|
||||||
|
|
||||||
eval_cfg = EvalAgeConfig(
|
|
||||||
horizons_years=_parse_floats(args.horizons),
|
|
||||||
age_bins=age_bins,
|
|
||||||
topk_percents=[float(x) for x in args.topk_percent],
|
|
||||||
n_mc=int(args.n_mc),
|
|
||||||
seed=int(args.seed),
|
|
||||||
cause_ids=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.out_prefix is None:
|
|
||||||
out_prefix = os.path.join(
|
|
||||||
args.run_dir, f"age_bin_time_dependent_{args.split}")
|
|
||||||
else:
|
|
||||||
out_prefix = args.out_prefix
|
|
||||||
|
|
||||||
out_bin = out_prefix + "_by_bin.csv"
|
|
||||||
out_agg = out_prefix + "_agg.csv"
|
|
||||||
|
|
||||||
gpus = _parse_gpus(args.gpus)
|
|
||||||
if len(gpus) <= 1:
|
|
||||||
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
|
|
||||||
model=model,
|
|
||||||
head=head,
|
|
||||||
criterion=criterion,
|
|
||||||
dataloader=loader,
|
|
||||||
n_disease=int(dataset.n_disease),
|
|
||||||
cfg=eval_cfg,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
df_by_bin_csv = _drop_zero_positives_rows(df_by_bin, "n_positives")
|
|
||||||
df_agg_csv = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
|
|
||||||
df_by_bin_csv.to_csv(out_bin, index=False)
|
|
||||||
df_agg_csv.to_csv(out_agg, index=False)
|
|
||||||
print(f"Wrote: {out_bin}")
|
|
||||||
print(f"Wrote: {out_agg}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise SystemExit("--gpus was provided but CUDA is not available")
|
|
||||||
|
|
||||||
# Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU).
|
|
||||||
mc_indices_all = list(range(int(args.n_mc)))
|
|
||||||
per_gpu: List[Tuple[int, List[int]]] = []
|
|
||||||
for pos, gpu_id in enumerate(gpus):
|
|
||||||
assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos]
|
|
||||||
if assigned:
|
|
||||||
per_gpu.append((int(gpu_id), assigned))
|
|
||||||
|
|
||||||
ctx = mp.get_context("spawn")
|
|
||||||
queue: "mp.Queue" = ctx.Queue()
|
|
||||||
procs: List[mp.Process] = []
|
|
||||||
tmp_paths: List[str] = []
|
|
||||||
|
|
||||||
for gpu_id, mc_idxs in per_gpu:
|
|
||||||
tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv"
|
|
||||||
tmp_paths.append(tmp_path)
|
|
||||||
p = ctx.Process(
|
|
||||||
target=_worker_eval_mcs_on_gpu,
|
|
||||||
kwargs=dict(
|
|
||||||
queue=queue,
|
|
||||||
run_dir=str(args.run_dir),
|
|
||||||
split=str(args.split),
|
|
||||||
data_prefix_override=(
|
|
||||||
str(args.data_prefix) if args.data_prefix is not None else None
|
|
||||||
),
|
|
||||||
horizons=_parse_floats(args.horizons),
|
|
||||||
age_bins=age_bins,
|
|
||||||
topk_percents=[float(x) for x in args.topk_percent],
|
|
||||||
n_mc=int(args.n_mc),
|
|
||||||
seed=int(args.seed),
|
|
||||||
batch_size=int(args.batch_size),
|
|
||||||
num_workers=int(args.num_workers),
|
|
||||||
gpu_id=int(gpu_id),
|
|
||||||
mc_indices=mc_idxs,
|
|
||||||
out_path=tmp_path,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
p.start()
|
|
||||||
procs.append(p)
|
|
||||||
|
|
||||||
results = [queue.get() for _ in range(len(procs))]
|
|
||||||
for p in procs:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
for r in results:
|
|
||||||
if not r.get("ok", False):
|
|
||||||
raise SystemExit(f"Worker failed: {r.get('error')}")
|
|
||||||
|
|
||||||
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
|
|
||||||
df_by_bin = pd.concat(frames, ignore_index=True) if len(
|
|
||||||
frames) else pd.DataFrame()
|
|
||||||
|
|
||||||
# Ensure we don't keep zero-positive rows even if a temp file was produced
|
|
||||||
# by an older version of the worker.
|
|
||||||
df_by_bin = _drop_zero_positives_rows(df_by_bin, "n_positives")
|
|
||||||
df_agg = aggregate_age_bin_results(df_by_bin)
|
|
||||||
|
|
||||||
df_agg = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
|
|
||||||
df_by_bin.to_csv(out_bin, index=False)
|
|
||||||
df_agg.to_csv(out_agg, index=False)
|
|
||||||
|
|
||||||
# Best-effort cleanup.
|
|
||||||
for p in tmp_paths:
|
|
||||||
try:
|
|
||||||
if os.path.exists(p):
|
|
||||||
os.remove(p)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print(f"Wrote: {out_bin}")
|
|
||||||
print(f"Wrote: {out_agg}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,852 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
|
|
||||||
def tqdm(x, **kwargs):
|
|
||||||
return x
|
|
||||||
from utils import (
|
|
||||||
multi_hot_ever_within_horizon,
|
|
||||||
multi_hot_selected_causes_within_horizon,
|
|
||||||
sample_context_in_fixed_age_bin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch_metrics import compute_binary_metrics_torch
|
|
||||||
|
|
||||||
|
|
||||||
def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray:
|
|
||||||
with np.errstate(invalid="ignore"):
|
|
||||||
return np.nanmean(x, axis=axis)
|
|
||||||
|
|
||||||
|
|
||||||
def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray:
|
|
||||||
"""NaN-aware sample std (ddof=1), matching pandas std() semantics."""
|
|
||||||
x = np.asarray(x, dtype=float)
|
|
||||||
mask = np.isfinite(x)
|
|
||||||
cnt = mask.sum(axis=axis)
|
|
||||||
# mean over finite entries
|
|
||||||
x0 = np.where(mask, x, 0.0)
|
|
||||||
mean = x0.sum(axis=axis) / np.maximum(cnt, 1)
|
|
||||||
# sum of squared deviations over finite entries
|
|
||||||
dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0)
|
|
||||||
ss = dev2.sum(axis=axis)
|
|
||||||
denom = cnt - 1
|
|
||||||
out = np.sqrt(ss / np.maximum(denom, 1))
|
|
||||||
out = np.where(denom > 0, out, np.nan)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray:
|
|
||||||
"""NaN-aware weighted mean.
|
|
||||||
|
|
||||||
Only bins with finite x contribute to both numerator and denominator.
|
|
||||||
If denom==0 -> NaN.
|
|
||||||
"""
|
|
||||||
x = np.asarray(x, dtype=float)
|
|
||||||
w = np.asarray(w, dtype=float)
|
|
||||||
|
|
||||||
if axis != 0:
|
|
||||||
raise ValueError("_weighted_mean_np currently supports axis=0 only")
|
|
||||||
|
|
||||||
# Broadcast weights along trailing dims of x.
|
|
||||||
while w.ndim < x.ndim:
|
|
||||||
w = w[..., None]
|
|
||||||
w = np.broadcast_to(w, x.shape)
|
|
||||||
|
|
||||||
mask = np.isfinite(x)
|
|
||||||
num = np.where(mask, x * w, 0.0).sum(axis=0)
|
|
||||||
denom = np.where(mask, w, 0.0).sum(axis=0)
|
|
||||||
return np.where(denom > 0.0, num / denom, np.nan)
|
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_df_by_bin(
|
|
||||||
blocks: List[Dict[str, Any]],
|
|
||||||
*,
|
|
||||||
topk_percents: np.ndarray,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Convert per-block column vectors into the long-format per-bin DataFrame.
|
|
||||||
|
|
||||||
This does a single vectorized reshape per block (cause-major ordering), and
|
|
||||||
concatenates columns once at the end.
|
|
||||||
"""
|
|
||||||
if len(blocks) == 0:
|
|
||||||
return pd.DataFrame(
|
|
||||||
columns=[
|
|
||||||
"mc_idx",
|
|
||||||
"age_bin_id",
|
|
||||||
"age_bin_low",
|
|
||||||
"age_bin_high",
|
|
||||||
"horizon_tau",
|
|
||||||
"topk_percent",
|
|
||||||
"cause_id",
|
|
||||||
"n_samples",
|
|
||||||
"n_positives",
|
|
||||||
"auc",
|
|
||||||
"auprc",
|
|
||||||
"recall_at_K",
|
|
||||||
"precision_at_K",
|
|
||||||
"brier_score",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
P = int(topk_percents.size)
|
|
||||||
|
|
||||||
cols: Dict[str, List[np.ndarray]] = {
|
|
||||||
"mc_idx": [],
|
|
||||||
"age_bin_id": [],
|
|
||||||
"age_bin_low": [],
|
|
||||||
"age_bin_high": [],
|
|
||||||
"horizon_tau": [],
|
|
||||||
"topk_percent": [],
|
|
||||||
"cause_id": [],
|
|
||||||
"n_samples": [],
|
|
||||||
"n_positives": [],
|
|
||||||
"auc": [],
|
|
||||||
"auprc": [],
|
|
||||||
"recall_at_K": [],
|
|
||||||
"precision_at_K": [],
|
|
||||||
"brier_score": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
for blk in blocks:
|
|
||||||
cause_id = np.asarray(blk["cause_id"], dtype=int)
|
|
||||||
K = int(cause_id.size)
|
|
||||||
n_rows = K * P
|
|
||||||
|
|
||||||
cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int))
|
|
||||||
cols["age_bin_id"].append(
|
|
||||||
np.full(n_rows, int(blk["age_bin_id"]), dtype=int))
|
|
||||||
cols["age_bin_low"].append(
|
|
||||||
np.full(n_rows, float(blk["age_bin_low"]), dtype=float))
|
|
||||||
cols["age_bin_high"].append(
|
|
||||||
np.full(n_rows, float(blk["age_bin_high"]), dtype=float))
|
|
||||||
cols["horizon_tau"].append(
|
|
||||||
np.full(n_rows, float(blk["horizon_tau"]), dtype=float))
|
|
||||||
|
|
||||||
cols["cause_id"].append(np.repeat(cause_id, P))
|
|
||||||
cols["topk_percent"].append(np.tile(topk_percents.astype(float), K))
|
|
||||||
cols["n_samples"].append(
|
|
||||||
np.full(n_rows, int(blk["n_samples"]), dtype=int))
|
|
||||||
|
|
||||||
n_pos = np.asarray(blk["n_positives"], dtype=int)
|
|
||||||
cols["n_positives"].append(np.repeat(n_pos, P))
|
|
||||||
|
|
||||||
auc = np.asarray(blk["auc"], dtype=float)
|
|
||||||
auprc = np.asarray(blk["auprc"], dtype=float)
|
|
||||||
brier = np.asarray(blk["brier_score"], dtype=float)
|
|
||||||
cols["auc"].append(np.repeat(auc, P))
|
|
||||||
cols["auprc"].append(np.repeat(auprc, P))
|
|
||||||
cols["brier_score"].append(np.repeat(brier, P))
|
|
||||||
|
|
||||||
# precision/recall are stored as (P,K); we want cause-major rows, i.e.
|
|
||||||
# (K,P) then flatten.
|
|
||||||
prec = np.asarray(blk["precision_at_K"], dtype=float)
|
|
||||||
rec = np.asarray(blk["recall_at_K"], dtype=float)
|
|
||||||
if prec.shape != (P, K) or rec.shape != (P, K):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}"
|
|
||||||
)
|
|
||||||
cols["precision_at_K"].append(prec.T.reshape(-1))
|
|
||||||
cols["recall_at_K"].append(rec.T.reshape(-1))
|
|
||||||
|
|
||||||
out = {k: np.concatenate(v, axis=0) for k, v in cols.items()}
|
|
||||||
return pd.DataFrame(out)
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_metrics_columnar(
|
|
||||||
blocks: List[Dict[str, Any]],
|
|
||||||
*,
|
|
||||||
topk_percents: np.ndarray,
|
|
||||||
cause_id: np.ndarray,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std).
|
|
||||||
|
|
||||||
This is a vectorized, columnar replacement for the old pandas groupby/apply.
|
|
||||||
Semantics match the previous implementation:
|
|
||||||
- bins with n_samples==0 are excluded from bin-aggregation
|
|
||||||
- macro: unweighted mean over bins (NaN-aware)
|
|
||||||
- weighted: weighted mean over bins using weights=n_samples (NaN-aware)
|
|
||||||
- across MC: mean/std (ddof=1), NaN-aware
|
|
||||||
"""
|
|
||||||
if len(blocks) == 0:
|
|
||||||
return pd.DataFrame(
|
|
||||||
columns=[
|
|
||||||
"agg_type",
|
|
||||||
"horizon_tau",
|
|
||||||
"topk_percent",
|
|
||||||
"cause_id",
|
|
||||||
"n_mc",
|
|
||||||
"n_bins_used_mean",
|
|
||||||
"n_samples_total_mean",
|
|
||||||
"n_positives_total_mean",
|
|
||||||
"auc_mean",
|
|
||||||
"auc_std",
|
|
||||||
"auprc_mean",
|
|
||||||
"auprc_std",
|
|
||||||
"recall_at_K_mean",
|
|
||||||
"recall_at_K_std",
|
|
||||||
"precision_at_K_mean",
|
|
||||||
"precision_at_K_std",
|
|
||||||
"brier_score_mean",
|
|
||||||
"brier_score_std",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
P = int(topk_percents.size)
|
|
||||||
cause_id = np.asarray(cause_id, dtype=int)
|
|
||||||
K = int(cause_id.size)
|
|
||||||
|
|
||||||
# Group blocks by (mc_idx, horizon_tau)
|
|
||||||
keys: List[Tuple[int, float]] = []
|
|
||||||
grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {}
|
|
||||||
for blk in blocks:
|
|
||||||
key = (int(blk["mc_idx"]), float(blk["horizon_tau"]))
|
|
||||||
if key not in grouped:
|
|
||||||
grouped[key] = []
|
|
||||||
keys.append(key)
|
|
||||||
grouped[key].append(blk)
|
|
||||||
|
|
||||||
mc_vals = sorted({k[0] for k in keys})
|
|
||||||
tau_vals = sorted({k[1] for k in keys})
|
|
||||||
M = len(mc_vals)
|
|
||||||
T = len(tau_vals)
|
|
||||||
|
|
||||||
mc_index = {mc: i for i, mc in enumerate(mc_vals)}
|
|
||||||
tau_index = {tau: i for i, tau in enumerate(tau_vals)}
|
|
||||||
|
|
||||||
# Per (agg_type, mc, tau): store arrays
|
|
||||||
# metrics: (M,T,K) and (M,T,P,K)
|
|
||||||
auc_macro = np.full((M, T, K), np.nan, dtype=float)
|
|
||||||
auc_weighted = np.full((M, T, K), np.nan, dtype=float)
|
|
||||||
ap_macro = np.full((M, T, K), np.nan, dtype=float)
|
|
||||||
ap_weighted = np.full((M, T, K), np.nan, dtype=float)
|
|
||||||
brier_macro = np.full((M, T, K), np.nan, dtype=float)
|
|
||||||
brier_weighted = np.full((M, T, K), np.nan, dtype=float)
|
|
||||||
|
|
||||||
prec_macro = np.full((M, T, P, K), np.nan, dtype=float)
|
|
||||||
prec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
|
|
||||||
rec_macro = np.full((M, T, P, K), np.nan, dtype=float)
|
|
||||||
rec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
|
|
||||||
|
|
||||||
n_bins_used = np.zeros((M, T), dtype=float)
|
|
||||||
n_samples_total = np.zeros((M, T), dtype=float)
|
|
||||||
n_pos_total = np.zeros((M, T, K), dtype=float)
|
|
||||||
|
|
||||||
for (mc, tau), blks in grouped.items():
|
|
||||||
mi = mc_index[mc]
|
|
||||||
ti = tau_index[tau]
|
|
||||||
|
|
||||||
# keep only bins with n_samples>0
|
|
||||||
blks_nz = [b for b in blks if int(b["n_samples"]) > 0]
|
|
||||||
if len(blks_nz) == 0:
|
|
||||||
n_bins_used[mi, ti] = 0.0
|
|
||||||
n_samples_total[mi, ti] = 0.0
|
|
||||||
n_pos_total[mi, ti, :] = 0.0
|
|
||||||
continue
|
|
||||||
|
|
||||||
w = np.asarray([int(b["n_samples"])
|
|
||||||
for b in blks_nz], dtype=float) # (B,)
|
|
||||||
n_bins_used[mi, ti] = float(len(w))
|
|
||||||
n_samples_total[mi, ti] = float(w.sum())
|
|
||||||
|
|
||||||
npos = np.stack([np.asarray(b["n_positives"], dtype=float)
|
|
||||||
for b in blks_nz], axis=0) # (B,K)
|
|
||||||
n_pos_total[mi, ti, :] = npos.sum(axis=0)
|
|
||||||
|
|
||||||
auc_b = np.stack([np.asarray(b["auc"], dtype=float)
|
|
||||||
for b in blks_nz], axis=0) # (B,K)
|
|
||||||
ap_b = np.stack([np.asarray(b["auprc"], dtype=float)
|
|
||||||
for b in blks_nz], axis=0)
|
|
||||||
brier_b = np.stack([np.asarray(b["brier_score"], dtype=float)
|
|
||||||
for b in blks_nz], axis=0)
|
|
||||||
|
|
||||||
auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0)
|
|
||||||
ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0)
|
|
||||||
brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0)
|
|
||||||
|
|
||||||
auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0)
|
|
||||||
ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0)
|
|
||||||
brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0)
|
|
||||||
|
|
||||||
prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float)
|
|
||||||
for b in blks_nz], axis=0) # (B,P,K)
|
|
||||||
rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float)
|
|
||||||
for b in blks_nz], axis=0)
|
|
||||||
|
|
||||||
# macro mean over bins
|
|
||||||
prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0)
|
|
||||||
rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0)
|
|
||||||
|
|
||||||
# weighted mean over bins (weights along bin axis)
|
|
||||||
w3 = w.reshape(-1, 1, 1)
|
|
||||||
prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0)
|
|
||||||
rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0)
|
|
||||||
|
|
||||||
# Across-MC aggregation (mean/std), then emit long-format df keyed by
|
|
||||||
# (agg_type, horizon_tau, topk_percent, cause_id)
|
|
||||||
rows: Dict[str, List[np.ndarray]] = {
|
|
||||||
"agg_type": [],
|
|
||||||
"horizon_tau": [],
|
|
||||||
"topk_percent": [],
|
|
||||||
"cause_id": [],
|
|
||||||
"n_mc": [],
|
|
||||||
"n_bins_used_mean": [],
|
|
||||||
"n_samples_total_mean": [],
|
|
||||||
"n_positives_total_mean": [],
|
|
||||||
"auc_mean": [],
|
|
||||||
"auc_std": [],
|
|
||||||
"auprc_mean": [],
|
|
||||||
"auprc_std": [],
|
|
||||||
"recall_at_K_mean": [],
|
|
||||||
"recall_at_K_std": [],
|
|
||||||
"precision_at_K_mean": [],
|
|
||||||
"precision_at_K_std": [],
|
|
||||||
"brier_score_mean": [],
|
|
||||||
"brier_score_std": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
cause_long = np.repeat(cause_id, P)
|
|
||||||
topk_long = np.tile(topk_percents.astype(float), K)
|
|
||||||
n_mc_val = float(M)
|
|
||||||
|
|
||||||
for ti, tau in enumerate(tau_vals):
|
|
||||||
# scalar totals (repeat across causes/topk)
|
|
||||||
n_bins_mean = float(
|
|
||||||
np.mean(n_bins_used[:, ti])) if M > 0 else float("nan")
|
|
||||||
n_samp_mean = float(
|
|
||||||
np.mean(n_samples_total[:, ti])) if M > 0 else float("nan")
|
|
||||||
n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,)
|
|
||||||
|
|
||||||
for agg_type in ("macro", "weighted"):
|
|
||||||
if agg_type == "macro":
|
|
||||||
auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0)
|
|
||||||
auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0)
|
|
||||||
ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0)
|
|
||||||
ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0)
|
|
||||||
brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0)
|
|
||||||
brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0)
|
|
||||||
prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K)
|
|
||||||
prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0)
|
|
||||||
rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0)
|
|
||||||
rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0)
|
|
||||||
else:
|
|
||||||
auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0)
|
|
||||||
auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0)
|
|
||||||
ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0)
|
|
||||||
ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0)
|
|
||||||
brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0)
|
|
||||||
brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0)
|
|
||||||
prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0)
|
|
||||||
prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0)
|
|
||||||
rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0)
|
|
||||||
rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0)
|
|
||||||
|
|
||||||
n_rows = K * P
|
|
||||||
rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object))
|
|
||||||
rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float))
|
|
||||||
rows["topk_percent"].append(topk_long)
|
|
||||||
rows["cause_id"].append(cause_long)
|
|
||||||
rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float))
|
|
||||||
rows["n_bins_used_mean"].append(
|
|
||||||
np.full(n_rows, n_bins_mean, dtype=float))
|
|
||||||
rows["n_samples_total_mean"].append(
|
|
||||||
np.full(n_rows, n_samp_mean, dtype=float))
|
|
||||||
rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P))
|
|
||||||
|
|
||||||
rows["auc_mean"].append(np.repeat(auc_m, P))
|
|
||||||
rows["auc_std"].append(np.repeat(auc_s, P))
|
|
||||||
rows["auprc_mean"].append(np.repeat(ap_m, P))
|
|
||||||
rows["auprc_std"].append(np.repeat(ap_s, P))
|
|
||||||
rows["brier_score_mean"].append(np.repeat(brier_m, P))
|
|
||||||
rows["brier_score_std"].append(np.repeat(brier_s, P))
|
|
||||||
|
|
||||||
rows["precision_at_K_mean"].append(prec_m.T.reshape(-1))
|
|
||||||
rows["precision_at_K_std"].append(prec_s.T.reshape(-1))
|
|
||||||
rows["recall_at_K_mean"].append(rec_m.T.reshape(-1))
|
|
||||||
rows["recall_at_K_std"].append(rec_s.T.reshape(-1))
|
|
||||||
|
|
||||||
out = {k: np.concatenate(v, axis=0) for k, v in rows.items()}
|
|
||||||
df = pd.DataFrame(out)
|
|
||||||
return df.sort_values(
|
|
||||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
|
||||||
"""Aggregate per-bin age evaluation results.
|
|
||||||
|
|
||||||
Produces both:
|
|
||||||
- macro: unweighted mean over bins with n_samples>0
|
|
||||||
- weighted: weighted mean over bins using weights=n_samples
|
|
||||||
|
|
||||||
Then aggregates across MC repetitions (mean/std).
|
|
||||||
|
|
||||||
Requires df_by_bin to include:
|
|
||||||
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
|
|
||||||
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
|
|
||||||
"""
|
|
||||||
if df_by_bin is None or len(df_by_bin) == 0:
|
|
||||||
return pd.DataFrame(
|
|
||||||
columns=[
|
|
||||||
"agg_type",
|
|
||||||
"horizon_tau",
|
|
||||||
"topk_percent",
|
|
||||||
"cause_id",
|
|
||||||
"n_mc",
|
|
||||||
"n_bins_used_mean",
|
|
||||||
"n_samples_total_mean",
|
|
||||||
"n_positives_total_mean",
|
|
||||||
"auc_mean",
|
|
||||||
"auc_std",
|
|
||||||
"auprc_mean",
|
|
||||||
"auprc_std",
|
|
||||||
"recall_at_K_mean",
|
|
||||||
"recall_at_K_std",
|
|
||||||
"precision_at_K_mean",
|
|
||||||
"precision_at_K_std",
|
|
||||||
"brier_score_mean",
|
|
||||||
"brier_score_std",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
|
||||||
g = group[group["n_samples"] > 0]
|
|
||||||
if len(g) == 0:
|
|
||||||
return pd.Series(
|
|
||||||
dict(
|
|
||||||
n_bins_used=0,
|
|
||||||
n_samples_total=0,
|
|
||||||
n_positives_total=0,
|
|
||||||
auc=float("nan"),
|
|
||||||
auprc=float("nan"),
|
|
||||||
recall_at_K=float("nan"),
|
|
||||||
precision_at_K=float("nan"),
|
|
||||||
brier_score=float("nan"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
n_bins_used = int(g["age_bin_id"].nunique())
|
|
||||||
n_samples_total = int(g["n_samples"].sum())
|
|
||||||
n_positives_total = int(g["n_positives"].sum())
|
|
||||||
|
|
||||||
if not weighted:
|
|
||||||
return pd.Series(
|
|
||||||
dict(
|
|
||||||
n_bins_used=n_bins_used,
|
|
||||||
n_samples_total=n_samples_total,
|
|
||||||
n_positives_total=n_positives_total,
|
|
||||||
auc=float(g["auc"].mean()),
|
|
||||||
auprc=float(g["auprc"].mean()),
|
|
||||||
recall_at_K=float(g["recall_at_K"].mean()),
|
|
||||||
precision_at_K=float(g["precision_at_K"].mean()),
|
|
||||||
brier_score=float(g["brier_score"].mean()),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
w = g["n_samples"].to_numpy(dtype=float)
|
|
||||||
w_sum = float(w.sum())
|
|
||||||
if w_sum <= 0.0:
|
|
||||||
return pd.Series(
|
|
||||||
dict(
|
|
||||||
n_bins_used=n_bins_used,
|
|
||||||
n_samples_total=n_samples_total,
|
|
||||||
n_positives_total=n_positives_total,
|
|
||||||
auc=float("nan"),
|
|
||||||
auprc=float("nan"),
|
|
||||||
recall_at_K=float("nan"),
|
|
||||||
precision_at_K=float("nan"),
|
|
||||||
brier_score=float("nan"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _wavg(col: str) -> float:
|
|
||||||
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
|
||||||
|
|
||||||
return pd.Series(
|
|
||||||
dict(
|
|
||||||
n_bins_used=n_bins_used,
|
|
||||||
n_samples_total=n_samples_total,
|
|
||||||
n_positives_total=n_positives_total,
|
|
||||||
auc=_wavg("auc"),
|
|
||||||
auprc=_wavg("auprc"),
|
|
||||||
recall_at_K=_wavg("recall_at_K"),
|
|
||||||
precision_at_K=_wavg("precision_at_K"),
|
|
||||||
brier_score=_wavg("brier_score"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Kept for backward compatibility (e.g., if callers load a CSV and need to
|
|
||||||
# aggregate). Prefer `aggregate_metrics_columnar` during evaluation.
|
|
||||||
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
|
||||||
|
|
||||||
df = df_by_bin[df_by_bin["n_samples"] > 0].copy()
|
|
||||||
if len(df) == 0:
|
|
||||||
return pd.DataFrame(
|
|
||||||
columns=[
|
|
||||||
"agg_type",
|
|
||||||
"horizon_tau",
|
|
||||||
"topk_percent",
|
|
||||||
"cause_id",
|
|
||||||
"n_mc",
|
|
||||||
"n_bins_used_mean",
|
|
||||||
"n_samples_total_mean",
|
|
||||||
"n_positives_total_mean",
|
|
||||||
"auc_mean",
|
|
||||||
"auc_std",
|
|
||||||
"auprc_mean",
|
|
||||||
"auprc_std",
|
|
||||||
"recall_at_K_mean",
|
|
||||||
"recall_at_K_std",
|
|
||||||
"precision_at_K_mean",
|
|
||||||
"precision_at_K_std",
|
|
||||||
"brier_score_mean",
|
|
||||||
"brier_score_std",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Macro: mean over bins
|
|
||||||
df_mc_macro = (
|
|
||||||
df.groupby(group_keys, as_index=False)
|
|
||||||
.agg(
|
|
||||||
n_bins_used=("age_bin_id", "nunique"),
|
|
||||||
n_samples_total=("n_samples", "sum"),
|
|
||||||
n_positives_total=("n_positives", "sum"),
|
|
||||||
auc=("auc", "mean"),
|
|
||||||
auprc=("auprc", "mean"),
|
|
||||||
recall_at_K=("recall_at_K", "mean"),
|
|
||||||
precision_at_K=("precision_at_K", "mean"),
|
|
||||||
brier_score=("brier_score", "mean"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
df_mc_macro["agg_type"] = "macro"
|
|
||||||
|
|
||||||
# Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric
|
|
||||||
w = df["n_samples"].astype(float)
|
|
||||||
df_w = df.copy()
|
|
||||||
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
|
|
||||||
m = df_w[col].astype(float)
|
|
||||||
ww = w.where(m.notna(), other=0.0)
|
|
||||||
df_w[f"__num_{col}"] = (m.fillna(0.0) * w)
|
|
||||||
df_w[f"__den_{col}"] = ww
|
|
||||||
|
|
||||||
df_mc_w = df_w.groupby(group_keys, as_index=False).agg(
|
|
||||||
n_bins_used=("age_bin_id", "nunique"),
|
|
||||||
n_samples_total=("n_samples", "sum"),
|
|
||||||
n_positives_total=("n_positives", "sum"),
|
|
||||||
**{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
|
|
||||||
**{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
|
|
||||||
)
|
|
||||||
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
|
|
||||||
num = df_mc_w[f"__num_{col}"].astype(float)
|
|
||||||
den = df_mc_w[f"__den_{col}"].astype(float)
|
|
||||||
df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan"))
|
|
||||||
df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True)
|
|
||||||
df_mc_w["agg_type"] = "weighted"
|
|
||||||
|
|
||||||
df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True)
|
|
||||||
|
|
||||||
df_agg = (
|
|
||||||
df_mc_binagg.groupby(
|
|
||||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False)
|
|
||||||
.agg(
|
|
||||||
n_mc=("mc_idx", "nunique"),
|
|
||||||
n_bins_used_mean=("n_bins_used", "mean"),
|
|
||||||
n_samples_total_mean=("n_samples_total", "mean"),
|
|
||||||
n_positives_total_mean=("n_positives_total", "mean"),
|
|
||||||
auc_mean=("auc", "mean"),
|
|
||||||
auc_std=("auc", "std"),
|
|
||||||
auprc_mean=("auprc", "mean"),
|
|
||||||
auprc_std=("auprc", "std"),
|
|
||||||
recall_at_K_mean=("recall_at_K", "mean"),
|
|
||||||
recall_at_K_std=("recall_at_K", "std"),
|
|
||||||
precision_at_K_mean=("precision_at_K", "mean"),
|
|
||||||
precision_at_K_std=("precision_at_K", "std"),
|
|
||||||
brier_score_mean=("brier_score", "mean"),
|
|
||||||
brier_score_std=("brier_score", "std"),
|
|
||||||
)
|
|
||||||
.sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True)
|
|
||||||
)
|
|
||||||
return df_agg
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
|
|
||||||
# NumPy/Pandas are only used for final CSV formatting/aggregation.
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvalAgeConfig:
|
|
||||||
horizons_years: Sequence[float]
|
|
||||||
age_bins: Sequence[Tuple[float, float]]
|
|
||||||
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
|
|
||||||
n_mc: int = 5
|
|
||||||
seed: int = 0
|
|
||||||
cause_ids: Optional[Sequence[int]] = None
|
|
||||||
store_per_cause: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def evaluate_time_dependent_age_bins(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
head: torch.nn.Module,
|
|
||||||
criterion,
|
|
||||||
dataloader: torch.utils.data.DataLoader,
|
|
||||||
n_disease: int,
|
|
||||||
cfg: EvalAgeConfig,
|
|
||||||
device: str | torch.device,
|
|
||||||
mc_offset: int = 0,
|
|
||||||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
||||||
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
|
|
||||||
|
|
||||||
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
|
|
||||||
- build the eligible token set within that bin
|
|
||||||
- enforce follow-up coverage: t_ctx + tau <= t_end
|
|
||||||
- randomly sample exactly one token per individual within the bin (de-dup)
|
|
||||||
- recompute context representations and predictions for that (tau, bin)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
|
|
||||||
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
|
|
||||||
"""
|
|
||||||
device = torch.device(device)
|
|
||||||
model.eval()
|
|
||||||
head.eval()
|
|
||||||
|
|
||||||
horizons_years = [float(x) for x in cfg.horizons_years]
|
|
||||||
if len(horizons_years) == 0:
|
|
||||||
raise ValueError("cfg.horizons_years must be non-empty")
|
|
||||||
|
|
||||||
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
|
|
||||||
if len(age_bins) == 0:
|
|
||||||
raise ValueError("cfg.age_bins must be non-empty")
|
|
||||||
for (a, b) in age_bins:
|
|
||||||
if not (b > a):
|
|
||||||
raise ValueError(
|
|
||||||
f"age_bins must be (low, high) with high>low; got {(a, b)}")
|
|
||||||
|
|
||||||
topk_percents = [float(x) for x in cfg.topk_percents]
|
|
||||||
if len(topk_percents) == 0:
|
|
||||||
raise ValueError("cfg.topk_percents must be non-empty")
|
|
||||||
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
|
|
||||||
raise ValueError(
|
|
||||||
f"All topk_percents must be in (0,100]; got {topk_percents}")
|
|
||||||
|
|
||||||
if int(cfg.n_mc) <= 0:
|
|
||||||
raise ValueError("cfg.n_mc must be >= 1")
|
|
||||||
|
|
||||||
if cfg.cause_ids is None:
|
|
||||||
cause_ids = None
|
|
||||||
n_causes_eval = int(n_disease)
|
|
||||||
cause_id_vec = np.arange(n_causes_eval, dtype=int)
|
|
||||||
else:
|
|
||||||
cause_ids = torch.tensor(
|
|
||||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
|
||||||
n_causes_eval = int(cause_ids.numel())
|
|
||||||
cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int)
|
|
||||||
|
|
||||||
topk_percents_np = np.asarray(topk_percents, dtype=float)
|
|
||||||
|
|
||||||
# Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends.
|
|
||||||
blocks: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
for mc_idx in range(int(cfg.n_mc)):
|
|
||||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
|
||||||
|
|
||||||
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
|
|
||||||
# This keeps computations GPU-first while preventing a factor-n_mc
|
|
||||||
# blow-up in GPU memory.
|
|
||||||
y_true_mc: List[List[List[torch.Tensor]]] = [
|
|
||||||
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
|
|
||||||
]
|
|
||||||
y_pred_mc: List[List[List[torch.Tensor]]] = [
|
|
||||||
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
|
|
||||||
]
|
|
||||||
|
|
||||||
# tqdm over batches; include MC idx in description.
|
|
||||||
for batch_idx, batch in enumerate(
|
|
||||||
tqdm(dataloader,
|
|
||||||
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
|
|
||||||
):
|
|
||||||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
|
||||||
event_seq = event_seq.to(device)
|
|
||||||
time_seq = time_seq.to(device)
|
|
||||||
cont_feats = cont_feats.to(device)
|
|
||||||
cate_feats = cate_feats.to(device)
|
|
||||||
sexes = sexes.to(device)
|
|
||||||
|
|
||||||
B = int(event_seq.size(0))
|
|
||||||
b = torch.arange(B, device=device)
|
|
||||||
|
|
||||||
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
|
|
||||||
# within this batch, so this is safe and numerically identical.
|
|
||||||
h = model(event_seq, time_seq, sexes,
|
|
||||||
cont_feats, cate_feats) # (B,L,D)
|
|
||||||
|
|
||||||
for tau_idx, tau_y in enumerate(horizons_years):
|
|
||||||
tau_tensor = torch.tensor(float(tau_y), device=device)
|
|
||||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
|
||||||
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
|
||||||
seed = (
|
|
||||||
int(cfg.seed)
|
|
||||||
+ (100_000 * int(global_mc_idx))
|
|
||||||
+ (1_000 * int(tau_idx))
|
|
||||||
+ (10 * int(bin_idx))
|
|
||||||
+ int(batch_idx)
|
|
||||||
)
|
|
||||||
|
|
||||||
keep, t_ctx = sample_context_in_fixed_age_bin(
|
|
||||||
event_seq=event_seq,
|
|
||||||
time_seq=time_seq,
|
|
||||||
tau_years=float(tau_y),
|
|
||||||
age_bin=(float(a_lo), float(a_hi)),
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
if not keep.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Bin-specific prediction: context indices differ per (tau, bin)
|
|
||||||
# but the backbone features do not.
|
|
||||||
c = h[b, t_ctx]
|
|
||||||
logits = head(c)
|
|
||||||
|
|
||||||
cifs = criterion.calculate_cifs(
|
|
||||||
logits, taus=tau_tensor
|
|
||||||
)
|
|
||||||
if cifs.ndim != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"criterion.calculate_cifs must return (B,K) for scalar tau; "
|
|
||||||
f"got shape={tuple(cifs.shape)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cause_ids is None:
|
|
||||||
y = multi_hot_ever_within_horizon(
|
|
||||||
event_seq=event_seq,
|
|
||||||
time_seq=time_seq,
|
|
||||||
t_ctx=t_ctx,
|
|
||||||
tau_years=float(tau_y),
|
|
||||||
n_disease=n_disease,
|
|
||||||
)
|
|
||||||
preds = cifs
|
|
||||||
else:
|
|
||||||
y = multi_hot_selected_causes_within_horizon(
|
|
||||||
event_seq=event_seq,
|
|
||||||
time_seq=time_seq,
|
|
||||||
t_ctx=t_ctx,
|
|
||||||
tau_years=float(tau_y),
|
|
||||||
cause_ids=cause_ids,
|
|
||||||
n_disease=n_disease,
|
|
||||||
)
|
|
||||||
preds = cifs.index_select(dim=1, index=cause_ids)
|
|
||||||
|
|
||||||
y_true_mc[tau_idx][bin_idx].append(
|
|
||||||
y[keep].detach().to(dtype=torch.bool)
|
|
||||||
)
|
|
||||||
y_pred_mc[tau_idx][bin_idx].append(
|
|
||||||
preds[keep].detach().to(dtype=torch.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Aggregate this MC immediately (frees GPU memory before next MC).
|
|
||||||
for h_idx, tau_y in enumerate(horizons_years):
|
|
||||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
|
||||||
if len(y_true_mc[h_idx][bin_idx]) == 0:
|
|
||||||
# No samples in this bin for this (mc, tau): store a single
|
|
||||||
# block with NaN metric vectors.
|
|
||||||
K = int(n_causes_eval)
|
|
||||||
P = int(topk_percents_np.size)
|
|
||||||
blocks.append(
|
|
||||||
dict(
|
|
||||||
mc_idx=global_mc_idx,
|
|
||||||
age_bin_id=bin_idx,
|
|
||||||
age_bin_low=float(a_lo),
|
|
||||||
age_bin_high=float(a_hi),
|
|
||||||
horizon_tau=float(tau_y),
|
|
||||||
n_samples=0,
|
|
||||||
cause_id=cause_id_vec,
|
|
||||||
n_positives=np.zeros((K,), dtype=int),
|
|
||||||
auc=np.full((K,), np.nan, dtype=float),
|
|
||||||
auprc=np.full((K,), np.nan, dtype=float),
|
|
||||||
brier_score=np.full((K,), np.nan, dtype=float),
|
|
||||||
precision_at_K=np.full((P, K), np.nan, dtype=float),
|
|
||||||
recall_at_K=np.full((P, K), np.nan, dtype=float),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
|
|
||||||
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
|
|
||||||
if tuple(yb_t.shape) != tuple(pb_t.shape):
|
|
||||||
raise ValueError(
|
|
||||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
n_samples = int(yb_t.size(0))
|
|
||||||
|
|
||||||
metrics = compute_binary_metrics_torch(
|
|
||||||
y_true=yb_t,
|
|
||||||
y_pred=pb_t,
|
|
||||||
k_percents=topk_percents,
|
|
||||||
tie_mode="exact",
|
|
||||||
chunk_size=128,
|
|
||||||
compute_ici=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect a single columnar block (vectors, not per-row dicts).
|
|
||||||
blocks.append(
|
|
||||||
dict(
|
|
||||||
mc_idx=global_mc_idx,
|
|
||||||
age_bin_id=bin_idx,
|
|
||||||
age_bin_low=float(a_lo),
|
|
||||||
age_bin_high=float(a_hi),
|
|
||||||
horizon_tau=float(tau_y),
|
|
||||||
n_samples=int(n_samples),
|
|
||||||
cause_id=cause_id_vec,
|
|
||||||
n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int),
|
|
||||||
auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float),
|
|
||||||
auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float),
|
|
||||||
brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float),
|
|
||||||
precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float),
|
|
||||||
recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Aggregation is computed from columnar blocks (fast, no pandas apply).
|
|
||||||
df_agg = aggregate_metrics_columnar(
|
|
||||||
blocks,
|
|
||||||
topk_percents=topk_percents_np,
|
|
||||||
cause_id=cause_id_vec,
|
|
||||||
)
|
|
||||||
|
|
||||||
if bool(cfg.store_per_cause):
|
|
||||||
df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np)
|
|
||||||
else:
|
|
||||||
df_by_bin = pd.DataFrame(
|
|
||||||
columns=[
|
|
||||||
"mc_idx",
|
|
||||||
"age_bin_id",
|
|
||||||
"age_bin_low",
|
|
||||||
"age_bin_high",
|
|
||||||
"horizon_tau",
|
|
||||||
"topk_percent",
|
|
||||||
"cause_id",
|
|
||||||
"n_samples",
|
|
||||||
"n_positives",
|
|
||||||
"auc",
|
|
||||||
"auprc",
|
|
||||||
"recall_at_K",
|
|
||||||
"precision_at_K",
|
|
||||||
"brier_score",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return df_by_bin, df_agg
|
|
||||||
207
utils.py
207
utils.py
@@ -1,207 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
DAYS_PER_YEAR = 365.25
|
|
||||||
|
|
||||||
|
|
||||||
def sample_context_in_fixed_age_bin(
|
|
||||||
event_seq: torch.Tensor,
|
|
||||||
time_seq: torch.Tensor,
|
|
||||||
tau_years: float,
|
|
||||||
age_bin: Tuple[float, float],
|
|
||||||
seed: int,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Sample one context token per individual within a fixed age bin.
|
|
||||||
|
|
||||||
Delphi-2M semantics for a specific (tau, age_bin):
|
|
||||||
- Token times are interpreted as age in *days* (converted to years).
|
|
||||||
- Follow-up end time is the last valid token time per individual.
|
|
||||||
- A token index j is eligible iff:
|
|
||||||
(token is valid)
|
|
||||||
AND (age_years in [age_low, age_high))
|
|
||||||
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
|
|
||||||
- For each individual, randomly select exactly one eligible token in this bin.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_seq: (B, L) token ids, 0 is padding.
|
|
||||||
time_seq: (B, L) token times in days.
|
|
||||||
tau_years: horizon length in years.
|
|
||||||
age_bin: (low, high) bounds in years, interpreted as [low, high).
|
|
||||||
seed: RNG seed for deterministic sampling.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
keep: (B,) bool, True if a context was sampled for this bin.
|
|
||||||
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
|
|
||||||
"""
|
|
||||||
low, high = float(age_bin[0]), float(age_bin[1])
|
|
||||||
if not (high > low):
|
|
||||||
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
|
|
||||||
|
|
||||||
device = event_seq.device
|
|
||||||
B, _ = event_seq.shape
|
|
||||||
|
|
||||||
valid = event_seq != 0
|
|
||||||
lengths = valid.sum(dim=1)
|
|
||||||
last_idx = torch.clamp(lengths - 1, min=0)
|
|
||||||
b = torch.arange(B, device=device)
|
|
||||||
followup_end_time = time_seq[b, last_idx] # (B,)
|
|
||||||
|
|
||||||
tau_days = float(tau_years) * DAYS_PER_YEAR
|
|
||||||
age_years = time_seq / DAYS_PER_YEAR
|
|
||||||
|
|
||||||
in_bin = (age_years >= low) & (age_years < high)
|
|
||||||
eligible = valid & in_bin & (
|
|
||||||
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
|
||||||
|
|
||||||
# Vectorized, uniform sampling over eligible indices per sample.
|
|
||||||
# Using argmax of i.i.d. Uniform(0,1) over eligible positions yields a uniform
|
|
||||||
# choice among eligible indices by symmetry (ties have probability ~0).
|
|
||||||
keep = eligible.any(dim=1)
|
|
||||||
|
|
||||||
# Prefer a per-call generator on the target device for reproducibility without
|
|
||||||
# touching global RNG state. If unavailable, fall back to seeding the global
|
|
||||||
# CUDA RNG for this call.
|
|
||||||
gen = None
|
|
||||||
if device.type == "cuda":
|
|
||||||
try:
|
|
||||||
gen = torch.Generator(device=device)
|
|
||||||
gen.manual_seed(int(seed))
|
|
||||||
except Exception:
|
|
||||||
gen = None
|
|
||||||
torch.cuda.manual_seed(int(seed))
|
|
||||||
else:
|
|
||||||
gen = torch.Generator()
|
|
||||||
gen.manual_seed(int(seed))
|
|
||||||
|
|
||||||
r = torch.rand((B, eligible.size(1)), device=device, generator=gen)
|
|
||||||
r = r.masked_fill(~eligible, -1.0)
|
|
||||||
t_ctx = r.argmax(dim=1).to(torch.long)
|
|
||||||
|
|
||||||
# When keep=False, t_ctx is arbitrary (argmax over all -1 yields 0).
|
|
||||||
return keep, t_ctx
|
|
||||||
|
|
||||||
|
|
||||||
def select_context_indices(
|
|
||||||
event_seq: torch.Tensor,
|
|
||||||
time_seq: torch.Tensor,
|
|
||||||
offset_years: float,
|
|
||||||
tau_years: float = 0.0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Select per-sample prediction context index.
|
|
||||||
|
|
||||||
IMPORTANT SEMANTICS:
|
|
||||||
- The last observed token time is treated as the FOLLOW-UP END time.
|
|
||||||
- We pick the last valid token with time <= (followup_end_time - offset).
|
|
||||||
- We do NOT interpret followup_end_time as an event time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
keep_mask: (B,) bool, which samples have a valid context
|
|
||||||
t_ctx: (B,) long, index into sequence
|
|
||||||
t_ctx_time: (B,) float, time (days) at context
|
|
||||||
"""
|
|
||||||
# valid tokens are event != 0 (padding is 0)
|
|
||||||
valid = event_seq != 0
|
|
||||||
lengths = valid.sum(dim=1)
|
|
||||||
last_idx = torch.clamp(lengths - 1, min=0)
|
|
||||||
|
|
||||||
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
|
||||||
followup_end_time = time_seq[b, last_idx]
|
|
||||||
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
|
||||||
|
|
||||||
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
|
||||||
eligible_counts = eligible.sum(dim=1)
|
|
||||||
keep = eligible_counts > 0
|
|
||||||
|
|
||||||
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
|
||||||
t_ctx_time = time_seq[b, t_ctx]
|
|
||||||
|
|
||||||
# Horizon-aligned eligibility: require enough follow-up time after the selected context.
|
|
||||||
# All times are in days.
|
|
||||||
keep = keep & (followup_end_time >= (
|
|
||||||
t_ctx_time + (tau_years * DAYS_PER_YEAR)))
|
|
||||||
|
|
||||||
return keep, t_ctx, t_ctx_time
|
|
||||||
|
|
||||||
|
|
||||||
def multi_hot_ever_within_horizon(
|
|
||||||
event_seq: torch.Tensor,
|
|
||||||
time_seq: torch.Tensor,
|
|
||||||
t_ctx: torch.Tensor,
|
|
||||||
tau_years: float,
|
|
||||||
n_disease: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
|
||||||
B, L = event_seq.shape
|
|
||||||
b = torch.arange(B, device=event_seq.device)
|
|
||||||
t0 = time_seq[b, t_ctx]
|
|
||||||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
|
||||||
|
|
||||||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
|
||||||
# Include same-day events after context, exclude any token at/before context index.
|
|
||||||
in_window = (
|
|
||||||
(idxs > t_ctx.unsqueeze(1))
|
|
||||||
& (time_seq >= t0.unsqueeze(1))
|
|
||||||
& (time_seq <= t1.unsqueeze(1))
|
|
||||||
& (event_seq >= 2)
|
|
||||||
& (event_seq != 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not in_window.any():
|
|
||||||
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
|
||||||
|
|
||||||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
|
||||||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
|
||||||
|
|
||||||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
|
||||||
y[b_idx, disease_ids] = True
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def multi_hot_selected_causes_within_horizon(
|
|
||||||
event_seq: torch.Tensor,
|
|
||||||
time_seq: torch.Tensor,
|
|
||||||
t_ctx: torch.Tensor,
|
|
||||||
tau_years: float,
|
|
||||||
cause_ids: torch.Tensor,
|
|
||||||
n_disease: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
|
||||||
B, L = event_seq.shape
|
|
||||||
device = event_seq.device
|
|
||||||
b = torch.arange(B, device=device)
|
|
||||||
t0 = time_seq[b, t_ctx]
|
|
||||||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
|
||||||
|
|
||||||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
|
||||||
in_window = (
|
|
||||||
(idxs > t_ctx.unsqueeze(1))
|
|
||||||
& (time_seq >= t0.unsqueeze(1))
|
|
||||||
& (time_seq <= t1.unsqueeze(1))
|
|
||||||
& (event_seq >= 2)
|
|
||||||
& (event_seq != 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
|
||||||
if not in_window.any():
|
|
||||||
return out
|
|
||||||
|
|
||||||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
|
||||||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
|
||||||
|
|
||||||
# Filter to selected causes via a boolean membership mask over the global disease space.
|
|
||||||
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
|
||||||
selected[cause_ids] = True
|
|
||||||
keep = selected[disease_ids]
|
|
||||||
if not keep.any():
|
|
||||||
return out
|
|
||||||
|
|
||||||
b_idx = b_idx[keep]
|
|
||||||
disease_ids = disease_ids[keep]
|
|
||||||
|
|
||||||
# Map disease_id -> local index in cause_ids
|
|
||||||
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
|
||||||
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
|
||||||
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
|
||||||
local = lookup[disease_ids]
|
|
||||||
out[b_idx, local] = True
|
|
||||||
return out
|
|
||||||
Reference in New Issue
Block a user