import json import math import os import random import re from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import torch from torch.utils.data import DataLoader, Dataset, Subset, random_split try: from tqdm import tqdm as _tqdm except Exception: # pragma: no cover _tqdm = None try: from joblib import Parallel, delayed # type: ignore except Exception: # pragma: no cover Parallel = None delayed = None from dataset import HealthDataset from losses import ( DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss, ) from model import DelphiFork, SapDelphi, SimpleHead DAYS_PER_YEAR = 365.25 N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2 def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None): if enabled and _tqdm is not None: return _tqdm(iterable, desc=desc, total=total) return iterable def make_inference_dataloader_kwargs( device: torch.device, num_workers: int, ) -> Dict[str, Any]: """DataLoader kwargs tuned for inference throughput. Behavior/metrics are unchanged; this only impacts speed. """ use_cuda = device.type == "cuda" and torch.cuda.is_available() kwargs: Dict[str, Any] = { "pin_memory": bool(use_cuda), } if num_workers > 0: kwargs["persistent_workers"] = True # default prefetch is 2; set explicitly for clarity. kwargs["prefetch_factor"] = 2 return kwargs # ------------------------- # Config + determinism # ------------------------- def _replace_nonstandard_json_numbers(text: str) -> str: # Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats. # Replace bare tokens (not within quotes) with string placeholders. def repl(match: re.Match[str]) -> str: token = match.group(0) if token == "-Infinity": return '"__NINF__"' if token == "Infinity": return '"__INF__"' if token == "NaN": return '"__NAN__"' return token return re.sub(r'(? Any: if isinstance(obj, dict): return {k: _restore_placeholders(v) for k, v in obj.items()} if isinstance(obj, list): return [_restore_placeholders(v) for v in obj] if obj == "__INF__": return float("inf") if obj == "__NINF__": return float("-inf") if obj == "__NAN__": return float("nan") return obj def load_train_config(run_dir: str) -> Dict[str, Any]: cfg_path = os.path.join(run_dir, "train_config.json") with open(cfg_path, "r", encoding="utf-8") as f: raw = f.read() raw = _replace_nonstandard_json_numbers(raw) cfg = json.loads(raw) cfg = _restore_placeholders(cfg) return cfg def seed_everything(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def parse_float_list(values: Sequence[str]) -> List[float]: out: List[float] = [] for v in values: s = str(v).strip().lower() if s in {"inf", "+inf", "infty", "infinity", "+infinity"}: out.append(float("inf")) elif s in {"-inf", "-infty", "-infinity"}: out.append(float("-inf")) else: out.append(float(v)) return out # ------------------------- # Dataset + split (match train.py) # ------------------------- def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset: data_prefix = cfg["data_prefix"] full_cov = bool(cfg.get("full_cov", False)) if full_cov: cov_list = None else: cov_list = ["bmi", "smoking", "alcohol"] dataset = HealthDataset( data_prefix=data_prefix, covariate_list=cov_list, ) return dataset def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset: n_total = len(dataset) train_ratio = float(cfg["train_ratio"]) val_ratio = float(cfg["val_ratio"]) seed = int(cfg["random_seed"]) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val _, _, test_subset = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed), ) return test_subset # ------------------------- # Model + head + criterion (match train.py) # ------------------------- def build_model_head_criterion( cfg: Dict[str, Any], dataset: HealthDataset, device: torch.device, ) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: loss_type = cfg["loss_type"] if loss_type == "exponential": criterion = ExponentialNLLLoss(lambda_reg=float( cfg.get("lambda_reg", 0.0))).to(device) out_dims = [dataset.n_disease] elif loss_type == "discrete_time_cif": bin_edges = [float(x) for x in cfg["bin_edges"]] criterion = DiscreteTimeCIFNLLLoss( bin_edges=bin_edges, lambda_reg=float(cfg.get("lambda_reg", 0.0)), ).to(device) out_dims = [dataset.n_disease + 1, len(bin_edges)] elif loss_type == "pwe_cif": # training drops +inf for PWE raw_edges = [float(x) for x in cfg["bin_edges"]] pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))] if len(pwe_edges) < 2: raise ValueError( "pwe_cif requires at least 2 finite bin edges (including 0). " f"Got bin_edges={raw_edges}" ) if float(pwe_edges[0]) != 0.0: raise ValueError( f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}") criterion = PiecewiseExponentialCIFNLLLoss( bin_edges=pwe_edges, lambda_reg=float(cfg.get("lambda_reg", 0.0)), ).to(device) n_bins = len(pwe_edges) - 1 out_dims = [dataset.n_disease, n_bins] else: raise ValueError(f"Unsupported loss_type: {loss_type}") model_type = cfg["model_type"] if model_type == "delphi_fork": model = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=N_TECH_TOKENS, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), ).to(device) elif model_type == "sap_delphi": model = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=N_TECH_TOKENS, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), pretrained_weights_path=str( cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), freeze_embeddings=True, ).to(device) else: raise ValueError(f"Unsupported model_type: {model_type}") head = SimpleHead( n_embd=int(cfg["n_embd"]), out_dims=list(out_dims), ).to(device) return model, head, criterion def load_checkpoint_into( run_dir: str, model: torch.nn.Module, head: torch.nn.Module, criterion: Optional[torch.nn.Module], device: torch.device, ) -> Dict[str, Any]: ckpt_path = os.path.join(run_dir, "best_model.pt") ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt["model_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True) if criterion is not None and "criterion_state_dict" in ckpt: try: criterion.load_state_dict( ckpt["criterion_state_dict"], strict=False) except Exception: # Criterion state is not essential for inference. pass return ckpt # ------------------------- # Evaluation record construction (event-driven) # ------------------------- @dataclass(frozen=True) class EvalRecord: subset_idx: int doa_days: float t0_days: float cutoff_pos: int # baseline position (inclusive) next_event_cause: Optional[int] next_event_dt_years: Optional[float] # (U,) unique causes ever observed (clean-control filtering) lifetime_causes: np.ndarray future_causes: np.ndarray # (E,) in [0..K-1] future_dt_years: np.ndarray # (E,) strictly > 0 def _to_days(x_years: float) -> float: if math.isinf(float(x_years)): return float("inf") return float(x_years) * DAYS_PER_YEAR def build_event_driven_records( subset: Subset, age_bins_years: Sequence[float], seed: int, show_progress: bool = False, n_jobs: int = 1, chunk_size: int = 256, prefer: str = "threads", ) -> List[EvalRecord]: if len(age_bins_years) < 2: raise ValueError("age_bins must have at least 2 boundaries") age_bins_days = [_to_days(b) for b in age_bins_years] if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)): raise ValueError("age_bins must be strictly increasing") def _iter_chunks(n: int, size: int) -> List[np.ndarray]: if size <= 0: raise ValueError("chunk_size must be >= 1") if n == 0: return [] idx = np.arange(n, dtype=np.int64) return [idx[i:i + size] for i in range(0, n, size)] def _build_records_for_index( subset_idx: int, *, age_bins_days_local: Sequence[float], rng_local: np.random.Generator, ) -> List[EvalRecord]: event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) doa_pos = np.flatnonzero(codes_ins == 1) if doa_pos.size == 0: raise ValueError("Expected DOA token (code=1) in event sequence") doa_days = float(times_ins[int(doa_pos[0])]) is_disease = codes_ins >= N_TECH_TOKENS # Lifetime (ever) disease history for Clean Control filtering. if np.any(is_disease): lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype( np.int64, copy=False ) lifetime_causes = np.unique(lifetime_causes) else: lifetime_causes = np.zeros((0,), dtype=np.int64) disease_pos_all = np.flatnonzero(is_disease) disease_times_all = ( times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros((0,), dtype=np.float64) ) eps = 1e-6 out: List[EvalRecord] = [] for b in range(len(age_bins_days_local) - 1): lo = float(age_bins_days_local[b]) hi = float(age_bins_days_local[b + 1]) # Inclusion rule: # 1) DOA <= bin_upper if not (doa_days <= hi): continue # 2) at least one disease event within bin, and baseline must satisfy t0>=DOA. # Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin). if disease_pos_all.size == 0: continue in_bin = ( (disease_times_all >= lo) & (disease_times_all < hi) & (disease_times_all >= doa_days) ) cand_pos = disease_pos_all[in_bin] if cand_pos.size == 0: continue cutoff_pos = int(rng_local.choice(cand_pos)) t0_days = float(times_ins[cutoff_pos]) # Future disease events strictly after t0 future_mask = (times_ins > (t0_days + eps)) & is_disease future_pos = np.flatnonzero(future_mask) if future_pos.size == 0: next_cause = None next_dt_years = None future_causes = np.zeros((0,), dtype=np.int64) future_dt_years_arr = np.zeros((0,), dtype=np.float32) else: future_times_days = times_ins[future_pos] future_tokens = codes_ins[future_pos] future_causes = ( future_tokens - N_TECH_TOKENS).astype(np.int64) future_dt_years_arr = ( (future_times_days - t0_days) / DAYS_PER_YEAR ).astype(np.float32) # next-event = minimal time > t0 (tie broken by earliest position) next_idx = int(np.argmin(future_times_days)) next_cause = int(future_causes[next_idx]) next_dt_years = float(future_dt_years_arr[next_idx]) out.append( EvalRecord( subset_idx=int(subset_idx), doa_days=float(doa_days), t0_days=float(t0_days), cutoff_pos=int(cutoff_pos), next_event_cause=next_cause, next_event_dt_years=next_dt_years, lifetime_causes=lifetime_causes, future_causes=future_causes, future_dt_years=future_dt_years_arr, ) ) return out def _process_chunk( chunk_indices: Sequence[int], *, age_bins_days_local: Sequence[float], seed_local: int, ) -> List[EvalRecord]: out: List[EvalRecord] = [] for subset_idx in chunk_indices: # Ensure each subject has its own deterministic RNG stream, so parallel # workers do not share identical seeds. ss = np.random.SeedSequence([int(seed_local), int(subset_idx)]) rng_local = np.random.default_rng(ss) out.extend( _build_records_for_index( int(subset_idx), age_bins_days_local=age_bins_days_local, rng_local=rng_local, ) ) return out n = int(len(subset)) chunks = _iter_chunks(n, int(chunk_size)) do_parallel = ( int(n_jobs) != 1 and Parallel is not None and delayed is not None and n > 0 ) if do_parallel: # Note: on Windows, process-based parallelism may require the underlying # dataset to be pickleable. `prefer="threads"` is the default for safety. parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)( delayed(_process_chunk)( chunk, age_bins_days_local=age_bins_days, seed_local=int(seed), ) for chunk in chunks ) records = [r for part in parts for r in part] return records # Sequential (preserve prior behavior/progress reporting) rng = np.random.default_rng(seed) records: List[EvalRecord] = [] eps = 1e-6 for subset_idx in _progress( range(len(subset)), enabled=show_progress, desc="Building eval records", total=len(subset), ): event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) doa_pos = np.flatnonzero(codes_ins == 1) if doa_pos.size == 0: raise ValueError("Expected DOA token (code=1) in event sequence") doa_days = float(times_ins[int(doa_pos[0])]) is_disease = codes_ins >= N_TECH_TOKENS if np.any(is_disease): lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype( np.int64, copy=False ) lifetime_causes = np.unique(lifetime_causes) else: lifetime_causes = np.zeros((0,), dtype=np.int64) disease_pos_all = np.flatnonzero(is_disease) disease_times_all = ( times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros((0,), dtype=np.float64) ) for b in range(len(age_bins_days) - 1): lo = age_bins_days[b] hi = age_bins_days[b + 1] if not (doa_days <= hi): continue if disease_pos_all.size == 0: continue in_bin = ( (disease_times_all >= lo) & (disease_times_all < hi) & (disease_times_all >= doa_days) ) cand_pos = disease_pos_all[in_bin] if cand_pos.size == 0: continue cutoff_pos = int(rng.choice(cand_pos)) t0_days = float(times_ins[cutoff_pos]) future_mask = (times_ins > (t0_days + eps)) & is_disease future_pos = np.flatnonzero(future_mask) if future_pos.size == 0: next_cause = None next_dt_years = None future_causes = np.zeros((0,), dtype=np.int64) future_dt_years_arr = np.zeros((0,), dtype=np.float32) else: future_times_days = times_ins[future_pos] future_tokens = codes_ins[future_pos] future_causes = ( future_tokens - N_TECH_TOKENS).astype(np.int64) future_dt_years_arr = ( (future_times_days - t0_days) / DAYS_PER_YEAR ).astype(np.float32) next_idx = int(np.argmin(future_times_days)) next_cause = int(future_causes[next_idx]) next_dt_years = float(future_dt_years_arr[next_idx]) records.append( EvalRecord( subset_idx=int(subset_idx), doa_days=float(doa_days), t0_days=float(t0_days), cutoff_pos=int(cutoff_pos), next_event_cause=next_cause, next_event_dt_years=next_dt_years, lifetime_causes=lifetime_causes, future_causes=future_causes, future_dt_years=future_dt_years_arr, ) ) return records class EvalRecordDataset(Dataset): def __init__(self, subset: Dataset, records: Sequence[EvalRecord]): self.subset = subset self.records = list(records) self._cache: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] = {} self._cache_order: List[int] = [] self._cache_max = 2048 def __len__(self) -> int: return len(self.records) def __getitem__(self, idx: int): rec = self.records[idx] cached = self._cache.get(rec.subset_idx) if cached is None: event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx] cached = (event_seq, time_seq, cont, cate, int(sex)) self._cache[rec.subset_idx] = cached self._cache_order.append(rec.subset_idx) if len(self._cache_order) > self._cache_max: drop = self._cache_order.pop(0) self._cache.pop(drop, None) else: event_seq, time_seq, cont, cate, sex = cached cutoff = rec.cutoff_pos + 1 event_seq = event_seq[:cutoff] time_seq = time_seq[:cutoff] baseline_pos = rec.cutoff_pos # same index in truncated sequence return event_seq, time_seq, cont, cate, sex, baseline_pos def eval_collate_fn(batch): from torch.nn.utils.rnn import pad_sequence event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip( *batch) event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0) time_batch = pad_sequence( time_seqs, batch_first=True, padding_value=36525.0) cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1) cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1) sex_batch = torch.tensor(sexes, dtype=torch.long) baseline_pos = torch.tensor(baseline_pos, dtype=torch.long) return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos # ------------------------- # Inference utilities # ------------------------- def predict_cifs( model: torch.nn.Module, head: torch.nn.Module, criterion: torch.nn.Module, loader: DataLoader, taus_years: Sequence[float], device: torch.device, show_progress: bool = False, progress_desc: str = "Inference", ) -> np.ndarray: model.eval() head.eval() taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device) all_out: List[np.ndarray] = [] with torch.no_grad(): for batch in _progress( loader, enabled=show_progress, desc=progress_desc, total=len(loader) if hasattr(loader, "__len__") else None, ): event_seq, time_seq, cont, cate, sex, baseline_pos = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) cont = cont.to(device, non_blocking=True) cate = cate.to(device, non_blocking=True) sex = sex.to(device, non_blocking=True) baseline_pos = baseline_pos.to(device, non_blocking=True) h = model(event_seq, time_seq, sex, cont, cate) b_idx = torch.arange(h.size(0), device=device) c = h[b_idx, baseline_pos] logits = head(c) cifs = criterion.calculate_cifs(logits, taus_t) out = cifs.detach().cpu().numpy() all_out.append(out) return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,)) def flatten_future_events( records: Sequence[EvalRecord], n_causes: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Flatten (record_idx, cause, dt_years) across all future events. Used to build horizon labels via vectorized masking + scatter. """ rec_idx_parts: List[np.ndarray] = [] cause_parts: List[np.ndarray] = [] dt_parts: List[np.ndarray] = [] for i, r in enumerate(records): if r.future_causes.size == 0: continue causes = r.future_causes dts = r.future_dt_years # Keep only valid cause ids. m = (causes >= 0) & (causes < n_causes) if not np.any(m): continue causes = causes[m].astype(np.int64, copy=False) dts = dts[m].astype(np.float32, copy=False) rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32)) cause_parts.append(causes) dt_parts.append(dts) if not rec_idx_parts: return ( np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int64), np.zeros((0,), dtype=np.float32), ) return ( np.concatenate(rec_idx_parts, axis=0), np.concatenate(cause_parts, axis=0), np.concatenate(dt_parts, axis=0), ) # ------------------------- # Metrics helpers # ------------------------- def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float: """Binary ROC AUC with tie-aware average ranks. Returns NaN if y_true has no positives or no negatives. """ y_true = np.asarray(y_true).astype(np.int32) y_score = np.asarray(y_score).astype(np.float64) n_pos = int(y_true.sum()) n = int(y_true.size) n_neg = n - n_pos if n_pos == 0 or n_neg == 0: return float("nan") order = np.argsort(y_score, kind="mergesort") scores_sorted = y_score[order] y_sorted = y_true[order] ranks = np.empty(n, dtype=np.float64) i = 0 while i < n: j = i + 1 while j < n and scores_sorted[j] == scores_sorted[i]: j += 1 # average rank for ties, ranks are 1..n avg_rank = 0.5 * (i + 1 + j) ranks[i:j] = avg_rank i = j sum_ranks_pos = float((ranks * y_sorted).sum()) auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg) return float(auc) def topk_indices(scores: np.ndarray, k: int) -> np.ndarray: """Return indices of top-k scores per row (descending).""" if k <= 0: raise ValueError("k must be positive") n, K = scores.shape k = min(k, K) # argpartition gives arbitrary order within topk; sort those by score part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k] part_scores = np.take_along_axis(scores, part, axis=1) order = np.argsort(-part_scores, axis=1, kind="mergesort") return np.take_along_axis(part, order, axis=1) # ------------------------- # Statistical evaluation (DeLong) # ------------------------- def compute_midrank(x: np.ndarray) -> np.ndarray: """Compute midranks of a 1D array (1-based ranks, tie-aware).""" x = np.asarray(x, dtype=np.float64) if x.ndim != 1: raise ValueError("compute_midrank expects a 1D array") order = np.argsort(x, kind="mergesort") x_sorted = x[order] n = int(x_sorted.size) midranks = np.empty((n,), dtype=np.float64) i = 0 while i < n: j = i while j < n and x_sorted[j] == x_sorted[i]: j += 1 # ranks are 1..n; average over ties mid = 0.5 * ((i + 1) + j) midranks[i:j] = mid i = j out = np.empty((n,), dtype=np.float64) out[order] = midranks return out def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: """Fast DeLong method for AUC covariance. Args: predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first label_1_count examples are positives. label_1_count: number of positive examples. Returns: (aucs, delong_cov) """ preds = np.asarray(predictions_sorted_transposed, dtype=np.float64) if preds.ndim != 2: raise ValueError("predictions_sorted_transposed must be 2D") m = int(label_1_count) n = int(preds.shape[1] - m) if m <= 0 or n <= 0: raise ValueError("DeLong requires at least 1 positive and 1 negative") k = int(preds.shape[0]) tx = np.empty((k, m), dtype=np.float64) ty = np.empty((k, n), dtype=np.float64) tz = np.empty((k, m + n), dtype=np.float64) for r in range(k): tx[r] = compute_midrank(preds[r, :m]) ty[r] = compute_midrank(preds[r, m:]) tz[r] = compute_midrank(preds[r, :]) aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) v01 = (tz[:, :m] - tx) / float(n) v10 = 1.0 - (tz[:, m:] - ty) / float(m) # np.cov expects variables in rows by default when rowvar=True. sx = np.cov(v01, rowvar=True, bias=False) sy = np.cov(v10, rowvar=True, bias=False) delong_cov = sx / float(m) + sy / float(n) return aucs, delong_cov def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]: """Return ordering that places positives first and label_1_count.""" y = np.asarray(ground_truth, dtype=np.int32) if y.ndim != 1: raise ValueError("ground_truth must be 1D") label_1_count = int(y.sum()) order = np.argsort(-y, kind="mergesort") return order, label_1_count def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]: """Compute AUC and its DeLong variance. Args: healthy_scores: scores for controls (label=0) diseased_scores: scores for cases (label=1) Returns: (auc, auc_variance) """ h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1) d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1) n0 = int(h.size) n1 = int(d.size) if n0 == 0 or n1 == 0: return float("nan"), float("nan") # Arrange positives first as required by fastDeLong. scores = np.concatenate([d, h], axis=0) gt = np.concatenate([ np.ones((n1,), dtype=np.int32), np.zeros((n0,), dtype=np.int32), ]) order, label_1_count = compute_ground_truth_statistics(gt) preds_sorted = scores[order][None, :] aucs, cov = fastDeLong(preds_sorted, label_1_count) auc = float(aucs[0]) cov = np.asarray(cov) var = float(cov[0, 0]) if cov.ndim == 2 else float(cov) return auc, var # ------------------------- # Next-token inference helper # ------------------------- def predict_next_token_logits( model: torch.nn.Module, head: torch.nn.Module, loader: DataLoader, device: torch.device, show_progress: bool = False, progress_desc: str = "Inference (next-token)", return_probs: bool = True, ) -> np.ndarray: """Predict per-cause next-token scores at baseline positions. Returns: np.ndarray of shape (N, K) where K is number of diseases (causes). Notes: - For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the *first* time/bin (index 0) and drops the complement channel when present. - If return_probs=True, applies softmax over causes for probability-like scores. """ model.eval() head.eval() all_out: List[np.ndarray] = [] with torch.no_grad(): for batch in _progress( loader, enabled=show_progress, desc=progress_desc, total=len(loader) if hasattr(loader, "__len__") else None, ): event_seq, time_seq, cont, cate, sex, baseline_pos = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) cont = cont.to(device, non_blocking=True) cate = cate.to(device, non_blocking=True) sex = sex.to(device, non_blocking=True) baseline_pos = baseline_pos.to(device, non_blocking=True) h = model(event_seq, time_seq, sex, cont, cate) b_idx = torch.arange(h.size(0), device=device) c = h[b_idx, baseline_pos] logits = head(c) # logits can be (B, K) or (B, K, T) or (B, K+1, T) if logits.ndim == 2: cause_logits = logits elif logits.ndim == 3: # Use the first time/bin. cause_logits = logits[..., 0] else: raise ValueError( f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}" ) # If a complement/survival channel exists (discrete-time CIF), drop it. if hasattr(model, "n_disease"): n_disease = int(getattr(model, "n_disease")) if cause_logits.size(1) == n_disease + 1: cause_logits = cause_logits[:, :n_disease] elif cause_logits.size(1) > n_disease: cause_logits = cause_logits[:, :n_disease] if return_probs: scores = torch.softmax(cause_logits, dim=1) else: scores = cause_logits all_out.append(scores.detach().cpu().numpy()) return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))