Files
DeepHealth/age_encoder.py
Jiarui Li 6984b254b3 Add loss functions and model architecture for time-to-event prediction
- Implemented ExponentialNLLLoss and WeibullNLLLoss in losses.py for negative log-likelihood calculations.
- Developed TabularEncoder class in model.py for encoding tabular features.
- Created DelphiFork and SapDelphi classes in model.py for time-to-event prediction using transformer architecture.
- Added data preparation scripts in prepare_data.R and prepare_data.py for processing UK Biobank data, including handling field mappings and event data extraction.
2026-01-07 21:32:00 +08:00

56 lines
1.6 KiB
Python

import torch
import torch.nn as nn
class AgeSinusoidalEncoder(nn.Module):
"""
Sinusoidal encoder for age.
Args:
n_embd (int): Embedding dimension. Must be even.
"""
def __init__(self, n_embd: int):
super().__init__()
if n_embd % 2 != 0:
raise ValueError("n_embd must be even for sinusoidal encoding.")
self.n_embd = n_embd
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
divisor = torch.pow(10000, i / self.n_embd)
self.register_buffer('divisor', divisor)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
t_years = ages / 365.25
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
# Interleave cos and sin along the last dimension
output = torch.zeros(
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
output[:, :, 0::2] = torch.cos(args)
output[:, :, 1::2] = torch.sin(args)
return output
class AgeMLPEncoder(nn.Module):
"""
MLP encoder for age, using sinusoidal encoding as input.
Args:
n_embd (int): Embedding dimension.
"""
def __init__(self, n_embd: int):
super().__init__()
self.sin_encoder = AgeSinusoidalEncoder(n_embd=n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, ages: torch.Tensor) -> torch.Tensor:
x = self.sin_encoder(ages)
output = self.mlp(x)
return output