Refactor context index selection in evaluate_time_dependent to improve horizon-specific eligibility handling
This commit is contained in:
@@ -188,8 +188,9 @@ def evaluate_time_dependent(
|
|||||||
|
|
||||||
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
|
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
|
||||||
|
|
||||||
# Context index selection (independent of horizon); keep mask is refined per horizon.
|
# Select a single fixed context per sample for this batch.
|
||||||
keep0, t_ctx, _ = select_context_indices(
|
# Horizon-specific eligibility is derived from this context (do not re-select per horizon).
|
||||||
|
keep0, t_ctx, t_ctx_time = select_context_indices(
|
||||||
event_seq=event_seq,
|
event_seq=event_seq,
|
||||||
time_seq=time_seq,
|
time_seq=time_seq,
|
||||||
offset_years=float(cfg.offset_years),
|
offset_years=float(cfg.offset_years),
|
||||||
@@ -211,15 +212,20 @@ def evaluate_time_dependent(
|
|||||||
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
|
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Follow-up end time per sample = time at last valid token.
|
||||||
|
valid = event_seq != 0
|
||||||
|
lengths = valid.sum(dim=1)
|
||||||
|
last_idx = torch.clamp(lengths - 1, min=0)
|
||||||
|
followup_end_time = time_seq[b, last_idx]
|
||||||
|
|
||||||
for h_idx, tau_y in enumerate(horizons_years):
|
for h_idx, tau_y in enumerate(horizons_years):
|
||||||
keep, _, _ = select_context_indices(
|
# Horizon-specific eligibility without reselecting context:
|
||||||
event_seq=event_seq,
|
# keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
|
||||||
time_seq=time_seq,
|
keep_tau = keep0 & (
|
||||||
offset_years=float(cfg.offset_years),
|
followup_end_time >= (
|
||||||
tau_years=float(tau_y),
|
t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
|
||||||
)
|
)
|
||||||
keep = keep & keep0
|
if not keep_tau.any():
|
||||||
if not keep.any():
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cause_ids is None:
|
if cause_ids is None:
|
||||||
@@ -230,8 +236,8 @@ def evaluate_time_dependent(
|
|||||||
tau_years=float(tau_y),
|
tau_years=float(tau_y),
|
||||||
n_disease=n_disease,
|
n_disease=n_disease,
|
||||||
)
|
)
|
||||||
y = y[keep]
|
y = y[keep_tau]
|
||||||
preds = cifs_all[keep, :, h_idx]
|
preds = cifs_all[keep_tau, :, h_idx]
|
||||||
else:
|
else:
|
||||||
y = multi_hot_selected_causes_within_horizon(
|
y = multi_hot_selected_causes_within_horizon(
|
||||||
event_seq=event_seq,
|
event_seq=event_seq,
|
||||||
@@ -241,8 +247,8 @@ def evaluate_time_dependent(
|
|||||||
cause_ids=cause_ids,
|
cause_ids=cause_ids,
|
||||||
n_disease=n_disease,
|
n_disease=n_disease,
|
||||||
)
|
)
|
||||||
y = y[keep]
|
y = y[keep_tau]
|
||||||
preds = cifs_all[keep, :, h_idx].index_select(
|
preds = cifs_all[keep_tau, :, h_idx].index_select(
|
||||||
dim=1, index=cause_ids)
|
dim=1, index=cause_ids)
|
||||||
|
|
||||||
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
|
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
|
||||||
|
|||||||
Reference in New Issue
Block a user