import torch import torch.nn as nn class AgeSinusoidalEncoder(nn.Module): def __init__(self, n_embd: int): super().__init__() if n_embd % 2 != 0: raise ValueError("n_embd must be even for sinusoidal encoding.") self.n_embd = n_embd i = torch.arange(0, self.n_embd, 2, dtype=torch.float32) divisor = torch.pow(10000, i / self.n_embd) self.register_buffer('divisor', divisor) def forward(self, ages: torch.Tensor) -> torch.Tensor: t_years = ages / 365.25 # Broadcast (B, L, 1) against (1, 1, D/2) to get (B, L, D/2) args = t_years.unsqueeze(-1) / self.divisor.view(1, 1, -1) # Interleave cos and sin along the last dimension output = torch.zeros( ages.shape[0], ages.shape[1], self.n_embd, device=ages.device) output[:, :, 0::2] = torch.cos(args) output[:, :, 1::2] = torch.sin(args) return output class AgeMLPEncoder(nn.Module): def __init__(self, n_embd: int): super().__init__() self.mlp = nn.Sequential( nn.Linear(2, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), ) def forward(self, ages: torch.Tensor) -> torch.Tensor: ages = ages.unsqueeze(-1).float() # (B, L, 1) ages_normalized = ages / 365.25 # normalize to years log1page = torch.log1p(ages_normalized) # (B, L, 1) ages = torch.cat([ages_normalized, log1page], dim=-1) # (B, L, 2) output = self.mlp(ages) # (B, L, n_embd) return output