Enhance HealthDataset with caching for event tensors and update evaluation scripts to use test subsets

This commit is contained in:
2026-01-17 14:42:02 +08:00
parent 7840a4c53e
commit a90f22a865
4 changed files with 70 additions and 49 deletions

View File

@@ -20,6 +20,7 @@ class HealthDataset(Dataset):
self, self,
data_prefix: str, data_prefix: str,
covariate_list: List[str] | None = None, covariate_list: List[str] | None = None,
cache_event_tensors: bool = True,
): ):
basic_info = pd.read_csv( basic_info = pd.read_csv(
f"{data_prefix}_basic_info.csv", index_col='eid') f"{data_prefix}_basic_info.csv", index_col='eid')
@@ -73,23 +74,58 @@ class HealthDataset(Dataset):
self.cont_features = torch.from_numpy(self.cont_features) self.cont_features = torch.from_numpy(self.cont_features)
self.cate_features = torch.from_numpy(self.cate_features) self.cate_features = torch.from_numpy(self.cate_features)
# Optional cache for the DOA-inserted sequences produced by __getitem__.
# This preserves outputs exactly (we reuse the same construction logic),
# but avoids re-building Python lists on repeated access.
self._cache_event_tensors = bool(cache_event_tensors)
self._cached_event_tensors: List[torch.Tensor | None] = [None] * len(
self.patient_ids
)
self._cached_time_tensors: List[torch.Tensor | None] = [None] * len(
self.patient_ids
)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.patient_ids) return len(self.patient_ids)
def __getitem__(self, idx): def __getitem__(self, idx):
patient_id = self.patient_ids[idx] if self._cache_event_tensors:
records = self.patient_events.get(patient_id, []) cached_e = self._cached_event_tensors[idx]
event_seq = [item[1] for item in records] cached_t = self._cached_time_tensors[idx]
time_seq = [item[0] for item in records] if cached_e is not None and cached_t is not None:
event_tensor = cached_e
time_tensor = cached_t
else:
patient_id = self.patient_ids[idx]
records = self.patient_events.get(patient_id, [])
event_seq = [item[1] for item in records]
time_seq = [item[0] for item in records]
doa = float(self._doa[idx]) doa = float(self._doa[idx])
insert_pos = np.searchsorted(time_seq, doa) insert_pos = np.searchsorted(time_seq, doa)
time_seq.insert(insert_pos, doa) time_seq.insert(insert_pos, doa)
# assuming 1 is the code for 'DOA' event # assuming 1 is the code for 'DOA' event
event_seq.insert(insert_pos, 1) event_seq.insert(insert_pos, 1)
event_tensor = torch.tensor(event_seq, dtype=torch.long) event_tensor = torch.tensor(event_seq, dtype=torch.long)
time_tensor = torch.tensor(time_seq, dtype=torch.float) time_tensor = torch.tensor(time_seq, dtype=torch.float)
self._cached_event_tensors[idx] = event_tensor
self._cached_time_tensors[idx] = time_tensor
else:
patient_id = self.patient_ids[idx]
records = self.patient_events.get(patient_id, [])
event_seq = [item[1] for item in records]
time_seq = [item[0] for item in records]
doa = float(self._doa[idx])
insert_pos = np.searchsorted(time_seq, doa)
time_seq.insert(insert_pos, doa)
# assuming 1 is the code for 'DOA' event
event_seq.insert(insert_pos, 1)
event_tensor = torch.tensor(event_seq, dtype=torch.long)
time_tensor = torch.tensor(time_seq, dtype=torch.float)
cont_tensor = self.cont_features[idx, :].to(dtype=torch.float) cont_tensor = self.cont_features[idx, :].to(dtype=torch.float)
cate_tensor = self.cate_features[idx, :].to(dtype=torch.long) cate_tensor = self.cate_features[idx, :].to(dtype=torch.long)

View File

@@ -125,7 +125,6 @@ def main() -> None:
horizons = [float(h) for h in horizons] horizons = [float(h) for h in horizons]
records = build_event_driven_records( records = build_event_driven_records(
dataset=dataset,
subset=test_subset, subset=test_subset,
age_bins_years=age_bins_years, age_bins_years=age_bins_years,
seed=args.seed, seed=args.seed,
@@ -136,7 +135,7 @@ def main() -> None:
model, head, criterion = build_model_head_criterion(cfg, dataset, device) model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device) load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records) rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader( loader = DataLoader(

View File

@@ -202,7 +202,6 @@ def main() -> None:
age_bins_years = parse_float_list(args.age_bins) age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records( records = build_event_driven_records(
dataset=dataset,
subset=test_subset, subset=test_subset,
age_bins_years=age_bins_years, age_bins_years=age_bins_years,
seed=args.seed, seed=args.seed,
@@ -213,7 +212,7 @@ def main() -> None:
model, head, criterion = build_model_head_criterion(cfg, dataset, device) model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device) load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records) rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader( loader = DataLoader(

View File

@@ -267,8 +267,7 @@ def load_checkpoint_into(
@dataclass(frozen=True) @dataclass(frozen=True)
class EvalRecord: class EvalRecord:
patient_idx: int subset_idx: int
patient_id: Any
doa_days: float doa_days: float
t0_days: float t0_days: float
cutoff_pos: int # baseline position (inclusive) cutoff_pos: int # baseline position (inclusive)
@@ -285,7 +284,6 @@ def _to_days(x_years: float) -> float:
def build_event_driven_records( def build_event_driven_records(
dataset: HealthDataset,
subset: Subset, subset: Subset,
age_bins_years: Sequence[float], age_bins_years: Sequence[float],
seed: int, seed: int,
@@ -302,34 +300,24 @@ def build_event_driven_records(
records: List[EvalRecord] = [] records: List[EvalRecord] = []
# Subset.indices is deterministic from random_split # Build records exclusively from the provided subset.
indices = list(getattr(subset, "indices", range(len(subset)))) # We intentionally avoid reading from subset.dataset internals so the
# evaluation pipeline does not depend on the full dataset object.
# Speed: avoid calling dataset.__getitem__ for every patient here.
# We only need DOA + event times/codes to create evaluation records.
eps = 1e-6 eps = 1e-6
for patient_idx in _progress( for subset_idx in _progress(
indices, range(len(subset)),
enabled=show_progress, enabled=show_progress,
desc="Building eval records", desc="Building eval records",
total=len(indices), total=len(subset),
): ):
patient_id = dataset.patient_ids[patient_idx] event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
doa_days = float(dataset._doa[patient_idx]) doa_pos = np.flatnonzero(codes_ins == 1)
if doa_pos.size == 0:
raw_records = dataset.patient_events.get(patient_id, []) raise ValueError("Expected DOA token (code=1) in event sequence")
if raw_records: doa_days = float(times_ins[int(doa_pos[0])])
times = np.asarray([t for t, _ in raw_records], dtype=np.float64)
codes = np.asarray([c for _, c in raw_records], dtype=np.int64)
else:
times = np.zeros((0,), dtype=np.float64)
codes = np.zeros((0,), dtype=np.int64)
# Mirror HealthDataset insertion logic exactly.
insert_pos = int(np.searchsorted(times, doa_days, side="left"))
times_ins = np.insert(times, insert_pos, doa_days)
codes_ins = np.insert(codes, insert_pos, 1)
is_disease = codes_ins >= N_TECH_TOKENS is_disease = codes_ins >= N_TECH_TOKENS
disease_times = times_ins[is_disease] disease_times = times_ins[is_disease]
@@ -389,8 +377,7 @@ def build_event_driven_records(
records.append( records.append(
EvalRecord( EvalRecord(
patient_idx=int(patient_idx), subset_idx=int(subset_idx),
patient_id=patient_id,
doa_days=float(doa_days), doa_days=float(doa_days),
t0_days=float(t0_days), t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos), cutoff_pos=int(cutoff_pos),
@@ -405,8 +392,8 @@ def build_event_driven_records(
class EvalRecordDataset(Dataset): class EvalRecordDataset(Dataset):
def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]): def __init__(self, subset: Dataset, records: Sequence[EvalRecord]):
self.base = base_dataset self.subset = subset
self.records = list(records) self.records = list(records)
self._cache: Dict[int, Tuple[torch.Tensor, self._cache: Dict[int, Tuple[torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {} torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
@@ -418,12 +405,12 @@ class EvalRecordDataset(Dataset):
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
rec = self.records[idx] rec = self.records[idx]
cached = self._cache.get(rec.patient_idx) cached = self._cache.get(rec.subset_idx)
if cached is None: if cached is None:
event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx] event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx]
cached = (event_seq, time_seq, cont, cate, int(sex)) cached = (event_seq, time_seq, cont, cate, int(sex))
self._cache[rec.patient_idx] = cached self._cache[rec.subset_idx] = cached
self._cache_order.append(rec.patient_idx) self._cache_order.append(rec.subset_idx)
if len(self._cache_order) > self._cache_max: if len(self._cache_order) > self._cache_max:
drop = self._cache_order.pop(0) drop = self._cache_order.pop(0)
self._cache.pop(drop, None) self._cache.pop(drop, None)