From 29913106cb30fc1883f2f583f21c99c7f2487c61 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Mon, 19 Jan 2026 14:39:13 +0800 Subject: [PATCH] Add get_sequence_lengths method to HealthDataset and create extract_sequence_lengths script for per-patient sequence length extraction --- dataset.py | 11 +++++ extract_sequence_lengths.py | 85 +++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 extract_sequence_lengths.py diff --git a/dataset.py b/dataset.py index 6cecc01..106cd92 100644 --- a/dataset.py +++ b/dataset.py @@ -88,6 +88,17 @@ class HealthDataset(Dataset): def __len__(self) -> int: return len(self.patient_ids) + def get_sequence_lengths(self) -> List[int]: + """Return the sequence length for each patient. + + Lengths correspond to what :meth:`__getitem__` returns: the number of + patient events plus the inserted DOA event. + + Returns: + List[int]: Sequence lengths aligned with dataset indices. + """ + return [len(self.patient_events.get(pid, ())) + 1 for pid in self.patient_ids] + def __getitem__(self, idx): if self._cache_event_tensors: cached_e = self._cached_event_tensors[idx] diff --git a/extract_sequence_lengths.py b/extract_sequence_lengths.py new file mode 100644 index 0000000..6d0ae48 --- /dev/null +++ b/extract_sequence_lengths.py @@ -0,0 +1,85 @@ +import argparse +import os +from typing import List, Optional + +import numpy as np +import pandas as pd + +import matplotlib.pyplot as plt + +from dataset import HealthDataset + + +def _default_covariates(full_cov: bool) -> Optional[List[str]]: + # Mirrors train.py default behavior. + if full_cov: + return None + return ["bmi", "smoking", "alcohol"] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Extract per-patient sequence lengths from the UKB dataset." + ) + parser.add_argument( + "--data_prefix", + type=str, + default="ukb", + help="Prefix for dataset files (expects _basic_info.csv, _table.csv, _event_data.npy).", + ) + parser.add_argument( + "--full_cov", + action="store_true", + help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).", + ) + parser.add_argument( + "--out_csv", + type=str, + default=None, + help="Output CSV path. Default: _sequence_lengths.csv", + ) + parser.add_argument( + "--out_npy", + type=str, + default=None, + help="Optional output .npy path for just the lengths array.", + ) + args = parser.parse_args() + + out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv" + + cov_list = _default_covariates(args.full_cov) + ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list) + + lengths = ds.get_sequence_lengths() + + df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths}) + os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True) + df.to_csv(out_csv, index=False) + + if args.out_npy: + os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True) + np.save(args.out_npy, np.asarray(lengths, dtype=np.int32)) + + arr = np.asarray(lengths, dtype=np.int64) + print(f"Wrote: {out_csv}") + print( + "Summary: " + f"n={arr.size}, min={arr.min()}, p50={int(np.median(arr))}, mean={arr.mean():.2f}, max={arr.max()}" + ) + + # Plot histogram + plt.figure(figsize=(8, 5)) + plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8) + plt.xlabel("Sequence length (including DOA)") + plt.ylabel("Number of patients") + plt.title("Histogram of Patient Sequence Lengths") + plt.tight_layout() + plot_path = os.path.splitext(out_csv)[0] + "_hist.png" + plt.savefig(plot_path) + print(f"Histogram saved to: {plot_path}") + plt.show() + + +if __name__ == "__main__": + main()