Add get_sequence_lengths method to HealthDataset and create extract_sequence_lengths script for per-patient sequence length extraction
This commit is contained in:
11
dataset.py
11
dataset.py
@@ -88,6 +88,17 @@ class HealthDataset(Dataset):
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.patient_ids)
|
return len(self.patient_ids)
|
||||||
|
|
||||||
|
def get_sequence_lengths(self) -> List[int]:
|
||||||
|
"""Return the sequence length for each patient.
|
||||||
|
|
||||||
|
Lengths correspond to what :meth:`__getitem__` returns: the number of
|
||||||
|
patient events plus the inserted DOA event.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: Sequence lengths aligned with dataset indices.
|
||||||
|
"""
|
||||||
|
return [len(self.patient_events.get(pid, ())) + 1 for pid in self.patient_ids]
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if self._cache_event_tensors:
|
if self._cache_event_tensors:
|
||||||
cached_e = self._cached_event_tensors[idx]
|
cached_e = self._cached_event_tensors[idx]
|
||||||
|
|||||||
85
extract_sequence_lengths.py
Normal file
85
extract_sequence_lengths.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from dataset import HealthDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _default_covariates(full_cov: bool) -> Optional[List[str]]:
|
||||||
|
# Mirrors train.py default behavior.
|
||||||
|
if full_cov:
|
||||||
|
return None
|
||||||
|
return ["bmi", "smoking", "alcohol"]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Extract per-patient sequence lengths from the UKB dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_prefix",
|
||||||
|
type=str,
|
||||||
|
default="ukb",
|
||||||
|
help="Prefix for dataset files (expects <prefix>_basic_info.csv, <prefix>_table.csv, <prefix>_event_data.npy).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--full_cov",
|
||||||
|
action="store_true",
|
||||||
|
help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out_csv",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Output CSV path. Default: <data_prefix>_sequence_lengths.csv",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out_npy",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Optional output .npy path for just the lengths array.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv"
|
||||||
|
|
||||||
|
cov_list = _default_covariates(args.full_cov)
|
||||||
|
ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list)
|
||||||
|
|
||||||
|
lengths = ds.get_sequence_lengths()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths})
|
||||||
|
os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
|
||||||
|
df.to_csv(out_csv, index=False)
|
||||||
|
|
||||||
|
if args.out_npy:
|
||||||
|
os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True)
|
||||||
|
np.save(args.out_npy, np.asarray(lengths, dtype=np.int32))
|
||||||
|
|
||||||
|
arr = np.asarray(lengths, dtype=np.int64)
|
||||||
|
print(f"Wrote: {out_csv}")
|
||||||
|
print(
|
||||||
|
"Summary: "
|
||||||
|
f"n={arr.size}, min={arr.min()}, p50={int(np.median(arr))}, mean={arr.mean():.2f}, max={arr.max()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot histogram
|
||||||
|
plt.figure(figsize=(8, 5))
|
||||||
|
plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8)
|
||||||
|
plt.xlabel("Sequence length (including DOA)")
|
||||||
|
plt.ylabel("Number of patients")
|
||||||
|
plt.title("Histogram of Patient Sequence Lengths")
|
||||||
|
plt.tight_layout()
|
||||||
|
plot_path = os.path.splitext(out_csv)[0] + "_hist.png"
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
print(f"Histogram saved to: {plot_path}")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user