import torch import torch.nn as nn import torch.nn.functional as F class ExponentialNLLLoss(nn.Module): def __init__( self, n_tech_tokens: int, alpha: float = 0.1, ): super().__init__() self.n_tech_tokens = n_tech_tokens self.alpha = alpha def forward( self, logits: torch.Tensor, event_seqs: torch.Tensor, time_seqs: torch.Tensor, ) -> torch.Tensor: # Calculate the negative log-likelihood for the exponential distribution # 1, shift event_seqs to remove technical tokens target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens mask = target_event_seqs >= 0 # 2, create a mask to filter out technical tokens if not mask.any(): # if there are no valid events, return zero loss return logits.new_zeros(()) # 3, compute time differences dt = time_seqs[:, 1:] - time_seqs[:, :-1] dt = dt[mask] # (N,) # 4, filter target events target_events = target_event_seqs[mask] # (N,) # 5, compute hazard and total hazard hazard = logits[:, :-1, :] # (B, L-1, vocab_size) hazard_at_events = hazard[mask].gather( dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,) total_hazard = hazard[mask].sum(dim=-1) # (N,) # 6, compute negative log-likelihood nll = torch.log(hazard_at_events + 1e-6) - total_hazard * dt nll = -nll.mean() # 7, compute cross-entropy regularization p_ce = hazard_at_events / total_hazard regularization = -self.alpha * torch.log(p_ce + 1e-6).mean() return nll + regularization class WeibullLosses(nn.Module): def __init__( self, n_tech_tokens: int, alpha: float = 0.1, ): super().__init__() self.n_tech_tokens = n_tech_tokens self.alpha = alpha def forward( self, shapes: torch.Tensor, scales: torch.Tensor, event_seqs: torch.Tensor, time_seqs: torch.Tensor, ) -> torch.Tensor: # Calculate the negative log-likelihood for the Weibull distribution # 1, shift event_seqs to remove technical tokens target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens mask = target_event_seqs >= 0 # 2, create a mask to filter out technical tokens if not mask.any(): # if there are no valid events, return zero loss return shapes.new_zeros(()) # 3, compute time differences dt = time_seqs[:, 1:] - time_seqs[:, :-1] dt = dt[mask] # (N,) # 4, filter target events target_events = target_event_seqs[mask] # (N,) shapes = shapes[mask] # (N, vocab_size) scales = scales[mask] # (N, vocab_size) # 5, compute shape and scale at events shape_at_events = shapes.gather( dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,) scale_at_events = scales.gather( dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,) log_shapes = torch.log(shape_at_events) log_scales = torch.log(scale_at_events) log_dt = torch.log(dt + 1e-6) # 6, compute negative log-likelihood nll = log_shapes - log_scales + \ (shape_at_events - 1) * (log_dt - log_scales) log_tot_survival = (dt.unsqueeze(-1) / scales) ** shapes # (N, vocab_size) nll -= log_tot_survival.sum(dim=-1) nll = -nll.mean() # 7, compute cross-entropy regularization log_shapes_all = torch.log(shapes) log_scales_all = torch.log(scales) log_dt_expanded = log_dt.unsqueeze(-1) log_hazards = log_shapes_all - log_scales_all + (shapes - 1) * \ (log_dt_expanded - log_scales_all) # (N, vocab_size) ce_loss = F.cross_entropy( log_hazards, target_events, reduction='mean') return nll + self.alpha * ce_loss