Add support for evaluation indices in LandmarkEvaluator class
This commit is contained in:
49
evaluate.py
49
evaluate.py
@@ -221,6 +221,7 @@ class LandmarkEvaluator:
|
|||||||
head: torch.nn.Module,
|
head: torch.nn.Module,
|
||||||
loss_fn: torch.nn.Module,
|
loss_fn: torch.nn.Module,
|
||||||
dataset: HealthDataset,
|
dataset: HealthDataset,
|
||||||
|
eval_indices: Optional[List[int]] = None,
|
||||||
device: str = 'cuda',
|
device: str = 'cuda',
|
||||||
batch_size: int = 256,
|
batch_size: int = 256,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
@@ -231,6 +232,7 @@ class LandmarkEvaluator:
|
|||||||
self.head.eval()
|
self.head.eval()
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
self.eval_indices = eval_indices
|
||||||
self.device = device
|
self.device = device
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
@@ -430,7 +432,9 @@ class LandmarkEvaluator:
|
|||||||
labels_rows: List[np.ndarray] = []
|
labels_rows: List[np.ndarray] = []
|
||||||
valid_rows: List[np.ndarray] = []
|
valid_rows: List[np.ndarray] = []
|
||||||
|
|
||||||
for idx in range(len(self.dataset)):
|
candidate_indices = self.eval_indices if self.eval_indices is not None else list(
|
||||||
|
range(len(self.dataset)))
|
||||||
|
for idx in candidate_indices:
|
||||||
patient_id = self.dataset.patient_ids[idx]
|
patient_id = self.dataset.patient_ids[idx]
|
||||||
records = self.dataset.patient_events.get(patient_id, [])
|
records = self.dataset.patient_events.get(patient_id, [])
|
||||||
if not records:
|
if not records:
|
||||||
@@ -965,23 +969,45 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
|
|||||||
print(f"Model type: {config['model_type']}")
|
print(f"Model type: {config['model_type']}")
|
||||||
print(f"Loss type: {config['loss_type']}")
|
print(f"Loss type: {config['loss_type']}")
|
||||||
|
|
||||||
# Load dataset to get dimensions
|
# Load dataset (same as training) and reproduce the train/val/test split.
|
||||||
|
# IMPORTANT: do NOT change data_prefix; train.py reads files like
|
||||||
|
# <data_prefix>_basic_info.csv, <data_prefix>_table.csv, <data_prefix>_event_data.npy
|
||||||
data_prefix = config['data_prefix']
|
data_prefix = config['data_prefix']
|
||||||
|
|
||||||
# Determine covariate list based on full_cov
|
if config.get('full_cov', False):
|
||||||
if config['full_cov']:
|
covariate_list = None
|
||||||
covariate_list = None # Use all covariates
|
|
||||||
else:
|
else:
|
||||||
# Use partial covariates (define your partial list here)
|
# Match train.py partial-cov settings
|
||||||
covariate_list = ['age_at_assessment',
|
covariate_list = ["bmi", "smoking", "alcohol"]
|
||||||
'bmi', 'smoking_status'] # Example
|
|
||||||
|
|
||||||
dataset = HealthDataset(
|
dataset = HealthDataset(
|
||||||
data_prefix=f"{data_prefix}_test",
|
data_prefix=data_prefix,
|
||||||
covariate_list=covariate_list,
|
covariate_list=covariate_list,
|
||||||
cache_event_tensors=True,
|
cache_event_tensors=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reproduce the random_split used in train.py to obtain the held-out test subset.
|
||||||
|
n_total = len(dataset)
|
||||||
|
train_ratio = float(config.get('train_ratio', 0.7))
|
||||||
|
val_ratio = float(config.get('val_ratio', 0.15))
|
||||||
|
seed = int(config.get('random_seed', 42))
|
||||||
|
|
||||||
|
n_train = int(n_total * train_ratio)
|
||||||
|
n_val = int(n_total * val_ratio)
|
||||||
|
n_test = n_total - n_train - n_val
|
||||||
|
if n_test < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid split sizes from config: n_total={n_total}, train_ratio={train_ratio}, val_ratio={val_ratio}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.utils.data import random_split
|
||||||
|
_, _, test_subset = random_split(
|
||||||
|
dataset,
|
||||||
|
[n_train, n_val, n_test],
|
||||||
|
generator=torch.Generator().manual_seed(seed),
|
||||||
|
)
|
||||||
|
test_indices = list(getattr(test_subset, 'indices', []))
|
||||||
|
|
||||||
# Determine output dimensions based on loss type
|
# Determine output dimensions based on loss type
|
||||||
import math
|
import math
|
||||||
if config['loss_type'] == 'exponential':
|
if config['loss_type'] == 'exponential':
|
||||||
@@ -1075,7 +1101,7 @@ def load_model_and_config(run_dir: str, device: str = 'cuda') -> Tuple:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
raise ValueError(f"Unknown loss type: {config['loss_type']}")
|
||||||
|
|
||||||
return model, head, loss_fn, dataset, config
|
return model, head, loss_fn, dataset, config, test_indices
|
||||||
|
|
||||||
|
|
||||||
def print_summary(results: Dict):
|
def print_summary(results: Dict):
|
||||||
@@ -1180,7 +1206,7 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load model and dataset
|
# Load model and dataset
|
||||||
model, head, loss_fn, dataset, config = load_model_and_config(
|
model, head, loss_fn, dataset, config, test_indices = load_model_and_config(
|
||||||
args.run_dir, args.device)
|
args.run_dir, args.device)
|
||||||
|
|
||||||
# Create evaluator
|
# Create evaluator
|
||||||
@@ -1189,6 +1215,7 @@ def main():
|
|||||||
head=head,
|
head=head,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
eval_indices=test_indices,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
|
|||||||
Reference in New Issue
Block a user