Add get_sequence_lengths method to HealthDataset and create extract_sequence_lengths script for per-patient sequence length extraction
This commit is contained in:
85
extract_sequence_lengths.py
Normal file
85
extract_sequence_lengths.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from dataset import HealthDataset
|
||||
|
||||
|
||||
def _default_covariates(full_cov: bool) -> Optional[List[str]]:
|
||||
# Mirrors train.py default behavior.
|
||||
if full_cov:
|
||||
return None
|
||||
return ["bmi", "smoking", "alcohol"]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract per-patient sequence lengths from the UKB dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_prefix",
|
||||
type=str,
|
||||
default="ukb",
|
||||
help="Prefix for dataset files (expects <prefix>_basic_info.csv, <prefix>_table.csv, <prefix>_event_data.npy).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_cov",
|
||||
action="store_true",
|
||||
help="Use full covariates (otherwise uses the training default: bmi/smoking/alcohol).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_csv",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output CSV path. Default: <data_prefix>_sequence_lengths.csv",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_npy",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional output .npy path for just the lengths array.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
out_csv = args.out_csv or f"{args.data_prefix}_sequence_lengths.csv"
|
||||
|
||||
cov_list = _default_covariates(args.full_cov)
|
||||
ds = HealthDataset(data_prefix=args.data_prefix, covariate_list=cov_list)
|
||||
|
||||
lengths = ds.get_sequence_lengths()
|
||||
|
||||
df = pd.DataFrame({"eid": ds.patient_ids, "seq_len": lengths})
|
||||
os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
|
||||
df.to_csv(out_csv, index=False)
|
||||
|
||||
if args.out_npy:
|
||||
os.makedirs(os.path.dirname(args.out_npy) or ".", exist_ok=True)
|
||||
np.save(args.out_npy, np.asarray(lengths, dtype=np.int32))
|
||||
|
||||
arr = np.asarray(lengths, dtype=np.int64)
|
||||
print(f"Wrote: {out_csv}")
|
||||
print(
|
||||
"Summary: "
|
||||
f"n={arr.size}, min={arr.min()}, p50={int(np.median(arr))}, mean={arr.mean():.2f}, max={arr.max()}"
|
||||
)
|
||||
|
||||
# Plot histogram
|
||||
plt.figure(figsize=(8, 5))
|
||||
plt.hist(arr, bins=50, color="#348abd", edgecolor="black", alpha=0.8)
|
||||
plt.xlabel("Sequence length (including DOA)")
|
||||
plt.ylabel("Number of patients")
|
||||
plt.title("Histogram of Patient Sequence Lengths")
|
||||
plt.tight_layout()
|
||||
plot_path = os.path.splitext(out_csv)[0] + "_hist.png"
|
||||
plt.savefig(plot_path)
|
||||
print(f"Histogram saved to: {plot_path}")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user