Refactor data preparation and add loss functions for model training
- Removed `prepare_data.py` as it is no longer needed. - Introduced `losses.py` containing ExponentialNLLLoss and WeibullLosses classes for calculating negative log-likelihood losses with regularization. - Added `model.py` which defines the DelphiFork model architecture, including a tabular encoder for handling continuous and categorical features, and merging sequences based on time order.
This commit is contained in:
112
losses.py
Normal file
112
losses.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ExponentialNLLLoss(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tech_tokens: int,
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_tech_tokens = n_tech_tokens
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
event_seqs: torch.Tensor,
|
||||
time_seqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate the negative log-likelihood for the exponential distribution
|
||||
|
||||
# 1, shift event_seqs to remove technical tokens
|
||||
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
|
||||
mask = target_event_seqs >= 0
|
||||
# 2, create a mask to filter out technical tokens
|
||||
if not mask.any():
|
||||
# if there are no valid events, return zero loss
|
||||
return logits.new_zeros(())
|
||||
|
||||
# 3, compute time differences
|
||||
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
|
||||
dt = dt[mask] # (N,)
|
||||
# 4, filter target events
|
||||
target_events = target_event_seqs[mask] # (N,)
|
||||
# 5, compute hazard and total hazard
|
||||
hazard = logits[:, :-1, :] # (B, L-1, vocab_size)
|
||||
hazard_at_events = hazard[mask].gather(
|
||||
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
||||
total_hazard = hazard[mask].sum(dim=-1) # (N,)
|
||||
# 6, compute negative log-likelihood
|
||||
nll = torch.log(hazard_at_events + 1e-6) - total_hazard * dt
|
||||
nll = -nll.mean()
|
||||
# 7, compute cross-entropy regularization
|
||||
p_ce = hazard_at_events / total_hazard
|
||||
regularization = -self.alpha * torch.log(p_ce + 1e-6).mean()
|
||||
|
||||
return nll + regularization
|
||||
|
||||
|
||||
class WeibullLosses(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tech_tokens: int,
|
||||
alpha: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_tech_tokens = n_tech_tokens
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
shapes: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
event_seqs: torch.Tensor,
|
||||
time_seqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate the negative log-likelihood for the Weibull distribution
|
||||
|
||||
# 1, shift event_seqs to remove technical tokens
|
||||
target_event_seqs = event_seqs[:, 1:] - self.n_tech_tokens
|
||||
mask = target_event_seqs >= 0
|
||||
# 2, create a mask to filter out technical tokens
|
||||
if not mask.any():
|
||||
# if there are no valid events, return zero loss
|
||||
return shapes.new_zeros(())
|
||||
|
||||
# 3, compute time differences
|
||||
dt = time_seqs[:, 1:] - time_seqs[:, :-1]
|
||||
dt = dt[mask] # (N,)
|
||||
# 4, filter target events
|
||||
target_events = target_event_seqs[mask] # (N,)
|
||||
shapes = shapes[mask] # (N, vocab_size)
|
||||
scales = scales[mask] # (N, vocab_size)
|
||||
# 5, compute shape and scale at events
|
||||
shape_at_events = shapes.gather(
|
||||
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
||||
scale_at_events = scales.gather(
|
||||
dim=-1, index=target_events.unsqueeze(-1)).squeeze(-1) # (N,)
|
||||
log_shapes = torch.log(shape_at_events)
|
||||
log_scales = torch.log(scale_at_events)
|
||||
log_dt = torch.log(dt + 1e-6)
|
||||
# 6, compute negative log-likelihood
|
||||
nll = log_shapes - log_scales + \
|
||||
(shape_at_events - 1) * (log_dt - log_scales)
|
||||
log_tot_survival = (dt.unsqueeze(-1) /
|
||||
scales) ** shapes # (N, vocab_size)
|
||||
nll -= log_tot_survival.sum(dim=-1)
|
||||
nll = -nll.mean()
|
||||
# 7, compute cross-entropy regularization
|
||||
log_shapes_all = torch.log(shapes)
|
||||
log_scales_all = torch.log(scales)
|
||||
log_dt_expanded = log_dt.unsqueeze(-1)
|
||||
|
||||
log_hazards = log_shapes_all - log_scales_all + (shapes - 1) * \
|
||||
(log_dt_expanded - log_scales_all) # (N, vocab_size)
|
||||
ce_loss = F.cross_entropy(
|
||||
log_hazards, target_events, reduction='mean')
|
||||
|
||||
return nll + self.alpha * ce_loss
|
||||
Reference in New Issue
Block a user