Add R script to plot 1-year AUC comparisons with ggplot2 (256 vs Delphi and 120 vs Delphi)
This commit is contained in:
76
plot_model_comparison_1year.R
Normal file
76
plot_model_comparison_1year.R
Normal file
@@ -0,0 +1,76 @@
|
||||
# Plot AUC comparisons (1-year gap) between models and Delphi using ggplot2
|
||||
# Usage:
|
||||
# Rscript plot_model_comparison_1year.R [path_to_csv] [output_dir]
|
||||
# Defaults:
|
||||
# path_to_csv = "model_comparison_auc_1year.csv"
|
||||
# output_dir = current working directory (".")
|
||||
|
||||
suppressPackageStartupMessages({
|
||||
library(ggplot2)
|
||||
})
|
||||
|
||||
args <- commandArgs(trailingOnly = TRUE)
|
||||
csv_path <- if (length(args) >= 1) args[1] else "model_comparison_auc_1year.csv"
|
||||
out_dir <- if (length(args) >= 2) args[2] else "."
|
||||
|
||||
if (!dir.exists(out_dir)) {
|
||||
dir.create(out_dir, recursive = TRUE, showWarnings = FALSE)
|
||||
}
|
||||
|
||||
# Read data
|
||||
# Expect columns including: auc_delphi, auc_256, auc_120, Colour (hex color), name, etc.
|
||||
df <- tryCatch({
|
||||
read.csv(csv_path, check.names = FALSE)
|
||||
}, error = function(e) {
|
||||
stop(sprintf("Failed to read CSV at '%s': %s", csv_path, e$message))
|
||||
})
|
||||
|
||||
# Helper to make a scatter plot comparing model vs Delphi AUC
|
||||
make_comparison_plot <- function(data, y_col, title_text, y_label) {
|
||||
# Use shape 21 (filled circle) to allow white stroke and colored fill
|
||||
ggplot(data, aes(x = auc_delphi, y = .data[[y_col]])) +
|
||||
geom_abline(slope = 1, intercept = 0, color = "black", linetype = "dashed", linewidth = 0.5) +
|
||||
geom_vline(xintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||
geom_hline(yintercept = 0.5, color = "gray50", linetype = "dashed", linewidth = 0.4) +
|
||||
geom_point(aes(fill = Colour), shape = 21, color = "white", stroke = 0.65, size = 2.2, alpha = 0.95, show.legend = FALSE) +
|
||||
scale_fill_identity() +
|
||||
coord_cartesian(xlim = c(0.3, 1.05), ylim = c(0.3, 1.05)) +
|
||||
coord_fixed(ratio = 1) +
|
||||
labs(title = title_text, x = "AUC_Delphi", y = y_label) +
|
||||
theme_minimal(base_size = 10) +
|
||||
theme(
|
||||
plot.title = element_text(hjust = 0.5),
|
||||
panel.grid.minor = element_blank()
|
||||
)
|
||||
}
|
||||
|
||||
# Plot: AUC_256 vs AUC_Delphi (1 year gap)
|
||||
if (!all(c("auc_delphi", "auc_256") %in% names(df))) {
|
||||
stop("Input CSV must contain columns 'auc_delphi' and 'auc_256'.")
|
||||
}
|
||||
|
||||
p256 <- make_comparison_plot(
|
||||
data = df,
|
||||
y_col = "auc_256",
|
||||
title_text = "AUC_256 vs AUC_Delphi 1 year gap",
|
||||
y_label = "AUC_256"
|
||||
)
|
||||
|
||||
out_256 <- file.path(out_dir, "model_comparison_auc_256_vs_delphi_1year.png")
|
||||
ggsave(filename = out_256, plot = p256, width = 7, height = 4, dpi = 600, bg = "white")
|
||||
cat(sprintf("Saved: %s\n", out_256))
|
||||
|
||||
# Plot: AUC_120 vs AUC_Delphi (1 year gap)
|
||||
if (!"auc_120" %in% names(df)) {
|
||||
warning("Column 'auc_120' not found in CSV; skipping AUC_120 vs AUC_Delphi plot.")
|
||||
} else {
|
||||
p120 <- make_comparison_plot(
|
||||
data = df,
|
||||
y_col = "auc_120",
|
||||
title_text = "AUC_120 vs AUC_Delphi 1 year gap",
|
||||
y_label = "AUC_120"
|
||||
)
|
||||
out_120 <- file.path(out_dir, "fig_auc_120_vs_delphi_1year.png")
|
||||
ggsave(filename = out_120, plot = p120, width = 7, height = 4, dpi = 600, bg = "white")
|
||||
cat(sprintf("Saved: %s\n", out_120))
|
||||
}
|
Reference in New Issue
Block a user