Add evaluation scripts for age-bin time-dependent metrics and remove obsolete evaluation_time_dependent.py
This commit is contained in:
@@ -4,13 +4,13 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List, Sequence
|
from typing import List, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
|
||||||
from dataset import HealthDataset, health_collate_fn
|
from dataset import HealthDataset, health_collate_fn
|
||||||
from evaluation_time_dependent import EvalConfig, evaluate_time_dependent
|
from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins
|
||||||
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
||||||
from model import DelphiFork, SapDelphi, SimpleHead
|
from model import DelphiFork, SapDelphi, SimpleHead
|
||||||
|
|
||||||
@@ -25,6 +25,20 @@ def _parse_floats(items: Sequence[str]) -> List[float]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]:
|
||||||
|
vals = _parse_floats(edges)
|
||||||
|
if len(vals) < 2:
|
||||||
|
raise ValueError("--age_bin_edges must have at least 2 values")
|
||||||
|
for i in range(1, len(vals)):
|
||||||
|
if not (vals[i] > vals[i - 1]):
|
||||||
|
raise ValueError("--age_bin_edges must be strictly increasing")
|
||||||
|
return vals
|
||||||
|
|
||||||
|
|
||||||
|
def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
|
||||||
|
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
|
||||||
|
|
||||||
|
|
||||||
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
||||||
if loss_type == "exponential":
|
if loss_type == "exponential":
|
||||||
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
||||||
@@ -90,44 +104,48 @@ def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Time-dependent evaluation for DeepHealth")
|
description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--run_dir",
|
"--run_dir",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Training run directory (contains best_model.pt and train_config.json)",
|
help="Training run directory (contains best_model.pt and train_config.json)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--data_prefix", type=str, default=None,
|
parser.add_argument("--data_prefix", type=str, default=None)
|
||||||
help="Dataset prefix (overrides config if provided)")
|
|
||||||
parser.add_argument("--split", type=str,
|
parser.add_argument("--split", type=str,
|
||||||
choices=["train", "val", "test", "all"], default="val")
|
choices=["train", "val", "test", "all"], default="val")
|
||||||
|
|
||||||
parser.add_argument("--horizons", type=str, nargs="+",
|
parser.add_argument("--horizons", type=str, nargs="+",
|
||||||
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)")
|
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"])
|
||||||
parser.add_argument("--offset_years", type=float, default=0.0,
|
parser.add_argument(
|
||||||
help="Context selection offset (years before follow-up end)")
|
"--age_bin_edges",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=["40", "45", "50", "55", "60", "65", "70", "75", "80"],
|
||||||
|
help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--topk_percent",
|
"--topk_percent",
|
||||||
type=float,
|
type=float,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=[1, 5, 10, 20, 50],
|
default=[1, 5, 10, 20, 50],
|
||||||
help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)",
|
help="One or more K%% values for recall/precision@K%%",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--n_mc", type=int, default=5)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument("--device", type=str,
|
parser.add_argument("--device", type=str,
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
parser.add_argument("--batch_size", type=int, default=256)
|
parser.add_argument("--batch_size", type=int, default=256)
|
||||||
parser.add_argument("--num_workers", type=int,
|
parser.add_argument("--num_workers", type=int, default=0)
|
||||||
default=0, help="Keep 0 on Windows")
|
|
||||||
|
|
||||||
parser.add_argument("--out_csv", type=str, default=None,
|
parser.add_argument("--out_prefix", type=str,
|
||||||
help="Optional output CSV path")
|
default=None, help="Output prefix for CSVs")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
|
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
|
||||||
cfg_path = os.path.join(args.run_dir, "train_config.json")
|
cfg_path = os.path.join(args.run_dir, "train_config.json")
|
||||||
|
|
||||||
if not os.path.exists(ckpt_path):
|
if not os.path.exists(ckpt_path):
|
||||||
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
|
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
|
||||||
if not os.path.exists(cfg_path):
|
if not os.path.exists(cfg_path):
|
||||||
@@ -139,24 +157,23 @@ def main() -> None:
|
|||||||
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
|
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
|
||||||
"data_prefix", "ukb")
|
"data_prefix", "ukb")
|
||||||
|
|
||||||
# Match training covariate selection.
|
|
||||||
full_cov = bool(cfg.get("full_cov", False))
|
full_cov = bool(cfg.get("full_cov", False))
|
||||||
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
||||||
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
|
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
|
||||||
|
|
||||||
# Recreate the same split scheme as train.py
|
|
||||||
train_ratio = float(cfg.get("train_ratio", 0.7))
|
train_ratio = float(cfg.get("train_ratio", 0.7))
|
||||||
val_ratio = float(cfg.get("val_ratio", 0.15))
|
val_ratio = float(cfg.get("val_ratio", 0.15))
|
||||||
seed = int(cfg.get("random_seed", 42))
|
seed_split = int(cfg.get("random_seed", 42))
|
||||||
|
|
||||||
n_total = len(dataset)
|
n_total = len(dataset)
|
||||||
n_train = int(n_total * train_ratio)
|
n_train = int(n_total * train_ratio)
|
||||||
n_val = int(n_total * val_ratio)
|
n_val = int(n_total * val_ratio)
|
||||||
n_test = n_total - n_train - n_val
|
n_test = n_total - n_train - n_val
|
||||||
|
|
||||||
train_ds, val_ds, test_ds = random_split(
|
train_ds, val_ds, test_ds = random_split(
|
||||||
dataset,
|
dataset,
|
||||||
[n_train, n_val, n_test],
|
[n_train, n_val, n_test],
|
||||||
generator=torch.Generator().manual_seed(seed),
|
generator=torch.Generator().manual_seed(seed_split),
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.split == "train":
|
if args.split == "train":
|
||||||
@@ -203,14 +220,19 @@ def main() -> None:
|
|||||||
head.to(device)
|
head.to(device)
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
eval_cfg = EvalConfig(
|
age_edges = _parse_age_bin_edges(args.age_bin_edges)
|
||||||
|
age_bins = _edges_to_bins(age_edges)
|
||||||
|
|
||||||
|
eval_cfg = EvalAgeConfig(
|
||||||
horizons_years=_parse_floats(args.horizons),
|
horizons_years=_parse_floats(args.horizons),
|
||||||
offset_years=float(args.offset_years),
|
age_bins=age_bins,
|
||||||
topk_percents=[float(x) for x in args.topk_percent],
|
topk_percents=[float(x) for x in args.topk_percent],
|
||||||
|
n_mc=int(args.n_mc),
|
||||||
|
seed=int(args.seed),
|
||||||
cause_ids=None,
|
cause_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
df = evaluate_time_dependent(
|
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
|
||||||
model=model,
|
model=model,
|
||||||
head=head,
|
head=head,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
@@ -220,14 +242,20 @@ def main() -> None:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.out_csv is None:
|
if args.out_prefix is None:
|
||||||
out_csv = os.path.join(
|
out_prefix = os.path.join(
|
||||||
args.run_dir, f"time_dependent_metrics_{args.split}.csv")
|
args.run_dir, f"age_bin_time_dependent_{args.split}")
|
||||||
else:
|
else:
|
||||||
out_csv = args.out_csv
|
out_prefix = args.out_prefix
|
||||||
|
|
||||||
df.to_csv(out_csv, index=False)
|
out_bin = out_prefix + "_by_bin.csv"
|
||||||
print(f"Wrote: {out_csv}")
|
out_agg = out_prefix + "_agg.csv"
|
||||||
|
|
||||||
|
df_by_bin.to_csv(out_bin, index=False)
|
||||||
|
df_agg.to_csv(out_agg, index=False)
|
||||||
|
|
||||||
|
print(f"Wrote: {out_bin}")
|
||||||
|
print(f"Wrote: {out_agg}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
469
evaluation_age_time_dependent.py
Normal file
469
evaluation_age_time_dependent.py
Normal file
@@ -0,0 +1,469 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
|
||||||
|
def tqdm(x, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
multi_hot_ever_within_horizon,
|
||||||
|
multi_hot_selected_causes_within_horizon,
|
||||||
|
sample_context_in_fixed_age_bin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||||
|
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
||||||
|
|
||||||
|
Returns NaN if y_true has no positives or no negatives.
|
||||||
|
|
||||||
|
Uses the Mann–Whitney U statistic with average ranks for ties.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(bool)
|
||||||
|
y_score = np.asarray(y_score).astype(float)
|
||||||
|
|
||||||
|
n = y_true.size
|
||||||
|
if n == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
n_neg = n - n_pos
|
||||||
|
if n_pos == 0 or n_neg == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
# Rank scores ascending, average ranks for ties.
|
||||||
|
order = np.argsort(y_score, kind="mergesort")
|
||||||
|
sorted_scores = y_score[order]
|
||||||
|
|
||||||
|
ranks = np.empty(n, dtype=float)
|
||||||
|
i = 0
|
||||||
|
# 1-based ranks
|
||||||
|
while i < n:
|
||||||
|
j = i + 1
|
||||||
|
while j < n and sorted_scores[j] == sorted_scores[i]:
|
||||||
|
j += 1
|
||||||
|
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
|
||||||
|
ranks[order[i:j]] = avg_rank
|
||||||
|
i = j
|
||||||
|
|
||||||
|
sum_ranks_pos = float(ranks[y_true].sum())
|
||||||
|
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
|
||||||
|
return float(u / (n_pos * n_neg))
|
||||||
|
|
||||||
|
|
||||||
|
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||||
|
"""Average precision (area under PR curve using step-wise interpolation).
|
||||||
|
|
||||||
|
Returns NaN if no positives.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(bool)
|
||||||
|
y_score = np.asarray(y_score).astype(float)
|
||||||
|
|
||||||
|
n = y_true.size
|
||||||
|
if n == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
if n_pos == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
order = np.argsort(-y_score, kind="mergesort")
|
||||||
|
y = y_true[order]
|
||||||
|
|
||||||
|
tp = np.cumsum(y).astype(float)
|
||||||
|
fp = np.cumsum(~y).astype(float)
|
||||||
|
precision = tp / np.maximum(tp + fp, 1.0)
|
||||||
|
|
||||||
|
# AP = sum over each positive of precision at that point / n_pos
|
||||||
|
# (equivalent to ∑ Δrecall * precision)
|
||||||
|
ap = float(np.sum(precision[y]) / n_pos)
|
||||||
|
# handle potential tiny numerical overshoots
|
||||||
|
return float(max(0.0, min(1.0, ap)))
|
||||||
|
|
||||||
|
|
||||||
|
def _precision_recall_at_k_percent(
|
||||||
|
y_true: np.ndarray,
|
||||||
|
y_score: np.ndarray,
|
||||||
|
k_percent: float,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""Precision@K% and Recall@K% for binary labels.
|
||||||
|
|
||||||
|
Returns (precision, recall). Returns NaN for recall if no positives.
|
||||||
|
Returns NaN for precision if k leads to 0 selected.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(bool)
|
||||||
|
y_score = np.asarray(y_score).astype(float)
|
||||||
|
|
||||||
|
n = y_true.size
|
||||||
|
if n == 0:
|
||||||
|
return float("nan"), float("nan")
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
|
||||||
|
k = int(math.ceil((float(k_percent) / 100.0) * n))
|
||||||
|
if k <= 0:
|
||||||
|
return float("nan"), float("nan")
|
||||||
|
|
||||||
|
order = np.argsort(-y_score, kind="mergesort")
|
||||||
|
top = order[:k]
|
||||||
|
tp_top = int(y_true[top].sum())
|
||||||
|
|
||||||
|
precision = tp_top / k
|
||||||
|
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
|
||||||
|
return float(precision), float(recall)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalAgeConfig:
|
||||||
|
horizons_years: Sequence[float]
|
||||||
|
age_bins: Sequence[Tuple[float, float]]
|
||||||
|
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
|
||||||
|
n_mc: int = 5
|
||||||
|
seed: int = 0
|
||||||
|
cause_ids: Optional[Sequence[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_time_dependent_age_bins(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
head: torch.nn.Module,
|
||||||
|
criterion,
|
||||||
|
dataloader: torch.utils.data.DataLoader,
|
||||||
|
n_disease: int,
|
||||||
|
cfg: EvalAgeConfig,
|
||||||
|
device: str | torch.device,
|
||||||
|
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||||
|
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
|
||||||
|
|
||||||
|
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
|
||||||
|
- build the eligible token set within that bin
|
||||||
|
- enforce follow-up coverage: t_ctx + tau <= t_end
|
||||||
|
- randomly sample exactly one token per individual within the bin (de-dup)
|
||||||
|
- recompute context representations and predictions for that (tau, bin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
|
||||||
|
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
|
||||||
|
"""
|
||||||
|
device = torch.device(device)
|
||||||
|
model.eval()
|
||||||
|
head.eval()
|
||||||
|
|
||||||
|
horizons_years = [float(x) for x in cfg.horizons_years]
|
||||||
|
if len(horizons_years) == 0:
|
||||||
|
raise ValueError("cfg.horizons_years must be non-empty")
|
||||||
|
|
||||||
|
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
|
||||||
|
if len(age_bins) == 0:
|
||||||
|
raise ValueError("cfg.age_bins must be non-empty")
|
||||||
|
for (a, b) in age_bins:
|
||||||
|
if not (b > a):
|
||||||
|
raise ValueError(
|
||||||
|
f"age_bins must be (low, high) with high>low; got {(a, b)}")
|
||||||
|
|
||||||
|
topk_percents = [float(x) for x in cfg.topk_percents]
|
||||||
|
if len(topk_percents) == 0:
|
||||||
|
raise ValueError("cfg.topk_percents must be non-empty")
|
||||||
|
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
|
||||||
|
raise ValueError(
|
||||||
|
f"All topk_percents must be in (0,100]; got {topk_percents}")
|
||||||
|
|
||||||
|
if int(cfg.n_mc) <= 0:
|
||||||
|
raise ValueError("cfg.n_mc must be >= 1")
|
||||||
|
|
||||||
|
if cfg.cause_ids is None:
|
||||||
|
cause_ids = None
|
||||||
|
n_causes_eval = int(n_disease)
|
||||||
|
else:
|
||||||
|
cause_ids = torch.tensor(
|
||||||
|
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||||
|
n_causes_eval = int(cause_ids.numel())
|
||||||
|
|
||||||
|
# Storage: (mc, h, bin) -> list of arrays
|
||||||
|
y_true: List[List[List[List[np.ndarray]]]] = [
|
||||||
|
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||||
|
for _ in range(int(cfg.n_mc))
|
||||||
|
]
|
||||||
|
y_pred: List[List[List[List[np.ndarray]]]] = [
|
||||||
|
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||||
|
for _ in range(int(cfg.n_mc))
|
||||||
|
]
|
||||||
|
|
||||||
|
for mc_idx in range(int(cfg.n_mc)):
|
||||||
|
# tqdm over batches; include MC idx in description.
|
||||||
|
for batch_idx, batch in enumerate(
|
||||||
|
tqdm(dataloader,
|
||||||
|
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
|
||||||
|
):
|
||||||
|
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||||||
|
event_seq = event_seq.to(device)
|
||||||
|
time_seq = time_seq.to(device)
|
||||||
|
cont_feats = cont_feats.to(device)
|
||||||
|
cate_feats = cate_feats.to(device)
|
||||||
|
sexes = sexes.to(device)
|
||||||
|
|
||||||
|
B = int(event_seq.size(0))
|
||||||
|
b = torch.arange(B, device=device)
|
||||||
|
|
||||||
|
for tau_idx, tau_y in enumerate(horizons_years):
|
||||||
|
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||||
|
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
||||||
|
seed = (
|
||||||
|
int(cfg.seed)
|
||||||
|
+ (100_000 * int(mc_idx))
|
||||||
|
+ (1_000 * int(tau_idx))
|
||||||
|
+ (10 * int(bin_idx))
|
||||||
|
+ int(batch_idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
keep, t_ctx = sample_context_in_fixed_age_bin(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
tau_years=float(tau_y),
|
||||||
|
age_bin=(float(a_lo), float(a_hi)),
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
if not keep.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
|
||||||
|
h = model(event_seq, time_seq, sexes,
|
||||||
|
cont_feats, cate_feats) # (B,L,D)
|
||||||
|
c = h[b, t_ctx]
|
||||||
|
logits = head(c)
|
||||||
|
|
||||||
|
cifs = criterion.calculate_cifs(
|
||||||
|
logits, taus=torch.tensor(float(tau_y), device=device)
|
||||||
|
)
|
||||||
|
if cifs.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"criterion.calculate_cifs must return (B,K) for scalar tau; "
|
||||||
|
f"got shape={tuple(cifs.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cause_ids is None:
|
||||||
|
y = multi_hot_ever_within_horizon(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
t_ctx=t_ctx,
|
||||||
|
tau_years=float(tau_y),
|
||||||
|
n_disease=n_disease,
|
||||||
|
)
|
||||||
|
preds = cifs
|
||||||
|
else:
|
||||||
|
y = multi_hot_selected_causes_within_horizon(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
t_ctx=t_ctx,
|
||||||
|
tau_years=float(tau_y),
|
||||||
|
cause_ids=cause_ids,
|
||||||
|
n_disease=n_disease,
|
||||||
|
)
|
||||||
|
preds = cifs.index_select(dim=1, index=cause_ids)
|
||||||
|
|
||||||
|
y_true[mc_idx][tau_idx][bin_idx].append(
|
||||||
|
y[keep].detach().to(torch.bool).cpu().numpy()
|
||||||
|
)
|
||||||
|
y_pred[mc_idx][tau_idx][bin_idx].append(
|
||||||
|
preds[keep].detach().to(torch.float32).cpu().numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
rows_by_bin: List[Dict[str, float | int]] = []
|
||||||
|
|
||||||
|
for mc_idx in range(int(cfg.n_mc)):
|
||||||
|
for h_idx, tau_y in enumerate(horizons_years):
|
||||||
|
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||||
|
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
|
||||||
|
# No samples in this bin for this (mc, tau)
|
||||||
|
for cause_k in range(n_causes_eval):
|
||||||
|
cause_id = int(cause_k) if cause_ids is None else int(
|
||||||
|
cfg.cause_ids[cause_k])
|
||||||
|
for k_percent in topk_percents:
|
||||||
|
rows_by_bin.append(
|
||||||
|
dict(
|
||||||
|
mc_idx=mc_idx,
|
||||||
|
age_bin_id=bin_idx,
|
||||||
|
age_bin_low=float(a_lo),
|
||||||
|
age_bin_high=float(a_hi),
|
||||||
|
horizon_tau=float(tau_y),
|
||||||
|
topk_percent=float(k_percent),
|
||||||
|
cause_id=cause_id,
|
||||||
|
n_samples=0,
|
||||||
|
n_positives=0,
|
||||||
|
auc=float("nan"),
|
||||||
|
auprc=float("nan"),
|
||||||
|
recall_at_K=float("nan"),
|
||||||
|
precision_at_K=float("nan"),
|
||||||
|
brier_score=float("nan"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
|
||||||
|
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
|
||||||
|
if yb.shape != pb.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
n_samples = int(yb.shape[0])
|
||||||
|
|
||||||
|
for cause_k in range(n_causes_eval):
|
||||||
|
yk = yb[:, cause_k]
|
||||||
|
pk = pb[:, cause_k]
|
||||||
|
n_pos = int(yk.sum())
|
||||||
|
|
||||||
|
auc = _binary_roc_auc(yk, pk)
|
||||||
|
auprc = _average_precision(yk, pk)
|
||||||
|
brier = float(np.mean(
|
||||||
|
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
|
||||||
|
|
||||||
|
cause_id = int(cause_k) if cause_ids is None else int(
|
||||||
|
cfg.cause_ids[cause_k])
|
||||||
|
|
||||||
|
for k_percent in topk_percents:
|
||||||
|
precision_k, recall_k = _precision_recall_at_k_percent(
|
||||||
|
yk, pk, float(k_percent))
|
||||||
|
rows_by_bin.append(
|
||||||
|
dict(
|
||||||
|
mc_idx=mc_idx,
|
||||||
|
age_bin_id=bin_idx,
|
||||||
|
age_bin_low=float(a_lo),
|
||||||
|
age_bin_high=float(a_hi),
|
||||||
|
horizon_tau=float(tau_y),
|
||||||
|
topk_percent=float(k_percent),
|
||||||
|
cause_id=cause_id,
|
||||||
|
n_samples=n_samples,
|
||||||
|
n_positives=n_pos,
|
||||||
|
auc=float(auc),
|
||||||
|
auprc=float(auprc),
|
||||||
|
recall_at_K=float(recall_k),
|
||||||
|
precision_at_K=float(precision_k),
|
||||||
|
brier_score=float(brier),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
df_by_bin = pd.DataFrame(rows_by_bin)
|
||||||
|
|
||||||
|
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
||||||
|
g = group[group["n_samples"] > 0]
|
||||||
|
if len(g) == 0:
|
||||||
|
return pd.Series(
|
||||||
|
dict(
|
||||||
|
n_bins_used=0,
|
||||||
|
n_samples_total=0,
|
||||||
|
n_positives_total=0,
|
||||||
|
auc=float("nan"),
|
||||||
|
auprc=float("nan"),
|
||||||
|
recall_at_K=float("nan"),
|
||||||
|
precision_at_K=float("nan"),
|
||||||
|
brier_score=float("nan"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
n_bins_used = int(g["age_bin_id"].nunique())
|
||||||
|
n_samples_total = int(g["n_samples"].sum())
|
||||||
|
n_positives_total = int(g["n_positives"].sum())
|
||||||
|
|
||||||
|
if not weighted:
|
||||||
|
return pd.Series(
|
||||||
|
dict(
|
||||||
|
n_bins_used=n_bins_used,
|
||||||
|
n_samples_total=n_samples_total,
|
||||||
|
n_positives_total=n_positives_total,
|
||||||
|
auc=float(g["auc"].mean()),
|
||||||
|
auprc=float(g["auprc"].mean()),
|
||||||
|
recall_at_K=float(g["recall_at_K"].mean()),
|
||||||
|
precision_at_K=float(g["precision_at_K"].mean()),
|
||||||
|
brier_score=float(g["brier_score"].mean()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
w = g["n_samples"].to_numpy(dtype=float)
|
||||||
|
w_sum = float(w.sum())
|
||||||
|
if w_sum <= 0.0:
|
||||||
|
return pd.Series(
|
||||||
|
dict(
|
||||||
|
n_bins_used=n_bins_used,
|
||||||
|
n_samples_total=n_samples_total,
|
||||||
|
n_positives_total=n_positives_total,
|
||||||
|
auc=float("nan"),
|
||||||
|
auprc=float("nan"),
|
||||||
|
recall_at_K=float("nan"),
|
||||||
|
precision_at_K=float("nan"),
|
||||||
|
brier_score=float("nan"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _wavg(col: str) -> float:
|
||||||
|
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
||||||
|
|
||||||
|
return pd.Series(
|
||||||
|
dict(
|
||||||
|
n_bins_used=n_bins_used,
|
||||||
|
n_samples_total=n_samples_total,
|
||||||
|
n_positives_total=n_positives_total,
|
||||||
|
auc=_wavg("auc"),
|
||||||
|
auprc=_wavg("auprc"),
|
||||||
|
recall_at_K=_wavg("recall_at_K"),
|
||||||
|
precision_at_K=_wavg("precision_at_K"),
|
||||||
|
brier_score=_wavg("brier_score"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
||||||
|
|
||||||
|
df_mc_macro = (
|
||||||
|
df_by_bin.groupby(group_keys)
|
||||||
|
.apply(lambda g: _bin_aggregate(g, weighted=False))
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
df_mc_macro["agg_type"] = "macro"
|
||||||
|
|
||||||
|
df_mc_weighted = (
|
||||||
|
df_by_bin.groupby(group_keys)
|
||||||
|
.apply(lambda g: _bin_aggregate(g, weighted=True))
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
df_mc_weighted["agg_type"] = "weighted"
|
||||||
|
|
||||||
|
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
||||||
|
|
||||||
|
# Then average over MC repetitions.
|
||||||
|
df_agg = (
|
||||||
|
df_mc_binagg.groupby(
|
||||||
|
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
||||||
|
)
|
||||||
|
.agg(
|
||||||
|
n_mc=("mc_idx", "nunique"),
|
||||||
|
n_bins_used_mean=("n_bins_used", "mean"),
|
||||||
|
n_samples_total_mean=("n_samples_total", "mean"),
|
||||||
|
n_positives_total_mean=("n_positives_total", "mean"),
|
||||||
|
auc_mean=("auc", "mean"),
|
||||||
|
auc_std=("auc", "std"),
|
||||||
|
auprc_mean=("auprc", "mean"),
|
||||||
|
auprc_std=("auprc", "std"),
|
||||||
|
recall_at_K_mean=("recall_at_K", "mean"),
|
||||||
|
recall_at_K_std=("recall_at_K", "std"),
|
||||||
|
precision_at_K_mean=("precision_at_K", "mean"),
|
||||||
|
precision_at_K_std=("precision_at_K", "std"),
|
||||||
|
brier_score_mean=("brier_score", "mean"),
|
||||||
|
brier_score_std=("brier_score", "std"),
|
||||||
|
)
|
||||||
|
.sort_values(
|
||||||
|
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
||||||
|
ignore_index=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return df_by_bin, df_agg
|
||||||
@@ -1,322 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from utils import (
|
|
||||||
DAYS_PER_YEAR,
|
|
||||||
multi_hot_ever_within_horizon,
|
|
||||||
multi_hot_selected_causes_within_horizon,
|
|
||||||
select_context_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
|
||||||
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
|
||||||
|
|
||||||
Returns NaN if y_true has no positives or no negatives.
|
|
||||||
|
|
||||||
Uses the Mann–Whitney U statistic with average ranks for ties.
|
|
||||||
"""
|
|
||||||
y_true = np.asarray(y_true).astype(bool)
|
|
||||||
y_score = np.asarray(y_score).astype(float)
|
|
||||||
|
|
||||||
n = y_true.size
|
|
||||||
if n == 0:
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
n_pos = int(y_true.sum())
|
|
||||||
n_neg = n - n_pos
|
|
||||||
if n_pos == 0 or n_neg == 0:
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
# Rank scores ascending, average ranks for ties.
|
|
||||||
order = np.argsort(y_score, kind="mergesort")
|
|
||||||
sorted_scores = y_score[order]
|
|
||||||
|
|
||||||
ranks = np.empty(n, dtype=float)
|
|
||||||
i = 0
|
|
||||||
# 1-based ranks
|
|
||||||
while i < n:
|
|
||||||
j = i + 1
|
|
||||||
while j < n and sorted_scores[j] == sorted_scores[i]:
|
|
||||||
j += 1
|
|
||||||
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
|
|
||||||
ranks[order[i:j]] = avg_rank
|
|
||||||
i = j
|
|
||||||
|
|
||||||
sum_ranks_pos = float(ranks[y_true].sum())
|
|
||||||
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
|
|
||||||
return float(u / (n_pos * n_neg))
|
|
||||||
|
|
||||||
|
|
||||||
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
|
||||||
"""Average precision (area under PR curve using step-wise interpolation).
|
|
||||||
|
|
||||||
Returns NaN if no positives.
|
|
||||||
"""
|
|
||||||
y_true = np.asarray(y_true).astype(bool)
|
|
||||||
y_score = np.asarray(y_score).astype(float)
|
|
||||||
|
|
||||||
n = y_true.size
|
|
||||||
if n == 0:
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
n_pos = int(y_true.sum())
|
|
||||||
if n_pos == 0:
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
order = np.argsort(-y_score, kind="mergesort")
|
|
||||||
y = y_true[order]
|
|
||||||
|
|
||||||
tp = np.cumsum(y).astype(float)
|
|
||||||
fp = np.cumsum(~y).astype(float)
|
|
||||||
precision = tp / np.maximum(tp + fp, 1.0)
|
|
||||||
recall = tp / n_pos
|
|
||||||
|
|
||||||
# AP = sum over each positive of precision at that point / n_pos
|
|
||||||
# (equivalent to ∑ Δrecall * precision)
|
|
||||||
ap = float(np.sum(precision[y]) / n_pos)
|
|
||||||
# handle potential tiny numerical overshoots
|
|
||||||
return float(max(0.0, min(1.0, ap)))
|
|
||||||
|
|
||||||
|
|
||||||
def _precision_recall_at_k_percent(
|
|
||||||
y_true: np.ndarray,
|
|
||||||
y_score: np.ndarray,
|
|
||||||
k_percent: float,
|
|
||||||
) -> Tuple[float, float]:
|
|
||||||
"""Precision@K% and Recall@K% for binary labels.
|
|
||||||
|
|
||||||
Returns (precision, recall). Returns NaN for recall if no positives.
|
|
||||||
Returns NaN for precision if k leads to 0 selected.
|
|
||||||
"""
|
|
||||||
y_true = np.asarray(y_true).astype(bool)
|
|
||||||
y_score = np.asarray(y_score).astype(float)
|
|
||||||
|
|
||||||
n = y_true.size
|
|
||||||
if n == 0:
|
|
||||||
return float("nan"), float("nan")
|
|
||||||
|
|
||||||
n_pos = int(y_true.sum())
|
|
||||||
|
|
||||||
k = int(math.ceil((float(k_percent) / 100.0) * n))
|
|
||||||
if k <= 0:
|
|
||||||
return float("nan"), float("nan")
|
|
||||||
|
|
||||||
order = np.argsort(-y_score, kind="mergesort")
|
|
||||||
top = order[:k]
|
|
||||||
tp_top = int(y_true[top].sum())
|
|
||||||
|
|
||||||
precision = tp_top / k
|
|
||||||
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
|
|
||||||
return float(precision), float(recall)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvalConfig:
|
|
||||||
horizons_years: Sequence[float]
|
|
||||||
offset_years: float = 0.0
|
|
||||||
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
|
|
||||||
cause_ids: Optional[Sequence[int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def evaluate_time_dependent(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
head: torch.nn.Module,
|
|
||||||
criterion,
|
|
||||||
dataloader: torch.utils.data.DataLoader,
|
|
||||||
n_disease: int,
|
|
||||||
cfg: EvalConfig,
|
|
||||||
device: str | torch.device,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Evaluate time-dependent metrics per cause and per horizon.
|
|
||||||
|
|
||||||
Assumptions:
|
|
||||||
- time_seq is in days
|
|
||||||
- horizons_years and the loss CIF times are in years
|
|
||||||
- disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame with columns:
|
|
||||||
cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc,
|
|
||||||
recall_at_K, precision_at_K, brier_score
|
|
||||||
"""
|
|
||||||
device = torch.device(device)
|
|
||||||
model.eval()
|
|
||||||
head.eval()
|
|
||||||
|
|
||||||
horizons_years = [float(x) for x in cfg.horizons_years]
|
|
||||||
if len(horizons_years) == 0:
|
|
||||||
raise ValueError("cfg.horizons_years must be non-empty")
|
|
||||||
|
|
||||||
topk_percents = [float(x) for x in cfg.topk_percents]
|
|
||||||
if len(topk_percents) == 0:
|
|
||||||
raise ValueError("cfg.topk_percents must be non-empty")
|
|
||||||
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
|
|
||||||
raise ValueError(
|
|
||||||
f"All topk_percents must be in (0,100]; got {topk_percents}")
|
|
||||||
|
|
||||||
taus_tensor = torch.tensor(
|
|
||||||
horizons_years, device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
if cfg.cause_ids is None:
|
|
||||||
cause_ids = None
|
|
||||||
n_causes_eval = int(n_disease)
|
|
||||||
else:
|
|
||||||
cause_ids = torch.tensor(
|
|
||||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
|
||||||
n_causes_eval = int(cause_ids.numel())
|
|
||||||
|
|
||||||
# Accumulate per horizon
|
|
||||||
y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
|
|
||||||
y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
|
|
||||||
|
|
||||||
for batch in dataloader:
|
|
||||||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
|
||||||
event_seq = event_seq.to(device)
|
|
||||||
time_seq = time_seq.to(device)
|
|
||||||
cont_feats = cont_feats.to(device)
|
|
||||||
cate_feats = cate_feats.to(device)
|
|
||||||
sexes = sexes.to(device)
|
|
||||||
|
|
||||||
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
|
|
||||||
|
|
||||||
# Select a single fixed context per sample for this batch.
|
|
||||||
# Horizon-specific eligibility is derived from this context (do not re-select per horizon).
|
|
||||||
keep0, t_ctx, t_ctx_time = select_context_indices(
|
|
||||||
event_seq=event_seq,
|
|
||||||
time_seq=time_seq,
|
|
||||||
offset_years=float(cfg.offset_years),
|
|
||||||
tau_years=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not keep0.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
b = torch.arange(event_seq.size(0), device=device)
|
|
||||||
c = h[b, t_ctx] # (B,D)
|
|
||||||
logits = head(c)
|
|
||||||
|
|
||||||
# CIFs for all horizons at once
|
|
||||||
cifs_all = criterion.calculate_cifs(
|
|
||||||
logits, taus=taus_tensor) # (B,K,T) or (B,K)
|
|
||||||
if cifs_all.ndim != 3:
|
|
||||||
raise ValueError(
|
|
||||||
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Follow-up end time per sample = time at last valid token.
|
|
||||||
valid = event_seq != 0
|
|
||||||
lengths = valid.sum(dim=1)
|
|
||||||
last_idx = torch.clamp(lengths - 1, min=0)
|
|
||||||
followup_end_time = time_seq[b, last_idx]
|
|
||||||
|
|
||||||
for h_idx, tau_y in enumerate(horizons_years):
|
|
||||||
# Horizon-specific eligibility without reselecting context:
|
|
||||||
# keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
|
|
||||||
keep_tau = keep0 & (
|
|
||||||
followup_end_time >= (
|
|
||||||
t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
|
|
||||||
)
|
|
||||||
if not keep_tau.any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if cause_ids is None:
|
|
||||||
y = multi_hot_ever_within_horizon(
|
|
||||||
event_seq=event_seq,
|
|
||||||
time_seq=time_seq,
|
|
||||||
t_ctx=t_ctx,
|
|
||||||
tau_years=float(tau_y),
|
|
||||||
n_disease=n_disease,
|
|
||||||
)
|
|
||||||
y = y[keep_tau]
|
|
||||||
preds = cifs_all[keep_tau, :, h_idx]
|
|
||||||
else:
|
|
||||||
y = multi_hot_selected_causes_within_horizon(
|
|
||||||
event_seq=event_seq,
|
|
||||||
time_seq=time_seq,
|
|
||||||
t_ctx=t_ctx,
|
|
||||||
tau_years=float(tau_y),
|
|
||||||
cause_ids=cause_ids,
|
|
||||||
n_disease=n_disease,
|
|
||||||
)
|
|
||||||
y = y[keep_tau]
|
|
||||||
preds = cifs_all[keep_tau, :, h_idx].index_select(
|
|
||||||
dim=1, index=cause_ids)
|
|
||||||
|
|
||||||
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
|
|
||||||
y_pred_by_h[h_idx].append(
|
|
||||||
preds.detach().to(torch.float32).cpu().numpy())
|
|
||||||
|
|
||||||
rows: List[Dict[str, float | int]] = []
|
|
||||||
|
|
||||||
for h_idx, tau_y in enumerate(horizons_years):
|
|
||||||
if len(y_true_by_h[h_idx]) == 0:
|
|
||||||
# No eligible samples for this horizon.
|
|
||||||
for k in range(n_causes_eval):
|
|
||||||
cause_id = int(k) if cause_ids is None else int(
|
|
||||||
cfg.cause_ids[k])
|
|
||||||
for k_percent in topk_percents:
|
|
||||||
rows.append(
|
|
||||||
dict(
|
|
||||||
cause_id=cause_id,
|
|
||||||
horizon_tau=float(tau_y),
|
|
||||||
topk_percent=float(k_percent),
|
|
||||||
n_samples=0,
|
|
||||||
n_positives=0,
|
|
||||||
auc=float("nan"),
|
|
||||||
auprc=float("nan"),
|
|
||||||
recall_at_K=float("nan"),
|
|
||||||
precision_at_K=float("nan"),
|
|
||||||
brier_score=float("nan"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true = np.concatenate(y_true_by_h[h_idx], axis=0)
|
|
||||||
y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0)
|
|
||||||
|
|
||||||
if y_true.shape != y_pred.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
n_samples = int(y_true.shape[0])
|
|
||||||
|
|
||||||
for k in range(n_causes_eval):
|
|
||||||
yk = y_true[:, k]
|
|
||||||
pk = y_pred[:, k]
|
|
||||||
n_pos = int(yk.sum())
|
|
||||||
|
|
||||||
auc = _binary_roc_auc(yk, pk)
|
|
||||||
auprc = _average_precision(yk, pk)
|
|
||||||
brier = float(np.mean((yk.astype(float) - pk.astype(float))
|
|
||||||
** 2)) if n_samples > 0 else float("nan")
|
|
||||||
|
|
||||||
cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k])
|
|
||||||
for k_percent in topk_percents:
|
|
||||||
precision_k, recall_k = _precision_recall_at_k_percent(
|
|
||||||
yk, pk, float(k_percent))
|
|
||||||
rows.append(
|
|
||||||
dict(
|
|
||||||
cause_id=cause_id,
|
|
||||||
horizon_tau=float(tau_y),
|
|
||||||
topk_percent=float(k_percent),
|
|
||||||
n_samples=n_samples,
|
|
||||||
n_positives=n_pos,
|
|
||||||
auc=float(auc),
|
|
||||||
auprc=float(auprc),
|
|
||||||
recall_at_K=float(recall_k),
|
|
||||||
precision_at_K=float(precision_k),
|
|
||||||
brier_score=float(brier),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return pd.DataFrame(rows)
|
|
||||||
73
utils.py
73
utils.py
@@ -4,6 +4,79 @@ from typing import Tuple
|
|||||||
DAYS_PER_YEAR = 365.25
|
DAYS_PER_YEAR = 365.25
|
||||||
|
|
||||||
|
|
||||||
|
def sample_context_in_fixed_age_bin(
|
||||||
|
event_seq: torch.Tensor,
|
||||||
|
time_seq: torch.Tensor,
|
||||||
|
tau_years: float,
|
||||||
|
age_bin: Tuple[float, float],
|
||||||
|
seed: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Sample one context token per individual within a fixed age bin.
|
||||||
|
|
||||||
|
Delphi-2M semantics for a specific (tau, age_bin):
|
||||||
|
- Token times are interpreted as age in *days* (converted to years).
|
||||||
|
- Follow-up end time is the last valid token time per individual.
|
||||||
|
- A token index j is eligible iff:
|
||||||
|
(token is valid)
|
||||||
|
AND (age_years in [age_low, age_high))
|
||||||
|
AND (time_seq[i, j] + tau_days <= followup_end_time[i])
|
||||||
|
- For each individual, randomly select exactly one eligible token in this bin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_seq: (B, L) token ids, 0 is padding.
|
||||||
|
time_seq: (B, L) token times in days.
|
||||||
|
tau_years: horizon length in years.
|
||||||
|
age_bin: (low, high) bounds in years, interpreted as [low, high).
|
||||||
|
seed: RNG seed for deterministic sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
keep: (B,) bool, True if a context was sampled for this bin.
|
||||||
|
t_ctx: (B,) long, sampled context index (undefined when keep=False; set to 0).
|
||||||
|
"""
|
||||||
|
low, high = float(age_bin[0]), float(age_bin[1])
|
||||||
|
if not (high > low):
|
||||||
|
raise ValueError(f"age_bin must satisfy high>low; got {(low, high)}")
|
||||||
|
|
||||||
|
device = event_seq.device
|
||||||
|
B, _ = event_seq.shape
|
||||||
|
|
||||||
|
valid = event_seq != 0
|
||||||
|
lengths = valid.sum(dim=1)
|
||||||
|
last_idx = torch.clamp(lengths - 1, min=0)
|
||||||
|
b = torch.arange(B, device=device)
|
||||||
|
followup_end_time = time_seq[b, last_idx] # (B,)
|
||||||
|
|
||||||
|
tau_days = float(tau_years) * DAYS_PER_YEAR
|
||||||
|
age_years = time_seq / DAYS_PER_YEAR
|
||||||
|
|
||||||
|
in_bin = (age_years >= low) & (age_years < high)
|
||||||
|
eligible = valid & in_bin & (
|
||||||
|
(time_seq + tau_days) <= followup_end_time.unsqueeze(1))
|
||||||
|
|
||||||
|
keep = torch.zeros((B,), dtype=torch.bool, device=device)
|
||||||
|
t_ctx = torch.zeros((B,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
gen = torch.Generator(device="cpu")
|
||||||
|
gen.manual_seed(int(seed))
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
m = eligible[i]
|
||||||
|
if not m.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
idxs = m.nonzero(as_tuple=False).view(-1).cpu()
|
||||||
|
chosen_idx_pos = int(
|
||||||
|
torch.randint(low=0, high=int(idxs.numel()),
|
||||||
|
size=(1,), generator=gen).item()
|
||||||
|
)
|
||||||
|
chosen_t = int(idxs[chosen_idx_pos].item())
|
||||||
|
|
||||||
|
keep[i] = True
|
||||||
|
t_ctx[i] = chosen_t
|
||||||
|
|
||||||
|
return keep, t_ctx
|
||||||
|
|
||||||
|
|
||||||
def select_context_indices(
|
def select_context_indices(
|
||||||
event_seq: torch.Tensor,
|
event_seq: torch.Tensor,
|
||||||
time_seq: torch.Tensor,
|
time_seq: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user