Add loss functions and model architecture for time-to-event prediction

- Implemented ExponentialNLLLoss and WeibullNLLLoss in losses.py for negative log-likelihood calculations.
- Developed TabularEncoder class in model.py for encoding tabular features.
- Created DelphiFork and SapDelphi classes in model.py for time-to-event prediction using transformer architecture.
- Added data preparation scripts in prepare_data.R and prepare_data.py for processing UK Biobank data, including handling field mappings and event data extraction.
This commit is contained in:
2026-01-07 21:32:00 +08:00
parent 5d1d79b908
commit 6984b254b3
12 changed files with 5098 additions and 0 deletions

243
embed_icd10.py Normal file
View File

@@ -0,0 +1,243 @@
import argparse
import csv
import os
import re
from dataclasses import dataclass
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
@dataclass(frozen=True)
class Icd10Label:
code: str
disease: str
_LABEL_RE = re.compile(r"^\s*([A-Z][0-9][0-9][A-Z0-9]{0,2})\s*\((.+)\)\s*$")
_CODE_RE = re.compile(r"^[A-Z][A-Z0-9]{1,6}$")
def _read_labels(labels_path: str, *, strict_codes: bool) -> list[Icd10Label]:
labels: list[Icd10Label] = []
with open(labels_path, "r", encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line:
continue
match = _LABEL_RE.match(line)
if match is not None:
code, disease = match.group(1), match.group(2)
else:
parts = line.split(maxsplit=1)
if len(parts) == 1:
# Some label lists include non-ICD entries (e.g., "Death").
# Treat these as both code and disease.
code = parts[0].strip()
disease = code
elif len(parts) == 2:
code, disease = parts[0].strip(), parts[1].strip()
else:
raise ValueError(
f"Unrecognized label format: {line!r}. "
"Expected like 'A00 (cholera)', 'CXX Unknown Cancer', or 'Death'."
)
if disease.startswith("(") and disease.endswith(")"):
disease = disease[1:-1].strip()
if strict_codes and not _CODE_RE.match(code):
raise ValueError(
f"Unrecognized ICD10-like code in label: {line!r} (code={code!r}). "
"Re-run without --strict-codes to allow non-ICD labels (e.g., 'Death')."
)
labels.append(Icd10Label(code=code, disease=disease))
if not labels:
raise ValueError(f"No labels found in {labels_path!r}.")
return labels
def embed_texts(
texts: list[str],
*,
model_name: str,
batch_size: int,
max_length: int,
device: str,
) -> np.ndarray:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()
model.to(device)
all_embs: list[np.ndarray] = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
batch = texts[i: i + batch_size]
toks = tokenizer(
batch,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
toks = {k: v.to(device) for k, v in toks.items()}
# Use CLS token embedding (same as original script).
cls_rep = model(**toks)[0][:, 0, :]
all_embs.append(cls_rep.detach().cpu().to(torch.float32).numpy())
return np.concatenate(all_embs, axis=0)
def save_umap_plot(
embeddings: np.ndarray,
codes: list[str],
*,
out_path: str,
random_state: int = 42,
) -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError as e: # pragma: no cover
raise ImportError(
"UMAP visualization requires matplotlib. Install it with: pip install matplotlib"
) from e
try:
import umap
except ImportError as e: # pragma: no cover
raise ImportError(
"UMAP visualization requires umap-learn. Install it with: pip install umap-learn"
) from e
reducer = umap.UMAP(n_components=2, metric="cosine",
random_state=random_state)
coords = reducer.fit_transform(embeddings)
if len(codes) != coords.shape[0]:
raise ValueError(
f"codes length ({len(codes)}) does not match embeddings rows ({coords.shape[0]})."
)
groups: list[str] = []
for code in codes:
cleaned = code.strip()
if cleaned.lower() == "death":
groups.append("Death")
else:
groups.append(cleaned[:1].upper() if cleaned else "?")
group_names = sorted({g for g in groups if g != "Death"})
cmap = plt.get_cmap("tab20")
group_to_color: dict[str, object] = {
g: cmap(i % cmap.N) for i, g in enumerate(group_names)
}
group_to_color["Death"] = "grey"
colors = [group_to_color.get(g, "black") for g in groups]
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(1, 1, 1)
ax.scatter(coords[:, 0], coords[:, 1], s=6, alpha=0.7, c=colors)
ax.set_title("UMAP of ICD label embeddings")
ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
fig.tight_layout()
fig.savefig(out_path, dpi=200)
plt.close(fig)
def main() -> int:
parser = argparse.ArgumentParser(
description="Embed ICD-10 disease labels with SapBERT")
parser.add_argument(
"--labels",
default="labels.csv",
help="Path to labels.csv (lines like 'A00 (cholera)')",
)
parser.add_argument(
"--out-dir",
default=".",
help="Output directory for embeddings and metadata",
)
parser.add_argument(
"--model",
default="cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
help="HuggingFace model name",
)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--max-length", type=int, default=25)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run on (cuda or cpu)",
)
parser.add_argument(
"--strict-codes",
action="store_true",
help="Fail if a label code is not ICD10-like (disallows labels like 'Death')",
)
parser.add_argument(
"--umap",
action="store_true",
help="Also save a 2D UMAP scatterplot of the embeddings",
)
parser.add_argument(
"--umap-out",
default=None,
help="Path to save UMAP PNG (default: <out-dir>/icd10_sapbert_umap.png)",
)
parser.add_argument(
"--umap-random-state",
type=int,
default=42,
help="Random seed for UMAP",
)
args = parser.parse_args()
labels = _read_labels(args.labels, strict_codes=args.strict_codes)
texts = [lbl.disease for lbl in labels]
embs = embed_texts(
texts,
model_name=args.model,
batch_size=args.batch_size,
max_length=args.max_length,
device=args.device,
)
os.makedirs(args.out_dir, exist_ok=True)
embs_path = os.path.join(args.out_dir, "icd10_sapbert_embeddings.npy")
meta_path = os.path.join(args.out_dir, "icd10_sapbert_metadata.tsv")
np.save(embs_path, embs)
with open(meta_path, "w", encoding="utf-8", newline="") as f:
w = csv.writer(f, delimiter="\t")
w.writerow(["index", "icd10_code", "disease"])
for i, lbl in enumerate(labels):
w.writerow([i, lbl.code, lbl.disease])
if args.umap:
umap_path = (
args.umap_out
if args.umap_out is not None
else os.path.join(args.out_dir, "icd10_sapbert_umap.png")
)
save_umap_plot(
embs,
[lbl.code for lbl in labels],
out_path=umap_path,
random_state=args.umap_random_state,
)
print(f"Saved UMAP plot: {umap_path}")
print(f"Saved embeddings: {embs_path} (shape={embs.shape})")
print(f"Saved metadata: {meta_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())