diff --git a/losses.py b/losses.py index 69e16b6..35a5942 100644 --- a/losses.py +++ b/losses.py @@ -138,8 +138,8 @@ class DiscreteTimeCIFNLLLoss(nn.Module): This loss assumes the model outputs per-bin logits over (K causes + 1 complement) channels, where the complement channel (index K) represents survival across bins. - Per-sample likelihood for observed cause k at time bin j: - p = \prod_{u=1}^{j-1} p(comp at u) * p(k at j) + Per-sample likelihood for observed cause k at time bin j: + p = \\prod_{u=1}^{j-1} p(comp at u) * p(k at j) Args: bin_edges: Increasing sequence of floats of length (n_bins + 1) with bin_edges[0] == 0.