Refactor context index selection in evaluate_time_dependent to improve horizon-specific eligibility handling
This commit is contained in:
@@ -188,8 +188,9 @@ def evaluate_time_dependent(
|
||||
|
||||
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
|
||||
|
||||
# Context index selection (independent of horizon); keep mask is refined per horizon.
|
||||
keep0, t_ctx, _ = select_context_indices(
|
||||
# Select a single fixed context per sample for this batch.
|
||||
# Horizon-specific eligibility is derived from this context (do not re-select per horizon).
|
||||
keep0, t_ctx, t_ctx_time = select_context_indices(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
offset_years=float(cfg.offset_years),
|
||||
@@ -211,15 +212,20 @@ def evaluate_time_dependent(
|
||||
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
|
||||
)
|
||||
|
||||
# Follow-up end time per sample = time at last valid token.
|
||||
valid = event_seq != 0
|
||||
lengths = valid.sum(dim=1)
|
||||
last_idx = torch.clamp(lengths - 1, min=0)
|
||||
followup_end_time = time_seq[b, last_idx]
|
||||
|
||||
for h_idx, tau_y in enumerate(horizons_years):
|
||||
keep, _, _ = select_context_indices(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
offset_years=float(cfg.offset_years),
|
||||
tau_years=float(tau_y),
|
||||
# Horizon-specific eligibility without reselecting context:
|
||||
# keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
|
||||
keep_tau = keep0 & (
|
||||
followup_end_time >= (
|
||||
t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
|
||||
)
|
||||
keep = keep & keep0
|
||||
if not keep.any():
|
||||
if not keep_tau.any():
|
||||
continue
|
||||
|
||||
if cause_ids is None:
|
||||
@@ -230,8 +236,8 @@ def evaluate_time_dependent(
|
||||
tau_years=float(tau_y),
|
||||
n_disease=n_disease,
|
||||
)
|
||||
y = y[keep]
|
||||
preds = cifs_all[keep, :, h_idx]
|
||||
y = y[keep_tau]
|
||||
preds = cifs_all[keep_tau, :, h_idx]
|
||||
else:
|
||||
y = multi_hot_selected_causes_within_horizon(
|
||||
event_seq=event_seq,
|
||||
@@ -241,8 +247,8 @@ def evaluate_time_dependent(
|
||||
cause_ids=cause_ids,
|
||||
n_disease=n_disease,
|
||||
)
|
||||
y = y[keep]
|
||||
preds = cifs_all[keep, :, h_idx].index_select(
|
||||
y = y[keep_tau]
|
||||
preds = cifs_all[keep_tau, :, h_idx].index_select(
|
||||
dim=1, index=cause_ids)
|
||||
|
||||
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
|
||||
|
||||
Reference in New Issue
Block a user