90 lines
2.6 KiB
Python
90 lines
2.6 KiB
Python
import argparse
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from dataset import HealthDataset
|
|
|
|
|
|
def _default_covariates(full_cov: bool) -> Optional[List[str]]:
|
|
# Mirrors train.py default behavior.
|
|
if full_cov:
|
|
return None
|
|
return ["bmi", "smoking", "alcohol"]
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Extract per-patient sequence lengths from the UKB dataset."
|
|
)
|
|
parser.add_argument(
|
|
"--data_prefix",
|
|
type=str,
|
|
default="ukb",
|
|
help="Prefix for dataset files (expects <prefix>_basic_info.csv, <prefix>_table.csv, <prefix>_event_data.npy).",
|
|
)
|
|
parser.add_argument(
|
|
"--full_cov",
|
|
action="store_true",
|
|
help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).",
|
|
)
|
|
parser.add_argument(
|
|
"--out_csv",
|
|
type=str,
|
|
default=None,
|
|
help="Output CSV path. Default: <data_prefix>_sequence_lengths.csv",
|
|
)
|
|
parser.add_argument(
|
|
"--out_npy",
|
|
type=str,
|
|
default=None,
|
|
help="Optional output .npy path for just the lengths array.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv"
|
|
|
|
cov_list = _default_covariates(args.full_cov)
|
|
ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list)
|
|
|
|
lengths = ds.get_sequence_lengths()
|
|
|
|
df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths})
|
|
os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
|
|
df.to_csv(out_csv, index=False)
|
|
|
|
if args.out_npy:
|
|
os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True)
|
|
np.save(args.out_npy, np.asarray(lengths, dtype=np.int32))
|
|
|
|
arr = np.asarray(lengths, dtype=np.int64)
|
|
print(f"Wrote: {out_csv}")
|
|
|
|
percentiles = [5, 10, 25, 50, 75, 90, 95, 99]
|
|
pct_values = np.percentile(arr, percentiles)
|
|
print("Summary:")
|
|
print(f" n={arr.size}")
|
|
print(f" min={arr.min()} max={arr.max()} mean={arr.mean():.2f}")
|
|
for p, v in zip(percentiles, pct_values):
|
|
print(f" p{p:02d}={int(v)}")
|
|
|
|
# Plot histogram
|
|
plt.figure(figsize=(8, 5))
|
|
plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8)
|
|
plt.xlabel("Sequence length (including DOA)")
|
|
plt.ylabel("Number of patients")
|
|
plt.title("Histogram of Patient Sequence Lengths")
|
|
plt.tight_layout()
|
|
plot_path = os.path.splitext(out_csv)[0] + "_hist.png"
|
|
plt.savefig(plot_path)
|
|
print(f"Histogram saved to: {plot_path}")
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|