Add support for evaluation indices in LandmarkEvaluator class

This commit is contained in:
2026-01-18 17:43:47 +08:00
parent 0057bc0dd9
commit 014393a33f

View File

@@ -221,6 +221,7 @@ class LandmarkEvaluator:
head: torch.nn.Module, head: torch.nn.Module,
loss_fn: torch.nn.Module, loss_fn: torch.nn.Module,
dataset: HealthDataset, dataset: HealthDataset,
eval_indices: Optional[List[int]] = None,
device: str = 'cuda', device: str = 'cuda',
batch_size: int = 256, batch_size: int = 256,
num_workers: int = 4, num_workers: int = 4,
@@ -231,6 +232,7 @@ class LandmarkEvaluator:
self.head.eval() self.head.eval()
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.dataset = dataset self.dataset = dataset
self.eval_indices = eval_indices
self.device = device self.device = device
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
@@ -430,7 +432,9 @@ class LandmarkEvaluator:
labels_rows: List[np.ndarray] = [] labels_rows: List[np.ndarray] = []
valid_rows: List[np.ndarray] = [] valid_rows: List[np.ndarray] = []
for idx in range(len(self.dataset)): candidate_indices = self.eval_indices if self.eval_indices is not None else list(
range(len(self.dataset)))
for idx in candidate_indices:
patient_id = self.dataset.patient_ids[idx] patient_id = self.dataset.patient_ids[idx]
records = self.dataset.patient_events.get(patient_id, []) records = self.dataset.patient_events.get(patient_id, [])
if not records: if not records:
@@ -965,23 +969,45 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
print(f"Model type: {config['model_type']}") print(f"Model type: {config['model_type']}")
print(f"Loss type: {config['loss_type']}") print(f"Loss type: {config['loss_type']}")
# Load dataset to get dimensions # Load dataset (same as training) and reproduce the train/val/test split.
# IMPORTANT: do NOT change data_prefix; train.py reads files like
# <data_prefix>_basic_info.csv, <data_prefix>_table.csv, <data_prefix>_event_data.npy
data_prefix = config['data_prefix'] data_prefix = config['data_prefix']
# Determine covariate list based on full_cov if config.get('full_cov', False):
if config['full_cov']: covariate_list = None
covariate_list = None # Use all covariates
else: else:
# Use partial covariates (define your partial list here) # Match train.py partial-cov settings
covariate_list = ['age_at_assessment', covariate_list = ["bmi", "smoking", "alcohol"]
'bmi', 'smoking_status'] # Example
dataset = HealthDataset( dataset = HealthDataset(
data_prefix=f"{data_prefix}_test", data_prefix=data_prefix,
covariate_list=covariate_list, covariate_list=covariate_list,
cache_event_tensors=True, cache_event_tensors=True,
) )
# Reproduce the random_split used in train.py to obtain the held-out test subset.
n_total = len(dataset)
train_ratio = float(config.get('train_ratio', 0.7))
val_ratio = float(config.get('val_ratio', 0.15))
seed = int(config.get('random_seed', 42))
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
if n_test < 0:
raise ValueError(
f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}"
)
from torch.utils.data import random_split
_, _, test_subset = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
test_indices = list(getattr(test_subset, 'indices', []))
# Determine output dimensions based on loss type # Determine output dimensions based on loss type
import math import math
if config['loss_type'] == 'exponential': if config['loss_type'] == 'exponential':
@@ -1075,7 +1101,7 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
else: else:
raise ValueError(f"Unknown loss type: {config['loss_type']}") raise ValueError(f"Unknown loss type: {config['loss_type']}")
return model, head, loss_fn, dataset, config return model, head, loss_fn, dataset, config, test_indices
def print_summary(results: Dict): def print_summary(results: Dict):
@@ -1180,7 +1206,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Load model and dataset # Load model and dataset
model, head, loss_fn, dataset, config = load_model_and_config( model, head, loss_fn, dataset, config, test_indices = load_model_and_config(
args.run_dir, args.device) args.run_dir, args.device)
# Create evaluator # Create evaluator
@@ -1189,6 +1215,7 @@ def main():
head=head, head=head,
loss_fn=loss_fn, loss_fn=loss_fn,
dataset=dataset, dataset=dataset,
eval_indices=test_indices,
device=args.device, device=args.device,
batch_size=args.batch_size, batch_size=args.batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,