diff --git a/dataset.py b/dataset.py index 34528f2..6cecc01 100644 --- a/dataset.py +++ b/dataset.py @@ -20,6 +20,7 @@ class HealthDataset(Dataset): self, data_prefix: str, covariate_list: List[str] | None = None, + cache_event_tensors: bool = True, ): basic_info = pd.read_csv( f"{data_prefix}_basic_info.csv", index_col='eid') @@ -73,23 +74,58 @@ class HealthDataset(Dataset): self.cont_features = torch.from_numpy(self.cont_features) self.cate_features = torch.from_numpy(self.cate_features) + # Optional cache for the DOA-inserted sequences produced by __getitem__. + # This preserves outputs exactly (we reuse the same construction logic), + # but avoids re-building Python lists on repeated access. + self._cache_event_tensors = bool(cache_event_tensors) + self._cached_event_tensors: List[torch.Tensor | None] = [None] * len( + self.patient_ids + ) + self._cached_time_tensors: List[torch.Tensor | None] = [None] * len( + self.patient_ids + ) + def __len__(self) -> int: return len(self.patient_ids) def __getitem__(self, idx): - patient_id = self.patient_ids[idx] - records = self.patient_events.get(patient_id, []) - event_seq = [item[1] for item in records] - time_seq = [item[0] for item in records] + if self._cache_event_tensors: + cached_e = self._cached_event_tensors[idx] + cached_t = self._cached_time_tensors[idx] + if cached_e is not None and cached_t is not None: + event_tensor = cached_e + time_tensor = cached_t + else: + patient_id = self.patient_ids[idx] + records = self.patient_events.get(patient_id, []) + event_seq = [item[1] for item in records] + time_seq = [item[0] for item in records] - doa = float(self._doa[idx]) + doa = float(self._doa[idx]) - insert_pos = np.searchsorted(time_seq, doa) - time_seq.insert(insert_pos, doa) - # assuming 1 is the code for 'DOA' event - event_seq.insert(insert_pos, 1) - event_tensor = torch.tensor(event_seq, dtype=torch.long) - time_tensor = torch.tensor(time_seq, dtype=torch.float) + insert_pos = np.searchsorted(time_seq, doa) + time_seq.insert(insert_pos, doa) + # assuming 1 is the code for 'DOA' event + event_seq.insert(insert_pos, 1) + event_tensor = torch.tensor(event_seq, dtype=torch.long) + time_tensor = torch.tensor(time_seq, dtype=torch.float) + + self._cached_event_tensors[idx] = event_tensor + self._cached_time_tensors[idx] = time_tensor + else: + patient_id = self.patient_ids[idx] + records = self.patient_events.get(patient_id, []) + event_seq = [item[1] for item in records] + time_seq = [item[0] for item in records] + + doa = float(self._doa[idx]) + + insert_pos = np.searchsorted(time_seq, doa) + time_seq.insert(insert_pos, doa) + # assuming 1 is the code for 'DOA' event + event_seq.insert(insert_pos, 1) + event_tensor = torch.tensor(event_seq, dtype=torch.long) + time_tensor = torch.tensor(time_seq, dtype=torch.float) cont_tensor = self.cont_features[idx, :].to(dtype=torch.float) cate_tensor = self.cate_features[idx, :].to(dtype=torch.long) diff --git a/evaluate_horizon.py b/evaluate_horizon.py index b4c0bc5..ff4392d 100644 --- a/evaluate_horizon.py +++ b/evaluate_horizon.py @@ -125,7 +125,6 @@ def main() -> None: horizons = [float(h) for h in horizons] records = build_event_driven_records( - dataset=dataset, subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, @@ -136,7 +135,7 @@ def main() -> None: model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) - rec_ds = EvalRecordDataset(dataset, records) + rec_ds = EvalRecordDataset(test_subset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( diff --git a/evaluate_next_event.py b/evaluate_next_event.py index 130ee69..65741af 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -202,7 +202,6 @@ def main() -> None: age_bins_years = parse_float_list(args.age_bins) records = build_event_driven_records( - dataset=dataset, subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, @@ -213,7 +212,7 @@ def main() -> None: model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) - rec_ds = EvalRecordDataset(dataset, records) + rec_ds = EvalRecordDataset(test_subset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( diff --git a/utils.py b/utils.py index c17ec7d..593072b 100644 --- a/utils.py +++ b/utils.py @@ -267,8 +267,7 @@ def load_checkpoint_into( @dataclass(frozen=True) class EvalRecord: - patient_idx: int - patient_id: Any + subset_idx: int doa_days: float t0_days: float cutoff_pos: int # baseline position (inclusive) @@ -285,7 +284,6 @@ def _to_days(x_years: float) -> float: def build_event_driven_records( - dataset: HealthDataset, subset: Subset, age_bins_years: Sequence[float], seed: int, @@ -302,34 +300,24 @@ def build_event_driven_records( records: List[EvalRecord] = [] - # Subset.indices is deterministic from random_split - indices = list(getattr(subset, "indices", range(len(subset)))) - - # Speed: avoid calling dataset.__getitem__ for every patient here. - # We only need DOA + event times/codes to create evaluation records. + # Build records exclusively from the provided subset. + # We intentionally avoid reading from subset.dataset internals so the + # evaluation pipeline does not depend on the full dataset object. eps = 1e-6 - for patient_idx in _progress( - indices, + for subset_idx in _progress( + range(len(subset)), enabled=show_progress, desc="Building eval records", - total=len(indices), + total=len(subset), ): - patient_id = dataset.patient_ids[patient_idx] + event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] + codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) + times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) - doa_days = float(dataset._doa[patient_idx]) - - raw_records = dataset.patient_events.get(patient_id, []) - if raw_records: - times = np.asarray([t for t, _ in raw_records], dtype=np.float64) - codes = np.asarray([c for _, c in raw_records], dtype=np.int64) - else: - times = np.zeros((0,), dtype=np.float64) - codes = np.zeros((0,), dtype=np.int64) - - # Mirror HealthDataset insertion logic exactly. - insert_pos = int(np.searchsorted(times, doa_days, side="left")) - times_ins = np.insert(times, insert_pos, doa_days) - codes_ins = np.insert(codes, insert_pos, 1) + doa_pos = np.flatnonzero(codes_ins == 1) + if doa_pos.size == 0: + raise ValueError("Expected DOA token (code=1) in event sequence") + doa_days = float(times_ins[int(doa_pos[0])]) is_disease = codes_ins >= N_TECH_TOKENS disease_times = times_ins[is_disease] @@ -389,8 +377,7 @@ def build_event_driven_records( records.append( EvalRecord( - patient_idx=int(patient_idx), - patient_id=patient_id, + subset_idx=int(subset_idx), doa_days=float(doa_days), t0_days=float(t0_days), cutoff_pos=int(cutoff_pos), @@ -405,8 +392,8 @@ def build_event_driven_records( class EvalRecordDataset(Dataset): - def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]): - self.base = base_dataset + def __init__(self, subset: Dataset, records: Sequence[EvalRecord]): + self.subset = subset self.records = list(records) self._cache: Dict[int, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]] = {} @@ -418,12 +405,12 @@ class EvalRecordDataset(Dataset): def __getitem__(self, idx: int): rec = self.records[idx] - cached = self._cache.get(rec.patient_idx) + cached = self._cache.get(rec.subset_idx) if cached is None: - event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx] + event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx] cached = (event_seq, time_seq, cont, cate, int(sex)) - self._cache[rec.patient_idx] = cached - self._cache_order.append(rec.patient_idx) + self._cache[rec.subset_idx] = cached + self._cache_order.append(rec.subset_idx) if len(self._cache_order) > self._cache_max: drop = self._cache_order.pop(0) self._cache.pop(drop, None)