135 lines
4.8 KiB
R
135 lines
4.8 KiB
R
# Compare AUC distributions between models by ICD-10 chapter (1-year and no-gap)
|
|
# Usage:
|
|
# Rscript plot_auc_boxplots_by_chapter.R [one_year_csv] [no_gap_csv] [output_dir]
|
|
# Defaults:
|
|
# one_year_csv = "model_comparison_auc_1year.csv"
|
|
# no_gap_csv = "model_comparison_auc_no_gap.csv"
|
|
# output_dir = current working directory (".")
|
|
|
|
suppressPackageStartupMessages({
|
|
library(ggplot2)
|
|
library(cowplot)
|
|
})
|
|
|
|
args <- commandArgs(trailingOnly = TRUE)
|
|
one_year_csv <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv"
|
|
no_gap_csv <- if (length(args) >= 2) args[2] else "model_comparison_auc_no_gap.csv"
|
|
out_dir <- if (length(args) >= 3) args[3] else "."
|
|
|
|
if (!dir.exists(out_dir)) {
|
|
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
|
|
}
|
|
|
|
read_csv_safe <- function(path) {
|
|
tryCatch({
|
|
read.csv(path, check.names = FALSE)
|
|
}, error = function(e) {
|
|
stop(sprintf("Failed to read CSV at '%s': %s", path, e$message))
|
|
})
|
|
}
|
|
|
|
# Determine a chapter column name robustly
|
|
get_chapter_col <- function(df) {
|
|
candidates <- c("ICD-10 Chapter (short)", "ICD-10 Chapter", "ICD10_chapter", "chapter", "ICD_chapter")
|
|
for (c in candidates) {
|
|
if (c %in% names(df)) return(c)
|
|
}
|
|
return(NA_character_)
|
|
}
|
|
|
|
# Build long-format data.frame with columns: chapter, model, auc
|
|
# It will include any of the known model columns that exist in the input df
|
|
build_long_df <- function(df) {
|
|
model_cols <- c(
|
|
auc_120 = "auc_120",
|
|
auc_120_l = "auc_120_l",
|
|
auc_256 = "auc_256",
|
|
auc_256_l = "auc_256_l",
|
|
auc_delphi = "auc_delphi"
|
|
)
|
|
pretty_names <- c(
|
|
auc_120 = "GPT-2 120",
|
|
auc_120_l = "GPT-2 120_L",
|
|
auc_256 = "GPT-2 256",
|
|
auc_256_l = "GPT-2 256_L",
|
|
auc_delphi = "Delphi"
|
|
)
|
|
present <- model_cols[names(model_cols) %in% names(df)]
|
|
if (length(present) == 0) stop("No known AUC columns found in input data.")
|
|
chap_col <- get_chapter_col(df)
|
|
if (is.na(chap_col)) {
|
|
warning("No chapter column found; using a single 'All' group.")
|
|
chapters <- rep("All", nrow(df))
|
|
} else {
|
|
chapters <- df[[chap_col]]
|
|
}
|
|
out_list <- list()
|
|
for (key in names(model_cols)) {
|
|
col <- model_cols[[key]]
|
|
if (col %in% names(df)) {
|
|
out_list[[length(out_list) + 1]] <- data.frame(
|
|
chapter = chapters,
|
|
model = pretty_names[[key]],
|
|
auc = as.numeric(df[[col]]),
|
|
stringsAsFactors = FALSE
|
|
)
|
|
}
|
|
}
|
|
long_df <- do.call(rbind, out_list)
|
|
# Filter out-of-range or NA
|
|
long_df <- long_df[is.finite(long_df$auc) & long_df$auc >= 0 & long_df$auc <= 1, ]
|
|
long_df$model <- factor(long_df$model, levels = c("GPT-2 120", "GPT-2 120_L", "GPT-2 256", "GPT-2 256_L", "Delphi"))
|
|
return(long_df)
|
|
}
|
|
|
|
# Make the boxplot grouped by chapter
|
|
make_boxplot <- function(long_df, title_text) {
|
|
# Order chapters by median AUC of Delphi if available, otherwise overall median
|
|
has_delphi <- any(long_df$model == "Delphi")
|
|
if (has_delphi) {
|
|
med <- aggregate(auc ~ chapter, data = subset(long_df, model == "Delphi"), median, na.rm = TRUE)
|
|
} else {
|
|
med <- aggregate(auc ~ chapter, data = long_df, median, na.rm = TRUE)
|
|
}
|
|
chap_levels <- med[order(med$auc, decreasing = TRUE), "chapter"]
|
|
long_df$chapter <- factor(long_df$chapter, levels = chap_levels)
|
|
|
|
ggplot(long_df, aes(x = chapter, y = auc, fill = model)) +
|
|
geom_boxplot(outlier.shape = 19, outlier.size = 0.7, width = 0.75, alpha = 0.95) +
|
|
coord_flip() +
|
|
scale_y_continuous(limits = c(0.3, 1.0), breaks = seq(0.3, 1.0, by = 0.1)) +
|
|
labs(title = title_text, x = "ICD-10 Chapter", y = "AUC") +
|
|
theme_minimal(base_size = 11) +
|
|
theme(
|
|
plot.title = element_text(hjust = 0.5),
|
|
panel.grid.minor = element_blank(),
|
|
legend.position = "bottom"
|
|
) +
|
|
guides(fill = guide_legend(nrow = 1))
|
|
}
|
|
|
|
# Build plots for 1-year and no-gap
|
|
one_year_df <- read_csv_safe(one_year_csv)
|
|
no_gap_df <- read_csv_safe(no_gap_csv)
|
|
|
|
one_year_long <- build_long_df(one_year_df)
|
|
no_gap_long <- build_long_df(no_gap_df)
|
|
|
|
p1 <- make_boxplot(one_year_long, "AUC by ICD-10 Chapter (1-year gap)")
|
|
p2 <- make_boxplot(no_gap_long, "AUC by ICD-10 Chapter (no gap)")
|
|
|
|
# Save individual plots
|
|
out_1year <- file.path(out_dir, "auc_boxplot_by_chapter_1year.png")
|
|
ggsave(out_1year, p1, width = 12, height = 10, dpi = 300, bg = "white")
|
|
cat(sprintf("Saved: %s\n", out_1year))
|
|
|
|
out_nogap <- file.path(out_dir, "auc_boxplot_by_chapter_no_gap.png")
|
|
ggsave(out_nogap, p2, width = 12, height = 10, dpi = 300, bg = "white")
|
|
cat(sprintf("Saved: %s\n", out_nogap))
|
|
|
|
# Save a side-by-side grid for quick comparison
|
|
grid <- plot_grid(p1, p2, labels = c("A", "B"), ncol = 2, align = "hv")
|
|
out_grid <- file.path(out_dir, "auc_boxplot_by_chapter_grid.png")
|
|
ggsave(out_grid, grid, width = 18, height = 10, dpi = 250, bg = "white")
|
|
cat(sprintf("Saved grid: %s\n", out_grid))
|