diff --git a/losses.py b/losses.py index 3ae8976..69e16b6 100644 --- a/losses.py +++ b/losses.py @@ -132,9 +132,19 @@ class ExponentialNLLLoss(nn.Module): return nll, reg -class PiecewiseExponentialLoss(nn.Module): - """ - Piecewise-constant competing risks exponential likelihood. +class DiscreteTimeCIFNLLLoss(nn.Module): + """Direct discrete-time CIF negative log-likelihood (no censoring). + + This loss assumes the model outputs per-bin logits over (K causes + 1 complement) + channels, where the complement channel (index K) represents survival across bins. + + Per-sample likelihood for observed cause k at time bin j: + p = \prod_{u=1}^{j-1} p(comp at u) * p(k at j) + + Args: + bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0. + eps: Unused; kept for interface compatibility / future numerical tweaks. + lambda_reg: Optional regularization strength. """ def __init__( @@ -146,18 +156,20 @@ class PiecewiseExponentialLoss(nn.Module): super().__init__() if len(bin_edges) < 2: - raise ValueError("bin_edges must have length >= 2") - if bin_edges[0] != 0: - raise ValueError("bin_edges must start at 0") + raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)") + if float(bin_edges[0]) != 0.0: + raise ValueError("bin_edges[0] must equal 0") for i in range(1, len(bin_edges)): - if not (bin_edges[i] > bin_edges[i - 1]): + if not (float(bin_edges[i]) > float(bin_edges[i - 1])): raise ValueError("bin_edges must be strictly increasing") self.eps = float(eps) self.lambda_reg = float(lambda_reg) - - edges = torch.tensor(list(bin_edges), dtype=torch.float32) - self.register_buffer("bin_edges", edges, persistent=False) + self.register_buffer( + "bin_edges", + torch.tensor(bin_edges, dtype=torch.float32), + persistent=False, + ) def forward( self, @@ -166,145 +178,83 @@ class PiecewiseExponentialLoss(nn.Module): dt: torch.Tensor, reduction: str = "mean", ) -> Tuple[torch.Tensor, torch.Tensor]: - if logits.dim() != 3: - raise ValueError("logits must have shape (M, K, B)") - - M, K, B = logits.shape - if self.bin_edges.numel() != B + 1: + if logits.ndim != 3: raise ValueError( - f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})" + f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}" ) + if target_events.ndim != 1 or dt.ndim != 1: + raise ValueError( + f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}" + ) + if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]: + raise ValueError( + "Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]" + ) + if reduction not in {"mean", "sum", "none"}: + raise ValueError("reduction must be one of {'mean','sum','none'}") - device = logits.device - dt = dt.to(device=device, dtype=torch.float32) - target_events = target_events.to(device=device) + if not torch.all(dt > 0): + raise ValueError("dt must be strictly positive") + + # Infer K and n_bins from logits and bin_edges. + m, k_plus_1, n_bins_plus_1 = logits.shape + k_comp = k_plus_1 - 1 + if k_comp < 1: + raise ValueError( + "logits.shape[1] must be at least 2 (K>=1 plus complement channel)") + + n_bins = int(self.bin_edges.numel() - 1) + if n_bins_plus_1 != n_bins + 1: + raise ValueError( + f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}" + ) if target_events.dtype != torch.long: - target_events = target_events.to(dtype=torch.long) - if target_events.min().item() < 0 or target_events.max().item() >= K: - raise ValueError("target_events must be in [0, K)") + target_events = target_events.to(torch.long) - # Hazards: (M, K, B) - hazards = F.softplus(logits) + self.eps - total_hazard = hazards.sum(dim=1) # (M, B) + if (target_events < 0).any() or (target_events >= k_comp).any(): + raise ValueError( + f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}" + ) - edges = self.bin_edges.to(device=device, dtype=dt.dtype) - widths = edges[1:] - edges[:-1] # (B,) + # Map continuous dt to discrete bins j in {1..n_bins}. + bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype) + # (M,), may be n_bins+1 if dt > bin_edges[-1] + time_bin = torch.bucketize(dt, bin_edges) + time_bin = torch.clamp(time_bin, min=1, max=n_bins).to( + torch.long) # ensure valid event bins - if dt.min().item() <= 0: - raise ValueError("dt must be strictly positive") - if dt.max().item() > edges[-1].item(): - raise ValueError("dt must be <= last bin edge") + # Log-probabilities across causes+complement for each bin. + logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1) - # Bin index b* in [0, B-1]. - b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,) + # Previous survival term: sum_{u=1}^{j-1} -log p(comp at u) + bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,) + mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze( + 0) < time_bin.unsqueeze(1)) # (M, n_bins+1) + logp_comp = logp[:, k_comp, :] # (M, n_bins+1) + loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,) - # 1. Hazard at event (M,) - # gather needs matching dims. - # hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,) - # Alternative: hazards[m, k, b] - ar = torch.arange(M, device=device) - hazard_event = hazards[ar, target_events, b_star] # (M,) - hazard_event = torch.clamp(hazard_event, min=self.eps) + # Event term at bin j: -log p(k at j) + m_idx = torch.arange(m, device=logits.device) + loss_event = -logp[m_idx, target_events, time_bin] # (M,) - # 2. Integral part - # Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left) - - # Full bins accumulation - weighted = total_hazard * widths.unsqueeze(0) # (M, B) - cum = weighted.cumsum(dim=1) # (M, B) - - full_bins_int = torch.zeros_like(dt) - - # We process 'has_full' logic generally. - # If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional - has_full = b_star > 0 - - # NOTE: Even without protection, we need valid indices for gather. - # We use a temporary index that is safe (0) for the 'False' cases, then mask the result. - safe_indices = (b_star - 1).clamp(min=0) - gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1) - full_bins_int = torch.where(has_full, gathered_cum, full_bins_int) - - # Partial bin accumulation - edge_left = edges[b_star] # (M,) - partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1) - partial = partial_hazard * (dt - edge_left) - - integral = full_bins_int + partial - - # Final NLL - nll = -torch.log(hazard_event) + integral - - # Reduction - if reduction == "none": - nll_out = nll - elif reduction == "sum": - nll_out = nll.sum() - elif reduction == "mean": - nll_out = nll.mean() - else: - raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") - - reg = logits.new_zeros(()) - if self.lambda_reg != 0.0: - reg = reg + (self.lambda_reg * logits.pow(2).mean()) - - return nll_out, reg - - -class WeibullNLLLoss(nn.Module): - """ - Weibull hazard in t. - """ - - def __init__( - self, - eps: float = 1e-6, - lambda_reg: float = 0.0, - ): - super().__init__() - self.eps = eps - self.lambda_reg = lambda_reg - - def forward(self, logits, target_events, dt, reduction="mean"): - if logits.dim() != 3 or logits.size(-1) != 2: - raise ValueError("logits must have shape (M, K, 2)") - - M, K, _ = logits.shape - device = logits.device - - dt = dt.to(device=device, dtype=torch.float32) - if dt.min().item() <= 0: - raise ValueError("dt must be strictly positive") - - target_events = target_events.to(device=device) - target_events = target_events.to(dtype=torch.long) - if target_events.min().item() < 0 or target_events.max().item() >= K: - raise ValueError("target_events must be in [0, K)") - - shapes = F.softplus(logits[..., 0]) + self.eps - scales = F.softplus(logits[..., 1]) + self.eps - - t_mat = dt.unsqueeze(1) # (M,1) - cum_hazard = scales * torch.pow(t_mat, shapes) - hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0) - hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1) - hazard_event = torch.clamp(hazard_event, min=self.eps) - - nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1) + loss = loss_prev + loss_event if reduction == "mean": - nll = nll.mean() + nll = loss.mean() elif reduction == "sum": - nll = nll.sum() - elif reduction != "none": - raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") + nll = loss.sum() + else: + nll = loss + + reg = torch.zeros((), device=logits.device, dtype=loss.dtype) + if self.lambda_reg > 0.0: + # Regularize the cause distribution at the event bin using NLL on log-probs. + logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1) + idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1) + logp_at_event_bin = logp_causes.gather( + dim=2, index=idx).squeeze(2) # (M, K) + reg = self.lambda_reg * \ + F.nll_loss(logp_at_event_bin, target_events, reduction="mean") - reg = shapes.new_zeros(()) - if self.lambda_reg > 0: - reg = self.lambda_reg * ( - (torch.log(scales + self.eps) ** 2).mean() + - (torch.log(shapes + self.eps) ** 2).mean() - ) return nll, reg diff --git a/model.py b/model.py index 09f2c86..462cc82 100644 --- a/model.py +++ b/model.py @@ -259,64 +259,26 @@ class AutoDiscretization(nn.Module): return emb -class FactorizedHead(nn.Module): +class SimpleHead(nn.Module): def __init__( self, n_embd: int, - n_disease: int, - n_dim: int, - rank: int = 16, + out_dims: List[int], ): super().__init__() - self.n_embd = n_embd - self.n_disease = n_disease - self.n_dim = n_dim - self.rank = rank - - self.disease_base_proj = nn.Sequential( - nn.LayerNorm(n_embd), - nn.Linear(n_embd, n_dim), + self.out_dims = out_dims + total_out_dims = np.prod(out_dims) + self.net = nn.Sequential( + nn.Linear(n_embd, n_embd), + nn.GELU(), + nn.Linear(n_embd, total_out_dims), + nn.LayerNorm(total_out_dims), ) - self.context_mod_proj = nn.Sequential( - nn.LayerNorm(n_embd), - nn.Linear(n_embd, rank, bias=False), - ) - self.disease_mod_proj = nn.Sequential( - nn.LayerNorm(n_embd), - nn.Linear(n_embd, rank * n_dim, bias=False), - ) - self.delta_scale = nn.Parameter(torch.tensor(1e-3)) - self._init_weights() - - def _init_weights(self): - # init disease_base_proj: [LayerNorm, Linear] - nn.init.normal_(self.disease_base_proj[1].weight, std=0.02) - nn.init.zeros_(self.disease_base_proj[1].bias) - - # init context_mod_proj: [LayerNorm, Linear(bias=False)] - nn.init.zeros_(self.context_mod_proj[1].weight) - - # init disease_mod_proj: [LayerNorm, Linear(bias=False)] - nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02) - - def forward( - self, - c: torch.Tensor, # (M, n_embd) - disease_embedding, # (n_disease, n_embd) - ) -> torch.Tensor: - M = c.shape[0] - K = disease_embedding.shape[0] - assert K == self.n_disease - base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim) - base_logits = base_logits.unsqueeze( - 0).expand(M, -1, -1) # (M, K, n_dim) - u = self.context_mod_proj(c) - v = self.disease_mod_proj(disease_embedding) - v = v.view(K, self.rank, self.n_dim) - delta_logits = torch.einsum('mr, krd -> mkd', u, v) - - return base_logits + self.delta_scale * delta_logits + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.net(x) + x = x.view(x.size(0), -1) + return x.view(-1, *self.out_dims) def _build_time_padding_mask( @@ -363,9 +325,6 @@ class DelphiFork(nn.Module): cate_dims: List[int], age_encoder_type: str = "sinusoidal", pdrop: float = 0.0, - token_pdrop: float = 0.0, - n_dim: int = 1, - rank: int = 16, ): super().__init__() self.vocab_size = n_disease + n_tech_tokens @@ -373,7 +332,6 @@ class DelphiFork(nn.Module): self.n_disease = n_disease self.n_embd = n_embd self.n_head = n_head - self.n_dim = n_dim self.token_embedding = nn.Embedding( self.vocab_size, n_embd, padding_idx=0) @@ -397,15 +355,21 @@ class DelphiFork(nn.Module): ]) self.ln_f = nn.LayerNorm(n_embd) - self.token_dropout = nn.Dropout(token_pdrop) - # Head layers - self.theta_proj = FactorizedHead( - n_embd=n_embd, - n_disease=n_disease, - n_dim=n_dim, - rank=rank, + def get_disease_embedding(self) -> torch.Tensor: + """Get disease token embeddings for head computation. + + Returns: + (n_disease, n_embd) tensor of disease token embeddings. + """ + device = self.token_embedding.weight.device + disease_ids = torch.arange( + self.n_tech_tokens, + self.n_tech_tokens + self.n_disease, + device=device, ) + disease_embs = self.token_embedding(disease_ids) + return disease_embs def forward( self, @@ -414,8 +378,6 @@ class DelphiFork(nn.Module): sex: torch.Tensor, # (B,) cont_seq: torch.Tensor, # (B, Lc, n_cont) cate_seq: torch.Tensor, # (B, Lc, n_cate) - b_prev: Optional[torch.Tensor] = None, # (M,) - t_prev: Optional[torch.Tensor] = None, # (M,) ) -> torch.Tensor: token_embds = self.token_embedding(event_seq) # (B, L, D) age_embds = self.age_encoder(time_seq) # (B, L, D) @@ -443,24 +405,13 @@ class DelphiFork(nn.Module): final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds) x = final_embds + age_embds + sex_embds # (B, L, D) - x = self.token_dropout(x) attn_mask = _build_time_padding_mask( event_seq, time_seq) for block in self.blocks: x = block(x, attn_mask=attn_mask) x = self.ln_f(x) - if b_prev is not None and t_prev is not None: - M = b_prev.numel() - c = x[b_prev, t_prev] # (M, D) - disease_embeddings = self.token_embedding.weight[ - self.n_tech_tokens: self.n_tech_tokens + self.n_disease - ] - - theta = self.theta_proj(c, disease_embeddings) - return theta - else: - return x + return x class SapDelphi(nn.Module): @@ -477,9 +428,6 @@ class SapDelphi(nn.Module): cate_dims: List[int], age_encoder_type: str = "sinusoidal", pdrop: float = 0.0, - token_pdrop: float = 0.0, - n_dim: int = 1, - rank: int = 16, pretrained_weights_path: Optional[str] = None, # 新增参数 freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调 ): @@ -489,8 +437,6 @@ class SapDelphi(nn.Module): self.n_disease = n_disease self.n_embd = n_embd self.n_head = n_head - self.n_dim = n_dim - self.rank = rank if pretrained_weights_path is not None: print( @@ -540,15 +486,22 @@ class SapDelphi(nn.Module): ]) self.ln_f = nn.LayerNorm(n_embd) - self.token_dropout = nn.Dropout(token_pdrop) - # Head layers - self.theta_proj = FactorizedHead( - n_embd=n_embd, - n_disease=n_disease, - n_dim=n_dim, - rank=rank, + def get_disease_embedding(self) -> torch.Tensor: + """Get disease token embeddings for head computation. + + Returns: + (n_disease, n_embd) tensor of disease token embeddings. + """ + device = self.token_embedding.weight.device + disease_ids = torch.arange( + self.n_tech_tokens, + self.n_tech_tokens + self.n_disease, + device=device, ) + disease_embs = self.token_embedding(disease_ids) + disease_embs = self.emb_proj(disease_embs) + return disease_embs def forward( self, @@ -557,8 +510,6 @@ class SapDelphi(nn.Module): sex: torch.Tensor, # (B,) cont_seq: torch.Tensor, # (B, Lc, n_cont) cate_seq: torch.Tensor, # (B, Lc, n_cate) - b_prev: Optional[torch.Tensor] = None, # (M,) - t_prev: Optional[torch.Tensor] = None, # (M,) ) -> torch.Tensor: token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim) token_embds = self.emb_proj(token_embds) # (B, L, D) @@ -587,22 +538,10 @@ class SapDelphi(nn.Module): final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds) x = final_embds + age_embds + sex_embds # (B, L, D) - x = self.token_dropout(x) attn_mask = _build_time_padding_mask( event_seq, time_seq) for block in self.blocks: x = block(x, attn_mask=attn_mask) x = self.ln_f(x) - if b_prev is not None and t_prev is not None: - M = b_prev.numel() - c = x[b_prev, t_prev] # (M, D) - disease_embeddings_raw = self.token_embedding.weight[ - self.n_tech_tokens: self.n_tech_tokens + self.n_disease - ] # (K, vocab_dim) - - disease_embeddings = self.emb_proj(disease_embeddings_raw) - theta = self.theta_proj(c, disease_embeddings) - return theta - else: - return x + return x diff --git a/train.py b/train.py index f584a05..947bf1b 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,5 @@ -from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt -from model import DelphiFork, SapDelphi +from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt +from model import DelphiFork, SapDelphi, SimpleHead from dataset import HealthDataset, health_collate_fn from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ @@ -22,8 +22,7 @@ from typing import Literal, Sequence class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' - loss_type: Literal['exponential', 'weibull', - 'piecewise_exponential'] = 'weibull' + loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 @@ -32,7 +31,8 @@ class TrainConfig: pdrop: float = 0.1 lambda_reg: float = 1e-4 bin_edges: Sequence[float] = field( - default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0] + default_factory=lambda: [0.0, 0.24, 0.72, + 1.61, 3.84, 10.0, 31.0, float('inf')] ) rank: int = 16 # SapDelphi specific @@ -61,8 +61,12 @@ def parse_args() -> TrainConfig: parser = argparse.ArgumentParser(description="Train Delphi Model") parser.add_argument("--model_type", type=str, choices=[ 'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.") - parser.add_argument("--loss_type", type=str, choices=[ - 'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.") + parser.add_argument( + "--loss_type", + type=str, + choices=['exponential', 'discrete_time_cif'], + default='exponential', + help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.") parser.add_argument("--n_embd", type=int, default=120, @@ -193,18 +197,14 @@ class Trainer: self.criterion = ExponentialNLLLoss( lambda_reg=cfg.lambda_reg, ).to(self.device) - n_dim = 1 - elif cfg.loss_type == "piecewise_exponential": - self.criterion = PiecewiseExponentialLoss( + out_dims = [dataset.n_disease] + elif cfg.loss_type == "discrete_time_cif": + self.criterion = DiscreteTimeCIFNLLLoss( bin_edges=cfg.bin_edges, lambda_reg=cfg.lambda_reg, ).to(self.device) - n_dim = len(cfg.bin_edges) - 1 - elif cfg.loss_type == "weibull": - self.criterion = WeibullNLLLoss( - lambda_reg=cfg.lambda_reg, - ).to(self.device) - n_dim = 2 + # logits shape (M, K+1, n_bins+1) + out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)] else: raise ValueError(f"Unsupported loss type: {cfg.loss_type}") @@ -217,8 +217,6 @@ class Trainer: n_layer=cfg.n_layer, pdrop=cfg.pdrop, age_encoder_type=cfg.age_encoder, - n_dim=n_dim, - rank=cfg.rank, n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, @@ -232,8 +230,6 @@ class Trainer: n_layer=cfg.n_layer, pdrop=cfg.pdrop, age_encoder_type=cfg.age_encoder, - n_dim=n_dim, - rank=cfg.rank, n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, @@ -242,10 +238,25 @@ class Trainer: ).to(self.device) else: raise ValueError(f"Unsupported model type: {cfg.model_type}") + + # Prediction head maps context vectors -> logits with the shape required by the loss. + self.head = SimpleHead( + n_embd=cfg.n_embd, + out_dims=out_dims, + ).to(self.device) + print(f"Model initialized: {cfg.model_type}") - print(f"Number of trainable parameters: {get_num_params(self.model)}") + print( + f"Number of trainable parameters (backbone): {get_num_params(self.model)}") + print( + f"Number of trainable parameters (head): {get_num_params(self.head)}") + + self._optim_params = ( + list(self.model.parameters()) + + list(self.head.parameters()) + ) self.optimizer = AdamW( - self.model.parameters(), + self._optim_params, lr=cfg.max_lr, weight_decay=cfg.weight_decay, betas=(0.9, 0.99), @@ -293,23 +304,11 @@ class Trainer: best_val_score = float('inf') patience_counter = 0 for epoch in range(1, self.cfg.max_epochs + 1): - model_for_logging = self.model.module if hasattr( - self.model, "module") else self.model - delta_scale = None - theta_proj = getattr(model_for_logging, "theta_proj", None) - if theta_proj is not None and hasattr(theta_proj, "delta_scale"): - try: - delta_scale = float( - theta_proj.delta_scale.detach().cpu().item()) - except Exception: - delta_scale = None - self.model.train() + self.head.train() total_train_pairs = 0 total_train_nll = 0.0 total_train_reg = 0.0 - total_train_log_scale_sq = 0.0 - total_train_log_shape_sq = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) batch_count = 0 @@ -334,25 +333,17 @@ class Trainer: self.optimizer.zero_grad() lr = self.compute_lr(self.global_step) self.optimizer.param_groups[0]['lr'] = lr - logits = self.model( + h = self.model( event_seq, time_seq, sexes, cont_feats, cate_feats, - b_prev=b_prev, - t_prev=t_prev, ) - if isinstance(self.criterion, WeibullNLLLoss): - eps = float(self.criterion.eps) - shapes = torch.nn.functional.softplus(logits[..., 0]) + eps - scales = torch.nn.functional.softplus(logits[..., 1]) + eps - log_scale_sq = (torch.log(scales + eps) ** 2).mean() - log_shape_sq = (torch.log(shapes + eps) ** 2).mean() - else: - log_scale_sq = None - log_shape_sq = None + # Context vectors for selected previous events + c = h[b_prev, t_prev] # (M, D) + logits = self.head(c) target_event = event_seq[b_next, t_next] - 2 nll_vec, reg = self.criterion( @@ -367,10 +358,6 @@ class Trainer: total_train_pairs += num_pairs total_train_nll += nll_vec.sum().item() total_train_reg += reg.item() * num_pairs - if log_scale_sq is not None: - total_train_log_scale_sq += log_scale_sq.item() * num_pairs - if log_shape_sq is not None: - total_train_log_shape_sq += log_shape_sq.item() * num_pairs avg_train_nll = total_train_nll / total_train_pairs avg_train_reg = total_train_reg / total_train_pairs pbar.set_postfix({ @@ -380,8 +367,7 @@ class Trainer: }) loss.backward() if self.cfg.grad_clip > 0: - clip_grad_norm_(self.model.parameters(), - self.cfg.grad_clip) + clip_grad_norm_(self._optim_params, self.cfg.grad_clip) self.optimizer.step() self.global_step += 1 @@ -391,23 +377,12 @@ class Trainer: train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0 train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0 - train_log_scale_sq = ( - total_train_log_scale_sq / total_train_pairs - if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) - else None - ) - train_log_shape_sq = ( - total_train_log_shape_sq / total_train_pairs - if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) - else None - ) self.model.eval() + self.head.eval() total_val_pairs = 0 total_val_nll = 0.0 total_val_reg = 0.0 - total_val_log_scale_sq = 0.0 - total_val_log_shape_sq = 0.0 with torch.no_grad(): val_pbar = tqdm(self.val_loader, desc="Validation") for batch in val_pbar: @@ -428,27 +403,16 @@ class Trainer: continue dt, b_prev, t_prev, b_next, t_next = res num_pairs = dt.size(0) - logits = self.model( + h = self.model( event_seq, time_seq, sexes, cont_feats, cate_feats, - b_prev=b_prev, - t_prev=t_prev ) - if isinstance(self.criterion, WeibullNLLLoss): - eps = float(self.criterion.eps) - shapes = torch.nn.functional.softplus( - logits[..., 0]) + eps - scales = torch.nn.functional.softplus( - logits[..., 1]) + eps - log_scale_sq = (torch.log(scales + eps) ** 2).mean() - log_shape_sq = (torch.log(shapes + eps) ** 2).mean() - else: - log_scale_sq = None - log_shape_sq = None + c = h[b_prev, t_prev] + logits = self.head(c) target_events = event_seq[b_next, t_next] - 2 nll, reg = self.criterion( @@ -460,10 +424,6 @@ class Trainer: batch_nll_sum = nll.sum().item() total_val_nll += batch_nll_sum total_val_reg += reg.item() * num_pairs - if log_scale_sq is not None: - total_val_log_scale_sq += log_scale_sq.item() * num_pairs - if log_shape_sq is not None: - total_val_log_shape_sq += log_shape_sq.item() * num_pairs total_val_pairs += num_pairs current_val_avg_nll = total_val_nll / \ @@ -478,16 +438,6 @@ class Trainer: val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0 val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0 - val_log_scale_sq = ( - total_val_log_scale_sq / total_val_pairs - if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) - else None - ) - val_log_shape_sq = ( - total_val_log_shape_sq / total_val_pairs - if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) - else None - ) history.append({ "epoch": epoch, @@ -495,11 +445,6 @@ class Trainer: "train_reg": train_reg, "val_nll": val_nll, "val_reg": val_reg, - "delta_scale": delta_scale, - "train_log_scale_sq": train_log_scale_sq, - "train_log_shape_sq": train_log_shape_sq, - "val_log_scale_sq": val_log_scale_sq, - "val_log_shape_sq": val_log_shape_sq, }) tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") @@ -507,18 +452,6 @@ class Trainer: tqdm.write(f" Train Reg: {train_reg:.4f}") tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") tqdm.write(f" Val Reg: {val_reg:.4f}") - if delta_scale is not None: - tqdm.write(f" Delta scale: {delta_scale:.6g}") - if train_log_scale_sq is not None and train_log_shape_sq is not None: - tqdm.write( - f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}") - tqdm.write( - f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}") - if val_log_scale_sq is not None and val_log_shape_sq is not None: - tqdm.write( - f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}") - tqdm.write( - f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}") with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: json.dump(history, f, indent=4) @@ -533,6 +466,7 @@ class Trainer: "epoch": epoch, "global_step": self.global_step, "model_state_dict": self.model.state_dict(), + "head_state_dict": self.head.state_dict(), "criterion_state_dict": self.criterion.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, self.best_path)