Refactor context index selection in evaluate_time_dependent to improve horizon-specific eligibility handling

This commit is contained in:
2026-01-16 15:01:35 +08:00
parent 34d8d8ce9d
commit 502ddd153b

View File

@@ -188,8 +188,9 @@ def evaluate_time_dependent(
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
# Context index selection (independent of horizon); keep mask is refined per horizon. # Select a single fixed context per sample for this batch.
keep0, t_ctx, _ = select_context_indices( # Horizon-specific eligibility is derived from this context (do not re-select per horizon).
keep0, t_ctx, t_ctx_time = select_context_indices(
event_seq=event_seq, event_seq=event_seq,
time_seq=time_seq, time_seq=time_seq,
offset_years=float(cfg.offset_years), offset_years=float(cfg.offset_years),
@@ -211,15 +212,20 @@ def evaluate_time_dependent(
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}" f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
) )
# Follow-up end time per sample = time at last valid token.
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
followup_end_time = time_seq[b, last_idx]
for h_idx, tau_y in enumerate(horizons_years): for h_idx, tau_y in enumerate(horizons_years):
keep, _, _ = select_context_indices( # Horizon-specific eligibility without reselecting context:
event_seq=event_seq, # keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
time_seq=time_seq, keep_tau = keep0 & (
offset_years=float(cfg.offset_years), followup_end_time >= (
tau_years=float(tau_y), t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
) )
keep = keep & keep0 if not keep_tau.any():
if not keep.any():
continue continue
if cause_ids is None: if cause_ids is None:
@@ -230,8 +236,8 @@ def evaluate_time_dependent(
tau_years=float(tau_y), tau_years=float(tau_y),
n_disease=n_disease, n_disease=n_disease,
) )
y = y[keep] y = y[keep_tau]
preds = cifs_all[keep, :, h_idx] preds = cifs_all[keep_tau, :, h_idx]
else: else:
y = multi_hot_selected_causes_within_horizon( y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq, event_seq=event_seq,
@@ -241,8 +247,8 @@ def evaluate_time_dependent(
cause_ids=cause_ids, cause_ids=cause_ids,
n_disease=n_disease, n_disease=n_disease,
) )
y = y[keep] y = y[keep_tau]
preds = cifs_all[keep, :, h_idx].index_select( preds = cifs_all[keep_tau, :, h_idx].index_select(
dim=1, index=cause_ids) dim=1, index=cause_ids)
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy()) y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())