Refactor context index selection in evaluate_time_dependent to improve horizon-specific eligibility handling

This commit is contained in:
2026-01-16 15:01:35 +08:00
parent 34d8d8ce9d
commit 502ddd153b

View File

@@ -188,8 +188,9 @@ def evaluate_time_dependent(
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
# Context index selection (independent of horizon); keep mask is refined per horizon.
keep0, t_ctx, _ = select_context_indices(
# Select a single fixed context per sample for this batch.
# Horizon-specific eligibility is derived from this context (do not re-select per horizon).
keep0, t_ctx, t_ctx_time = select_context_indices(
event_seq=event_seq,
time_seq=time_seq,
offset_years=float(cfg.offset_years),
@@ -211,15 +212,20 @@ def evaluate_time_dependent(
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
)
# Follow-up end time per sample = time at last valid token.
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
followup_end_time = time_seq[b, last_idx]
for h_idx, tau_y in enumerate(horizons_years):
keep, _, _ = select_context_indices(
event_seq=event_seq,
time_seq=time_seq,
offset_years=float(cfg.offset_years),
tau_years=float(tau_y),
# Horizon-specific eligibility without reselecting context:
# keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
keep_tau = keep0 & (
followup_end_time >= (
t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
)
keep = keep & keep0
if not keep.any():
if not keep_tau.any():
continue
if cause_ids is None:
@@ -230,8 +236,8 @@ def evaluate_time_dependent(
tau_years=float(tau_y),
n_disease=n_disease,
)
y = y[keep]
preds = cifs_all[keep, :, h_idx]
y = y[keep_tau]
preds = cifs_all[keep_tau, :, h_idx]
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
@@ -241,8 +247,8 @@ def evaluate_time_dependent(
cause_ids=cause_ids,
n_disease=n_disease,
)
y = y[keep]
preds = cifs_all[keep, :, h_idx].index_select(
y = y[keep_tau]
preds = cifs_all[keep_tau, :, h_idx].index_select(
dim=1, index=cause_ids)
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())