Add get_sequence_lengths method to HealthDataset and create extract_sequence_lengths script for per-patient sequence length extraction

This commit is contained in:
2026-01-19 14:39:13 +08:00
parent 76d3fed76f
commit 29913106cb
2 changed files with 96 additions and 0 deletions

View File

@@ -88,6 +88,17 @@ class HealthDataset(Dataset):
def __len__(self) -> int:
return len(self.patient_ids)
def get_sequence_lengths(self) -> List[int]:
"""Return the sequence length for each patient.
Lengths correspond to what :meth:`__getitem__` returns: the number of
patient events plus the inserted DOA event.
Returns:
List[int]: Sequence lengths aligned with dataset indices.
"""
return [len(self.patient_events.get(pid, ())) + 1 for pid in self.patient_ids]
def __getitem__(self, idx):
if self._cache_event_tensors:
cached_e = self._cached_event_tensors[idx]