Files
DeepHealth/embed_icd10.py

244 lines
7.6 KiB
Python
Raw Normal View History

import argparse
import csv
import os
import re
from dataclasses import dataclass
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
@dataclass(frozen=True)
class Icd10Label:
code: str
disease: str
_LABEL_RE = re.compile(r"^\s*([A-Z][0-9][0-9][A-Z0-9]{0,2})\s*\((.+)\)\s*$")
_CODE_RE = re.compile(r"^[A-Z][A-Z0-9]{1,6}$")
def _read_labels(labels_path: str, *, strict_codes: bool) -> list[Icd10Label]:
labels: list[Icd10Label] = []
with open(labels_path, "r", encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line:
continue
match = _LABEL_RE.match(line)
if match is not None:
code, disease = match.group(1), match.group(2)
else:
parts = line.split(maxsplit=1)
if len(parts) == 1:
# Some label lists include non-ICD entries (e.g., "Death").
# Treat these as both code and disease.
code = parts[0].strip()
disease = code
elif len(parts) == 2:
code, disease = parts[0].strip(), parts[1].strip()
else:
raise ValueError(
f"Unrecognized label format: {line!r}. "
"Expected like 'A00 (cholera)', 'CXX Unknown Cancer', or 'Death'."
)
if disease.startswith("(") and disease.endswith(")"):
disease = disease[1:-1].strip()
if strict_codes and not _CODE_RE.match(code):
raise ValueError(
f"Unrecognized ICD10-like code in label: {line!r} (code={code!r}). "
"Re-run without --strict-codes to allow non-ICD labels (e.g., 'Death')."
)
labels.append(Icd10Label(code=code, disease=disease))
if not labels:
raise ValueError(f"No labels found in {labels_path!r}.")
return labels
def embed_texts(
texts: list[str],
*,
model_name: str,
batch_size: int,
max_length: int,
device: str,
) -> np.ndarray:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()
model.to(device)
all_embs: list[np.ndarray] = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
batch = texts[i: i + batch_size]
toks = tokenizer(
batch,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
toks = {k: v.to(device) for k, v in toks.items()}
# Use CLS token embedding (same as original script).
cls_rep = model(**toks)[0][:, 0, :]
all_embs.append(cls_rep.detach().cpu().to(torch.float32).numpy())
return np.concatenate(all_embs, axis=0)
def save_umap_plot(
embeddings: np.ndarray,
codes: list[str],
*,
out_path: str,
random_state: int = 42,
) -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError as e: # pragma: no cover
raise ImportError(
"UMAP visualization requires matplotlib. Install it with: pip install matplotlib"
) from e
try:
import umap
except ImportError as e: # pragma: no cover
raise ImportError(
"UMAP visualization requires umap-learn. Install it with: pip install umap-learn"
) from e
reducer = umap.UMAP(n_components=2, metric="cosine",
random_state=random_state)
coords = reducer.fit_transform(embeddings)
if len(codes) != coords.shape[0]:
raise ValueError(
f"codes length ({len(codes)}) does not match embeddings rows ({coords.shape[0]})."
)
groups: list[str] = []
for code in codes:
cleaned = code.strip()
if cleaned.lower() == "death":
groups.append("Death")
else:
groups.append(cleaned[:1].upper() if cleaned else "?")
group_names = sorted({g for g in groups if g != "Death"})
cmap = plt.get_cmap("tab20")
group_to_color: dict[str, object] = {
g: cmap(i % cmap.N) for i, g in enumerate(group_names)
}
group_to_color["Death"] = "grey"
colors = [group_to_color.get(g, "black") for g in groups]
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(1, 1, 1)
ax.scatter(coords[:, 0], coords[:, 1], s=6, alpha=0.7, c=colors)
ax.set_title("UMAP of ICD label embeddings")
ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
fig.tight_layout()
fig.savefig(out_path, dpi=200)
plt.close(fig)
def main() -> int:
parser = argparse.ArgumentParser(
description="Embed ICD-10 disease labels with SapBERT")
parser.add_argument(
"--labels",
default="labels.csv",
help="Path to labels.csv (lines like 'A00 (cholera)')",
)
parser.add_argument(
"--out-dir",
default=".",
help="Output directory for embeddings and metadata",
)
parser.add_argument(
"--model",
default="cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
help="HuggingFace model name",
)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--max-length", type=int, default=25)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run on (cuda or cpu)",
)
parser.add_argument(
"--strict-codes",
action="store_true",
help="Fail if a label code is not ICD10-like (disallows labels like 'Death')",
)
parser.add_argument(
"--umap",
action="store_true",
help="Also save a 2D UMAP scatterplot of the embeddings",
)
parser.add_argument(
"--umap-out",
default=None,
help="Path to save UMAP PNG (default: <out-dir>/icd10_sapbert_umap.png)",
)
parser.add_argument(
"--umap-random-state",
type=int,
default=42,
help="Random seed for UMAP",
)
args = parser.parse_args()
labels = _read_labels(args.labels, strict_codes=args.strict_codes)
texts = [lbl.disease for lbl in labels]
embs = embed_texts(
texts,
model_name=args.model,
batch_size=args.batch_size,
max_length=args.max_length,
device=args.device,
)
os.makedirs(args.out_dir, exist_ok=True)
embs_path = os.path.join(args.out_dir, "icd10_sapbert_embeddings.npy")
meta_path = os.path.join(args.out_dir, "icd10_sapbert_metadata.tsv")
np.save(embs_path, embs)
with open(meta_path, "w", encoding="utf-8", newline="") as f:
w = csv.writer(f, delimiter="\t")
w.writerow(["index", "icd10_code", "disease"])
for i, lbl in enumerate(labels):
w.writerow([i, lbl.code, lbl.disease])
if args.umap:
umap_path = (
args.umap_out
if args.umap_out is not None
else os.path.join(args.out_dir, "icd10_sapbert_umap.png")
)
save_umap_plot(
embs,
[lbl.code for lbl in labels],
out_path=umap_path,
random_state=args.umap_random_state,
)
print(f"Saved UMAP plot: {umap_path}")
print(f"Saved embeddings: {embs_path} (shape={embs.shape})")
print(f"Saved metadata: {meta_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())