From 014393a33f9889a184747a7225273829e95fc1a1 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sun, 18 Jan 2026 17:43:47 +0800 Subject: [PATCH] Add support for evaluation indices in LandmarkEvaluator class --- evaluate.py | 49 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/evaluate.py b/evaluate.py index 6633bb7..64ab22e 100644 --- a/evaluate.py +++ b/evaluate.py @@ -221,6 +221,7 @@ class LandmarkEvaluator: head: torch.nn.Module, loss_fn: torch.nn.Module, dataset: HealthDataset, + eval_indices: Optional[List[int]] = None, device: str = 'cuda', batch_size: int = 256, num_workers: int = 4, @@ -231,6 +232,7 @@ class LandmarkEvaluator: self.head.eval() self.loss_fn = loss_fn self.dataset = dataset + self.eval_indices = eval_indices self.device = device self.batch_size = batch_size self.num_workers = num_workers @@ -430,7 +432,9 @@ class LandmarkEvaluator: labels_rows: List[np.ndarray] = [] valid_rows: List[np.ndarray] = [] - for idx in range(len(self.dataset)): + candidate_indices = self.eval_indices if self.eval_indices is not None else list( + range(len(self.dataset))) + for idx in candidate_indices: patient_id = self.dataset.patient_ids[idx] records = self.dataset.patient_events.get(patient_id, []) if not records: @@ -965,23 +969,45 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple: print(f"Model type: {config['model_type']}") print(f"Loss type: {config['loss_type']}") - # Load dataset to get dimensions + # Load dataset (same as training) and reproduce the train/val/test split. + # IMPORTANT: do NOT change data_prefix; train.py reads files like + # _basic_info.csv, _table.csv, _event_data.npy data_prefix = config['data_prefix'] - # Determine covariate list based on full_cov - if config['full_cov']: - covariate_list = None # Use all covariates + if config.get('full_cov', False): + covariate_list = None else: - # Use partial covariates (define your partial list here) - covariate_list = ['age_at_assessment', - 'bmi', 'smoking_status'] # Example + # Match train.py partial-cov settings + covariate_list = ["bmi", "smoking", "alcohol"] dataset = HealthDataset( - data_prefix=f"{data_prefix}_test", + data_prefix=data_prefix, covariate_list=covariate_list, cache_event_tensors=True, ) + # Reproduce the random_split used in train.py to obtain the held-out test subset. + n_total = len(dataset) + train_ratio = float(config.get('train_ratio', 0.7)) + val_ratio = float(config.get('val_ratio', 0.15)) + seed = int(config.get('random_seed', 42)) + + n_train = int(n_total * train_ratio) + n_val = int(n_total * val_ratio) + n_test = n_total - n_train - n_val + if n_test < 0: + raise ValueError( + f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}" + ) + + from torch.utils.data import random_split + _, _, test_subset = random_split( + dataset, + [n_train, n_val, n_test], + generator=torch.Generator().manual_seed(seed), + ) + test_indices = list(getattr(test_subset, 'indices', [])) + # Determine output dimensions based on loss type import math if config['loss_type'] == 'exponential': @@ -1075,7 +1101,7 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple: else: raise ValueError(f"Unknown loss type: {config['loss_type']}") - return model, head, loss_fn, dataset, config + return model, head, loss_fn, dataset, config, test_indices def print_summary(results: Dict): @@ -1180,7 +1206,7 @@ def main(): args = parser.parse_args() # Load model and dataset - model, head, loss_fn, dataset, config = load_model_and_config( + model, head, loss_fn, dataset, config, test_indices = load_model_and_config( args.run_dir, args.device) # Create evaluator @@ -1189,6 +1215,7 @@ def main(): head=head, loss_fn=loss_fn, dataset=dataset, + eval_indices=test_indices, device=args.device, batch_size=args.batch_size, num_workers=args.num_workers,