summaryrefslogtreecommitdiff
path: root/scripts/roc.R
blob: cbd5ea66b7a99a6fe6af094310727af5064f413e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env Rscript
library(ggplot2)
args = commandArgs(trailingOnly=TRUE)

data = read.csv(args[1],header=T)
model_labels = factor(data$model, levels = c("LAZAR-HC", "LAZAR-ALL", "RF", "LR-sgd", "LR-scikit", "NN", "SVM"))
descriptor_labels = factor(data$descriptor, levels = c("MP2D", "CDK"))

colors = c(
"LAZAR-HC" =  "#0072B2",
"LAZAR-ALL" = "#56B4E9",
"RF" =        "#009E73",
"LR-sgd" =        "#F0E442",
"LR-scikit" =       "#D55E00",
"NN" =        "#CC79A7",
"SVM" =       "#E69F00"
)

ggplot(data,aes(fpr, tpr, color = model_labels, shape = descriptor_labels)) +
  geom_point(size = 2.5) +
  geom_abline() +
  expand_limits(x=c(0,1),y=c(0,1)) +
  labs(x = "False positive rate", y = "True positive rate") +
  scale_color_manual(values = colors) +
  theme_minimal() +
  theme(legend.title=element_blank())#,legend.position = "bottom",legend.direction="vertical")#,legend.key.height = 7,legend.key.width=2) +

ggsave(args[2])