Files
DeepHealth/extract_sequence_lengths.py

90 lines
2.6 KiB
Python

import argparse
import os
from typing import List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataset import HealthDataset
def _default_covariates(full_cov: bool) -> Optional[List[str]]:
# Mirrors train.py default behavior.
if full_cov:
return None
return ["bmi", "smoking", "alcohol"]
def main() -> None:
parser = argparse.ArgumentParser(
description="Extract per-patient sequence lengths from the UKB dataset."
)
parser.add_argument(
"--data_prefix",
type=str,
default="ukb",
help="Prefix for dataset files (expects <prefix>_basic_info.csv, <prefix>_table.csv, <prefix>_event_data.npy).",
)
parser.add_argument(
"--full_cov",
action="store_true",
help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).",
)
parser.add_argument(
"--out_csv",
type=str,
default=None,
help="Output CSV path. Default: <data_prefix>_sequence_lengths.csv",
)
parser.add_argument(
"--out_npy",
type=str,
default=None,
help="Optional output .npy path for just the lengths array.",
)
args = parser.parse_args()
out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv"
cov_list = _default_covariates(args.full_cov)
ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list)
lengths = ds.get_sequence_lengths()
df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths})
os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
df.to_csv(out_csv, index=False)
if args.out_npy:
os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True)
np.save(args.out_npy, np.asarray(lengths, dtype=np.int32))
arr = np.asarray(lengths, dtype=np.int64)
print(f"Wrote: {out_csv}")
percentiles = [5, 10, 25, 50, 75, 90, 95, 99]
pct_values = np.percentile(arr, percentiles)
print("Summary:")
print(f" n={arr.size}")
print(f" min={arr.min()} max={arr.max()} mean={arr.mean():.2f}")
for p, v in zip(percentiles, pct_values):
print(f" p{p:02d}={int(v)}")
# Plot histogram
plt.figure(figsize=(8, 5))
plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8)
plt.xlabel("Sequence length (including DOA)")
plt.ylabel("Number of patients")
plt.title("Histogram of Patient Sequence Lengths")
plt.tight_layout()
plot_path = os.path.splitext(out_csv)[0] + "_hist.png"
plt.savefig(plot_path)
print(f"Histogram saved to: {plot_path}")
plt.show()
if __name__ == "__main__":
main()