Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.
This commit is contained in:
222
losses.py
222
losses.py
@@ -132,9 +132,19 @@ class ExponentialNLLLoss(nn.Module):
|
||||
return nll, reg
|
||||
|
||||
|
||||
class PiecewiseExponentialLoss(nn.Module):
|
||||
"""
|
||||
Piecewise-constant competing risks exponential likelihood.
|
||||
class DiscreteTimeCIFNLLLoss(nn.Module):
|
||||
"""Direct discrete-time CIF negative log-likelihood (no censoring).
|
||||
|
||||
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
|
||||
channels, where the complement channel (index K) represents survival across bins.
|
||||
|
||||
Per-sample likelihood for observed cause k at time bin j:
|
||||
p = \prod_{u=1}^{j-1} p(comp at u) * p(k at j)
|
||||
|
||||
Args:
|
||||
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
|
||||
eps: Unused; kept for interface compatibility / future numerical tweaks.
|
||||
lambda_reg: Optional regularization strength.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -146,18 +156,20 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if len(bin_edges) < 2:
|
||||
raise ValueError("bin_edges must have length >= 2")
|
||||
if bin_edges[0] != 0:
|
||||
raise ValueError("bin_edges must start at 0")
|
||||
raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
|
||||
if float(bin_edges[0]) != 0.0:
|
||||
raise ValueError("bin_edges[0] must equal 0")
|
||||
for i in range(1, len(bin_edges)):
|
||||
if not (bin_edges[i] > bin_edges[i - 1]):
|
||||
if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
|
||||
raise ValueError("bin_edges must be strictly increasing")
|
||||
|
||||
self.eps = float(eps)
|
||||
self.lambda_reg = float(lambda_reg)
|
||||
|
||||
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
||||
self.register_buffer("bin_edges", edges, persistent=False)
|
||||
self.register_buffer(
|
||||
"bin_edges",
|
||||
torch.tensor(bin_edges, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -166,145 +178,83 @@ class PiecewiseExponentialLoss(nn.Module):
|
||||
dt: torch.Tensor,
|
||||
reduction: str = "mean",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("logits must have shape (M, K, B)")
|
||||
|
||||
M, K, B = logits.shape
|
||||
if self.bin_edges.numel() != B + 1:
|
||||
if logits.ndim != 3:
|
||||
raise ValueError(
|
||||
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})"
|
||||
f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
|
||||
)
|
||||
if target_events.ndim != 1 or dt.ndim != 1:
|
||||
raise ValueError(
|
||||
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
|
||||
)
|
||||
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
|
||||
raise ValueError(
|
||||
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
|
||||
)
|
||||
if reduction not in {"mean", "sum", "none"}:
|
||||
raise ValueError("reduction must be one of {'mean','sum','none'}")
|
||||
|
||||
device = logits.device
|
||||
dt = dt.to(device=device, dtype=torch.float32)
|
||||
target_events = target_events.to(device=device)
|
||||
if not torch.all(dt > 0):
|
||||
raise ValueError("dt must be strictly positive")
|
||||
|
||||
# Infer K and n_bins from logits and bin_edges.
|
||||
m, k_plus_1, n_bins_plus_1 = logits.shape
|
||||
k_comp = k_plus_1 - 1
|
||||
if k_comp < 1:
|
||||
raise ValueError(
|
||||
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
|
||||
|
||||
n_bins = int(self.bin_edges.numel() - 1)
|
||||
if n_bins_plus_1 != n_bins + 1:
|
||||
raise ValueError(
|
||||
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
|
||||
)
|
||||
|
||||
if target_events.dtype != torch.long:
|
||||
target_events = target_events.to(dtype=torch.long)
|
||||
if target_events.min().item() < 0 or target_events.max().item() >= K:
|
||||
raise ValueError("target_events must be in [0, K)")
|
||||
target_events = target_events.to(torch.long)
|
||||
|
||||
# Hazards: (M, K, B)
|
||||
hazards = F.softplus(logits) + self.eps
|
||||
total_hazard = hazards.sum(dim=1) # (M, B)
|
||||
if (target_events < 0).any() or (target_events >= k_comp).any():
|
||||
raise ValueError(
|
||||
f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
|
||||
)
|
||||
|
||||
edges = self.bin_edges.to(device=device, dtype=dt.dtype)
|
||||
widths = edges[1:] - edges[:-1] # (B,)
|
||||
# Map continuous dt to discrete bins j in {1..n_bins}.
|
||||
bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
|
||||
# (M,), may be n_bins+1 if dt > bin_edges[-1]
|
||||
time_bin = torch.bucketize(dt, bin_edges)
|
||||
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
|
||||
torch.long) # ensure valid event bins
|
||||
|
||||
if dt.min().item() <= 0:
|
||||
raise ValueError("dt must be strictly positive")
|
||||
if dt.max().item() > edges[-1].item():
|
||||
raise ValueError("dt must be <= last bin edge")
|
||||
# Log-probabilities across causes+complement for each bin.
|
||||
logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
|
||||
|
||||
# Bin index b* in [0, B-1].
|
||||
b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,)
|
||||
# Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
|
||||
bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
|
||||
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
|
||||
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
|
||||
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
|
||||
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
|
||||
|
||||
# 1. Hazard at event (M,)
|
||||
# gather needs matching dims.
|
||||
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
|
||||
# Alternative: hazards[m, k, b]
|
||||
ar = torch.arange(M, device=device)
|
||||
hazard_event = hazards[ar, target_events, b_star] # (M,)
|
||||
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||
# Event term at bin j: -log p(k at j)
|
||||
m_idx = torch.arange(m, device=logits.device)
|
||||
loss_event = -logp[m_idx, target_events, time_bin] # (M,)
|
||||
|
||||
# 2. Integral part
|
||||
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
||||
|
||||
# Full bins accumulation
|
||||
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
|
||||
cum = weighted.cumsum(dim=1) # (M, B)
|
||||
|
||||
full_bins_int = torch.zeros_like(dt)
|
||||
|
||||
# We process 'has_full' logic generally.
|
||||
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
|
||||
has_full = b_star > 0
|
||||
|
||||
# NOTE: Even without protection, we need valid indices for gather.
|
||||
# We use a temporary index that is safe (0) for the 'False' cases, then mask the result.
|
||||
safe_indices = (b_star - 1).clamp(min=0)
|
||||
gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1)
|
||||
full_bins_int = torch.where(has_full, gathered_cum, full_bins_int)
|
||||
|
||||
# Partial bin accumulation
|
||||
edge_left = edges[b_star] # (M,)
|
||||
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
|
||||
partial = partial_hazard * (dt - edge_left)
|
||||
|
||||
integral = full_bins_int + partial
|
||||
|
||||
# Final NLL
|
||||
nll = -torch.log(hazard_event) + integral
|
||||
|
||||
# Reduction
|
||||
if reduction == "none":
|
||||
nll_out = nll
|
||||
elif reduction == "sum":
|
||||
nll_out = nll.sum()
|
||||
elif reduction == "mean":
|
||||
nll_out = nll.mean()
|
||||
else:
|
||||
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||
|
||||
reg = logits.new_zeros(())
|
||||
if self.lambda_reg != 0.0:
|
||||
reg = reg + (self.lambda_reg * logits.pow(2).mean())
|
||||
|
||||
return nll_out, reg
|
||||
|
||||
|
||||
class WeibullNLLLoss(nn.Module):
|
||||
"""
|
||||
Weibull hazard in t.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
lambda_reg: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.lambda_reg = lambda_reg
|
||||
|
||||
def forward(self, logits, target_events, dt, reduction="mean"):
|
||||
if logits.dim() != 3 or logits.size(-1) != 2:
|
||||
raise ValueError("logits must have shape (M, K, 2)")
|
||||
|
||||
M, K, _ = logits.shape
|
||||
device = logits.device
|
||||
|
||||
dt = dt.to(device=device, dtype=torch.float32)
|
||||
if dt.min().item() <= 0:
|
||||
raise ValueError("dt must be strictly positive")
|
||||
|
||||
target_events = target_events.to(device=device)
|
||||
target_events = target_events.to(dtype=torch.long)
|
||||
if target_events.min().item() < 0 or target_events.max().item() >= K:
|
||||
raise ValueError("target_events must be in [0, K)")
|
||||
|
||||
shapes = F.softplus(logits[..., 0]) + self.eps
|
||||
scales = F.softplus(logits[..., 1]) + self.eps
|
||||
|
||||
t_mat = dt.unsqueeze(1) # (M,1)
|
||||
cum_hazard = scales * torch.pow(t_mat, shapes)
|
||||
hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0)
|
||||
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
||||
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||
|
||||
nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1)
|
||||
loss = loss_prev + loss_event
|
||||
|
||||
if reduction == "mean":
|
||||
nll = nll.mean()
|
||||
nll = loss.mean()
|
||||
elif reduction == "sum":
|
||||
nll = nll.sum()
|
||||
elif reduction != "none":
|
||||
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
|
||||
nll = loss.sum()
|
||||
else:
|
||||
nll = loss
|
||||
|
||||
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
|
||||
if self.lambda_reg > 0.0:
|
||||
# Regularize the cause distribution at the event bin using NLL on log-probs.
|
||||
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
|
||||
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
|
||||
logp_at_event_bin = logp_causes.gather(
|
||||
dim=2, index=idx).squeeze(2) # (M, K)
|
||||
reg = self.lambda_reg * \
|
||||
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
|
||||
|
||||
reg = shapes.new_zeros(())
|
||||
if self.lambda_reg > 0:
|
||||
reg = self.lambda_reg * (
|
||||
(torch.log(scales + self.eps) ** 2).mean() +
|
||||
(torch.log(shapes + self.eps) ** 2).mean()
|
||||
)
|
||||
return nll, reg
|
||||
|
||||
141
model.py
141
model.py
@@ -259,64 +259,26 @@ class AutoDiscretization(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class FactorizedHead(nn.Module):
|
||||
class SimpleHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_embd: int,
|
||||
n_disease: int,
|
||||
n_dim: int,
|
||||
rank: int = 16,
|
||||
out_dims: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.n_embd = n_embd
|
||||
self.n_disease = n_disease
|
||||
self.n_dim = n_dim
|
||||
self.rank = rank
|
||||
|
||||
self.disease_base_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, n_dim),
|
||||
self.out_dims = out_dims
|
||||
total_out_dims = np.prod(out_dims)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(n_embd, n_embd),
|
||||
nn.GELU(),
|
||||
nn.Linear(n_embd, total_out_dims),
|
||||
nn.LayerNorm(total_out_dims),
|
||||
)
|
||||
self.context_mod_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, rank, bias=False),
|
||||
)
|
||||
self.disease_mod_proj = nn.Sequential(
|
||||
nn.LayerNorm(n_embd),
|
||||
nn.Linear(n_embd, rank * n_dim, bias=False),
|
||||
)
|
||||
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# init disease_base_proj: [LayerNorm, Linear]
|
||||
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
|
||||
nn.init.zeros_(self.disease_base_proj[1].bias)
|
||||
|
||||
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
|
||||
nn.init.zeros_(self.context_mod_proj[1].weight)
|
||||
|
||||
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
|
||||
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor, # (M, n_embd)
|
||||
disease_embedding, # (n_disease, n_embd)
|
||||
) -> torch.Tensor:
|
||||
M = c.shape[0]
|
||||
K = disease_embedding.shape[0]
|
||||
assert K == self.n_disease
|
||||
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
|
||||
base_logits = base_logits.unsqueeze(
|
||||
0).expand(M, -1, -1) # (M, K, n_dim)
|
||||
u = self.context_mod_proj(c)
|
||||
v = self.disease_mod_proj(disease_embedding)
|
||||
v = v.view(K, self.rank, self.n_dim)
|
||||
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
|
||||
|
||||
return base_logits + self.delta_scale * delta_logits
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.net(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x.view(-1, *self.out_dims)
|
||||
|
||||
|
||||
def _build_time_padding_mask(
|
||||
@@ -363,9 +325,6 @@ class DelphiFork(nn.Module):
|
||||
cate_dims: List[int],
|
||||
age_encoder_type: str = "sinusoidal",
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
rank: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = n_disease + n_tech_tokens
|
||||
@@ -373,7 +332,6 @@ class DelphiFork(nn.Module):
|
||||
self.n_disease = n_disease
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
|
||||
self.token_embedding = nn.Embedding(
|
||||
self.vocab_size, n_embd, padding_idx=0)
|
||||
@@ -397,15 +355,21 @@ class DelphiFork(nn.Module):
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = FactorizedHead(
|
||||
n_embd=n_embd,
|
||||
n_disease=n_disease,
|
||||
n_dim=n_dim,
|
||||
rank=rank,
|
||||
def get_disease_embedding(self) -> torch.Tensor:
|
||||
"""Get disease token embeddings for head computation.
|
||||
|
||||
Returns:
|
||||
(n_disease, n_embd) tensor of disease token embeddings.
|
||||
"""
|
||||
device = self.token_embedding.weight.device
|
||||
disease_ids = torch.arange(
|
||||
self.n_tech_tokens,
|
||||
self.n_tech_tokens + self.n_disease,
|
||||
device=device,
|
||||
)
|
||||
disease_embs = self.token_embedding(disease_ids)
|
||||
return disease_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -414,8 +378,6 @@ class DelphiFork(nn.Module):
|
||||
sex: torch.Tensor, # (B,)
|
||||
cont_seq: torch.Tensor, # (B, Lc, n_cont)
|
||||
cate_seq: torch.Tensor, # (B, Lc, n_cate)
|
||||
b_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
t_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
) -> torch.Tensor:
|
||||
token_embds = self.token_embedding(event_seq) # (B, L, D)
|
||||
age_embds = self.age_encoder(time_seq) # (B, L, D)
|
||||
@@ -443,23 +405,12 @@ class DelphiFork(nn.Module):
|
||||
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
|
||||
|
||||
x = final_embds + age_embds + sex_embds # (B, L, D)
|
||||
x = self.token_dropout(x)
|
||||
attn_mask = _build_time_padding_mask(
|
||||
event_seq, time_seq)
|
||||
for block in self.blocks:
|
||||
x = block(x, attn_mask=attn_mask)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
disease_embeddings = self.token_embedding.weight[
|
||||
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
|
||||
]
|
||||
|
||||
theta = self.theta_proj(c, disease_embeddings)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
@@ -477,9 +428,6 @@ class SapDelphi(nn.Module):
|
||||
cate_dims: List[int],
|
||||
age_encoder_type: str = "sinusoidal",
|
||||
pdrop: float = 0.0,
|
||||
token_pdrop: float = 0.0,
|
||||
n_dim: int = 1,
|
||||
rank: int = 16,
|
||||
pretrained_weights_path: Optional[str] = None, # 新增参数
|
||||
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
|
||||
):
|
||||
@@ -489,8 +437,6 @@ class SapDelphi(nn.Module):
|
||||
self.n_disease = n_disease
|
||||
self.n_embd = n_embd
|
||||
self.n_head = n_head
|
||||
self.n_dim = n_dim
|
||||
self.rank = rank
|
||||
|
||||
if pretrained_weights_path is not None:
|
||||
print(
|
||||
@@ -540,15 +486,22 @@ class SapDelphi(nn.Module):
|
||||
])
|
||||
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.token_dropout = nn.Dropout(token_pdrop)
|
||||
|
||||
# Head layers
|
||||
self.theta_proj = FactorizedHead(
|
||||
n_embd=n_embd,
|
||||
n_disease=n_disease,
|
||||
n_dim=n_dim,
|
||||
rank=rank,
|
||||
def get_disease_embedding(self) -> torch.Tensor:
|
||||
"""Get disease token embeddings for head computation.
|
||||
|
||||
Returns:
|
||||
(n_disease, n_embd) tensor of disease token embeddings.
|
||||
"""
|
||||
device = self.token_embedding.weight.device
|
||||
disease_ids = torch.arange(
|
||||
self.n_tech_tokens,
|
||||
self.n_tech_tokens + self.n_disease,
|
||||
device=device,
|
||||
)
|
||||
disease_embs = self.token_embedding(disease_ids)
|
||||
disease_embs = self.emb_proj(disease_embs)
|
||||
return disease_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -557,8 +510,6 @@ class SapDelphi(nn.Module):
|
||||
sex: torch.Tensor, # (B,)
|
||||
cont_seq: torch.Tensor, # (B, Lc, n_cont)
|
||||
cate_seq: torch.Tensor, # (B, Lc, n_cate)
|
||||
b_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
t_prev: Optional[torch.Tensor] = None, # (M,)
|
||||
) -> torch.Tensor:
|
||||
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
|
||||
token_embds = self.emb_proj(token_embds) # (B, L, D)
|
||||
@@ -587,22 +538,10 @@ class SapDelphi(nn.Module):
|
||||
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
|
||||
|
||||
x = final_embds + age_embds + sex_embds # (B, L, D)
|
||||
x = self.token_dropout(x)
|
||||
attn_mask = _build_time_padding_mask(
|
||||
event_seq, time_seq)
|
||||
for block in self.blocks:
|
||||
x = block(x, attn_mask=attn_mask)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if b_prev is not None and t_prev is not None:
|
||||
M = b_prev.numel()
|
||||
c = x[b_prev, t_prev] # (M, D)
|
||||
disease_embeddings_raw = self.token_embedding.weight[
|
||||
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
|
||||
] # (K, vocab_dim)
|
||||
|
||||
disease_embeddings = self.emb_proj(disease_embeddings_raw)
|
||||
theta = self.theta_proj(c, disease_embeddings)
|
||||
return theta
|
||||
else:
|
||||
return x
|
||||
|
||||
154
train.py
154
train.py
@@ -1,5 +1,5 @@
|
||||
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi
|
||||
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@@ -22,8 +22,7 @@ from typing import Literal, Sequence
|
||||
class TrainConfig:
|
||||
# Model Parameters
|
||||
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
||||
loss_type: Literal['exponential', 'weibull',
|
||||
'piecewise_exponential'] = 'weibull'
|
||||
loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
|
||||
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
||||
full_cov: bool = False
|
||||
n_embd: int = 120
|
||||
@@ -32,7 +31,8 @@ class TrainConfig:
|
||||
pdrop: float = 0.1
|
||||
lambda_reg: float = 1e-4
|
||||
bin_edges: Sequence[float] = field(
|
||||
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0]
|
||||
default_factory=lambda: [0.0, 0.24, 0.72,
|
||||
1.61, 3.84, 10.0, 31.0, float('inf')]
|
||||
)
|
||||
rank: int = 16
|
||||
# SapDelphi specific
|
||||
@@ -61,8 +61,12 @@ def parse_args() -> TrainConfig:
|
||||
parser = argparse.ArgumentParser(description="Train Delphi Model")
|
||||
parser.add_argument("--model_type", type=str, choices=[
|
||||
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
||||
parser.add_argument("--loss_type", type=str, choices=[
|
||||
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.")
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
choices=['exponential', 'discrete_time_cif'],
|
||||
default='exponential',
|
||||
help="Type of loss function to use.")
|
||||
parser.add_argument("--age_encoder", type=str, choices=[
|
||||
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
||||
parser.add_argument("--n_embd", type=int, default=120,
|
||||
@@ -193,18 +197,14 @@ class Trainer:
|
||||
self.criterion = ExponentialNLLLoss(
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = 1
|
||||
elif cfg.loss_type == "piecewise_exponential":
|
||||
self.criterion = PiecewiseExponentialLoss(
|
||||
out_dims = [dataset.n_disease]
|
||||
elif cfg.loss_type == "discrete_time_cif":
|
||||
self.criterion = DiscreteTimeCIFNLLLoss(
|
||||
bin_edges=cfg.bin_edges,
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = len(cfg.bin_edges) - 1
|
||||
elif cfg.loss_type == "weibull":
|
||||
self.criterion = WeibullNLLLoss(
|
||||
lambda_reg=cfg.lambda_reg,
|
||||
).to(self.device)
|
||||
n_dim = 2
|
||||
# logits shape (M, K+1, n_bins+1)
|
||||
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
||||
|
||||
@@ -217,8 +217,6 @@ class Trainer:
|
||||
n_layer=cfg.n_layer,
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
rank=cfg.rank,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
@@ -232,8 +230,6 @@ class Trainer:
|
||||
n_layer=cfg.n_layer,
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
rank=cfg.rank,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
@@ -242,10 +238,25 @@ class Trainer:
|
||||
).to(self.device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||
|
||||
# Prediction head maps context vectors -> logits with the shape required by the loss.
|
||||
self.head = SimpleHead(
|
||||
n_embd=cfg.n_embd,
|
||||
out_dims=out_dims,
|
||||
).to(self.device)
|
||||
|
||||
print(f"Model initialized: {cfg.model_type}")
|
||||
print(f"Number of trainable parameters: {get_num_params(self.model)}")
|
||||
print(
|
||||
f"Number of trainable parameters (backbone): {get_num_params(self.model)}")
|
||||
print(
|
||||
f"Number of trainable parameters (head): {get_num_params(self.head)}")
|
||||
|
||||
self._optim_params = (
|
||||
list(self.model.parameters())
|
||||
+ list(self.head.parameters())
|
||||
)
|
||||
self.optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
self._optim_params,
|
||||
lr=cfg.max_lr,
|
||||
weight_decay=cfg.weight_decay,
|
||||
betas=(0.9, 0.99),
|
||||
@@ -293,23 +304,11 @@ class Trainer:
|
||||
best_val_score = float('inf')
|
||||
patience_counter = 0
|
||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||
model_for_logging = self.model.module if hasattr(
|
||||
self.model, "module") else self.model
|
||||
delta_scale = None
|
||||
theta_proj = getattr(model_for_logging, "theta_proj", None)
|
||||
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
|
||||
try:
|
||||
delta_scale = float(
|
||||
theta_proj.delta_scale.detach().cpu().item())
|
||||
except Exception:
|
||||
delta_scale = None
|
||||
|
||||
self.model.train()
|
||||
self.head.train()
|
||||
total_train_pairs = 0
|
||||
total_train_nll = 0.0
|
||||
total_train_reg = 0.0
|
||||
total_train_log_scale_sq = 0.0
|
||||
total_train_log_shape_sq = 0.0
|
||||
pbar = tqdm(self.train_loader,
|
||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||
batch_count = 0
|
||||
@@ -334,25 +333,17 @@ class Trainer:
|
||||
self.optimizer.zero_grad()
|
||||
lr = self.compute_lr(self.global_step)
|
||||
self.optimizer.param_groups[0]['lr'] = lr
|
||||
logits = self.model(
|
||||
h = self.model(
|
||||
event_seq,
|
||||
time_seq,
|
||||
sexes,
|
||||
cont_feats,
|
||||
cate_feats,
|
||||
b_prev=b_prev,
|
||||
t_prev=t_prev,
|
||||
)
|
||||
|
||||
if isinstance(self.criterion, WeibullNLLLoss):
|
||||
eps = float(self.criterion.eps)
|
||||
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps
|
||||
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
|
||||
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||
else:
|
||||
log_scale_sq = None
|
||||
log_shape_sq = None
|
||||
# Context vectors for selected previous events
|
||||
c = h[b_prev, t_prev] # (M, D)
|
||||
logits = self.head(c)
|
||||
|
||||
target_event = event_seq[b_next, t_next] - 2
|
||||
nll_vec, reg = self.criterion(
|
||||
@@ -367,10 +358,6 @@ class Trainer:
|
||||
total_train_pairs += num_pairs
|
||||
total_train_nll += nll_vec.sum().item()
|
||||
total_train_reg += reg.item() * num_pairs
|
||||
if log_scale_sq is not None:
|
||||
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||
if log_shape_sq is not None:
|
||||
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||
avg_train_nll = total_train_nll / total_train_pairs
|
||||
avg_train_reg = total_train_reg / total_train_pairs
|
||||
pbar.set_postfix({
|
||||
@@ -380,8 +367,7 @@ class Trainer:
|
||||
})
|
||||
loss.backward()
|
||||
if self.cfg.grad_clip > 0:
|
||||
clip_grad_norm_(self.model.parameters(),
|
||||
self.cfg.grad_clip)
|
||||
clip_grad_norm_(self._optim_params, self.cfg.grad_clip)
|
||||
self.optimizer.step()
|
||||
self.global_step += 1
|
||||
|
||||
@@ -391,23 +377,12 @@ class Trainer:
|
||||
|
||||
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_log_scale_sq = (
|
||||
total_train_log_scale_sq / total_train_pairs
|
||||
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
train_log_shape_sq = (
|
||||
total_train_log_shape_sq / total_train_pairs
|
||||
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
self.head.eval()
|
||||
total_val_pairs = 0
|
||||
total_val_nll = 0.0
|
||||
total_val_reg = 0.0
|
||||
total_val_log_scale_sq = 0.0
|
||||
total_val_log_shape_sq = 0.0
|
||||
with torch.no_grad():
|
||||
val_pbar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in val_pbar:
|
||||
@@ -428,27 +403,16 @@ class Trainer:
|
||||
continue
|
||||
dt, b_prev, t_prev, b_next, t_next = res
|
||||
num_pairs = dt.size(0)
|
||||
logits = self.model(
|
||||
h = self.model(
|
||||
event_seq,
|
||||
time_seq,
|
||||
sexes,
|
||||
cont_feats,
|
||||
cate_feats,
|
||||
b_prev=b_prev,
|
||||
t_prev=t_prev
|
||||
)
|
||||
|
||||
if isinstance(self.criterion, WeibullNLLLoss):
|
||||
eps = float(self.criterion.eps)
|
||||
shapes = torch.nn.functional.softplus(
|
||||
logits[..., 0]) + eps
|
||||
scales = torch.nn.functional.softplus(
|
||||
logits[..., 1]) + eps
|
||||
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||
else:
|
||||
log_scale_sq = None
|
||||
log_shape_sq = None
|
||||
c = h[b_prev, t_prev]
|
||||
logits = self.head(c)
|
||||
|
||||
target_events = event_seq[b_next, t_next] - 2
|
||||
nll, reg = self.criterion(
|
||||
@@ -460,10 +424,6 @@ class Trainer:
|
||||
batch_nll_sum = nll.sum().item()
|
||||
total_val_nll += batch_nll_sum
|
||||
total_val_reg += reg.item() * num_pairs
|
||||
if log_scale_sq is not None:
|
||||
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||
if log_shape_sq is not None:
|
||||
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||
total_val_pairs += num_pairs
|
||||
|
||||
current_val_avg_nll = total_val_nll / \
|
||||
@@ -478,16 +438,6 @@ class Trainer:
|
||||
|
||||
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||
val_log_scale_sq = (
|
||||
total_val_log_scale_sq / total_val_pairs
|
||||
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
val_log_shape_sq = (
|
||||
total_val_log_shape_sq / total_val_pairs
|
||||
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
|
||||
history.append({
|
||||
"epoch": epoch,
|
||||
@@ -495,11 +445,6 @@ class Trainer:
|
||||
"train_reg": train_reg,
|
||||
"val_nll": val_nll,
|
||||
"val_reg": val_reg,
|
||||
"delta_scale": delta_scale,
|
||||
"train_log_scale_sq": train_log_scale_sq,
|
||||
"train_log_shape_sq": train_log_shape_sq,
|
||||
"val_log_scale_sq": val_log_scale_sq,
|
||||
"val_log_shape_sq": val_log_shape_sq,
|
||||
})
|
||||
|
||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||
@@ -507,18 +452,6 @@ class Trainer:
|
||||
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
||||
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||
if delta_scale is not None:
|
||||
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
||||
if train_log_scale_sq is not None and train_log_shape_sq is not None:
|
||||
tqdm.write(
|
||||
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
|
||||
tqdm.write(
|
||||
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
|
||||
if val_log_scale_sq is not None and val_log_shape_sq is not None:
|
||||
tqdm.write(
|
||||
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
|
||||
tqdm.write(
|
||||
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
|
||||
|
||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||
json.dump(history, f, indent=4)
|
||||
@@ -533,6 +466,7 @@ class Trainer:
|
||||
"epoch": epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"head_state_dict": self.head.state_dict(),
|
||||
"criterion_state_dict": self.criterion.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
}, self.best_path)
|
||||
|
||||
Reference in New Issue
Block a user