Enhance HealthDataset with caching for event tensors and update evaluation scripts to use test subsets
This commit is contained in:
58
dataset.py
58
dataset.py
@@ -20,6 +20,7 @@ class HealthDataset(Dataset):
|
||||
self,
|
||||
data_prefix: str,
|
||||
covariate_list: List[str] | None = None,
|
||||
cache_event_tensors: bool = True,
|
||||
):
|
||||
basic_info = pd.read_csv(
|
||||
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||||
@@ -73,23 +74,58 @@ class HealthDataset(Dataset):
|
||||
self.cont_features = torch.from_numpy(self.cont_features)
|
||||
self.cate_features = torch.from_numpy(self.cate_features)
|
||||
|
||||
# Optional cache for the DOA-inserted sequences produced by __getitem__.
|
||||
# This preserves outputs exactly (we reuse the same construction logic),
|
||||
# but avoids re-building Python lists on repeated access.
|
||||
self._cache_event_tensors = bool(cache_event_tensors)
|
||||
self._cached_event_tensors: List[torch.Tensor | None] = [None] * len(
|
||||
self.patient_ids
|
||||
)
|
||||
self._cached_time_tensors: List[torch.Tensor | None] = [None] * len(
|
||||
self.patient_ids
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = self.patient_events.get(patient_id, [])
|
||||
event_seq = [item[1] for item in records]
|
||||
time_seq = [item[0] for item in records]
|
||||
if self._cache_event_tensors:
|
||||
cached_e = self._cached_event_tensors[idx]
|
||||
cached_t = self._cached_time_tensors[idx]
|
||||
if cached_e is not None and cached_t is not None:
|
||||
event_tensor = cached_e
|
||||
time_tensor = cached_t
|
||||
else:
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = self.patient_events.get(patient_id, [])
|
||||
event_seq = [item[1] for item in records]
|
||||
time_seq = [item[0] for item in records]
|
||||
|
||||
doa = float(self._doa[idx])
|
||||
doa = float(self._doa[idx])
|
||||
|
||||
insert_pos = np.searchsorted(time_seq, doa)
|
||||
time_seq.insert(insert_pos, doa)
|
||||
# assuming 1 is the code for 'DOA' event
|
||||
event_seq.insert(insert_pos, 1)
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
insert_pos = np.searchsorted(time_seq, doa)
|
||||
time_seq.insert(insert_pos, doa)
|
||||
# assuming 1 is the code for 'DOA' event
|
||||
event_seq.insert(insert_pos, 1)
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
|
||||
self._cached_event_tensors[idx] = event_tensor
|
||||
self._cached_time_tensors[idx] = time_tensor
|
||||
else:
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = self.patient_events.get(patient_id, [])
|
||||
event_seq = [item[1] for item in records]
|
||||
time_seq = [item[0] for item in records]
|
||||
|
||||
doa = float(self._doa[idx])
|
||||
|
||||
insert_pos = np.searchsorted(time_seq, doa)
|
||||
time_seq.insert(insert_pos, doa)
|
||||
# assuming 1 is the code for 'DOA' event
|
||||
event_seq.insert(insert_pos, 1)
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
|
||||
cont_tensor = self.cont_features[idx, :].to(dtype=torch.float)
|
||||
cate_tensor = self.cate_features[idx, :].to(dtype=torch.long)
|
||||
|
||||
Reference in New Issue
Block a user