Refactor data preparation and add loss functions for model training

- Removed `prepare_data.py` as it is no longer needed.
- Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization.
- Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
This commit is contained in:
2025-12-05 00:54:56 +08:00
parent 9ca8909e3a
commit cb7adb70d9
6 changed files with 445 additions and 1486 deletions

40
age_encoder.py Normal file
View File

@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
class AgeSinusoidalEncoder(nn.Module):
def __init__(self, n_embd: int):
super().__init__()
if n_embd % 2 != 0:
raise ValueError("n_embd must be even for sinusoidal encoding.")
self.n_embd = n_embd
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.n_embd)
self.register_buffer('divisor', divisor)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
t_years = ages / 365.25
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
# Interleave cos and sin along the last dimension
output = torch.zeros(
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
output[:, :, 0::2] = torch.cos(args)
output[:, :, 1::2] = torch.sin(args)
return output
class AgeMLPEncoder(nn.Module):
def __init__(self, n_embd: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(2, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
ages = ages.unsqueeze(-1).float() # (B, L, 1)
ages_normalized = ages / 365.25 # normalize to years
log1page = torch.log1p(ages_normalized) # (B, L, 1)
ages = torch.cat([ages_normalized, log1page], dim=-1) # (B, L, 2)
output = self.mlp(ages) # (B, L, n_embd)
return output

164
backbones.py Normal file
View File

@@ -0,0 +1,164 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class RMSNorm(nn.Module):
def __init__(
self,
n_embd: int,
eps: float = 1e-8,
):
super().__init__()
self.n_embd = n_embd
self.eps = eps
self.weight = nn.Parameter(torch.ones(n_embd))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm_x = x.norm(2, dim=-1, keepdim=True)
rms_x = norm_x * (self.n_embd ** -0.5)
x_normed = x / (rms_x + self.eps)
return self.weight * x_normed
class SelfAttention(nn.Module):
def __init__(
self,
n_embd: int,
n_head: int,
attn_pdrop: float = 0.1,
):
super().__init__()
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
self.n_head = n_head
self.head_dim = n_embd // n_head
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
self.attn_pdrop = attn_pdrop
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, L, D = x.shape
qkv = self.qkv_proj(x) # (B, L, 3D)
q, k, v = qkv.chunk(3, dim=-1)
def reshape_heads(t):
# (B, H, L, d)
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
q = reshape_heads(q)
k = reshape_heads(k)
v = reshape_heads(v)
attn = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_pdrop,
) # (B, H, L, d)
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
return self.o_proj(attn)
class SwiGLUMLP(nn.Module):
def __init__(
self,
n_embd: int,
pdrop: float = 0.0,
):
super().__init__()
hidden_dim = 4 * n_embd
self.fc1 = nn.Linear(n_embd, 2 * hidden_dim, bias=False)
self.fc2 = nn.Linear(hidden_dim, n_embd, bias=False)
self.dropout = nn.Dropout(pdrop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = self.fc1(x).chunk(2, dim=-1)
# SwiGLU: silu(x1) * x2
x = F.silu(x1) * x2
x = self.fc2(x)
return self.dropout(x)
class Block(nn.Module):
def __init__(
self,
n_embd: int,
n_head: int,
pdrop: float = 0.0,
):
super().__init__()
attn_pdrop = pdrop
self.norm_1 = nn.LayerNorm(n_embd)
self.attn = SelfAttention(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
)
self.norm_2 = nn.LayerNorm(n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc=nn.Linear(n_embd, 4 * n_embd),
c_proj=nn.Linear(4 * n_embd, n_embd),
act=nn.GELU(),
dropout=nn.Dropout(pdrop),
))
m = self.mlp
self.mlpf = lambda x: m.dropout(
m.c_proj(m.act(m.c_fc(x))))
self.resid_dropout = nn.Dropout(pdrop)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Attention
h = self.norm_1(x)
h = self.attn(h, attn_mask=attn_mask)
x = x + self.resid_dropout(h)
# MLP
h = self.norm_2(x)
h = self.mlpf(h)
x = x + self.resid_dropout(h)
return x
class ModernBlock(nn.Module):
def __init__(
self,
n_embd: int,
n_head: int,
pdrop: float = 0.0,
):
super().__init__()
attn_pdrop = pdrop
mlp_pdrop = pdrop
self.norm_1 = RMSNorm(n_embd)
self.attn = SelfAttention(
n_embd=n_embd,
n_head=n_head,
attn_pdrop=attn_pdrop,
)
self.norm_2 = RMSNorm(n_embd)
self.mlp = SwiGLUMLP(n_embd=n_embd, pdrop=mlp_pdrop)
self.resid_dropout = nn.Dropout(pdrop)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
h = self.norm_1(x)
h = self.attn(h, attn_mask=attn_mask)
x = x + self.resid_dropout(h)
# MLP
h = self.norm_2(x)
h = self.mlp(h)
x = x + self.resid_dropout(h)
return x

File diff suppressed because it is too large Load Diff

View File

@@ -1,216 +0,0 @@
import pandas as pd # Pandas for data manipulation
import tqdm # Progress bar for chunk processing
import numpy as np # Numerical operations
train_frac = 0.7 # Fraction of participants for training split
val_frac = 0.15 # Fraction of participants for validation split
test_frac = 0.15 # Fraction of participants for test split
# CSV mapping field IDs to human-readable names
field_map_file = "../field_ids_enriched.csv"
field_dict = {} # Map original field ID -> new column name
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
next(f) # skip header line
for line in f: # Iterate through lines
parts = line.strip().split(",") # Split by CSV commas
if len(parts) >= 3: # Ensure we have at least id and name columns (fix: was >=2)
# Original field identifier (e.g., "34-0.0")
field_id = parts[0]
field_name = parts[2] # Human-readable column name
field_dict[field_id] = field_name # Record the mapping
# Track as a potential tabular feature
# TSV mapping field IDs to ICD10-related date columns
field_to_icd_map = "../icd10_codes_mod.tsv"
# Date-like variables to be converted to offsets
date_vars = []
with open(field_to_icd_map, "r", encoding="utf-8") as f: # Open ICD10 mapping
for line in f: # Iterate each mapping row
parts = line.strip().split() # Split on whitespace for TSV
if len(parts) >= 6: # Guard against malformed lines
# Map field ID to the date column name
field_dict[parts[0]] = parts[5]
date_vars.append(parts[5]) # Track date column names in order
for j in range(17): # Map up to 17 cancer entry slots (dates and types)
# Cancer diagnosis date slot j
field_dict[f'40005-{j}.0'] = f'cancer_date_{j}'
field_dict[f'40006-{j}.0'] = f'cancer_type_{j}' # Cancer type/code slot j
# Number of ICD-related date columns before adding extras
len_icd = len(date_vars)
date_vars.extend(['Death', 'date_of_assessment'] + # Add outcome date and assessment date
# Add cancer date columns
[f'cancer_date_{j}' for j in range(17)])
labels_file = "labels.csv" # File listing label codes
label_dict = {} # Map code string -> integer label id
with open(labels_file, "r", encoding="utf-8") as f: # Open labels file
for idx, line in enumerate(f): # Enumerate to assign incremental label IDs
parts = line.strip().split(' ') # Split by space
if parts and parts[0]: # Guard against empty lines
label_dict[parts[0]] = idx
event_list = [] # Accumulator for event arrays across chunks
ukb_iterator = pd.read_csv( # Stream UK Biobank data in chunks
"../ukb_data.csv",
sep=',',
chunksize=10000, # Stream file in manageable chunks to reduce memory footprint
# First column (participant ID) becomes DataFrame index
index_col=0,
low_memory=False # Disable type inference optimization for consistent dtypes
)
# Iterate chunks with progress
for ukb_chunk in tqdm.tqdm(ukb_iterator, desc="Processing UK Biobank data"):
# Rename columns to friendly names
ukb_chunk = ukb_chunk.rename(columns=field_dict)
# Require sex to be present
ukb_chunk.dropna(subset=['sex'], inplace=True)
ukb_chunk['sex'] += 2 # Recode sex: 0-> 2, 1 -> 3
# Construct date of birth from year and month (day fixed to 1)
ukb_chunk['dob'] = pd.to_datetime(
# Guard against malformed dates
ukb_chunk[['year', 'month']].assign(DAY=1), errors='coerce'
)
# Use only date variables that actually exist in the current chunk
present_date_vars = [c for c in date_vars if c in ukb_chunk.columns]
# Convert date-like columns to datetime and compute day offsets from dob
if present_date_vars:
date_cols = ukb_chunk[present_date_vars].apply(
pd.to_datetime, format="%Y-%m-%d", errors='coerce' # Parse dates safely
)
date_cols_days = date_cols.sub(
ukb_chunk['dob'], axis=0) # Timedelta relative to dob
ukb_chunk[present_date_vars] = date_cols_days.apply(
lambda x: x.dt.days) # Store days since dob
# Process disease events from ICD10-related date columns
# Take ICD date cols plus 'Death' if present by order
icd10_cols = present_date_vars[:len_icd + 1]
# Melt to long form: participant id, event code (column name), and days offset
melted_df = ukb_chunk.reset_index().melt(
id_vars=['eid'],
value_vars=icd10_cols,
var_name='event_code',
value_name='days',
)
# Require non-missing day offsets
melted_df.dropna(subset=['days'], inplace=True)
if not melted_df.empty:
melted_df['label'] = melted_df['event_code'].map(
label_dict) # Map event code to numeric label
# Fix: ensure labels exist before int cast
melted_df.dropna(subset=['label'], inplace=True)
if not melted_df.empty:
event_list.append(
melted_df[['eid', 'days', 'label']]
.astype(int) # Safe now since label and days are non-null
.to_numpy()
)
df_res = ukb_chunk.reset_index() # Bring participant ID out of index
# Simplify stub names for wide_to_long
# Rename date stubs
rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)}
rename_dict.update(
# Rename type stubs
{f'cancer_type_{j}': f'cancertype{j}' for j in range(17)})
df_renamed = df_res.rename(columns=rename_dict) # Apply renaming
stubs_to_use = [] # Collect available stubs
if any('cancerdate' in col for col in df_renamed.columns):
stubs_to_use.append('cancerdate') # Date stub present
if any('cancertype' in col for col in df_renamed.columns):
stubs_to_use.append('cancertype') # Type stub present
if len(stubs_to_use) == 2: # Only proceed if both date and type columns exist
long_cancer = pd.wide_to_long(
df_renamed,
stubnames=stubs_to_use,
i=['eid'], # Participant ID identifier
j='cancer_num' # Index over cancer record number (0..16)
).dropna() # Remove rows missing either date or type
if not long_cancer.empty:
long_cancer['cancer'] = long_cancer['cancertype'].str.slice(
0, 3) # Use first 3 chars as code
long_cancer['cancer_label'] = long_cancer['cancer'].map(
label_dict) # Map to label id
cancer_array = (
long_cancer.reset_index(
)[['eid', 'cancerdate', 'cancer_label']]
.dropna()
.astype(int)
.to_numpy()
)
if cancer_array.size > 0:
event_list.append(cancer_array) # Append cancer events
# Process BMI, smoking, and alcohol status
ukb_bmi = ukb_chunk[['date_of_assessment', 'bmi']].dropna().reset_index()
if not ukb_bmi.empty:
ukb_bmi['bmi_status'] = np.select(
[ukb_bmi['bmi'] > 28, ukb_bmi['bmi'] > 22],
[6, 5],
default=4
)
event_list.append(
ukb_bmi[['eid', 'date_of_assessment', 'bmi_status']]
.astype(int)
.to_numpy()
)
ukb_sm = ukb_chunk[['date_of_assessment', 'smoking']].dropna().reset_index()
ukb_sm = ukb_sm[ukb_sm['smoking'] != -3] # Exclude unknown smoking status
if not ukb_sm.empty:
ukb_sm['smoking_status'] = np.select(
[ukb_sm['smoking'] == 1, ukb_sm['smoking'] == 2],
[9, 8],
default=7
)
event_list.append(
ukb_sm[['eid', 'date_of_assessment', 'smoking_status']]
.astype(int)
.to_numpy()
)
ukb_al = ukb_chunk[['date_of_assessment', 'alcohol']].dropna().reset_index()
ukb_al = ukb_al[ukb_al['alcohol'] != -3] # Exclude unknown alcohol status
if not ukb_al.empty:
ukb_al['alcohol_status'] = np.select(
[ukb_al['alcohol'] == 1, ukb_al['alcohol'] < 4],
[12, 11],
default=10
)
event_list.append(
ukb_al[['eid', 'date_of_assessment', 'alcohol_status']]
.astype(int)
.to_numpy()
)
# Combine tabular chunks
data = np.vstack(event_list) # Stack all event arrays into one
# Sort by participant then day
data = data[np.lexsort((data[:, 1], data[:, 0]))]
# Keep only events with non-negative day offsets
data = data[data[:, 1] >= 0]
# Remove duplicate (participant_id, label) pairs keeping first occurrence.
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
# Store compactly using unsigned 32-bit integers
data = data.astype(np.uint32)
# Split data into train/val/test based on unique participant IDs
unique_ids = np.unique(data[:, 0]) # Unique participant IDs
train_split_id = unique_ids[int(len(unique_ids) * train_frac)]
val_split_id = unique_ids[int(len(unique_ids) * (train_frac + val_frac))]
train_data = data[data[:, 0] <= train_split_id].tofile("ukb_real_train.bin")
val_data = data[(data[:, 0] > train_split_id) & (
data[:, 0] <= val_split_id)].tofile("ukb_real_val.bin")
test_data = data[data[:, 0] > val_split_id].tofile("ukb_real_test.bin")

112
losses.py Normal file
View File

@@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ExponentialNLLLoss(nn.Module):
def __init__(
self,
n_tech_tokens: int,
alpha: float = 0.1,
):
super().__init__()
self.n_tech_tokens = n_tech_tokens
self.alpha = alpha
def forward(
self,
logits: torch.Tensor,
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
) -> torch.Tensor:
# Calculate the negative log-likelihood for the exponential distribution
# 1, shift event_seqs to remove technical tokens
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
mask = target_event_seqs >= 0
# 2, create a mask to filter out technical tokens
if not mask.any():
# if there are no valid events, return zero loss
return logits.new_zeros(())
# 3, compute time differences
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
dt = dt[mask] # (N,)
# 4, filter target events
target_events = target_event_seqs[mask] # (N,)
# 5, compute hazard and total hazard
hazard = logits[:, :-1, :] # (B, L-1, vocab_size)
hazard_at_events = hazard[mask].gather(
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
total_hazard = hazard[mask].sum(dim=-1) # (N,)
# 6, compute negative log-likelihood
nll = torch.log(hazard_at_events + 1e-6) - total_hazard * dt
nll = -nll.mean()
# 7, compute cross-entropy regularization
p_ce = hazard_at_events / total_hazard
regularization = -self.alpha * torch.log(p_ce + 1e-6).mean()
return nll + regularization
class WeibullLosses(nn.Module):
def __init__(
self,
n_tech_tokens: int,
alpha: float = 0.1,
):
super().__init__()
self.n_tech_tokens = n_tech_tokens
self.alpha = alpha
def forward(
self,
shapes: torch.Tensor,
scales: torch.Tensor,
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
) -> torch.Tensor:
# Calculate the negative log-likelihood for the Weibull distribution
# 1, shift event_seqs to remove technical tokens
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
mask = target_event_seqs >= 0
# 2, create a mask to filter out technical tokens
if not mask.any():
# if there are no valid events, return zero loss
return shapes.new_zeros(())
# 3, compute time differences
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
dt = dt[mask] # (N,)
# 4, filter target events
target_events = target_event_seqs[mask] # (N,)
shapes = shapes[mask] # (N, vocab_size)
scales = scales[mask] # (N, vocab_size)
# 5, compute shape and scale at events
shape_at_events = shapes.gather(
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
scale_at_events = scales.gather(
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
log_shapes = torch.log(shape_at_events)
log_scales = torch.log(scale_at_events)
log_dt = torch.log(dt + 1e-6)
# 6, compute negative log-likelihood
nll = log_shapes - log_scales + \
(shape_at_events - 1) * (log_dt - log_scales)
log_tot_survival = (dt.unsqueeze(-1) /
scales) ** shapes # (N, vocab_size)
nll -= log_tot_survival.sum(dim=-1)
nll = -nll.mean()
# 7, compute cross-entropy regularization
log_shapes_all = torch.log(shapes)
log_scales_all = torch.log(scales)
log_dt_expanded = log_dt.unsqueeze(-1)
log_hazards = log_shapes_all - log_scales_all + (shapes - 1) * \
(log_dt_expanded - log_scales_all) # (N, vocab_size)
ce_loss = F.cross_entropy(
log_hazards, target_events, reduction='mean')
return nll + self.alpha * ce_loss

129
model.py Normal file
View File

@@ -0,0 +1,129 @@
import torch
import torch.nn as nn
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
from backbones import Block, ModernBlock, RMSNorm
class TabularEncoder(nn.Module):
def __init__(
self,
n_embd: int,
n_continuous: int,
n_categorical: int,
categorical_cardinalities: list[int],
):
super().__init__()
self.continuous_proj = nn.Linear(n_continuous, n_embd) if n_continuous > 0 else None
self.categorical_embeddings = nn.ModuleList([
nn.Embedding(cardinality, n_embd) for cardinality in categorical_cardinalities
]) if n_categorical > 0 else None
def forward(
self,
continuous_features: torch.Tensor | None,
categorical_features: list[torch.Tensor] | None,
) -> torch.Tensor:
embeddings = []
if self.continuous_proj is not None and continuous_features is not None:
cont_emb = self.continuous_proj(continuous_features)
embeddings.append(cont_emb)
if self.categorical_embeddings is not None and categorical_features is not None:
for emb_layer, cat_feat in zip(self.categorical_embeddings, categorical_features):
cat_emb = emb_layer(cat_feat)
embeddings.append(cat_emb)
if embeddings:
return torch.sum(torch.stack(embeddings, dim=0), dim=0)
else:
raise ValueError("No features provided for TabularEncoder.")
def merge_two_sequences(
time_seq1: torch.Tensor, # (B, L1)
time_seq2: torch.Tensor, # (B, L2)
seq1_embd: torch.Tensor, # (B, L1, D)
seq2_embd: torch.Tensor, # (B, L2, D)
) -> torch.Tensor:
"""Merge two time sequences and their embeddings based on time order."""
B, L1 = time_seq1.shape
L2 = time_seq2.shape[1]
merged_times = torch.cat([time_seq1, time_seq2], dim=1) # (B, L1 + L2)
merged_embd = torch.cat([seq1_embd, seq2_embd], dim=1) # (B, L1 + L2, D)
sorted_times, indices = torch.sort(merged_times, dim=1) # (B, L1 + L2)
batch_indices = torch.arange(B).unsqueeze(-1).expand(-1, L1 + L2) # (B, L1 + L2)
sorted_embd = merged_embd[batch_indices, indices] # (B, L1 + L2, D)
return sorted_times, sorted_embd
class DelphiFork(nn.Module):
def __init__(
self,
vocab_size: int,
n_embd: int,
n_head: int,
n_layer: int,
n_continuous: int,
n_categorical: int,
categorical_cardinalities: list[int],
pdrop: float = 0.1,
token_pdrop: float = 0.1,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
self.sex_encoder = nn.Embedding(2, n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
self.covariate_encoder = TabularEncoder(
n_embd=n_embd,
n_continuous=n_continuous,
n_categorical=n_categorical,
categorical_cardinalities=categorical_cardinalities,
)
self.blocks = nn.ModuleList([
Block(
n_embd=n_embd,
n_head=n_head,
pdrop=pdrop,
) for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
self.head.weight = self.token_embedding.weight
def forward(
self,
sex: torch.Tensor,
event_seq: torch.Tensor,
age_seq: torch.Tensor,
cov_seq_time: torch.Tensor | None = None,
cont_cov_seq: torch.Tensor | None = None,
cat_cov_seq: list[torch.Tensor] | None = None,
) -> torch.Tensor:
event_emb = self.token_embedding(event_seq)
age_emb = self.age_encoder(age_seq)
sex_emb = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1) -> (B, 1, n_embd)
x = event_emb + age_emb + sex_emb
if cov_seq_time is not None:
covariate_emb = self.covariate_encoder(
continuous_features=cont_cov_seq,
categorical_features=cat_cov_seq,
)
covariate_emb = covariate_emb + self.age_encoder(cov_seq_time) + sex_emb
x = merge_two_sequences(age_seq, cov_seq_time, x, covariate_emb)
x = self.token_dropout(x)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
return logits