- Added `model.py` containing the DelphiBERT architecture, including TabularEncoder and AutoDiscretization classes for handling tabular features. - Introduced `prepare_data.R` for merging disease and other data from UK Biobank, ensuring proper column selection and data integrity. - Created `prepare_data.py` to process UK Biobank data, including mapping field IDs, handling date variables, and preparing tabular features and event data for model training.
56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class AgeSinusoidalEncoder(nn.Module):
|
|
"""
|
|
Sinusoidal encoder for age.
|
|
|
|
Args:
|
|
n_embd (int): Embedding dimension. Must be even.
|
|
"""
|
|
|
|
def __init__(self, n_embd: int):
|
|
super().__init__()
|
|
if n_embd % 2 != 0:
|
|
raise ValueError("n_embd must be even for sinusoidal encoding.")
|
|
self.n_embd = n_embd
|
|
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
|
|
divisor = torch.pow(10000, i / self.n_embd)
|
|
self.register_buffer('divisor', divisor)
|
|
|
|
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
|
t_years = ages / 365.25
|
|
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
|
|
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
|
|
# Interleave cos and sin along the last dimension
|
|
output = torch.zeros(
|
|
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
|
|
output[:, :, 0::2] = torch.cos(args)
|
|
output[:, :, 1::2] = torch.sin(args)
|
|
return output
|
|
|
|
|
|
class AgeMLPEncoder(nn.Module):
|
|
"""
|
|
MLP encoder for age, using sinusoidal encoding as input.
|
|
|
|
Args:
|
|
n_embd (int): Embedding dimension.
|
|
"""
|
|
|
|
def __init__(self, n_embd: int):
|
|
super().__init__()
|
|
|
|
self.sin_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(n_embd, 4 * n_embd),
|
|
nn.GELU(),
|
|
nn.Linear(4 * n_embd, n_embd),
|
|
)
|
|
|
|
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
|
x = self.sin_encoder(ages)
|
|
output = self.mlp(x)
|
|
return output
|