diff --git a/evaluation_time_dependent.py b/evaluation_time_dependent.py index 9e6f2f2..955b02b 100644 --- a/evaluation_time_dependent.py +++ b/evaluation_time_dependent.py @@ -188,8 +188,9 @@ def evaluate_time_dependent( h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) - # Context index selection (independent of horizon); keep mask is refined per horizon. - keep0, t_ctx, _ = select_context_indices( + # Select a single fixed context per sample for this batch. + # Horizon-specific eligibility is derived from this context (do not re-select per horizon). + keep0, t_ctx, t_ctx_time = select_context_indices( event_seq=event_seq, time_seq=time_seq, offset_years=float(cfg.offset_years), @@ -211,15 +212,20 @@ def evaluate_time_dependent( f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}" ) + # Follow-up end time per sample = time at last valid token. + valid = event_seq != 0 + lengths = valid.sum(dim=1) + last_idx = torch.clamp(lengths - 1, min=0) + followup_end_time = time_seq[b, last_idx] + for h_idx, tau_y in enumerate(horizons_years): - keep, _, _ = select_context_indices( - event_seq=event_seq, - time_seq=time_seq, - offset_years=float(cfg.offset_years), - tau_years=float(tau_y), + # Horizon-specific eligibility without reselecting context: + # keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau) + keep_tau = keep0 & ( + followup_end_time >= ( + t_ctx_time + (float(tau_y) * DAYS_PER_YEAR)) ) - keep = keep & keep0 - if not keep.any(): + if not keep_tau.any(): continue if cause_ids is None: @@ -230,8 +236,8 @@ def evaluate_time_dependent( tau_years=float(tau_y), n_disease=n_disease, ) - y = y[keep] - preds = cifs_all[keep, :, h_idx] + y = y[keep_tau] + preds = cifs_all[keep_tau, :, h_idx] else: y = multi_hot_selected_causes_within_horizon( event_seq=event_seq, @@ -241,8 +247,8 @@ def evaluate_time_dependent( cause_ids=cause_ids, n_disease=n_disease, ) - y = y[keep] - preds = cifs_all[keep, :, h_idx].index_select( + y = y[keep_tau] + preds = cifs_all[keep_tau, :, h_idx].index_select( dim=1, index=cause_ids) y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())