From 34d8d8ce9da6d59e2666841f1e03124dee3a4eec Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 16 Jan 2026 14:55:09 +0800 Subject: [PATCH] Add evaluation and utility functions for time-dependent metrics - Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference. - Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds. - Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models. - Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons. --- evaluate.py | 234 +++++++++++++++++++++ evaluation_time_dependent.py | 316 ++++++++++++++++++++++++++++ losses.py | 386 +++++++++++++++++++++++++++++++++++ utils.py | 130 ++++++++++++ 4 files changed, 1066 insertions(+) create mode 100644 evaluate.py create mode 100644 evaluation_time_dependent.py create mode 100644 utils.py diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..0b5002e --- /dev/null +++ b/evaluate.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import argparse +import json +import math +import os +from typing import List, Sequence + +import torch +from torch.utils.data import DataLoader, random_split + +from dataset import HealthDataset, health_collate_fn +from evaluation_time_dependent import EvalConfig, evaluate_time_dependent +from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss +from model import DelphiFork, SapDelphi, SimpleHead + + +def _parse_floats(items: Sequence[str]) -> List[float]: + out: List[float] = [] + for x in items: + x = x.strip() + if not x: + continue + out.append(float(x)) + return out + + +def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): + if loss_type == "exponential": + criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) + out_dims = [n_disease] + return criterion, out_dims + + if loss_type == "discrete_time_cif": + criterion = DiscreteTimeCIFNLLLoss( + bin_edges=bin_edges, lambda_reg=lambda_reg) + out_dims = [n_disease + 1, len(bin_edges)] + return criterion, out_dims + + if loss_type == "pwe_cif": + pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] + if len(pwe_edges) < 2: + raise ValueError( + "pwe_cif requires at least 2 finite bin edges (including 0)") + if float(pwe_edges[0]) != 0.0: + raise ValueError("pwe_cif requires bin_edges[0]==0.0") + criterion = PiecewiseExponentialCIFNLLLoss( + bin_edges=pwe_edges, lambda_reg=lambda_reg) + n_bins = len(pwe_edges) - 1 + out_dims = [n_disease, n_bins] + return criterion, out_dims + + raise ValueError(f"Unsupported loss_type: {loss_type}") + + +def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): + if model_type == "delphi_fork": + return DelphiFork( + n_disease=dataset.n_disease, + n_tech_tokens=2, + n_embd=int(cfg["n_embd"]), + n_head=int(cfg["n_head"]), + n_layer=int(cfg["n_layer"]), + pdrop=float(cfg.get("pdrop", 0.0)), + age_encoder_type=str(cfg["age_encoder"]), + n_cont=int(dataset.n_cont), + n_cate=int(dataset.n_cate), + cate_dims=list(dataset.cate_dims), + ) + + if model_type == "sap_delphi": + return SapDelphi( + n_disease=dataset.n_disease, + n_tech_tokens=2, + n_embd=int(cfg["n_embd"]), + n_head=int(cfg["n_head"]), + n_layer=int(cfg["n_layer"]), + pdrop=float(cfg.get("pdrop", 0.0)), + age_encoder_type=str(cfg["age_encoder"]), + n_cont=int(dataset.n_cont), + n_cate=int(dataset.n_cate), + cate_dims=list(dataset.cate_dims), + pretrained_weights_path=str( + cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), + freeze_embeddings=bool(cfg.get("freeze_embeddings", True)), + ) + + raise ValueError(f"Unsupported model_type: {model_type}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Time-dependent evaluation for DeepHealth") + parser.add_argument( + "--run_dir", + type=str, + required=True, + help="Training run directory (contains best_model.pt and train_config.json)", + ) + parser.add_argument("--data_prefix", type=str, default=None, + help="Dataset prefix (overrides config if provided)") + parser.add_argument("--split", type=str, + choices=["train", "val", "test", "all"], default="val") + + parser.add_argument("--horizons", type=str, nargs="+", + default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)") + parser.add_argument("--offset_years", type=float, default=0.0, + help="Context selection offset (years before follow-up end)") + parser.add_argument( + "--topk_percent", + type=float, + nargs="+", + default=[1, 5, 10, 20, 50], + help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)", + ) + + parser.add_argument("--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_workers", type=int, + default=0, help="Keep 0 on Windows") + + parser.add_argument("--out_csv", type=str, default=None, + help="Optional output CSV path") + + args = parser.parse_args() + + ckpt_path = os.path.join(args.run_dir, "best_model.pt") + cfg_path = os.path.join(args.run_dir, "train_config.json") + + if not os.path.exists(ckpt_path): + raise SystemExit(f"Missing checkpoint: {ckpt_path}") + if not os.path.exists(cfg_path): + raise SystemExit(f"Missing config: {cfg_path}") + + with open(cfg_path, "r") as f: + cfg = json.load(f) + + data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( + "data_prefix", "ukb") + + # Match training covariate selection. + full_cov = bool(cfg.get("full_cov", False)) + cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] + dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) + + # Recreate the same split scheme as train.py + train_ratio = float(cfg.get("train_ratio", 0.7)) + val_ratio = float(cfg.get("val_ratio", 0.15)) + seed = int(cfg.get("random_seed", 42)) + + n_total = len(dataset) + n_train = int(n_total * train_ratio) + n_val = int(n_total * val_ratio) + n_test = n_total - n_train - n_val + train_ds, val_ds, test_ds = random_split( + dataset, + [n_train, n_val, n_test], + generator=torch.Generator().manual_seed(seed), + ) + + if args.split == "train": + ds = train_ds + elif args.split == "val": + ds = val_ds + elif args.split == "test": + ds = test_ds + else: + ds = dataset + + loader = DataLoader( + ds, + batch_size=int(args.batch_size), + shuffle=False, + collate_fn=health_collate_fn, + num_workers=int(args.num_workers), + pin_memory=str(args.device).startswith("cuda"), + ) + + criterion, out_dims = build_criterion_and_out_dims( + loss_type=str(cfg["loss_type"]), + n_disease=int(dataset.n_disease), + bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), + lambda_reg=float(cfg.get("lambda_reg", 0.0)), + ) + + model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) + head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) + + device = torch.device(args.device) + checkpoint = torch.load(ckpt_path, map_location=device) + + model.load_state_dict(checkpoint["model_state_dict"], strict=True) + head.load_state_dict(checkpoint["head_state_dict"], strict=True) + if "criterion_state_dict" in checkpoint: + try: + criterion.load_state_dict( + checkpoint["criterion_state_dict"], strict=False) + except Exception: + pass + + model.to(device) + head.to(device) + criterion.to(device) + + eval_cfg = EvalConfig( + horizons_years=_parse_floats(args.horizons), + offset_years=float(args.offset_years), + topk_percents=[float(x) for x in args.topk_percent], + cause_ids=None, + ) + + df = evaluate_time_dependent( + model=model, + head=head, + criterion=criterion, + dataloader=loader, + n_disease=int(dataset.n_disease), + cfg=eval_cfg, + device=device, + ) + + if args.out_csv is None: + out_csv = os.path.join( + args.run_dir, f"time_dependent_metrics_{args.split}.csv") + else: + out_csv = args.out_csv + + df.to_csv(out_csv, index=False) + print(f"Wrote: {out_csv}") + + +if __name__ == "__main__": + main() diff --git a/evaluation_time_dependent.py b/evaluation_time_dependent.py new file mode 100644 index 0000000..9e6f2f2 --- /dev/null +++ b/evaluation_time_dependent.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +import torch + +from utils import ( + DAYS_PER_YEAR, + multi_hot_ever_within_horizon, + multi_hot_selected_causes_within_horizon, + select_context_indices, +) + + +def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Compute ROC AUC for binary labels with tie-aware ranking. + + Returns NaN if y_true has no positives or no negatives. + + Uses the Mann–Whitney U statistic with average ranks for ties. + """ + y_true = np.asarray(y_true).astype(bool) + y_score = np.asarray(y_score).astype(float) + + n = y_true.size + if n == 0: + return float("nan") + + n_pos = int(y_true.sum()) + n_neg = n - n_pos + if n_pos == 0 or n_neg == 0: + return float("nan") + + # Rank scores ascending, average ranks for ties. + order = np.argsort(y_score, kind="mergesort") + sorted_scores = y_score[order] + + ranks = np.empty(n, dtype=float) + i = 0 + # 1-based ranks + while i < n: + j = i + 1 + while j < n and sorted_scores[j] == sorted_scores[i]: + j += 1 + avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j + ranks[order[i:j]] = avg_rank + i = j + + sum_ranks_pos = float(ranks[y_true].sum()) + u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) + return float(u / (n_pos * n_neg)) + + +def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Average precision (area under PR curve using step-wise interpolation). + + Returns NaN if no positives. + """ + y_true = np.asarray(y_true).astype(bool) + y_score = np.asarray(y_score).astype(float) + + n = y_true.size + if n == 0: + return float("nan") + + n_pos = int(y_true.sum()) + if n_pos == 0: + return float("nan") + + order = np.argsort(-y_score, kind="mergesort") + y = y_true[order] + + tp = np.cumsum(y).astype(float) + fp = np.cumsum(~y).astype(float) + precision = tp / np.maximum(tp + fp, 1.0) + recall = tp / n_pos + + # AP = sum over each positive of precision at that point / n_pos + # (equivalent to ∑ Δrecall * precision) + ap = float(np.sum(precision[y]) / n_pos) + # handle potential tiny numerical overshoots + return float(max(0.0, min(1.0, ap))) + + +def _precision_recall_at_k_percent( + y_true: np.ndarray, + y_score: np.ndarray, + k_percent: float, +) -> Tuple[float, float]: + """Precision@K% and Recall@K% for binary labels. + + Returns (precision, recall). Returns NaN for recall if no positives. + Returns NaN for precision if k leads to 0 selected. + """ + y_true = np.asarray(y_true).astype(bool) + y_score = np.asarray(y_score).astype(float) + + n = y_true.size + if n == 0: + return float("nan"), float("nan") + + n_pos = int(y_true.sum()) + + k = int(math.ceil((float(k_percent) / 100.0) * n)) + if k <= 0: + return float("nan"), float("nan") + + order = np.argsort(-y_score, kind="mergesort") + top = order[:k] + tp_top = int(y_true[top].sum()) + + precision = tp_top / k + recall = float("nan") if n_pos == 0 else (tp_top / n_pos) + return float(precision), float(recall) + + +@dataclass +class EvalConfig: + horizons_years: Sequence[float] + offset_years: float = 0.0 + topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) + cause_ids: Optional[Sequence[int]] = None + + +@torch.no_grad() +def evaluate_time_dependent( + model: torch.nn.Module, + head: torch.nn.Module, + criterion, + dataloader: torch.utils.data.DataLoader, + n_disease: int, + cfg: EvalConfig, + device: str | torch.device, +) -> pd.DataFrame: + """Evaluate time-dependent metrics per cause and per horizon. + + Assumptions: + - time_seq is in days + - horizons_years and the loss CIF times are in years + - disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2 + + Returns: + DataFrame with columns: + cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc, + recall_at_K, precision_at_K, brier_score + """ + device = torch.device(device) + model.eval() + head.eval() + + horizons_years = [float(x) for x in cfg.horizons_years] + if len(horizons_years) == 0: + raise ValueError("cfg.horizons_years must be non-empty") + + topk_percents = [float(x) for x in cfg.topk_percents] + if len(topk_percents) == 0: + raise ValueError("cfg.topk_percents must be non-empty") + if any((p <= 0.0 or p > 100.0) for p in topk_percents): + raise ValueError( + f"All topk_percents must be in (0,100]; got {topk_percents}") + + taus_tensor = torch.tensor( + horizons_years, device=device, dtype=torch.float32) + + if cfg.cause_ids is None: + cause_ids = None + n_causes_eval = int(n_disease) + else: + cause_ids = torch.tensor( + list(cfg.cause_ids), dtype=torch.long, device=device) + n_causes_eval = int(cause_ids.numel()) + + # Accumulate per horizon + y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years] + y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years] + + for batch in dataloader: + event_seq, time_seq, cont_feats, cate_feats, sexes = batch + event_seq = event_seq.to(device) + time_seq = time_seq.to(device) + cont_feats = cont_feats.to(device) + cate_feats = cate_feats.to(device) + sexes = sexes.to(device) + + h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) + + # Context index selection (independent of horizon); keep mask is refined per horizon. + keep0, t_ctx, _ = select_context_indices( + event_seq=event_seq, + time_seq=time_seq, + offset_years=float(cfg.offset_years), + tau_years=0.0, + ) + + if not keep0.any(): + continue + + b = torch.arange(event_seq.size(0), device=device) + c = h[b, t_ctx] # (B,D) + logits = head(c) + + # CIFs for all horizons at once + cifs_all = criterion.calculate_cifs( + logits, taus=taus_tensor) # (B,K,T) or (B,K) + if cifs_all.ndim != 3: + raise ValueError( + f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}" + ) + + for h_idx, tau_y in enumerate(horizons_years): + keep, _, _ = select_context_indices( + event_seq=event_seq, + time_seq=time_seq, + offset_years=float(cfg.offset_years), + tau_years=float(tau_y), + ) + keep = keep & keep0 + if not keep.any(): + continue + + if cause_ids is None: + y = multi_hot_ever_within_horizon( + event_seq=event_seq, + time_seq=time_seq, + t_ctx=t_ctx, + tau_years=float(tau_y), + n_disease=n_disease, + ) + y = y[keep] + preds = cifs_all[keep, :, h_idx] + else: + y = multi_hot_selected_causes_within_horizon( + event_seq=event_seq, + time_seq=time_seq, + t_ctx=t_ctx, + tau_years=float(tau_y), + cause_ids=cause_ids, + n_disease=n_disease, + ) + y = y[keep] + preds = cifs_all[keep, :, h_idx].index_select( + dim=1, index=cause_ids) + + y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy()) + y_pred_by_h[h_idx].append( + preds.detach().to(torch.float32).cpu().numpy()) + + rows: List[Dict[str, float | int]] = [] + + for h_idx, tau_y in enumerate(horizons_years): + if len(y_true_by_h[h_idx]) == 0: + # No eligible samples for this horizon. + for k in range(n_causes_eval): + cause_id = int(k) if cause_ids is None else int( + cfg.cause_ids[k]) + for k_percent in topk_percents: + rows.append( + dict( + cause_id=cause_id, + horizon_tau=float(tau_y), + topk_percent=float(k_percent), + n_samples=0, + n_positives=0, + auc=float("nan"), + auprc=float("nan"), + recall_at_K=float("nan"), + precision_at_K=float("nan"), + brier_score=float("nan"), + ) + ) + continue + + y_true = np.concatenate(y_true_by_h[h_idx], axis=0) + y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0) + + if y_true.shape != y_pred.shape: + raise ValueError( + f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}" + ) + + n_samples = int(y_true.shape[0]) + + for k in range(n_causes_eval): + yk = y_true[:, k] + pk = y_pred[:, k] + n_pos = int(yk.sum()) + + auc = _binary_roc_auc(yk, pk) + auprc = _average_precision(yk, pk) + brier = float(np.mean((yk.astype(float) - pk.astype(float)) + ** 2)) if n_samples > 0 else float("nan") + + cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k]) + for k_percent in topk_percents: + precision_k, recall_k = _precision_recall_at_k_percent( + yk, pk, float(k_percent)) + rows.append( + dict( + cause_id=cause_id, + horizon_tau=float(tau_y), + topk_percent=float(k_percent), + n_samples=n_samples, + n_positives=n_pos, + auc=float(auc), + auprc=float(auprc), + recall_at_K=float(recall_k), + precision_at_K=float(precision_k), + brier_score=float(brier), + ) + ) + + return pd.DataFrame(rows) diff --git a/losses.py b/losses.py index 3ef4de7..33323f5 100644 --- a/losses.py +++ b/losses.py @@ -131,6 +131,96 @@ class ExponentialNLLLoss(nn.Module): reduction="mean") * self.lambda_reg return nll, reg + def calculate_cifs( + self, + logits: torch.Tensor, + taus: torch.Tensor, + eps: Optional[float] = None, + return_survival: bool = False, + ): + """Compute CIFs for a competing-risks exponential model. + + Model assumptions: + - cause-specific hazards are constant in time within a sample. + - hazards are obtained via softplus(logits) + eps. + + Args: + logits: (M, K) or (M, K, 1) tensor. + taus: scalar, (T,), (M,), or (M, T) times (>=0 recommended). + eps: overrides self.eps for numerical stability. + return_survival: if True, also return survival S(tau). + + Returns: + cifs: (M, K) if taus is scalar or (M,), else (M, K, T). + survival (optional): (M,) if taus is scalar or (M,), else (M, T). + """ + + def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype): + t = torch.as_tensor(taus_tensor, device=device, dtype=dtype) + scalar_out = False + kind = "T" # one of: 'T', 'per_sample', 'MT' + if t.ndim == 0: + t = t.view(1) + scalar_out = True + t = t.view(1, 1) # (1,1) + kind = "T" + elif t.ndim == 1: + if t.shape[0] == batch_size: + t = t.view(batch_size, 1) # (M,1) + kind = "per_sample" + else: + t = t.view(1, -1) # (1,T) + kind = "T" + elif t.ndim == 2: + if t.shape[0] != batch_size: + raise ValueError( + f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}" + ) + kind = "MT" + else: + raise ValueError( + f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}") + return t, kind, scalar_out + + logits = logits.squeeze(-1) if logits.dim() == 3 else logits + if logits.ndim != 2: + raise ValueError( + f"logits must be 2D (M,K) (or 3D with last dim 1); got shape={tuple(logits.shape)}") + + M, K = logits.shape + used_eps = float(self.eps if eps is None else eps) + + hazards = F.softplus(logits) + used_eps # (M, K) + total_hazard = hazards.sum(dim=1, keepdim=True) # (M, 1) + total_hazard = torch.clamp(total_hazard, min=used_eps) + + frac = hazards / total_hazard # (M, K) + + taus_t, kind, scalar_out = _prepare_taus( + taus, M, logits.device, hazards.dtype) + taus_t = torch.clamp(taus_t, min=0) + + if kind == "T": + # taus_t: (1,T) + exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,T) + cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) # (M,K,T) + survival = torch.exp(-total_hazard * taus_t) # (M,T) + else: + # taus_t: (M,1) or (M,T) + exp_term = 1.0 - torch.exp(-total_hazard * taus_t) # (M,1) or (M,T) + # (M,K,1) or (M,K,T) + cifs = frac.unsqueeze(-1) * exp_term.unsqueeze(1) + survival = torch.exp(-total_hazard * taus_t) # (M,1) or (M,T) + + if kind == "per_sample": + cifs = cifs.squeeze(-1) # (M,K) + survival = survival.squeeze(-1) # (M,) + elif scalar_out: + cifs = cifs.squeeze(-1) # (M,K) + survival = survival.squeeze(-1) # (M,) + + return (cifs, survival) if return_survival else cifs + class DiscreteTimeCIFNLLLoss(nn.Module): """Direct discrete-time CIF negative log-likelihood (no censoring). @@ -259,6 +349,122 @@ class DiscreteTimeCIFNLLLoss(nn.Module): return nll, reg + def calculate_cifs( + self, + logits: torch.Tensor, + taus: torch.Tensor, + eps: Optional[float] = None, + return_survival: bool = False, + ): + """Compute discrete-time CIFs implied by per-bin (K causes + complement) logits. + + This matches the likelihood used in forward(): + p(event=cause k at bin j) = Π_{u=1}^{j-1} p(comp at u) * p(k at j) + + Args: + logits: (M, K+1, n_bins+1) where channel K is complement. + taus: scalar, (T,), (M,), or (M,T) continuous times. + eps: unused (kept for signature compatibility). + return_survival: if True, also return survival probability up to the mapped bin. + + Returns: + cifs: (M, K) if taus is scalar or (M,), else (M, K, T). + survival (optional): (M,) if taus is scalar or (M,), else (M, T). + """ + + def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype): + t = torch.as_tensor(taus_tensor, device=device, dtype=dtype) + scalar_out = False + kind = "T" + if t.ndim == 0: + t = t.view(1) + scalar_out = True + t = t.view(1, 1) + kind = "T" + elif t.ndim == 1: + if t.shape[0] == batch_size: + t = t.view(batch_size, 1) + kind = "per_sample" + else: + t = t.view(1, -1) + kind = "T" + elif t.ndim == 2: + if t.shape[0] != batch_size: + raise ValueError( + f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}" + ) + kind = "MT" + else: + raise ValueError( + f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}") + return t, kind, scalar_out + + if logits.ndim != 3: + raise ValueError( + f"logits must have shape (M, K+1, n_bins+1); got {tuple(logits.shape)}" + ) + + M, k_plus_1, n_bins_plus_1 = logits.shape + K = k_plus_1 - 1 + if K < 1: + raise ValueError( + "logits.shape[1] must be at least 2 (K>=1 plus complement)") + + n_bins = int(self.bin_edges.numel() - 1) + if n_bins_plus_1 != n_bins + 1: + raise ValueError( + f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}" + ) + + # probs over causes+complement per bin + probs = F.softmax(logits, dim=1) # (M, K+1, n_bins+1) + p_causes = probs[:, :K, 1:] # (M, K, n_bins) + p_comp = probs[:, K, 1:] # (M, n_bins) + + # survival up to end of each bin (1..n_bins) + surv_end = torch.cumprod(p_comp, dim=1) # (M, n_bins) + ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype) + surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins) + + inc = surv_start.unsqueeze(1) * p_causes # (M, K, n_bins) + cif_full = torch.cumsum(inc, dim=2) # (M, K, n_bins) + + taus_t, kind, scalar_out = _prepare_taus( + taus, M, logits.device, surv_end.dtype) + taus_t = torch.clamp(taus_t, min=0) + + bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype) + time_bin = torch.bucketize(taus_t, bin_edges) # (..) + time_bin = torch.clamp(time_bin, min=0, max=n_bins).to(torch.long) + + if kind == "T": + # (1,T) -> expand to (M,T) + time_bin = time_bin.expand(M, -1) + # kind per_sample gives (M,1), MT gives (M,T) + + idx = torch.clamp(time_bin - 1, min=0) # (M,T) + + gathered_cif = cif_full.gather( + dim=2, + index=idx.unsqueeze(1).expand(-1, K, -1), + ) # (M,K,T) + gathered_surv = surv_end.gather(dim=1, index=idx) # (M,T) + + # tau mapped to bin 0 => CIF=0, survival=1 + zero_mask = (time_bin == 0) + if zero_mask.any(): + gathered_cif = gathered_cif.masked_fill(zero_mask.unsqueeze(1), 0.0) + gathered_surv = gathered_surv.masked_fill(zero_mask, 1.0) + + if kind == "per_sample": + gathered_cif = gathered_cif.squeeze(-1) # (M,K) + gathered_surv = gathered_surv.squeeze(-1) # (M,) + elif scalar_out: + gathered_cif = gathered_cif.squeeze(-1) # (M,K) + gathered_surv = gathered_surv.squeeze(-1) # (M,) + + return (gathered_cif, gathered_surv) if return_survival else gathered_cif + class PiecewiseExponentialCIFNLLLoss(nn.Module): """ @@ -404,3 +610,183 @@ class PiecewiseExponentialCIFNLLLoss(nn.Module): reg = torch.zeros((), device=logits.device, dtype=loss_vec.dtype) return nll, reg + + def calculate_cifs( + self, + logits: torch.Tensor, + taus: torch.Tensor, + eps: Optional[float] = None, + return_survival: bool = False, + ): + """Compute CIFs for piecewise-constant cause-specific hazards. + + Uses the same binning convention as forward(): taus are mapped to a bin via + torch.bucketize(taus, bin_edges), clamped to [0, n_bins]. tau<=0 maps to 0. + + Args: + logits: (M, K, n_bins) hazard logits per cause per bin. + taus: scalar, (T,), (M,), or (M,T) times. + eps: overrides self.eps for numerical stability. + return_survival: if True, also return survival S(tau). + + Returns: + cifs: (M, K) if taus is scalar or (M,), else (M, K, T). + survival (optional): (M,) if taus is scalar or (M,), else (M, T). + """ + + def _prepare_taus(taus_tensor: torch.Tensor, batch_size: int, device, dtype): + t = torch.as_tensor(taus_tensor, device=device, dtype=dtype) + scalar_out = False + kind = "T" + if t.ndim == 0: + t = t.view(1) + scalar_out = True + t = t.view(1, 1) + kind = "T" + elif t.ndim == 1: + if t.shape[0] == batch_size: + t = t.view(batch_size, 1) + kind = "per_sample" + else: + t = t.view(1, -1) + kind = "T" + elif t.ndim == 2: + if t.shape[0] != batch_size: + raise ValueError( + f"taus with ndim==2 must have shape (M,T); got {tuple(t.shape)} for M={batch_size}" + ) + kind = "MT" + else: + raise ValueError( + f"taus must be scalar, 1D, or 2D; got taus.ndim={t.ndim}") + return t, kind, scalar_out + + if logits.ndim != 3: + raise ValueError( + f"logits must be 3D (M,K,n_bins); got shape={tuple(logits.shape)}") + + M, K, n_bins = logits.shape + if self.bin_edges.numel() != n_bins + 1: + raise ValueError( + f"bin_edges length must be n_bins+1={n_bins+1}; got {self.bin_edges.numel()}" + ) + + used_eps = float(self.eps if eps is None else eps) + + taus_t, kind, scalar_out = _prepare_taus( + taus, M, logits.device, logits.dtype) + taus_t = torch.clamp(taus_t, min=0) + + bin_edges = self.bin_edges.to(device=logits.device, dtype=taus_t.dtype) + dt_bins = (bin_edges[1:] - bin_edges[:-1] + ).to(device=logits.device, dtype=logits.dtype) # (n_bins,) + + hazards = F.softplus(logits) + used_eps # (M, K, n_bins) + total_h = hazards.sum(dim=1) # (M, n_bins) + total_h = torch.clamp(total_h, min=used_eps) + + # Precompute full-bin CIF increments + H_total_bin = total_h * dt_bins.view(1, n_bins) # (M, n_bins) + cum_H_end = torch.cumsum(H_total_bin, dim=1) # (M, n_bins) + surv_end = torch.exp(-cum_H_end) # (M, n_bins) + ones = torch.ones((M, 1), device=logits.device, dtype=surv_end.dtype) + surv_start = torch.cat([ones, surv_end[:, :-1]], dim=1) # (M, n_bins) + + frac = hazards / total_h.unsqueeze(1) # (M, K, n_bins) + one_minus = 1.0 - \ + torch.exp(-total_h * dt_bins.view(1, n_bins)) # (M, n_bins) + inc_full = surv_start.unsqueeze( + 1) * frac * one_minus.unsqueeze(1) # (M, K, n_bins) + cif_full = torch.cumsum(inc_full, dim=2) # (M, K, n_bins) + + # Map taus -> bin index b in [0..n_bins] + time_bin = torch.bucketize(taus_t, bin_edges) + time_bin = torch.clamp(time_bin, min=0, max=n_bins).to( + torch.long) # (...) + + if kind == "T": + time_bin = time_bin.expand(M, -1) # (M,T) + + # Compute within-bin length l and indices + b = time_bin # (M,T) + idx_bin0 = torch.clamp(b - 1, min=0) # 0..n_bins-1 + + # Start-of-bin survival for the current bin (for b==0 it's unused) + S_start_b = surv_start.gather(dim=1, index=idx_bin0) # (M,T) + + # Length into bin: l = tau - edge[b-1], clamped to [0, dt_bin] + left_edge = bin_edges.gather( + dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0).to(taus_t.dtype) + l = taus_t.expand_as(b) - left_edge + l = torch.clamp(l, min=0) + width_b = dt_bins.gather( + dim=0, index=idx_bin0.view(-1)).view_as(idx_bin0) + l = torch.min(l, width_b.to(l.dtype)) + + # CIF up to previous full bins + # if b<=1 => 0 else cif_full at (b-2) + prev_idx = torch.clamp(b - 2, min=0) + cif_before = cif_full.gather( + dim=2, + index=prev_idx.unsqueeze(1).expand(-1, K, -1), + ) # (M,K,T) + if (b <= 1).any(): + cif_before = cif_before.masked_fill((b <= 1).unsqueeze(1), 0.0) + + # Partial increment in current bin + total_h_b = total_h.gather(dim=1, index=idx_bin0) # (M,T) + haz_b = hazards.gather( + dim=2, + index=idx_bin0.unsqueeze(1).expand(-1, K, -1), + ) # (M,K,T) + frac_b = haz_b / total_h_b.unsqueeze(1) # (M,K,T) + + one_minus_partial = 1.0 - torch.exp(-total_h_b * l) # (M,T) + inc_partial = S_start_b.unsqueeze( + 1) * frac_b * one_minus_partial.unsqueeze(1) # (M,K,T) + + cifs = cif_before + inc_partial + + survival = S_start_b * torch.exp(-total_h_b * l) # (M,T) + + # Inference-only tail extension beyond the last finite edge. + # For tau > t_B (t_B = bin_edges[-1]), extend survival and CIFs using + # constant hazards from the final bin B: + # S(tau)=S(t_B) * exp(-Λ_B * (tau - t_B)) + # F_k(tau)=F_k(t_B) + S(t_B) * (λ_{k,B}/Λ_B) * (1 - exp(-Λ_B*(tau-t_B))) + last_edge = bin_edges[-1] + tau_full = taus_t.expand_as(b) # (M,T) + tail_mask = tau_full > last_edge + if tail_mask.any(): + delta = torch.clamp(tau_full - last_edge, min=0) # (M,T) + + S_B = surv_end[:, -1].unsqueeze(1) # (M,1) + F_B = cif_full[:, :, -1].unsqueeze(-1) # (M,K,1) + + lambda_last = hazards[:, :, -1] # (M,K) + Lambda_last = torch.clamp( + total_h[:, -1], min=used_eps).unsqueeze(1) # (M,1) + + exp_tail = torch.exp(-Lambda_last * delta) # (M,T) + survival_tail = S_B * exp_tail # (M,T) + cifs_tail = F_B + \ + S_B.unsqueeze( + 1) * (lambda_last / Lambda_last).unsqueeze(-1) * (1.0 - exp_tail).unsqueeze(1) + + survival = torch.where(tail_mask, survival_tail, survival) + cifs = torch.where(tail_mask.unsqueeze(1), cifs_tail, cifs) + + # tau mapped to bin 0 => CIF=0, survival=1 + zero_mask = (b == 0) + if zero_mask.any(): + cifs = cifs.masked_fill(zero_mask.unsqueeze(1), 0.0) + survival = survival.masked_fill(zero_mask, 1.0) + + if kind == "per_sample": + cifs = cifs.squeeze(-1) # (M,K) + survival = survival.squeeze(-1) # (M,) + elif scalar_out: + cifs = cifs.squeeze(-1) # (M,K) + survival = survival.squeeze(-1) # (M,) + + return (cifs, survival) if return_survival else cifs diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..25b6d3f --- /dev/null +++ b/utils.py @@ -0,0 +1,130 @@ +import torch +from typing import Tuple + +DAYS_PER_YEAR = 365.25 + + +def select_context_indices( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + offset_years: float, + tau_years: float = 0.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Select per-sample prediction context index. + + IMPORTANT SEMANTICS: + - The last observed token time is treated as the FOLLOW-UP END time. + - We pick the last valid token with time <= (followup_end_time - offset). + - We do NOT interpret followup_end_time as an event time. + + Returns: + keep_mask: (B,) bool, which samples have a valid context + t_ctx: (B,) long, index into sequence + t_ctx_time: (B,) float, time (days) at context + """ + # valid tokens are event != 0 (padding is 0) + valid = event_seq != 0 + lengths = valid.sum(dim=1) + last_idx = torch.clamp(lengths - 1, min=0) + + b = torch.arange(event_seq.size(0), device=event_seq.device) + followup_end_time = time_seq[b, last_idx] + t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) + + eligible = valid & (time_seq <= t_cut.unsqueeze(1)) + eligible_counts = eligible.sum(dim=1) + keep = eligible_counts > 0 + + t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) + t_ctx_time = time_seq[b, t_ctx] + + # Horizon-aligned eligibility: require enough follow-up time after the selected context. + # All times are in days. + keep = keep & (followup_end_time >= ( + t_ctx_time + (tau_years * DAYS_PER_YEAR))) + + return keep, t_ctx, t_ctx_time + + +def multi_hot_ever_within_horizon( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + t_ctx: torch.Tensor, + tau_years: float, + n_disease: int, +) -> torch.Tensor: + """Binary labels: disease k occurs within tau after context (any occurrence).""" + B, L = event_seq.shape + b = torch.arange(B, device=event_seq.device) + t0 = time_seq[b, t_ctx] + t1 = t0 + (tau_years * DAYS_PER_YEAR) + + idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) + # Include same-day events after context, exclude any token at/before context index. + in_window = ( + (idxs > t_ctx.unsqueeze(1)) + & (time_seq >= t0.unsqueeze(1)) + & (time_seq <= t1.unsqueeze(1)) + & (event_seq >= 2) + & (event_seq != 0) + ) + + if not in_window.any(): + return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) + + b_idx, t_idx = in_window.nonzero(as_tuple=True) + disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) + + y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) + y[b_idx, disease_ids] = True + return y + + +def multi_hot_selected_causes_within_horizon( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + t_ctx: torch.Tensor, + tau_years: float, + cause_ids: torch.Tensor, + n_disease: int, +) -> torch.Tensor: + """Labels for selected causes only: does cause k occur within tau after context?""" + B, L = event_seq.shape + device = event_seq.device + b = torch.arange(B, device=device) + t0 = time_seq[b, t_ctx] + t1 = t0 + (tau_years * DAYS_PER_YEAR) + + idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) + in_window = ( + (idxs > t_ctx.unsqueeze(1)) + & (time_seq >= t0.unsqueeze(1)) + & (time_seq <= t1.unsqueeze(1)) + & (event_seq >= 2) + & (event_seq != 0) + ) + + out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) + if not in_window.any(): + return out + + b_idx, t_idx = in_window.nonzero(as_tuple=True) + disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) + + # Filter to selected causes via a boolean membership mask over the global disease space. + selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) + selected[cause_ids] = True + keep = selected[disease_ids] + if not keep.any(): + return out + + b_idx = b_idx[keep] + disease_ids = disease_ids[keep] + + # Map disease_id -> local index in cause_ids + # Build a lookup table (global disease space) where lookup[disease_id] = local_index + lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) + lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) + local = lookup[disease_ids] + out[b_idx, local] = True + return out