Add loss functions and model architecture for time-to-event prediction

- Implemented ExponentialNLLLoss and WeibullNLLLoss in losses.py for negative log-likelihood calculations.
- Developed TabularEncoder class in model.py for encoding tabular features.
- Created DelphiFork and SapDelphi classes in model.py for time-to-event prediction using transformer architecture.
- Added data preparation scripts in prepare_data.R and prepare_data.py for processing UK Biobank data, including handling field mappings and event data extraction.
This commit is contained in:
2026-01-07 21:32:00 +08:00
parent 5d1d79b908
commit 6984b254b3
12 changed files with 5098 additions and 0 deletions

55
age_encoder.py Normal file
View File

@@ -0,0 +1,55 @@
import torch
import torch.nn as nn
class AgeSinusoidalEncoder(nn.Module):
"""
Sinusoidal encoder for age.
Args:
n_embd (int): Embedding dimension. Must be even.
"""
def __init__(self, n_embd: int):
super().__init__()
if n_embd % 2 != 0:
raise ValueError("n_embd must be even for sinusoidal encoding.")
self.n_embd = n_embd
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.n_embd)
self.register_buffer('divisor', divisor)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
t_years = ages / 365.25
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
# Interleave cos and sin along the last dimension
output = torch.zeros(
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
output[:, :, 0::2] = torch.cos(args)
output[:, :, 1::2] = torch.sin(args)
return output
class AgeMLPEncoder(nn.Module):
"""
MLP encoder for age, using sinusoidal encoding as input.
Args:
n_embd (int): Embedding dimension.
"""
def __init__(self, n_embd: int):
super().__init__()
self.sin_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
x = self.sin_encoder(ages)
output = self.mlp(x)
return output

111
backbones.py Normal file
View File

@@ -0,0 +1,111 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class SelfAttention(nn.Module):
"""
Multi-head self-attention mechanism.
Args:
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
attn_pdrop (float): Attention dropout probability.
"""
def __init__(
self,
n_embd: int,
n_head: int,
attn_pdrop: float = 0.1,
):
super().__init__()
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
self.n_head = n_head
self.head_dim = n_embd // n_head
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_pdrop = attn_pdrop
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, # (B, L, L)
) -> torch.Tensor:
B, L, D = x.shape
qkv = self.qkv_proj(x) # (B, L, 3D)
q, k, v = qkv.chunk(3, dim=-1)
def reshape_heads(t):
# (B, H, L, d)
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
q = reshape_heads(q)
k = reshape_heads(k)
v = reshape_heads(v)
attn = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_pdrop,
) # (B, H, L, d)
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
return self.o_proj(attn)
class Block(nn.Module):
"""
Transformer block consisting of self-attention and MLP.
Args:
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
pdrop (float): Dropout probability.
"""
def __init__(
self,
n_embd: int,
n_head: int,
pdrop: float = 0.0,
):
super().__init__()
attn_pdrop = pdrop
self.norm_1 = nn.LayerNorm(n_embd)
self.attn = SelfAttention(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
)
self.norm_2 = nn.LayerNorm(n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc=nn.Linear(n_embd, 4 * n_embd),
c_proj=nn.Linear(4 * n_embd, n_embd),
act=nn.GELU(),
dropout=nn.Dropout(pdrop),
))
m = self.mlp
self.mlpf = lambda x: m.dropout(
m.c_proj(m.act(m.c_fc(x))))
self.resid_dropout = nn.Dropout(pdrop)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Attention
h = self.norm_1(x)
h = self.attn(h, attn_mask=attn_mask)
x = x + self.resid_dropout(h)
# MLP
h = self.norm_2(x)
h = self.mlpf(h)
x = x + self.resid_dropout(h)
return x

101
dataset.py Normal file
View File

@@ -0,0 +1,101 @@
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
from collections import defaultdict
from typing import List
class HealthDataset(Dataset):
"""
Dataset for health records.
Args:
data_prefix (str): Prefix for data files.
covariate_list (List[str] | None): List of covariates to include.
"""
def __init__(
self,
data_prefix: str,
covariate_list: List[str] | None = None,
):
basic_info = pd.read_csv(
f"{data_prefix}_basic_info.csv", index_col='eid')
tabular_data = pd.read_csv(f"{data_prefix}_table.csv", index_col='eid')
event_data = np.load(f"{data_prefix}_event_data.npy")
patient_events = defaultdict(list)
vocab_size = 0
for patient_id, time_in_days, event_code in event_data:
patient_events[patient_id].append((time_in_days, event_code))
if event_code > vocab_size:
vocab_size = event_code
self.n_disease = vocab_size - 1
self.basic_info = basic_info.convert_dtypes()
self.patient_ids = self.basic_info.index.tolist()
self.patient_events = dict(patient_events)
tabular_data = tabular_data.convert_dtypes()
cont_cols = []
cate_cols = []
self.cate_dims = []
if covariate_list is not None:
tabular_data = tabular_data[covariate_list]
for col in tabular_data.columns:
if pd.api.types.is_float_dtype(tabular_data[col]):
cont_cols.append(col)
elif pd.api.types.is_integer_dtype(tabular_data[col]):
series = tabular_data[col]
unique_vals = series.dropna().unique()
if len(unique_vals) > 11:
cont_cols.append(col)
else:
cate_cols.append(col)
self.cate_dims.append(int(series.max()) + 1)
self.cont_features = tabular_data[cont_cols].to_numpy(
dtype=np.float32).copy()
self.cate_features = tabular_data[cate_cols].to_numpy(
dtype=np.int64).copy()
self.n_cont = self.cont_features.shape[1]
self.n_cate = self.cate_features.shape[1]
def __len__(self) -> int:
return len(self.patient_ids)
def __getitem__(self, idx):
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events.get(
patient_id, []), key=lambda x: x[0])
event_seq = [item[1] for item in records]
time_seq = [item[0] for item in records]
doa = self.basic_info.loc[patient_id, 'date_of_assessment']
insert_pos = np.searchsorted(time_seq, doa)
time_seq.insert(insert_pos, doa)
# assuming 1 is the code for 'DOA' event
event_seq.insert(insert_pos, 1)
event_tensor = torch.tensor(event_seq, dtype=torch.long)
time_tensor = torch.tensor(time_seq, dtype=torch.float)
cont_tensor = torch.tensor(
self.cont_features[idx, :], dtype=torch.float)
cate_tensor = torch.tensor(
self.cate_features[idx, :], dtype=torch.long)
sex = self.basic_info.loc[patient_id, 'sex']
return (event_tensor, time_tensor, cont_tensor, cate_tensor, sex)
def health_collate_fn(batch):
event_seqs, time_seqs, cont_feats, cate_feats, sexes = zip(*batch)
event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0)
time_batch = pad_sequence(
time_seqs, batch_first=True, padding_value=36525.0)
cont_batch = torch.stack(cont_feats, dim=0)
cont_batch = cont_batch.unsqueeze(1) # (B, 1, n_cont)
cate_batch = torch.stack(cate_feats, dim=0)
cate_batch = cate_batch.unsqueeze(1) # (B, 1, n_cate)
sex_batch = torch.tensor(sexes, dtype=torch.long)
return event_batch, time_batch, cont_batch, cate_batch, sex_batch

243
embed_icd10.py Normal file
View File

@@ -0,0 +1,243 @@
import argparse
import csv
import os
import re
from dataclasses import dataclass
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
@dataclass(frozen=True)
class Icd10Label:
code: str
disease: str
_LABEL_RE = re.compile(r"^\s*([A-Z][0-9][0-9][A-Z0-9]{0,2})\s*\((.+)\)\s*$")
_CODE_RE = re.compile(r"^[A-Z][A-Z0-9]{1,6}$")
def _read_labels(labels_path: str, *, strict_codes: bool) -> list[Icd10Label]:
labels: list[Icd10Label] = []
with open(labels_path, "r", encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line:
continue
match = _LABEL_RE.match(line)
if match is not None:
code, disease = match.group(1), match.group(2)
else:
parts = line.split(maxsplit=1)
if len(parts) == 1:
# Some label lists include non-ICD entries (e.g., "Death").
# Treat these as both code and disease.
code = parts[0].strip()
disease = code
elif len(parts) == 2:
code, disease = parts[0].strip(), parts[1].strip()
else:
raise ValueError(
f"Unrecognized label format: {line!r}. "
"Expected like 'A00 (cholera)', 'CXX Unknown Cancer', or 'Death'."
)
if disease.startswith("(") and disease.endswith(")"):
disease = disease[1:-1].strip()
if strict_codes and not _CODE_RE.match(code):
raise ValueError(
f"Unrecognized ICD10-like code in label: {line!r} (code={code!r}). "
"Re-run without --strict-codes to allow non-ICD labels (e.g., 'Death')."
)
labels.append(Icd10Label(code=code, disease=disease))
if not labels:
raise ValueError(f"No labels found in {labels_path!r}.")
return labels
def embed_texts(
texts: list[str],
*,
model_name: str,
batch_size: int,
max_length: int,
device: str,
) -> np.ndarray:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()
model.to(device)
all_embs: list[np.ndarray] = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
batch = texts[i: i + batch_size]
toks = tokenizer(
batch,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
toks = {k: v.to(device) for k, v in toks.items()}
# Use CLS token embedding (same as original script).
cls_rep = model(**toks)[0][:, 0, :]
all_embs.append(cls_rep.detach().cpu().to(torch.float32).numpy())
return np.concatenate(all_embs, axis=0)
def save_umap_plot(
embeddings: np.ndarray,
codes: list[str],
*,
out_path: str,
random_state: int = 42,
) -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError as e: # pragma: no cover
raise ImportError(
"UMAP visualization requires matplotlib. Install it with: pip install matplotlib"
) from e
try:
import umap
except ImportError as e: # pragma: no cover
raise ImportError(
"UMAP visualization requires umap-learn. Install it with: pip install umap-learn"
) from e
reducer = umap.UMAP(n_components=2, metric="cosine",
random_state=random_state)
coords = reducer.fit_transform(embeddings)
if len(codes) != coords.shape[0]:
raise ValueError(
f"codes length ({len(codes)}) does not match embeddings rows ({coords.shape[0]})."
)
groups: list[str] = []
for code in codes:
cleaned = code.strip()
if cleaned.lower() == "death":
groups.append("Death")
else:
groups.append(cleaned[:1].upper() if cleaned else "?")
group_names = sorted({g for g in groups if g != "Death"})
cmap = plt.get_cmap("tab20")
group_to_color: dict[str, object] = {
g: cmap(i % cmap.N) for i, g in enumerate(group_names)
}
group_to_color["Death"] = "grey"
colors = [group_to_color.get(g, "black") for g in groups]
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(1, 1, 1)
ax.scatter(coords[:, 0], coords[:, 1], s=6, alpha=0.7, c=colors)
ax.set_title("UMAP of ICD label embeddings")
ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
fig.tight_layout()
fig.savefig(out_path, dpi=200)
plt.close(fig)
def main() -> int:
parser = argparse.ArgumentParser(
description="Embed ICD-10 disease labels with SapBERT")
parser.add_argument(
"--labels",
default="labels.csv",
help="Path to labels.csv (lines like 'A00 (cholera)')",
)
parser.add_argument(
"--out-dir",
default=".",
help="Output directory for embeddings and metadata",
)
parser.add_argument(
"--model",
default="cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
help="HuggingFace model name",
)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--max-length", type=int, default=25)
parser.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run on (cuda or cpu)",
)
parser.add_argument(
"--strict-codes",
action="store_true",
help="Fail if a label code is not ICD10-like (disallows labels like 'Death')",
)
parser.add_argument(
"--umap",
action="store_true",
help="Also save a 2D UMAP scatterplot of the embeddings",
)
parser.add_argument(
"--umap-out",
default=None,
help="Path to save UMAP PNG (default: <out-dir>/icd10_sapbert_umap.png)",
)
parser.add_argument(
"--umap-random-state",
type=int,
default=42,
help="Random seed for UMAP",
)
args = parser.parse_args()
labels = _read_labels(args.labels, strict_codes=args.strict_codes)
texts = [lbl.disease for lbl in labels]
embs = embed_texts(
texts,
model_name=args.model,
batch_size=args.batch_size,
max_length=args.max_length,
device=args.device,
)
os.makedirs(args.out_dir, exist_ok=True)
embs_path = os.path.join(args.out_dir, "icd10_sapbert_embeddings.npy")
meta_path = os.path.join(args.out_dir, "icd10_sapbert_metadata.tsv")
np.save(embs_path, embs)
with open(meta_path, "w", encoding="utf-8", newline="") as f:
w = csv.writer(f, delimiter="\t")
w.writerow(["index", "icd10_code", "disease"])
for i, lbl in enumerate(labels):
w.writerow([i, lbl.code, lbl.disease])
if args.umap:
umap_path = (
args.umap_out
if args.umap_out is not None
else os.path.join(args.out_dir, "icd10_sapbert_umap.png")
)
save_umap_plot(
embs,
[lbl.code for lbl in labels],
out_path=umap_path,
random_state=args.umap_random_state,
)
print(f"Saved UMAP plot: {umap_path}")
print(f"Saved embeddings: {embs_path} (shape={embs.shape})")
print(f"Saved metadata: {meta_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

1237
field_id.txt Normal file

File diff suppressed because it is too large Load Diff

74
field_ids_enriched.csv Normal file
View File

@@ -0,0 +1,74 @@
field_instance,full_name,var_name
31-0.0,Sex,sex
34-0.0,Year of birth,year
48-0.0,Waist circumference,waist_circumference
49-0.0,Hip circumference,hip_circumference
50-0.0,Standing height,standing_height
52-0.0,Month of birth,month
53-0.0,Date of attending assessment centre,date_of_assessment
74-0.0,Fasting time,fasting_time
102-0.0,Pulse rate automated reading,pulse_rate
1239-0.0,Current tobacco smoking,smoking
1558-0.0,Alcohol intake frequency.,alcohol
4079-0.0,Diastolic blood pressure automated reading,dbp
4080-0.0,Systolic blood pressure automated reading,sbp
20150-0.0,Forced expiratory volume in 1-second (FEV1) Best measure,fev1_best
20151-0.0,Forced vital capacity (FVC) Best measure,fvc_best
20258-0.0,FEV1/ FVC ratio Z-score,fev1_fvc_ratio
21001-0.0,Body mass index (BMI),bmi
21003-0.0,Age when attended assessment centre,age_at_assessment
30000-0.0,White blood cell (leukocyte) count,WBC
30010-0.0,Red blood cell (erythrocyte) count,RBC
30020-0.0,Haemoglobin concentration,hemoglobin
30030-0.0,Haematocrit percentage,hematocrit
30040-0.0,Mean corpuscular volume,MCV
30050-0.0,Mean corpuscular haemoglobin,MCH
30060-0.0,Mean corpuscular haemoglobin concentration,MCHC
30080-0.0,Platelet count,Pc
30100-0.0,Mean platelet (thrombocyte) volume,MPV
30120-0.0,Lymphocyte count,LymC
30130-0.0,Monocyte count,MonC
30140-0.0,Neutrophill count,NeuC
30150-0.0,Eosinophill count,EosC
30160-0.0,Basophill count,BasC
30170-0.0,Nucleated red blood cell count,nRBC
30250-0.0,Reticulocyte count,RC
30260-0.0,Mean reticulocyte volume,MRV
30270-0.0,Mean sphered cell volume,MSCV
30280-0.0,Immature reticulocyte fraction,IRF
30300-0.0,High light scatter reticulocyte count,HLSRC
30500-0.0,Microalbumin in urine,MicU
30510-0.0,Creatinine (enzymatic) in urine,CreaU
30520-0.0,Potassium in urine,PotU
30530-0.0,Sodium in urine,SodU
30600-0.0,Albumin,Alb
30610-0.0,Alkaline phosphatase,ALP
30620-0.0,Alanine aminotransferase,Alanine
30630-0.0,Apolipoprotein A,ApoA
30640-0.0,Apolipoprotein B,ApoB
30650-0.0,Aspartate aminotransferase,AA
30660-0.0,Direct bilirubin,DBil
30670-0.0,Urea,Urea
30680-0.0,Calcium,Calcium
30690-0.0,Cholesterol,Cholesterol
30700-0.0,Creatinine,Creatinine
30710-0.0,C-reactive protein,CRP
30720-0.0,Cystatin C,CystatinC
30730-0.0,Gamma glutamyltransferase,GGT
30740-0.0,Glucose,Glu
30750-0.0,Glycated haemoglobin (HbA1c),HbA1c
30760-0.0,HDL cholesterol,HDL
30770-0.0,IGF-1,IGF1
30780-0.0,LDL direct,LDL
30790-0.0,Lipoprotein A,LpA
30800-0.0,Oestradiol,Oestradiol
30810-0.0,Phosphate,Phosphate
30820-0.0,Rheumatoid factor,Rheu
30830-0.0,SHBG,SHBG
30840-0.0,Total bilirubin,TotalBil
30850-0.0,Testosterone,Testosterone
30860-0.0,Total protein,TotalProtein
30870-0.0,Triglycerides,Tri
30880-0.0,Urate,Urate
30890-0.0,Vitamin D,VitaminD
40000-0.0,Date of death,Death
1 field_instance full_name var_name
2 31-0.0 Sex sex
3 34-0.0 Year of birth year
4 48-0.0 Waist circumference waist_circumference
5 49-0.0 Hip circumference hip_circumference
6 50-0.0 Standing height standing_height
7 52-0.0 Month of birth month
8 53-0.0 Date of attending assessment centre date_of_assessment
9 74-0.0 Fasting time fasting_time
10 102-0.0 Pulse rate automated reading pulse_rate
11 1239-0.0 Current tobacco smoking smoking
12 1558-0.0 Alcohol intake frequency. alcohol
13 4079-0.0 Diastolic blood pressure automated reading dbp
14 4080-0.0 Systolic blood pressure automated reading sbp
15 20150-0.0 Forced expiratory volume in 1-second (FEV1) Best measure fev1_best
16 20151-0.0 Forced vital capacity (FVC) Best measure fvc_best
17 20258-0.0 FEV1/ FVC ratio Z-score fev1_fvc_ratio
18 21001-0.0 Body mass index (BMI) bmi
19 21003-0.0 Age when attended assessment centre age_at_assessment
20 30000-0.0 White blood cell (leukocyte) count WBC
21 30010-0.0 Red blood cell (erythrocyte) count RBC
22 30020-0.0 Haemoglobin concentration hemoglobin
23 30030-0.0 Haematocrit percentage hematocrit
24 30040-0.0 Mean corpuscular volume MCV
25 30050-0.0 Mean corpuscular haemoglobin MCH
26 30060-0.0 Mean corpuscular haemoglobin concentration MCHC
27 30080-0.0 Platelet count Pc
28 30100-0.0 Mean platelet (thrombocyte) volume MPV
29 30120-0.0 Lymphocyte count LymC
30 30130-0.0 Monocyte count MonC
31 30140-0.0 Neutrophill count NeuC
32 30150-0.0 Eosinophill count EosC
33 30160-0.0 Basophill count BasC
34 30170-0.0 Nucleated red blood cell count nRBC
35 30250-0.0 Reticulocyte count RC
36 30260-0.0 Mean reticulocyte volume MRV
37 30270-0.0 Mean sphered cell volume MSCV
38 30280-0.0 Immature reticulocyte fraction IRF
39 30300-0.0 High light scatter reticulocyte count HLSRC
40 30500-0.0 Microalbumin in urine MicU
41 30510-0.0 Creatinine (enzymatic) in urine CreaU
42 30520-0.0 Potassium in urine PotU
43 30530-0.0 Sodium in urine SodU
44 30600-0.0 Albumin Alb
45 30610-0.0 Alkaline phosphatase ALP
46 30620-0.0 Alanine aminotransferase Alanine
47 30630-0.0 Apolipoprotein A ApoA
48 30640-0.0 Apolipoprotein B ApoB
49 30650-0.0 Aspartate aminotransferase AA
50 30660-0.0 Direct bilirubin DBil
51 30670-0.0 Urea Urea
52 30680-0.0 Calcium Calcium
53 30690-0.0 Cholesterol Cholesterol
54 30700-0.0 Creatinine Creatinine
55 30710-0.0 C-reactive protein CRP
56 30720-0.0 Cystatin C CystatinC
57 30730-0.0 Gamma glutamyltransferase GGT
58 30740-0.0 Glucose Glu
59 30750-0.0 Glycated haemoglobin (HbA1c) HbA1c
60 30760-0.0 HDL cholesterol HDL
61 30770-0.0 IGF-1 IGF1
62 30780-0.0 LDL direct LDL
63 30790-0.0 Lipoprotein A LpA
64 30800-0.0 Oestradiol Oestradiol
65 30810-0.0 Phosphate Phosphate
66 30820-0.0 Rheumatoid factor Rheu
67 30830-0.0 SHBG SHBG
68 30840-0.0 Total bilirubin TotalBil
69 30850-0.0 Testosterone Testosterone
70 30860-0.0 Total protein TotalProtein
71 30870-0.0 Triglycerides Tri
72 30880-0.0 Urate Urate
73 30890-0.0 Vitamin D VitaminD
74 40000-0.0 Date of death Death

1129
icd10_codes_mod.tsv Normal file

File diff suppressed because it is too large Load Diff

1257
labels.csv Normal file

File diff suppressed because it is too large Load Diff

210
losses.py Normal file
View File

@@ -0,0 +1,210 @@
import math
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# Pair extraction (utility; not used by the losses below)
# ============================================================
def get_valid_pairs_and_dt(
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
n_tech_tokens: int
) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Extract valid event pairs (prev -> next) and compute dt in years.
Args:
event_seqs (torch.Tensor): Event sequences.
time_seqs (torch.Tensor): Time sequences.
n_tech_tokens (int): Number of technical tokens.
Returns:
Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
(dt, b_prev, t_prev, b_next, t_next) if valid pairs exist, else None.
Notes:
- Assumes strict right-padding.
- Filters to next events that are disease tokens: token_id >= n_tech_tokens.
- Filters to strictly positive dt.
"""
real_mask = event_seqs >= 1
idx = real_mask.nonzero(as_tuple=False)
if idx.size(0) <= 1:
return None
same_batch = idx[1:, 0] == idx[:-1, 0]
if not same_batch.any():
return None
prev_idx = idx[:-1][same_batch]
next_idx = idx[1:][same_batch]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
valid_target = event_seqs[b_next, t_next] >= n_tech_tokens
if not valid_target.any():
return None
prev_idx = prev_idx[valid_target]
next_idx = next_idx[valid_target]
b_prev, t_prev = prev_idx[:, 0], prev_idx[:, 1]
b_next, t_next = next_idx[:, 0], next_idx[:, 1]
dt = (time_seqs[b_next, t_next] -
time_seqs[b_prev, t_prev]).to(torch.float32) / 365.25
valid_dt = dt > 0
if not valid_dt.any():
return None
dt = dt[valid_dt]
b_prev = b_prev[valid_dt]
t_prev = t_prev[valid_dt]
b_next = b_next[valid_dt]
t_next = t_next[valid_dt]
return dt, b_prev, t_prev, b_next, t_next
# ============================================================
# Losses (clean interface): loss_fn(preds, target_events, dt) -> (nll, regularization)
# ============================================================
class ExponentialNLLLoss(nn.Module):
"""
Competing risks exponential likelihood.
The negative log-likelihood is given by:
.. math::
\\text{nll} = -\\log \\lambda_{k^*} + \\left(\\sum_k \\lambda_k\\right) \\cdot dt
Args:
eps (float): Small epsilon for numerical stability.
"""
def __init__(
self,
lambda_reg: float = 0.0,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
logits (torch.Tensor): (M, K) tensor of logits.
target_events (torch.Tensor): (M,) int64 tensor of target events in [0, K).
dt (torch.Tensor): (M,) float tensor of time intervals (years), strictly positive.
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization) where regularization is 0.
"""
logits = logits.squeeze(-1) if logits.dim() == 3 else logits
hazards = F.softplus(logits) + self.eps # (M,K)
hazard_event = hazards.gather(
1, target_events.unsqueeze(1)).squeeze(1) # (M,)
total_hazard = hazards.sum(dim=1) # (M,)
log_hazards = torch.log(hazards) # (M, K)
nll = -torch.log(hazard_event) + total_hazard * dt
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
reg = F.cross_entropy(log_hazards, target_events,
reduction="mean") * self.lambda_reg
return nll, reg
class WeibullNLLLoss(nn.Module):
"""
Weibull hazard in t.
.. math::
\\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k}
\\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1}
Args:
eps (float): Small epsilon for numerical stability.
lambda_reg (float): Regularization weight.
use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples.
near_integer_eps_years (float): Near-integer threshold in years.
interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years.
min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year.
"""
def __init__(
self,
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(
self,
logits: torch.Tensor,
target_events: torch.Tensor,
dt: torch.Tensor,
reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass.
Args:
logits (torch.Tensor): (M, K, 2) tensor of logits.
target_events (torch.Tensor): (M,) tensor of target events.
dt (torch.Tensor): (M,) tensor of time intervals.
reduction (str): 'mean', 'sum', or 'none'.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (nll, regularization).
"""
shapes = F.softplus(logits[..., 0]) + self.eps # (M,K)
scales = F.softplus(logits[..., 1]) + self.eps # (M,K)
eps = self.eps
t = torch.clamp(dt, min=eps)
t_mat = t.unsqueeze(1) # (M,1)
# cumulative hazard (M,K)
cum_hazard = scales * t_mat.pow(shapes)
# hazard (M,K)
hazard = shapes * scales * t_mat.pow(shapes - 1.0)
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
# Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t))
# NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t)
nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1)
if reduction == "mean":
nll = nll.mean()
elif reduction == "sum":
nll = nll.sum()
reg = shapes.new_zeros(())
if self.lambda_reg > 0:
reg = self.lambda_reg * (
(torch.log(scales + eps) ** 2).mean() +
(torch.log(shapes + eps) ** 2).mean()
)
return nll, reg

440
model.py Normal file
View File

@@ -0,0 +1,440 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
from backbones import Block
from typing import Optional, List
import numpy as np
class TabularEncoder(nn.Module):
"""
Encoder for tabular features (continuous and categorical).
Args:
n_embd (int): Embedding dimension.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
"""
def __init__(
self,
n_embd: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
):
super().__init__()
self.n_embd = n_embd
self.n_cont = n_cont
self.n_cate = n_cate
if n_cont > 0:
hidden = 2 * n_embd
self.cont_mlp = nn.Sequential(
nn.Linear(2 * n_cont, hidden),
nn.GELU(),
nn.Linear(hidden, n_embd),
)
else:
self.cont_mlp = None
if n_cate > 0:
assert len(cate_dims) == n_cate, \
"Length of cate_dims must match n_cate"
self.cate_embds = nn.ModuleList([
nn.Embedding(dim, n_embd) for dim in cate_dims
])
self.cate_mask_embds = nn.ModuleList([
nn.Embedding(2, n_embd) for _ in range(n_cate)
])
else:
self.cate_embds = None
self.cate_mask_embds = None
self.cont_mask_proj = (
nn.Linear(n_cont, n_embd) if n_cont > 0 else None
)
self.film = nn.Sequential(
nn.Linear(n_embd, 2 * n_embd),
nn.GELU(),
nn.Linear(2 * n_embd, 2 * n_embd),
)
self.apply(self._init_weights)
self.out_ln = nn.LayerNorm(n_embd)
# Zero-init the last layer of FiLM to start with identity modulation
with torch.no_grad():
last_linear = self.film[-1]
last_linear.weight.zero_()
last_linear.bias.zero_()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
cont_features: Optional[torch.Tensor],
cate_features: Optional[torch.Tensor],
) -> torch.Tensor:
if self.n_cont == 0 and self.n_cate == 0:
# infer (B, L) from whichever input is not None
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError(
"TabularEncoder received no features but cannot infer (B, L)."
)
return torch.zeros(B, L, self.n_embd, device=device)
value_parts: List[torch.Tensor] = []
mask_parts: List[torch.Tensor] = []
if self.n_cont > 0 and cont_features is not None:
if cont_features.dim() != 3:
raise ValueError(
"cont_features must be 3D tensor (B, L, n_cont)")
B, L, D_cont = cont_features.shape
if D_cont != self.n_cont:
raise ValueError(
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
cont_mask = (~torch.isnan(cont_features)).float()
cont_filled = torch.nan_to_num(cont_features, nan=0.0)
cont_joint = torch.cat([cont_filled, cont_mask], dim=-1)
h_cont_value = self.cont_mlp(cont_joint)
value_parts.append(h_cont_value)
if self.cont_mask_proj is not None:
h_cont_mask = self.cont_mask_proj(cont_mask)
mask_parts.append(h_cont_mask)
if self.n_cate > 0 and cate_features is not None:
if cate_features.dim() != 3:
raise ValueError(
"cate_features must be 3D tensor (B, L, n_cate)")
B, L, D_cate = cate_features.shape
if D_cate != self.n_cate:
raise ValueError(
f"Expected cate_features last dim to be {self.n_cate}, got {D_cate}")
for i in range(self.n_cate):
cate_feat = cate_features[:, :, i]
cate_embd = self.cate_embds[i]
cate_mask_embd = self.cate_mask_embds[i]
cate_value = cate_embd(
torch.clamp(cate_feat, min=0))
cate_mask = (cate_feat > 0).long()
cate_mask_value = cate_mask_embd(cate_mask)
value_parts.append(cate_value)
mask_parts.append(cate_mask_value)
if not value_parts:
if cont_features is not None:
B, L = cont_features.shape[:2]
device = cont_features.device
elif cate_features is not None:
B, L = cate_features.shape[:2]
device = cate_features.device
else:
raise ValueError("No features provided to TabularEncoder.")
return torch.zeros(B, L, self.n_embd, device=device)
h_value = torch.stack(value_parts, dim=0).mean(dim=0)
h_mask = torch.stack(mask_parts, dim=0).mean(dim=0)
h_mask_flat = h_mask.view(-1, self.n_embd)
film_params = self.film(h_mask_flat)
gamma_delta, beta = film_params.chunk(2, dim=-1)
gamma = 1.0 + gamma_delta
h_value_flat = h_value.view(-1, self.n_embd)
h_out = gamma * h_value_flat + beta
h_out = h_out.view(B, L, self.n_embd)
h_out = self.out_ln(h_out)
return h_out
def _build_time_padding_mask(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
) -> torch.Tensor:
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j <= t_i) # allow attending only to past or current
key_is_valid = (event_seq != 0) # disallow padded positions
allowed = time_mask & key_is_valid.unsqueeze(1)
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
return attn_mask.unsqueeze(1) # (B, 1, L, L)
class DelphiFork(nn.Module):
"""
DelphiFork model for time-to-event prediction.
Args:
n_disease (int): Number of disease tokens.
n_tech_tokens (int): Number of technical tokens.
n_embd (int): Embedding dimension.
n_head (int): Number of attention heads.
n_layer (int): Number of transformer layers.
n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature.
age_encoder_type (str): Type of age encoder ("sinusoidal" or "mlp").
pdrop (float): Dropout probability.
token_pdrop (float): Token dropout probability.
n_dim (int): Dimension of theta parameters.
"""
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x
class SapDelphi(nn.Module):
def __init__(
self,
n_disease: int,
n_tech_tokens: int,
n_embd: int,
n_head: int,
n_layer: int,
n_cont: int,
n_cate: int,
cate_dims: List[int],
age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
):
super().__init__()
self.vocab_size = n_disease + n_tech_tokens
self.n_tech_tokens = n_tech_tokens
self.n_disease = n_disease
self.n_embd = n_embd
self.n_head = n_head
self.n_dim = n_dim
if pretrained_weights_path is not None:
print(
f"Loading pretrained embeddings from {pretrained_weights_path}...")
bert_weights = np.load(pretrained_weights_path)
bert_weights = torch.tensor(bert_weights, dtype=torch.float32)
vocab_dim = bert_weights.shape[1] # 通常是 768
pad_emb = torch.zeros(1, vocab_dim)
tech_embs = nn.init.normal_(torch.empty(
n_tech_tokens-1, vocab_dim))
full_emb_weights = torch.cat(
[pad_emb, tech_embs, bert_weights], dim=0)
self.token_embedding = nn.Embedding.from_pretrained(
full_emb_weights, freeze=freeze_embeddings)
print("Pretrained embeddings loaded.")
if vocab_dim != n_embd:
self.emb_proj = nn.Sequential(
nn.Linear(vocab_dim, n_embd, bias=False),
nn.LayerNorm(n_embd),
nn.Dropout(pdrop),
)
else:
self.emb_proj = nn.Identity()
else:
self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0)
self.emb_proj = nn.Identity()
if age_encoder_type == "sinusoidal":
self.age_encoder = AgeSinusoidalEncoder(n_embd)
elif age_encoder_type == "mlp":
self.age_encoder = AgeMLPEncoder(n_embd)
else:
raise ValueError(
f"Unsupported age_encoder_type: {age_encoder_type}")
self.sex_encoder = nn.Embedding(2, n_embd)
self.tabular_encoder = TabularEncoder(
n_embd, n_cont, n_cate, cate_dims)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers
self.theta_proj = nn.Linear(n_embd, n_disease * n_dim)
def forward(
self,
event_seq: torch.Tensor, # (B, L)
time_seq: torch.Tensor, # (B, L)
sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
token_embds = self.emb_proj(token_embds) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D)
sex_embds = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1, D)
table_embds = self.tabular_encoder(cont_seq, cate_seq) # (B, Lc, D)
mask = (event_seq == 1) # (B, L)
B, L = event_seq.shape
Lc = table_embds.size(1)
D = table_embds.size(2)
# occ[b, t] = 第几次出现(从0开始)非mask位置值无意义后面会置0
# (B, L), DOA: 0,1,2,...
occ = torch.cumsum(mask.to(torch.long), dim=1) - 1
# 将超过 Lc-1 的部分截断并把非mask位置强制为 0避免无意义 gather
tab_idx = occ.clamp(min=0, max=max(Lc - 1, 0))
tab_idx = tab_idx.masked_fill(~mask, 0) # (B, L)
# 按 dim=1 从 (B, Lc, D) 取出每个位置应注入的 tab embedding -> (B, L, D)
tab_inject = table_embds.gather(
dim=1,
index=tab_idx.unsqueeze(-1).expand(-1, -1, D)
)
# 只在 mask==True 的位置替换
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask(
event_seq, time_seq)
for block in self.blocks:
x = block(x, attn_mask=attn_mask)
x = self.ln_f(x)
if b_prev is not None and t_prev is not None:
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
theta = self.theta_proj(c) # (M, N_disease * n_dim)
theta = theta.view(M, self.n_disease, self.n_dim)
return theta
else:
return x

26
prepare_data.R Normal file
View File

@@ -0,0 +1,26 @@
library(data.table)
setDTthreads(40)
library(readr)
field_id <- read.csv("field_id.txt", header = FALSE)
uid <- field_id$V1
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HHdata_221103_0512.csv"
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
all_names <- names(header_dt)
keep_names <- intersect(all_names,uid)
ukb_disease <- fread(big_path,
select = keep_names,
showProgress = TRUE)
field_id <- read.csv("field_id.txt", header = FALSE)
uid <- field_id$V1
big_path <- "/mnt/storage/shared_data/UKBB/20230518-from-zhourong/HH_data_220812_0512.csv"
header_dt <- fread(big_path, nrows = 0) # Read 0 rows => only column names
all_names <- names(header_dt)
keep_names <- intersect(all_names,uid)
ukb_others <- fread(big_path,
select = keep_names,
showProgress = TRUE)
# merge disease and other data by "eid"
ukb_data <- merge(ukb_disease, ukb_others, by = "eid", all = TRUE)
fwrite(ukb_data, "ukb_data.csv")

215
prepare_data.py Normal file
View File

@@ -0,0 +1,215 @@
import pandas as pd # Pandas for data manipulation
import tqdm # Progress bar for chunk processing
import numpy as np # Numerical operations
# CSV mapping field IDs to human-readable names
field_map_file = "field_ids_enriched.csv"
field_dict = {} # Map original field ID -> new column name
tabular_fields = [] # List of tabular feature column names
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
next(f) # skip header line
for line in f: # Iterate through lines
parts = line.strip().split(",") # Split by CSV commas
if len(parts) >= 3: # Ensure we have at least id and name columns (fix: was >=2)
# Original field identifier (e.g., "34-0.0")
field_id = parts[0]
field_name = parts[2] # Human-readable column name
field_dict[field_id] = field_name # Record the mapping
# Track as a potential tabular feature
tabular_fields.append(field_name)
# Exclude raw date parts and target columns
exclude_fields = ['year', 'month', 'Death', 'age_at_assessment']
tabular_fields = [
# Filter out excluded columns
field for field in tabular_fields if field not in exclude_fields]
# TSV mapping field IDs to ICD10-related date columns
field_to_icd_map = "icd10_codes_mod.tsv"
# Date-like variables to be converted to offsets
date_vars = []
with open(field_to_icd_map, "r", encoding="utf-8") as f: # Open ICD10 mapping
for line in f: # Iterate each mapping row
parts = line.strip().split() # Split on whitespace for TSV
if len(parts) >= 6: # Guard against malformed lines
# Map field ID to the date column name
field_dict[parts[0]] = parts[5]
date_vars.append(parts[5]) # Track date column names in order
for j in range(17): # Map up to 17 cancer entry slots (dates and types)
# Cancer diagnosis date slot j
field_dict[f'40005-{j}.0'] = f'cancer_date_{j}'
field_dict[f'40006-{j}.0'] = f'cancer_type_{j}' # Cancer type/code slot j
# Number of ICD-related date columns before adding extras
len_icd = len(date_vars)
date_vars.extend(['Death', 'date_of_assessment'] + # Add outcome date and assessment date
# Add cancer date columns
[f'cancer_date_{j}' for j in range(17)])
labels_file = "labels.csv" # File listing label codes
label_dict = {} # Map code string -> integer label id
with open(labels_file, "r", encoding="utf-8") as f: # Open labels file
for idx, line in enumerate(f): # Enumerate to assign incremental label IDs
parts = line.strip().split(' ') # Split by space
if parts and parts[0]: # Guard against empty lines
# Map code to index (0 for padding, 1 for checkup)
label_dict[parts[0]] = idx + 2
event_list = [] # Accumulator for event arrays across chunks
tabular_list = [] # Accumulator for tabular feature DataFrames across chunks
ukb_iterator = pd.read_csv( # Stream UK Biobank data in chunks
"ukb_data.csv",
sep=',',
chunksize=10000, # Stream file in manageable chunks to reduce memory footprint
# First column (participant ID) becomes DataFrame index
index_col=0,
low_memory=False # Disable type inference optimization for consistent dtypes
)
# Iterate chunks with progress
for ukb_chunk in tqdm.tqdm(ukb_iterator, desc="Processing UK Biobank data"):
# Rename columns to friendly names
ukb_chunk = ukb_chunk.rename(columns=field_dict)
# Require sex to be present
ukb_chunk.dropna(subset=['sex'], inplace=True)
# Construct date of birth from year and month (day fixed to 1)
ukb_chunk['day'] = 1
ukb_chunk['dob'] = pd.to_datetime(
# Guard against malformed dates
ukb_chunk[['year', 'month', 'day']], errors='coerce'
)
del ukb_chunk['day']
# Use only date variables that actually exist in the current chunk
present_date_vars = [c for c in date_vars if c in ukb_chunk.columns]
# Convert date-like columns to datetime and compute day offsets from dob
if present_date_vars:
date_cols = ukb_chunk[present_date_vars].apply(
pd.to_datetime, format="%Y-%m-%d", errors='coerce' # Parse dates safely
)
date_cols_days = date_cols.sub(
ukb_chunk['dob'], axis=0) # Timedelta relative to dob
ukb_chunk[present_date_vars] = date_cols_days.apply(
lambda x: x.dt.days) # Store days since dob
# Append tabular features (use only columns that exist)
present_tabular_fields = [
c for c in tabular_fields if c in ukb_chunk.columns]
tabular_list.append(ukb_chunk[present_tabular_fields].copy())
# Process disease events from ICD10-related date columns
# Take ICD date cols plus 'Death' if present by order
icd10_cols = present_date_vars[:len_icd + 1]
# Melt to long form: participant id, event code (column name), and days offset
melted_df = ukb_chunk.reset_index().melt(
id_vars=['eid'],
value_vars=icd10_cols,
var_name='event_code',
value_name='days',
)
# Require non-missing day offsets
melted_df.dropna(subset=['days'], inplace=True)
if not melted_df.empty:
melted_df['label'] = melted_df['event_code'].map(
label_dict) # Map event code to numeric label
# Fix: ensure labels exist before int cast
melted_df.dropna(subset=['label'], inplace=True)
if not melted_df.empty:
event_list.append(
melted_df[['eid', 'days', 'label']]
.astype(int) # Safe now since label and days are non-null
.to_numpy()
)
# Optimized cancer processing without wide_to_long
cancer_frames = []
for j in range(17):
d_col = f'cancer_date_{j}'
t_col = f'cancer_type_{j}'
if d_col in ukb_chunk.columns and t_col in ukb_chunk.columns:
# Filter rows where both date and type are present
mask = ukb_chunk[d_col].notna() & ukb_chunk[t_col].notna()
if mask.any():
subset_idx = ukb_chunk.index[mask]
subset_days = ukb_chunk.loc[mask, d_col]
subset_type = ukb_chunk.loc[mask, t_col]
# Map cancer type to label
# Use first 3 chars
cancer_codes = subset_type.str.slice(0, 3)
labels = cancer_codes.map(label_dict)
# Filter valid labels
valid_label_mask = labels.notna()
if valid_label_mask.any():
# Create array: eid, days, label
# Ensure types are correct for numpy
c_eids = subset_idx[valid_label_mask].values
c_days = subset_days[valid_label_mask].values
c_labels = labels[valid_label_mask].values
# Stack
chunk_cancer_data = np.column_stack(
(c_eids, c_days, c_labels))
cancer_frames.append(chunk_cancer_data)
if cancer_frames:
event_list.append(np.vstack(cancer_frames))
# Combine tabular chunks
final_tabular = pd.concat(tabular_list, axis=0, ignore_index=False)
final_tabular.index.name = 'eid' # Ensure index named consistently
data = np.vstack(event_list) # Stack all event arrays into one
# Sort by participant then day
data = data[np.lexsort((data[:, 1], data[:, 0]))]
# Keep only events with non-negative day offsets
data = data[data[:, 1] >= 0]
# Remove duplicate (participant_id, label) pairs keeping first occurrence.
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
# Store compactly using unsigned 32-bit integers
data = data.astype(np.uint32)
# Select eid in both data and tabular
valid_eids = np.intersect1d(data[:, 0], final_tabular.index)
data = data[np.isin(data[:, 0], valid_eids)]
final_tabular = final_tabular.loc[valid_eids]
final_tabular = final_tabular.convert_dtypes()
# Save [eid, sex, date_of_assessment] for basic info
basic_info = final_tabular[['sex', 'date_of_assessment']]
basic_info.to_csv("ukb_basic_info.csv")
# Drop sex and date_of_assessment from tabular features
final_tabular = final_tabular.drop(columns=['sex', 'date_of_assessment'])
# Process categorical columns in tabular features
# If a column is integer type with few unique values, treat as categorical. For each integer column:
# Count unique values (exclude NaN, and negative values if any) as C, set NaN or negative to 0, remap original values to [1..C].
for col in final_tabular.select_dtypes(include=['Int64', 'int64']).columns:
# Get unique values efficiently
series = final_tabular[col]
unique_vals = series.dropna().unique()
# Filter negatives from unique values
valid_vals = sorted([v for v in unique_vals if v >= 0])
if len(valid_vals) <= 10: # Threshold for categorical
# Create mapping
val_map = {val: idx + 1 for idx, val in enumerate(valid_vals)}
# Map values. Values not in val_map (negatives, NaNs) become NaN
mapped_col = series.map(val_map)
# Fill NaN with 0 and convert to uint32
final_tabular[col] = mapped_col.fillna(0).astype(np.uint32)
# Save processed tabular features
final_tabular.to_csv("ukb_table.csv")
# Save event data
np.save("ukb_event_data.npy", data)