import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class RMSNorm(nn.Module): def __init__( self, n_embd: int, eps: float = 1e-8, ): super().__init__() self.n_embd = n_embd self.eps = eps self.weight = nn.Parameter(torch.ones(n_embd)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * (self.n_embd ** -0.5) x_normed = x / (rms_x + self.eps) return self.weight * x_normed class SelfAttention(nn.Module): def __init__( self, n_embd: int, n_head: int, attn_pdrop: float = 0.1, ): super().__init__() assert n_embd % n_head == 0, "n_embd must be divisible by n_head" self.n_head = n_head self.head_dim = n_embd // n_head self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=False) self.o_proj = nn.Linear(n_embd, n_embd, bias=False) self.attn_pdrop = attn_pdrop def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, L, D = x.shape qkv = self.qkv_proj(x) # (B, L, 3D) q, k, v = qkv.chunk(3, dim=-1) def reshape_heads(t): # (B, H, L, d) return t.view(B, L, self.n_head, self.head_dim).transpose(1, 2) q = reshape_heads(q) k = reshape_heads(k) v = reshape_heads(v) attn = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_pdrop, ) # (B, H, L, d) attn = attn.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D) return self.o_proj(attn) class SwiGLUMLP(nn.Module): def __init__( self, n_embd: int, pdrop: float = 0.0, ): super().__init__() hidden_dim = 4 * n_embd self.fc1 = nn.Linear(n_embd, 2 * hidden_dim, bias=False) self.fc2 = nn.Linear(hidden_dim, n_embd, bias=False) self.dropout = nn.Dropout(pdrop) def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = self.fc1(x).chunk(2, dim=-1) # SwiGLU: silu(x1) * x2 x = F.silu(x1) * x2 x = self.fc2(x) return self.dropout(x) class Block(nn.Module): def __init__( self, n_embd: int, n_head: int, pdrop: float = 0.0, ): super().__init__() attn_pdrop = pdrop self.norm_1 = nn.LayerNorm(n_embd) self.attn = SelfAttention( n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, ) self.norm_2 = nn.LayerNorm(n_embd) self.mlp = nn.ModuleDict(dict( c_fc=nn.Linear(n_embd, 4 * n_embd), c_proj=nn.Linear(4 * n_embd, n_embd), act=nn.GELU(), dropout=nn.Dropout(pdrop), )) m = self.mlp self.mlpf = lambda x: m.dropout( m.c_proj(m.act(m.c_fc(x)))) self.resid_dropout = nn.Dropout(pdrop) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Attention h = self.norm_1(x) h = self.attn(h, attn_mask=attn_mask) x = x + self.resid_dropout(h) # MLP h = self.norm_2(x) h = self.mlpf(h) x = x + self.resid_dropout(h) return x class ModernBlock(nn.Module): def __init__( self, n_embd: int, n_head: int, pdrop: float = 0.0, ): super().__init__() attn_pdrop = pdrop mlp_pdrop = pdrop self.norm_1 = RMSNorm(n_embd) self.attn = SelfAttention( n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, ) self.norm_2 = RMSNorm(n_embd) self.mlp = SwiGLUMLP(n_embd=n_embd, pdrop=mlp_pdrop) self.resid_dropout = nn.Dropout(pdrop) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: h = self.norm_1(x) h = self.attn(h, attn_mask=attn_mask) x = x + self.resid_dropout(h) # MLP h = self.norm_2(x) h = self.mlp(h) x = x + self.resid_dropout(h) return x