Add support for evaluation indices in LandmarkEvaluator class

This commit is contained in:
2026-01-18 17:43:47 +08:00
parent 0057bc0dd9
commit 014393a33f

View File

@@ -221,6 +221,7 @@ class LandmarkEvaluator:
head: torch.nn.Module,
loss_fn: torch.nn.Module,
dataset: HealthDataset,
eval_indices: Optional[List[int]] = None,
device: str = 'cuda',
batch_size: int = 256,
num_workers: int = 4,
@@ -231,6 +232,7 @@ class LandmarkEvaluator:
self.head.eval()
self.loss_fn = loss_fn
self.dataset = dataset
self.eval_indices = eval_indices
self.device = device
self.batch_size = batch_size
self.num_workers = num_workers
@@ -430,7 +432,9 @@ class LandmarkEvaluator:
labels_rows: List[np.ndarray] = []
valid_rows: List[np.ndarray] = []
for idx in range(len(self.dataset)):
candidate_indices = self.eval_indices if self.eval_indices is not None else list(
range(len(self.dataset)))
for idx in candidate_indices:
patient_id = self.dataset.patient_ids[idx]
records = self.dataset.patient_events.get(patient_id, [])
if not records:
@@ -965,23 +969,45 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
print(f"Model type: {config['model_type']}")
print(f"Loss type: {config['loss_type']}")
# Load dataset to get dimensions
# Load dataset (same as training) and reproduce the train/val/test split.
# IMPORTANT: do NOT change data_prefix; train.py reads files like
# <data_prefix>_basic_info.csv, <data_prefix>_table.csv, <data_prefix>_event_data.npy
data_prefix = config['data_prefix']
# Determine covariate list based on full_cov
if config['full_cov']:
covariate_list = None # Use all covariates
if config.get('full_cov', False):
covariate_list = None
else:
# Use partial covariates (define your partial list here)
covariate_list = ['age_at_assessment',
'bmi', 'smoking_status'] # Example
# Match train.py partial-cov settings
covariate_list = ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=f"{data_prefix}_test",
data_prefix=data_prefix,
covariate_list=covariate_list,
cache_event_tensors=True,
)
# Reproduce the random_split used in train.py to obtain the held-out test subset.
n_total = len(dataset)
train_ratio = float(config.get('train_ratio', 0.7))
val_ratio = float(config.get('val_ratio', 0.15))
seed = int(config.get('random_seed', 42))
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
if n_test < 0:
raise ValueError(
f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}"
)
from torch.utils.data import random_split
_, _, test_subset = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
test_indices = list(getattr(test_subset, 'indices', []))
# Determine output dimensions based on loss type
import math
if config['loss_type'] == 'exponential':
@@ -1075,7 +1101,7 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
else:
raise ValueError(f"Unknown loss type: {config['loss_type']}")
return model, head, loss_fn, dataset, config
return model, head, loss_fn, dataset, config, test_indices
def print_summary(results: Dict):
@@ -1180,7 +1206,7 @@ def main():
args = parser.parse_args()
# Load model and dataset
model, head, loss_fn, dataset, config = load_model_and_config(
model, head, loss_fn, dataset, config, test_indices = load_model_and_config(
args.run_dir, args.device)
# Create evaluator
@@ -1189,6 +1215,7 @@ def main():
head=head,
loss_fn=loss_fn,
dataset=dataset,
eval_indices=test_indices,
device=args.device,
batch_size=args.batch_size,
num_workers=args.num_workers,