Add multi-GPU support for age-bin evaluation and refactor aggregation logic
This commit is contained in:
@@ -18,6 +18,158 @@ from utils import (
|
||||
)
|
||||
|
||||
|
||||
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Aggregate per-bin age evaluation results.
|
||||
|
||||
Produces both:
|
||||
- macro: unweighted mean over bins with n_samples>0
|
||||
- weighted: weighted mean over bins using weights=n_samples
|
||||
|
||||
Then aggregates across MC repetitions (mean/std).
|
||||
|
||||
Requires df_by_bin to include:
|
||||
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
|
||||
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
|
||||
|
||||
Returns:
|
||||
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
|
||||
"""
|
||||
if df_by_bin is None or len(df_by_bin) == 0:
|
||||
return pd.DataFrame(
|
||||
columns=[
|
||||
"agg_type",
|
||||
"horizon_tau",
|
||||
"topk_percent",
|
||||
"cause_id",
|
||||
"n_mc",
|
||||
"n_bins_used_mean",
|
||||
"n_samples_total_mean",
|
||||
"n_positives_total_mean",
|
||||
"auc_mean",
|
||||
"auc_std",
|
||||
"auprc_mean",
|
||||
"auprc_std",
|
||||
"recall_at_K_mean",
|
||||
"recall_at_K_std",
|
||||
"precision_at_K_mean",
|
||||
"precision_at_K_std",
|
||||
"brier_score_mean",
|
||||
"brier_score_std",
|
||||
]
|
||||
)
|
||||
|
||||
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
||||
g = group[group["n_samples"] > 0]
|
||||
if len(g) == 0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=0,
|
||||
n_samples_total=0,
|
||||
n_positives_total=0,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
n_bins_used = int(g["age_bin_id"].nunique())
|
||||
n_samples_total = int(g["n_samples"].sum())
|
||||
n_positives_total = int(g["n_positives"].sum())
|
||||
|
||||
if not weighted:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float(g["auc"].mean()),
|
||||
auprc=float(g["auprc"].mean()),
|
||||
recall_at_K=float(g["recall_at_K"].mean()),
|
||||
precision_at_K=float(g["precision_at_K"].mean()),
|
||||
brier_score=float(g["brier_score"].mean()),
|
||||
)
|
||||
)
|
||||
|
||||
w = g["n_samples"].to_numpy(dtype=float)
|
||||
w_sum = float(w.sum())
|
||||
if w_sum <= 0.0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
def _wavg(col: str) -> float:
|
||||
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
||||
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=_wavg("auc"),
|
||||
auprc=_wavg("auprc"),
|
||||
recall_at_K=_wavg("recall_at_K"),
|
||||
precision_at_K=_wavg("precision_at_K"),
|
||||
brier_score=_wavg("brier_score"),
|
||||
)
|
||||
)
|
||||
|
||||
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
||||
|
||||
df_mc_macro = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=False))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_macro["agg_type"] = "macro"
|
||||
|
||||
df_mc_weighted = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=True))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_weighted["agg_type"] = "weighted"
|
||||
|
||||
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
||||
|
||||
df_agg = (
|
||||
df_mc_binagg.groupby(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
||||
)
|
||||
.agg(
|
||||
n_mc=("mc_idx", "nunique"),
|
||||
n_bins_used_mean=("n_bins_used", "mean"),
|
||||
n_samples_total_mean=("n_samples_total", "mean"),
|
||||
n_positives_total_mean=("n_positives_total", "mean"),
|
||||
auc_mean=("auc", "mean"),
|
||||
auc_std=("auc", "std"),
|
||||
auprc_mean=("auprc", "mean"),
|
||||
auprc_std=("auprc", "std"),
|
||||
recall_at_K_mean=("recall_at_K", "mean"),
|
||||
recall_at_K_std=("recall_at_K", "std"),
|
||||
precision_at_K_mean=("precision_at_K", "mean"),
|
||||
precision_at_K_std=("precision_at_K", "std"),
|
||||
brier_score_mean=("brier_score", "mean"),
|
||||
brier_score_std=("brier_score", "std"),
|
||||
)
|
||||
.sort_values(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
||||
ignore_index=True,
|
||||
)
|
||||
)
|
||||
return df_agg
|
||||
|
||||
|
||||
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
||||
|
||||
@@ -138,6 +290,7 @@ def evaluate_time_dependent_age_bins(
|
||||
n_disease: int,
|
||||
cfg: EvalAgeConfig,
|
||||
device: str | torch.device,
|
||||
mc_offset: int = 0,
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
|
||||
|
||||
@@ -196,6 +349,7 @@ def evaluate_time_dependent_age_bins(
|
||||
]
|
||||
|
||||
for mc_idx in range(int(cfg.n_mc)):
|
||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||
# tqdm over batches; include MC idx in description.
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(dataloader,
|
||||
@@ -216,7 +370,7 @@ def evaluate_time_dependent_age_bins(
|
||||
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
||||
seed = (
|
||||
int(cfg.seed)
|
||||
+ (100_000 * int(mc_idx))
|
||||
+ (100_000 * int(global_mc_idx))
|
||||
+ (1_000 * int(tau_idx))
|
||||
+ (10 * int(bin_idx))
|
||||
+ int(batch_idx)
|
||||
@@ -277,6 +431,7 @@ def evaluate_time_dependent_age_bins(
|
||||
rows_by_bin: List[Dict[str, float | int]] = []
|
||||
|
||||
for mc_idx in range(int(cfg.n_mc)):
|
||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||
for h_idx, tau_y in enumerate(horizons_years):
|
||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
|
||||
@@ -287,7 +442,7 @@ def evaluate_time_dependent_age_bins(
|
||||
for k_percent in topk_percents:
|
||||
rows_by_bin.append(
|
||||
dict(
|
||||
mc_idx=mc_idx,
|
||||
mc_idx=global_mc_idx,
|
||||
age_bin_id=bin_idx,
|
||||
age_bin_low=float(a_lo),
|
||||
age_bin_high=float(a_hi),
|
||||
@@ -332,7 +487,7 @@ def evaluate_time_dependent_age_bins(
|
||||
yk, pk, float(k_percent))
|
||||
rows_by_bin.append(
|
||||
dict(
|
||||
mc_idx=mc_idx,
|
||||
mc_idx=global_mc_idx,
|
||||
age_bin_id=bin_idx,
|
||||
age_bin_low=float(a_lo),
|
||||
age_bin_high=float(a_hi),
|
||||
@@ -351,115 +506,6 @@ def evaluate_time_dependent_age_bins(
|
||||
|
||||
df_by_bin = pd.DataFrame(rows_by_bin)
|
||||
|
||||
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
||||
g = group[group["n_samples"] > 0]
|
||||
if len(g) == 0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=0,
|
||||
n_samples_total=0,
|
||||
n_positives_total=0,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
n_bins_used = int(g["age_bin_id"].nunique())
|
||||
n_samples_total = int(g["n_samples"].sum())
|
||||
n_positives_total = int(g["n_positives"].sum())
|
||||
|
||||
if not weighted:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float(g["auc"].mean()),
|
||||
auprc=float(g["auprc"].mean()),
|
||||
recall_at_K=float(g["recall_at_K"].mean()),
|
||||
precision_at_K=float(g["precision_at_K"].mean()),
|
||||
brier_score=float(g["brier_score"].mean()),
|
||||
)
|
||||
)
|
||||
|
||||
w = g["n_samples"].to_numpy(dtype=float)
|
||||
w_sum = float(w.sum())
|
||||
if w_sum <= 0.0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
def _wavg(col: str) -> float:
|
||||
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
||||
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=_wavg("auc"),
|
||||
auprc=_wavg("auprc"),
|
||||
recall_at_K=_wavg("recall_at_K"),
|
||||
precision_at_K=_wavg("precision_at_K"),
|
||||
brier_score=_wavg("brier_score"),
|
||||
)
|
||||
)
|
||||
|
||||
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
||||
|
||||
df_mc_macro = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=False))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_macro["agg_type"] = "macro"
|
||||
|
||||
df_mc_weighted = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=True))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_weighted["agg_type"] = "weighted"
|
||||
|
||||
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
||||
|
||||
# Then average over MC repetitions.
|
||||
df_agg = (
|
||||
df_mc_binagg.groupby(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
||||
)
|
||||
.agg(
|
||||
n_mc=("mc_idx", "nunique"),
|
||||
n_bins_used_mean=("n_bins_used", "mean"),
|
||||
n_samples_total_mean=("n_samples_total", "mean"),
|
||||
n_positives_total_mean=("n_positives_total", "mean"),
|
||||
auc_mean=("auc", "mean"),
|
||||
auc_std=("auc", "std"),
|
||||
auprc_mean=("auprc", "mean"),
|
||||
auprc_std=("auprc", "std"),
|
||||
recall_at_K_mean=("recall_at_K", "mean"),
|
||||
recall_at_K_std=("recall_at_K", "std"),
|
||||
precision_at_K_mean=("precision_at_K", "mean"),
|
||||
precision_at_K_std=("precision_at_K", "std"),
|
||||
brier_score_mean=("brier_score", "mean"),
|
||||
brier_score_std=("brier_score", "std"),
|
||||
)
|
||||
.sort_values(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
||||
ignore_index=True,
|
||||
)
|
||||
)
|
||||
df_agg = aggregate_age_bin_results(df_by_bin)
|
||||
|
||||
return df_by_bin, df_agg
|
||||
|
||||
Reference in New Issue
Block a user