From 08e5768e9a446db8ab95152d2e9403a0e635ec63 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Mon, 8 Mar 2021 17:41:26 +0100 Subject: cdk predictions fixed --- bin/batch-prediction.rb | 4 ++ bin/batch_fingerprint_classification.rb | 13 ------ bin/classification-summary.rb | 70 ++++++++++++++++++++++++++++++++- bin/crossvalidation-folds.rb | 54 +++++++++++++++++++++++++ bin/crossvalidation-predictions.rb | 13 ++++++ bin/fingerprints.rb | 2 +- bin/preprocessing.R | 7 ++++ 7 files changed, 147 insertions(+), 16 deletions(-) create mode 100755 bin/batch-prediction.rb delete mode 100755 bin/batch_fingerprint_classification.rb create mode 100755 bin/crossvalidation-folds.rb create mode 100755 bin/crossvalidation-predictions.rb create mode 100644 bin/preprocessing.R (limited to 'bin') diff --git a/bin/batch-prediction.rb b/bin/batch-prediction.rb new file mode 100755 index 0000000..770bc60 --- /dev/null +++ b/bin/batch-prediction.rb @@ -0,0 +1,4 @@ +#!/usr/bin/env ruby +require_relative "../lib/lazar" +model = Model.new ARGV[0] +model.predict ARGV[1] diff --git a/bin/batch_fingerprint_classification.rb b/bin/batch_fingerprint_classification.rb deleted file mode 100755 index 318fae6..0000000 --- a/bin/batch_fingerprint_classification.rb +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env ruby -require_relative "../lib/lazar" -model = ClassificationModel.new ARGV[0] - -File.read(ARGV[1]).each_line do |line| - if line.match(/SMILES/i) - puts "ID,SMILES,experimental,classification,probability(0),probability(1),max_similarity,nr_neighbors" - else - id,smi = line.chomp.split(",") - puts ([id] + model.predict_smiles(smi)).join(",") - end -end - diff --git a/bin/classification-summary.rb b/bin/classification-summary.rb index a3e4172..c6755a1 100755 --- a/bin/classification-summary.rb +++ b/bin/classification-summary.rb @@ -1,4 +1,70 @@ #!/usr/bin/env ruby require_relative "../lib/lazar" -stat = ClassificationStatistics.new ARGV[0] -stat.summary +#stat = ClassificationStatistics.new ARGV[0] +#stat.summary +dir = File.join(File.dirname(ARGV[0]),"crossvalidation") +folds = Dir[File.join(dir,"[0-9]*")] + +predictions = [] +tp=0 +tn=0 +fp=0 +fn=0 +n=0 +experimental = {} + +lines = File.readlines(File.join(ARGV[0])) +lines.shift +lines.each do |line| + items = line.chomp.split(',') + experimental[items[0]] ||= [] + experimental[items[0]] << items[1].to_i +end + +File.open(File.join(dir,"predictions.csv"),"w+") do |f| + folds.each do |fold| + pred = File.readlines(File.join(fold,"test-prediction.csv")).collect{|row| row.chomp.split(",")} + pred.shift + pred.each do |prediction| + smi = prediction[0] + exp = experimental[smi] + unless exp.nil? or prediction[2].empty? or exp.empty? + p = prediction[2].to_i + n+=1 + v = "NA" + exp.each do |e| + if p and e + if p == 1 and e == 1 + v = "TP" + tp+=1 + elsif p == 0 and e == 0 + v = "TN" + tn+=1 + elsif p == 1 and e == 0 + v = "FP" + fp+=1 + elsif p == 0 and e == 1 + v = "FN" + fn+=1 + end + end + predictions << v + end + f.puts([smi,v].join(",")) + end + end + end +end + +File.open(File.join(dir,"confusion-matrix.csv"),"w+") do |f| + f.puts "#{tp},#{fp}\n#{fn},#{tn}" +end + +File.open(File.join(dir,"summary.csv"),"w+") do |f| + f.puts "accuracy,#{(tp+tn)/(tp+fp+tn+fn).to_f}" + f.puts "true_positive_rate,#{tp/(tp+fn).to_f}" + f.puts "true_negative_rate,#{tn/(tn+fp).to_f}" + f.puts "positive_predictive_value,#{tp/(tp+fp).to_f}" + f.puts "negative_predictive_value,#{tn/(tn+fn).to_f}" +end + diff --git a/bin/crossvalidation-folds.rb b/bin/crossvalidation-folds.rb new file mode 100755 index 0000000..0c765f7 --- /dev/null +++ b/bin/crossvalidation-folds.rb @@ -0,0 +1,54 @@ +#!/usr/bin/env ruby +require_relative "../lib/lazar" +model = Model.new ARGV[0] +ARGV[1] ? folds = ARGV[1].to_i : folds = 10 +nr_instances = model.train.size +indices = (0..nr_instances-1).to_a.shuffle +mid = (nr_instances/folds) +start = 0 +0.upto(folds-1) do |i| + fork do + # split train data + puts "Creating fold #{i}" + last = start+mid + last = last-1 unless nr_instances%folds > i + test_idxs = indices[start..last] || [] + idxs = { + :train => indices-test_idxs, + :test => test_idxs + } + start = last+1 + # write training/test data + cv_dir = File.join(File.dirname(ARGV[0]),"crossvalidation",i.to_s) + idxs.each do |t,idx| + file = File.join(cv_dir,t.to_s+".csv") + `mkdir -p #{File.dirname file}` + case t + when :train + File.open(file,"w+") do |f| + f.puts (["Canonical SMILES",model.dependent_variable_name] + model.independent_variable_names).join(",") + idx.collect{|i| model.train[i]}.each do |t| + f.puts t.join(",") + end + end + when :test + File.open(file,"w+") do |f| + f.puts (["Canonical SMILES"] + model.independent_variable_names).join(",") + idx.collect{|i| model.train[i]}.each do |t| + t.delete_at(1) + f.puts t.join(",") + end + end + file = File.join(cv_dir,t.to_s+"-experimental.csv") + File.open(file,"w+") do |f| + f.puts (["Canonical SMILES", model.dependent_variable_name]).join(",") + idx.collect{|i| model.train[i]}.each do |t| + # TODO fix + f.puts t[0..1].join(",") + end + end + end + end + Process.waitall + end +end diff --git a/bin/crossvalidation-predictions.rb b/bin/crossvalidation-predictions.rb new file mode 100755 index 0000000..55ae5a1 --- /dev/null +++ b/bin/crossvalidation-predictions.rb @@ -0,0 +1,13 @@ +#!/usr/bin/env ruby +require_relative "../lib/lazar" + +t = Time.now +Dir["#{File.join(ARGV[0],'[0-9]')}"].each do |fold| + fork do + puts "Crossvalidation #{fold} started" + model = Model.new File.join(fold,"train.csv") + model.predict File.join(fold,"test.csv") + end +end +Process.waitall +puts "Crossvalidation: #{(Time.now-t)/60} min" diff --git a/bin/fingerprints.rb b/bin/fingerprints.rb index 923be8d..862a1aa 100755 --- a/bin/fingerprints.rb +++ b/bin/fingerprints.rb @@ -1,5 +1,5 @@ #!/usr/bin/env ruby -require_relative "../lib/lazar" +require_relative "../lib/compound" File.read(ARGV[0]).each_line do |smi| c = Compound.from_smiles(smi.chomp) puts c.fingerprint.join(",") diff --git a/bin/preprocessing.R b/bin/preprocessing.R new file mode 100644 index 0000000..393bf46 --- /dev/null +++ b/bin/preprocessing.R @@ -0,0 +1,7 @@ +#!/usr/bin/env Rscript +library(caret) +args = commandArgs(trailingOnly=TRUE) +variables = read.csv(args[1]) +scaling = preProcess(variables, method = c("nzv","corr","center", "scale")) +scaled = predict(scaling,variables) +write.csv(scaled,file=args[2], row.names=F, quote=F) -- cgit v1.2.3