Refactor aggregation logic in age-bin results to handle pandas version compatibility

This commit is contained in:
2026-01-16 17:24:53 +08:00
parent 810c75e6d1
commit 4068310a12

View File

@@ -131,18 +131,24 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df_mc_macro = ( gb = df_by_bin.groupby(group_keys)
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=False)) try:
.reset_index() df_mc_macro = gb.apply(
) lambda g: _bin_aggregate(g, weighted=False), include_groups=False
).reset_index()
except TypeError: # pandas<2.2 (no include_groups)
df_mc_macro = gb.apply(lambda g: _bin_aggregate(
g, weighted=False)).reset_index()
df_mc_macro["agg_type"] = "macro" df_mc_macro["agg_type"] = "macro"
df_mc_weighted = ( try:
df_by_bin.groupby(group_keys) df_mc_weighted = gb.apply(
.apply(lambda g: _bin_aggregate(g, weighted=True)) lambda g: _bin_aggregate(g, weighted=True), include_groups=False
.reset_index() ).reset_index()
) except TypeError: # pandas<2.2 (no include_groups)
df_mc_weighted = gb.apply(
lambda g: _bin_aggregate(g, weighted=True)).reset_index()
df_mc_weighted["agg_type"] = "weighted" df_mc_weighted["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)