Enhance TabularEncoder with BatchNorm and AutoDiscretization for continuous features

This commit is contained in:
2026-01-08 00:24:43 +08:00
parent 33ba7e6c1d
commit 8293f7ee24

149
model.py
View File

@@ -16,6 +16,7 @@ class TabularEncoder(nn.Module):
n_cont (int): Number of continuous features. n_cont (int): Number of continuous features.
n_cate (int): Number of categorical features. n_cate (int): Number of categorical features.
cate_dims (List[int]): List of dimensions for each categorical feature. cate_dims (List[int]): List of dimensions for each categorical feature.
n_bins (int): Number of soft bins for continuous AutoDiscretization.
""" """
def __init__( def __init__(
@@ -24,21 +25,26 @@ class TabularEncoder(nn.Module):
n_cont: int, n_cont: int,
n_cate: int, n_cate: int,
cate_dims: List[int], cate_dims: List[int],
n_bins: int = 16,
): ):
super().__init__() super().__init__()
self.n_embd = n_embd self.n_embd = n_embd
self.n_cont = n_cont self.n_cont = n_cont
self.n_cate = n_cate self.n_cate = n_cate
# Continuous feature path
# - BatchNorm on raw (NaN-filled) continuous values
# - AutoDiscretization (soft binning) per feature
if n_cont > 0: if n_cont > 0:
hidden = 2 * n_embd self.cont_bn = nn.BatchNorm1d(n_cont)
self.cont_mlp = nn.Sequential( self.cont_discretizer = AutoDiscretization(
nn.Linear(2 * n_cont, hidden), n_features=n_cont,
nn.GELU(), n_bins=n_bins,
nn.Linear(hidden, n_embd), n_embd=n_embd,
) )
else: else:
self.cont_mlp = None self.cont_bn = None
self.cont_discretizer = None
if n_cate > 0: if n_cate > 0:
assert len(cate_dims) == n_cate, \ assert len(cate_dims) == n_cate, \
@@ -57,21 +63,16 @@ class TabularEncoder(nn.Module):
nn.Linear(n_cont, n_embd) if n_cont > 0 else None nn.Linear(n_cont, n_embd) if n_cont > 0 else None
) )
self.film = nn.Sequential( # Fuse aggregated value + aggregated mask via MLP
nn.Linear(n_embd, 2 * n_embd), self.fuse_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(2 * n_embd, 2 * n_embd), nn.Linear(2 * n_embd, 2 * n_embd),
nn.GELU(),
nn.Linear(2 * n_embd, n_embd),
) )
self.apply(self._init_weights) self.apply(self._init_weights)
self.out_ln = nn.LayerNorm(n_embd) self.out_ln = nn.LayerNorm(n_embd)
# Zero-init the last layer of FiLM to start with identity modulation
with torch.no_grad():
last_linear = self.film[-1]
last_linear.weight.zero_()
last_linear.bias.zero_()
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -85,6 +86,15 @@ class TabularEncoder(nn.Module):
cont_features: Optional[torch.Tensor], cont_features: Optional[torch.Tensor],
cate_features: Optional[torch.Tensor], cate_features: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""Encode tabular features into a per-timestep embedding.
Inputs:
cont_features: (B, L, n_cont) float tensor; NaN indicates missing.
cate_features: (B, L, n_cate) long/int tensor; 0 indicates missing/pad.
Returns:
(B, L, n_embd) encoded embedding.
"""
if self.n_cont == 0 and self.n_cate == 0: if self.n_cont == 0 and self.n_cate == 0:
# infer (B, L) from whichever input is not None # infer (B, L) from whichever input is not None
@@ -112,14 +122,32 @@ class TabularEncoder(nn.Module):
raise ValueError( raise ValueError(
f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}") f"Expected cont_features last dim to be {self.n_cont}, got {D_cont}")
cont_mask = (~torch.isnan(cont_features)).float() # Missingness mask: 1 for valid, 0 for missing
cont_filled = torch.nan_to_num(cont_features, nan=0.0) cont_mask = (~torch.isnan(cont_features)).float() # (B, L, n_cont)
cont_joint = torch.cat([cont_filled, cont_mask], dim=-1)
h_cont_value = self.cont_mlp(cont_joint) # BatchNorm cannot handle NaNs; fill missing with 0 before BN.
cont_filled = torch.nan_to_num(
cont_features, nan=0.0) # (B, L, n_cont)
# Apply BN over the feature dimension: (B, L, C) -> (B*L, C) -> (B, L, C)
cont_flat = cont_filled.reshape(-1, self.n_cont)
cont_norm_flat = self.cont_bn(cont_flat) # (B*L, n_cont)
# Soft-binning per feature: (B*L, n_cont) -> (B*L, n_cont, n_embd)
cont_emb_flat = self.cont_discretizer(cont_norm_flat)
cont_emb = cont_emb_flat.view(B, L, self.n_cont, self.n_embd)
# Mask-out missing continuous features before aggregating across features
# (B, L, n_cont, n_embd)
cont_emb = cont_emb * cont_mask.unsqueeze(-1)
denom = cont_mask.sum(
dim=-1, keepdim=True).clamp(min=1.0) # (B, L, 1)
h_cont_value = cont_emb.sum(dim=2) / denom # (B, L, n_embd)
value_parts.append(h_cont_value) value_parts.append(h_cont_value)
# Explicit continuous mask embedding (fused later)
if self.cont_mask_proj is not None: if self.cont_mask_proj is not None:
h_cont_mask = self.cont_mask_proj(cont_mask) h_cont_mask = self.cont_mask_proj(cont_mask) # (B, L, n_embd)
mask_parts.append(h_cont_mask) mask_parts.append(h_cont_mask)
if self.n_cate > 0 and cate_features is not None: if self.n_cate > 0 and cate_features is not None:
@@ -155,19 +183,82 @@ class TabularEncoder(nn.Module):
raise ValueError("No features provided to TabularEncoder.") raise ValueError("No features provided to TabularEncoder.")
return torch.zeros(B, L, self.n_embd, device=device) return torch.zeros(B, L, self.n_embd, device=device)
h_value = torch.stack(value_parts, dim=0).mean(dim=0) # Aggregate across feature groups (continuous block counts as one part;
h_mask = torch.stack(mask_parts, dim=0).mean(dim=0) # each categorical feature counts as one part).
h_mask_flat = h_mask.view(-1, self.n_embd) h_value = torch.stack(value_parts, dim=0).mean(dim=0) # (B, L, n_embd)
film_params = self.film(h_mask_flat)
gamma_delta, beta = film_params.chunk(2, dim=-1) if mask_parts:
gamma = 1.0 + gamma_delta h_mask = torch.stack(mask_parts, dim=0).mean(
h_value_flat = h_value.view(-1, self.n_embd) dim=0) # (B, L, n_embd)
h_out = gamma * h_value_flat + beta else:
h_out = h_out.view(B, L, self.n_embd) h_mask = torch.zeros_like(h_value)
# Fuse by concatenation + MLP projection
h_fused = torch.cat([h_value, h_mask], dim=-1) # (B, L, 2*n_embd)
h_out = self.fuse_mlp(h_fused) # (B, L, n_embd)
h_out = self.out_ln(h_out) h_out = self.out_ln(h_out)
return h_out return h_out
class AutoDiscretization(nn.Module):
"""AutoDiscretization / soft-binning for continuous tabular scalars.
For each feature scalar $x$, compute a soft assignment over `n_bins`:
p = softmax(x * w + b)
Then compute the embedding as a weighted sum of learnable bin embeddings:
emb = sum_k p_k * E_k
Shapes:
Input: (N, n_features)
Output: (N, n_features, n_embd)
"""
def __init__(self, n_features: int, n_bins: int, n_embd: int):
super().__init__()
if n_features <= 0:
raise ValueError("n_features must be > 0")
if n_bins <= 1:
raise ValueError("n_bins must be > 1")
if n_embd <= 0:
raise ValueError("n_embd must be > 0")
self.n_features = n_features
self.n_bins = n_bins
self.n_embd = n_embd
# Per-feature, per-bin affine transform to produce logits
self.weight = nn.Parameter(torch.empty(n_features, n_bins))
self.bias = nn.Parameter(torch.empty(n_features, n_bins))
# Learnable embeddings for each (feature, bin)
self.bin_emb = nn.Parameter(torch.empty(n_features, n_bins, n_embd))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.bias)
nn.init.normal_(self.bin_emb, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 2:
raise ValueError(
"AutoDiscretization expects input of shape (N, n_features)")
if x.size(1) != self.n_features:
raise ValueError(
f"Expected x.size(1) == {self.n_features}, got {x.size(1)}"
)
# logits: (N, n_features, n_bins)
logits = x.unsqueeze(-1) * self.weight.unsqueeze(0) + \
self.bias.unsqueeze(0)
probs = torch.softmax(logits, dim=-1)
# Weighted sum over bins -> (N, n_features, n_embd)
emb = (probs.unsqueeze(-1) * self.bin_emb.unsqueeze(0)).sum(dim=-2)
return emb
def _build_time_padding_mask( def _build_time_padding_mask(
event_seq: torch.Tensor, event_seq: torch.Tensor,
time_seq: torch.Tensor, time_seq: torch.Tensor,