Enhance DataLoader configuration and improve tensor transfer efficiency in Trainer class
This commit is contained in:
26
dataset.py
26
dataset.py
@@ -35,6 +35,8 @@ class HealthDataset(Dataset):
|
||||
self.basic_info = basic_info.convert_dtypes()
|
||||
self.patient_ids = self.basic_info.index.tolist()
|
||||
self.patient_events = dict(patient_events)
|
||||
for patient_id, records in self.patient_events.items():
|
||||
records.sort(key=lambda x: x[0])
|
||||
|
||||
tabular_data = tabular_data.convert_dtypes()
|
||||
cont_cols = []
|
||||
@@ -61,17 +63,26 @@ class HealthDataset(Dataset):
|
||||
self.n_cont = self.cont_features.shape[1]
|
||||
self.n_cate = self.cate_features.shape[1]
|
||||
|
||||
self._doa = self.basic_info.loc[
|
||||
self.patient_ids, 'date_of_assessment'
|
||||
].to_numpy(dtype=np.float32)
|
||||
self._sex = self.basic_info.loc[
|
||||
self.patient_ids, 'sex'
|
||||
].to_numpy(dtype=np.int64)
|
||||
|
||||
self.cont_features = torch.from_numpy(self.cont_features)
|
||||
self.cate_features = torch.from_numpy(self.cate_features)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events.get(
|
||||
patient_id, []), key=lambda x: x[0])
|
||||
records = self.patient_events.get(patient_id, [])
|
||||
event_seq = [item[1] for item in records]
|
||||
time_seq = [item[0] for item in records]
|
||||
|
||||
doa = self.basic_info.loc[patient_id, 'date_of_assessment']
|
||||
doa = float(self._doa[idx])
|
||||
|
||||
insert_pos = np.searchsorted(time_seq, doa)
|
||||
time_seq.insert(insert_pos, doa)
|
||||
@@ -79,11 +90,10 @@ class HealthDataset(Dataset):
|
||||
event_seq.insert(insert_pos, 1)
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
cont_tensor = torch.tensor(
|
||||
self.cont_features[idx, :], dtype=torch.float)
|
||||
cate_tensor = torch.tensor(
|
||||
self.cate_features[idx, :], dtype=torch.long)
|
||||
sex = self.basic_info.loc[patient_id, 'sex']
|
||||
|
||||
cont_tensor = self.cont_features[idx, :].to(dtype=torch.float)
|
||||
cate_tensor = self.cate_features[idx, :].to(dtype=torch.long)
|
||||
sex = int(self._sex[idx])
|
||||
|
||||
return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user