diff options
Diffstat (limited to 'scripts/roc.R')
-rwxr-xr-x | scripts/roc.R | 65 |
1 files changed, 24 insertions, 41 deletions
diff --git a/scripts/roc.R b/scripts/roc.R index 32e1674..cbd5ea6 100755 --- a/scripts/roc.R +++ b/scripts/roc.R @@ -1,45 +1,28 @@ #!/usr/bin/env Rscript library(ggplot2) -data <- read.csv("figures/roc.csv",header=T) -labels = factor(row.names(data), levels = c("MP2D-LAZAR-HC", "MP2D-LAZAR-ALL", "MP2D-RF", "MP2D-LR", "MP2D-LR2", "MP2D-NN", "MP2D-SVM", "CDK-LAZAR-HC", "CDK-LAZAR-ALL", "CDK-RF", "CDK-LR", "CDK-LR2", "CDK-NN", "CDK-SVM")) -shapes = c( -"MP2D-LAZAR-HC" = 16, -"MP2D-LAZAR-ALL" = 16, -"MP2D-RF" = 16, -"MP2D-LR" = 16, -"MP2D-LR2" = 16, -"MP2D-NN" = 16, -"MP2D-SVM" = 16, -"CDK-LAZAR-HC" = 17, -"CDK-LAZAR-ALL" = 17, -"CDK-RF" = 17, -"CDK-LR" = 17, -"CDK-LR2" = 17, -"CDK-NN" = 17, -"CDK-SVM" = 17) +args = commandArgs(trailingOnly=TRUE) -colors <- c( -"MP2D-LAZAR-HC" = "#E69F00", -"MP2D-LAZAR-ALL" = "#56B4E9", -"MP2D-RF" = "#009E73", -"MP2D-LR" = "#F0E442", -"MP2D-LR2" = "#0072B2", -"MP2D-NN" = "#D55E00", -"MP2D-SVM" = "#CC79A7", -"CDK-LAZAR-HC" = "#E69F00", -"CDK-LAZAR-ALL" = "#56B4E9", -"CDK-RF" = "#009E73", -"CDK-LR" = "#F0E442", -"CDK-LR2" = "#0072B2", -"CDK-NN" = "#D55E00", -"CDK-SVM" = "#CC79A7") +data = read.csv(args[1],header=T) +model_labels = factor(data$model, levels = c("LAZAR-HC", "LAZAR-ALL", "RF", "LR-sgd", "LR-scikit", "NN", "SVM")) +descriptor_labels = factor(data$descriptor, levels = c("MP2D", "CDK")) -p <- ggplot(data) -p <- p + geom_point(aes(x=fpr, y=tpr, color = labels, shape = labels)) -p <- p + geom_abline() -p <- p + theme(legend.title=element_blank()) -p <- p + expand_limits(x=c(0,1),y=c(0,1)) -p <- p + scale_shape_manual(values = shapes) -p <- p + scale_color_manual(values = colors) -p <- p + labs(x = "False positive rate", y = "True positive rate") -ggsave("figures/roc.png") +colors = c( +"LAZAR-HC" = "#0072B2", +"LAZAR-ALL" = "#56B4E9", +"RF" = "#009E73", +"LR-sgd" = "#F0E442", +"LR-scikit" = "#D55E00", +"NN" = "#CC79A7", +"SVM" = "#E69F00" +) + +ggplot(data,aes(fpr, tpr, color = model_labels, shape = descriptor_labels)) + + geom_point(size = 2.5) + + geom_abline() + + expand_limits(x=c(0,1),y=c(0,1)) + + labs(x = "False positive rate", y = "True positive rate") + + scale_color_manual(values = colors) + + theme_minimal() + + theme(legend.title=element_blank())#,legend.position = "bottom",legend.direction="vertical")#,legend.key.height = 7,legend.key.width=2) + + +ggsave(args[2]) |