From 83591831c6e36c36d87159acba6afdfedab95522 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Thu, 18 Mar 2021 16:48:36 +0100 Subject: fingerprint predictions added --- bin/lazar | 182 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100755 bin/lazar (limited to 'bin/lazar') diff --git a/bin/lazar b/bin/lazar new file mode 100755 index 0000000..e603b4c --- /dev/null +++ b/bin/lazar @@ -0,0 +1,182 @@ +#!/usr/bin/env ruby +require 'optparse' +require_relative '../lib/lazar' + +ARGV << '-h' if ARGV.empty? +options = {} +options[:folds] = 10 +options[:thresholds] = [0.5,0.2] + +OptionParser.new do |opts| + opts.banner = "Usage: lazar -t TRAIN -x|-p descriptors [options]" + opts.on( '-h', '--help', 'Display this screen' ) do + puts opts + exit + end + opts.on( '-t TRAIN', '-train TRAIN', "Training data in csv format (required). Type 'lazar -f' for format specifications." ) do |t| + options[:train] = t + end + opts.on( '-p descriptors', '--predict descriptors', "Prediction data in csv format. Type 'lazar -f' for format specifications.") do |p| + options[:predict] = p + end + opts.on( '-x', '--crossvalidation', "Run crossvalidation." ) do |c| + options[:cv] = true + end + opts.on( '-f folds', '--folds folds', Integer, "Change crossvalidation folds (default: #{options[:folds]})." ) do |f| + options[:folds] = f + end + opts.on( '-f', '--formats', "Describe input and output formats" ) do |f| + raise OptionParser::InvalidArgument, "Format description not yet implemented." + end +# opts.on( '-d', '--daemon', "Run as daemon in background" ) do |f| +# raise OptionParser::InvalidArgument, "Daemon mode not yet implemented" +# end +end.parse! + +raise OptionParser::MissingArgument, "Training data is required. Type 'lazar -h' for help." if options[:train].nil? +raise OptionParser::InvalidArgument, "Training data file #{options[:train]} does not exist. Type 'lazar -h' for help." unless File.exists? options[:train] +raise OptionParser::InvalidOption, "Choose either --predict or --crossvalidation. Type 'lazar -h' for help." if options[:predict] and options[:cv] +raise OptionParser::InvalidOption, "One of the --predict or --crossvalidation options is required. Type 'lazar -h' for help." unless options[:predict] or options[:cv] +raise OptionParser::InvalidArgument, "Prediction descriptor file #{options[:predict]} does not exist. Type 'lazar -h' for help." if options[:predict] and !File.exists? options[:predict] + +model = Model.new options[:train] + +if options[:predict] # batch predictions + model.predict options[:predict] + +elsif options[:cv] # crossvalidation + + # create folds + cv_dir = File.join(File.dirname(options[:train]),"crossvalidation") + folds = (0..options[:folds]-1).collect{|i| File.join(cv_dir,i.to_s)} + nr_instances = model.train.size + indices = (0..nr_instances-1).to_a.shuffle + mid = (nr_instances/options[:folds]) + start = 0 + 0.upto(options[:folds]-1) do |i| + + # split train data + puts "Creating fold #{i}" + last = start+mid + last = last-1 unless nr_instances%options[:folds] > i + test_idxs = indices[start..last] || [] + idxs = { + :train => indices-test_idxs, + :test => test_idxs + } + start = last+1 + + # write training/test data + idxs.each do |t,idx| + file = File.join(cv_dir,i.to_s,t.to_s+".csv") + `mkdir -p #{File.dirname file}` + case t + when :train + File.open(file,"w+") do |f| + f.puts (["Canonical SMILES",model.dependent_variable_name] + model.independent_variable_names).join(",") + idx.collect{|i| model.train[i]}.each do |t| + f.puts t.join(",") + end + end + when :test + File.open(file,"w+") do |f| + f.puts (["Canonical SMILES"] + model.independent_variable_names).join(",") + idx.collect{|i| model.train[i]}.each do |t| + o = t.clone # keep model.train intact + o.delete_at(1) + f.puts o.join(",") + end + end + end + end + end + + # crossvalidation predictions + t = Time.now + folds.each do |fold| + fork do + puts "Crossvalidation #{fold} started" + m = Model.new File.join(fold,"train.csv") + m.predict File.join(fold,"test.csv") + end + end + Process.waitall + puts "Crossvalidation: #{(Time.now-t)/60} min" + + # crossvalidation summaries + + predictions = [] + tp=0 + tn=0 + fp=0 + fn=0 + hc_tp=0 + hc_tn=0 + hc_fp=0 + hc_fn=0 + + File.open(File.join(cv_dir,"predictions.csv"),"w+") do |f| + folds.each do |fold| + pred = File.readlines(File.join(fold,"test-prediction.csv")).collect{|row| row.chomp.split(",")} + pred.shift + pred.each do |prediction| + smi = prediction[0] + exp = model.train.select{|t| t[0] == smi}.collect{|t| t[1].to_i} + maxsim = prediction[5].to_f + v = "NA" + unless exp.nil? or prediction[2].empty? or exp.empty? + p = prediction[2].to_i + exp.each do |e| + if p and e + if p == 1 and e == 1 + v = "TP" + tp+=1 + hc_tp+=1 if maxsim > model.minsim.max + elsif p == 0 and e == 0 + v = "TN" + tn+=1 + hc_tn+=1 if maxsim > model.minsim.max + elsif p == 1 and e == 0 + v = "FP" + fp+=1 + hc_fp+=1 if maxsim > model.minsim.max + elsif p == 0 and e == 1 + v = "FN" + fn+=1 + hc_fn+=1 if maxsim > model.minsim.max + end + end + predictions << v + end + end + f.puts([smi,v,maxsim].join(",")) + end + end + end + + File.open(File.join(cv_dir,"confusion-matrix-all.csv"),"w+") do |f| + f.puts "#{tp},#{fp}\n#{fn},#{tn}" + end + + File.open(File.join(cv_dir,"confusion-matrix-high-confidence.csv"),"w+") do |f| + f.puts "#{hc_tp},#{hc_fp}\n#{hc_fn},#{hc_tn}" + end + + File.open(File.join(cv_dir,"summary-all.csv"),"w+") do |f| + f.puts "accuracy,#{(tp+tn)/(tp+fp+tn+fn).to_f}" + f.puts "true_positive_rate,#{tp/(tp+fn).to_f}" + f.puts "true_negative_rate,#{tn/(tn+fp).to_f}" + f.puts "positive_predictive_value,#{tp/(tp+fp).to_f}" + f.puts "negative_predictive_value,#{tn/(tn+fn).to_f}" + end + + File.open(File.join(cv_dir,"summary-high-confidence.csv"),"w+") do |f| + f.puts "accuracy,#{(hc_tp+hc_tn)/(hc_tp+hc_fp+hc_tn+hc_fn).to_f}" + f.puts "true_positive_rate,#{hc_tp/(hc_tp+hc_fn).to_f}" + f.puts "true_negative_rate,#{hc_tn/(hc_tn+hc_fp).to_f}" + f.puts "positive_predictive_value,#{hc_tp/(hc_tp+hc_fp).to_f}" + f.puts "negative_predictive_value,#{hc_tn/(hc_tn+hc_fn).to_f}" + end + +end + -- cgit v1.2.3