diff --git a/model.py b/model.py index 467b53f..ed21378 100644 --- a/model.py +++ b/model.py @@ -16,6 +16,7 @@ class TabularEncoder(nn.Module): n_cont (int): Number of continuous features. n_cate (int): Number of categorical features. cate_dims (List[int]): List of dimensions for each categorical feature. + n_bins (int): Number of soft bins for continuous AutoDiscretization. """ def __init__( @@ -24,21 +25,26 @@ class TabularEncoder(nn.Module): n_cont: int, n_cate: int, cate_dims: List[int], + n_bins: int = 16, ): super().__init__() self.n_embd = n_embd self.n_cont = n_cont self.n_cate = n_cate + # Continuous feature path + # - BatchNorm on raw (NaN-filled) continuous values + # - AutoDiscretization (soft binning) per feature if n_cont > 0: - hidden = 2 * n_embd - self.cont_mlp = nn.Sequential( - nn.Linear(2 * n_cont, hidden), - nn.GELU(), - nn.Linear(hidden, n_embd), + self.cont_bn = nn.BatchNorm1d(n_cont) + self.cont_discretizer = AutoDiscretization( + n_features=n_cont, + n_bins=n_bins, + n_embd=n_embd, ) else: - self.cont_mlp = None + self.cont_bn = None + self.cont_discretizer = None if n_cate > 0: assert len(cate_dims) == n_cate, \ @@ -57,21 +63,16 @@ class TabularEncoder(nn.Module): nn.Linear(n_cont, n_embd) if n_cont > 0 else None ) - self.film = nn.Sequential( - nn.Linear(n_embd, 2 * n_embd), - nn.GELU(), + # Fuse aggregated value + aggregated mask via MLP + self.fuse_mlp = nn.Sequential( nn.Linear(2 * n_embd, 2 * n_embd), + nn.GELU(), + nn.Linear(2 * n_embd, n_embd), ) self.apply(self._init_weights) self.out_ln = nn.LayerNorm(n_embd) - # Zero-init the last layer of FiLM to start with identity modulation - with torch.no_grad(): - last_linear = self.film[-1] - last_linear.weight.zero_() - last_linear.bias.zero_() - def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -85,6 +86,15 @@ class TabularEncoder(nn.Module): cont_features: Optional[torch.Tensor], cate_features: Optional[torch.Tensor], ) -> torch.Tensor: + """Encode tabular features into a per-timestep embedding. + + Inputs: + cont_features: (B, L, n_cont) float tensor; NaN indicates missing. + cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad. + + Returns: + (B, L, n_embd) encoded embedding. + """ if self.n_cont == 0 and self.n_cate == 0: # infer (B, L) from whichever input is not None @@ -112,14 +122,32 @@ class TabularEncoder(nn.Module): raise ValueError( f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}") - cont_mask = (~torch.isnan(cont_features)).float() - cont_filled = torch.nan_to_num(cont_features, nan=0.0) - cont_joint = torch.cat([cont_filled, cont_mask], dim=-1) - h_cont_value = self.cont_mlp(cont_joint) + # Missingness mask: 1 for valid, 0 for missing + cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont) + + # BatchNorm cannot handle NaNs; fill missing with 0 before BN. + cont_filled = torch.nan_to_num( + cont_features, nan=0.0) # (B, L, n_cont) + + # Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C) + cont_flat = cont_filled.reshape(-1, self.n_cont) + cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont) + + # Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd) + cont_emb_flat = self.cont_discretizer(cont_norm_flat) + cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd) + + # Mask-out missing continuous features before aggregating across features + # (B, L, n_cont, n_embd) + cont_emb = cont_emb * cont_mask.unsqueeze(-1) + denom = cont_mask.sum( + dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1) + h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd) value_parts.append(h_cont_value) + # Explicit continuous mask embedding (fused later) if self.cont_mask_proj is not None: - h_cont_mask = self.cont_mask_proj(cont_mask) + h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd) mask_parts.append(h_cont_mask) if self.n_cate > 0 and cate_features is not None: @@ -155,19 +183,82 @@ class TabularEncoder(nn.Module): raise ValueError("No features provided to TabularEncoder.") return torch.zeros(B, L, self.n_embd, device=device) - h_value = torch.stack(value_parts, dim=0).mean(dim=0) - h_mask = torch.stack(mask_parts, dim=0).mean(dim=0) - h_mask_flat = h_mask.view(-1, self.n_embd) - film_params = self.film(h_mask_flat) - gamma_delta, beta = film_params.chunk(2, dim=-1) - gamma = 1.0 + gamma_delta - h_value_flat = h_value.view(-1, self.n_embd) - h_out = gamma * h_value_flat + beta - h_out = h_out.view(B, L, self.n_embd) + # Aggregate across feature groups (continuous block counts as one part; + # each categorical feature counts as one part). + h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd) + + if mask_parts: + h_mask = torch.stack(mask_parts, dim=0).mean( + dim=0) # (B, L, n_embd) + else: + h_mask = torch.zeros_like(h_value) + + # Fuse by concatenation + MLP projection + h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd) + h_out = self.fuse_mlp(h_fused) # (B, L, n_embd) h_out = self.out_ln(h_out) return h_out +class AutoDiscretization(nn.Module): + """AutoDiscretization / soft-binning for continuous tabular scalars. + + For each feature scalar $x$, compute a soft assignment over `n_bins`: + p = softmax(x * w + b) + Then compute the embedding as a weighted sum of learnable bin embeddings: + emb = sum_k p_k * E_k + + Shapes: + Input: (N, n_features) + Output: (N, n_features, n_embd) + """ + + def __init__(self, n_features: int, n_bins: int, n_embd: int): + super().__init__() + if n_features <= 0: + raise ValueError("n_features must be > 0") + if n_bins <= 1: + raise ValueError("n_bins must be > 1") + if n_embd <= 0: + raise ValueError("n_embd must be > 0") + + self.n_features = n_features + self.n_bins = n_bins + self.n_embd = n_embd + + # Per-feature, per-bin affine transform to produce logits + self.weight = nn.Parameter(torch.empty(n_features, n_bins)) + self.bias = nn.Parameter(torch.empty(n_features, n_bins)) + + # Learnable embeddings for each (feature, bin) + self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, mean=0.0, std=0.02) + nn.init.zeros_(self.bias) + nn.init.normal_(self.bin_emb, mean=0.0, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() != 2: + raise ValueError( + "AutoDiscretization expects input of shape (N, n_features)") + if x.size(1) != self.n_features: + raise ValueError( + f"Expected x.size(1) == {self.n_features}, got {x.size(1)}" + ) + + # logits: (N, n_features, n_bins) + logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \ + self.bias.unsqueeze(0) + probs = torch.softmax(logits, dim=-1) + + # Weighted sum over bins -> (N, n_features, n_embd) + emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2) + return emb + + def _build_time_padding_mask( event_seq: torch.Tensor, time_seq: torch.Tensor,