Enhance HealthDataset with caching for event tensors and update evaluation scripts to use test subsets
This commit is contained in:
58
dataset.py
58
dataset.py
@@ -20,6 +20,7 @@ class HealthDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
data_prefix: str,
|
data_prefix: str,
|
||||||
covariate_list: List[str] | None = None,
|
covariate_list: List[str] | None = None,
|
||||||
|
cache_event_tensors: bool = True,
|
||||||
):
|
):
|
||||||
basic_info = pd.read_csv(
|
basic_info = pd.read_csv(
|
||||||
f"{data_prefix}_basic_info.csv", index_col='eid')
|
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||||||
@@ -73,23 +74,58 @@ class HealthDataset(Dataset):
|
|||||||
self.cont_features = torch.from_numpy(self.cont_features)
|
self.cont_features = torch.from_numpy(self.cont_features)
|
||||||
self.cate_features = torch.from_numpy(self.cate_features)
|
self.cate_features = torch.from_numpy(self.cate_features)
|
||||||
|
|
||||||
|
# Optional cache for the DOA-inserted sequences produced by __getitem__.
|
||||||
|
# This preserves outputs exactly (we reuse the same construction logic),
|
||||||
|
# but avoids re-building Python lists on repeated access.
|
||||||
|
self._cache_event_tensors = bool(cache_event_tensors)
|
||||||
|
self._cached_event_tensors: List[torch.Tensor | None] = [None] * len(
|
||||||
|
self.patient_ids
|
||||||
|
)
|
||||||
|
self._cached_time_tensors: List[torch.Tensor | None] = [None] * len(
|
||||||
|
self.patient_ids
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.patient_ids)
|
return len(self.patient_ids)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
patient_id = self.patient_ids[idx]
|
if self._cache_event_tensors:
|
||||||
records = self.patient_events.get(patient_id, [])
|
cached_e = self._cached_event_tensors[idx]
|
||||||
event_seq = [item[1] for item in records]
|
cached_t = self._cached_time_tensors[idx]
|
||||||
time_seq = [item[0] for item in records]
|
if cached_e is not None and cached_t is not None:
|
||||||
|
event_tensor = cached_e
|
||||||
|
time_tensor = cached_t
|
||||||
|
else:
|
||||||
|
patient_id = self.patient_ids[idx]
|
||||||
|
records = self.patient_events.get(patient_id, [])
|
||||||
|
event_seq = [item[1] for item in records]
|
||||||
|
time_seq = [item[0] for item in records]
|
||||||
|
|
||||||
doa = float(self._doa[idx])
|
doa = float(self._doa[idx])
|
||||||
|
|
||||||
insert_pos = np.searchsorted(time_seq, doa)
|
insert_pos = np.searchsorted(time_seq, doa)
|
||||||
time_seq.insert(insert_pos, doa)
|
time_seq.insert(insert_pos, doa)
|
||||||
# assuming 1 is the code for 'DOA' event
|
# assuming 1 is the code for 'DOA' event
|
||||||
event_seq.insert(insert_pos, 1)
|
event_seq.insert(insert_pos, 1)
|
||||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||||
|
|
||||||
|
self._cached_event_tensors[idx] = event_tensor
|
||||||
|
self._cached_time_tensors[idx] = time_tensor
|
||||||
|
else:
|
||||||
|
patient_id = self.patient_ids[idx]
|
||||||
|
records = self.patient_events.get(patient_id, [])
|
||||||
|
event_seq = [item[1] for item in records]
|
||||||
|
time_seq = [item[0] for item in records]
|
||||||
|
|
||||||
|
doa = float(self._doa[idx])
|
||||||
|
|
||||||
|
insert_pos = np.searchsorted(time_seq, doa)
|
||||||
|
time_seq.insert(insert_pos, doa)
|
||||||
|
# assuming 1 is the code for 'DOA' event
|
||||||
|
event_seq.insert(insert_pos, 1)
|
||||||
|
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||||
|
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||||
|
|
||||||
cont_tensor = self.cont_features[idx, :].to(dtype=torch.float)
|
cont_tensor = self.cont_features[idx, :].to(dtype=torch.float)
|
||||||
cate_tensor = self.cate_features[idx, :].to(dtype=torch.long)
|
cate_tensor = self.cate_features[idx, :].to(dtype=torch.long)
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ def main() -> None:
|
|||||||
horizons = [float(h) for h in horizons]
|
horizons = [float(h) for h in horizons]
|
||||||
|
|
||||||
records = build_event_driven_records(
|
records = build_event_driven_records(
|
||||||
dataset=dataset,
|
|
||||||
subset=test_subset,
|
subset=test_subset,
|
||||||
age_bins_years=age_bins_years,
|
age_bins_years=age_bins_years,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
@@ -136,7 +135,7 @@ def main() -> None:
|
|||||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||||
|
|
||||||
rec_ds = EvalRecordDataset(dataset, records)
|
rec_ds = EvalRecordDataset(test_subset, records)
|
||||||
|
|
||||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
|||||||
@@ -202,7 +202,6 @@ def main() -> None:
|
|||||||
|
|
||||||
age_bins_years = parse_float_list(args.age_bins)
|
age_bins_years = parse_float_list(args.age_bins)
|
||||||
records = build_event_driven_records(
|
records = build_event_driven_records(
|
||||||
dataset=dataset,
|
|
||||||
subset=test_subset,
|
subset=test_subset,
|
||||||
age_bins_years=age_bins_years,
|
age_bins_years=age_bins_years,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
@@ -213,7 +212,7 @@ def main() -> None:
|
|||||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||||
|
|
||||||
rec_ds = EvalRecordDataset(dataset, records)
|
rec_ds = EvalRecordDataset(test_subset, records)
|
||||||
|
|
||||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
|||||||
55
utils.py
55
utils.py
@@ -267,8 +267,7 @@ def load_checkpoint_into(
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class EvalRecord:
|
class EvalRecord:
|
||||||
patient_idx: int
|
subset_idx: int
|
||||||
patient_id: Any
|
|
||||||
doa_days: float
|
doa_days: float
|
||||||
t0_days: float
|
t0_days: float
|
||||||
cutoff_pos: int # baseline position (inclusive)
|
cutoff_pos: int # baseline position (inclusive)
|
||||||
@@ -285,7 +284,6 @@ def _to_days(x_years: float) -> float:
|
|||||||
|
|
||||||
|
|
||||||
def build_event_driven_records(
|
def build_event_driven_records(
|
||||||
dataset: HealthDataset,
|
|
||||||
subset: Subset,
|
subset: Subset,
|
||||||
age_bins_years: Sequence[float],
|
age_bins_years: Sequence[float],
|
||||||
seed: int,
|
seed: int,
|
||||||
@@ -302,34 +300,24 @@ def build_event_driven_records(
|
|||||||
|
|
||||||
records: List[EvalRecord] = []
|
records: List[EvalRecord] = []
|
||||||
|
|
||||||
# Subset.indices is deterministic from random_split
|
# Build records exclusively from the provided subset.
|
||||||
indices = list(getattr(subset, "indices", range(len(subset))))
|
# We intentionally avoid reading from subset.dataset internals so the
|
||||||
|
# evaluation pipeline does not depend on the full dataset object.
|
||||||
# Speed: avoid calling dataset.__getitem__ for every patient here.
|
|
||||||
# We only need DOA + event times/codes to create evaluation records.
|
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
for patient_idx in _progress(
|
for subset_idx in _progress(
|
||||||
indices,
|
range(len(subset)),
|
||||||
enabled=show_progress,
|
enabled=show_progress,
|
||||||
desc="Building eval records",
|
desc="Building eval records",
|
||||||
total=len(indices),
|
total=len(subset),
|
||||||
):
|
):
|
||||||
patient_id = dataset.patient_ids[patient_idx]
|
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||||
|
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||||
|
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||||
|
|
||||||
doa_days = float(dataset._doa[patient_idx])
|
doa_pos = np.flatnonzero(codes_ins == 1)
|
||||||
|
if doa_pos.size == 0:
|
||||||
raw_records = dataset.patient_events.get(patient_id, [])
|
raise ValueError("Expected DOA token (code=1) in event sequence")
|
||||||
if raw_records:
|
doa_days = float(times_ins[int(doa_pos[0])])
|
||||||
times = np.asarray([t for t, _ in raw_records], dtype=np.float64)
|
|
||||||
codes = np.asarray([c for _, c in raw_records], dtype=np.int64)
|
|
||||||
else:
|
|
||||||
times = np.zeros((0,), dtype=np.float64)
|
|
||||||
codes = np.zeros((0,), dtype=np.int64)
|
|
||||||
|
|
||||||
# Mirror HealthDataset insertion logic exactly.
|
|
||||||
insert_pos = int(np.searchsorted(times, doa_days, side="left"))
|
|
||||||
times_ins = np.insert(times, insert_pos, doa_days)
|
|
||||||
codes_ins = np.insert(codes, insert_pos, 1)
|
|
||||||
|
|
||||||
is_disease = codes_ins >= N_TECH_TOKENS
|
is_disease = codes_ins >= N_TECH_TOKENS
|
||||||
disease_times = times_ins[is_disease]
|
disease_times = times_ins[is_disease]
|
||||||
@@ -389,8 +377,7 @@ def build_event_driven_records(
|
|||||||
|
|
||||||
records.append(
|
records.append(
|
||||||
EvalRecord(
|
EvalRecord(
|
||||||
patient_idx=int(patient_idx),
|
subset_idx=int(subset_idx),
|
||||||
patient_id=patient_id,
|
|
||||||
doa_days=float(doa_days),
|
doa_days=float(doa_days),
|
||||||
t0_days=float(t0_days),
|
t0_days=float(t0_days),
|
||||||
cutoff_pos=int(cutoff_pos),
|
cutoff_pos=int(cutoff_pos),
|
||||||
@@ -405,8 +392,8 @@ def build_event_driven_records(
|
|||||||
|
|
||||||
|
|
||||||
class EvalRecordDataset(Dataset):
|
class EvalRecordDataset(Dataset):
|
||||||
def __init__(self, base_dataset: HealthDataset, records: Sequence[EvalRecord]):
|
def __init__(self, subset: Dataset, records: Sequence[EvalRecord]):
|
||||||
self.base = base_dataset
|
self.subset = subset
|
||||||
self.records = list(records)
|
self.records = list(records)
|
||||||
self._cache: Dict[int, Tuple[torch.Tensor,
|
self._cache: Dict[int, Tuple[torch.Tensor,
|
||||||
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
|
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
|
||||||
@@ -418,12 +405,12 @@ class EvalRecordDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
rec = self.records[idx]
|
rec = self.records[idx]
|
||||||
cached = self._cache.get(rec.patient_idx)
|
cached = self._cache.get(rec.subset_idx)
|
||||||
if cached is None:
|
if cached is None:
|
||||||
event_seq, time_seq, cont, cate, sex = self.base[rec.patient_idx]
|
event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx]
|
||||||
cached = (event_seq, time_seq, cont, cate, int(sex))
|
cached = (event_seq, time_seq, cont, cate, int(sex))
|
||||||
self._cache[rec.patient_idx] = cached
|
self._cache[rec.subset_idx] = cached
|
||||||
self._cache_order.append(rec.patient_idx)
|
self._cache_order.append(rec.subset_idx)
|
||||||
if len(self._cache_order) > self._cache_max:
|
if len(self._cache_order) > self._cache_max:
|
||||||
drop = self._cache_order.pop(0)
|
drop = self._cache_order.pop(0)
|
||||||
self._cache.pop(drop, None)
|
self._cache.pop(drop, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user