Add loss functions and model architecture for time-to-event prediction
- Implemented ExponentialNLLLoss and WeibullNLLLoss in losses.py for negative log-likelihood calculations. - Developed TabularEncoder class in model.py for encoding tabular features. - Created DelphiFork and SapDelphi classes in model.py for time-to-event prediction using transformer architecture. - Added data preparation scripts in prepare_data.R and prepare_data.py for processing UK Biobank data, including handling field mappings and event data extraction.
This commit is contained in:
55
age_encoder.py
Normal file
55
age_encoder.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AgeSinusoidalEncoder(nn.Module):
|
||||
"""
|
||||
Sinusoidal encoder for age.
|
||||
|
||||
Args:
|
||||
n_embd (int): Embedding dimension. Must be even.
|
||||
"""
|
||||
|
||||
def __init__(self, n_embd: int):
|
||||
super().__init__()
|
||||
if n_embd % 2 != 0:
|
||||
raise ValueError("n_embd must be even for sinusoidal encoding.")
|
||||
self.n_embd = n_embd
|
||||
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
|
||||
divisor = torch.pow(10000, i / self.n_embd)
|
||||
self.register_buffer('divisor', divisor)
|
||||
|
||||
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
||||
t_years = ages / 365.25
|
||||
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
|
||||
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
|
||||
# Interleave cos and sin along the last dimension
|
||||
output = torch.zeros(
|
||||
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
|
||||
output[:, :, 0::2] = torch.cos(args)
|
||||
output[:, :, 1::2] = torch.sin(args)
|
||||
return output
|
||||
|
||||
|
||||
class AgeMLPEncoder(nn.Module):
|
||||
"""
|
||||
MLP encoder for age, using sinusoidal encoding as input.
|
||||
|
||||
Args:
|
||||
n_embd (int): Embedding dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, n_embd: int):
|
||||
super().__init__()
|
||||
|
||||
self.sin_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(n_embd, 4 * n_embd),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * n_embd, n_embd),
|
||||
)
|
||||
|
||||
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
||||
x = self.sin_encoder(ages)
|
||||
output = self.mlp(x)
|
||||
return output
|
||||
111
backbones.py
Normal file
111
backbones.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Multi-head self-attention mechanism.
|
||||
|
||||
Args:
|
||||
n_embd (int): Embedding dimension.
|
||||
n_head (int): Number of attention heads.
|
||||
attn_pdrop (float): Attention dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
attn_pdrop: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
|
||||
self.n_head = n_head
|
||||
self.head_dim = n_embd // n_head
|
||||
|
||||
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
|
||||
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
|
||||
self.attn_pdrop = attn_pdrop
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None, # (B, L, L)
|
||||
) -> torch.Tensor:
|
||||
B, L, D = x.shape
|
||||
qkv = self.qkv_proj(x) # (B, L, 3D)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
def reshape_heads(t):
|
||||
# (B, H, L, d)
|
||||
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
||||
|
||||
q = reshape_heads(q)
|
||||
k = reshape_heads(k)
|
||||
v = reshape_heads(v)
|
||||
|
||||
attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_pdrop,
|
||||
) # (B, H, L, d)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
||||
return self.o_proj(attn)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
Transformer block consisting of self-attention and MLP.
|
||||
|
||||
Args:
|
||||
n_embd (int): Embedding dimension.
|
||||
n_head (int): Number of attention heads.
|
||||
pdrop (float): Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
pdrop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
attn_pdrop = pdrop
|
||||
|
||||
self.norm_1 = nn.LayerNorm(n_embd)
|
||||
self.attn = SelfAttention(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
attn_pdrop=attn_pdrop,
|
||||
)
|
||||
self.norm_2 = nn.LayerNorm(n_embd)
|
||||
self.mlp = nn.ModuleDict(dict(
|
||||
c_fc=nn.Linear(n_embd, 4 * n_embd),
|
||||
c_proj=nn.Linear(4 * n_embd, n_embd),
|
||||
act=nn.GELU(),
|
||||
dropout=nn.Dropout(pdrop),
|
||||
))
|
||||
m = self.mlp
|
||||
self.mlpf = lambda x: m.dropout(
|
||||
m.c_proj(m.act(m.c_fc(x))))
|
||||
self.resid_dropout = nn.Dropout(pdrop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Attention
|
||||
h = self.norm_1(x)
|
||||
h = self.attn(h, attn_mask=attn_mask)
|
||||
x = x + self.resid_dropout(h)
|
||||
|
||||
# MLP
|
||||
h = self.norm_2(x)
|
||||
h = self.mlpf(h)
|
||||
x = x + self.resid_dropout(h)
|
||||
|
||||
return x
|
||||
101
dataset.py
Normal file
101
dataset.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
|
||||
class HealthDataset(Dataset):
|
||||
"""
|
||||
Dataset for health records.
|
||||
|
||||
Args:
|
||||
data_prefix (str): Prefix for data files.
|
||||
covariate_list (List[str] | None): List of covariates to include.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_prefix: str,
|
||||
covariate_list: List[str] | None = None,
|
||||
):
|
||||
basic_info = pd.read_csv(
|
||||
f"{data_prefix}_basic_info.csv", index_col='eid')
|
||||
tabular_data = pd.read_csv(f"{data_prefix}_table.csv", index_col='eid')
|
||||
event_data = np.load(f"{data_prefix}_event_data.npy")
|
||||
patient_events = defaultdict(list)
|
||||
vocab_size = 0
|
||||
for patient_id, time_in_days, event_code in event_data:
|
||||
patient_events[patient_id].append((time_in_days, event_code))
|
||||
if event_code > vocab_size:
|
||||
vocab_size = event_code
|
||||
self.n_disease = vocab_size - 1
|
||||
self.basic_info = basic_info.convert_dtypes()
|
||||
self.patient_ids = self.basic_info.index.tolist()
|
||||
self.patient_events = dict(patient_events)
|
||||
|
||||
tabular_data = tabular_data.convert_dtypes()
|
||||
cont_cols = []
|
||||
cate_cols = []
|
||||
self.cate_dims = []
|
||||
if covariate_list is not None:
|
||||
tabular_data = tabular_data[covariate_list]
|
||||
for col in tabular_data.columns:
|
||||
if pd.api.types.is_float_dtype(tabular_data[col]):
|
||||
cont_cols.append(col)
|
||||
elif pd.api.types.is_integer_dtype(tabular_data[col]):
|
||||
series = tabular_data[col]
|
||||
unique_vals = series.dropna().unique()
|
||||
if len(unique_vals) > 11:
|
||||
cont_cols.append(col)
|
||||
else:
|
||||
cate_cols.append(col)
|
||||
self.cate_dims.append(int(series.max()) + 1)
|
||||
|
||||
self.cont_features = tabular_data[cont_cols].to_numpy(
|
||||
dtype=np.float32).copy()
|
||||
self.cate_features = tabular_data[cate_cols].to_numpy(
|
||||
dtype=np.int64).copy()
|
||||
self.n_cont = self.cont_features.shape[1]
|
||||
self.n_cate = self.cate_features.shape[1]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events.get(
|
||||
patient_id, []), key=lambda x: x[0])
|
||||
event_seq = [item[1] for item in records]
|
||||
time_seq = [item[0] for item in records]
|
||||
|
||||
doa = self.basic_info.loc[patient_id, 'date_of_assessment']
|
||||
|
||||
insert_pos = np.searchsorted(time_seq, doa)
|
||||
time_seq.insert(insert_pos, doa)
|
||||
# assuming 1 is the code for 'DOA' event
|
||||
event_seq.insert(insert_pos, 1)
|
||||
event_tensor = torch.tensor(event_seq, dtype=torch.long)
|
||||
time_tensor = torch.tensor(time_seq, dtype=torch.float)
|
||||
cont_tensor = torch.tensor(
|
||||
self.cont_features[idx, :], dtype=torch.float)
|
||||
cate_tensor = torch.tensor(
|
||||
self.cate_features[idx, :], dtype=torch.long)
|
||||
sex = self.basic_info.loc[patient_id, 'sex']
|
||||
|
||||
return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex)
|
||||
|
||||
|
||||
def health_collate_fn(batch):
|
||||
event_seqs, time_seqs, cont_feats, cate_feats, sexes = zip(*batch)
|
||||
event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0)
|
||||
time_batch = pad_sequence(
|
||||
time_seqs, batch_first=True, padding_value=36525.0)
|
||||
cont_batch = torch.stack(cont_feats, dim=0)
|
||||
cont_batch = cont_batch.unsqueeze(1) # (B, 1, n_cont)
|
||||
cate_batch = torch.stack(cate_feats, dim=0)
|
||||
cate_batch = cate_batch.unsqueeze(1) # (B, 1, n_cate)
|
||||
sex_batch = torch.tensor(sexes, dtype=torch.long)
|
||||
return event_batch, time_batch, cont_batch, cate_batch, sex_batch
|
||||
243
embed_icd10.py
Normal file
243
embed_icd10.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Icd10Label:
|
||||
code: str
|
||||
disease: str
|
||||
|
||||
|
||||
_LABEL_RE = re.compile(r"^\s*([A-Z][0-9][0-9][A-Z0-9]{0,2})\s*\((.+)\)\s*$")
|
||||
_CODE_RE = re.compile(r"^[A-Z][A-Z0-9]{1,6}$")
|
||||
|
||||
|
||||
def _read_labels(labels_path: str, *, strict_codes: bool) -> list[Icd10Label]:
|
||||
labels: list[Icd10Label] = []
|
||||
with open(labels_path, "r", encoding="utf-8") as f:
|
||||
for raw_line in f:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
match = _LABEL_RE.match(line)
|
||||
if match is not None:
|
||||
code, disease = match.group(1), match.group(2)
|
||||
else:
|
||||
parts = line.split(maxsplit=1)
|
||||
if len(parts) == 1:
|
||||
# Some label lists include non-ICD entries (e.g., "Death").
|
||||
# Treat these as both code and disease.
|
||||
code = parts[0].strip()
|
||||
disease = code
|
||||
elif len(parts) == 2:
|
||||
code, disease = parts[0].strip(), parts[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized label format: {line!r}. "
|
||||
"Expected like 'A00 (cholera)', 'CXX Unknown Cancer', or 'Death'."
|
||||
)
|
||||
if disease.startswith("(") and disease.endswith(")"):
|
||||
disease = disease[1:-1].strip()
|
||||
|
||||
if strict_codes and not _CODE_RE.match(code):
|
||||
raise ValueError(
|
||||
f"Unrecognized ICD10-like code in label: {line!r} (code={code!r}). "
|
||||
"Re-run without --strict-codes to allow non-ICD labels (e.g., 'Death')."
|
||||
)
|
||||
labels.append(Icd10Label(code=code, disease=disease))
|
||||
if not labels:
|
||||
raise ValueError(f"No labels found in {labels_path!r}.")
|
||||
return labels
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: list[str],
|
||||
*,
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
device: str,
|
||||
) -> np.ndarray:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
all_embs: list[np.ndarray] = []
|
||||
with torch.no_grad():
|
||||
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
|
||||
batch = texts[i: i + batch_size]
|
||||
toks = tokenizer(
|
||||
batch,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
toks = {k: v.to(device) for k, v in toks.items()}
|
||||
# Use CLS token embedding (same as original script).
|
||||
cls_rep = model(**toks)[0][:, 0, :]
|
||||
all_embs.append(cls_rep.detach().cpu().to(torch.float32).numpy())
|
||||
|
||||
return np.concatenate(all_embs, axis=0)
|
||||
|
||||
|
||||
def save_umap_plot(
|
||||
embeddings: np.ndarray,
|
||||
codes: list[str],
|
||||
*,
|
||||
out_path: str,
|
||||
random_state: int = 42,
|
||||
) -> None:
|
||||
try:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"UMAP visualization requires matplotlib. Install it with: pip install matplotlib"
|
||||
) from e
|
||||
|
||||
try:
|
||||
import umap
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"UMAP visualization requires umap-learn. Install it with: pip install umap-learn"
|
||||
) from e
|
||||
|
||||
reducer = umap.UMAP(n_components=2, metric="cosine",
|
||||
random_state=random_state)
|
||||
coords = reducer.fit_transform(embeddings)
|
||||
|
||||
if len(codes) != coords.shape[0]:
|
||||
raise ValueError(
|
||||
f"codes length ({len(codes)}) does not match embeddings rows ({coords.shape[0]})."
|
||||
)
|
||||
|
||||
groups: list[str] = []
|
||||
for code in codes:
|
||||
cleaned = code.strip()
|
||||
if cleaned.lower() == "death":
|
||||
groups.append("Death")
|
||||
else:
|
||||
groups.append(cleaned[:1].upper() if cleaned else "?")
|
||||
|
||||
group_names = sorted({g for g in groups if g != "Death"})
|
||||
cmap = plt.get_cmap("tab20")
|
||||
group_to_color: dict[str, object] = {
|
||||
g: cmap(i % cmap.N) for i, g in enumerate(group_names)
|
||||
}
|
||||
group_to_color["Death"] = "grey"
|
||||
colors = [group_to_color.get(g, "black") for g in groups]
|
||||
|
||||
fig = plt.figure(figsize=(10, 8))
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
ax.scatter(coords[:, 0], coords[:, 1], s=6, alpha=0.7, c=colors)
|
||||
ax.set_title("UMAP of ICD label embeddings")
|
||||
ax.set_xlabel("UMAP-1")
|
||||
ax.set_ylabel("UMAP-2")
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Embed ICD-10 disease labels with SapBERT")
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
default="labels.csv",
|
||||
help="Path to labels.csv (lines like 'A00 (cholera)')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
default=".",
|
||||
help="Output directory for embeddings and metadata",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
|
||||
help="HuggingFace model name",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=128)
|
||||
parser.add_argument("--max-length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to run on (cuda or cpu)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-codes",
|
||||
action="store_true",
|
||||
help="Fail if a label code is not ICD10-like (disallows labels like 'Death')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--umap",
|
||||
action="store_true",
|
||||
help="Also save a 2D UMAP scatterplot of the embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--umap-out",
|
||||
default=None,
|
||||
help="Path to save UMAP PNG (default: <out-dir>/icd10_sapbert_umap.png)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--umap-random-state",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for UMAP",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
labels = _read_labels(args.labels, strict_codes=args.strict_codes)
|
||||
texts = [lbl.disease for lbl in labels]
|
||||
embs = embed_texts(
|
||||
texts,
|
||||
model_name=args.model,
|
||||
batch_size=args.batch_size,
|
||||
max_length=args.max_length,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
embs_path = os.path.join(args.out_dir, "icd10_sapbert_embeddings.npy")
|
||||
meta_path = os.path.join(args.out_dir, "icd10_sapbert_metadata.tsv")
|
||||
|
||||
np.save(embs_path, embs)
|
||||
|
||||
with open(meta_path, "w", encoding="utf-8", newline="") as f:
|
||||
w = csv.writer(f, delimiter="\t")
|
||||
w.writerow(["index", "icd10_code", "disease"])
|
||||
for i, lbl in enumerate(labels):
|
||||
w.writerow([i, lbl.code, lbl.disease])
|
||||
|
||||
if args.umap:
|
||||
umap_path = (
|
||||
args.umap_out
|
||||
if args.umap_out is not None
|
||||
else os.path.join(args.out_dir, "icd10_sapbert_umap.png")
|
||||
)
|
||||
save_umap_plot(
|
||||
embs,
|
||||
[lbl.code for lbl in labels],
|
||||
out_path=umap_path,
|
||||
random_state=args.umap_random_state,
|
||||
)
|
||||
print(f"Saved UMAP plot: {umap_path}")
|
||||
|
||||
print(f"Saved embeddings: {embs_path} (shape={embs.shape})")
|
||||
print(f"Saved metadata: {meta_path}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1237
field_id.txt
Normal file
1237
field_id.txt
Normal file
File diff suppressed because it is too large
Load Diff
74
field_ids_enriched.csv
Normal file
74
field_ids_enriched.csv
Normal file
@@ -0,0 +1,74 @@
|
||||
field_instance,full_name,var_name
|
||||
31-0.0,Sex,sex
|
||||
34-0.0,Year of birth,year
|
||||
48-0.0,Waist circumference,waist_circumference
|
||||
49-0.0,Hip circumference,hip_circumference
|
||||
50-0.0,Standing height,standing_height
|
||||
52-0.0,Month of birth,month
|
||||
53-0.0,Date of attending assessment centre,date_of_assessment
|
||||
74-0.0,Fasting time,fasting_time
|
||||
102-0.0,Pulse rate automated reading,pulse_rate
|
||||
1239-0.0,Current tobacco smoking,smoking
|
||||
1558-0.0,Alcohol intake frequency.,alcohol
|
||||
4079-0.0,Diastolic blood pressure automated reading,dbp
|
||||
4080-0.0,Systolic blood pressure automated reading,sbp
|
||||
20150-0.0,Forced expiratory volume in 1-second (FEV1) Best measure,fev1_best
|
||||
20151-0.0,Forced vital capacity (FVC) Best measure,fvc_best
|
||||
20258-0.0,FEV1/ FVC ratio Z-score,fev1_fvc_ratio
|
||||
21001-0.0,Body mass index (BMI),bmi
|
||||
21003-0.0,Age when attended assessment centre,age_at_assessment
|
||||
30000-0.0,White blood cell (leukocyte) count,WBC
|
||||
30010-0.0,Red blood cell (erythrocyte) count,RBC
|
||||
30020-0.0,Haemoglobin concentration,hemoglobin
|
||||
30030-0.0,Haematocrit percentage,hematocrit
|
||||
30040-0.0,Mean corpuscular volume,MCV
|
||||
30050-0.0,Mean corpuscular haemoglobin,MCH
|
||||
30060-0.0,Mean corpuscular haemoglobin concentration,MCHC
|
||||
30080-0.0,Platelet count,Pc
|
||||
30100-0.0,Mean platelet (thrombocyte) volume,MPV
|
||||
30120-0.0,Lymphocyte count,LymC
|
||||
30130-0.0,Monocyte count,MonC
|
||||
30140-0.0,Neutrophill count,NeuC
|
||||
30150-0.0,Eosinophill count,EosC
|
||||
30160-0.0,Basophill count,BasC
|
||||
30170-0.0,Nucleated red blood cell count,nRBC
|
||||
30250-0.0,Reticulocyte count,RC
|
||||
30260-0.0,Mean reticulocyte volume,MRV
|
||||
30270-0.0,Mean sphered cell volume,MSCV
|
||||
30280-0.0,Immature reticulocyte fraction,IRF
|
||||
30300-0.0,High light scatter reticulocyte count,HLSRC
|
||||
30500-0.0,Microalbumin in urine,MicU
|
||||
30510-0.0,Creatinine (enzymatic) in urine,CreaU
|
||||
30520-0.0,Potassium in urine,PotU
|
||||
30530-0.0,Sodium in urine,SodU
|
||||
30600-0.0,Albumin,Alb
|
||||
30610-0.0,Alkaline phosphatase,ALP
|
||||
30620-0.0,Alanine aminotransferase,Alanine
|
||||
30630-0.0,Apolipoprotein A,ApoA
|
||||
30640-0.0,Apolipoprotein B,ApoB
|
||||
30650-0.0,Aspartate aminotransferase,AA
|
||||
30660-0.0,Direct bilirubin,DBil
|
||||
30670-0.0,Urea,Urea
|
||||
30680-0.0,Calcium,Calcium
|
||||
30690-0.0,Cholesterol,Cholesterol
|
||||
30700-0.0,Creatinine,Creatinine
|
||||
30710-0.0,C-reactive protein,CRP
|
||||
30720-0.0,Cystatin C,CystatinC
|
||||
30730-0.0,Gamma glutamyltransferase,GGT
|
||||
30740-0.0,Glucose,Glu
|
||||
30750-0.0,Glycated haemoglobin (HbA1c),HbA1c
|
||||
30760-0.0,HDL cholesterol,HDL
|
||||
30770-0.0,IGF-1,IGF1
|
||||
30780-0.0,LDL direct,LDL
|
||||
30790-0.0,Lipoprotein A,LpA
|
||||
30800-0.0,Oestradiol,Oestradiol
|
||||
30810-0.0,Phosphate,Phosphate
|
||||
30820-0.0,Rheumatoid factor,Rheu
|
||||
30830-0.0,SHBG,SHBG
|
||||
30840-0.0,Total bilirubin,TotalBil
|
||||
30850-0.0,Testosterone,Testosterone
|
||||
30860-0.0,Total protein,TotalProtein
|
||||
30870-0.0,Triglycerides,Tri
|
||||
30880-0.0,Urate,Urate
|
||||
30890-0.0,Vitamin D,VitaminD
|
||||
40000-0.0,Date of death,Death
|
||||
|
1129
icd10_codes_mod.tsv
Normal file
1129
icd10_codes_mod.tsv
Normal file
File diff suppressed because it is too large
Load Diff
1257
labels.csv
Normal file
1257
labels.csv
Normal file
File diff suppressed because it is too large
Load Diff
210
losses.py
Normal file
210
losses.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import math
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pair extraction (utility; not used by the losses below)
|
||||
# ============================================================
|
||||
def get_valid_pairs_and_dt(
|
||||
event_seqs: torch.Tensor,
|
||||
time_seqs: torch.Tensor,
|
||||
n_tech_tokens: int
|
||||
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Extract valid event pairs (prev -> next) and compute dt in years.
|
||||
|
||||
Args:
|
||||
event_seqs (torch.Tensor): Event sequences.
|
||||
time_seqs (torch.Tensor): Time sequences.
|
||||
n_tech_tokens (int): Number of technical tokens.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
|
||||
|
||||
Notes:
|
||||
- Assumes strict right-padding.
|
||||
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
|
||||
- Filters to strictly positive dt.
|
||||
"""
|
||||
real_mask = event_seqs >= 1
|
||||
idx = real_mask.nonzero(as_tuple=False)
|
||||
|
||||
if idx.size(0) <= 1:
|
||||
return None
|
||||
|
||||
same_batch = idx[1:, 0] == idx[:-1, 0]
|
||||
if not same_batch.any():
|
||||
return None
|
||||
|
||||
prev_idx = idx[:-1][same_batch]
|
||||
next_idx = idx[1:][same_batch]
|
||||
|
||||
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
||||
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
|
||||
if not valid_target.any():
|
||||
return None
|
||||
|
||||
prev_idx = prev_idx[valid_target]
|
||||
next_idx = next_idx[valid_target]
|
||||
|
||||
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
|
||||
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
|
||||
|
||||
dt = (time_seqs[b_next, t_next] -
|
||||
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
|
||||
valid_dt = dt > 0
|
||||
if not valid_dt.any():
|
||||
return None
|
||||
|
||||
dt = dt[valid_dt]
|
||||
b_prev = b_prev[valid_dt]
|
||||
t_prev = t_prev[valid_dt]
|
||||
b_next = b_next[valid_dt]
|
||||
t_next = t_next[valid_dt]
|
||||
|
||||
return dt, b_prev, t_prev, b_next, t_next
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
|
||||
# ============================================================
|
||||
class ExponentialNLLLoss(nn.Module):
|
||||
"""
|
||||
Competing risks exponential likelihood.
|
||||
|
||||
The negative log-likelihood is given by:
|
||||
|
||||
.. math::
|
||||
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
|
||||
|
||||
Args:
|
||||
eps (float): Small epsilon for numerical stability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lambda_reg: float = 0.0,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.lambda_reg = lambda_reg
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (M, K) tensor of logits.
|
||||
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
|
||||
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
|
||||
reduction (str): 'mean', 'sum', or 'none'.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
|
||||
"""
|
||||
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
|
||||
hazards = F.softplus(logits) + self.eps # (M,K)
|
||||
hazard_event = hazards.gather(
|
||||
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
|
||||
total_hazard = hazards.sum(dim=1) # (M,)
|
||||
log_hazards = torch.log(hazards) # (M, K)
|
||||
nll = -torch.log(hazard_event) + total_hazard * dt
|
||||
|
||||
if reduction == "mean":
|
||||
nll = nll.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll.sum()
|
||||
|
||||
reg = F.cross_entropy(log_hazards, target_events,
|
||||
reduction="mean") * self.lambda_reg
|
||||
return nll, reg
|
||||
|
||||
|
||||
class WeibullNLLLoss(nn.Module):
|
||||
"""
|
||||
Weibull hazard in t.
|
||||
|
||||
.. math::
|
||||
\\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k}
|
||||
|
||||
\\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1}
|
||||
|
||||
Args:
|
||||
eps (float): Small epsilon for numerical stability.
|
||||
lambda_reg (float): Regularization weight.
|
||||
use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples.
|
||||
near_integer_eps_years (float): Near-integer threshold in years.
|
||||
interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years.
|
||||
min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
lambda_reg: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.lambda_reg = lambda_reg
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target_events: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (M, K, 2) tensor of logits.
|
||||
target_events (torch.Tensor): (M,) tensor of target events.
|
||||
dt (torch.Tensor): (M,) tensor of time intervals.
|
||||
reduction (str): 'mean', 'sum', or 'none'.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization).
|
||||
"""
|
||||
shapes = F.softplus(logits[..., 0]) + self.eps # (M,K)
|
||||
scales = F.softplus(logits[..., 1]) + self.eps # (M,K)
|
||||
eps = self.eps
|
||||
t = torch.clamp(dt, min=eps)
|
||||
|
||||
t_mat = t.unsqueeze(1) # (M,1)
|
||||
|
||||
# cumulative hazard (M,K)
|
||||
cum_hazard = scales * t_mat.pow(shapes)
|
||||
|
||||
# hazard (M,K)
|
||||
hazard = shapes * scales * t_mat.pow(shapes - 1.0)
|
||||
|
||||
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
||||
# Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t))
|
||||
# NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t)
|
||||
nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1)
|
||||
|
||||
if reduction == "mean":
|
||||
nll = nll.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll.sum()
|
||||
|
||||
reg = shapes.new_zeros(())
|
||||
if self.lambda_reg > 0:
|
||||
reg = self.lambda_reg * (
|
||||
(torch.log(scales + eps) ** 2).mean() +
|
||||
(torch.log(shapes + eps) ** 2).mean()
|
||||
)
|
||||
return nll, reg
|
||||
440
model.py
Normal file
440
model.py
Normal file
@@ -0,0 +1,440 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
from backbones import Block
|
||||
from typing import Optional, List
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TabularEncoder(nn.Module):
|
||||
"""
|
||||
Encoder for tabular features (continuous and categorical).
|
||||
|
||||
Args:
|
||||
n_embd (int): Embedding dimension.
|
||||
n_cont (int): Number of continuous features.
|
||||
n_cate (int): Number of categorical features.
|
||||
cate_dims (List[int]): List of dimensions for each categorical feature.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_cont: int,
|
||||
n_cate: int,
|
||||
cate_dims: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.n_embd = n_embd
|
||||
self.n_cont = n_cont
|
||||
self.n_cate = n_cate
|
||||
|
||||
if n_cont > 0:
|
||||
hidden = 2 * n_embd
|
||||
self.cont_mlp = nn.Sequential(
|
||||
nn.Linear(2 * n_cont, hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden, n_embd),
|
||||
)
|
||||
else:
|
||||
self.cont_mlp = None
|
||||
|
||||
if n_cate > 0:
|
||||
assert len(cate_dims) == n_cate, \
|
||||
"Length of cate_dims must match n_cate"
|
||||
self.cate_embds = nn.ModuleList([
|
||||
nn.Embedding(dim, n_embd) for dim in cate_dims
|
||||
])
|
||||
self.cate_mask_embds = nn.ModuleList([
|
||||
nn.Embedding(2, n_embd) for _ in range(n_cate)
|
||||
])
|
||||
else:
|
||||
self.cate_embds = None
|
||||
self.cate_mask_embds = None
|
||||
|
||||
self.cont_mask_proj = (
|
||||
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
|
||||
)
|
||||
|
||||
self.film = nn.Sequential(
|
||||
nn.Linear(n_embd, 2 * n_embd),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * n_embd, 2 * n_embd),
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
self.out_ln = nn.LayerNorm(n_embd)
|
||||
|
||||
# Zero-init the last layer of FiLM to start with identity modulation
|
||||
with torch.no_grad():
|
||||
last_linear = self.film[-1]
|
||||
last_linear.weight.zero_()
|
||||
last_linear.bias.zero_()
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cont_features: Optional[torch.Tensor],
|
||||
cate_features: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.n_cont == 0 and self.n_cate == 0:
|
||||
# infer (B, L) from whichever input is not None
|
||||
if cont_features is not None:
|
||||
B, L = cont_features.shape[:2]
|
||||
device = cont_features.device
|
||||
elif cate_features is not None:
|
||||
B, L = cate_features.shape[:2]
|
||||
device = cate_features.device
|
||||
else:
|
||||
raise ValueError(
|
||||
"TabularEncoder received no features but cannot infer (B, L)."
|
||||
)
|
||||
return torch.zeros(B, L, self.n_embd, device=device)
|
||||
|
||||
value_parts: List[torch.Tensor] = []
|
||||
mask_parts: List[torch.Tensor] = []
|
||||
|
||||
if self.n_cont > 0 and cont_features is not None:
|
||||
if cont_features.dim() != 3:
|
||||
raise ValueError(
|
||||
"cont_features must be 3D tensor (B, L, n_cont)")
|
||||
B, L, D_cont = cont_features.shape
|
||||
if D_cont != self.n_cont:
|
||||
raise ValueError(
|
||||
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
|
||||
|
||||
cont_mask = (~torch.isnan(cont_features)).float()
|
||||
cont_filled = torch.nan_to_num(cont_features, nan=0.0)
|
||||
cont_joint = torch.cat([cont_filled, cont_mask], dim=-1)
|
||||
h_cont_value = self.cont_mlp(cont_joint)
|
||||
value_parts.append(h_cont_value)
|
||||
|
||||
if self.cont_mask_proj is not None:
|
||||
h_cont_mask = self.cont_mask_proj(cont_mask)
|
||||
mask_parts.append(h_cont_mask)
|
||||
|
||||
if self.n_cate > 0 and cate_features is not None:
|
||||
if cate_features.dim() != 3:
|
||||
raise ValueError(
|
||||
"cate_features must be 3D tensor (B, L, n_cate)")
|
||||
B, L, D_cate = cate_features.shape
|
||||
if D_cate != self.n_cate:
|
||||
raise ValueError(
|
||||
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
|
||||
|
||||
for i in range(self.n_cate):
|
||||
cate_feat = cate_features[:, :, i]
|
||||
cate_embd = self.cate_embds[i]
|
||||
cate_mask_embd = self.cate_mask_embds[i]
|
||||
|
||||
cate_value = cate_embd(
|
||||
torch.clamp(cate_feat, min=0))
|
||||
cate_mask = (cate_feat > 0).long()
|
||||
cate_mask_value = cate_mask_embd(cate_mask)
|
||||
|
||||
value_parts.append(cate_value)
|
||||
mask_parts.append(cate_mask_value)
|
||||
|
||||
if not value_parts:
|
||||
if cont_features is not None:
|
||||
B, L = cont_features.shape[:2]
|
||||
device = cont_features.device
|
||||
elif cate_features is not None:
|
||||
B, L = cate_features.shape[:2]
|
||||
device = cate_features.device
|
||||
else:
|
||||
raise ValueError("No features provided to TabularEncoder.")
|
||||
return torch.zeros(B, L, self.n_embd, device=device)
|
||||
|
||||
h_value = torch.stack(value_parts, dim=0).mean(dim=0)
|
||||
h_mask = torch.stack(mask_parts, dim=0).mean(dim=0)
|
||||
h_mask_flat = h_mask.view(-1, self.n_embd)
|
||||
film_params = self.film(h_mask_flat)
|
||||
gamma_delta, beta = film_params.chunk(2, dim=-1)
|
||||
gamma = 1.0 + gamma_delta
|
||||
h_value_flat = h_value.view(-1, self.n_embd)
|
||||
h_out = gamma * h_value_flat + beta
|
||||
h_out = h_out.view(B, L, self.n_embd)
|
||||
h_out = self.out_ln(h_out)
|
||||
return h_out
|
||||
|
||||
|
||||
def _build_time_padding_mask(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_i = time_seq.unsqueeze(-1)
|
||||
t_j = time_seq.unsqueeze(1)
|
||||
time_mask = (t_j <= t_i) # allow attending only to past or current
|
||||
key_is_valid = (event_seq != 0) # disallow padded positions
|
||||
allowed = time_mask & key_is_valid.unsqueeze(1)
|
||||
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
|
||||
return attn_mask.unsqueeze(1) # (B, 1, L, L)
|
||||
|
||||
|
||||
class DelphiFork(nn.Module):
|
||||
"""
|
||||
DelphiFork model for time-to-event prediction.
|
||||
|
||||
Args:
|
||||
n_disease (int): Number of disease tokens.
|
||||
n_tech_tokens (int): Number of technical tokens.
|
||||
n_embd (int): Embedding dimension.
|
||||
n_head (int): Number of attention heads.
|
||||
n_layer (int): Number of transformer layers.
|
||||
n_cont (int): Number of continuous features.
|
||||
n_cate (int): Number of categorical features.
|
||||
cate_dims (List[int]): List of dimensions for each categorical feature.
|
||||
age_encoder_type (str): Type of age encoder ("sinusoidal" or "mlp").
|
||||
pdrop (float): Dropout probability.
|
||||
token_pdrop (float): Token dropout probability.
|
||||
n_dim (int): Dimension of theta parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_disease: int,
|
||||
n_tech_tokens: int,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
n_cont: int,
|
||||
n_cate: int,
|
||||
cate_dims: List[int],
|
||||
age_encoder_type: str = "sinusoidal",
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = n_disease + n_tech_tokens
|
||||
self.n_tech_tokens = n_tech_tokens
|
||||
self.n_disease = n_disease
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
|
||||
self.token_embedding = nn.Embedding(
|
||||
self.vocab_size, n_embd, padding_idx=0)
|
||||
if age_encoder_type == "sinusoidal":
|
||||
self.age_encoder = AgeSinusoidalEncoder(n_embd)
|
||||
elif age_encoder_type == "mlp":
|
||||
self.age_encoder = AgeMLPEncoder(n_embd)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported age_encoder_type: {age_encoder_type}")
|
||||
self.sex_encoder = nn.Embedding(2, n_embd)
|
||||
self.tabular_encoder = TabularEncoder(
|
||||
n_embd, n_cont, n_cate, cate_dims)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
pdrop=pdrop,
|
||||
) for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
event_seq: torch.Tensor, # (B, L)
|
||||
time_seq: torch.Tensor, # (B, L)
|
||||
sex: torch.Tensor, # (B,)
|
||||
cont_seq: torch.Tensor, # (B, Lc, n_cont)
|
||||
cate_seq: torch.Tensor, # (B, Lc, n_cate)
|
||||
b_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
t_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
) -> torch.Tensor:
|
||||
token_embds = self.token_embedding(event_seq) # (B, L, D)
|
||||
age_embds = self.age_encoder(time_seq) # (B, L, D)
|
||||
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
|
||||
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
|
||||
mask = (event_seq == 1) # (B, L)
|
||||
B, L = event_seq.shape
|
||||
Lc = table_embds.size(1)
|
||||
D = table_embds.size(2)
|
||||
|
||||
# occ[b, t] = 第几次出现(从0开始);非mask位置值无意义,后面会置0
|
||||
# (B, L), DOA: 0,1,2,...
|
||||
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
|
||||
|
||||
# 将超过 Lc-1 的部分截断;并把非mask位置强制为 0(避免无意义 gather)
|
||||
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
|
||||
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
|
||||
|
||||
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
|
||||
tab_inject = table_embds.gather(
|
||||
dim=1,
|
||||
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
|
||||
)
|
||||
# 只在 mask==True 的位置替换
|
||||
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
|
||||
|
||||
x = final_embds + age_embds + sex_embds # (B, L, D)
|
||||
x = self.token_dropout(x)
|
||||
attn_mask = _build_time_padding_mask(
|
||||
event_seq, time_seq)
|
||||
for block in self.blocks:
|
||||
x = block(x, attn_mask=attn_mask)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
|
||||
theta = self.theta_proj(c) # (M, N_disease * n_dim)
|
||||
theta = theta.view(M, self.n_disease, self.n_dim)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SapDelphi(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_disease: int,
|
||||
n_tech_tokens: int,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
n_cont: int,
|
||||
n_cate: int,
|
||||
cate_dims: List[int],
|
||||
age_encoder_type: str = "sinusoidal",
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
pretrained_weights_path: Optional[str] = None, # 新增参数
|
||||
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = n_disease + n_tech_tokens
|
||||
self.n_tech_tokens = n_tech_tokens
|
||||
self.n_disease = n_disease
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
|
||||
if pretrained_weights_path is not None:
|
||||
print(
|
||||
f"Loading pretrained embeddings from {pretrained_weights_path}...")
|
||||
bert_weights = np.load(pretrained_weights_path)
|
||||
bert_weights = torch.tensor(bert_weights, dtype=torch.float32)
|
||||
|
||||
vocab_dim = bert_weights.shape[1] # 通常是 768
|
||||
|
||||
pad_emb = torch.zeros(1, vocab_dim)
|
||||
tech_embs = nn.init.normal_(torch.empty(
|
||||
n_tech_tokens-1, vocab_dim))
|
||||
full_emb_weights = torch.cat(
|
||||
[pad_emb, tech_embs, bert_weights], dim=0)
|
||||
self.token_embedding = nn.Embedding.from_pretrained(
|
||||
full_emb_weights, freeze=freeze_embeddings)
|
||||
print("Pretrained embeddings loaded.")
|
||||
if vocab_dim != n_embd:
|
||||
self.emb_proj = nn.Sequential(
|
||||
nn.Linear(vocab_dim, n_embd, bias=False),
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Dropout(pdrop),
|
||||
)
|
||||
else:
|
||||
self.emb_proj = nn.Identity()
|
||||
else:
|
||||
self.token_embedding = nn.Embedding(
|
||||
self.vocab_size, n_embd, padding_idx=0)
|
||||
self.emb_proj = nn.Identity()
|
||||
|
||||
if age_encoder_type == "sinusoidal":
|
||||
self.age_encoder = AgeSinusoidalEncoder(n_embd)
|
||||
elif age_encoder_type == "mlp":
|
||||
self.age_encoder = AgeMLPEncoder(n_embd)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported age_encoder_type: {age_encoder_type}")
|
||||
self.sex_encoder = nn.Embedding(2, n_embd)
|
||||
self.tabular_encoder = TabularEncoder(
|
||||
n_embd, n_cont, n_cate, cate_dims)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
pdrop=pdrop,
|
||||
) for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
event_seq: torch.Tensor, # (B, L)
|
||||
time_seq: torch.Tensor, # (B, L)
|
||||
sex: torch.Tensor, # (B,)
|
||||
cont_seq: torch.Tensor, # (B, Lc, n_cont)
|
||||
cate_seq: torch.Tensor, # (B, Lc, n_cate)
|
||||
b_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
t_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
) -> torch.Tensor:
|
||||
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
|
||||
token_embds = self.emb_proj(token_embds) # (B, L, D)
|
||||
age_embds = self.age_encoder(time_seq) # (B, L, D)
|
||||
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
|
||||
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
|
||||
mask = (event_seq == 1) # (B, L)
|
||||
B, L = event_seq.shape
|
||||
Lc = table_embds.size(1)
|
||||
D = table_embds.size(2)
|
||||
|
||||
# occ[b, t] = 第几次出现(从0开始);非mask位置值无意义,后面会置0
|
||||
# (B, L), DOA: 0,1,2,...
|
||||
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
|
||||
|
||||
# 将超过 Lc-1 的部分截断;并把非mask位置强制为 0(避免无意义 gather)
|
||||
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
|
||||
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
|
||||
|
||||
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
|
||||
tab_inject = table_embds.gather(
|
||||
dim=1,
|
||||
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
|
||||
)
|
||||
# 只在 mask==True 的位置替换
|
||||
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
|
||||
|
||||
x = final_embds + age_embds + sex_embds # (B, L, D)
|
||||
x = self.token_dropout(x)
|
||||
attn_mask = _build_time_padding_mask(
|
||||
event_seq, time_seq)
|
||||
for block in self.blocks:
|
||||
x = block(x, attn_mask=attn_mask)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
|
||||
theta = self.theta_proj(c) # (M, N_disease * n_dim)
|
||||
theta = theta.view(M, self.n_disease, self.n_dim)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
26
prepare_data.R
Normal file
26
prepare_data.R
Normal file
@@ -0,0 +1,26 @@
|
||||
library(data.table)
|
||||
setDTthreads(40)
|
||||
library(readr)
|
||||
field_id <- read.csv("field_id.txt", header = FALSE)
|
||||
uid <- field_id$V1
|
||||
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HHdata_221103_0512.csv"
|
||||
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
|
||||
all_names <- names(header_dt)
|
||||
keep_names <- intersect(all_names,uid)
|
||||
ukb_disease <- fread(big_path,
|
||||
select = keep_names,
|
||||
showProgress = TRUE)
|
||||
|
||||
field_id <- read.csv("field_id.txt", header = FALSE)
|
||||
uid <- field_id$V1
|
||||
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HH_data_220812_0512.csv"
|
||||
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
|
||||
all_names <- names(header_dt)
|
||||
keep_names <- intersect(all_names,uid)
|
||||
ukb_others <- fread(big_path,
|
||||
select = keep_names,
|
||||
showProgress = TRUE)
|
||||
|
||||
# merge disease and other data by "eid"
|
||||
ukb_data <- merge(ukb_disease, ukb_others, by = "eid", all = TRUE)
|
||||
fwrite(ukb_data, "ukb_data.csv")
|
||||
215
prepare_data.py
Normal file
215
prepare_data.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import pandas as pd # Pandas for data manipulation
|
||||
import tqdm # Progress bar for chunk processing
|
||||
import numpy as np # Numerical operations
|
||||
|
||||
# CSV mapping field IDs to human-readable names
|
||||
field_map_file = "field_ids_enriched.csv"
|
||||
field_dict = {} # Map original field ID -> new column name
|
||||
tabular_fields = [] # List of tabular feature column names
|
||||
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
|
||||
next(f) # skip header line
|
||||
for line in f: # Iterate through lines
|
||||
parts = line.strip().split(",") # Split by CSV commas
|
||||
if len(parts) >= 3: # Ensure we have at least id and name columns (fix: was >=2)
|
||||
# Original field identifier (e.g., "34-0.0")
|
||||
field_id = parts[0]
|
||||
field_name = parts[2] # Human-readable column name
|
||||
field_dict[field_id] = field_name # Record the mapping
|
||||
# Track as a potential tabular feature
|
||||
tabular_fields.append(field_name)
|
||||
# Exclude raw date parts and target columns
|
||||
exclude_fields = ['year', 'month', 'Death', 'age_at_assessment']
|
||||
tabular_fields = [
|
||||
# Filter out excluded columns
|
||||
field for field in tabular_fields if field not in exclude_fields]
|
||||
|
||||
# TSV mapping field IDs to ICD10-related date columns
|
||||
field_to_icd_map = "icd10_codes_mod.tsv"
|
||||
# Date-like variables to be converted to offsets
|
||||
date_vars = []
|
||||
with open(field_to_icd_map, "r", encoding="utf-8") as f: # Open ICD10 mapping
|
||||
for line in f: # Iterate each mapping row
|
||||
parts = line.strip().split() # Split on whitespace for TSV
|
||||
if len(parts) >= 6: # Guard against malformed lines
|
||||
# Map field ID to the date column name
|
||||
field_dict[parts[0]] = parts[5]
|
||||
date_vars.append(parts[5]) # Track date column names in order
|
||||
|
||||
for j in range(17): # Map up to 17 cancer entry slots (dates and types)
|
||||
# Cancer diagnosis date slot j
|
||||
field_dict[f'40005-{j}.0'] = f'cancer_date_{j}'
|
||||
field_dict[f'40006-{j}.0'] = f'cancer_type_{j}' # Cancer type/code slot j
|
||||
|
||||
# Number of ICD-related date columns before adding extras
|
||||
len_icd = len(date_vars)
|
||||
date_vars.extend(['Death', 'date_of_assessment'] + # Add outcome date and assessment date
|
||||
# Add cancer date columns
|
||||
[f'cancer_date_{j}' for j in range(17)])
|
||||
|
||||
labels_file = "labels.csv" # File listing label codes
|
||||
label_dict = {} # Map code string -> integer label id
|
||||
with open(labels_file, "r", encoding="utf-8") as f: # Open labels file
|
||||
for idx, line in enumerate(f): # Enumerate to assign incremental label IDs
|
||||
parts = line.strip().split(' ') # Split by space
|
||||
if parts and parts[0]: # Guard against empty lines
|
||||
# Map code to index (0 for padding, 1 for checkup)
|
||||
label_dict[parts[0]] = idx + 2
|
||||
|
||||
event_list = [] # Accumulator for event arrays across chunks
|
||||
tabular_list = [] # Accumulator for tabular feature DataFrames across chunks
|
||||
ukb_iterator = pd.read_csv( # Stream UK Biobank data in chunks
|
||||
"ukb_data.csv",
|
||||
sep=',',
|
||||
chunksize=10000, # Stream file in manageable chunks to reduce memory footprint
|
||||
# First column (participant ID) becomes DataFrame index
|
||||
index_col=0,
|
||||
low_memory=False # Disable type inference optimization for consistent dtypes
|
||||
)
|
||||
# Iterate chunks with progress
|
||||
for ukb_chunk in tqdm.tqdm(ukb_iterator, desc="Processing UK Biobank data"):
|
||||
# Rename columns to friendly names
|
||||
ukb_chunk = ukb_chunk.rename(columns=field_dict)
|
||||
# Require sex to be present
|
||||
ukb_chunk.dropna(subset=['sex'], inplace=True)
|
||||
|
||||
# Construct date of birth from year and month (day fixed to 1)
|
||||
ukb_chunk['day'] = 1
|
||||
ukb_chunk['dob'] = pd.to_datetime(
|
||||
# Guard against malformed dates
|
||||
ukb_chunk[['year', 'month', 'day']], errors='coerce'
|
||||
)
|
||||
del ukb_chunk['day']
|
||||
|
||||
# Use only date variables that actually exist in the current chunk
|
||||
present_date_vars = [c for c in date_vars if c in ukb_chunk.columns]
|
||||
|
||||
# Convert date-like columns to datetime and compute day offsets from dob
|
||||
if present_date_vars:
|
||||
date_cols = ukb_chunk[present_date_vars].apply(
|
||||
pd.to_datetime, format="%Y-%m-%d", errors='coerce' # Parse dates safely
|
||||
)
|
||||
date_cols_days = date_cols.sub(
|
||||
ukb_chunk['dob'], axis=0) # Timedelta relative to dob
|
||||
ukb_chunk[present_date_vars] = date_cols_days.apply(
|
||||
lambda x: x.dt.days) # Store days since dob
|
||||
|
||||
# Append tabular features (use only columns that exist)
|
||||
present_tabular_fields = [
|
||||
c for c in tabular_fields if c in ukb_chunk.columns]
|
||||
tabular_list.append(ukb_chunk[present_tabular_fields].copy())
|
||||
|
||||
# Process disease events from ICD10-related date columns
|
||||
# Take ICD date cols plus 'Death' if present by order
|
||||
icd10_cols = present_date_vars[:len_icd + 1]
|
||||
# Melt to long form: participant id, event code (column name), and days offset
|
||||
melted_df = ukb_chunk.reset_index().melt(
|
||||
id_vars=['eid'],
|
||||
value_vars=icd10_cols,
|
||||
var_name='event_code',
|
||||
value_name='days',
|
||||
)
|
||||
# Require non-missing day offsets
|
||||
melted_df.dropna(subset=['days'], inplace=True)
|
||||
if not melted_df.empty:
|
||||
melted_df['label'] = melted_df['event_code'].map(
|
||||
label_dict) # Map event code to numeric label
|
||||
# Fix: ensure labels exist before int cast
|
||||
melted_df.dropna(subset=['label'], inplace=True)
|
||||
if not melted_df.empty:
|
||||
event_list.append(
|
||||
melted_df[['eid', 'days', 'label']]
|
||||
.astype(int) # Safe now since label and days are non-null
|
||||
.to_numpy()
|
||||
)
|
||||
|
||||
# Optimized cancer processing without wide_to_long
|
||||
cancer_frames = []
|
||||
for j in range(17):
|
||||
d_col = f'cancer_date_{j}'
|
||||
t_col = f'cancer_type_{j}'
|
||||
if d_col in ukb_chunk.columns and t_col in ukb_chunk.columns:
|
||||
# Filter rows where both date and type are present
|
||||
mask = ukb_chunk[d_col].notna() & ukb_chunk[t_col].notna()
|
||||
if mask.any():
|
||||
subset_idx = ukb_chunk.index[mask]
|
||||
subset_days = ukb_chunk.loc[mask, d_col]
|
||||
subset_type = ukb_chunk.loc[mask, t_col]
|
||||
|
||||
# Map cancer type to label
|
||||
# Use first 3 chars
|
||||
cancer_codes = subset_type.str.slice(0, 3)
|
||||
labels = cancer_codes.map(label_dict)
|
||||
|
||||
# Filter valid labels
|
||||
valid_label_mask = labels.notna()
|
||||
if valid_label_mask.any():
|
||||
# Create array: eid, days, label
|
||||
# Ensure types are correct for numpy
|
||||
c_eids = subset_idx[valid_label_mask].values
|
||||
c_days = subset_days[valid_label_mask].values
|
||||
c_labels = labels[valid_label_mask].values
|
||||
|
||||
# Stack
|
||||
chunk_cancer_data = np.column_stack(
|
||||
(c_eids, c_days, c_labels))
|
||||
cancer_frames.append(chunk_cancer_data)
|
||||
|
||||
if cancer_frames:
|
||||
event_list.append(np.vstack(cancer_frames))
|
||||
|
||||
# Combine tabular chunks
|
||||
final_tabular = pd.concat(tabular_list, axis=0, ignore_index=False)
|
||||
final_tabular.index.name = 'eid' # Ensure index named consistently
|
||||
data = np.vstack(event_list) # Stack all event arrays into one
|
||||
|
||||
# Sort by participant then day
|
||||
data = data[np.lexsort((data[:, 1], data[:, 0]))]
|
||||
|
||||
# Keep only events with non-negative day offsets
|
||||
data = data[data[:, 1] >= 0]
|
||||
|
||||
# Remove duplicate (participant_id, label) pairs keeping first occurrence.
|
||||
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
|
||||
|
||||
# Store compactly using unsigned 32-bit integers
|
||||
data = data.astype(np.uint32)
|
||||
|
||||
# Select eid in both data and tabular
|
||||
valid_eids = np.intersect1d(data[:, 0], final_tabular.index)
|
||||
data = data[np.isin(data[:, 0], valid_eids)]
|
||||
final_tabular = final_tabular.loc[valid_eids]
|
||||
final_tabular = final_tabular.convert_dtypes()
|
||||
|
||||
# Save [eid, sex, date_of_assessment] for basic info
|
||||
basic_info = final_tabular[['sex', 'date_of_assessment']]
|
||||
basic_info.to_csv("ukb_basic_info.csv")
|
||||
|
||||
# Drop sex and date_of_assessment from tabular features
|
||||
final_tabular = final_tabular.drop(columns=['sex', 'date_of_assessment'])
|
||||
|
||||
# Process categorical columns in tabular features
|
||||
# If a column is integer type with few unique values, treat as categorical. For each integer column:
|
||||
# Count unique values (exclude NaN, and negative values if any) as C, set NaN or negative to 0, remap original values to [1..C].
|
||||
for col in final_tabular.select_dtypes(include=['Int64', 'int64']).columns:
|
||||
# Get unique values efficiently
|
||||
series = final_tabular[col]
|
||||
unique_vals = series.dropna().unique()
|
||||
|
||||
# Filter negatives from unique values
|
||||
valid_vals = sorted([v for v in unique_vals if v >= 0])
|
||||
|
||||
if len(valid_vals) <= 10: # Threshold for categorical
|
||||
# Create mapping
|
||||
val_map = {val: idx + 1 for idx, val in enumerate(valid_vals)}
|
||||
|
||||
# Map values. Values not in val_map (negatives, NaNs) become NaN
|
||||
mapped_col = series.map(val_map)
|
||||
|
||||
# Fill NaN with 0 and convert to uint32
|
||||
final_tabular[col] = mapped_col.fillna(0).astype(np.uint32)
|
||||
|
||||
# Save processed tabular features
|
||||
final_tabular.to_csv("ukb_table.csv")
|
||||
|
||||
# Save event data
|
||||
np.save("ukb_event_data.npy", data)
|
||||
Reference in New Issue
Block a user