Add evaluation and utility functions for time-dependent metrics
- Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference. - Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds. - Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models. - Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons.
This commit is contained in:
234
evaluate.py
Normal file
234
evaluate.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import List, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
|
||||||
|
from dataset import HealthDataset, health_collate_fn
|
||||||
|
from evaluation_time_dependent import EvalConfig, evaluate_time_dependent
|
||||||
|
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
||||||
|
from model import DelphiFork, SapDelphi, SimpleHead
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_floats(items: Sequence[str]) -> List[float]:
|
||||||
|
out: List[float] = []
|
||||||
|
for x in items:
|
||||||
|
x = x.strip()
|
||||||
|
if not x:
|
||||||
|
continue
|
||||||
|
out.append(float(x))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
||||||
|
if loss_type == "exponential":
|
||||||
|
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
||||||
|
out_dims = [n_disease]
|
||||||
|
return criterion, out_dims
|
||||||
|
|
||||||
|
if loss_type == "discrete_time_cif":
|
||||||
|
criterion = DiscreteTimeCIFNLLLoss(
|
||||||
|
bin_edges=bin_edges, lambda_reg=lambda_reg)
|
||||||
|
out_dims = [n_disease + 1, len(bin_edges)]
|
||||||
|
return criterion, out_dims
|
||||||
|
|
||||||
|
if loss_type == "pwe_cif":
|
||||||
|
pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))]
|
||||||
|
if len(pwe_edges) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"pwe_cif requires at least 2 finite bin edges (including 0)")
|
||||||
|
if float(pwe_edges[0]) != 0.0:
|
||||||
|
raise ValueError("pwe_cif requires bin_edges[0]==0.0")
|
||||||
|
criterion = PiecewiseExponentialCIFNLLLoss(
|
||||||
|
bin_edges=pwe_edges, lambda_reg=lambda_reg)
|
||||||
|
n_bins = len(pwe_edges) - 1
|
||||||
|
out_dims = [n_disease, n_bins]
|
||||||
|
return criterion, out_dims
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
||||||
|
if model_type == "delphi_fork":
|
||||||
|
return DelphiFork(
|
||||||
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=2,
|
||||||
|
n_embd=int(cfg["n_embd"]),
|
||||||
|
n_head=int(cfg["n_head"]),
|
||||||
|
n_layer=int(cfg["n_layer"]),
|
||||||
|
pdrop=float(cfg.get("pdrop", 0.0)),
|
||||||
|
age_encoder_type=str(cfg["age_encoder"]),
|
||||||
|
n_cont=int(dataset.n_cont),
|
||||||
|
n_cate=int(dataset.n_cate),
|
||||||
|
cate_dims=list(dataset.cate_dims),
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "sap_delphi":
|
||||||
|
return SapDelphi(
|
||||||
|
n_disease=dataset.n_disease,
|
||||||
|
n_tech_tokens=2,
|
||||||
|
n_embd=int(cfg["n_embd"]),
|
||||||
|
n_head=int(cfg["n_head"]),
|
||||||
|
n_layer=int(cfg["n_layer"]),
|
||||||
|
pdrop=float(cfg.get("pdrop", 0.0)),
|
||||||
|
age_encoder_type=str(cfg["age_encoder"]),
|
||||||
|
n_cont=int(dataset.n_cont),
|
||||||
|
n_cate=int(dataset.n_cate),
|
||||||
|
cate_dims=list(dataset.cate_dims),
|
||||||
|
pretrained_weights_path=str(
|
||||||
|
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
|
||||||
|
freeze_embeddings=bool(cfg.get("freeze_embeddings", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported model_type: {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Time-dependent evaluation for DeepHealth")
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Training run directory (contains best_model.pt and train_config.json)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--data_prefix", type=str, default=None,
|
||||||
|
help="Dataset prefix (overrides config if provided)")
|
||||||
|
parser.add_argument("--split", type=str,
|
||||||
|
choices=["train", "val", "test", "all"], default="val")
|
||||||
|
|
||||||
|
parser.add_argument("--horizons", type=str, nargs="+",
|
||||||
|
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)")
|
||||||
|
parser.add_argument("--offset_years", type=float, default=0.0,
|
||||||
|
help="Context selection offset (years before follow-up end)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk_percent",
|
||||||
|
type=float,
|
||||||
|
nargs="+",
|
||||||
|
default=[1, 5, 10, 20, 50],
|
||||||
|
help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--device", type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=256)
|
||||||
|
parser.add_argument("--num_workers", type=int,
|
||||||
|
default=0, help="Keep 0 on Windows")
|
||||||
|
|
||||||
|
parser.add_argument("--out_csv", type=str, default=None,
|
||||||
|
help="Optional output CSV path")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ckpt_path = os.path.join(args.run_dir, "best_model.pt")
|
||||||
|
cfg_path = os.path.join(args.run_dir, "train_config.json")
|
||||||
|
|
||||||
|
if not os.path.exists(ckpt_path):
|
||||||
|
raise SystemExit(f"Missing checkpoint: {ckpt_path}")
|
||||||
|
if not os.path.exists(cfg_path):
|
||||||
|
raise SystemExit(f"Missing config: {cfg_path}")
|
||||||
|
|
||||||
|
with open(cfg_path, "r") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
|
||||||
|
data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get(
|
||||||
|
"data_prefix", "ukb")
|
||||||
|
|
||||||
|
# Match training covariate selection.
|
||||||
|
full_cov = bool(cfg.get("full_cov", False))
|
||||||
|
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
||||||
|
dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list)
|
||||||
|
|
||||||
|
# Recreate the same split scheme as train.py
|
||||||
|
train_ratio = float(cfg.get("train_ratio", 0.7))
|
||||||
|
val_ratio = float(cfg.get("val_ratio", 0.15))
|
||||||
|
seed = int(cfg.get("random_seed", 42))
|
||||||
|
|
||||||
|
n_total = len(dataset)
|
||||||
|
n_train = int(n_total * train_ratio)
|
||||||
|
n_val = int(n_total * val_ratio)
|
||||||
|
n_test = n_total - n_train - n_val
|
||||||
|
train_ds, val_ds, test_ds = random_split(
|
||||||
|
dataset,
|
||||||
|
[n_train, n_val, n_test],
|
||||||
|
generator=torch.Generator().manual_seed(seed),
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.split == "train":
|
||||||
|
ds = train_ds
|
||||||
|
elif args.split == "val":
|
||||||
|
ds = val_ds
|
||||||
|
elif args.split == "test":
|
||||||
|
ds = test_ds
|
||||||
|
else:
|
||||||
|
ds = dataset
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
ds,
|
||||||
|
batch_size=int(args.batch_size),
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=health_collate_fn,
|
||||||
|
num_workers=int(args.num_workers),
|
||||||
|
pin_memory=str(args.device).startswith("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
criterion, out_dims = build_criterion_and_out_dims(
|
||||||
|
loss_type=str(cfg["loss_type"]),
|
||||||
|
n_disease=int(dataset.n_disease),
|
||||||
|
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
||||||
|
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
||||||
|
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
||||||
|
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
||||||
|
if "criterion_state_dict" in checkpoint:
|
||||||
|
try:
|
||||||
|
criterion.load_state_dict(
|
||||||
|
checkpoint["criterion_state_dict"], strict=False)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
head.to(device)
|
||||||
|
criterion.to(device)
|
||||||
|
|
||||||
|
eval_cfg = EvalConfig(
|
||||||
|
horizons_years=_parse_floats(args.horizons),
|
||||||
|
offset_years=float(args.offset_years),
|
||||||
|
topk_percents=[float(x) for x in args.topk_percent],
|
||||||
|
cause_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
df = evaluate_time_dependent(
|
||||||
|
model=model,
|
||||||
|
head=head,
|
||||||
|
criterion=criterion,
|
||||||
|
dataloader=loader,
|
||||||
|
n_disease=int(dataset.n_disease),
|
||||||
|
cfg=eval_cfg,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.out_csv is None:
|
||||||
|
out_csv = os.path.join(
|
||||||
|
args.run_dir, f"time_dependent_metrics_{args.split}.csv")
|
||||||
|
else:
|
||||||
|
out_csv = args.out_csv
|
||||||
|
|
||||||
|
df.to_csv(out_csv, index=False)
|
||||||
|
print(f"Wrote: {out_csv}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
316
evaluation_time_dependent.py
Normal file
316
evaluation_time_dependent.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
DAYS_PER_YEAR,
|
||||||
|
multi_hot_ever_within_horizon,
|
||||||
|
multi_hot_selected_causes_within_horizon,
|
||||||
|
select_context_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||||
|
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
||||||
|
|
||||||
|
Returns NaN if y_true has no positives or no negatives.
|
||||||
|
|
||||||
|
Uses the Mann–Whitney U statistic with average ranks for ties.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(bool)
|
||||||
|
y_score = np.asarray(y_score).astype(float)
|
||||||
|
|
||||||
|
n = y_true.size
|
||||||
|
if n == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
n_neg = n - n_pos
|
||||||
|
if n_pos == 0 or n_neg == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
# Rank scores ascending, average ranks for ties.
|
||||||
|
order = np.argsort(y_score, kind="mergesort")
|
||||||
|
sorted_scores = y_score[order]
|
||||||
|
|
||||||
|
ranks = np.empty(n, dtype=float)
|
||||||
|
i = 0
|
||||||
|
# 1-based ranks
|
||||||
|
while i < n:
|
||||||
|
j = i + 1
|
||||||
|
while j < n and sorted_scores[j] == sorted_scores[i]:
|
||||||
|
j += 1
|
||||||
|
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
|
||||||
|
ranks[order[i:j]] = avg_rank
|
||||||
|
i = j
|
||||||
|
|
||||||
|
sum_ranks_pos = float(ranks[y_true].sum())
|
||||||
|
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
|
||||||
|
return float(u / (n_pos * n_neg))
|
||||||
|
|
||||||
|
|
||||||
|
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||||
|
"""Average precision (area under PR curve using step-wise interpolation).
|
||||||
|
|
||||||
|
Returns NaN if no positives.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(bool)
|
||||||
|
y_score = np.asarray(y_score).astype(float)
|
||||||
|
|
||||||
|
n = y_true.size
|
||||||
|
if n == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
if n_pos == 0:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
order = np.argsort(-y_score, kind="mergesort")
|
||||||
|
y = y_true[order]
|
||||||
|
|
||||||
|
tp = np.cumsum(y).astype(float)
|
||||||
|
fp = np.cumsum(~y).astype(float)
|
||||||
|
precision = tp / np.maximum(tp + fp, 1.0)
|
||||||
|
recall = tp / n_pos
|
||||||
|
|
||||||
|
# AP = sum over each positive of precision at that point / n_pos
|
||||||
|
# (equivalent to ∑ Δrecall * precision)
|
||||||
|
ap = float(np.sum(precision[y]) / n_pos)
|
||||||
|
# handle potential tiny numerical overshoots
|
||||||
|
return float(max(0.0, min(1.0, ap)))
|
||||||
|
|
||||||
|
|
||||||
|
def _precision_recall_at_k_percent(
|
||||||
|
y_true: np.ndarray,
|
||||||
|
y_score: np.ndarray,
|
||||||
|
k_percent: float,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""Precision@K% and Recall@K% for binary labels.
|
||||||
|
|
||||||
|
Returns (precision, recall). Returns NaN for recall if no positives.
|
||||||
|
Returns NaN for precision if k leads to 0 selected.
|
||||||
|
"""
|
||||||
|
y_true = np.asarray(y_true).astype(bool)
|
||||||
|
y_score = np.asarray(y_score).astype(float)
|
||||||
|
|
||||||
|
n = y_true.size
|
||||||
|
if n == 0:
|
||||||
|
return float("nan"), float("nan")
|
||||||
|
|
||||||
|
n_pos = int(y_true.sum())
|
||||||
|
|
||||||
|
k = int(math.ceil((float(k_percent) / 100.0) * n))
|
||||||
|
if k <= 0:
|
||||||
|
return float("nan"), float("nan")
|
||||||
|
|
||||||
|
order = np.argsort(-y_score, kind="mergesort")
|
||||||
|
top = order[:k]
|
||||||
|
tp_top = int(y_true[top].sum())
|
||||||
|
|
||||||
|
precision = tp_top / k
|
||||||
|
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
|
||||||
|
return float(precision), float(recall)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalConfig:
|
||||||
|
horizons_years: Sequence[float]
|
||||||
|
offset_years: float = 0.0
|
||||||
|
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
|
||||||
|
cause_ids: Optional[Sequence[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_time_dependent(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
head: torch.nn.Module,
|
||||||
|
criterion,
|
||||||
|
dataloader: torch.utils.data.DataLoader,
|
||||||
|
n_disease: int,
|
||||||
|
cfg: EvalConfig,
|
||||||
|
device: str | torch.device,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Evaluate time-dependent metrics per cause and per horizon.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- time_seq is in days
|
||||||
|
- horizons_years and the loss CIF times are in years
|
||||||
|
- disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns:
|
||||||
|
cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc,
|
||||||
|
recall_at_K, precision_at_K, brier_score
|
||||||
|
"""
|
||||||
|
device = torch.device(device)
|
||||||
|
model.eval()
|
||||||
|
head.eval()
|
||||||
|
|
||||||
|
horizons_years = [float(x) for x in cfg.horizons_years]
|
||||||
|
if len(horizons_years) == 0:
|
||||||
|
raise ValueError("cfg.horizons_years must be non-empty")
|
||||||
|
|
||||||
|
topk_percents = [float(x) for x in cfg.topk_percents]
|
||||||
|
if len(topk_percents) == 0:
|
||||||
|
raise ValueError("cfg.topk_percents must be non-empty")
|
||||||
|
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
|
||||||
|
raise ValueError(
|
||||||
|
f"All topk_percents must be in (0,100]; got {topk_percents}")
|
||||||
|
|
||||||
|
taus_tensor = torch.tensor(
|
||||||
|
horizons_years, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if cfg.cause_ids is None:
|
||||||
|
cause_ids = None
|
||||||
|
n_causes_eval = int(n_disease)
|
||||||
|
else:
|
||||||
|
cause_ids = torch.tensor(
|
||||||
|
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||||
|
n_causes_eval = int(cause_ids.numel())
|
||||||
|
|
||||||
|
# Accumulate per horizon
|
||||||
|
y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
|
||||||
|
y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||||||
|
event_seq = event_seq.to(device)
|
||||||
|
time_seq = time_seq.to(device)
|
||||||
|
cont_feats = cont_feats.to(device)
|
||||||
|
cate_feats = cate_feats.to(device)
|
||||||
|
sexes = sexes.to(device)
|
||||||
|
|
||||||
|
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
|
||||||
|
|
||||||
|
# Context index selection (independent of horizon); keep mask is refined per horizon.
|
||||||
|
keep0, t_ctx, _ = select_context_indices(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
offset_years=float(cfg.offset_years),
|
||||||
|
tau_years=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not keep0.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
b = torch.arange(event_seq.size(0), device=device)
|
||||||
|
c = h[b, t_ctx] # (B,D)
|
||||||
|
logits = head(c)
|
||||||
|
|
||||||
|
# CIFs for all horizons at once
|
||||||
|
cifs_all = criterion.calculate_cifs(
|
||||||
|
logits, taus=taus_tensor) # (B,K,T) or (B,K)
|
||||||
|
if cifs_all.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for h_idx, tau_y in enumerate(horizons_years):
|
||||||
|
keep, _, _ = select_context_indices(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
offset_years=float(cfg.offset_years),
|
||||||
|
tau_years=float(tau_y),
|
||||||
|
)
|
||||||
|
keep = keep & keep0
|
||||||
|
if not keep.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cause_ids is None:
|
||||||
|
y = multi_hot_ever_within_horizon(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
t_ctx=t_ctx,
|
||||||
|
tau_years=float(tau_y),
|
||||||
|
n_disease=n_disease,
|
||||||
|
)
|
||||||
|
y = y[keep]
|
||||||
|
preds = cifs_all[keep, :, h_idx]
|
||||||
|
else:
|
||||||
|
y = multi_hot_selected_causes_within_horizon(
|
||||||
|
event_seq=event_seq,
|
||||||
|
time_seq=time_seq,
|
||||||
|
t_ctx=t_ctx,
|
||||||
|
tau_years=float(tau_y),
|
||||||
|
cause_ids=cause_ids,
|
||||||
|
n_disease=n_disease,
|
||||||
|
)
|
||||||
|
y = y[keep]
|
||||||
|
preds = cifs_all[keep, :, h_idx].index_select(
|
||||||
|
dim=1, index=cause_ids)
|
||||||
|
|
||||||
|
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
|
||||||
|
y_pred_by_h[h_idx].append(
|
||||||
|
preds.detach().to(torch.float32).cpu().numpy())
|
||||||
|
|
||||||
|
rows: List[Dict[str, float | int]] = []
|
||||||
|
|
||||||
|
for h_idx, tau_y in enumerate(horizons_years):
|
||||||
|
if len(y_true_by_h[h_idx]) == 0:
|
||||||
|
# No eligible samples for this horizon.
|
||||||
|
for k in range(n_causes_eval):
|
||||||
|
cause_id = int(k) if cause_ids is None else int(
|
||||||
|
cfg.cause_ids[k])
|
||||||
|
for k_percent in topk_percents:
|
||||||
|
rows.append(
|
||||||
|
dict(
|
||||||
|
cause_id=cause_id,
|
||||||
|
horizon_tau=float(tau_y),
|
||||||
|
topk_percent=float(k_percent),
|
||||||
|
n_samples=0,
|
||||||
|
n_positives=0,
|
||||||
|
auc=float("nan"),
|
||||||
|
auprc=float("nan"),
|
||||||
|
recall_at_K=float("nan"),
|
||||||
|
precision_at_K=float("nan"),
|
||||||
|
brier_score=float("nan"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true = np.concatenate(y_true_by_h[h_idx], axis=0)
|
||||||
|
y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0)
|
||||||
|
|
||||||
|
if y_true.shape != y_pred.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
n_samples = int(y_true.shape[0])
|
||||||
|
|
||||||
|
for k in range(n_causes_eval):
|
||||||
|
yk = y_true[:, k]
|
||||||
|
pk = y_pred[:, k]
|
||||||
|
n_pos = int(yk.sum())
|
||||||
|
|
||||||
|
auc = _binary_roc_auc(yk, pk)
|
||||||
|
auprc = _average_precision(yk, pk)
|
||||||
|
brier = float(np.mean((yk.astype(float) - pk.astype(float))
|
||||||
|
** 2)) if n_samples > 0 else float("nan")
|
||||||
|
|
||||||
|
cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k])
|
||||||
|
for k_percent in topk_percents:
|
||||||
|
precision_k, recall_k = _precision_recall_at_k_percent(
|
||||||
|
yk, pk, float(k_percent))
|
||||||
|
rows.append(
|
||||||
|
dict(
|
||||||
|
cause_id=cause_id,
|
||||||
|
horizon_tau=float(tau_y),
|
||||||
|
topk_percent=float(k_percent),
|
||||||
|
n_samples=n_samples,
|
||||||
|
n_positives=n_pos,
|
||||||
|
auc=float(auc),
|
||||||
|
auprc=float(auprc),
|
||||||
|
recall_at_K=float(recall_k),
|
||||||
|
precision_at_K=float(precision_k),
|
||||||
|
brier_score=float(brier),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return pd.DataFrame(rows)
|
||||||
386
losses.py
386
losses.py
@@ -131,6 +131,96 @@ class ExponentialNLLLoss(nn.Module):
|
|||||||
reduction="mean") * self.lambda_reg
|
reduction="mean") * self.lambda_reg
|
||||||
return nll, reg
|
return nll, reg
|
||||||
|
|
||||||
|
def calculate_cifs(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
taus: torch.Tensor,
|
||||||
|
eps: Optional[float] = None,
|
||||||
|
return_survival: bool = False,
|
||||||
|
):
|
||||||
|
"""Compute CIFs for a competing-risks exponential model.
|
||||||
|
|
||||||
|
Model assumptions:
|
||||||
|
- cause-specific hazards are constant in time within a sample.
|
||||||
|
- hazards are obtained via softplus(logits) + eps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: (M, K) or (M, K, 1) tensor.
|
||||||
|
taus: scalar, (T,), (M,), or (M, T) times (>=0 recommended).
|
||||||
|
eps: overrides self.eps for numerical stability.
|
||||||
|
return_survival: if True, also return survival S(tau).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
||||||
|
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
||||||
|
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
||||||
|
scalar_out = False
|
||||||
|
kind = "T" # one of: 'T', 'per_sample', 'MT'
|
||||||
|
if t.ndim == 0:
|
||||||
|
t = t.view(1)
|
||||||
|
scalar_out = True
|
||||||
|
t = t.view(1, 1) # (1,1)
|
||||||
|
kind = "T"
|
||||||
|
elif t.ndim == 1:
|
||||||
|
if t.shape[0] == batch_size:
|
||||||
|
t = t.view(batch_size, 1) # (M,1)
|
||||||
|
kind = "per_sample"
|
||||||
|
else:
|
||||||
|
t = t.view(1, -1) # (1,T)
|
||||||
|
kind = "T"
|
||||||
|
elif t.ndim == 2:
|
||||||
|
if t.shape[0] != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
||||||
|
)
|
||||||
|
kind = "MT"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
||||||
|
return t, kind, scalar_out
|
||||||
|
|
||||||
|
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
||||||
|
if logits.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"logits must be 2D (M,K) (or 3D with last dim 1); got shape={tuple(logits.shape)}")
|
||||||
|
|
||||||
|
M, K = logits.shape
|
||||||
|
used_eps = float(self.eps if eps is None else eps)
|
||||||
|
|
||||||
|
hazards = F.softplus(logits) + used_eps # (M, K)
|
||||||
|
total_hazard = hazards.sum(dim=1, keepdim=True) # (M, 1)
|
||||||
|
total_hazard = torch.clamp(total_hazard, min=used_eps)
|
||||||
|
|
||||||
|
frac = hazards / total_hazard # (M, K)
|
||||||
|
|
||||||
|
taus_t, kind, scalar_out = _prepare_taus(
|
||||||
|
taus, M, logits.device, hazards.dtype)
|
||||||
|
taus_t = torch.clamp(taus_t, min=0)
|
||||||
|
|
||||||
|
if kind == "T":
|
||||||
|
# taus_t: (1,T)
|
||||||
|
exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,T)
|
||||||
|
cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) # (M,K,T)
|
||||||
|
survival = torch.exp(-total_hazard * taus_t) # (M,T)
|
||||||
|
else:
|
||||||
|
# taus_t: (M,1) or (M,T)
|
||||||
|
exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,1) or (M,T)
|
||||||
|
# (M,K,1) or (M,K,T)
|
||||||
|
cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1)
|
||||||
|
survival = torch.exp(-total_hazard * taus_t) # (M,1) or (M,T)
|
||||||
|
|
||||||
|
if kind == "per_sample":
|
||||||
|
cifs = cifs.squeeze(-1) # (M,K)
|
||||||
|
survival = survival.squeeze(-1) # (M,)
|
||||||
|
elif scalar_out:
|
||||||
|
cifs = cifs.squeeze(-1) # (M,K)
|
||||||
|
survival = survival.squeeze(-1) # (M,)
|
||||||
|
|
||||||
|
return (cifs, survival) if return_survival else cifs
|
||||||
|
|
||||||
|
|
||||||
class DiscreteTimeCIFNLLLoss(nn.Module):
|
class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||||
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
||||||
@@ -259,6 +349,122 @@ class DiscreteTimeCIFNLLLoss(nn.Module):
|
|||||||
|
|
||||||
return nll, reg
|
return nll, reg
|
||||||
|
|
||||||
|
def calculate_cifs(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
taus: torch.Tensor,
|
||||||
|
eps: Optional[float] = None,
|
||||||
|
return_survival: bool = False,
|
||||||
|
):
|
||||||
|
"""Compute discrete-time CIFs implied by per-bin (K causes + complement) logits.
|
||||||
|
|
||||||
|
This matches the likelihood used in forward():
|
||||||
|
p(event=cause k at bin j) = Π_{u=1}^{j-1} p(comp at u) * p(k at j)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: (M, K+1, n_bins+1) where channel K is complement.
|
||||||
|
taus: scalar, (T,), (M,), or (M,T) continuous times.
|
||||||
|
eps: unused (kept for signature compatibility).
|
||||||
|
return_survival: if True, also return survival probability up to the mapped bin.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
||||||
|
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
||||||
|
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
||||||
|
scalar_out = False
|
||||||
|
kind = "T"
|
||||||
|
if t.ndim == 0:
|
||||||
|
t = t.view(1)
|
||||||
|
scalar_out = True
|
||||||
|
t = t.view(1, 1)
|
||||||
|
kind = "T"
|
||||||
|
elif t.ndim == 1:
|
||||||
|
if t.shape[0] == batch_size:
|
||||||
|
t = t.view(batch_size, 1)
|
||||||
|
kind = "per_sample"
|
||||||
|
else:
|
||||||
|
t = t.view(1, -1)
|
||||||
|
kind = "T"
|
||||||
|
elif t.ndim == 2:
|
||||||
|
if t.shape[0] != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
||||||
|
)
|
||||||
|
kind = "MT"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
||||||
|
return t, kind, scalar_out
|
||||||
|
|
||||||
|
if logits.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"logits must have shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
M, k_plus_1, n_bins_plus_1 = logits.shape
|
||||||
|
K = k_plus_1 - 1
|
||||||
|
if K < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"logits.shape[1] must be at least 2 (K>=1 plus complement)")
|
||||||
|
|
||||||
|
n_bins = int(self.bin_edges.numel() - 1)
|
||||||
|
if n_bins_plus_1 != n_bins + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# probs over causes+complement per bin
|
||||||
|
probs = F.softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
||||||
|
p_causes = probs[:, :K, 1:] # (M, K, n_bins)
|
||||||
|
p_comp = probs[:, K, 1:] # (M, n_bins)
|
||||||
|
|
||||||
|
# survival up to end of each bin (1..n_bins)
|
||||||
|
surv_end = torch.cumprod(p_comp, dim=1) # (M, n_bins)
|
||||||
|
ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype)
|
||||||
|
surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins)
|
||||||
|
|
||||||
|
inc = surv_start.unsqueeze(1) * p_causes # (M, K, n_bins)
|
||||||
|
cif_full = torch.cumsum(inc, dim=2) # (M, K, n_bins)
|
||||||
|
|
||||||
|
taus_t, kind, scalar_out = _prepare_taus(
|
||||||
|
taus, M, logits.device, surv_end.dtype)
|
||||||
|
taus_t = torch.clamp(taus_t, min=0)
|
||||||
|
|
||||||
|
bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype)
|
||||||
|
time_bin = torch.bucketize(taus_t, bin_edges) # (..)
|
||||||
|
time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(torch.long)
|
||||||
|
|
||||||
|
if kind == "T":
|
||||||
|
# (1,T) -> expand to (M,T)
|
||||||
|
time_bin = time_bin.expand(M, -1)
|
||||||
|
# kind per_sample gives (M,1), MT gives (M,T)
|
||||||
|
|
||||||
|
idx = torch.clamp(time_bin - 1, min=0) # (M,T)
|
||||||
|
|
||||||
|
gathered_cif = cif_full.gather(
|
||||||
|
dim=2,
|
||||||
|
index=idx.unsqueeze(1).expand(-1, K, -1),
|
||||||
|
) # (M,K,T)
|
||||||
|
gathered_surv = surv_end.gather(dim=1, index=idx) # (M,T)
|
||||||
|
|
||||||
|
# tau mapped to bin 0 => CIF=0, survival=1
|
||||||
|
zero_mask = (time_bin == 0)
|
||||||
|
if zero_mask.any():
|
||||||
|
gathered_cif = gathered_cif.masked_fill(zero_mask.unsqueeze(1), 0.0)
|
||||||
|
gathered_surv = gathered_surv.masked_fill(zero_mask, 1.0)
|
||||||
|
|
||||||
|
if kind == "per_sample":
|
||||||
|
gathered_cif = gathered_cif.squeeze(-1) # (M,K)
|
||||||
|
gathered_surv = gathered_surv.squeeze(-1) # (M,)
|
||||||
|
elif scalar_out:
|
||||||
|
gathered_cif = gathered_cif.squeeze(-1) # (M,K)
|
||||||
|
gathered_surv = gathered_surv.squeeze(-1) # (M,)
|
||||||
|
|
||||||
|
return (gathered_cif, gathered_surv) if return_survival else gathered_cif
|
||||||
|
|
||||||
|
|
||||||
class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -404,3 +610,183 @@ class PiecewiseExponentialCIFNLLLoss(nn.Module):
|
|||||||
reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype)
|
reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype)
|
||||||
|
|
||||||
return nll, reg
|
return nll, reg
|
||||||
|
|
||||||
|
def calculate_cifs(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
taus: torch.Tensor,
|
||||||
|
eps: Optional[float] = None,
|
||||||
|
return_survival: bool = False,
|
||||||
|
):
|
||||||
|
"""Compute CIFs for piecewise-constant cause-specific hazards.
|
||||||
|
|
||||||
|
Uses the same binning convention as forward(): taus are mapped to a bin via
|
||||||
|
torch.bucketize(taus, bin_edges), clamped to [0, n_bins]. tau<=0 maps to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: (M, K, n_bins) hazard logits per cause per bin.
|
||||||
|
taus: scalar, (T,), (M,), or (M,T) times.
|
||||||
|
eps: overrides self.eps for numerical stability.
|
||||||
|
return_survival: if True, also return survival S(tau).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cifs: (M, K) if taus is scalar or (M,), else (M, K, T).
|
||||||
|
survival (optional): (M,) if taus is scalar or (M,), else (M, T).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype):
|
||||||
|
t = torch.as_tensor(taus_tensor, device=device, dtype=dtype)
|
||||||
|
scalar_out = False
|
||||||
|
kind = "T"
|
||||||
|
if t.ndim == 0:
|
||||||
|
t = t.view(1)
|
||||||
|
scalar_out = True
|
||||||
|
t = t.view(1, 1)
|
||||||
|
kind = "T"
|
||||||
|
elif t.ndim == 1:
|
||||||
|
if t.shape[0] == batch_size:
|
||||||
|
t = t.view(batch_size, 1)
|
||||||
|
kind = "per_sample"
|
||||||
|
else:
|
||||||
|
t = t.view(1, -1)
|
||||||
|
kind = "T"
|
||||||
|
elif t.ndim == 2:
|
||||||
|
if t.shape[0] != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}"
|
||||||
|
)
|
||||||
|
kind = "MT"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}")
|
||||||
|
return t, kind, scalar_out
|
||||||
|
|
||||||
|
if logits.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"logits must be 3D (M,K,n_bins); got shape={tuple(logits.shape)}")
|
||||||
|
|
||||||
|
M, K, n_bins = logits.shape
|
||||||
|
if self.bin_edges.numel() != n_bins + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"bin_edges length must be n_bins+1={n_bins+1}; got {self.bin_edges.numel()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
used_eps = float(self.eps if eps is None else eps)
|
||||||
|
|
||||||
|
taus_t, kind, scalar_out = _prepare_taus(
|
||||||
|
taus, M, logits.device, logits.dtype)
|
||||||
|
taus_t = torch.clamp(taus_t, min=0)
|
||||||
|
|
||||||
|
bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype)
|
||||||
|
dt_bins = (bin_edges[1:] - bin_edges[:-1]
|
||||||
|
).to(device=logits.device, dtype=logits.dtype) # (n_bins,)
|
||||||
|
|
||||||
|
hazards = F.softplus(logits) + used_eps # (M, K, n_bins)
|
||||||
|
total_h = hazards.sum(dim=1) # (M, n_bins)
|
||||||
|
total_h = torch.clamp(total_h, min=used_eps)
|
||||||
|
|
||||||
|
# Precompute full-bin CIF increments
|
||||||
|
H_total_bin = total_h * dt_bins.view(1, n_bins) # (M, n_bins)
|
||||||
|
cum_H_end = torch.cumsum(H_total_bin, dim=1) # (M, n_bins)
|
||||||
|
surv_end = torch.exp(-cum_H_end) # (M, n_bins)
|
||||||
|
ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype)
|
||||||
|
surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins)
|
||||||
|
|
||||||
|
frac = hazards / total_h.unsqueeze(1) # (M, K, n_bins)
|
||||||
|
one_minus = 1.0 - \
|
||||||
|
torch.exp(-total_h * dt_bins.view(1, n_bins)) # (M, n_bins)
|
||||||
|
inc_full = surv_start.unsqueeze(
|
||||||
|
1) * frac * one_minus.unsqueeze(1) # (M, K, n_bins)
|
||||||
|
cif_full = torch.cumsum(inc_full, dim=2) # (M, K, n_bins)
|
||||||
|
|
||||||
|
# Map taus -> bin index b in [0..n_bins]
|
||||||
|
time_bin = torch.bucketize(taus_t, bin_edges)
|
||||||
|
time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(
|
||||||
|
torch.long) # (...)
|
||||||
|
|
||||||
|
if kind == "T":
|
||||||
|
time_bin = time_bin.expand(M, -1) # (M,T)
|
||||||
|
|
||||||
|
# Compute within-bin length l and indices
|
||||||
|
b = time_bin # (M,T)
|
||||||
|
idx_bin0 = torch.clamp(b - 1, min=0) # 0..n_bins-1
|
||||||
|
|
||||||
|
# Start-of-bin survival for the current bin (for b==0 it's unused)
|
||||||
|
S_start_b = surv_start.gather(dim=1, index=idx_bin0) # (M,T)
|
||||||
|
|
||||||
|
# Length into bin: l = tau - edge[b-1], clamped to [0, dt_bin]
|
||||||
|
left_edge = bin_edges.gather(
|
||||||
|
dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0).to(taus_t.dtype)
|
||||||
|
l = taus_t.expand_as(b) - left_edge
|
||||||
|
l = torch.clamp(l, min=0)
|
||||||
|
width_b = dt_bins.gather(
|
||||||
|
dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0)
|
||||||
|
l = torch.min(l, width_b.to(l.dtype))
|
||||||
|
|
||||||
|
# CIF up to previous full bins
|
||||||
|
# if b<=1 => 0 else cif_full at (b-2)
|
||||||
|
prev_idx = torch.clamp(b - 2, min=0)
|
||||||
|
cif_before = cif_full.gather(
|
||||||
|
dim=2,
|
||||||
|
index=prev_idx.unsqueeze(1).expand(-1, K, -1),
|
||||||
|
) # (M,K,T)
|
||||||
|
if (b <= 1).any():
|
||||||
|
cif_before = cif_before.masked_fill((b <= 1).unsqueeze(1), 0.0)
|
||||||
|
|
||||||
|
# Partial increment in current bin
|
||||||
|
total_h_b = total_h.gather(dim=1, index=idx_bin0) # (M,T)
|
||||||
|
haz_b = hazards.gather(
|
||||||
|
dim=2,
|
||||||
|
index=idx_bin0.unsqueeze(1).expand(-1, K, -1),
|
||||||
|
) # (M,K,T)
|
||||||
|
frac_b = haz_b / total_h_b.unsqueeze(1) # (M,K,T)
|
||||||
|
|
||||||
|
one_minus_partial = 1.0 - torch.exp(-total_h_b * l) # (M,T)
|
||||||
|
inc_partial = S_start_b.unsqueeze(
|
||||||
|
1) * frac_b * one_minus_partial.unsqueeze(1) # (M,K,T)
|
||||||
|
|
||||||
|
cifs = cif_before + inc_partial
|
||||||
|
|
||||||
|
survival = S_start_b * torch.exp(-total_h_b * l) # (M,T)
|
||||||
|
|
||||||
|
# Inference-only tail extension beyond the last finite edge.
|
||||||
|
# For tau > t_B (t_B = bin_edges[-1]), extend survival and CIFs using
|
||||||
|
# constant hazards from the final bin B:
|
||||||
|
# S(tau)=S(t_B) * exp(-Λ_B * (tau - t_B))
|
||||||
|
# F_k(tau)=F_k(t_B) + S(t_B) * (λ_{k,B}/Λ_B) * (1 - exp(-Λ_B*(tau-t_B)))
|
||||||
|
last_edge = bin_edges[-1]
|
||||||
|
tau_full = taus_t.expand_as(b) # (M,T)
|
||||||
|
tail_mask = tau_full > last_edge
|
||||||
|
if tail_mask.any():
|
||||||
|
delta = torch.clamp(tau_full - last_edge, min=0) # (M,T)
|
||||||
|
|
||||||
|
S_B = surv_end[:, -1].unsqueeze(1) # (M,1)
|
||||||
|
F_B = cif_full[:, :, -1].unsqueeze(-1) # (M,K,1)
|
||||||
|
|
||||||
|
lambda_last = hazards[:, :, -1] # (M,K)
|
||||||
|
Lambda_last = torch.clamp(
|
||||||
|
total_h[:, -1], min=used_eps).unsqueeze(1) # (M,1)
|
||||||
|
|
||||||
|
exp_tail = torch.exp(-Lambda_last * delta) # (M,T)
|
||||||
|
survival_tail = S_B * exp_tail # (M,T)
|
||||||
|
cifs_tail = F_B + \
|
||||||
|
S_B.unsqueeze(
|
||||||
|
1) * (lambda_last / Lambda_last).unsqueeze(-1) * (1.0 - exp_tail).unsqueeze(1)
|
||||||
|
|
||||||
|
survival = torch.where(tail_mask, survival_tail, survival)
|
||||||
|
cifs = torch.where(tail_mask.unsqueeze(1), cifs_tail, cifs)
|
||||||
|
|
||||||
|
# tau mapped to bin 0 => CIF=0, survival=1
|
||||||
|
zero_mask = (b == 0)
|
||||||
|
if zero_mask.any():
|
||||||
|
cifs = cifs.masked_fill(zero_mask.unsqueeze(1), 0.0)
|
||||||
|
survival = survival.masked_fill(zero_mask, 1.0)
|
||||||
|
|
||||||
|
if kind == "per_sample":
|
||||||
|
cifs = cifs.squeeze(-1) # (M,K)
|
||||||
|
survival = survival.squeeze(-1) # (M,)
|
||||||
|
elif scalar_out:
|
||||||
|
cifs = cifs.squeeze(-1) # (M,K)
|
||||||
|
survival = survival.squeeze(-1) # (M,)
|
||||||
|
|
||||||
|
return (cifs, survival) if return_survival else cifs
|
||||||
|
|||||||
130
utils.py
Normal file
130
utils.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
DAYS_PER_YEAR = 365.25
|
||||||
|
|
||||||
|
|
||||||
|
def select_context_indices(
|
||||||
|
event_seq: torch.Tensor,
|
||||||
|
time_seq: torch.Tensor,
|
||||||
|
offset_years: float,
|
||||||
|
tau_years: float = 0.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Select per-sample prediction context index.
|
||||||
|
|
||||||
|
IMPORTANT SEMANTICS:
|
||||||
|
- The last observed token time is treated as the FOLLOW-UP END time.
|
||||||
|
- We pick the last valid token with time <= (followup_end_time - offset).
|
||||||
|
- We do NOT interpret followup_end_time as an event time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
keep_mask: (B,) bool, which samples have a valid context
|
||||||
|
t_ctx: (B,) long, index into sequence
|
||||||
|
t_ctx_time: (B,) float, time (days) at context
|
||||||
|
"""
|
||||||
|
# valid tokens are event != 0 (padding is 0)
|
||||||
|
valid = event_seq != 0
|
||||||
|
lengths = valid.sum(dim=1)
|
||||||
|
last_idx = torch.clamp(lengths - 1, min=0)
|
||||||
|
|
||||||
|
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
||||||
|
followup_end_time = time_seq[b, last_idx]
|
||||||
|
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
||||||
|
|
||||||
|
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
||||||
|
eligible_counts = eligible.sum(dim=1)
|
||||||
|
keep = eligible_counts > 0
|
||||||
|
|
||||||
|
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
||||||
|
t_ctx_time = time_seq[b, t_ctx]
|
||||||
|
|
||||||
|
# Horizon-aligned eligibility: require enough follow-up time after the selected context.
|
||||||
|
# All times are in days.
|
||||||
|
keep = keep & (followup_end_time >= (
|
||||||
|
t_ctx_time + (tau_years * DAYS_PER_YEAR)))
|
||||||
|
|
||||||
|
return keep, t_ctx, t_ctx_time
|
||||||
|
|
||||||
|
|
||||||
|
def multi_hot_ever_within_horizon(
|
||||||
|
event_seq: torch.Tensor,
|
||||||
|
time_seq: torch.Tensor,
|
||||||
|
t_ctx: torch.Tensor,
|
||||||
|
tau_years: float,
|
||||||
|
n_disease: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
||||||
|
B, L = event_seq.shape
|
||||||
|
b = torch.arange(B, device=event_seq.device)
|
||||||
|
t0 = time_seq[b, t_ctx]
|
||||||
|
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||||||
|
|
||||||
|
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||||||
|
# Include same-day events after context, exclude any token at/before context index.
|
||||||
|
in_window = (
|
||||||
|
(idxs > t_ctx.unsqueeze(1))
|
||||||
|
& (time_seq >= t0.unsqueeze(1))
|
||||||
|
& (time_seq <= t1.unsqueeze(1))
|
||||||
|
& (event_seq >= 2)
|
||||||
|
& (event_seq != 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not in_window.any():
|
||||||
|
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||||||
|
|
||||||
|
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||||||
|
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||||||
|
|
||||||
|
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||||||
|
y[b_idx, disease_ids] = True
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def multi_hot_selected_causes_within_horizon(
|
||||||
|
event_seq: torch.Tensor,
|
||||||
|
time_seq: torch.Tensor,
|
||||||
|
t_ctx: torch.Tensor,
|
||||||
|
tau_years: float,
|
||||||
|
cause_ids: torch.Tensor,
|
||||||
|
n_disease: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
||||||
|
B, L = event_seq.shape
|
||||||
|
device = event_seq.device
|
||||||
|
b = torch.arange(B, device=device)
|
||||||
|
t0 = time_seq[b, t_ctx]
|
||||||
|
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||||||
|
|
||||||
|
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||||||
|
in_window = (
|
||||||
|
(idxs > t_ctx.unsqueeze(1))
|
||||||
|
& (time_seq >= t0.unsqueeze(1))
|
||||||
|
& (time_seq <= t1.unsqueeze(1))
|
||||||
|
& (event_seq >= 2)
|
||||||
|
& (event_seq != 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
||||||
|
if not in_window.any():
|
||||||
|
return out
|
||||||
|
|
||||||
|
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||||||
|
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||||||
|
|
||||||
|
# Filter to selected causes via a boolean membership mask over the global disease space.
|
||||||
|
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
||||||
|
selected[cause_ids] = True
|
||||||
|
keep = selected[disease_ids]
|
||||||
|
if not keep.any():
|
||||||
|
return out
|
||||||
|
|
||||||
|
b_idx = b_idx[keep]
|
||||||
|
disease_ids = disease_ids[keep]
|
||||||
|
|
||||||
|
# Map disease_id -> local index in cause_ids
|
||||||
|
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
||||||
|
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
||||||
|
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
||||||
|
local = lookup[disease_ids]
|
||||||
|
out[b_idx, local] = True
|
||||||
|
return out
|
||||||
Reference in New Issue
Block a user