Add support for evaluation indices in LandmarkEvaluator class
This commit is contained in:
49
evaluate.py
49
evaluate.py
@@ -221,6 +221,7 @@ class LandmarkEvaluator:
|
||||
head: torch.nn.Module,
|
||||
loss_fn: torch.nn.Module,
|
||||
dataset: HealthDataset,
|
||||
eval_indices: Optional[List[int]] = None,
|
||||
device: str = 'cuda',
|
||||
batch_size: int = 256,
|
||||
num_workers: int = 4,
|
||||
@@ -231,6 +232,7 @@ class LandmarkEvaluator:
|
||||
self.head.eval()
|
||||
self.loss_fn = loss_fn
|
||||
self.dataset = dataset
|
||||
self.eval_indices = eval_indices
|
||||
self.device = device
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
@@ -430,7 +432,9 @@ class LandmarkEvaluator:
|
||||
labels_rows: List[np.ndarray] = []
|
||||
valid_rows: List[np.ndarray] = []
|
||||
|
||||
for idx in range(len(self.dataset)):
|
||||
candidate_indices = self.eval_indices if self.eval_indices is not None else list(
|
||||
range(len(self.dataset)))
|
||||
for idx in candidate_indices:
|
||||
patient_id = self.dataset.patient_ids[idx]
|
||||
records = self.dataset.patient_events.get(patient_id, [])
|
||||
if not records:
|
||||
@@ -965,23 +969,45 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
|
||||
print(f"Model type: {config['model_type']}")
|
||||
print(f"Loss type: {config['loss_type']}")
|
||||
|
||||
# Load dataset to get dimensions
|
||||
# Load dataset (same as training) and reproduce the train/val/test split.
|
||||
# IMPORTANT: do NOT change data_prefix; train.py reads files like
|
||||
# <data_prefix>_basic_info.csv, <data_prefix>_table.csv, <data_prefix>_event_data.npy
|
||||
data_prefix = config['data_prefix']
|
||||
|
||||
# Determine covariate list based on full_cov
|
||||
if config['full_cov']:
|
||||
covariate_list = None # Use all covariates
|
||||
if config.get('full_cov', False):
|
||||
covariate_list = None
|
||||
else:
|
||||
# Use partial covariates (define your partial list here)
|
||||
covariate_list = ['age_at_assessment',
|
||||
'bmi', 'smoking_status'] # Example
|
||||
# Match train.py partial-cov settings
|
||||
covariate_list = ["bmi", "smoking", "alcohol"]
|
||||
|
||||
dataset = HealthDataset(
|
||||
data_prefix=f"{data_prefix}_test",
|
||||
data_prefix=data_prefix,
|
||||
covariate_list=covariate_list,
|
||||
cache_event_tensors=True,
|
||||
)
|
||||
|
||||
# Reproduce the random_split used in train.py to obtain the held-out test subset.
|
||||
n_total = len(dataset)
|
||||
train_ratio = float(config.get('train_ratio', 0.7))
|
||||
val_ratio = float(config.get('val_ratio', 0.15))
|
||||
seed = int(config.get('random_seed', 42))
|
||||
|
||||
n_train = int(n_total * train_ratio)
|
||||
n_val = int(n_total * val_ratio)
|
||||
n_test = n_total - n_train - n_val
|
||||
if n_test < 0:
|
||||
raise ValueError(
|
||||
f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}"
|
||||
)
|
||||
|
||||
from torch.utils.data import random_split
|
||||
_, _, test_subset = random_split(
|
||||
dataset,
|
||||
[n_train, n_val, n_test],
|
||||
generator=torch.Generator().manual_seed(seed),
|
||||
)
|
||||
test_indices = list(getattr(test_subset, 'indices', []))
|
||||
|
||||
# Determine output dimensions based on loss type
|
||||
import math
|
||||
if config['loss_type'] == 'exponential':
|
||||
@@ -1075,7 +1101,7 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
||||
|
||||
return model, head, loss_fn, dataset, config
|
||||
return model, head, loss_fn, dataset, config, test_indices
|
||||
|
||||
|
||||
def print_summary(results: Dict):
|
||||
@@ -1180,7 +1206,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and dataset
|
||||
model, head, loss_fn, dataset, config = load_model_and_config(
|
||||
model, head, loss_fn, dataset, config, test_indices = load_model_and_config(
|
||||
args.run_dir, args.device)
|
||||
|
||||
# Create evaluator
|
||||
@@ -1189,6 +1215,7 @@ def main():
|
||||
head=head,
|
||||
loss_fn=loss_fn,
|
||||
dataset=dataset,
|
||||
eval_indices=test_indices,
|
||||
device=args.device,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
|
||||
Reference in New Issue
Block a user