diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..6633bb7 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,1245 @@ +""" +Landmark Analysis Evaluation Script for Longitudinal Health Prediction Models + +Implements the comprehensive evaluation framework defined in evaluate_design.md: +- Landmark analysis at age cutoffs {50, 60, 70} +- Prediction horizons {0.25, 0.5, 1, 2, 5, 10} years +- Two tracks: Complete-Case (primary) and Clean Control (academic benchmark) +- Metrics: AUC, Brier Score, Disease-Capture@K, Lift, Yield, DCA +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Dict, List, Tuple, Optional +import warnings + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, Subset +from tqdm import tqdm +from sklearn.metrics import roc_auc_score, brier_score_loss + +# Import model components +from model import DelphiFork, SapDelphi, SimpleHead +from dataset import HealthDataset, health_collate_fn +from losses import ( + ExponentialNLLLoss, + DiscreteTimeCIFNLLLoss, + PiecewiseExponentialCIFNLLLoss +) + +warnings.filterwarnings('ignore') + + +def _ensure_dir(path: str) -> str: + os.makedirs(path, exist_ok=True) + return path + + +def _to_float(x): + try: + return float(x) + except Exception: + return np.nan + + +def save_summary_json(summary: Dict, output_path: str) -> None: + """Save a single JSON summary file.""" + + def convert_to_serializable(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, (np.integer,)): + return int(obj) + if isinstance(obj, (np.floating,)): + return float(obj) + if isinstance(obj, dict): + return {k: convert_to_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [convert_to_serializable(v) for v in obj] + return obj + + summary_serializable = convert_to_serializable(summary) + with open(output_path, 'w') as f: + json.dump(summary_serializable, f, indent=2) + + +def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]: + """Save evaluation results into multiple CSV files. + + Produces long-form tables so they are easy to analyze/plot: + - landmarks_summary.csv + - auc_per_disease.csv + - capture_at_k.csv + - lift_yield.csv + - dca.csv + + Returns: + Mapping from logical name to file path. + """ + + out_dir = _ensure_dir(out_dir) + + summary_rows: List[Dict] = [] + auc_rows: List[Dict] = [] + capture_rows: List[Dict] = [] + lift_rows: List[Dict] = [] + dca_rows: List[Dict] = [] + + landmarks = results.get('landmarks', []) + for lm in landmarks: + age = lm.get('age_cutoff') + horizon = lm.get('horizon') + for track in ['complete_case', 'clean_control']: + track_res = lm.get(track) or {} + if not track_res: + continue + + summary_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'n_patients': track_res.get('n_patients', np.nan), + 'n_valid': track_res.get('n_valid', np.nan), + 'n_valid_patients': track_res.get('n_valid_patients', np.nan), + 'mean_auc': track_res.get('mean_auc', np.nan), + 'brier_score': track_res.get('brier_score', np.nan), + 'brier_skill_score': track_res.get('brier_skill_score', np.nan), + }) + + auc_per_disease = track_res.get('auc_per_disease') or {} + for disease_idx, auc in auc_per_disease.items(): + auc_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'disease_idx': int(disease_idx), + 'auc': _to_float(auc), + }) + + if track == 'complete_case': + capture = track_res.get('disease_capture_at_k') or {} + for k_val, per_disease in capture.items(): + k_int = int(k_val) + for disease_idx, rate in (per_disease or {}).items(): + capture_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'k': k_int, + 'disease_idx': int(disease_idx), + 'capture_rate': _to_float(rate), + }) + + lift_yield = track_res.get('lift_and_yield') or {} + overall = (lift_yield.get('overall') or {}) if isinstance( + lift_yield, dict) else {} + for frac, metrics in overall.items(): + lift_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'level': 'overall', + 'disease_idx': '', + 'workload_frac': _to_float(frac), + 'lift': _to_float((metrics or {}).get('lift')), + 'yield': _to_float((metrics or {}).get('yield')), + }) + + per_disease = (lift_yield.get('per_disease') or {} + ) if isinstance(lift_yield, dict) else {} + for disease_idx, disease_metrics in per_disease.items(): + for frac, metrics in (disease_metrics or {}).items(): + lift_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'level': 'per_disease', + 'disease_idx': int(disease_idx), + 'workload_frac': _to_float(frac), + 'lift': _to_float((metrics or {}).get('lift')), + 'yield': _to_float((metrics or {}).get('yield')), + }) + + dca = track_res.get('dca') or {} + thresholds = dca.get('thresholds') + net_benefit = dca.get('net_benefit') + if thresholds is not None and net_benefit is not None: + thresholds_arr = np.asarray(thresholds, dtype=np.float64) + nb_arr = np.asarray(net_benefit, dtype=np.float64) + for thr, nb in zip(thresholds_arr, nb_arr): + dca_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'threshold': float(thr), + 'net_benefit': float(nb), + }) + + paths: Dict[str, str] = {} + + df_summary = pd.DataFrame(summary_rows) + summary_path = os.path.join(out_dir, 'landmarks_summary.csv') + df_summary.to_csv(summary_path, index=False) + paths['landmarks_summary'] = summary_path + + df_auc = pd.DataFrame(auc_rows) + auc_path = os.path.join(out_dir, 'auc_per_disease.csv') + df_auc.to_csv(auc_path, index=False) + paths['auc_per_disease'] = auc_path + + df_capture = pd.DataFrame(capture_rows) + capture_path = os.path.join(out_dir, 'capture_at_k.csv') + df_capture.to_csv(capture_path, index=False) + paths['capture_at_k'] = capture_path + + df_lift = pd.DataFrame(lift_rows) + lift_path = os.path.join(out_dir, 'lift_yield.csv') + df_lift.to_csv(lift_path, index=False) + paths['lift_yield'] = lift_path + + df_dca = pd.DataFrame(dca_rows) + dca_path = os.path.join(out_dir, 'dca.csv') + df_dca.to_csv(dca_path, index=False) + paths['dca'] = dca_path + + return paths + + +class LandmarkEvaluator: + """ + Comprehensive landmark analysis evaluator for survival/competing risks models. + """ + + def __init__( + self, + model: torch.nn.Module, + head: torch.nn.Module, + loss_fn: torch.nn.Module, + dataset: HealthDataset, + device: str = 'cuda', + batch_size: int = 256, + num_workers: int = 4, + ): + self.model = model.to(device) + self.model.eval() + self.head = head.to(device) + self.head.eval() + self.loss_fn = loss_fn + self.dataset = dataset + self.device = device + self.batch_size = batch_size + self.num_workers = num_workers + + # Evaluation parameters from design doc + self.age_cutoffs = [50, 60, 70] + self.horizons = [0.25, 0.5, 1, 2, 5, 10] + self.top_k_values = [5, 10, 20, 50] + self.workload_fracs = [0.01, 0.05, 0.10, 0.20, 0.50] + + # Convert age to days for comparison + self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs] + self.horizons_days = [h * 365.25 for h in self.horizons] + + @torch.no_grad() + def compute_risk_scores( + self, + indices: List[int], + age_cutoff_days: float, + horizon_days: float, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute risk scores for specified patient indices at given landmark. + + Args: + indices: Patient indices to evaluate + age_cutoff_days: Age cutoff in days (T_cut) + horizon_days: Prediction horizon in days (H) + + Returns: + risk_scores: (N, K) array of risk scores per disease + t_anchors: (N,) array of anchor times (days from birth) + valid_mask: (N,) boolean array indicating valid predictions + """ + subset = Subset(self.dataset, indices) + loader = DataLoader( + subset, + batch_size=self.batch_size, + shuffle=False, + collate_fn=health_collate_fn, + num_workers=self.num_workers, + pin_memory=True if self.device == 'cuda' else False, + ) + + all_risk_scores = [] + all_t_anchors = [] + all_valid_mask = [] + + for batch in loader: + event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch + event_batch = event_batch.to(self.device) + time_batch = time_batch.to(self.device) + cont_batch = cont_batch.to(self.device) + cate_batch = cate_batch.to(self.device) + sex_batch = sex_batch.to(self.device) + + # Find anchor point: last valid record before age_cutoff + # Valid records are non-padding events (event >= 1) + valid_mask = event_batch >= 1 # (B, L) + before_cutoff = time_batch < age_cutoff_days # (B, L) + valid_before = valid_mask & before_cutoff # (B, L) + + # Find last valid position for each patient + batch_size = event_batch.size(0) + t_anchor = torch.zeros(batch_size, device=self.device) + anchor_idx = torch.zeros( + batch_size, dtype=torch.long, device=self.device) + has_anchor = torch.zeros( + batch_size, dtype=torch.bool, device=self.device) + + for b in range(batch_size): + valid_positions = valid_before[b].nonzero(as_tuple=True)[0] + if len(valid_positions) > 0: + last_pos = valid_positions[-1] + anchor_idx[b] = last_pos + t_anchor[b] = time_batch[b, last_pos] + has_anchor[b] = True + + # Get model predictions at anchor points + if has_anchor.any(): + # Forward pass + hidden = self.model(event_batch, time_batch, + sex_batch, cont_batch, cate_batch) + + # Get predictions at anchor positions + batch_indices = torch.arange(batch_size, device=self.device) + # (B, n_embd) + hidden_at_anchor = hidden[batch_indices, anchor_idx] + + # Compute logits using the loaded head + # (B, n_disease, ...) or (B, K+1, n_bins+1) etc. + logits = self.head(hidden_at_anchor) + + # Compute CIF scores + # Time gap from anchor to start of horizon + t_start = age_cutoff_days - t_anchor # (B,) + # Time gap from anchor to end of horizon + t_end = age_cutoff_days + horizon_days - t_anchor # (B,) + + # Ensure non-negative time gaps + t_start = torch.clamp(t_start, min=0) + t_end = torch.clamp(t_end, min=0) + + # Calculate CIF at both time points + cif_start = self._compute_cif(logits, t_start) # (B, K) + cif_end = self._compute_cif(logits, t_end) # (B, K) + + # Risk score is the increment within the horizon + risk_scores = cif_end - cif_start # (B, K) + risk_scores = torch.clamp( + risk_scores, min=0) # Ensure non-negative + + else: + # No valid anchor points in this batch + risk_scores = torch.zeros( + batch_size, self.dataset.n_disease, device=self.device) + + all_risk_scores.append(risk_scores.cpu().numpy()) + all_t_anchors.append(t_anchor.cpu().numpy()) + all_valid_mask.append(has_anchor.cpu().numpy()) + + # Concatenate results + risk_scores = np.vstack(all_risk_scores) # (N, K) + t_anchors = np.concatenate(all_t_anchors) # (N,) + valid_mask = np.concatenate(all_valid_mask) # (N,) + + return risk_scores, t_anchors, valid_mask + + def _compute_cif(self, logits: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Compute Cumulative Incidence Function at time t. + + Args: + logits: Model output logits (B, K, ...) depending on loss type + t: Time points (B,) in years from anchor + + Returns: + cif: (B, K) cumulative incidence probabilities + """ + t_years = t / 365.25 # Convert to years + + if isinstance(self.loss_fn, ExponentialNLLLoss): + # Exponential: logits are (B, K) + lambdas = F.softplus(logits) + 1e-6 # (B, K) + total_lambda = lambdas.sum(dim=-1, keepdim=True) # (B, 1) + + # CIF_k(t) = (λ_k / Σλ) * (1 - exp(-Σλ * t)) + frac = lambdas / total_lambda # (B, K) + exp_term = 1.0 - \ + torch.exp(-total_lambda.squeeze(-1).unsqueeze(-1) + * t_years.unsqueeze(-1)) + cif = frac * exp_term # (B, K) + + elif isinstance(self.loss_fn, DiscreteTimeCIFNLLLoss): + # Discrete-time CIF: use calculate_cifs method + cif = self.loss_fn.calculate_cifs( + logits, t_years, return_survival=False) + + elif isinstance(self.loss_fn, PiecewiseExponentialCIFNLLLoss): + # PWE CIF: use calculate_cifs method + cif = self.loss_fn.calculate_cifs( + logits, t_years, return_survival=False) + + else: + raise ValueError(f"Unknown loss type: {type(self.loss_fn)}") + + return cif + + def prepare_evaluation_cohort( + self, + age_cutoff_days: float, + horizon_days: float, + track: str = 'complete_case', + ) -> Tuple[List[int], np.ndarray, np.ndarray]: + """Prepare evaluation cohort per design protocol with per-disease validity. + + Key fixes vs the earlier implementation: + - Multi-disease labeling: mark *all* diseases that occur within the horizon. + - Per-disease validity mask: mask[i, k]=1 if patient i is valid for disease k. + + Returns: + indices: list of patient indices included in this cohort table + labels: (N, K) float array + valid_mask: (N, K) bool array + """ + + n_tech_tokens = 2 # PAD=0, DOA=1 + K = int(self.dataset.n_disease) + + # Competing risk token: treat death as competing event. + # As requested: DEATH_CODE is the last disease index. + DEATH_CODE = int(K - 1) + + horizon_end_days = age_cutoff_days + horizon_days + + indices: List[int] = [] + labels_rows: List[np.ndarray] = [] + valid_rows: List[np.ndarray] = [] + + for idx in range(len(self.dataset)): + patient_id = self.dataset.patient_ids[idx] + records = self.dataset.patient_events.get(patient_id, []) + if not records: + continue + + # Must have some information strictly prior to cutoff (anchor existence in the data). + has_pre_cutoff = any(t < age_cutoff_days for t, _ in records) + if not has_pre_cutoff: + continue + + # Events after cutoff (already sorted in dataset init) + events_after = [(t, e) for t, e in records if t >= age_cutoff_days] + + labels = np.zeros(K, dtype=np.float32) + valid = np.zeros(K, dtype=bool) + + # Identify diseases within horizon (multi-label; no early break) + diseases_in_horizon: set[int] = set() + death_in_horizon = False + for t, e in events_after: + if t > horizon_end_days: + break # events are time-sorted + if e < n_tech_tokens: + continue + disease_idx = int(e - n_tech_tokens) + if 0 <= disease_idx < K: + diseases_in_horizon.add(disease_idx) + if disease_idx == DEATH_CODE: + death_in_horizon = True + + for d in diseases_in_horizon: + labels[d] = 1.0 + + last_time = float(records[-1][0]) + + if track == 'complete_case': + # Track A: Complete-Case at Horizon + # - Hit: disease occurs within horizon => valid positive for that disease + # - Healthy: last record > horizon_end => valid negative for all diseases + # - Death: death within horizon => valid negative for diseases not occurred before death + # (we implement as valid for all diseases, with labels marking any hits incl death) + # - Loss: censored within horizon (no hit for disease) => invalid for that disease + + if last_time > horizon_end_days: + # Observed past horizon: negative for all non-hit diseases + valid[:] = True + elif death_in_horizon: + # Competing risk: include as explicit negatives (label 0) for diseases not hit + # and positives for diseases hit before death (we don't model ordering here). + # This matches the requirement: death within window is label=0 for target diseases. + valid[:] = True + else: + # Lost within horizon: only diseases that actually occurred in the horizon are valid positives + # (no assumptions about negatives). + for d in diseases_in_horizon: + valid[d] = True + + elif track == 'clean_control': + # Track B: Clean Control (per-disease) + # For each disease k: + # - Hit: disease k occurs within horizon => label=1, valid + # - Pure clean for k: disease k never occurs in entire record => label=0, valid + # - Death within window => drop (invalid) for all k + # - Loss within window => drop (invalid) for all k + # - Late onset of k (after horizon) => drop (invalid) for k + + if death_in_horizon: + # Drop all diseases for this patient + valid[:] = False + else: + # Precompute per-disease first occurrence time after cutoff (including after horizon) + first_occ_after_cutoff = np.full( + K, np.inf, dtype=np.float64) + ever_has_disease = np.zeros(K, dtype=bool) + for t, e in records: + if e < n_tech_tokens: + continue + disease_idx = int(e - n_tech_tokens) + if 0 <= disease_idx < K: + ever_has_disease[disease_idx] = True + if t >= age_cutoff_days and t < first_occ_after_cutoff[disease_idx]: + first_occ_after_cutoff[disease_idx] = float(t) + + # If lost within horizon and did not have disease k in horizon, it's invalid for k. + lost_within_horizon = last_time <= horizon_end_days + + for k in range(K): + if labels[k] == 1.0: + valid[k] = True + continue + + if not ever_has_disease[k]: + # Lifetime clean for k + # Still need to have complete follow-up within horizon for clean-control track. + # If censored within horizon, we cannot be sure k didn't occur in the window. + valid[k] = not lost_within_horizon + continue + + # Has disease k at some point. + t_first = first_occ_after_cutoff[k] + if np.isfinite(t_first) and t_first > horizon_end_days: + # Late onset of k -> invalid for k + valid[k] = False + else: + # Has k before cutoff (prevalent) or has k after cutoff but not in horizon? -> invalid. + valid[k] = False + + else: + raise ValueError(f"Unknown track: {track}") + + # Keep patient row if they are valid for at least one disease. + if valid.any(): + indices.append(idx) + labels_rows.append(labels) + valid_rows.append(valid) + + if not indices: + return [], np.zeros((0, K), dtype=np.float32), np.zeros((0, K), dtype=bool) + + return indices, np.stack(labels_rows, axis=0), np.stack(valid_rows, axis=0) + + def compute_auc_per_disease( + self, + risk_scores: np.ndarray, + labels: np.ndarray, + valid_mask: np.ndarray, + ) -> Dict[int, float]: + """ + Compute time-dependent AUC for each disease. + + Args: + risk_scores: (N, K) risk scores + labels: (N, K) binary labels + valid_mask: (N, K) boolean mask for valid evaluations + + Returns: + auc_scores: Dict mapping disease_idx -> AUC + """ + auc_scores = {} + n_diseases = risk_scores.shape[1] + + for k in range(n_diseases): + mk = valid_mask[:, k] + if not np.any(mk): + auc_scores[k] = np.nan + continue + + y_true = labels[mk, k] + y_score = risk_scores[mk, k] + + # Check if we have both classes + if len(np.unique(y_true)) < 2: + auc_scores[k] = np.nan + else: + try: + auc = roc_auc_score(y_true, y_score) + auc_scores[k] = auc + except Exception: + auc_scores[k] = np.nan + + return auc_scores + + def compute_brier_score( + self, + risk_scores: np.ndarray, + labels: np.ndarray, + valid_mask: np.ndarray, + ) -> Dict[str, float]: + """ + Compute Brier Score and Brier Skill Score. + + Args: + risk_scores: (N, K) risk scores + labels: (N, K) binary labels + valid_mask: (N, K) boolean mask + + Returns: + metrics: Dict with 'brier_score' and 'brier_skill_score' + """ + # Apply per-entry valid mask and flatten + m = valid_mask.astype(bool) + y_true_flat = labels[m] + y_pred_flat = risk_scores[m] + + mask = ~(np.isnan(y_true_flat) | np.isnan(y_pred_flat)) + y_true_flat = y_true_flat[mask] + y_pred_flat = y_pred_flat[mask] + + if len(y_true_flat) == 0: + return {'brier_score': np.nan, 'brier_skill_score': np.nan} + + bs = brier_score_loss(y_true_flat, y_pred_flat) + + # Brier Skill Score: reference is predicting the mean + p_mean = y_true_flat.mean() + bs_ref = ((p_mean - y_true_flat) ** 2).mean() + bss = 1.0 - (bs / bs_ref) if bs_ref > 0 else 0.0 + + return {'brier_score': bs, 'brier_skill_score': bss} + + def compute_disease_capture_at_k( + self, + risk_scores: np.ndarray, + labels: np.ndarray, + valid_mask: np.ndarray, + ) -> Dict[int, Dict[int, float]]: + """ + Compute Disease-Capture@K: fraction of true positives where the true + disease appears in the patient's top-K predicted risks. + + Args: + risk_scores: (N, K) risk scores + labels: (N, K) binary labels + valid_mask: (N, K) boolean mask + + Returns: + capture_rates: Dict[K_value][disease_idx] -> capture rate + """ + capture_rates = {k: {} for k in self.top_k_values} + + n_diseases = risk_scores.shape[1] + + for disease_idx in range(n_diseases): + mk = valid_mask[:, disease_idx] + if not np.any(mk): + for k in self.top_k_values: + capture_rates[k][disease_idx] = np.nan + continue + + y_true = labels[mk, disease_idx] + y_scores = risk_scores[mk] # (N_valid_k, K) + + # Find patients with positive label for this disease + pos_mask = y_true == 1 + if pos_mask.sum() == 0: + for k in self.top_k_values: + capture_rates[k][disease_idx] = np.nan + continue + + # For each positive patient, check if true disease is in top-K + for k_val in self.top_k_values: + captures = [] + for i in np.where(pos_mask)[0]: + # Get top-K disease indices for this patient + top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val] + # Check if true disease is in top-K + is_captured = disease_idx in top_k_diseases + captures.append(int(is_captured)) + + capture_rate = np.mean(captures) if captures else np.nan + capture_rates[k_val][disease_idx] = capture_rate + + return capture_rates + + def compute_lift_and_yield( + self, + risk_scores: np.ndarray, + labels: np.ndarray, + valid_mask: np.ndarray, + ) -> Dict[str, Dict]: + """ + Compute Lift and Yield at various workload fractions. + + Args: + risk_scores: (N, K) risk scores + labels: (N, K) binary labels + valid_mask: (N,) boolean mask + + Returns: + metrics: + { + 'overall': {workload_frac: {'lift': ..., 'yield': ...}, ...}, + 'per_disease': {disease_idx: {workload_frac: {'lift': ..., 'yield': ...}, ...}, ...} + } + """ + # Overall metric uses a patient-level mask to avoid including purely censored negatives. + # We include patients who either (a) have known outcomes for all diseases, or (b) have at least one hit. + has_any_hit = labels.max(axis=1) > 0 + has_all_known = valid_mask.all(axis=1) + overall_patient_mask = has_any_hit | has_all_known + + risk_scores_overall = risk_scores[overall_patient_mask] + labels_overall = labels[overall_patient_mask] + valid_overall = valid_mask[overall_patient_mask] + + # Flatten to patient-level: any disease event + # Max risk across all diseases for each patient + max_risk_per_patient = risk_scores_overall.max(axis=1) + any_disease_label = labels_overall.max(axis=1) + + n_patients = len(max_risk_per_patient) + base_rate = any_disease_label.mean() if n_patients > 0 else 0.0 + + overall: Dict[float, Dict[str, float]] = {} + for workload_frac in self.workload_fracs: + n_screen = max(1, int(n_patients * workload_frac)) + top_n_idx = np.argsort(max_risk_per_patient)[::-1][:n_screen] + top_n_labels = any_disease_label[top_n_idx] + yield_val = float(top_n_labels.mean()) if n_screen > 0 else np.nan + lift_val = (yield_val / float(base_rate)) if base_rate > 0 else 0.0 + overall[workload_frac] = {'lift': lift_val, 'yield': yield_val} + + per_disease: Dict[int, Dict[float, Dict[str, float]]] = {} + n_diseases = risk_scores.shape[1] + for disease_idx in range(n_diseases): + mk = valid_mask[:, disease_idx] + if not np.any(mk): + per_disease[disease_idx] = { + frac: {'lift': np.nan, 'yield': np.nan} for frac in self.workload_fracs} + continue + + disease_scores = risk_scores[mk, disease_idx] + disease_labels = labels[mk, disease_idx] + disease_base_rate = disease_labels.mean() if disease_labels.size > 0 else 0.0 + + n_patients_k = disease_scores.shape[0] + + disease_metrics: Dict[float, Dict[str, float]] = {} + for workload_frac in self.workload_fracs: + n_screen = max(1, int(n_patients_k * workload_frac)) + top_n_idx = np.argsort(disease_scores)[::-1][:n_screen] + top_n_labels = disease_labels[top_n_idx] + yield_val = float(top_n_labels.mean() + ) if n_screen > 0 else np.nan + lift_val = (yield_val / float(disease_base_rate) + ) if disease_base_rate > 0 else 0.0 + disease_metrics[workload_frac] = { + 'lift': lift_val, 'yield': yield_val} + + per_disease[disease_idx] = disease_metrics + + return { + 'overall': overall, + 'per_disease': per_disease, + } + + def compute_dca_net_benefit( + self, + risk_scores: np.ndarray, + labels: np.ndarray, + valid_mask: np.ndarray, + threshold_range: np.ndarray = np.linspace(0, 0.5, 51), + ) -> Dict[str, np.ndarray]: + """ + Compute Decision Curve Analysis (DCA) net benefit. + + Args: + risk_scores: (N, K) risk scores + labels: (N, K) binary labels + valid_mask: (N,) boolean mask + threshold_range: Array of threshold probabilities + + Returns: + dca_results: Dict with 'thresholds' and 'net_benefit' arrays + """ + # Use the same overall patient mask as lift/yield (complete-case style) + has_any_hit = labels.max(axis=1) > 0 + has_all_known = valid_mask.all(axis=1) + patient_mask = has_any_hit | has_all_known + + risk_scores = risk_scores[patient_mask] + labels = labels[patient_mask] + + # Use max risk and any disease label + max_risk = risk_scores.max(axis=1) + any_disease = labels.max(axis=1) + + n = len(max_risk) + net_benefits = [] + + for pt in threshold_range: + if pt == 0: + # Treat all + nb = any_disease.mean() + else: + # Treat if predicted risk > threshold + treat = max_risk >= pt + tp = (treat & (any_disease == 1)).sum() + fp = (treat & (any_disease == 0)).sum() + + # Net benefit = (TP/N) - (FP/N) * (pt / (1-pt)) + nb = (tp / n) - (fp / n) * (pt / (1 - pt)) + + net_benefits.append(nb) + + return { + 'thresholds': threshold_range, + 'net_benefit': np.array(net_benefits), + } + + def evaluate_landmark( + self, + age_cutoff: float, + horizon: float, + ) -> Dict: + """ + Evaluate model at a specific landmark (age_cutoff, horizon). + + Args: + age_cutoff: Age cutoff in years + horizon: Prediction horizon in years + + Returns: + results: Dictionary with all metrics + """ + age_cutoff_days = age_cutoff * 365.25 + horizon_days = horizon * 365.25 + + print(f"\nEvaluating Landmark: Age={age_cutoff}, Horizon={horizon}y") + + results = { + 'age_cutoff': age_cutoff, + 'horizon': horizon, + 'complete_case': {}, + 'clean_control': {}, + } + + for track in ['complete_case', 'clean_control']: + print(f" Track: {track}") + + # Prepare cohort + indices, labels_array, valid_mask = self.prepare_evaluation_cohort( + age_cutoff_days, horizon_days, track + ) + + if len(indices) == 0: + print(f" No valid patients for track {track}") + continue + + print(f" Cohort size: {len(indices)}") + + # Compute risk scores + risk_scores, t_anchors, anchor_mask = self.compute_risk_scores( + indices, age_cutoff_days, horizon_days + ) + + # Combine anchor availability with per-disease validity + valid_mask = valid_mask & anchor_mask.astype(bool)[:, None] + + # Compute metrics + print(" Computing AUC...") + auc_scores = self.compute_auc_per_disease( + risk_scores, labels_array, valid_mask) + mean_auc = np.nanmean(list(auc_scores.values())) + + print(" Computing Brier Score...") + brier_metrics = self.compute_brier_score( + risk_scores, labels_array, valid_mask) + + # Only compute patient-level and population metrics for complete_case + if track == 'complete_case': + print(" Computing Disease-Capture@K...") + capture_metrics = self.compute_disease_capture_at_k( + risk_scores, labels_array, valid_mask + ) + + print(" Computing Lift & Yield...") + lift_yield_metrics = self.compute_lift_and_yield( + risk_scores, labels_array, valid_mask + ) + + print(" Computing DCA...") + dca_metrics = self.compute_dca_net_benefit( + risk_scores, labels_array, valid_mask + ) + + results[track] = { + 'n_patients': len(indices), + 'n_valid': int(valid_mask.sum()), + 'n_valid_patients': int(valid_mask.any(axis=1).sum()), + 'auc_per_disease': auc_scores, + 'mean_auc': mean_auc, + 'brier_score': brier_metrics['brier_score'], + 'brier_skill_score': brier_metrics['brier_skill_score'], + 'disease_capture_at_k': capture_metrics, + 'lift_and_yield': lift_yield_metrics, + 'dca': dca_metrics, + } + else: + # Clean control track: only discrimination metrics + results[track] = { + 'n_patients': len(indices), + 'n_valid': int(valid_mask.sum()), + 'n_valid_patients': int(valid_mask.any(axis=1).sum()), + 'auc_per_disease': auc_scores, + 'mean_auc': mean_auc, + } + + return results + + def run_full_evaluation(self) -> Dict: + """ + Run complete landmark analysis across all cutoffs and horizons. + + Returns: + all_results: Nested dictionary with all evaluation results + """ + all_results = { + 'age_cutoffs': self.age_cutoffs, + 'horizons': self.horizons, + 'landmarks': [], + } + + # Evaluate each landmark + for age_cutoff in self.age_cutoffs: + for horizon in self.horizons: + landmark_results = self.evaluate_landmark(age_cutoff, horizon) + all_results['landmarks'].append(landmark_results) + + return all_results + + +def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple: + """ + Load trained model and configuration from run directory. + + Args: + run_dir: Path to run directory containing train_config.json and best_model.pt + device: Device to load model on + + Returns: + model, head, loss_fn, dataset, config + """ + run_path = Path(run_dir) + + # Load config + config_path = run_path / 'train_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + + print(f"Loading model from {run_dir}") + print(f"Model type: {config['model_type']}") + print(f"Loss type: {config['loss_type']}") + + # Load dataset to get dimensions + data_prefix = config['data_prefix'] + + # Determine covariate list based on full_cov + if config['full_cov']: + covariate_list = None # Use all covariates + else: + # Use partial covariates (define your partial list here) + covariate_list = ['age_at_assessment', + 'bmi', 'smoking_status'] # Example + + dataset = HealthDataset( + data_prefix=f"{data_prefix}_test", + covariate_list=covariate_list, + cache_event_tensors=True, + ) + + # Determine output dimensions based on loss type + import math + if config['loss_type'] == 'exponential': + out_dims = [dataset.n_disease] + elif config['loss_type'] == 'discrete_time_cif': + # logits shape (M, K+1, n_bins+1) + bin_edges = config.get( + 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]) + out_dims = [dataset.n_disease + 1, len(bin_edges)] + elif config['loss_type'] == 'pwe_cif': + # Piecewise-exponential requires finite edges + bin_edges = config.get( + 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]) + pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] + n_bins = len(pwe_edges) - 1 + # logits shape (M, K, n_bins) + out_dims = [dataset.n_disease, n_bins] + else: + raise ValueError(f"Unknown loss type: {config['loss_type']}") + + # Build model + if config['model_type'] == 'delphi_fork': + model = DelphiFork( + n_disease=dataset.n_disease, + n_tech_tokens=2, # PAD=0, DOA=1 + n_embd=config['n_embd'], + n_head=config['n_head'], + n_layer=config['n_layer'], + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, + age_encoder_type=config['age_encoder'], + pdrop=config['pdrop'], + ) + elif config['model_type'] == 'sap_delphi': + model = SapDelphi( + n_disease=dataset.n_disease, + n_tech_tokens=2, + n_embd=config['n_embd'], + n_head=config['n_head'], + n_layer=config['n_layer'], + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, + age_encoder_type=config['age_encoder'], + pdrop=config['pdrop'], + pretrained_weights_path=config.get('pretrained_emd_path'), + ) + else: + raise ValueError(f"Unknown model type: {config['model_type']}") + + # Build head + head = SimpleHead( + n_embd=config['n_embd'], + out_dims=out_dims, + ) + + # Load model weights (checkpoint contains model and head state dicts) + model_path = run_path / 'best_model.pt' + checkpoint = torch.load(model_path, map_location=device) + + # The checkpoint is a dict with 'model_state_dict' and 'head_state_dict' + if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + head.load_state_dict(checkpoint['head_state_dict']) + print("Loaded model and head from checkpoint") + else: + raise ValueError( + "Checkpoint format not recognized. Expected 'model_state_dict' and 'head_state_dict' keys.") + + model.to(device) + head.to(device) + + # Build loss function + if config['loss_type'] == 'exponential': + loss_fn = ExponentialNLLLoss( + lambda_reg=config.get('lambda_reg', 0.0) + ) + elif config['loss_type'] == 'discrete_time_cif': + loss_fn = DiscreteTimeCIFNLLLoss( + bin_edges=config.get( + 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float('inf')]), + lambda_reg=config.get('lambda_reg', 0.0), + ) + elif config['loss_type'] == 'pwe_cif': + loss_fn = PiecewiseExponentialCIFNLLLoss( + bin_edges=config.get( + 'bin_edges', [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]), + lambda_reg=config.get('lambda_reg', 0.0), + ) + else: + raise ValueError(f"Unknown loss type: {config['loss_type']}") + + return model, head, loss_fn, dataset, config + + +def print_summary(results: Dict): + """ + Print summary of evaluation results. + + Args: + results: Results dictionary + """ + print("\n" + "=" * 80) + print("EVALUATION SUMMARY") + print("=" * 80) + + for landmark in results['landmarks']: + age = landmark['age_cutoff'] + horizon = landmark['horizon'] + + print(f"\nLandmark: Age {age}, Horizon {horizon}y") + print("-" * 40) + + # Complete-case results + if 'complete_case' in landmark and landmark['complete_case']: + cc = landmark['complete_case'] + print(f" Complete-Case Track:") + print(f" Patients: {cc['n_patients']}") + print(f" Mean AUC: {cc['mean_auc']:.4f}") + print(f" Brier Score: {cc['brier_score']:.4f}") + print(f" Brier Skill Score: {cc['brier_skill_score']:.4f}") + + # Show top-K capture rates (average across diseases) + if 'disease_capture_at_k' in cc: + print(f" Disease Capture:") + for k in [5, 10, 20, 50]: + if k in cc['disease_capture_at_k']: + rates = list(cc['disease_capture_at_k'][k].values()) + mean_rate = np.nanmean(rates) + print(f" Top-{k}: {mean_rate:.3f}") + + # Show lift and yield + if 'lift_and_yield' in cc: + print(f" Lift & Yield:") + overall = cc['lift_and_yield'].get('overall', {}) if isinstance( + cc['lift_and_yield'], dict) else {} + for frac in [0.01, 0.05, 0.10]: + if frac in overall: + lift = overall[frac].get('lift', np.nan) + yield_val = overall[frac].get('yield', np.nan) + print( + f" Overall Top {int(frac*100)}%: Lift={lift:.2f}, Yield={yield_val:.3f}") + + # Clean control results + if 'clean_control' in landmark and landmark['clean_control']: + clean = landmark['clean_control'] + print(f" Clean-Control Track:") + print(f" Patients: {clean['n_patients']}") + print(f" Mean AUC: {clean['mean_auc']:.4f}") + + print("\n" + "=" * 80) + + +def main(): + parser = argparse.ArgumentParser( + description='Evaluate longitudinal health prediction model using landmark analysis' + ) + parser.add_argument( + '--run_dir', + type=str, + required=True, + help='Path to run directory containing train_config.json and best_model.pt' + ) + parser.add_argument( + '--output', + type=str, + default=None, + help='Output path for results JSON (default: /evaluation_results.json)' + ) + parser.add_argument( + '--out_dir', + type=str, + default=None, + help='Directory to write CSV outputs (default: /evaluation_outputs)' + ) + parser.add_argument( + '--device', + type=str, + default='cuda' if torch.cuda.is_available() else 'cpu', + help='Device to run evaluation on' + ) + parser.add_argument( + '--batch_size', + type=int, + default=256, + help='Batch size for evaluation' + ) + parser.add_argument( + '--num_workers', + type=int, + default=4, + help='Number of data loader workers' + ) + + args = parser.parse_args() + + # Load model and dataset + model, head, loss_fn, dataset, config = load_model_and_config( + args.run_dir, args.device) + + # Create evaluator + evaluator = LandmarkEvaluator( + model=model, + head=head, + loss_fn=loss_fn, + dataset=dataset, + device=args.device, + batch_size=args.batch_size, + num_workers=args.num_workers, + ) + + # Run evaluation + print("\nStarting landmark analysis evaluation...") + print(f"Age cutoffs: {evaluator.age_cutoffs}") + print(f"Horizons: {evaluator.horizons}") + + results = evaluator.run_full_evaluation() + + # Add metadata + results['metadata'] = { + 'run_dir': args.run_dir, + 'config': config, + 'n_diseases': dataset.n_disease, + 'device': args.device, + } + + # Save results (CSV bundle + single JSON summary) + if args.out_dir is None: + args.out_dir = os.path.join(args.run_dir, 'evaluation_outputs') + csv_paths = save_results_csv_bundle(results, args.out_dir) + + summary = { + 'metadata': results.get('metadata', {}), + 'age_cutoffs': results.get('age_cutoffs', []), + 'horizons': results.get('horizons', []), + 'csv_outputs': csv_paths, + 'notes': { + 'metrics': [ + 'AUC (per disease + mean)', + 'Brier Score / Brier Skill Score (complete-case only)', + 'Disease-Capture@K (complete-case only)', + 'Lift/Yield (complete-case only)', + 'Decision Curve Analysis (complete-case only)', + ], + }, + } + + if args.output is None: + args.output = os.path.join(args.run_dir, 'evaluation_summary.json') + save_summary_json(summary, args.output) + + print(f"\nWrote CSV outputs to: {args.out_dir}") + print(f"Wrote summary JSON to: {args.output}") + + # Print summary + print_summary(results) + + +if __name__ == '__main__': + main() diff --git a/run_evaluations_multi_gpu.sh b/run_evaluations_multi_gpu.sh deleted file mode 100644 index a60f00b..0000000 --- a/run_evaluations_multi_gpu.sh +++ /dev/null @@ -1,354 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -usage() { - cat <<'USAGE' -Usage: - ./run_evaluations_multi_gpu.sh --gpus 0,1,2 [options] [-- ] - -Description: - Discovers trained run directories (containing best_model.pt + train_config.json) - and runs BOTH evaluations on each run: - 1) evaluate_next_event.py - 2) evaluate_horizon.py - - Jobs are distributed round-robin across the provided GPU list and each GPU runs - at most one job at a time. - -Options: - --gpus Comma-separated GPU ids (required), e.g. 0,1,2 - --runs-root Root directory containing run subfolders (default: runs) - --pattern Shell glob to filter run folder basenames (default: *) - --run-dirs-file Text file with one run_dir per line (overrides --runs-root) - --horizons Horizon grid in years (space-separated list). If omitted, uses script defaults. - --age-bins Age bin boundaries in years (space-separated list). If omitted, uses script defaults. - --next-args-file File with one CLI argument per line appended only to evaluate_next_event.py - --horizon-args-file File with one CLI argument per line appended only to evaluate_horizon.py - --python Python executable/command (default: python) - --log-dir Directory for logs (default: eval_logs) - --dry-run Print commands without executing - --help|-h Show this help - -Common eval args: - Anything after `--` is appended to BOTH evaluation commands. - Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --max_cpu_cores, --seed, --min_pos, --no_tqdm). - -Per-eval args: - For eval-specific flags (e.g. evaluate_horizon.py --topk_list / --workload_fracs), use --horizon-args-file. - Args files are "one argument per line"; blank lines are ignored. - -Examples: - ./run_evaluations_multi_gpu.sh --gpus 0,1 - ./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \ - --horizons 0.25 0.5 1 2 5 10 --age-bins 40 45 50 55 60 65 70 75 inf -- --batch_size 512 --num_workers 4 - ./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \ - -- --batch_size 512 --num_workers 4 --max_cpu_cores -1 - -USAGE -} - -runs_root="runs" -pattern="*" -run_dirs_file="" -gpu_list="" -python_cmd="python" -log_dir="eval_logs" -dry_run=0 -horizons=() -age_bins=() -extra_args=() -next_args_file="" -horizon_args_file="" - -while [[ $# -gt 0 ]]; do - case "$1" in - --gpus) - gpu_list="${2-}" - shift 2 - ;; - --runs-root) - runs_root="${2-}" - shift 2 - ;; - --pattern) - pattern="${2-}" - shift 2 - ;; - --run-dirs-file) - run_dirs_file="${2-}" - shift 2 - ;; - --next-args-file) - next_args_file="${2-}" - shift 2 - ;; - --horizon-args-file) - horizon_args_file="${2-}" - shift 2 - ;; - --python) - python_cmd="${2-}" - shift 2 - ;; - --log-dir) - log_dir="${2-}" - shift 2 - ;; - --dry-run) - dry_run=1 - shift - ;; - --horizons) - shift - horizons=() - while [[ $# -gt 0 && "$1" != --* ]]; do - horizons+=("$1") - shift - done - ;; - --age-bins) - shift - age_bins=() - while [[ $# -gt 0 && "$1" != --* ]]; do - age_bins+=("$1") - shift - done - ;; - --help|-h) - usage - exit 0 - ;; - --) - shift - extra_args=("$@") - break - ;; - *) - echo "Unknown argument: $1" >&2 - usage - exit 2 - ;; - esac -done - -if [[ -z "$gpu_list" ]]; then - echo "Error: --gpus is required (e.g. --gpus 0,1,2)." >&2 - exit 2 -fi - -read_args_file() { - local f="${1-}" - if [[ -z "$f" ]]; then - return 0 - fi - if [[ ! -f "$f" ]]; then - echo "Error: args file not found: $f" >&2 - exit 2 - fi - while IFS= read -r line || [[ -n "$line" ]]; do - line="${line%$'\r'}" # handle CRLF - [[ -z "$line" ]] && continue - printf '%s\n' "$line" - done < "$f" -} - -mkdir -p "$log_dir" - -IFS=',' read -r -a gpus <<< "$gpu_list" -if [[ ${#gpus[@]} -lt 1 ]]; then - echo "Error: parsed 0 GPUs from --gpus '$gpu_list'" >&2 - exit 2 -fi - -sanitize() { - # Replace any char outside [A-Za-z0-9._-] with '_' - local s="${1-}" - s="${s//[^A-Za-z0-9._-]/_}" - printf '%s' "$s" -} - -# Discover run directories -run_dirs=() -if [[ -n "$run_dirs_file" ]]; then - if [[ ! -f "$run_dirs_file" ]]; then - echo "Error: --run-dirs-file not found: $run_dirs_file" >&2 - exit 2 - fi - while IFS= read -r line || [[ -n "$line" ]]; do - line="${line%$'\r'}" # handle CRLF - [[ -z "$line" ]] && continue - run_dirs+=("$line") - done < "$run_dirs_file" -else - if [[ ! -d "$runs_root" ]]; then - echo "Error: runs root not found: $runs_root" >&2 - exit 2 - fi - shopt -s nullglob - for d in "$runs_root"/$pattern; do - [[ -d "$d" ]] || continue - [[ -f "$d/best_model.pt" ]] || continue - [[ -f "$d/train_config.json" ]] || continue - run_dirs+=("$d") - done - shopt -u nullglob -fi - -if [[ ${#run_dirs[@]} -eq 0 ]]; then - echo "Error: no run directories found." >&2 - exit 1 -fi - -echo "Queued ${#run_dirs[@]} run(s) across ${#gpus[@]} GPU(s): ${gpus[*]}" - -_tmpdir="" -cleanup() { - if [[ -n "${_tmpdir}" && -d "${_tmpdir}" ]]; then - rm -rf "${_tmpdir}" - fi -} -trap cleanup EXIT - -_tmpdir="$(mktemp -d)" - -# Prepare per-GPU queue files (TSV: job_id \t run_dir) -queue_files=() -for i in "${!gpus[@]}"; do - qfile="${_tmpdir}/queue_${i}.tsv" - : > "$qfile" - queue_files+=("$qfile") -done - -job_id=0 -for run_dir in "${run_dirs[@]}"; do - slot=$((job_id % ${#gpus[@]})) - printf '%s\t%s\n' "$job_id" "$run_dir" >> "${queue_files[$slot]}" - job_id=$((job_id + 1)) -done - -pids=() -for i in "${!gpus[@]}"; do - gpu="${gpus[$i]}" - qfile="${queue_files[$i]}" - - ( - export CUDA_VISIBLE_DEVICES="$gpu" - - while IFS=$'\t' read -r jid run_dir || [[ -n "${jid-}" ]]; do - [[ -z "${jid-}" ]] && continue - [[ -z "${run_dir-}" ]] && continue - - ts="$(date +%Y%m%d-%H%M%S)" - safe_run="$(sanitize "$(basename "$run_dir")")" - log_file="${log_dir}/eval_${jid}_gpu${gpu}_${safe_run}_${ts}.log" - - { - echo "===== EVALUATION START =====" - echo "timestamp: $ts" - echo "gpu: $gpu" - echo "job_id: $jid" - echo "run_dir: $run_dir" - if [[ ${#horizons[@]} -gt 0 ]]; then - echo "horizons: ${horizons[*]}" - fi - if [[ ${#age_bins[@]} -gt 0 ]]; then - echo "age_bins: ${age_bins[*]}" - fi - if [[ -n "${next_args_file}" ]]; then - echo "next_args_file: ${next_args_file}" - fi - if [[ -n "${horizon_args_file}" ]]; then - echo "horizon_args_file: ${horizon_args_file}" - fi - if [[ ${#extra_args[@]} -gt 0 ]]; then - echo "extra_args: ${extra_args[*]}" - fi - echo "============================" - } > "$log_file" - - # Build argv arrays - next_cmd=("$python_cmd" evaluate_next_event.py --run_dir "$run_dir") - if [[ ${#age_bins[@]} -gt 0 ]]; then - next_cmd+=(--age_bins "${age_bins[@]}") - fi - if [[ -n "${next_args_file}" ]]; then - while IFS= read -r a; do - next_cmd+=("$a") - done < <(read_args_file "${next_args_file}") - fi - if [[ ${#extra_args[@]} -gt 0 ]]; then - next_cmd+=("${extra_args[@]}") - fi - - hor_cmd=("$python_cmd" evaluate_horizon.py --run_dir "$run_dir") - if [[ ${#horizons[@]} -gt 0 ]]; then - hor_cmd+=(--horizons "${horizons[@]}") - fi - if [[ ${#age_bins[@]} -gt 0 ]]; then - hor_cmd+=(--age_bins "${age_bins[@]}") - fi - if [[ -n "${horizon_args_file}" ]]; then - while IFS= read -r a; do - hor_cmd+=("$a") - done < <(read_args_file "${horizon_args_file}") - fi - if [[ ${#extra_args[@]} -gt 0 ]]; then - hor_cmd+=("${extra_args[@]}") - fi - - echo "[GPU $gpu] START job $jid: $run_dir" - - if [[ $dry_run -eq 1 ]]; then - { - echo "[DRY-RUN] next-event cmd:"; printf ' %q' "${next_cmd[@]}"; echo - echo "[DRY-RUN] horizon cmd:"; printf ' %q' "${hor_cmd[@]}"; echo - echo "[DRY-RUN] log: $log_file" - } | tee -a "$log_file" - echo "[GPU $gpu] DONE job $jid (dry-run)" - continue - fi - - set +e - { - echo "--- RUN evaluate_next_event.py ---" - printf 'cmd:'; printf ' %q' "${next_cmd[@]}"; echo - "${next_cmd[@]}" - rc1=$? - echo "exit_code_next_event: $rc1" - - echo "--- RUN evaluate_horizon.py ---" - printf 'cmd:'; printf ' %q' "${hor_cmd[@]}"; echo - "${hor_cmd[@]}" - rc2=$? - echo "exit_code_horizon: $rc2" - - echo "===== EVALUATION END =======" - } >> "$log_file" 2>&1 - - set -e - - if [[ $rc1 -ne 0 || $rc2 -ne 0 ]]; then - echo "[GPU $gpu] FAIL job $jid (next=$rc1 horizon=$rc2). Log: $log_file" >&2 - exit 1 - fi - - echo "[GPU $gpu] DONE job $jid (log: $log_file)" - done < "$qfile" - ) & - - pids+=("$!") -done - -fail=0 -for pid in "${pids[@]}"; do - if ! wait "$pid"; then - fail=1 - fi -done - -if [[ $fail -ne 0 ]]; then - echo "One or more workers failed." >&2 - exit 1 -fi - -echo "All evaluations complete." diff --git a/run_experiments_multi_gpu.sh b/run_experiments_multi_gpu.sh index adcc403..bf3e91b 100644 --- a/run_experiments_multi_gpu.sh +++ b/run_experiments_multi_gpu.sh @@ -4,23 +4,41 @@ set -euo pipefail usage() { cat <<'USAGE' Usage: - ./run_experiments_multi_gpu.sh --gpus 0,1,2 [--experiments experiments.txt] [--cmd "python train.py"] [--log-dir experiment_logs] [--dry-run] [-- ] + ./run_experiments_multi_gpu.sh --gpus 0,1,2 [--runs-file runs_to_eval.txt | --runs-root runs] [--cmd "python evaluate.py"] [--log-dir evaluation_logs] [--out-root eval_outputs] [--skip-existing] [--dry-run] [-- ] Description: - Distributes rows from experiments.txt across multiple GPUs (round-robin) and runs + Distributes evaluation jobs across multiple GPUs (round-robin) and runs at most one job per GPU at a time. + A job is a run directory containing: + - train_config.json + - best_model.pt + + By default, run directories are auto-discovered under --runs-root (default: runs). + Alternatively, provide --runs-file with one run_dir per line. + Examples: + # Auto-discover run dirs under ./runs ./run_experiments_multi_gpu.sh --gpus 0,1,2 - ./run_experiments_multi_gpu.sh --gpus 0,1 --experiments experiments.txt -- --batch_size 64 --max_epochs 50 - ./run_experiments_multi_gpu.sh --gpus 3 --cmd "python train.py" -- --loss_type discrete_time_cif + + # Use an explicit list of run directories + ./run_experiments_multi_gpu.sh --gpus 0,1 --runs-file runs_to_eval.txt + + # Centralize outputs (CSV bundle + summary JSON) under eval_outputs/ + ./run_experiments_multi_gpu.sh --gpus 0,1 --out-root eval_outputs + + # Forward args to evaluate.py + ./run_experiments_multi_gpu.sh --gpus 0,1 -- --batch_size 512 --num_workers 8 USAGE } -experiments_file="experiments.txt" +runs_file="" +runs_root="runs" gpu_list="" -cmd_str="python train.py" -log_dir="experiment_logs" +cmd_str="python evaluate.py" +log_dir="evaluation_logs" +out_root="" +skip_existing=0 dry_run=0 extra_args=() @@ -30,8 +48,12 @@ while [[ $# -gt 0 ]]; do gpu_list="${2-}" shift 2 ;; - --experiments|-f) - experiments_file="${2-}" + --runs-file|-f) + runs_file="${2-}" + shift 2 + ;; + --runs-root) + runs_root="${2-}" shift 2 ;; --cmd) @@ -42,6 +64,14 @@ while [[ $# -gt 0 ]]; do log_dir="${2-}" shift 2 ;; + --out-root) + out_root="${2-}" + shift 2 + ;; + --skip-existing) + skip_existing=1 + shift + ;; --dry-run) dry_run=1 shift @@ -70,11 +100,6 @@ fi mkdir -p "$log_dir" -if [[ ! -f "$experiments_file" ]]; then - echo "Error: experiments file not found: $experiments_file" >&2 - exit 2 -fi - IFS=',' read -r -a gpus <<< "$gpu_list" if [[ ${#gpus[@]} -lt 1 ]]; then echo "Error: parsed 0 GPUs from --gpus '$gpu_list'" >&2 @@ -85,7 +110,7 @@ fi # shellcheck disable=SC2206 cmd=($cmd_str) if [[ ${#cmd[@]} -lt 2 ]]; then - echo "Error: --cmd should look like 'python train.py'" >&2 + echo "Error: --cmd should look like 'python evaluate.py'" >&2 exit 2 fi @@ -107,28 +132,81 @@ for i in "${!gpus[@]}"; do queue_files+=("$qfile") done -# Distribute experiments round-robin. -exp_idx=0 -while IFS= read -r line || [[ -n "$line" ]]; do - line="${line%$'\r'}" # handle CRLF - [[ -z "$line" ]] && continue - # Skip header if present - if [[ "$line" == model_type,* ]]; then - continue +discover_runs() { + local root="${1-}" + if [[ -z "$root" ]]; then + return 0 + fi + if [[ ! -d "$root" ]]; then + echo "Error: runs root not found: $root" >&2 + return 2 fi - slot=$((exp_idx % ${#gpus[@]})) - # Prefix a stable experiment index for logging. - printf '%s,%s\n' "$exp_idx" "$line" >> "${queue_files[$slot]}" - exp_idx=$((exp_idx + 1)) -done < "$experiments_file" + # shellcheck disable=SC2016 + find "$root" -mindepth 1 -maxdepth 1 -type d -print 2>/dev/null | + sort +} -if [[ $exp_idx -eq 0 ]]; then - echo "No experiments found in $experiments_file" >&2 +run_dirs=() +if [[ -n "$runs_file" ]]; then + if [[ ! -f "$runs_file" ]]; then + echo "Error: runs file not found: $runs_file" >&2 + exit 2 + fi + + while IFS= read -r line || [[ -n "$line" ]]; do + line="${line%$'\r'}" # handle CRLF + [[ -z "$line" ]] && continue + [[ "$line" == \#* ]] && continue + run_dirs+=("$line") + done < "$runs_file" +else + while IFS= read -r d || [[ -n "${d-}" ]]; do + [[ -z "${d-}" ]] && continue + run_dirs+=("$d") + done < <(discover_runs "$runs_root") +fi + +if [[ ${#run_dirs[@]} -eq 0 ]]; then + if [[ -n "$runs_file" ]]; then + echo "No run directories found in $runs_file" >&2 + else + echo "No run directories found under $runs_root" >&2 + fi exit 1 fi -echo "Queued $exp_idx experiments across ${#gpus[@]} GPU(s): ${gpus[*]}" +is_valid_run_dir() { + local d="${1-}" + [[ -d "$d" ]] || return 1 + [[ -f "$d/train_config.json" ]] || return 1 + [[ -f "$d/best_model.pt" ]] || return 1 + return 0 +} + +valid_run_dirs=() +for d in "${run_dirs[@]}"; do + if is_valid_run_dir "$d"; then + valid_run_dirs+=("$d") + else + echo "Skipping invalid run_dir (missing train_config.json or best_model.pt): $d" >&2 + fi +done + +if [[ ${#valid_run_dirs[@]} -eq 0 ]]; then + echo "No valid run directories found." >&2 + exit 1 +fi + +# Distribute evaluation jobs round-robin. +job_idx=0 +for d in "${valid_run_dirs[@]}"; do + slot=$((job_idx % ${#gpus[@]})) + printf '%s,%s\n' "$job_idx" "$d" >> "${queue_files[$slot]}" + job_idx=$((job_idx + 1)) +done + +echo "Queued $job_idx evaluation job(s) across ${#gpus[@]} GPU(s): ${gpus[*]}" sanitize() { # Replace any char outside [A-Za-z0-9._-] with '_' @@ -145,44 +223,53 @@ for i in "${!gpus[@]}"; do ( export CUDA_VISIBLE_DEVICES="$gpu" - while IFS=',' read -r exp_id model_type loss_type age_encoder full_cov || [[ -n "${exp_id-}" ]]; do - # Skip empty lines - [[ -z "${exp_id-}" ]] && continue + while IFS=',' read -r job_id run_dir || [[ -n "${job_id-}" ]]; do + [[ -z "${job_id-}" ]] && continue + [[ -z "${run_dir-}" ]] && continue - # Normalize booleans / strip whitespace - full_cov="${full_cov//[[:space:]]/}" + ts="$(date +%Y%m%d-%H%M%S)" + safe_run="$(sanitize "$(basename "$run_dir")")" + + # Decide output locations. + out_dir_arg=() + out_json="" + if [[ -n "$out_root" ]]; then + job_out_dir="${out_root%/}/run_${job_id}_${safe_run}" + mkdir -p "$job_out_dir" + out_json="$job_out_dir/evaluation_summary.json" + out_dir_arg=(--out_dir "$job_out_dir" --output "$out_json") + else + out_json="$run_dir/evaluation_summary.json" + fi + + if [[ $skip_existing -eq 1 && -f "$out_json" ]]; then + echo "[GPU $gpu] SKIP job $job_id: already exists ($out_json)" + continue + fi run_cmd=("${cmd[@]}" \ - --model_type "$model_type" \ - --loss_type "$loss_type" \ - --age_encoder "$age_encoder") + --run_dir "$run_dir" \ + --device cuda) - if [[ "$full_cov" == "True" || "$full_cov" == "true" || "$full_cov" == "1" ]]; then - run_cmd+=(--full_cov) + if [[ ${#out_dir_arg[@]} -gt 0 ]]; then + run_cmd+=("${out_dir_arg[@]}") fi if [[ ${#extra_args[@]} -gt 0 ]]; then run_cmd+=("${extra_args[@]}") fi - echo "[GPU $gpu] START exp $exp_id: model_type=$model_type loss_type=$loss_type age_encoder=$age_encoder full_cov=$full_cov" - - ts="$(date +%Y%m%d-%H%M%S)" - safe_model="$(sanitize "$model_type")" - safe_loss="$(sanitize "$loss_type")" - safe_age="$(sanitize "$age_encoder")" - safe_cov="$(sanitize "$full_cov")" - log_file="${log_dir}/exp_${exp_id}_gpu${gpu}_${safe_model}_${safe_loss}_${safe_age}_${safe_cov}_${ts}.log" + echo "[GPU $gpu] START job $job_id: run_dir=$run_dir" + log_file="${log_dir}/eval_${job_id}_gpu${gpu}_${safe_run}_${ts}.log" { - echo "===== EXPERIMENT START =====" + echo "===== EVALUATION START =====" echo "timestamp: $ts" echo "gpu: $gpu" - echo "exp_id: $exp_id" - echo "model_type: $model_type" - echo "loss_type: $loss_type" - echo "age_encoder: $age_encoder" - echo "full_cov: $full_cov" + echo "job_id: $job_id" + echo "run_dir: $run_dir" + echo "out_root: ${out_root:-}" + echo "out_json: $out_json" printf 'cmd:' printf ' %q' "${run_cmd[@]}" echo @@ -203,16 +290,16 @@ for i in "${!gpus[@]}"; do { echo "============================" echo "exit_code: $rc" - echo "===== EXPERIMENT END =======" + echo "===== EVALUATION END =======" } >> "$log_file" if [[ $rc -ne 0 ]]; then - echo "[GPU $gpu] FAIL exp $exp_id (exit=$rc). Log: $log_file" >&2 + echo "[GPU $gpu] FAIL job $job_id (exit=$rc). Log: $log_file" >&2 exit "$rc" fi fi - echo "[GPU $gpu] DONE exp $exp_id (log: $log_file)" + echo "[GPU $gpu] DONE job $job_id (log: $log_file)" done < "$qfile" ) & @@ -232,4 +319,4 @@ if [[ $fail -ne 0 ]]; then exit 1 fi -echo "All experiments complete." +echo "All evaluations complete."