Add get_sequence_lengths method to HealthDataset and create extract_sequence_lengths script for per-patient sequence length extraction
This commit is contained in:
11
dataset.py
11
dataset.py
@@ -88,6 +88,17 @@ class HealthDataset(Dataset):
|
||||
def __len__(self) -> int:
|
||||
return len(self.patient_ids)
|
||||
|
||||
def get_sequence_lengths(self) -> List[int]:
|
||||
"""Return the sequence length for each patient.
|
||||
|
||||
Lengths correspond to what :meth:`__getitem__` returns: the number of
|
||||
patient events plus the inserted DOA event.
|
||||
|
||||
Returns:
|
||||
List[int]: Sequence lengths aligned with dataset indices.
|
||||
"""
|
||||
return [len(self.patient_events.get(pid, ())) + 1 for pid in self.patient_ids]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self._cache_event_tensors:
|
||||
cached_e = self._cached_event_tensors[idx]
|
||||
|
||||
Reference in New Issue
Block a user