Add multi-GPU support for age-bin evaluation and refactor aggregation logic
This commit is contained in:
239
evaluate_age.py
239
evaluate_age.py
@@ -4,13 +4,19 @@ import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins
|
||||
from evaluation_age_time_dependent import (
|
||||
EvalAgeConfig,
|
||||
aggregate_age_bin_results,
|
||||
evaluate_time_dependent_age_bins,
|
||||
)
|
||||
from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss
|
||||
from model import DelphiFork, SapDelphi, SimpleHead
|
||||
|
||||
@@ -39,6 +45,144 @@ def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]:
|
||||
return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)]
|
||||
|
||||
|
||||
def _parse_gpus(gpus: str | None) -> List[int]:
|
||||
if gpus is None:
|
||||
return []
|
||||
s = gpus.strip()
|
||||
if not s:
|
||||
return []
|
||||
parts = [p.strip() for p in s.split(",") if p.strip()]
|
||||
out: List[int] = []
|
||||
for p in parts:
|
||||
out.append(int(p))
|
||||
return out
|
||||
|
||||
|
||||
def _worker_eval_mcs_on_gpu(
|
||||
queue: "mp.Queue",
|
||||
*,
|
||||
run_dir: str,
|
||||
split: str,
|
||||
data_prefix_override: str | None,
|
||||
horizons: List[float],
|
||||
age_bins: List[Tuple[float, float]],
|
||||
topk_percents: List[float],
|
||||
n_mc: int,
|
||||
seed: int,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
gpu_id: int,
|
||||
mc_indices: List[int],
|
||||
out_path: str,
|
||||
) -> None:
|
||||
"""Worker process: evaluate a subset of MC indices on a single GPU."""
|
||||
try:
|
||||
ckpt_path = os.path.join(run_dir, "best_model.pt")
|
||||
cfg_path = os.path.join(run_dir, "train_config.json")
|
||||
with open(cfg_path, "r") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
data_prefix = (
|
||||
data_prefix_override
|
||||
if data_prefix_override is not None
|
||||
else cfg.get("data_prefix", "ukb")
|
||||
)
|
||||
|
||||
full_cov = bool(cfg.get("full_cov", False))
|
||||
cov_list = None if full_cov else ["bmi", "smoking", "alcohol"]
|
||||
dataset = HealthDataset(data_prefix=data_prefix,
|
||||
covariate_list=cov_list)
|
||||
|
||||
train_ratio = float(cfg.get("train_ratio", 0.7))
|
||||
val_ratio = float(cfg.get("val_ratio", 0.15))
|
||||
seed_split = int(cfg.get("random_seed", 42))
|
||||
|
||||
n_total = len(dataset)
|
||||
n_train = int(n_total * train_ratio)
|
||||
n_val = int(n_total * val_ratio)
|
||||
n_test = n_total - n_train - n_val
|
||||
|
||||
train_ds, val_ds, test_ds = random_split(
|
||||
dataset,
|
||||
[n_train, n_val, n_test],
|
||||
generator=torch.Generator().manual_seed(seed_split),
|
||||
)
|
||||
|
||||
if split == "train":
|
||||
ds = train_ds
|
||||
elif split == "val":
|
||||
ds = val_ds
|
||||
elif split == "test":
|
||||
ds = test_ds
|
||||
else:
|
||||
ds = dataset
|
||||
|
||||
loader = DataLoader(
|
||||
ds,
|
||||
batch_size=int(batch_size),
|
||||
shuffle=False,
|
||||
collate_fn=health_collate_fn,
|
||||
num_workers=int(num_workers),
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
criterion, out_dims = build_criterion_and_out_dims(
|
||||
loss_type=str(cfg["loss_type"]),
|
||||
n_disease=int(dataset.n_disease),
|
||||
bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]),
|
||||
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
||||
)
|
||||
|
||||
model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg)
|
||||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims)
|
||||
|
||||
device = torch.device(f"cuda:{int(gpu_id)}")
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
|
||||
head.load_state_dict(checkpoint["head_state_dict"], strict=True)
|
||||
if "criterion_state_dict" in checkpoint:
|
||||
try:
|
||||
criterion.load_state_dict(
|
||||
checkpoint["criterion_state_dict"], strict=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model.to(device)
|
||||
head.to(device)
|
||||
criterion.to(device)
|
||||
|
||||
frames: List[pd.DataFrame] = []
|
||||
for mc_idx in mc_indices:
|
||||
eval_cfg = EvalAgeConfig(
|
||||
horizons_years=horizons,
|
||||
age_bins=age_bins,
|
||||
topk_percents=topk_percents,
|
||||
n_mc=1,
|
||||
seed=int(seed),
|
||||
cause_ids=None,
|
||||
)
|
||||
|
||||
df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins(
|
||||
model=model,
|
||||
head=head,
|
||||
criterion=criterion,
|
||||
dataloader=loader,
|
||||
n_disease=int(dataset.n_disease),
|
||||
cfg=eval_cfg,
|
||||
device=device,
|
||||
mc_offset=int(mc_idx),
|
||||
)
|
||||
frames.append(df_by_bin)
|
||||
|
||||
df_all = pd.concat(frames, ignore_index=True) if len(
|
||||
frames) else pd.DataFrame()
|
||||
df_all.to_csv(out_path, index=False)
|
||||
queue.put({"ok": True, "out_path": out_path})
|
||||
except Exception as e:
|
||||
queue.put({"ok": False, "error": repr(e)})
|
||||
|
||||
|
||||
def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float):
|
||||
if loss_type == "exponential":
|
||||
criterion = ExponentialNLLLoss(lambda_reg=lambda_reg)
|
||||
@@ -134,6 +278,13 @@ def main() -> None:
|
||||
parser.add_argument("--n_mc", type=int, default=5)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpus",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3",
|
||||
)
|
||||
|
||||
parser.add_argument("--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--batch_size", type=int, default=256)
|
||||
@@ -232,6 +383,17 @@ def main() -> None:
|
||||
cause_ids=None,
|
||||
)
|
||||
|
||||
if args.out_prefix is None:
|
||||
out_prefix = os.path.join(
|
||||
args.run_dir, f"age_bin_time_dependent_{args.split}")
|
||||
else:
|
||||
out_prefix = args.out_prefix
|
||||
|
||||
out_bin = out_prefix + "_by_bin.csv"
|
||||
out_agg = out_prefix + "_agg.csv"
|
||||
|
||||
gpus = _parse_gpus(args.gpus)
|
||||
if len(gpus) <= 1:
|
||||
df_by_bin, df_agg = evaluate_time_dependent_age_bins(
|
||||
model=model,
|
||||
head=head,
|
||||
@@ -242,18 +404,79 @@ def main() -> None:
|
||||
device=device,
|
||||
)
|
||||
|
||||
if args.out_prefix is None:
|
||||
out_prefix = os.path.join(
|
||||
args.run_dir, f"age_bin_time_dependent_{args.split}")
|
||||
else:
|
||||
out_prefix = args.out_prefix
|
||||
df_by_bin.to_csv(out_bin, index=False)
|
||||
df_agg.to_csv(out_agg, index=False)
|
||||
print(f"Wrote: {out_bin}")
|
||||
print(f"Wrote: {out_agg}")
|
||||
return
|
||||
|
||||
out_bin = out_prefix + "_by_bin.csv"
|
||||
out_agg = out_prefix + "_agg.csv"
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("--gpus was provided but CUDA is not available")
|
||||
|
||||
# Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU).
|
||||
mc_indices_all = list(range(int(args.n_mc)))
|
||||
per_gpu: List[Tuple[int, List[int]]] = []
|
||||
for pos, gpu_id in enumerate(gpus):
|
||||
assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos]
|
||||
if assigned:
|
||||
per_gpu.append((int(gpu_id), assigned))
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
queue: "mp.Queue" = ctx.Queue()
|
||||
procs: List[mp.Process] = []
|
||||
tmp_paths: List[str] = []
|
||||
|
||||
for gpu_id, mc_idxs in per_gpu:
|
||||
tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv"
|
||||
tmp_paths.append(tmp_path)
|
||||
p = ctx.Process(
|
||||
target=_worker_eval_mcs_on_gpu,
|
||||
kwargs=dict(
|
||||
queue=queue,
|
||||
run_dir=str(args.run_dir),
|
||||
split=str(args.split),
|
||||
data_prefix_override=(
|
||||
str(args.data_prefix) if args.data_prefix is not None else None
|
||||
),
|
||||
horizons=_parse_floats(args.horizons),
|
||||
age_bins=age_bins,
|
||||
topk_percents=[float(x) for x in args.topk_percent],
|
||||
n_mc=int(args.n_mc),
|
||||
seed=int(args.seed),
|
||||
batch_size=int(args.batch_size),
|
||||
num_workers=int(args.num_workers),
|
||||
gpu_id=int(gpu_id),
|
||||
mc_indices=mc_idxs,
|
||||
out_path=tmp_path,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
procs.append(p)
|
||||
|
||||
results = [queue.get() for _ in range(len(procs))]
|
||||
for p in procs:
|
||||
p.join()
|
||||
|
||||
for r in results:
|
||||
if not r.get("ok", False):
|
||||
raise SystemExit(f"Worker failed: {r.get('error')}")
|
||||
|
||||
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
|
||||
df_by_bin = pd.concat(frames, ignore_index=True) if len(
|
||||
frames) else pd.DataFrame()
|
||||
df_agg = aggregate_age_bin_results(df_by_bin)
|
||||
|
||||
df_by_bin.to_csv(out_bin, index=False)
|
||||
df_agg.to_csv(out_agg, index=False)
|
||||
|
||||
# Best-effort cleanup.
|
||||
for p in tmp_paths:
|
||||
try:
|
||||
if os.path.exists(p):
|
||||
os.remove(p)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"Wrote: {out_bin}")
|
||||
print(f"Wrote: {out_agg}")
|
||||
|
||||
|
||||
@@ -18,6 +18,158 @@ from utils import (
|
||||
)
|
||||
|
||||
|
||||
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Aggregate per-bin age evaluation results.
|
||||
|
||||
Produces both:
|
||||
- macro: unweighted mean over bins with n_samples>0
|
||||
- weighted: weighted mean over bins using weights=n_samples
|
||||
|
||||
Then aggregates across MC repetitions (mean/std).
|
||||
|
||||
Requires df_by_bin to include:
|
||||
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
|
||||
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
|
||||
|
||||
Returns:
|
||||
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
|
||||
"""
|
||||
if df_by_bin is None or len(df_by_bin) == 0:
|
||||
return pd.DataFrame(
|
||||
columns=[
|
||||
"agg_type",
|
||||
"horizon_tau",
|
||||
"topk_percent",
|
||||
"cause_id",
|
||||
"n_mc",
|
||||
"n_bins_used_mean",
|
||||
"n_samples_total_mean",
|
||||
"n_positives_total_mean",
|
||||
"auc_mean",
|
||||
"auc_std",
|
||||
"auprc_mean",
|
||||
"auprc_std",
|
||||
"recall_at_K_mean",
|
||||
"recall_at_K_std",
|
||||
"precision_at_K_mean",
|
||||
"precision_at_K_std",
|
||||
"brier_score_mean",
|
||||
"brier_score_std",
|
||||
]
|
||||
)
|
||||
|
||||
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
||||
g = group[group["n_samples"] > 0]
|
||||
if len(g) == 0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=0,
|
||||
n_samples_total=0,
|
||||
n_positives_total=0,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
n_bins_used = int(g["age_bin_id"].nunique())
|
||||
n_samples_total = int(g["n_samples"].sum())
|
||||
n_positives_total = int(g["n_positives"].sum())
|
||||
|
||||
if not weighted:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float(g["auc"].mean()),
|
||||
auprc=float(g["auprc"].mean()),
|
||||
recall_at_K=float(g["recall_at_K"].mean()),
|
||||
precision_at_K=float(g["precision_at_K"].mean()),
|
||||
brier_score=float(g["brier_score"].mean()),
|
||||
)
|
||||
)
|
||||
|
||||
w = g["n_samples"].to_numpy(dtype=float)
|
||||
w_sum = float(w.sum())
|
||||
if w_sum <= 0.0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
def _wavg(col: str) -> float:
|
||||
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
||||
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=_wavg("auc"),
|
||||
auprc=_wavg("auprc"),
|
||||
recall_at_K=_wavg("recall_at_K"),
|
||||
precision_at_K=_wavg("precision_at_K"),
|
||||
brier_score=_wavg("brier_score"),
|
||||
)
|
||||
)
|
||||
|
||||
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
||||
|
||||
df_mc_macro = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=False))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_macro["agg_type"] = "macro"
|
||||
|
||||
df_mc_weighted = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=True))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_weighted["agg_type"] = "weighted"
|
||||
|
||||
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
||||
|
||||
df_agg = (
|
||||
df_mc_binagg.groupby(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
||||
)
|
||||
.agg(
|
||||
n_mc=("mc_idx", "nunique"),
|
||||
n_bins_used_mean=("n_bins_used", "mean"),
|
||||
n_samples_total_mean=("n_samples_total", "mean"),
|
||||
n_positives_total_mean=("n_positives_total", "mean"),
|
||||
auc_mean=("auc", "mean"),
|
||||
auc_std=("auc", "std"),
|
||||
auprc_mean=("auprc", "mean"),
|
||||
auprc_std=("auprc", "std"),
|
||||
recall_at_K_mean=("recall_at_K", "mean"),
|
||||
recall_at_K_std=("recall_at_K", "std"),
|
||||
precision_at_K_mean=("precision_at_K", "mean"),
|
||||
precision_at_K_std=("precision_at_K", "std"),
|
||||
brier_score_mean=("brier_score", "mean"),
|
||||
brier_score_std=("brier_score", "std"),
|
||||
)
|
||||
.sort_values(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
||||
ignore_index=True,
|
||||
)
|
||||
)
|
||||
return df_agg
|
||||
|
||||
|
||||
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
||||
|
||||
@@ -138,6 +290,7 @@ def evaluate_time_dependent_age_bins(
|
||||
n_disease: int,
|
||||
cfg: EvalAgeConfig,
|
||||
device: str | torch.device,
|
||||
mc_offset: int = 0,
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
|
||||
|
||||
@@ -196,6 +349,7 @@ def evaluate_time_dependent_age_bins(
|
||||
]
|
||||
|
||||
for mc_idx in range(int(cfg.n_mc)):
|
||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||
# tqdm over batches; include MC idx in description.
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(dataloader,
|
||||
@@ -216,7 +370,7 @@ def evaluate_time_dependent_age_bins(
|
||||
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
||||
seed = (
|
||||
int(cfg.seed)
|
||||
+ (100_000 * int(mc_idx))
|
||||
+ (100_000 * int(global_mc_idx))
|
||||
+ (1_000 * int(tau_idx))
|
||||
+ (10 * int(bin_idx))
|
||||
+ int(batch_idx)
|
||||
@@ -277,6 +431,7 @@ def evaluate_time_dependent_age_bins(
|
||||
rows_by_bin: List[Dict[str, float | int]] = []
|
||||
|
||||
for mc_idx in range(int(cfg.n_mc)):
|
||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||
for h_idx, tau_y in enumerate(horizons_years):
|
||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
|
||||
@@ -287,7 +442,7 @@ def evaluate_time_dependent_age_bins(
|
||||
for k_percent in topk_percents:
|
||||
rows_by_bin.append(
|
||||
dict(
|
||||
mc_idx=mc_idx,
|
||||
mc_idx=global_mc_idx,
|
||||
age_bin_id=bin_idx,
|
||||
age_bin_low=float(a_lo),
|
||||
age_bin_high=float(a_hi),
|
||||
@@ -332,7 +487,7 @@ def evaluate_time_dependent_age_bins(
|
||||
yk, pk, float(k_percent))
|
||||
rows_by_bin.append(
|
||||
dict(
|
||||
mc_idx=mc_idx,
|
||||
mc_idx=global_mc_idx,
|
||||
age_bin_id=bin_idx,
|
||||
age_bin_low=float(a_lo),
|
||||
age_bin_high=float(a_hi),
|
||||
@@ -351,115 +506,6 @@ def evaluate_time_dependent_age_bins(
|
||||
|
||||
df_by_bin = pd.DataFrame(rows_by_bin)
|
||||
|
||||
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
||||
g = group[group["n_samples"] > 0]
|
||||
if len(g) == 0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=0,
|
||||
n_samples_total=0,
|
||||
n_positives_total=0,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
n_bins_used = int(g["age_bin_id"].nunique())
|
||||
n_samples_total = int(g["n_samples"].sum())
|
||||
n_positives_total = int(g["n_positives"].sum())
|
||||
|
||||
if not weighted:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float(g["auc"].mean()),
|
||||
auprc=float(g["auprc"].mean()),
|
||||
recall_at_K=float(g["recall_at_K"].mean()),
|
||||
precision_at_K=float(g["precision_at_K"].mean()),
|
||||
brier_score=float(g["brier_score"].mean()),
|
||||
)
|
||||
)
|
||||
|
||||
w = g["n_samples"].to_numpy(dtype=float)
|
||||
w_sum = float(w.sum())
|
||||
if w_sum <= 0.0:
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
|
||||
def _wavg(col: str) -> float:
|
||||
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
||||
|
||||
return pd.Series(
|
||||
dict(
|
||||
n_bins_used=n_bins_used,
|
||||
n_samples_total=n_samples_total,
|
||||
n_positives_total=n_positives_total,
|
||||
auc=_wavg("auc"),
|
||||
auprc=_wavg("auprc"),
|
||||
recall_at_K=_wavg("recall_at_K"),
|
||||
precision_at_K=_wavg("precision_at_K"),
|
||||
brier_score=_wavg("brier_score"),
|
||||
)
|
||||
)
|
||||
|
||||
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
||||
|
||||
df_mc_macro = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=False))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_macro["agg_type"] = "macro"
|
||||
|
||||
df_mc_weighted = (
|
||||
df_by_bin.groupby(group_keys)
|
||||
.apply(lambda g: _bin_aggregate(g, weighted=True))
|
||||
.reset_index()
|
||||
)
|
||||
df_mc_weighted["agg_type"] = "weighted"
|
||||
|
||||
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
||||
|
||||
# Then average over MC repetitions.
|
||||
df_agg = (
|
||||
df_mc_binagg.groupby(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
||||
)
|
||||
.agg(
|
||||
n_mc=("mc_idx", "nunique"),
|
||||
n_bins_used_mean=("n_bins_used", "mean"),
|
||||
n_samples_total_mean=("n_samples_total", "mean"),
|
||||
n_positives_total_mean=("n_positives_total", "mean"),
|
||||
auc_mean=("auc", "mean"),
|
||||
auc_std=("auc", "std"),
|
||||
auprc_mean=("auprc", "mean"),
|
||||
auprc_std=("auprc", "std"),
|
||||
recall_at_K_mean=("recall_at_K", "mean"),
|
||||
recall_at_K_std=("recall_at_K", "std"),
|
||||
precision_at_K_mean=("precision_at_K", "mean"),
|
||||
precision_at_K_std=("precision_at_K", "std"),
|
||||
brier_score_mean=("brier_score", "mean"),
|
||||
brier_score_std=("brier_score", "std"),
|
||||
)
|
||||
.sort_values(
|
||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
||||
ignore_index=True,
|
||||
)
|
||||
)
|
||||
df_agg = aggregate_age_bin_results(df_by_bin)
|
||||
|
||||
return df_by_bin, df_agg
|
||||
|
||||
Reference in New Issue
Block a user