diff --git a/losses.py b/losses.py index acd0847..3ae8976 100644 --- a/losses.py +++ b/losses.py @@ -135,11 +135,6 @@ class ExponentialNLLLoss(nn.Module): class PiecewiseExponentialLoss(nn.Module): """ Piecewise-constant competing risks exponential likelihood. - - Lightweight numerical protections: - - Does NOT mask/skip any samples. - - Uses nan_to_num for dt/logits/targets to avoid NaN/Inf propagation. - - Clamps logits and dt to keep softplus/log operations finite. """ def __init__( @@ -147,7 +142,6 @@ class PiecewiseExponentialLoss(nn.Module): bin_edges: Sequence[float], eps: float = 1e-6, lambda_reg: float = 0.0, - logit_clip: float = 30.0, ): super().__init__() @@ -161,7 +155,6 @@ class PiecewiseExponentialLoss(nn.Module): self.eps = float(eps) self.lambda_reg = float(lambda_reg) - self.logit_clip = float(logit_clip) edges = torch.tensor(list(bin_edges), dtype=torch.float32) self.register_buffer("bin_edges", edges, persistent=False) @@ -186,43 +179,33 @@ class PiecewiseExponentialLoss(nn.Module): dt = dt.to(device=device, dtype=torch.float32) target_events = target_events.to(device=device) - # No masking/skipping: coerce invalid values to safe defaults. - logits_v = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0) - logits_v = torch.clamp( - logits_v, min=-self.logit_clip, max=self.logit_clip) - - dt_v = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0) - target_v = torch.nan_to_num( - target_events, nan=0.0, posinf=0.0, neginf=0.0) - target_v = target_v.to(dtype=torch.long) - target_v = torch.clamp(target_v, min=0, max=K - 1) - - # Keep structural clamping to prevent index-out-of-bounds errors - # (Necessary for searchsorted/gather to work at all) - eps = self.eps - max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype) - dt_max = torch.nextafter(max_edge, max_edge.new_zeros(())) - dt_v = torch.clamp(dt_v, min=eps, max=dt_max) + if target_events.dtype != torch.long: + target_events = target_events.to(dtype=torch.long) + if target_events.min().item() < 0 or target_events.max().item() >= K: + raise ValueError("target_events must be in [0, K)") # Hazards: (M, K, B) - hazards = F.softplus(logits_v) + eps - hazards = torch.clamp(hazards, min=eps) + hazards = F.softplus(logits) + self.eps total_hazard = hazards.sum(dim=1) # (M, B) - edges = self.bin_edges.to(device=device, dtype=dt_v.dtype) + edges = self.bin_edges.to(device=device, dtype=dt.dtype) widths = edges[1:] - edges[:-1] # (B,) + if dt.min().item() <= 0: + raise ValueError("dt must be strictly positive") + if dt.max().item() > edges[-1].item(): + raise ValueError("dt must be <= last bin edge") + # Bin index b* in [0, B-1]. - b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (M,) - b_star = torch.clamp(b_star, min=0, max=B - 1) + b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,) # 1. Hazard at event (M,) # gather needs matching dims. # hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,) # Alternative: hazards[m, k, b] ar = torch.arange(M, device=device) - hazard_event = hazards[ar, target_v, b_star] # (M,) - hazard_event = torch.clamp(hazard_event, min=eps) + hazard_event = hazards[ar, target_events, b_star] # (M,) + hazard_event = torch.clamp(hazard_event, min=self.eps) # 2. Integral part # Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left) @@ -231,7 +214,7 @@ class PiecewiseExponentialLoss(nn.Module): weighted = total_hazard * widths.unsqueeze(0) # (M, B) cum = weighted.cumsum(dim=1) # (M, B) - full_bins_int = torch.zeros_like(dt_v) + full_bins_int = torch.zeros_like(dt) # We process 'has_full' logic generally. # If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional @@ -246,7 +229,7 @@ class PiecewiseExponentialLoss(nn.Module): # Partial bin accumulation edge_left = edges[b_star] # (M,) partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1) - partial = partial_hazard * (dt_v - edge_left) + partial = partial_hazard * (dt - edge_left) integral = full_bins_int + partial @@ -265,37 +248,24 @@ class PiecewiseExponentialLoss(nn.Module): reg = logits.new_zeros(()) if self.lambda_reg != 0.0: - reg = reg + (self.lambda_reg * logits_v.pow(2).mean()) + reg = reg + (self.lambda_reg * logits.pow(2).mean()) return nll_out, reg class WeibullNLLLoss(nn.Module): """ - Weibull hazard in t with lightweight numerical protections. - - Does NOT mask/skip any samples. Instead: - - nan_to_num for logits/dt/targets - - clamps logits to keep softplus outputs reasonable - - computes t^shape in log-space with clamped exponent to prevent overflow + Weibull hazard in t. """ def __init__( self, eps: float = 1e-6, lambda_reg: float = 0.0, - logit_clip: float = 30.0, - max_shape: float = 30.0, - max_dt: float = 1.0e3, - max_exp: float = 80.0, ): super().__init__() self.eps = eps self.lambda_reg = lambda_reg - self.logit_clip = float(logit_clip) - self.max_shape = float(max_shape) - self.max_dt = float(max_dt) - self.max_exp = float(max_exp) def forward(self, logits, target_events, dt, reduction="mean"): if logits.dim() != 3 or logits.size(-1) != 2: @@ -304,35 +274,21 @@ class WeibullNLLLoss(nn.Module): M, K, _ = logits.shape device = logits.device - logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0) - logits = torch.clamp(logits, min=-self.logit_clip, max=self.logit_clip) - dt = dt.to(device=device, dtype=torch.float32) - dt = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0) - dt = torch.clamp(dt, min=self.eps, max=self.max_dt) + if dt.min().item() <= 0: + raise ValueError("dt must be strictly positive") target_events = target_events.to(device=device) - target_events = torch.nan_to_num( - target_events, nan=0.0, posinf=0.0, neginf=0.0) target_events = target_events.to(dtype=torch.long) - target_events = torch.clamp(target_events, min=0, max=K - 1) + if target_events.min().item() < 0 or target_events.max().item() >= K: + raise ValueError("target_events must be in [0, K)") shapes = F.softplus(logits[..., 0]) + self.eps scales = F.softplus(logits[..., 1]) + self.eps - shapes = torch.clamp(shapes, min=self.eps, max=self.max_shape) - scales = torch.clamp(scales, min=self.eps) t_mat = dt.unsqueeze(1) # (M,1) - log_t = torch.log(torch.clamp(t_mat, min=self.eps)) - - # Compute t^shape and t^(shape-1) in log-space with exponent clamp. - pow_shape = torch.exp(torch.clamp(shapes * log_t, max=self.max_exp)) - pow_shape_minus_1 = torch.exp( - torch.clamp((shapes - 1.0) * log_t, max=self.max_exp) - ) - - cum_hazard = scales * pow_shape - hazard = shapes * scales * pow_shape_minus_1 + cum_hazard = scales * torch.pow(t_mat, shapes) + hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0) hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1) hazard_event = torch.clamp(hazard_event, min=self.eps)