- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
113 lines
4.0 KiB
Python
113 lines
4.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class ExponentialNLLLoss(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
n_tech_tokens: int,
|
|
alpha: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
self.n_tech_tokens = n_tech_tokens
|
|
self.alpha = alpha
|
|
|
|
def forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
event_seqs: torch.Tensor,
|
|
time_seqs: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Calculate the negative log-likelihood for the exponential distribution
|
|
|
|
# 1, shift event_seqs to remove technical tokens
|
|
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
|
|
mask = target_event_seqs >= 0
|
|
# 2, create a mask to filter out technical tokens
|
|
if not mask.any():
|
|
# if there are no valid events, return zero loss
|
|
return logits.new_zeros(())
|
|
|
|
# 3, compute time differences
|
|
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
|
|
dt = dt[mask] # (N,)
|
|
# 4, filter target events
|
|
target_events = target_event_seqs[mask] # (N,)
|
|
# 5, compute hazard and total hazard
|
|
hazard = logits[:, :-1, :] # (B, L-1, vocab_size)
|
|
hazard_at_events = hazard[mask].gather(
|
|
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
|
total_hazard = hazard[mask].sum(dim=-1) # (N,)
|
|
# 6, compute negative log-likelihood
|
|
nll = torch.log(hazard_at_events + 1e-6) - total_hazard * dt
|
|
nll = -nll.mean()
|
|
# 7, compute cross-entropy regularization
|
|
p_ce = hazard_at_events / total_hazard
|
|
regularization = -self.alpha * torch.log(p_ce + 1e-6).mean()
|
|
|
|
return nll + regularization
|
|
|
|
|
|
class WeibullLosses(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
n_tech_tokens: int,
|
|
alpha: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
self.n_tech_tokens = n_tech_tokens
|
|
self.alpha = alpha
|
|
|
|
def forward(
|
|
self,
|
|
shapes: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
event_seqs: torch.Tensor,
|
|
time_seqs: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Calculate the negative log-likelihood for the Weibull distribution
|
|
|
|
# 1, shift event_seqs to remove technical tokens
|
|
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
|
|
mask = target_event_seqs >= 0
|
|
# 2, create a mask to filter out technical tokens
|
|
if not mask.any():
|
|
# if there are no valid events, return zero loss
|
|
return shapes.new_zeros(())
|
|
|
|
# 3, compute time differences
|
|
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
|
|
dt = dt[mask] # (N,)
|
|
# 4, filter target events
|
|
target_events = target_event_seqs[mask] # (N,)
|
|
shapes = shapes[mask] # (N, vocab_size)
|
|
scales = scales[mask] # (N, vocab_size)
|
|
# 5, compute shape and scale at events
|
|
shape_at_events = shapes.gather(
|
|
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
|
scale_at_events = scales.gather(
|
|
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
|
log_shapes = torch.log(shape_at_events)
|
|
log_scales = torch.log(scale_at_events)
|
|
log_dt = torch.log(dt + 1e-6)
|
|
# 6, compute negative log-likelihood
|
|
nll = log_shapes - log_scales + \
|
|
(shape_at_events - 1) * (log_dt - log_scales)
|
|
log_tot_survival = (dt.unsqueeze(-1) /
|
|
scales) ** shapes # (N, vocab_size)
|
|
nll -= log_tot_survival.sum(dim=-1)
|
|
nll = -nll.mean()
|
|
# 7, compute cross-entropy regularization
|
|
log_shapes_all = torch.log(shapes)
|
|
log_scales_all = torch.log(scales)
|
|
log_dt_expanded = log_dt.unsqueeze(-1)
|
|
|
|
log_hazards = log_shapes_all - log_scales_all + (shapes - 1) * \
|
|
(log_dt_expanded - log_scales_all) # (N, vocab_size)
|
|
ce_loss = F.cross_entropy(
|
|
log_hazards, target_events, reduction='mean')
|
|
|
|
return nll + self.alpha * ce_loss
|