- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
164 lines
4.4 KiB
Python
164 lines
4.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
eps: float = 1e-8,
|
|
):
|
|
super().__init__()
|
|
self.n_embd = n_embd
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(n_embd))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
norm_x = x.norm(2, dim=-1, keepdim=True)
|
|
rms_x = norm_x * (self.n_embd ** -0.5)
|
|
x_normed = x / (rms_x + self.eps)
|
|
return self.weight * x_normed
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
n_head: int,
|
|
attn_pdrop: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
|
|
self.n_head = n_head
|
|
self.head_dim = n_embd // n_head
|
|
|
|
self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False)
|
|
self.o_proj = nn.Linear(n_embd, n_embd, bias=False)
|
|
self.attn_pdrop = attn_pdrop
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
B, L, D = x.shape
|
|
qkv = self.qkv_proj(x) # (B, L, 3D)
|
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
def reshape_heads(t):
|
|
# (B, H, L, d)
|
|
return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2)
|
|
|
|
q = reshape_heads(q)
|
|
k = reshape_heads(k)
|
|
v = reshape_heads(v)
|
|
|
|
attn = F.scaled_dot_product_attention(
|
|
q, k, v,
|
|
attn_mask=attn_mask,
|
|
dropout_p=self.attn_pdrop,
|
|
) # (B, H, L, d)
|
|
|
|
attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
|
return self.o_proj(attn)
|
|
|
|
class SwiGLUMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
pdrop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
hidden_dim = 4 * n_embd
|
|
self.fc1 = nn.Linear(n_embd, 2 * hidden_dim, bias=False)
|
|
self.fc2 = nn.Linear(hidden_dim, n_embd, bias=False)
|
|
self.dropout = nn.Dropout(pdrop)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x1, x2 = self.fc1(x).chunk(2, dim=-1)
|
|
# SwiGLU: silu(x1) * x2
|
|
x = F.silu(x1) * x2
|
|
x = self.fc2(x)
|
|
return self.dropout(x)
|
|
|
|
class Block(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
n_head: int,
|
|
pdrop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
attn_pdrop = pdrop
|
|
|
|
self.norm_1 = nn.LayerNorm(n_embd)
|
|
self.attn = SelfAttention(
|
|
n_embd=n_embd,
|
|
n_head=n_head,
|
|
attn_pdrop=attn_pdrop,
|
|
)
|
|
self.norm_2 = nn.LayerNorm(n_embd)
|
|
self.mlp = nn.ModuleDict(dict(
|
|
c_fc=nn.Linear(n_embd, 4 * n_embd),
|
|
c_proj=nn.Linear(4 * n_embd, n_embd),
|
|
act=nn.GELU(),
|
|
dropout=nn.Dropout(pdrop),
|
|
))
|
|
m = self.mlp
|
|
self.mlpf = lambda x: m.dropout(
|
|
m.c_proj(m.act(m.c_fc(x))))
|
|
self.resid_dropout = nn.Dropout(pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
# Attention
|
|
h = self.norm_1(x)
|
|
h = self.attn(h, attn_mask=attn_mask)
|
|
x = x + self.resid_dropout(h)
|
|
|
|
# MLP
|
|
h = self.norm_2(x)
|
|
h = self.mlpf(h)
|
|
x = x + self.resid_dropout(h)
|
|
|
|
return x
|
|
|
|
class ModernBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_embd: int,
|
|
n_head: int,
|
|
pdrop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
attn_pdrop = pdrop
|
|
mlp_pdrop = pdrop
|
|
|
|
self.norm_1 = RMSNorm(n_embd)
|
|
self.attn = SelfAttention(
|
|
n_embd=n_embd,
|
|
n_head=n_head,
|
|
attn_pdrop=attn_pdrop,
|
|
)
|
|
self.norm_2 = RMSNorm(n_embd)
|
|
self.mlp = SwiGLUMLP(n_embd=n_embd, pdrop=mlp_pdrop)
|
|
self.resid_dropout = nn.Dropout(pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
h = self.norm_1(x)
|
|
h = self.attn(h, attn_mask=attn_mask)
|
|
x = x + self.resid_dropout(h)
|
|
|
|
# MLP
|
|
h = self.norm_2(x)
|
|
h = self.mlp(h)
|
|
x = x + self.resid_dropout(h)
|
|
|
|
return x |