Refactor data preparation and add loss functions for model training
- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
This commit is contained in:
40
age_encoder.py
Normal file
40
age_encoder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class AgeSinusoidalEncoder(nn.Module):
|
||||
def __init__(self, n_embd: int):
|
||||
super().__init__()
|
||||
if n_embd % 2 != 0:
|
||||
raise ValueError("n_embd must be even for sinusoidal encoding.")
|
||||
self.n_embd = n_embd
|
||||
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
|
||||
divisor = torch.pow(10000, i / self.n_embd)
|
||||
self.register_buffer('divisor', divisor)
|
||||
|
||||
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
||||
t_years = ages / 365.25
|
||||
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
|
||||
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
|
||||
# Interleave cos and sin along the last dimension
|
||||
output = torch.zeros(
|
||||
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
|
||||
output[:, :, 0::2] = torch.cos(args)
|
||||
output[:, :, 1::2] = torch.sin(args)
|
||||
return output
|
||||
|
||||
class AgeMLPEncoder(nn.Module):
|
||||
def __init__(self, n_embd: int):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(2, 4 * n_embd),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * n_embd, n_embd),
|
||||
)
|
||||
|
||||
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
||||
ages = ages.unsqueeze(-1).float() # (B, L, 1)
|
||||
ages_normalized = ages / 365.25 # normalize to years
|
||||
log1page = torch.log1p(ages_normalized) # (B, L, 1)
|
||||
ages = torch.cat([ages_normalized, log1page], dim=-1) # (B, L, 2)
|
||||
output = self.mlp(ages) # (B, L, n_embd)
|
||||
return output
|
||||
164
backbones.py
Normal file
164
backbones.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
eps: float = 1e-8,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_embd = n_embd
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(n_embd))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
norm_x = x.norm(2, dim=-1, keepdim=True)
|
||||
rms_x = norm_x * (self.n_embd ** -0.5)
|
||||
x_normed = x / (rms_x + self.eps)
|
||||
return self.weight * x_normed
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
attn_pdrop: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
|
||||
self.n_head = n_head
|
||||
self.head_dim = n_embd // n_head
|
||||
|
||||
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
|
||||
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
|
||||
self.attn_pdrop = attn_pdrop
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
B, L, D = x.shape
|
||||
qkv = self.qkv_proj(x) # (B, L, 3D)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
def reshape_heads(t):
|
||||
# (B, H, L, d)
|
||||
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
||||
|
||||
q = reshape_heads(q)
|
||||
k = reshape_heads(k)
|
||||
v = reshape_heads(v)
|
||||
|
||||
attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_pdrop,
|
||||
) # (B, H, L, d)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
||||
return self.o_proj(attn)
|
||||
|
||||
class SwiGLUMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
pdrop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = 4 * n_embd
|
||||
self.fc1 = nn.Linear(n_embd, 2 * hidden_dim, bias=False)
|
||||
self.fc2 = nn.Linear(hidden_dim, n_embd, bias=False)
|
||||
self.dropout = nn.Dropout(pdrop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = self.fc1(x).chunk(2, dim=-1)
|
||||
# SwiGLU: silu(x1) * x2
|
||||
x = F.silu(x1) * x2
|
||||
x = self.fc2(x)
|
||||
return self.dropout(x)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
pdrop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
attn_pdrop = pdrop
|
||||
|
||||
self.norm_1 = nn.LayerNorm(n_embd)
|
||||
self.attn = SelfAttention(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
attn_pdrop=attn_pdrop,
|
||||
)
|
||||
self.norm_2 = nn.LayerNorm(n_embd)
|
||||
self.mlp = nn.ModuleDict(dict(
|
||||
c_fc=nn.Linear(n_embd, 4 * n_embd),
|
||||
c_proj=nn.Linear(4 * n_embd, n_embd),
|
||||
act=nn.GELU(),
|
||||
dropout=nn.Dropout(pdrop),
|
||||
))
|
||||
m = self.mlp
|
||||
self.mlpf = lambda x: m.dropout(
|
||||
m.c_proj(m.act(m.c_fc(x))))
|
||||
self.resid_dropout = nn.Dropout(pdrop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Attention
|
||||
h = self.norm_1(x)
|
||||
h = self.attn(h, attn_mask=attn_mask)
|
||||
x = x + self.resid_dropout(h)
|
||||
|
||||
# MLP
|
||||
h = self.norm_2(x)
|
||||
h = self.mlpf(h)
|
||||
x = x + self.resid_dropout(h)
|
||||
|
||||
return x
|
||||
|
||||
class ModernBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
pdrop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
attn_pdrop = pdrop
|
||||
mlp_pdrop = pdrop
|
||||
|
||||
self.norm_1 = RMSNorm(n_embd)
|
||||
self.attn = SelfAttention(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
attn_pdrop=attn_pdrop,
|
||||
)
|
||||
self.norm_2 = RMSNorm(n_embd)
|
||||
self.mlp = SwiGLUMLP(n_embd=n_embd, pdrop=mlp_pdrop)
|
||||
self.resid_dropout = nn.Dropout(pdrop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
h = self.norm_1(x)
|
||||
h = self.attn(h, attn_mask=attn_mask)
|
||||
x = x + self.resid_dropout(h)
|
||||
|
||||
# MLP
|
||||
h = self.norm_2(x)
|
||||
h = self.mlp(h)
|
||||
x = x + self.resid_dropout(h)
|
||||
|
||||
return x
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,216 +0,0 @@
|
||||
import pandas as pd # Pandas for data manipulation
|
||||
import tqdm # Progress bar for chunk processing
|
||||
import numpy as np # Numerical operations
|
||||
|
||||
train_frac = 0.7 # Fraction of participants for training split
|
||||
val_frac = 0.15 # Fraction of participants for validation split
|
||||
test_frac = 0.15 # Fraction of participants for test split
|
||||
|
||||
# CSV mapping field IDs to human-readable names
|
||||
field_map_file = "../field_ids_enriched.csv"
|
||||
field_dict = {} # Map original field ID -> new column name
|
||||
with open(field_map_file, "r", encoding="utf-8") as f: # Open the field mapping file
|
||||
next(f) # skip header line
|
||||
for line in f: # Iterate through lines
|
||||
parts = line.strip().split(",") # Split by CSV commas
|
||||
if len(parts) >= 3: # Ensure we have at least id and name columns (fix: was >=2)
|
||||
# Original field identifier (e.g., "34-0.0")
|
||||
field_id = parts[0]
|
||||
field_name = parts[2] # Human-readable column name
|
||||
field_dict[field_id] = field_name # Record the mapping
|
||||
# Track as a potential tabular feature
|
||||
|
||||
|
||||
# TSV mapping field IDs to ICD10-related date columns
|
||||
field_to_icd_map = "../icd10_codes_mod.tsv"
|
||||
# Date-like variables to be converted to offsets
|
||||
date_vars = []
|
||||
with open(field_to_icd_map, "r", encoding="utf-8") as f: # Open ICD10 mapping
|
||||
for line in f: # Iterate each mapping row
|
||||
parts = line.strip().split() # Split on whitespace for TSV
|
||||
if len(parts) >= 6: # Guard against malformed lines
|
||||
# Map field ID to the date column name
|
||||
field_dict[parts[0]] = parts[5]
|
||||
date_vars.append(parts[5]) # Track date column names in order
|
||||
|
||||
for j in range(17): # Map up to 17 cancer entry slots (dates and types)
|
||||
# Cancer diagnosis date slot j
|
||||
field_dict[f'40005-{j}.0'] = f'cancer_date_{j}'
|
||||
field_dict[f'40006-{j}.0'] = f'cancer_type_{j}' # Cancer type/code slot j
|
||||
|
||||
# Number of ICD-related date columns before adding extras
|
||||
len_icd = len(date_vars)
|
||||
date_vars.extend(['Death', 'date_of_assessment'] + # Add outcome date and assessment date
|
||||
# Add cancer date columns
|
||||
[f'cancer_date_{j}' for j in range(17)])
|
||||
|
||||
labels_file = "labels.csv" # File listing label codes
|
||||
label_dict = {} # Map code string -> integer label id
|
||||
with open(labels_file, "r", encoding="utf-8") as f: # Open labels file
|
||||
for idx, line in enumerate(f): # Enumerate to assign incremental label IDs
|
||||
parts = line.strip().split(' ') # Split by space
|
||||
if parts and parts[0]: # Guard against empty lines
|
||||
label_dict[parts[0]] = idx
|
||||
|
||||
event_list = [] # Accumulator for event arrays across chunks
|
||||
ukb_iterator = pd.read_csv( # Stream UK Biobank data in chunks
|
||||
"../ukb_data.csv",
|
||||
sep=',',
|
||||
chunksize=10000, # Stream file in manageable chunks to reduce memory footprint
|
||||
# First column (participant ID) becomes DataFrame index
|
||||
index_col=0,
|
||||
low_memory=False # Disable type inference optimization for consistent dtypes
|
||||
)
|
||||
# Iterate chunks with progress
|
||||
for ukb_chunk in tqdm.tqdm(ukb_iterator, desc="Processing UK Biobank data"):
|
||||
# Rename columns to friendly names
|
||||
ukb_chunk = ukb_chunk.rename(columns=field_dict)
|
||||
# Require sex to be present
|
||||
ukb_chunk.dropna(subset=['sex'], inplace=True)
|
||||
ukb_chunk['sex'] += 2 # Recode sex: 0-> 2, 1 -> 3
|
||||
|
||||
# Construct date of birth from year and month (day fixed to 1)
|
||||
ukb_chunk['dob'] = pd.to_datetime(
|
||||
# Guard against malformed dates
|
||||
ukb_chunk[['year', 'month']].assign(DAY=1), errors='coerce'
|
||||
)
|
||||
|
||||
# Use only date variables that actually exist in the current chunk
|
||||
present_date_vars = [c for c in date_vars if c in ukb_chunk.columns]
|
||||
|
||||
# Convert date-like columns to datetime and compute day offsets from dob
|
||||
if present_date_vars:
|
||||
date_cols = ukb_chunk[present_date_vars].apply(
|
||||
pd.to_datetime, format="%Y-%m-%d", errors='coerce' # Parse dates safely
|
||||
)
|
||||
date_cols_days = date_cols.sub(
|
||||
ukb_chunk['dob'], axis=0) # Timedelta relative to dob
|
||||
ukb_chunk[present_date_vars] = date_cols_days.apply(
|
||||
lambda x: x.dt.days) # Store days since dob
|
||||
|
||||
# Process disease events from ICD10-related date columns
|
||||
# Take ICD date cols plus 'Death' if present by order
|
||||
icd10_cols = present_date_vars[:len_icd + 1]
|
||||
# Melt to long form: participant id, event code (column name), and days offset
|
||||
melted_df = ukb_chunk.reset_index().melt(
|
||||
id_vars=['eid'],
|
||||
value_vars=icd10_cols,
|
||||
var_name='event_code',
|
||||
value_name='days',
|
||||
)
|
||||
# Require non-missing day offsets
|
||||
melted_df.dropna(subset=['days'], inplace=True)
|
||||
if not melted_df.empty:
|
||||
melted_df['label'] = melted_df['event_code'].map(
|
||||
label_dict) # Map event code to numeric label
|
||||
# Fix: ensure labels exist before int cast
|
||||
melted_df.dropna(subset=['label'], inplace=True)
|
||||
if not melted_df.empty:
|
||||
event_list.append(
|
||||
melted_df[['eid', 'days', 'label']]
|
||||
.astype(int) # Safe now since label and days are non-null
|
||||
.to_numpy()
|
||||
)
|
||||
|
||||
df_res = ukb_chunk.reset_index() # Bring participant ID out of index
|
||||
# Simplify stub names for wide_to_long
|
||||
# Rename date stubs
|
||||
rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)}
|
||||
rename_dict.update(
|
||||
# Rename type stubs
|
||||
{f'cancer_type_{j}': f'cancertype{j}' for j in range(17)})
|
||||
df_renamed = df_res.rename(columns=rename_dict) # Apply renaming
|
||||
stubs_to_use = [] # Collect available stubs
|
||||
if any('cancerdate' in col for col in df_renamed.columns):
|
||||
stubs_to_use.append('cancerdate') # Date stub present
|
||||
if any('cancertype' in col for col in df_renamed.columns):
|
||||
stubs_to_use.append('cancertype') # Type stub present
|
||||
|
||||
if len(stubs_to_use) == 2: # Only proceed if both date and type columns exist
|
||||
long_cancer = pd.wide_to_long(
|
||||
df_renamed,
|
||||
stubnames=stubs_to_use,
|
||||
i=['eid'], # Participant ID identifier
|
||||
j='cancer_num' # Index over cancer record number (0..16)
|
||||
).dropna() # Remove rows missing either date or type
|
||||
if not long_cancer.empty:
|
||||
long_cancer['cancer'] = long_cancer['cancertype'].str.slice(
|
||||
0, 3) # Use first 3 chars as code
|
||||
long_cancer['cancer_label'] = long_cancer['cancer'].map(
|
||||
label_dict) # Map to label id
|
||||
cancer_array = (
|
||||
long_cancer.reset_index(
|
||||
)[['eid', 'cancerdate', 'cancer_label']]
|
||||
.dropna()
|
||||
.astype(int)
|
||||
.to_numpy()
|
||||
)
|
||||
if cancer_array.size > 0:
|
||||
event_list.append(cancer_array) # Append cancer events
|
||||
|
||||
# Process BMI, smoking, and alcohol status
|
||||
ukb_bmi = ukb_chunk[['date_of_assessment', 'bmi']].dropna().reset_index()
|
||||
if not ukb_bmi.empty:
|
||||
ukb_bmi['bmi_status'] = np.select(
|
||||
[ukb_bmi['bmi'] > 28, ukb_bmi['bmi'] > 22],
|
||||
[6, 5],
|
||||
default=4
|
||||
)
|
||||
event_list.append(
|
||||
ukb_bmi[['eid', 'date_of_assessment', 'bmi_status']]
|
||||
.astype(int)
|
||||
.to_numpy()
|
||||
)
|
||||
|
||||
ukb_sm = ukb_chunk[['date_of_assessment', 'smoking']].dropna().reset_index()
|
||||
ukb_sm = ukb_sm[ukb_sm['smoking'] != -3] # Exclude unknown smoking status
|
||||
if not ukb_sm.empty:
|
||||
ukb_sm['smoking_status'] = np.select(
|
||||
[ukb_sm['smoking'] == 1, ukb_sm['smoking'] == 2],
|
||||
[9, 8],
|
||||
default=7
|
||||
)
|
||||
event_list.append(
|
||||
ukb_sm[['eid', 'date_of_assessment', 'smoking_status']]
|
||||
.astype(int)
|
||||
.to_numpy()
|
||||
)
|
||||
ukb_al = ukb_chunk[['date_of_assessment', 'alcohol']].dropna().reset_index()
|
||||
ukb_al = ukb_al[ukb_al['alcohol'] != -3] # Exclude unknown alcohol status
|
||||
if not ukb_al.empty:
|
||||
ukb_al['alcohol_status'] = np.select(
|
||||
[ukb_al['alcohol'] == 1, ukb_al['alcohol'] < 4],
|
||||
[12, 11],
|
||||
default=10
|
||||
)
|
||||
event_list.append(
|
||||
ukb_al[['eid', 'date_of_assessment', 'alcohol_status']]
|
||||
.astype(int)
|
||||
.to_numpy()
|
||||
)
|
||||
|
||||
# Combine tabular chunks
|
||||
|
||||
data = np.vstack(event_list) # Stack all event arrays into one
|
||||
|
||||
# Sort by participant then day
|
||||
data = data[np.lexsort((data[:, 1], data[:, 0]))]
|
||||
|
||||
# Keep only events with non-negative day offsets
|
||||
data = data[data[:, 1] >= 0]
|
||||
|
||||
# Remove duplicate (participant_id, label) pairs keeping first occurrence.
|
||||
data = pd.DataFrame(data).drop_duplicates([0, 2]).values
|
||||
|
||||
# Store compactly using unsigned 32-bit integers
|
||||
data = data.astype(np.uint32)
|
||||
|
||||
# Split data into train/val/test based on unique participant IDs
|
||||
unique_ids = np.unique(data[:, 0]) # Unique participant IDs
|
||||
train_split_id = unique_ids[int(len(unique_ids) * train_frac)]
|
||||
val_split_id = unique_ids[int(len(unique_ids) * (train_frac + val_frac))]
|
||||
|
||||
train_data = data[data[:, 0] <= train_split_id].tofile("ukb_real_train.bin")
|
||||
val_data = data[(data[:, 0] > train_split_id) & (
|
||||
data[:, 0] <= val_split_id)].tofile("ukb_real_val.bin")
|
||||
test_data = data[data[:, 0] > val_split_id].tofile("ukb_real_test.bin")
|
||||
112
losses.py
Normal file
112
losses.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ExponentialNLLLoss(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tech_tokens: int,
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_tech_tokens = n_tech_tokens
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
event_seqs: torch.Tensor,
|
||||
time_seqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate the negative log-likelihood for the exponential distribution
|
||||
|
||||
# 1, shift event_seqs to remove technical tokens
|
||||
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
|
||||
mask = target_event_seqs >= 0
|
||||
# 2, create a mask to filter out technical tokens
|
||||
if not mask.any():
|
||||
# if there are no valid events, return zero loss
|
||||
return logits.new_zeros(())
|
||||
|
||||
# 3, compute time differences
|
||||
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
|
||||
dt = dt[mask] # (N,)
|
||||
# 4, filter target events
|
||||
target_events = target_event_seqs[mask] # (N,)
|
||||
# 5, compute hazard and total hazard
|
||||
hazard = logits[:, :-1, :] # (B, L-1, vocab_size)
|
||||
hazard_at_events = hazard[mask].gather(
|
||||
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
||||
total_hazard = hazard[mask].sum(dim=-1) # (N,)
|
||||
# 6, compute negative log-likelihood
|
||||
nll = torch.log(hazard_at_events + 1e-6) - total_hazard * dt
|
||||
nll = -nll.mean()
|
||||
# 7, compute cross-entropy regularization
|
||||
p_ce = hazard_at_events / total_hazard
|
||||
regularization = -self.alpha * torch.log(p_ce + 1e-6).mean()
|
||||
|
||||
return nll + regularization
|
||||
|
||||
|
||||
class WeibullLosses(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tech_tokens: int,
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_tech_tokens = n_tech_tokens
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
shapes: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
event_seqs: torch.Tensor,
|
||||
time_seqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate the negative log-likelihood for the Weibull distribution
|
||||
|
||||
# 1, shift event_seqs to remove technical tokens
|
||||
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
|
||||
mask = target_event_seqs >= 0
|
||||
# 2, create a mask to filter out technical tokens
|
||||
if not mask.any():
|
||||
# if there are no valid events, return zero loss
|
||||
return shapes.new_zeros(())
|
||||
|
||||
# 3, compute time differences
|
||||
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
|
||||
dt = dt[mask] # (N,)
|
||||
# 4, filter target events
|
||||
target_events = target_event_seqs[mask] # (N,)
|
||||
shapes = shapes[mask] # (N, vocab_size)
|
||||
scales = scales[mask] # (N, vocab_size)
|
||||
# 5, compute shape and scale at events
|
||||
shape_at_events = shapes.gather(
|
||||
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
||||
scale_at_events = scales.gather(
|
||||
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
||||
log_shapes = torch.log(shape_at_events)
|
||||
log_scales = torch.log(scale_at_events)
|
||||
log_dt = torch.log(dt + 1e-6)
|
||||
# 6, compute negative log-likelihood
|
||||
nll = log_shapes - log_scales + \
|
||||
(shape_at_events - 1) * (log_dt - log_scales)
|
||||
log_tot_survival = (dt.unsqueeze(-1) /
|
||||
scales) ** shapes # (N, vocab_size)
|
||||
nll -= log_tot_survival.sum(dim=-1)
|
||||
nll = -nll.mean()
|
||||
# 7, compute cross-entropy regularization
|
||||
log_shapes_all = torch.log(shapes)
|
||||
log_scales_all = torch.log(scales)
|
||||
log_dt_expanded = log_dt.unsqueeze(-1)
|
||||
|
||||
log_hazards = log_shapes_all - log_scales_all + (shapes - 1) * \
|
||||
(log_dt_expanded - log_scales_all) # (N, vocab_size)
|
||||
ce_loss = F.cross_entropy(
|
||||
log_hazards, target_events, reduction='mean')
|
||||
|
||||
return nll + self.alpha * ce_loss
|
||||
129
model.py
Normal file
129
model.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
from backbones import Block, ModernBlock, RMSNorm
|
||||
|
||||
class TabularEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_continuous: int,
|
||||
n_categorical: int,
|
||||
categorical_cardinalities: list[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.continuous_proj = nn.Linear(n_continuous, n_embd) if n_continuous > 0 else None
|
||||
self.categorical_embeddings = nn.ModuleList([
|
||||
nn.Embedding(cardinality, n_embd) for cardinality in categorical_cardinalities
|
||||
]) if n_categorical > 0 else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
continuous_features: torch.Tensor | None,
|
||||
categorical_features: list[torch.Tensor] | None,
|
||||
) -> torch.Tensor:
|
||||
embeddings = []
|
||||
if self.continuous_proj is not None and continuous_features is not None:
|
||||
cont_emb = self.continuous_proj(continuous_features)
|
||||
embeddings.append(cont_emb)
|
||||
if self.categorical_embeddings is not None and categorical_features is not None:
|
||||
for emb_layer, cat_feat in zip(self.categorical_embeddings, categorical_features):
|
||||
cat_emb = emb_layer(cat_feat)
|
||||
embeddings.append(cat_emb)
|
||||
if embeddings:
|
||||
return torch.sum(torch.stack(embeddings, dim=0), dim=0)
|
||||
else:
|
||||
raise ValueError("No features provided for TabularEncoder.")
|
||||
|
||||
def merge_two_sequences(
|
||||
time_seq1: torch.Tensor, # (B, L1)
|
||||
time_seq2: torch.Tensor, # (B, L2)
|
||||
seq1_embd: torch.Tensor, # (B, L1, D)
|
||||
seq2_embd: torch.Tensor, # (B, L2, D)
|
||||
) -> torch.Tensor:
|
||||
"""Merge two time sequences and their embeddings based on time order."""
|
||||
B, L1 = time_seq1.shape
|
||||
L2 = time_seq2.shape[1]
|
||||
merged_times = torch.cat([time_seq1, time_seq2], dim=1) # (B, L1 + L2)
|
||||
merged_embd = torch.cat([seq1_embd, seq2_embd], dim=1) # (B, L1 + L2, D)
|
||||
|
||||
sorted_times, indices = torch.sort(merged_times, dim=1) # (B, L1 + L2)
|
||||
batch_indices = torch.arange(B).unsqueeze(-1).expand(-1, L1 + L2) # (B, L1 + L2)
|
||||
sorted_embd = merged_embd[batch_indices, indices] # (B, L1 + L2, D)
|
||||
|
||||
return sorted_times, sorted_embd
|
||||
|
||||
|
||||
class DelphiFork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
n_embd: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
n_continuous: int,
|
||||
n_categorical: int,
|
||||
categorical_cardinalities: list[int],
|
||||
pdrop: float = 0.1,
|
||||
token_pdrop: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.token_embedding = nn.Embedding(vocab_size, n_embd)
|
||||
self.age_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
|
||||
self.sex_encoder = nn.Embedding(2, n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
self.covariate_encoder = TabularEncoder(
|
||||
n_embd=n_embd,
|
||||
n_continuous=n_continuous,
|
||||
n_categorical=n_categorical,
|
||||
categorical_cardinalities=categorical_cardinalities,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
n_embd=n_embd,
|
||||
n_head=n_head,
|
||||
pdrop=pdrop,
|
||||
) for _ in range(n_layer)
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
self.head.weight = self.token_embedding.weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sex: torch.Tensor,
|
||||
event_seq: torch.Tensor,
|
||||
age_seq: torch.Tensor,
|
||||
cov_seq_time: torch.Tensor | None = None,
|
||||
cont_cov_seq: torch.Tensor | None = None,
|
||||
cat_cov_seq: list[torch.Tensor] | None = None,
|
||||
|
||||
) -> torch.Tensor:
|
||||
|
||||
event_emb = self.token_embedding(event_seq)
|
||||
age_emb = self.age_encoder(age_seq)
|
||||
sex_emb = self.sex_encoder(sex.unsqueeze(-1)) # (B, 1) -> (B, 1, n_embd)
|
||||
|
||||
x = event_emb + age_emb + sex_emb
|
||||
if cov_seq_time is not None:
|
||||
covariate_emb = self.covariate_encoder(
|
||||
continuous_features=cont_cov_seq,
|
||||
categorical_features=cat_cov_seq,
|
||||
)
|
||||
covariate_emb = covariate_emb + self.age_encoder(cov_seq_time) + sex_emb
|
||||
x = merge_two_sequences(age_seq, cov_seq_time, x, covariate_emb)
|
||||
|
||||
x = self.token_dropout(x)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user