Enhance HealthDataset with caching for event tensors and update evaluation scripts to use test subsets
This commit is contained in:
36
dataset.py
36
dataset.py
@@ -20,6 +20,7 @@ class HealthDataset(Dataset):
|
||||
self,
|
||||
data_prefix: str,
|
||||
covariate_list: List[str] | None = None,
|
||||
cache_event_tensors: bool = True,
|
||||
):
|
||||
basic_info = pd.read_csv(
|
||||
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||||
@@ -73,10 +74,45 @@ class HealthDataset(Dataset):
|
||||
self.cont_features = torch.from_numpy(self.cont_features)
|
||||
self.cate_features = torch.from_numpy(self.cate_features)
|
||||
|
||||
# Optional cache for the DOA-inserted sequences produced by __getitem__.
|
||||
# This preserves outputs exactly (we reuse the same construction logic),
|
||||
# but avoids re-building Python lists on repeated access.
|
||||
self._cache_event_tensors = bool(cache_event_tensors)
|
||||
self._cached_event_tensors: List[torch.Tensor | None] = [None] * len(
|
||||
self.patient_ids
|
||||
)
|
||||
self._cached_time_tensors: List[torch.Tensor | None] = [None] * len(
|
||||
self.patient_ids
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self._cache_event_tensors:
|
||||
cached_e = self._cached_event_tensors[idx]
|
||||
cached_t = self._cached_time_tensors[idx]
|
||||
if cached_e is not None and cached_t is not None:
|
||||
event_tensor = cached_e
|
||||
time_tensor = cached_t
|
||||
else:
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = self.patient_events.get(patient_id, [])
|
||||
event_seq = [item[1] for item in records]
|
||||
time_seq = [item[0] for item in records]
|
||||
|
||||
doa = float(self._doa[idx])
|
||||
|
||||
insert_pos = np.searchsorted(time_seq, doa)
|
||||
time_seq.insert(insert_pos, doa)
|
||||
# assuming 1 is the code for 'DOA' event
|
||||
event_seq.insert(insert_pos, 1)
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
|
||||
self._cached_event_tensors[idx] = event_tensor
|
||||
self._cached_time_tensors[idx] = time_tensor
|
||||
else:
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = self.patient_events.get(patient_id, [])
|
||||
event_seq = [item[1] for item in records]
|
||||
|
||||
@@ -125,7 +125,6 @@ def main() -> None:
|
||||
horizons = [float(h) for h in horizons]
|
||||
|
||||
records = build_event_driven_records(
|
||||
dataset=dataset,
|
||||
subset=test_subset,
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
@@ -136,7 +135,7 @@ def main() -> None:
|
||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||
|
||||
rec_ds = EvalRecordDataset(dataset, records)
|
||||
rec_ds = EvalRecordDataset(test_subset, records)
|
||||
|
||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||
loader = DataLoader(
|
||||
|
||||
@@ -202,7 +202,6 @@ def main() -> None:
|
||||
|
||||
age_bins_years = parse_float_list(args.age_bins)
|
||||
records = build_event_driven_records(
|
||||
dataset=dataset,
|
||||
subset=test_subset,
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
@@ -213,7 +212,7 @@ def main() -> None:
|
||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||
|
||||
rec_ds = EvalRecordDataset(dataset, records)
|
||||
rec_ds = EvalRecordDataset(test_subset, records)
|
||||
|
||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||
loader = DataLoader(
|
||||
|
||||
55
utils.py
55
utils.py
@@ -267,8 +267,7 @@ def load_checkpoint_into(
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EvalRecord:
|
||||
patient_idx: int
|
||||
patient_id: Any
|
||||
subset_idx: int
|
||||
doa_days: float
|
||||
t0_days: float
|
||||
cutoff_pos: int # baseline position (inclusive)
|
||||
@@ -285,7 +284,6 @@ def _to_days(x_years: float) -> float:
|
||||
|
||||
|
||||
def build_event_driven_records(
|
||||
dataset: HealthDataset,
|
||||
subset: Subset,
|
||||
age_bins_years: Sequence[float],
|
||||
seed: int,
|
||||
@@ -302,34 +300,24 @@ def build_event_driven_records(
|
||||
|
||||
records: List[EvalRecord] = []
|
||||
|
||||
# Subset.indices is deterministic from random_split
|
||||
indices = list(getattr(subset, "indices", range(len(subset))))
|
||||
|
||||
# Speed: avoid calling dataset.__getitem__ for every patient here.
|
||||
# We only need DOA + event times/codes to create evaluation records.
|
||||
# Build records exclusively from the provided subset.
|
||||
# We intentionally avoid reading from subset.dataset internals so the
|
||||
# evaluation pipeline does not depend on the full dataset object.
|
||||
eps = 1e-6
|
||||
for patient_idx in _progress(
|
||||
indices,
|
||||
for subset_idx in _progress(
|
||||
range(len(subset)),
|
||||
enabled=show_progress,
|
||||
desc="Building eval records",
|
||||
total=len(indices),
|
||||
total=len(subset),
|
||||
):
|
||||
patient_id = dataset.patient_ids[patient_idx]
|
||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||
|
||||
doa_days = float(dataset._doa[patient_idx])
|
||||
|
||||
raw_records = dataset.patient_events.get(patient_id, [])
|
||||
if raw_records:
|
||||
times = np.asarray([t for t, _ in raw_records], dtype=np.float64)
|
||||
codes = np.asarray([c for _, c in raw_records], dtype=np.int64)
|
||||
else:
|
||||
times = np.zeros((0,), dtype=np.float64)
|
||||
codes = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
# Mirror HealthDataset insertion logic exactly.
|
||||
insert_pos = int(np.searchsorted(times, doa_days, side="left"))
|
||||
times_ins = np.insert(times, insert_pos, doa_days)
|
||||
codes_ins = np.insert(codes, insert_pos, 1)
|
||||
doa_pos = np.flatnonzero(codes_ins == 1)
|
||||
if doa_pos.size == 0:
|
||||
raise ValueError("Expected DOA token (code=1) in event sequence")
|
||||
doa_days = float(times_ins[int(doa_pos[0])])
|
||||
|
||||
is_disease = codes_ins >= N_TECH_TOKENS
|
||||
disease_times = times_ins[is_disease]
|
||||
@@ -389,8 +377,7 @@ def build_event_driven_records(
|
||||
|
||||
records.append(
|
||||
EvalRecord(
|
||||
patient_idx=int(patient_idx),
|
||||
patient_id=patient_id,
|
||||
subset_idx=int(subset_idx),
|
||||
doa_days=float(doa_days),
|
||||
t0_days=float(t0_days),
|
||||
cutoff_pos=int(cutoff_pos),
|
||||
@@ -405,8 +392,8 @@ def build_event_driven_records(
|
||||
|
||||
|
||||
class EvalRecordDataset(Dataset):
|
||||
def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]):
|
||||
self.base = base_dataset
|
||||
def __init__(self, subset: Dataset, records: Sequence[EvalRecord]):
|
||||
self.subset = subset
|
||||
self.records = list(records)
|
||||
self._cache: Dict[int, Tuple[torch.Tensor,
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
|
||||
@@ -418,12 +405,12 @@ class EvalRecordDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
rec = self.records[idx]
|
||||
cached = self._cache.get(rec.patient_idx)
|
||||
cached = self._cache.get(rec.subset_idx)
|
||||
if cached is None:
|
||||
event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx]
|
||||
event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx]
|
||||
cached = (event_seq, time_seq, cont, cate, int(sex))
|
||||
self._cache[rec.patient_idx] = cached
|
||||
self._cache_order.append(rec.patient_idx)
|
||||
self._cache[rec.subset_idx] = cached
|
||||
self._cache_order.append(rec.subset_idx)
|
||||
if len(self._cache_order) > self._cache_max:
|
||||
drop = self._cache_order.pop(0)
|
||||
self._cache.pop(drop, None)
|
||||
|
||||
Reference in New Issue
Block a user