from __future__ import annotations import math from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd import torch from utils import ( DAYS_PER_YEAR, multi_hot_ever_within_horizon, multi_hot_selected_causes_within_horizon, select_context_indices, ) def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: """Compute ROC AUC for binary labels with tie-aware ranking. Returns NaN if y_true has no positives or no negatives. Uses the Mann–Whitney U statistic with average ranks for ties. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan") n_pos = int(y_true.sum()) n_neg = n - n_pos if n_pos == 0 or n_neg == 0: return float("nan") # Rank scores ascending, average ranks for ties. order = np.argsort(y_score, kind="mergesort") sorted_scores = y_score[order] ranks = np.empty(n, dtype=float) i = 0 # 1-based ranks while i < n: j = i + 1 while j < n and sorted_scores[j] == sorted_scores[i]: j += 1 avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j ranks[order[i:j]] = avg_rank i = j sum_ranks_pos = float(ranks[y_true].sum()) u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) return float(u / (n_pos * n_neg)) def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: """Average precision (area under PR curve using step-wise interpolation). Returns NaN if no positives. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan") n_pos = int(y_true.sum()) if n_pos == 0: return float("nan") order = np.argsort(-y_score, kind="mergesort") y = y_true[order] tp = np.cumsum(y).astype(float) fp = np.cumsum(~y).astype(float) precision = tp / np.maximum(tp + fp, 1.0) recall = tp / n_pos # AP = sum over each positive of precision at that point / n_pos # (equivalent to ∑ Δrecall * precision) ap = float(np.sum(precision[y]) / n_pos) # handle potential tiny numerical overshoots return float(max(0.0, min(1.0, ap))) def _precision_recall_at_k_percent( y_true: np.ndarray, y_score: np.ndarray, k_percent: float, ) -> Tuple[float, float]: """Precision@K% and Recall@K% for binary labels. Returns (precision, recall). Returns NaN for recall if no positives. Returns NaN for precision if k leads to 0 selected. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan"), float("nan") n_pos = int(y_true.sum()) k = int(math.ceil((float(k_percent) / 100.0) * n)) if k <= 0: return float("nan"), float("nan") order = np.argsort(-y_score, kind="mergesort") top = order[:k] tp_top = int(y_true[top].sum()) precision = tp_top / k recall = float("nan") if n_pos == 0 else (tp_top / n_pos) return float(precision), float(recall) @dataclass class EvalConfig: horizons_years: Sequence[float] offset_years: float = 0.0 topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) cause_ids: Optional[Sequence[int]] = None @torch.no_grad() def evaluate_time_dependent( model: torch.nn.Module, head: torch.nn.Module, criterion, dataloader: torch.utils.data.DataLoader, n_disease: int, cfg: EvalConfig, device: str | torch.device, ) -> pd.DataFrame: """Evaluate time-dependent metrics per cause and per horizon. Assumptions: - time_seq is in days - horizons_years and the loss CIF times are in years - disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2 Returns: DataFrame with columns: cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score """ device = torch.device(device) model.eval() head.eval() horizons_years = [float(x) for x in cfg.horizons_years] if len(horizons_years) == 0: raise ValueError("cfg.horizons_years must be non-empty") topk_percents = [float(x) for x in cfg.topk_percents] if len(topk_percents) == 0: raise ValueError("cfg.topk_percents must be non-empty") if any((p <= 0.0 or p > 100.0) for p in topk_percents): raise ValueError( f"All topk_percents must be in (0,100]; got {topk_percents}") taus_tensor = torch.tensor( horizons_years, device=device, dtype=torch.float32) if cfg.cause_ids is None: cause_ids = None n_causes_eval = int(n_disease) else: cause_ids = torch.tensor( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) # Accumulate per horizon y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years] y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years] for batch in dataloader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) # Context index selection (independent of horizon); keep mask is refined per horizon. keep0, t_ctx, _ = select_context_indices( event_seq=event_seq, time_seq=time_seq, offset_years=float(cfg.offset_years), tau_years=0.0, ) if not keep0.any(): continue b = torch.arange(event_seq.size(0), device=device) c = h[b, t_ctx] # (B,D) logits = head(c) # CIFs for all horizons at once cifs_all = criterion.calculate_cifs( logits, taus=taus_tensor) # (B,K,T) or (B,K) if cifs_all.ndim != 3: raise ValueError( f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}" ) for h_idx, tau_y in enumerate(horizons_years): keep, _, _ = select_context_indices( event_seq=event_seq, time_seq=time_seq, offset_years=float(cfg.offset_years), tau_years=float(tau_y), ) keep = keep & keep0 if not keep.any(): continue if cause_ids is None: y = multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), n_disease=n_disease, ) y = y[keep] preds = cifs_all[keep, :, h_idx] else: y = multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), cause_ids=cause_ids, n_disease=n_disease, ) y = y[keep] preds = cifs_all[keep, :, h_idx].index_select( dim=1, index=cause_ids) y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy()) y_pred_by_h[h_idx].append( preds.detach().to(torch.float32).cpu().numpy()) rows: List[Dict[str, float | int]] = [] for h_idx, tau_y in enumerate(horizons_years): if len(y_true_by_h[h_idx]) == 0: # No eligible samples for this horizon. for k in range(n_causes_eval): cause_id = int(k) if cause_ids is None else int( cfg.cause_ids[k]) for k_percent in topk_percents: rows.append( dict( cause_id=cause_id, horizon_tau=float(tau_y), topk_percent=float(k_percent), n_samples=0, n_positives=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) continue y_true = np.concatenate(y_true_by_h[h_idx], axis=0) y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0) if y_true.shape != y_pred.shape: raise ValueError( f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}" ) n_samples = int(y_true.shape[0]) for k in range(n_causes_eval): yk = y_true[:, k] pk = y_pred[:, k] n_pos = int(yk.sum()) auc = _binary_roc_auc(yk, pk) auprc = _average_precision(yk, pk) brier = float(np.mean((yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan") cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k]) for k_percent in topk_percents: precision_k, recall_k = _precision_recall_at_k_percent( yk, pk, float(k_percent)) rows.append( dict( cause_id=cause_id, horizon_tau=float(tau_y), topk_percent=float(k_percent), n_samples=n_samples, n_positives=n_pos, auc=float(auc), auprc=float(auprc), recall_at_K=float(recall_k), precision_at_K=float(precision_k), brier_score=float(brier), ) ) return pd.DataFrame(rows)