Files
DeepHealth/losses.py

113 lines
4.0 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
class ExponentialNLLLoss(nn.Module):
def __init__(
self,
n_tech_tokens: int,
alpha: float = 0.1,
):
super().__init__()
self.n_tech_tokens = n_tech_tokens
self.alpha = alpha
def forward(
self,
logits: torch.Tensor,
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
) -> torch.Tensor:
# Calculate the negative log-likelihood for the exponential distribution
# 1, shift event_seqs to remove technical tokens
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
mask = target_event_seqs >= 0
# 2, create a mask to filter out technical tokens
if not mask.any():
# if there are no valid events, return zero loss
return logits.new_zeros(())
# 3, compute time differences
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
dt = dt[mask] # (N,)
# 4, filter target events
target_events = target_event_seqs[mask] # (N,)
# 5, compute hazard and total hazard
hazard = logits[:, :-1, :] # (B, L-1, vocab_size)
hazard_at_events = hazard[mask].gather(
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
total_hazard = hazard[mask].sum(dim=-1) # (N,)
# 6, compute negative log-likelihood
nll = torch.log(hazard_at_events + 1e-6) - total_hazard * dt
nll = -nll.mean()
# 7, compute cross-entropy regularization
p_ce = hazard_at_events / total_hazard
regularization = -self.alpha * torch.log(p_ce + 1e-6).mean()
return nll + regularization
class WeibullLosses(nn.Module):
def __init__(
self,
n_tech_tokens: int,
alpha: float = 0.1,
):
super().__init__()
self.n_tech_tokens = n_tech_tokens
self.alpha = alpha
def forward(
self,
shapes: torch.Tensor,
scales: torch.Tensor,
event_seqs: torch.Tensor,
time_seqs: torch.Tensor,
) -> torch.Tensor:
# Calculate the negative log-likelihood for the Weibull distribution
# 1, shift event_seqs to remove technical tokens
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
mask = target_event_seqs >= 0
# 2, create a mask to filter out technical tokens
if not mask.any():
# if there are no valid events, return zero loss
return shapes.new_zeros(())
# 3, compute time differences
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
dt = dt[mask] # (N,)
# 4, filter target events
target_events = target_event_seqs[mask] # (N,)
shapes = shapes[mask] # (N, vocab_size)
scales = scales[mask] # (N, vocab_size)
# 5, compute shape and scale at events
shape_at_events = shapes.gather(
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
scale_at_events = scales.gather(
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
log_shapes = torch.log(shape_at_events)
log_scales = torch.log(scale_at_events)
log_dt = torch.log(dt + 1e-6)
# 6, compute negative log-likelihood
nll = log_shapes - log_scales + \
(shape_at_events - 1) * (log_dt - log_scales)
log_tot_survival = (dt.unsqueeze(-1) /
scales) ** shapes # (N, vocab_size)
nll -= log_tot_survival.sum(dim=-1)
nll = -nll.mean()
# 7, compute cross-entropy regularization
log_shapes_all = torch.log(shapes)
log_scales_all = torch.log(scales)
log_dt_expanded = log_dt.unsqueeze(-1)
log_hazards = log_shapes_all - log_scales_all + (shapes - 1) * \
(log_dt_expanded - log_scales_all) # (N, vocab_size)
ce_loss = F.cross_entropy(
log_hazards, target_events, reduction='mean')
return nll + self.alpha * ce_loss