Add get_sequence_lengths method to HealthDataset and create extract_sequence_lengths script for per-patient sequence length extraction

This commit is contained in:
2026-01-19 14:39:13 +08:00
parent 76d3fed76f
commit 29913106cb
2 changed files with 96 additions and 0 deletions

View File

@@ -88,6 +88,17 @@ class HealthDataset(Dataset):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.patient_ids) return len(self.patient_ids)
def get_sequence_lengths(self) -> List[int]:
"""Return the sequence length for each patient.
Lengths correspond to what :meth:`__getitem__` returns: the number of
patient events plus the inserted DOA event.
Returns:
List[int]: Sequence lengths aligned with dataset indices.
"""
return [len(self.patient_events.get(pid, ())) + 1 for pid in self.patient_ids]
def __getitem__(self, idx): def __getitem__(self, idx):
if self._cache_event_tensors: if self._cache_event_tensors:
cached_e = self._cached_event_tensors[idx] cached_e = self._cached_event_tensors[idx]

View File

@@ -0,0 +1,85 @@
import argparse
import os
from typing import List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataset import HealthDataset
def _default_covariates(full_cov: bool) -> Optional[List[str]]:
# Mirrors train.py default behavior.
if full_cov:
return None
return ["bmi", "smoking", "alcohol"]
def main() -> None:
parser = argparse.ArgumentParser(
description="Extract per-patient sequence lengths from the UKB dataset."
)
parser.add_argument(
"--data_prefix",
type=str,
default="ukb",
help="Prefix for dataset files (expects <prefix>_basic_info.csv, <prefix>_table.csv, <prefix>_event_data.npy).",
)
parser.add_argument(
"--full_cov",
action="store_true",
help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).",
)
parser.add_argument(
"--out_csv",
type=str,
default=None,
help="Output CSV path. Default: <data_prefix>_sequence_lengths.csv",
)
parser.add_argument(
"--out_npy",
type=str,
default=None,
help="Optional output .npy path for just the lengths array.",
)
args = parser.parse_args()
out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv"
cov_list = _default_covariates(args.full_cov)
ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list)
lengths = ds.get_sequence_lengths()
df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths})
os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
df.to_csv(out_csv, index=False)
if args.out_npy:
os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True)
np.save(args.out_npy, np.asarray(lengths, dtype=np.int32))
arr = np.asarray(lengths, dtype=np.int64)
print(f"Wrote: {out_csv}")
print(
"Summary: "
f"n={arr.size}, min={arr.min()}, p50={int(np.median(arr))}, mean={arr.mean():.2f}, max={arr.max()}"
)
# Plot histogram
plt.figure(figsize=(8, 5))
plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8)
plt.xlabel("Sequence length (including DOA)")
plt.ylabel("Number of patients")
plt.title("Histogram of Patient Sequence Lengths")
plt.tight_layout()
plot_path = os.path.splitext(out_csv)[0] + "_hist.png"
plt.savefig(plot_path)
print(f"Histogram saved to: {plot_path}")
plt.show()
if __name__ == "__main__":
main()