Refactor PiecewiseExponentialLoss and WeibullNLLLoss: remove lightweight numerical protections and improve error handling for input validation
This commit is contained in:
92
losses.py
92
losses.py
@@ -135,11 +135,6 @@ class ExponentialNLLLoss(nn.Module):
|
|||||||
class PiecewiseExponentialLoss(nn.Module):
|
class PiecewiseExponentialLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Piecewise-constant competing risks exponential likelihood.
|
Piecewise-constant competing risks exponential likelihood.
|
||||||
|
|
||||||
Lightweight numerical protections:
|
|
||||||
- Does NOT mask/skip any samples.
|
|
||||||
- Uses nan_to_num for dt/logits/targets to avoid NaN/Inf propagation.
|
|
||||||
- Clamps logits and dt to keep softplus/log operations finite.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -147,7 +142,6 @@ class PiecewiseExponentialLoss(nn.Module):
|
|||||||
bin_edges: Sequence[float],
|
bin_edges: Sequence[float],
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
lambda_reg: float = 0.0,
|
lambda_reg: float = 0.0,
|
||||||
logit_clip: float = 30.0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -161,7 +155,6 @@ class PiecewiseExponentialLoss(nn.Module):
|
|||||||
|
|
||||||
self.eps = float(eps)
|
self.eps = float(eps)
|
||||||
self.lambda_reg = float(lambda_reg)
|
self.lambda_reg = float(lambda_reg)
|
||||||
self.logit_clip = float(logit_clip)
|
|
||||||
|
|
||||||
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
edges = torch.tensor(list(bin_edges), dtype=torch.float32)
|
||||||
self.register_buffer("bin_edges", edges, persistent=False)
|
self.register_buffer("bin_edges", edges, persistent=False)
|
||||||
@@ -186,43 +179,33 @@ class PiecewiseExponentialLoss(nn.Module):
|
|||||||
dt = dt.to(device=device, dtype=torch.float32)
|
dt = dt.to(device=device, dtype=torch.float32)
|
||||||
target_events = target_events.to(device=device)
|
target_events = target_events.to(device=device)
|
||||||
|
|
||||||
# No masking/skipping: coerce invalid values to safe defaults.
|
if target_events.dtype != torch.long:
|
||||||
logits_v = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
|
target_events = target_events.to(dtype=torch.long)
|
||||||
logits_v = torch.clamp(
|
if target_events.min().item() < 0 or target_events.max().item() >= K:
|
||||||
logits_v, min=-self.logit_clip, max=self.logit_clip)
|
raise ValueError("target_events must be in [0, K)")
|
||||||
|
|
||||||
dt_v = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
target_v = torch.nan_to_num(
|
|
||||||
target_events, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
target_v = target_v.to(dtype=torch.long)
|
|
||||||
target_v = torch.clamp(target_v, min=0, max=K - 1)
|
|
||||||
|
|
||||||
# Keep structural clamping to prevent index-out-of-bounds errors
|
|
||||||
# (Necessary for searchsorted/gather to work at all)
|
|
||||||
eps = self.eps
|
|
||||||
max_edge = self.bin_edges[-1].to(device=device, dtype=dt_v.dtype)
|
|
||||||
dt_max = torch.nextafter(max_edge, max_edge.new_zeros(()))
|
|
||||||
dt_v = torch.clamp(dt_v, min=eps, max=dt_max)
|
|
||||||
|
|
||||||
# Hazards: (M, K, B)
|
# Hazards: (M, K, B)
|
||||||
hazards = F.softplus(logits_v) + eps
|
hazards = F.softplus(logits) + self.eps
|
||||||
hazards = torch.clamp(hazards, min=eps)
|
|
||||||
total_hazard = hazards.sum(dim=1) # (M, B)
|
total_hazard = hazards.sum(dim=1) # (M, B)
|
||||||
|
|
||||||
edges = self.bin_edges.to(device=device, dtype=dt_v.dtype)
|
edges = self.bin_edges.to(device=device, dtype=dt.dtype)
|
||||||
widths = edges[1:] - edges[:-1] # (B,)
|
widths = edges[1:] - edges[:-1] # (B,)
|
||||||
|
|
||||||
|
if dt.min().item() <= 0:
|
||||||
|
raise ValueError("dt must be strictly positive")
|
||||||
|
if dt.max().item() > edges[-1].item():
|
||||||
|
raise ValueError("dt must be <= last bin edge")
|
||||||
|
|
||||||
# Bin index b* in [0, B-1].
|
# Bin index b* in [0, B-1].
|
||||||
b_star = torch.searchsorted(edges[1:], dt_v, right=False) # (M,)
|
b_star = torch.searchsorted(edges[1:], dt, right=False) # (M,)
|
||||||
b_star = torch.clamp(b_star, min=0, max=B - 1)
|
|
||||||
|
|
||||||
# 1. Hazard at event (M,)
|
# 1. Hazard at event (M,)
|
||||||
# gather needs matching dims.
|
# gather needs matching dims.
|
||||||
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
|
# hazards: (M, K, B) -> select target_event -> (M, B) -> select b_star -> (M,)
|
||||||
# Alternative: hazards[m, k, b]
|
# Alternative: hazards[m, k, b]
|
||||||
ar = torch.arange(M, device=device)
|
ar = torch.arange(M, device=device)
|
||||||
hazard_event = hazards[ar, target_v, b_star] # (M,)
|
hazard_event = hazards[ar, target_events, b_star] # (M,)
|
||||||
hazard_event = torch.clamp(hazard_event, min=eps)
|
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||||
|
|
||||||
# 2. Integral part
|
# 2. Integral part
|
||||||
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
# Integral: sum_{b < b*} total_hazard[:,b]*width_b + total_hazard[:,b*]*(dt-edge_left)
|
||||||
@@ -231,7 +214,7 @@ class PiecewiseExponentialLoss(nn.Module):
|
|||||||
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
|
weighted = total_hazard * widths.unsqueeze(0) # (M, B)
|
||||||
cum = weighted.cumsum(dim=1) # (M, B)
|
cum = weighted.cumsum(dim=1) # (M, B)
|
||||||
|
|
||||||
full_bins_int = torch.zeros_like(dt_v)
|
full_bins_int = torch.zeros_like(dt)
|
||||||
|
|
||||||
# We process 'has_full' logic generally.
|
# We process 'has_full' logic generally.
|
||||||
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
|
# If b_star is 0, gather on index -1 would fail or wrap, so we mask carefully or use conditional
|
||||||
@@ -246,7 +229,7 @@ class PiecewiseExponentialLoss(nn.Module):
|
|||||||
# Partial bin accumulation
|
# Partial bin accumulation
|
||||||
edge_left = edges[b_star] # (M,)
|
edge_left = edges[b_star] # (M,)
|
||||||
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
|
partial_hazard = total_hazard.gather(1, b_star.unsqueeze(1)).squeeze(1)
|
||||||
partial = partial_hazard * (dt_v - edge_left)
|
partial = partial_hazard * (dt - edge_left)
|
||||||
|
|
||||||
integral = full_bins_int + partial
|
integral = full_bins_int + partial
|
||||||
|
|
||||||
@@ -265,37 +248,24 @@ class PiecewiseExponentialLoss(nn.Module):
|
|||||||
|
|
||||||
reg = logits.new_zeros(())
|
reg = logits.new_zeros(())
|
||||||
if self.lambda_reg != 0.0:
|
if self.lambda_reg != 0.0:
|
||||||
reg = reg + (self.lambda_reg * logits_v.pow(2).mean())
|
reg = reg + (self.lambda_reg * logits.pow(2).mean())
|
||||||
|
|
||||||
return nll_out, reg
|
return nll_out, reg
|
||||||
|
|
||||||
|
|
||||||
class WeibullNLLLoss(nn.Module):
|
class WeibullNLLLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
Weibull hazard in t with lightweight numerical protections.
|
Weibull hazard in t.
|
||||||
|
|
||||||
Does NOT mask/skip any samples. Instead:
|
|
||||||
- nan_to_num for logits/dt/targets
|
|
||||||
- clamps logits to keep softplus outputs reasonable
|
|
||||||
- computes t^shape in log-space with clamped exponent to prevent overflow
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
lambda_reg: float = 0.0,
|
lambda_reg: float = 0.0,
|
||||||
logit_clip: float = 30.0,
|
|
||||||
max_shape: float = 30.0,
|
|
||||||
max_dt: float = 1.0e3,
|
|
||||||
max_exp: float = 80.0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.lambda_reg = lambda_reg
|
self.lambda_reg = lambda_reg
|
||||||
self.logit_clip = float(logit_clip)
|
|
||||||
self.max_shape = float(max_shape)
|
|
||||||
self.max_dt = float(max_dt)
|
|
||||||
self.max_exp = float(max_exp)
|
|
||||||
|
|
||||||
def forward(self, logits, target_events, dt, reduction="mean"):
|
def forward(self, logits, target_events, dt, reduction="mean"):
|
||||||
if logits.dim() != 3 or logits.size(-1) != 2:
|
if logits.dim() != 3 or logits.size(-1) != 2:
|
||||||
@@ -304,35 +274,21 @@ class WeibullNLLLoss(nn.Module):
|
|||||||
M, K, _ = logits.shape
|
M, K, _ = logits.shape
|
||||||
device = logits.device
|
device = logits.device
|
||||||
|
|
||||||
logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
logits = torch.clamp(logits, min=-self.logit_clip, max=self.logit_clip)
|
|
||||||
|
|
||||||
dt = dt.to(device=device, dtype=torch.float32)
|
dt = dt.to(device=device, dtype=torch.float32)
|
||||||
dt = torch.nan_to_num(dt, nan=0.0, posinf=0.0, neginf=0.0)
|
if dt.min().item() <= 0:
|
||||||
dt = torch.clamp(dt, min=self.eps, max=self.max_dt)
|
raise ValueError("dt must be strictly positive")
|
||||||
|
|
||||||
target_events = target_events.to(device=device)
|
target_events = target_events.to(device=device)
|
||||||
target_events = torch.nan_to_num(
|
|
||||||
target_events, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
target_events = target_events.to(dtype=torch.long)
|
target_events = target_events.to(dtype=torch.long)
|
||||||
target_events = torch.clamp(target_events, min=0, max=K - 1)
|
if target_events.min().item() < 0 or target_events.max().item() >= K:
|
||||||
|
raise ValueError("target_events must be in [0, K)")
|
||||||
|
|
||||||
shapes = F.softplus(logits[..., 0]) + self.eps
|
shapes = F.softplus(logits[..., 0]) + self.eps
|
||||||
scales = F.softplus(logits[..., 1]) + self.eps
|
scales = F.softplus(logits[..., 1]) + self.eps
|
||||||
shapes = torch.clamp(shapes, min=self.eps, max=self.max_shape)
|
|
||||||
scales = torch.clamp(scales, min=self.eps)
|
|
||||||
|
|
||||||
t_mat = dt.unsqueeze(1) # (M,1)
|
t_mat = dt.unsqueeze(1) # (M,1)
|
||||||
log_t = torch.log(torch.clamp(t_mat, min=self.eps))
|
cum_hazard = scales * torch.pow(t_mat, shapes)
|
||||||
|
hazard = shapes * scales * torch.pow(t_mat, shapes - 1.0)
|
||||||
# Compute t^shape and t^(shape-1) in log-space with exponent clamp.
|
|
||||||
pow_shape = torch.exp(torch.clamp(shapes * log_t, max=self.max_exp))
|
|
||||||
pow_shape_minus_1 = torch.exp(
|
|
||||||
torch.clamp((shapes - 1.0) * log_t, max=self.max_exp)
|
|
||||||
)
|
|
||||||
|
|
||||||
cum_hazard = scales * pow_shape
|
|
||||||
hazard = shapes * scales * pow_shape_minus_1
|
|
||||||
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
hazard_event = hazard.gather(1, target_events.unsqueeze(1)).squeeze(1)
|
||||||
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
hazard_event = torch.clamp(hazard_event, min=self.eps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user