Refactor loss functions and model architecture: replace PiecewiseExponentialLoss with DiscreteTimeCIFNLLLoss, update Trainer to use SimpleHead, and modify argument parsing for new loss type.

This commit is contained in:
2026-01-09 18:31:38 +08:00
parent 880fd53a4b
commit 209dde2299
3 changed files with 172 additions and 349 deletions

222
losses.py
View File

@@ -132,9 +132,19 @@ class ExponentialNLLLoss(nn.Module):
return nll, reg return nll, reg
class PiecewiseExponentialLoss(nn.Module): class DiscreteTimeCIFNLLLoss(nn.Module):
""" """Direct discrete-time CIF negative log-likelihood (no censoring).
Piecewise-constant competing risks exponential likelihood.
This loss assumes the model outputs per-bin logits over (K causes + 1 complement)
channels, where the complement channel (index K) represents survival across bins.
Per-sample likelihood for observed cause k at time bin j:
p = \prod_{u=1}^{j-1} p(comp at u) * p(k at j)
Args:
bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.
eps: Unused; kept for interface compatibility / future numerical tweaks.
lambda_reg: Optional regularization strength.
""" """
def __init__( def __init__(
@@ -146,18 +156,20 @@ class PiecewiseExponentialLoss(nn.Module):
super().__init__() super().__init__()
if len(bin_edges) < 2: if len(bin_edges) < 2:
raise ValueError("bin_edges must have length >= 2") raise ValueError("bin_edges must have length >= 2 (n_bins >= 1)")
if bin_edges[0] != 0: if float(bin_edges[0]) != 0.0:
raise ValueError("bin_edges must start at 0") raise ValueError("bin_edges[0] must equal 0")
for i in range(1, len(bin_edges)): for i in range(1, len(bin_edges)):
if not (bin_edges[i] > bin_edges[i - 1]): if not (float(bin_edges[i]) > float(bin_edges[i - 1])):
raise ValueError("bin_edges must be strictly increasing") raise ValueError("bin_edges must be strictly increasing")
self.eps = float(eps) self.eps = float(eps)
self.lambda_reg = float(lambda_reg) self.lambda_reg = float(lambda_reg)
self.register_buffer(
edges = torch.tensor(list(bin_edges), dtype=torch.float32) "bin_edges",
self.register_buffer("bin_edges", edges, persistent=False) torch.tensor(bin_edges, dtype=torch.float32),
persistent=False,
)
def forward( def forward(
self, self,
@@ -166,145 +178,83 @@ class PiecewiseExponentialLoss(nn.Module):
dt: torch.Tensor, dt: torch.Tensor,
reduction: str = "mean", reduction: str = "mean",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if logits.dim() != 3: if logits.ndim != 3:
raise ValueError("logits must have shape (M, K, B)")
M, K, B = logits.shape
if self.bin_edges.numel() != B + 1:
raise ValueError( raise ValueError(
f"bin_edges length ({self.bin_edges.numel()}) must equal B+1 ({B+1})" f"logits must have ndim==3 with shape (M, K+1, n_bins+1); got {tuple(logits.shape)}"
) )
if target_events.ndim != 1 or dt.ndim != 1:
raise ValueError(
f"target_events and dt must be 1D tensors; got target_events.ndim={target_events.ndim}, dt.ndim={dt.ndim}"
)
if logits.shape[0] != target_events.shape[0] or logits.shape[0] != dt.shape[0]:
raise ValueError(
"Batch size mismatch: logits.shape[0] must equal target_events.shape[0] and dt.shape[0]"
)
if reduction not in {"mean", "sum", "none"}:
raise ValueError("reduction must be one of {'mean','sum','none'}")
device = logits.device if not torch.all(dt > 0):
dt = dt.to(device=device, dtype=torch.float32) raise ValueError("dt must be strictly positive")
target_events = target_events.to(device=device)
# Infer K and n_bins from logits and bin_edges.
m, k_plus_1, n_bins_plus_1 = logits.shape
k_comp = k_plus_1 - 1
if k_comp < 1:
raise ValueError(
"logits.shape[1] must be at least 2 (K>=1 plus complement channel)")
n_bins = int(self.bin_edges.numel() - 1)
if n_bins_plus_1 != n_bins + 1:
raise ValueError(
f"logits.shape[2] must equal n_bins+1={n_bins + 1} based on bin_edges; got {n_bins_plus_1}"
)
if target_events.dtype != torch.long: if target_events.dtype != torch.long:
target_events = target_events.to(dtype=torch.long) target_events = target_events.to(torch.long)
if target_events.min().item() < 0 or target_events.max().item() >= K:
raise ValueError("target_events must be in [0, K)")
# Hazards: (M, K, B) if (target_events < 0).any() or (target_events >= k_comp).any():
hazards = F.softplus(logits) + self.eps raise ValueError(
total_hazard = hazards.sum(dim=1) # (M, B) f"target_events must be in [0, K-1] where K={k_comp}; got min={int(target_events.min())}, max={int(target_events.max())}"
)
edges = self.bin_edges.to(device=device, dtype=dt.dtype) # Map continuous dt to discrete bins j in {1..n_bins}.
widths = edges[1:] - edges[:-1] # (B,) bin_edges = self.bin_edges.to(device=dt.device, dtype=dt.dtype)
# (M,), may be n_bins+1 if dt > bin_edges[-1]
time_bin = torch.bucketize(dt, bin_edges)
time_bin = torch.clamp(time_bin, min=1, max=n_bins).to(
torch.long) # ensure valid event bins
if dt.min().item() <= 0: # Log-probabilities across causes+complement for each bin.
raise ValueError("dt must be strictly positive") logp = F.log_softmax(logits, dim=1) # (M, K+1, n_bins+1)
if dt.max().item() > edges[-1].item():
raise ValueError("dt must be <= last bin edge")
# Bin index b* in [0, B-1]. # Previous survival term: sum_{u=1}^{j-1} -log p(comp at u)
b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,) bins = torch.arange(n_bins + 1, device=logits.device) # (n_bins+1,)
mask = (bins.unsqueeze(0) >= 1) & (bins.unsqueeze(
0) < time_bin.unsqueeze(1)) # (M, n_bins+1)
logp_comp = logp[:, k_comp, :] # (M, n_bins+1)
loss_prev = -(logp_comp * mask.to(logp_comp.dtype)).sum(dim=1) # (M,)
# 1. Hazard at event (M,) # Event term at bin j: -log p(k at j)
# gather needs matching dims. m_idx = torch.arange(m, device=logits.device)
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,) loss_event = -logp[m_idx, target_events, time_bin] # (M,)
# Alternative: hazards[m, k, b]
ar = torch.arange(M, device=device)
hazard_event = hazards[ar, target_events, b_star] # (M,)
hazard_event = torch.clamp(hazard_event, min=self.eps)
# 2. Integral part loss = loss_prev + loss_event
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
# Full bins accumulation
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
cum = weighted.cumsum(dim=1) # (M, B)
full_bins_int = torch.zeros_like(dt)
# We process 'has_full' logic generally.
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
has_full = b_star > 0
# NOTE: Even without protection, we need valid indices for gather.
# We use a temporary index that is safe (0) for the 'False' cases, then mask the result.
safe_indices = (b_star - 1).clamp(min=0)
gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1)
full_bins_int = torch.where(has_full, gathered_cum, full_bins_int)
# Partial bin accumulation
edge_left = edges[b_star] # (M,)
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
partial = partial_hazard * (dt - edge_left)
integral = full_bins_int + partial
# Final NLL
nll = -torch.log(hazard_event) + integral
# Reduction
if reduction == "none":
nll_out = nll
elif reduction == "sum":
nll_out = nll.sum()
elif reduction == "mean":
nll_out = nll.mean()
else:
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'")
reg = logits.new_zeros(())
if self.lambda_reg != 0.0:
reg = reg + (self.lambda_reg * logits.pow(2).mean())
return nll_out, reg
class WeibullNLLLoss(nn.Module):
"""
Weibull hazard in t.
"""
def __init__(
self,
eps: float = 1e-6,
lambda_reg: float = 0.0,
):
super().__init__()
self.eps = eps
self.lambda_reg = lambda_reg
def forward(self, logits, target_events, dt, reduction="mean"):
if logits.dim() != 3 or logits.size(-1) != 2:
raise ValueError("logits must have shape (M, K, 2)")
M, K, _ = logits.shape
device = logits.device
dt = dt.to(device=device, dtype=torch.float32)
if dt.min().item() <= 0:
raise ValueError("dt must be strictly positive")
target_events = target_events.to(device=device)
target_events = target_events.to(dtype=torch.long)
if target_events.min().item() < 0 or target_events.max().item() >= K:
raise ValueError("target_events must be in [0, K)")
shapes = F.softplus(logits[..., 0]) + self.eps
scales = F.softplus(logits[..., 1]) + self.eps
t_mat = dt.unsqueeze(1) # (M,1)
cum_hazard = scales * torch.pow(t_mat, shapes)
hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0)
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
hazard_event = torch.clamp(hazard_event, min=self.eps)
nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1)
if reduction == "mean": if reduction == "mean":
nll = nll.mean() nll = loss.mean()
elif reduction == "sum": elif reduction == "sum":
nll = nll.sum() nll = loss.sum()
elif reduction != "none": else:
raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") nll = loss
reg = torch.zeros((), device=logits.device, dtype=loss.dtype)
if self.lambda_reg > 0.0:
# Regularize the cause distribution at the event bin using NLL on log-probs.
logp_causes = logp[:, :k_comp, :] # (M, K, n_bins+1)
idx = time_bin.view(m, 1, 1).expand(-1, k_comp, 1)
logp_at_event_bin = logp_causes.gather(
dim=2, index=idx).squeeze(2) # (M, K)
reg = self.lambda_reg * \
F.nll_loss(logp_at_event_bin, target_events, reduction="mean")
reg = shapes.new_zeros(())
if self.lambda_reg > 0:
reg = self.lambda_reg * (
(torch.log(scales + self.eps) ** 2).mean() +
(torch.log(shapes + self.eps) ** 2).mean()
)
return nll, reg return nll, reg

145
model.py
View File

@@ -259,64 +259,26 @@ class AutoDiscretization(nn.Module):
return emb return emb
class FactorizedHead(nn.Module): class SimpleHead(nn.Module):
def __init__( def __init__(
self, self,
n_embd: int, n_embd: int,
n_disease: int, out_dims: List[int],
n_dim: int,
rank: int = 16,
): ):
super().__init__() super().__init__()
self.n_embd = n_embd self.out_dims = out_dims
self.n_disease = n_disease total_out_dims = np.prod(out_dims)
self.n_dim = n_dim self.net = nn.Sequential(
self.rank = rank nn.Linear(n_embd, n_embd),
nn.GELU(),
self.disease_base_proj = nn.Sequential( nn.Linear(n_embd, total_out_dims),
nn.LayerNorm(n_embd), nn.LayerNorm(total_out_dims),
nn.Linear(n_embd, n_dim),
) )
self.context_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank, bias=False),
)
self.disease_mod_proj = nn.Sequential(
nn.LayerNorm(n_embd),
nn.Linear(n_embd, rank * n_dim, bias=False),
)
self.delta_scale = nn.Parameter(torch.tensor(1e-3))
self._init_weights() def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
def _init_weights(self): x = x.view(x.size(0), -1)
# init disease_base_proj: [LayerNorm, Linear] return x.view(-1, *self.out_dims)
nn.init.normal_(self.disease_base_proj[1].weight, std=0.02)
nn.init.zeros_(self.disease_base_proj[1].bias)
# init context_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.zeros_(self.context_mod_proj[1].weight)
# init disease_mod_proj: [LayerNorm, Linear(bias=False)]
nn.init.normal_(self.disease_mod_proj[1].weight, std=0.02)
def forward(
self,
c: torch.Tensor, # (M, n_embd)
disease_embedding, # (n_disease, n_embd)
) -> torch.Tensor:
M = c.shape[0]
K = disease_embedding.shape[0]
assert K == self.n_disease
base_logits = self.disease_base_proj(disease_embedding) # (K, n_dim)
base_logits = base_logits.unsqueeze(
0).expand(M, -1, -1) # (M, K, n_dim)
u = self.context_mod_proj(c)
v = self.disease_mod_proj(disease_embedding)
v = v.view(K, self.rank, self.n_dim)
delta_logits = torch.einsum('mr, krd -> mkd', u, v)
return base_logits + self.delta_scale * delta_logits
def _build_time_padding_mask( def _build_time_padding_mask(
@@ -363,9 +325,6 @@ class DelphiFork(nn.Module):
cate_dims: List[int], cate_dims: List[int],
age_encoder_type: str = "sinusoidal", age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0, pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
rank: int = 16,
): ):
super().__init__() super().__init__()
self.vocab_size = n_disease + n_tech_tokens self.vocab_size = n_disease + n_tech_tokens
@@ -373,7 +332,6 @@ class DelphiFork(nn.Module):
self.n_disease = n_disease self.n_disease = n_disease
self.n_embd = n_embd self.n_embd = n_embd
self.n_head = n_head self.n_head = n_head
self.n_dim = n_dim
self.token_embedding = nn.Embedding( self.token_embedding = nn.Embedding(
self.vocab_size, n_embd, padding_idx=0) self.vocab_size, n_embd, padding_idx=0)
@@ -397,15 +355,21 @@ class DelphiFork(nn.Module):
]) ])
self.ln_f = nn.LayerNorm(n_embd) self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers def get_disease_embedding(self) -> torch.Tensor:
self.theta_proj = FactorizedHead( """Get disease token embeddings for head computation.
n_embd=n_embd,
n_disease=n_disease, Returns:
n_dim=n_dim, (n_disease, n_embd) tensor of disease token embeddings.
rank=rank, """
device = self.token_embedding.weight.device
disease_ids = torch.arange(
self.n_tech_tokens,
self.n_tech_tokens + self.n_disease,
device=device,
) )
disease_embs = self.token_embedding(disease_ids)
return disease_embs
def forward( def forward(
self, self,
@@ -414,8 +378,6 @@ class DelphiFork(nn.Module):
sex: torch.Tensor, # (B,) sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont) cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate) cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor: ) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, D) token_embds = self.token_embedding(event_seq) # (B, L, D)
age_embds = self.age_encoder(time_seq) # (B, L, D) age_embds = self.age_encoder(time_seq) # (B, L, D)
@@ -443,24 +405,13 @@ class DelphiFork(nn.Module):
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds) final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D) x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask( attn_mask = _build_time_padding_mask(
event_seq, time_seq) event_seq, time_seq)
for block in self.blocks: for block in self.blocks:
x = block(x, attn_mask=attn_mask) x = block(x, attn_mask=attn_mask)
x = self.ln_f(x) x = self.ln_f(x)
if b_prev is not None and t_prev is not None: return x
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
disease_embeddings = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
]
theta = self.theta_proj(c, disease_embeddings)
return theta
else:
return x
class SapDelphi(nn.Module): class SapDelphi(nn.Module):
@@ -477,9 +428,6 @@ class SapDelphi(nn.Module):
cate_dims: List[int], cate_dims: List[int],
age_encoder_type: str = "sinusoidal", age_encoder_type: str = "sinusoidal",
pdrop: float = 0.0, pdrop: float = 0.0,
token_pdrop: float = 0.0,
n_dim: int = 1,
rank: int = 16,
pretrained_weights_path: Optional[str] = None, # 新增参数 pretrained_weights_path: Optional[str] = None, # 新增参数
freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调 freeze_embeddings: bool = False, # 新增参数,默认为 False 表示微调
): ):
@@ -489,8 +437,6 @@ class SapDelphi(nn.Module):
self.n_disease = n_disease self.n_disease = n_disease
self.n_embd = n_embd self.n_embd = n_embd
self.n_head = n_head self.n_head = n_head
self.n_dim = n_dim
self.rank = rank
if pretrained_weights_path is not None: if pretrained_weights_path is not None:
print( print(
@@ -540,15 +486,22 @@ class SapDelphi(nn.Module):
]) ])
self.ln_f = nn.LayerNorm(n_embd) self.ln_f = nn.LayerNorm(n_embd)
self.token_dropout = nn.Dropout(token_pdrop)
# Head layers def get_disease_embedding(self) -> torch.Tensor:
self.theta_proj = FactorizedHead( """Get disease token embeddings for head computation.
n_embd=n_embd,
n_disease=n_disease, Returns:
n_dim=n_dim, (n_disease, n_embd) tensor of disease token embeddings.
rank=rank, """
device = self.token_embedding.weight.device
disease_ids = torch.arange(
self.n_tech_tokens,
self.n_tech_tokens + self.n_disease,
device=device,
) )
disease_embs = self.token_embedding(disease_ids)
disease_embs = self.emb_proj(disease_embs)
return disease_embs
def forward( def forward(
self, self,
@@ -557,8 +510,6 @@ class SapDelphi(nn.Module):
sex: torch.Tensor, # (B,) sex: torch.Tensor, # (B,)
cont_seq: torch.Tensor, # (B, Lc, n_cont) cont_seq: torch.Tensor, # (B, Lc, n_cont)
cate_seq: torch.Tensor, # (B, Lc, n_cate) cate_seq: torch.Tensor, # (B, Lc, n_cate)
b_prev: Optional[torch.Tensor] = None, # (M,)
t_prev: Optional[torch.Tensor] = None, # (M,)
) -> torch.Tensor: ) -> torch.Tensor:
token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim) token_embds = self.token_embedding(event_seq) # (B, L, Vocab_dim)
token_embds = self.emb_proj(token_embds) # (B, L, D) token_embds = self.emb_proj(token_embds) # (B, L, D)
@@ -587,22 +538,10 @@ class SapDelphi(nn.Module):
final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds) final_embds = torch.where(mask.unsqueeze(-1), tab_inject, token_embds)
x = final_embds + age_embds + sex_embds # (B, L, D) x = final_embds + age_embds + sex_embds # (B, L, D)
x = self.token_dropout(x)
attn_mask = _build_time_padding_mask( attn_mask = _build_time_padding_mask(
event_seq, time_seq) event_seq, time_seq)
for block in self.blocks: for block in self.blocks:
x = block(x, attn_mask=attn_mask) x = block(x, attn_mask=attn_mask)
x = self.ln_f(x) x = self.ln_f(x)
if b_prev is not None and t_prev is not None: return x
M = b_prev.numel()
c = x[b_prev, t_prev] # (M, D)
disease_embeddings_raw = self.token_embedding.weight[
self.n_tech_tokens: self.n_tech_tokens + self.n_disease
] # (K, vocab_dim)
disease_embeddings = self.emb_proj(disease_embeddings_raw)
theta = self.theta_proj(c, disease_embeddings)
return theta
else:
return x

154
train.py
View File

@@ -1,5 +1,5 @@
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, get_valid_pairs_and_dt
from model import DelphiFork, SapDelphi from model import DelphiFork, SapDelphi, SimpleHead
from dataset import HealthDataset, health_collate_fn from dataset import HealthDataset, health_collate_fn
from tqdm import tqdm from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
@@ -22,8 +22,7 @@ from typing import Literal, Sequence
class TrainConfig: class TrainConfig:
# Model Parameters # Model Parameters
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
loss_type: Literal['exponential', 'weibull', loss_type: Literal['exponential', 'discrete_time_cif'] = 'exponential'
'piecewise_exponential'] = 'weibull'
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
full_cov: bool = False full_cov: bool = False
n_embd: int = 120 n_embd: int = 120
@@ -32,7 +31,8 @@ class TrainConfig:
pdrop: float = 0.1 pdrop: float = 0.1
lambda_reg: float = 1e-4 lambda_reg: float = 1e-4
bin_edges: Sequence[float] = field( bin_edges: Sequence[float] = field(
default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0] default_factory=lambda: [0.0, 0.24, 0.72,
1.61, 3.84, 10.0, 31.0, float('inf')]
) )
rank: int = 16 rank: int = 16
# SapDelphi specific # SapDelphi specific
@@ -61,8 +61,12 @@ def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train Delphi Model") parser = argparse.ArgumentParser(description="Train Delphi Model")
parser.add_argument("--model_type", type=str, choices=[ parser.add_argument("--model_type", type=str, choices=[
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.") 'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
parser.add_argument("--loss_type", type=str, choices=[ parser.add_argument(
'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.") "--loss_type",
type=str,
choices=['exponential', 'discrete_time_cif'],
default='exponential',
help="Type of loss function to use.")
parser.add_argument("--age_encoder", type=str, choices=[ parser.add_argument("--age_encoder", type=str, choices=[
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.") 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
parser.add_argument("--n_embd", type=int, default=120, parser.add_argument("--n_embd", type=int, default=120,
@@ -193,18 +197,14 @@ class Trainer:
self.criterion = ExponentialNLLLoss( self.criterion = ExponentialNLLLoss(
lambda_reg=cfg.lambda_reg, lambda_reg=cfg.lambda_reg,
).to(self.device) ).to(self.device)
n_dim = 1 out_dims = [dataset.n_disease]
elif cfg.loss_type == "piecewise_exponential": elif cfg.loss_type == "discrete_time_cif":
self.criterion = PiecewiseExponentialLoss( self.criterion = DiscreteTimeCIFNLLLoss(
bin_edges=cfg.bin_edges, bin_edges=cfg.bin_edges,
lambda_reg=cfg.lambda_reg, lambda_reg=cfg.lambda_reg,
).to(self.device) ).to(self.device)
n_dim = len(cfg.bin_edges) - 1 # logits shape (M, K+1, n_bins+1)
elif cfg.loss_type == "weibull": out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
self.criterion = WeibullNLLLoss(
lambda_reg=cfg.lambda_reg,
).to(self.device)
n_dim = 2
else: else:
raise ValueError(f"Unsupported loss type: {cfg.loss_type}") raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
@@ -217,8 +217,6 @@ class Trainer:
n_layer=cfg.n_layer, n_layer=cfg.n_layer,
pdrop=cfg.pdrop, pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder, age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
rank=cfg.rank,
n_cont=dataset.n_cont, n_cont=dataset.n_cont,
n_cate=dataset.n_cate, n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims, cate_dims=dataset.cate_dims,
@@ -232,8 +230,6 @@ class Trainer:
n_layer=cfg.n_layer, n_layer=cfg.n_layer,
pdrop=cfg.pdrop, pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder, age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
rank=cfg.rank,
n_cont=dataset.n_cont, n_cont=dataset.n_cont,
n_cate=dataset.n_cate, n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims, cate_dims=dataset.cate_dims,
@@ -242,10 +238,25 @@ class Trainer:
).to(self.device) ).to(self.device)
else: else:
raise ValueError(f"Unsupported model type: {cfg.model_type}") raise ValueError(f"Unsupported model type: {cfg.model_type}")
# Prediction head maps context vectors -> logits with the shape required by the loss.
self.head = SimpleHead(
n_embd=cfg.n_embd,
out_dims=out_dims,
).to(self.device)
print(f"Model initialized: {cfg.model_type}") print(f"Model initialized: {cfg.model_type}")
print(f"Number of trainable parameters: {get_num_params(self.model)}") print(
f"Number of trainable parameters (backbone): {get_num_params(self.model)}")
print(
f"Number of trainable parameters (head): {get_num_params(self.head)}")
self._optim_params = (
list(self.model.parameters())
+ list(self.head.parameters())
)
self.optimizer = AdamW( self.optimizer = AdamW(
self.model.parameters(), self._optim_params,
lr=cfg.max_lr, lr=cfg.max_lr,
weight_decay=cfg.weight_decay, weight_decay=cfg.weight_decay,
betas=(0.9, 0.99), betas=(0.9, 0.99),
@@ -293,23 +304,11 @@ class Trainer:
best_val_score = float('inf') best_val_score = float('inf')
patience_counter = 0 patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1): for epoch in range(1, self.cfg.max_epochs + 1):
model_for_logging = self.model.module if hasattr(
self.model, "module") else self.model
delta_scale = None
theta_proj = getattr(model_for_logging, "theta_proj", None)
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
try:
delta_scale = float(
theta_proj.delta_scale.detach().cpu().item())
except Exception:
delta_scale = None
self.model.train() self.model.train()
self.head.train()
total_train_pairs = 0 total_train_pairs = 0
total_train_nll = 0.0 total_train_nll = 0.0
total_train_reg = 0.0 total_train_reg = 0.0
total_train_log_scale_sq = 0.0
total_train_log_shape_sq = 0.0
pbar = tqdm(self.train_loader, pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0 batch_count = 0
@@ -334,25 +333,17 @@ class Trainer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
lr = self.compute_lr(self.global_step) lr = self.compute_lr(self.global_step)
self.optimizer.param_groups[0]['lr'] = lr self.optimizer.param_groups[0]['lr'] = lr
logits = self.model( h = self.model(
event_seq, event_seq,
time_seq, time_seq,
sexes, sexes,
cont_feats, cont_feats,
cate_feats, cate_feats,
b_prev=b_prev,
t_prev=t_prev,
) )
if isinstance(self.criterion, WeibullNLLLoss): # Context vectors for selected previous events
eps = float(self.criterion.eps) c = h[b_prev, t_prev] # (M, D)
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps logits = self.head(c)
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
else:
log_scale_sq = None
log_shape_sq = None
target_event = event_seq[b_next, t_next] - 2 target_event = event_seq[b_next, t_next] - 2
nll_vec, reg = self.criterion( nll_vec, reg = self.criterion(
@@ -367,10 +358,6 @@ class Trainer:
total_train_pairs += num_pairs total_train_pairs += num_pairs
total_train_nll += nll_vec.sum().item() total_train_nll += nll_vec.sum().item()
total_train_reg += reg.item() * num_pairs total_train_reg += reg.item() * num_pairs
if log_scale_sq is not None:
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
if log_shape_sq is not None:
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
avg_train_nll = total_train_nll / total_train_pairs avg_train_nll = total_train_nll / total_train_pairs
avg_train_reg = total_train_reg / total_train_pairs avg_train_reg = total_train_reg / total_train_pairs
pbar.set_postfix({ pbar.set_postfix({
@@ -380,8 +367,7 @@ class Trainer:
}) })
loss.backward() loss.backward()
if self.cfg.grad_clip > 0: if self.cfg.grad_clip > 0:
clip_grad_norm_(self.model.parameters(), clip_grad_norm_(self._optim_params, self.cfg.grad_clip)
self.cfg.grad_clip)
self.optimizer.step() self.optimizer.step()
self.global_step += 1 self.global_step += 1
@@ -391,23 +377,12 @@ class Trainer:
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0 train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0 train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
train_log_scale_sq = (
total_train_log_scale_sq / total_train_pairs
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
train_log_shape_sq = (
total_train_log_shape_sq / total_train_pairs
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
self.model.eval() self.model.eval()
self.head.eval()
total_val_pairs = 0 total_val_pairs = 0
total_val_nll = 0.0 total_val_nll = 0.0
total_val_reg = 0.0 total_val_reg = 0.0
total_val_log_scale_sq = 0.0
total_val_log_shape_sq = 0.0
with torch.no_grad(): with torch.no_grad():
val_pbar = tqdm(self.val_loader, desc="Validation") val_pbar = tqdm(self.val_loader, desc="Validation")
for batch in val_pbar: for batch in val_pbar:
@@ -428,27 +403,16 @@ class Trainer:
continue continue
dt, b_prev, t_prev, b_next, t_next = res dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0) num_pairs = dt.size(0)
logits = self.model( h = self.model(
event_seq, event_seq,
time_seq, time_seq,
sexes, sexes,
cont_feats, cont_feats,
cate_feats, cate_feats,
b_prev=b_prev,
t_prev=t_prev
) )
if isinstance(self.criterion, WeibullNLLLoss): c = h[b_prev, t_prev]
eps = float(self.criterion.eps) logits = self.head(c)
shapes = torch.nn.functional.softplus(
logits[..., 0]) + eps
scales = torch.nn.functional.softplus(
logits[..., 1]) + eps
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
else:
log_scale_sq = None
log_shape_sq = None
target_events = event_seq[b_next, t_next] - 2 target_events = event_seq[b_next, t_next] - 2
nll, reg = self.criterion( nll, reg = self.criterion(
@@ -460,10 +424,6 @@ class Trainer:
batch_nll_sum = nll.sum().item() batch_nll_sum = nll.sum().item()
total_val_nll += batch_nll_sum total_val_nll += batch_nll_sum
total_val_reg += reg.item() * num_pairs total_val_reg += reg.item() * num_pairs
if log_scale_sq is not None:
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
if log_shape_sq is not None:
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
total_val_pairs += num_pairs total_val_pairs += num_pairs
current_val_avg_nll = total_val_nll / \ current_val_avg_nll = total_val_nll / \
@@ -478,16 +438,6 @@ class Trainer:
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0 val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0 val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
val_log_scale_sq = (
total_val_log_scale_sq / total_val_pairs
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
val_log_shape_sq = (
total_val_log_shape_sq / total_val_pairs
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
history.append({ history.append({
"epoch": epoch, "epoch": epoch,
@@ -495,11 +445,6 @@ class Trainer:
"train_reg": train_reg, "train_reg": train_reg,
"val_nll": val_nll, "val_nll": val_nll,
"val_reg": val_reg, "val_reg": val_reg,
"delta_scale": delta_scale,
"train_log_scale_sq": train_log_scale_sq,
"train_log_shape_sq": train_log_shape_sq,
"val_log_scale_sq": val_log_scale_sq,
"val_log_shape_sq": val_log_shape_sq,
}) })
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
@@ -507,18 +452,6 @@ class Trainer:
tqdm.write(f" Train Reg: {train_reg:.4f}") tqdm.write(f" Train Reg: {train_reg:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
tqdm.write(f" Val Reg: {val_reg:.4f}") tqdm.write(f" Val Reg: {val_reg:.4f}")
if delta_scale is not None:
tqdm.write(f" Delta scale: {delta_scale:.6g}")
if train_log_scale_sq is not None and train_log_shape_sq is not None:
tqdm.write(
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
tqdm.write(
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
if val_log_scale_sq is not None and val_log_shape_sq is not None:
tqdm.write(
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
tqdm.write(
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4) json.dump(history, f, indent=4)
@@ -533,6 +466,7 @@ class Trainer:
"epoch": epoch, "epoch": epoch,
"global_step": self.global_step, "global_step": self.global_step,
"model_state_dict": self.model.state_dict(), "model_state_dict": self.model.state_dict(),
"head_state_dict": self.head.state_dict(),
"criterion_state_dict": self.criterion.state_dict(), "criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path) }, self.best_path)