- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class AgeSinusoidalEncoder(nn.Module):
|
|
def __init__(self, n_embd: int):
|
|
super().__init__()
|
|
if n_embd % 2 != 0:
|
|
raise ValueError("n_embd must be even for sinusoidal encoding.")
|
|
self.n_embd = n_embd
|
|
i = torch.arange(0, self.n_embd, 2, dtype=torch.float32)
|
|
divisor = torch.pow(10000, i / self.n_embd)
|
|
self.register_buffer('divisor', divisor)
|
|
|
|
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
|
t_years = ages / 365.25
|
|
# Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2)
|
|
args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1)
|
|
# Interleave cos and sin along the last dimension
|
|
output = torch.zeros(
|
|
ages.shape[0], ages.shape[1], self.n_embd, device=ages.device)
|
|
output[:, :, 0::2] = torch.cos(args)
|
|
output[:, :, 1::2] = torch.sin(args)
|
|
return output
|
|
|
|
class AgeMLPEncoder(nn.Module):
|
|
def __init__(self, n_embd: int):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(2, 4 * n_embd),
|
|
nn.ReLU(),
|
|
nn.Linear(4 * n_embd, n_embd),
|
|
)
|
|
|
|
def forward(self, ages: torch.Tensor) -> torch.Tensor:
|
|
ages = ages.unsqueeze(-1).float() # (B, L, 1)
|
|
ages_normalized = ages / 365.25 # normalize to years
|
|
log1page = torch.log1p(ages_normalized) # (B, L, 1)
|
|
ages = torch.cat([ages_normalized, log1page], dim=-1) # (B, L, 2)
|
|
output = self.mlp(ages) # (B, L, n_embd)
|
|
return output |