Enhance HealthDataset with caching for event tensors and update evaluation scripts to use test subsets

This commit is contained in:
2026-01-17 14:42:02 +08:00
parent 7840a4c53e
commit a90f22a865
4 changed files with 70 additions and 49 deletions

View File

@@ -267,8 +267,7 @@ def load_checkpoint_into(
@dataclass(frozen=True)
class EvalRecord:
patient_idx: int
patient_id: Any
subset_idx: int
doa_days: float
t0_days: float
cutoff_pos: int # baseline position (inclusive)
@@ -285,7 +284,6 @@ def _to_days(x_years: float) -> float:
def build_event_driven_records(
dataset: HealthDataset,
subset: Subset,
age_bins_years: Sequence[float],
seed: int,
@@ -302,34 +300,24 @@ def build_event_driven_records(
records: List[EvalRecord] = []
# Subset.indices is deterministic from random_split
indices = list(getattr(subset, "indices", range(len(subset))))
# Speed: avoid calling dataset.__getitem__ for every patient here.
# We only need DOA + event times/codes to create evaluation records.
# Build records exclusively from the provided subset.
# We intentionally avoid reading from subset.dataset internals so the
# evaluation pipeline does not depend on the full dataset object.
eps = 1e-6
for patient_idx in _progress(
indices,
for subset_idx in _progress(
range(len(subset)),
enabled=show_progress,
desc="Building eval records",
total=len(indices),
total=len(subset),
):
patient_id = dataset.patient_ids[patient_idx]
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
doa_days = float(dataset._doa[patient_idx])
raw_records = dataset.patient_events.get(patient_id, [])
if raw_records:
times = np.asarray([t for t, _ in raw_records], dtype=np.float64)
codes = np.asarray([c for _, c in raw_records], dtype=np.int64)
else:
times = np.zeros((0,), dtype=np.float64)
codes = np.zeros((0,), dtype=np.int64)
# Mirror HealthDataset insertion logic exactly.
insert_pos = int(np.searchsorted(times, doa_days, side="left"))
times_ins = np.insert(times, insert_pos, doa_days)
codes_ins = np.insert(codes, insert_pos, 1)
doa_pos = np.flatnonzero(codes_ins == 1)
if doa_pos.size == 0:
raise ValueError("Expected DOA token (code=1) in event sequence")
doa_days = float(times_ins[int(doa_pos[0])])
is_disease = codes_ins >= N_TECH_TOKENS
disease_times = times_ins[is_disease]
@@ -389,8 +377,7 @@ def build_event_driven_records(
records.append(
EvalRecord(
patient_idx=int(patient_idx),
patient_id=patient_id,
subset_idx=int(subset_idx),
doa_days=float(doa_days),
t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos),
@@ -405,8 +392,8 @@ def build_event_driven_records(
class EvalRecordDataset(Dataset):
def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]):
self.base = base_dataset
def __init__(self, subset: Dataset, records: Sequence[EvalRecord]):
self.subset = subset
self.records = list(records)
self._cache: Dict[int, Tuple[torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
@@ -418,12 +405,12 @@ class EvalRecordDataset(Dataset):
def __getitem__(self, idx: int):
rec = self.records[idx]
cached = self._cache.get(rec.patient_idx)
cached = self._cache.get(rec.subset_idx)
if cached is None:
event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx]
event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx]
cached = (event_seq, time_seq, cont, cate, int(sex))
self._cache[rec.patient_idx] = cached
self._cache_order.append(rec.patient_idx)
self._cache[rec.subset_idx] = cached
self._cache_order.append(rec.subset_idx)
if len(self._cache_order) > self._cache_max:
drop = self._cache_order.pop(0)
self._cache.pop(drop, None)