import argparse import os from typing import List, Optional import numpy as np import pandas as pd import matplotlib.pyplot as plt from dataset import HealthDataset def _default_covariates(full_cov: bool) -> Optional[List[str]]: # Mirrors train.py default behavior. if full_cov: return None return ["bmi", "smoking", "alcohol"] def main() -> None: parser = argparse.ArgumentParser( description="Extract per-patient sequence lengths from the UKB dataset." ) parser.add_argument( "--data_prefix", type=str, default="ukb", help="Prefix for dataset files (expects _basic_info.csv, _table.csv, _event_data.npy).", ) parser.add_argument( "--full_cov", action="store_true", help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).", ) parser.add_argument( "--out_csv", type=str, default=None, help="Output CSV path. Default: _sequence_lengths.csv", ) parser.add_argument( "--out_npy", type=str, default=None, help="Optional output .npy path for just the lengths array.", ) args = parser.parse_args() out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv" cov_list = _default_covariates(args.full_cov) ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list) lengths = ds.get_sequence_lengths() df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths}) os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True) df.to_csv(out_csv, index=False) if args.out_npy: os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True) np.save(args.out_npy, np.asarray(lengths, dtype=np.int32)) arr = np.asarray(lengths, dtype=np.int64) print(f"Wrote: {out_csv}") print( "Summary: " f"n={arr.size}, min={arr.min()}, p50={int(np.median(arr))}, mean={arr.mean():.2f}, max={arr.max()}" ) # Plot histogram plt.figure(figsize=(8, 5)) plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8) plt.xlabel("Sequence length (including DOA)") plt.ylabel("Number of patients") plt.title("Histogram of Patient Sequence Lengths") plt.tight_layout() plot_path = os.path.splitext(out_csv)[0] + "_hist.png" plt.savefig(plot_path) print(f"Histogram saved to: {plot_path}") plt.show() if __name__ == "__main__": main()