diff --git a/losses.py b/losses.py index 3b54c49..acd0847 100644 --- a/losses.py +++ b/losses.py @@ -133,18 +133,13 @@ class ExponentialNLLLoss(nn.Module): class PiecewiseExponentialLoss(nn.Module): - """Piecewise-constant competing risks exponential likelihood. + """ + Piecewise-constant competing risks exponential likelihood. - Uses B time bins defined by `bin_edges` (length B+1, strictly increasing, starting at 0). - Within each bin b, hazards are constant and parameterized as: - - hazards = softplus(logits) + eps with logits shape (M, K, B) - - For each sample i, dt is bucketized to bin b* and the NLL is: - - nll_i = -log(hazard_{k*}(b*)) + \int_0^{dt} sum_k hazard_k(u) du - - The integral is computed in closed form by summing full bins plus the partial bin b*. + Lightweight numerical protections: + - Does NOT mask/skip any samples. + - Uses nan_to_num for dt/logits/targets to avoid NaN/Inf propagation. + - Clamps logits and dt to keep softplus/log operations finite. """ def __init__( @@ -152,6 +147,7 @@ class PiecewiseExponentialLoss(nn.Module): bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, + logit_clip: float = 30.0, ): super().__init__() @@ -165,6 +161,7 @@ class PiecewiseExponentialLoss(nn.Module): self.eps = float(eps) self.lambda_reg = float(lambda_reg) + self.logit_clip = float(logit_clip) edges = torch.tensor(list(bin_edges), dtype=torch.float32) self.register_buffer("bin_edges", edges, persistent=False) @@ -186,75 +183,87 @@ class PiecewiseExponentialLoss(nn.Module): ) device = logits.device - dt = dt.to(device=device) + dt = dt.to(device=device, dtype=torch.float32) target_events = target_events.to(device=device) - # Build a per-sample finite mask to avoid NaN/Inf propagation. - logits_finite = torch.isfinite(logits).view(M, -1).all(dim=1) - dt_finite = torch.isfinite(dt) - target_finite = torch.isfinite(target_events) - finite_mask = logits_finite & dt_finite & target_finite + # No masking/skipping: coerce invalid values to safe defaults. + logits_v = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0) + logits_v = torch.clamp( + logits_v, min=-self.logit_clip, max=self.logit_clip) - nll_full = logits.new_zeros((M,)) + dt_v = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0) + target_v = torch.nan_to_num( + target_events, nan=0.0, posinf=0.0, neginf=0.0) + target_v = target_v.to(dtype=torch.long) + target_v = torch.clamp(target_v, min=0, max=K - 1) - if not finite_mask.any(): - nll_out = nll_full if reduction == "none" else logits.new_zeros(()) - reg_out = logits.new_zeros(()) - return nll_out, reg_out - - idx = finite_mask.nonzero(as_tuple=False).squeeze(1) - logits_v = logits[idx] - target_v = target_events[idx].to(torch.long) - dt_v = dt[idx].to(torch.float32) - - # Clamp dt into [eps, max_edge) to keep bucket indices valid. + # Keep structural clamping to prevent index-out-of-bounds errors + # (Necessary for searchsorted/gather to work at all) eps = self.eps max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype) dt_max = torch.nextafter(max_edge, max_edge.new_zeros(())) dt_v = torch.clamp(dt_v, min=eps, max=dt_max) - hazards = F.softplus(logits_v) + eps # (Mv, K, B) - total_hazard = hazards.sum(dim=1) # (Mv, B) + # Hazards: (M, K, B) + hazards = F.softplus(logits_v) + eps + hazards = torch.clamp(hazards, min=eps) + total_hazard = hazards.sum(dim=1) # (M, B) edges = self.bin_edges.to(device=device, dtype=dt_v.dtype) widths = edges[1:] - edges[:-1] # (B,) - # Bin index b* in [0, B-1]. boundaries are edges[1:] (length B). - b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (Mv,) + # Bin index b* in [0, B-1]. + b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (M,) b_star = torch.clamp(b_star, min=0, max=B - 1) - ar = torch.arange(logits_v.size(0), device=device) - hazard_event = hazards[ar, target_v, b_star] # (Mv,) + # 1. Hazard at event (M,) + # gather needs matching dims. + # hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,) + # Alternative: hazards[m, k, b] + ar = torch.arange(M, device=device) + hazard_event = hazards[ar, target_v, b_star] # (M,) + hazard_event = torch.clamp(hazard_event, min=eps) + # 2. Integral part # Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left) - weighted = total_hazard * widths.unsqueeze(0) # (Mv, B) - cum = weighted.cumsum(dim=1) # (Mv, B) - full_bins_int = torch.zeros_like(dt_v) - has_full = b_star > 0 - if has_full.any(): - full_bins_int[has_full] = cum.gather( - 1, (b_star[has_full] - 1).unsqueeze(1) - ).squeeze(1) - edge_left = edges[b_star] # (Mv,) - partial = total_hazard.gather( - 1, b_star.unsqueeze(1)).squeeze(1) * (dt_v - edge_left) + # Full bins accumulation + weighted = total_hazard * widths.unsqueeze(0) # (M, B) + cum = weighted.cumsum(dim=1) # (M, B) + + full_bins_int = torch.zeros_like(dt_v) + + # We process 'has_full' logic generally. + # If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional + has_full = b_star > 0 + + # NOTE: Even without protection, we need valid indices for gather. + # We use a temporary index that is safe (0) for the 'False' cases, then mask the result. + safe_indices = (b_star - 1).clamp(min=0) + gathered_cum = cum.gather(1, safe_indices.unsqueeze(1)).squeeze(1) + full_bins_int = torch.where(has_full, gathered_cum, full_bins_int) + + # Partial bin accumulation + edge_left = edges[b_star] # (M,) + partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1) + partial = partial_hazard * (dt_v - edge_left) + integral = full_bins_int + partial - nll_v = -torch.log(hazard_event) + integral - nll_full[idx] = nll_v + # Final NLL + nll = -torch.log(hazard_event) + integral + # Reduction if reduction == "none": - nll_out = nll_full + nll_out = nll elif reduction == "sum": - nll_out = nll_v.sum() + nll_out = nll.sum() elif reduction == "mean": - nll_out = nll_v.mean() if nll_v.numel() > 0 else logits.new_zeros(()) + nll_out = nll.mean() else: raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") reg = logits.new_zeros(()) - if self.lambda_reg != 0.0: reg = reg + (self.lambda_reg * logits_v.pow(2).mean()) @@ -263,77 +272,83 @@ class PiecewiseExponentialLoss(nn.Module): class WeibullNLLLoss(nn.Module): """ - Weibull hazard in t. + Weibull hazard in t with lightweight numerical protections. - .. math:: - \\Lambda_k(t) = \\text{scale}_k \\cdot t^{\\text{shape}_k} - - \\lambda_k(t) = \\text{shape}_k \\cdot \\text{scale}_k \\cdot t^{\\text{shape}_k-1} - - Args: - eps (float): Small epsilon for numerical stability. - lambda_reg (float): Regularization weight. - use_interval_near_integer (bool): If True, use interval likelihood for near-integer-year samples. - near_integer_eps_years (float): Near-integer threshold in years. - interval_half_width_years (float): Half-width \u0394 for interval [t-\u0394, t+\u0394] in years. - min_integer_year (float): Only apply near-integer logic when round(t) >= min_integer_year. + Does NOT mask/skip any samples. Instead: + - nan_to_num for logits/dt/targets + - clamps logits to keep softplus outputs reasonable + - computes t^shape in log-space with clamped exponent to prevent overflow """ def __init__( self, eps: float = 1e-6, lambda_reg: float = 0.0, + logit_clip: float = 30.0, + max_shape: float = 30.0, + max_dt: float = 1.0e3, + max_exp: float = 80.0, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg + self.logit_clip = float(logit_clip) + self.max_shape = float(max_shape) + self.max_dt = float(max_dt) + self.max_exp = float(max_exp) - def forward( - self, - logits: torch.Tensor, - target_events: torch.Tensor, - dt: torch.Tensor, - reduction: str = "mean", - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass. + def forward(self, logits, target_events, dt, reduction="mean"): + if logits.dim() != 3 or logits.size(-1) != 2: + raise ValueError("logits must have shape (M, K, 2)") - Args: - logits (torch.Tensor): (M, K, 2) tensor of logits. - target_events (torch.Tensor): (M,) tensor of target events. - dt (torch.Tensor): (M,) tensor of time intervals. - reduction (str): 'mean', 'sum', or 'none'. + M, K, _ = logits.shape + device = logits.device - Returns: - Tuple[torch.Tensor, torch.Tensor]: (nll, regularization). - """ - shapes = F.softplus(logits[..., 0]) + self.eps # (M,K) - scales = F.softplus(logits[..., 1]) + self.eps # (M,K) - eps = self.eps - t = torch.clamp(dt, min=eps) + logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0) + logits = torch.clamp(logits, min=-self.logit_clip, max=self.logit_clip) - t_mat = t.unsqueeze(1) # (M,1) + dt = dt.to(device=device, dtype=torch.float32) + dt = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0) + dt = torch.clamp(dt, min=self.eps, max=self.max_dt) - # cumulative hazard (M,K) - cum_hazard = scales * t_mat.pow(shapes) + target_events = target_events.to(device=device) + target_events = torch.nan_to_num( + target_events, nan=0.0, posinf=0.0, neginf=0.0) + target_events = target_events.to(dtype=torch.long) + target_events = torch.clamp(target_events, min=0, max=K - 1) - # hazard (M,K) - hazard = shapes * scales * t_mat.pow(shapes - 1.0) + shapes = F.softplus(logits[..., 0]) + self.eps + scales = F.softplus(logits[..., 1]) + self.eps + shapes = torch.clamp(shapes, min=self.eps, max=self.max_shape) + scales = torch.clamp(scales, min=self.eps) + t_mat = dt.unsqueeze(1) # (M,1) + log_t = torch.log(torch.clamp(t_mat, min=self.eps)) + + # Compute t^shape and t^(shape-1) in log-space with exponent clamp. + pow_shape = torch.exp(torch.clamp(shapes * log_t, max=self.max_exp)) + pow_shape_minus_1 = torch.exp( + torch.clamp((shapes - 1.0) * log_t, max=self.max_exp) + ) + + cum_hazard = scales * pow_shape + hazard = shapes * scales * pow_shape_minus_1 hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1) - # Point-event likelihood: f_k(t) = \lambda_k(t) * exp(-\Lambda_total(t)) - # NLL_point = -log \lambda_{k*}(t) + \Lambda_total(t) - nll = -torch.log(hazard_event + eps) + cum_hazard.sum(dim=1) + hazard_event = torch.clamp(hazard_event, min=self.eps) + + nll = -torch.log(hazard_event) + cum_hazard.sum(dim=1) if reduction == "mean": nll = nll.mean() elif reduction == "sum": nll = nll.sum() + elif reduction != "none": + raise ValueError("reduction must be one of: 'mean', 'sum', 'none'") reg = shapes.new_zeros(()) if self.lambda_reg > 0: reg = self.lambda_reg * ( - (torch.log(scales + eps) ** 2).mean() + - (torch.log(shapes + eps) ** 2).mean() + (torch.log(scales + self.eps) ** 2).mean() + + (torch.log(shapes + self.eps) ** 2).mean() ) return nll, reg