Files
DeepHealth/extract_sequence_lengths.py

90 lines
2.6 KiB
Python
Raw Normal View History

import argparse
import os
from typing import List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataset import HealthDataset
def _default_covariates(full_cov: bool) -> Optional[List[str]]:
# Mirrors train.py default behavior.
if full_cov:
return None
return ["bmi", "smoking", "alcohol"]
def main() -> None:
parser = argparse.ArgumentParser(
description="Extract per-patient sequence lengths from the UKB dataset."
)
parser.add_argument(
"--data_prefix",
type=str,
default="ukb",
help="Prefix for dataset files (expects <prefix>_basic_info.csv, <prefix>_table.csv, <prefix>_event_data.npy).",
)
parser.add_argument(
"--full_cov",
action="store_true",
help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).",
)
parser.add_argument(
"--out_csv",
type=str,
default=None,
help="Output CSV path. Default: <data_prefix>_sequence_lengths.csv",
)
parser.add_argument(
"--out_npy",
type=str,
default=None,
help="Optional output .npy path for just the lengths array.",
)
args = parser.parse_args()
out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv"
cov_list = _default_covariates(args.full_cov)
ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list)
lengths = ds.get_sequence_lengths()
df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths})
os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
df.to_csv(out_csv, index=False)
if args.out_npy:
os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True)
np.save(args.out_npy, np.asarray(lengths, dtype=np.int32))
arr = np.asarray(lengths, dtype=np.int64)
print(f"Wrote: {out_csv}")
percentiles = [5, 10, 25, 50, 75, 90, 95, 99]
pct_values = np.percentile(arr, percentiles)
print("Summary:")
print(f" n={arr.size}")
print(f" min={arr.min()} max={arr.max()} mean={arr.mean():.2f}")
for p, v in zip(percentiles, pct_values):
print(f" p{p:02d}={int(v)}")
# Plot histogram
plt.figure(figsize=(8, 5))
plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8)
plt.xlabel("Sequence length (including DOA)")
plt.ylabel("Number of patients")
plt.title("Histogram of Patient Sequence Lengths")
plt.tight_layout()
plot_path = os.path.splitext(out_csv)[0] + "_hist.png"
plt.savefig(plot_path)
print(f"Histogram saved to: {plot_path}")
plt.show()
if __name__ == "__main__":
main()