summaryrefslogtreecommitdiff
path: root/scripts/roc.R
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/roc.R')
-rwxr-xr-xscripts/roc.R65
1 files changed, 24 insertions, 41 deletions
diff --git a/scripts/roc.R b/scripts/roc.R
index 32e1674..cbd5ea6 100755
--- a/scripts/roc.R
+++ b/scripts/roc.R
@@ -1,45 +1,28 @@
#!/usr/bin/env Rscript
library(ggplot2)
-data <- read.csv("figures/roc.csv",header=T)
-labels = factor(row.names(data), levels = c("MP2D-LAZAR-HC", "MP2D-LAZAR-ALL", "MP2D-RF", "MP2D-LR", "MP2D-LR2", "MP2D-NN", "MP2D-SVM", "CDK-LAZAR-HC", "CDK-LAZAR-ALL", "CDK-RF", "CDK-LR", "CDK-LR2", "CDK-NN", "CDK-SVM"))
-shapes = c(
-"MP2D-LAZAR-HC" = 16,
-"MP2D-LAZAR-ALL" = 16,
-"MP2D-RF" = 16,
-"MP2D-LR" = 16,
-"MP2D-LR2" = 16,
-"MP2D-NN" = 16,
-"MP2D-SVM" = 16,
-"CDK-LAZAR-HC" = 17,
-"CDK-LAZAR-ALL" = 17,
-"CDK-RF" = 17,
-"CDK-LR" = 17,
-"CDK-LR2" = 17,
-"CDK-NN" = 17,
-"CDK-SVM" = 17)
+args = commandArgs(trailingOnly=TRUE)
-colors <- c(
-"MP2D-LAZAR-HC" = "#E69F00",
-"MP2D-LAZAR-ALL" = "#56B4E9",
-"MP2D-RF" = "#009E73",
-"MP2D-LR" = "#F0E442",
-"MP2D-LR2" = "#0072B2",
-"MP2D-NN" = "#D55E00",
-"MP2D-SVM" = "#CC79A7",
-"CDK-LAZAR-HC" = "#E69F00",
-"CDK-LAZAR-ALL" = "#56B4E9",
-"CDK-RF" = "#009E73",
-"CDK-LR" = "#F0E442",
-"CDK-LR2" = "#0072B2",
-"CDK-NN" = "#D55E00",
-"CDK-SVM" = "#CC79A7")
+data = read.csv(args[1],header=T)
+model_labels = factor(data$model, levels = c("LAZAR-HC", "LAZAR-ALL", "RF", "LR-sgd", "LR-scikit", "NN", "SVM"))
+descriptor_labels = factor(data$descriptor, levels = c("MP2D", "CDK"))
-p <- ggplot(data)
-p <- p + geom_point(aes(x=fpr, y=tpr, color = labels, shape = labels))
-p <- p + geom_abline()
-p <- p + theme(legend.title=element_blank())
-p <- p + expand_limits(x=c(0,1),y=c(0,1))
-p <- p + scale_shape_manual(values = shapes)
-p <- p + scale_color_manual(values = colors)
-p <- p + labs(x = "False positive rate", y = "True positive rate")
-ggsave("figures/roc.png")
+colors = c(
+"LAZAR-HC" = "#0072B2",
+"LAZAR-ALL" = "#56B4E9",
+"RF" = "#009E73",
+"LR-sgd" = "#F0E442",
+"LR-scikit" = "#D55E00",
+"NN" = "#CC79A7",
+"SVM" = "#E69F00"
+)
+
+ggplot(data,aes(fpr, tpr, color = model_labels, shape = descriptor_labels)) +
+ geom_point(size = 2.5) +
+ geom_abline() +
+ expand_limits(x=c(0,1),y=c(0,1)) +
+ labs(x = "False positive rate", y = "True positive rate") +
+ scale_color_manual(values = colors) +
+ theme_minimal() +
+ theme(legend.title=element_blank())#,legend.position = "bottom",legend.direction="vertical")#,legend.key.height = 7,legend.key.width=2) +
+
+ggsave(args[2])