From bdc6b5b40437896384561d74a510560e9e592364 Mon Sep 17 00:00:00 2001 From: "helma@in-silico.ch" Date: Tue, 9 Oct 2018 18:20:27 +0200 Subject: tentative random forest classification: hangs unpredictably during caret model generation/optimization for some (inorganic?) compounds. --- lib/caret-classification.rb | 107 ++++++++++++++++++++++++++++++++++++++++++++ lib/classification.rb | 3 +- lib/compound.rb | 1 - lib/dataset.rb | 14 +++--- lib/model.rb | 7 ++- 5 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 lib/caret-classification.rb (limited to 'lib') diff --git a/lib/caret-classification.rb b/lib/caret-classification.rb new file mode 100644 index 0000000..fefe6b6 --- /dev/null +++ b/lib/caret-classification.rb @@ -0,0 +1,107 @@ +module OpenTox + module Algorithm + + # Ruby interface for the R caret package + # Caret model list: https://topepo.github.io/caret/modelList.html + class Caret + + # Create a local R caret model and make a prediction + # @param [Array] dependent_variables + # @param [Array>] independent_variables + # @param [Array] weights + # @param [String] Caret method + # @param [Array] query_variables + # @return [Hash] + def self.create_model_and_predict dependent_variables:, independent_variables:, weights:, method:, query_variables: + remove = [] + # remove independent_variables with single values + independent_variables.each_with_index { |values,i| remove << i if values.uniq.size == 1} + remove.sort.reverse.each do |i| + independent_variables.delete_at i + query_variables.delete_at i + end + if independent_variables.flatten.uniq == ["NA"] or independent_variables.flatten.uniq == [] + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "No variables for classification model. Using weighted average of similar substances." + elsif dependent_variables.uniq.size == 1 + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "All neighbors have the same measured activity. Cannot create random forest model, using weighted average of similar substances." + elsif dependent_variables.size < 3 + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "Insufficient number of neighbors (#{dependent_variables.size}) for classification model. Using weighted average of similar substances." + else + dependent_variables.collect!{|v| to_r(v)} + independent_variables.each_with_index do |c,i| + c.each_with_index do |v,j| + independent_variables[i][j] = to_r(v) + end + end +# query_variables.collect!{|v| to_r(v)} + begin + R.assign "weights", weights + #r_data_frame = "data.frame(#{([dependent_variables.collect{|v| to_r(v)}]+independent_variables).collect{|r| "c(#{r.collect{|v| to_r(v)}.join(',')})"}.join(', ')})" + r_data_frame = "data.frame(#{([dependent_variables]+independent_variables).collect{|r| "c(#{r.join(',')})"}.join(', ')})" + #p r_data_frame + R.eval "data <- #{r_data_frame}" + R.assign "features", (0..independent_variables.size-1).to_a + R.eval "names(data) <- append(c('activities'),features)" # + p "train" + R.eval "model <- train(activities ~ ., data = data, method = '#{method}', na.action = na.pass, allowParallel=TRUE)" + p "done" + rescue => e + $logger.debug "R caret model creation error for: #{e.message}" + $logger.debug dependent_variables + $logger.debug independent_variables + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "R caret model creation error. Using weighted average of similar substances." + return prediction + end + begin + R.eval "query <- data.frame(rbind(c(#{query_variables.collect{|v| to_r(v)}.join ','})))" + R.eval "names(query) <- features" + R.eval "prediction <- predict(model,query, type=\"prob\")" + names = R.eval("names(prediction)").to_ruby + probs = R.eval("prediction").to_ruby + probabilities = {} + names.each_with_index { |n,i| probabilities[n] = probs[i] } + value = probabilities.sort_by{|n,p| -p }[0][0] + prediction = { + :value => value, + :probabilities => probabilities, + :warnings => [], + } + rescue => e + $logger.debug "R caret prediction error for: #{e.inspect}" + $logger.debug self.inspect + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "R caret prediction error. Using weighted average of similar substances" + return prediction + end + if prediction.nil? or prediction[:value].nil? + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "Empty R caret prediction. Using weighted average of similar substances." + end + end + prediction + + end + + # Call caret methods dynamically, e.g. Caret.pls + def self.method_missing(sym, *args, &block) + args.first[:method] = sym.to_s + self.create_model_and_predict args.first + end + + # Convert Ruby values to R values + def self.to_r v + return "F" if v == false + return "T" if v == true + return nil if v.is_a? Float and v.nan? + return "\"#{v}\"" if v.is_a? String + v + end + + end + end +end + diff --git a/lib/classification.rb b/lib/classification.rb index a875903..2668e4a 100644 --- a/lib/classification.rb +++ b/lib/classification.rb @@ -19,6 +19,7 @@ module OpenTox probabilities[a] = w.sum/weights.sum end # DG: hack to ensure always two probability values + # TODO: does not work for arbitrary feature names FIX!! if probabilities.keys.uniq.size == 1 missing_key = probabilities.keys.uniq[0].match(/^non/) ? probabilities.keys.uniq[0].sub(/non-/,"") : "non-"+probabilities.keys.uniq[0] probabilities[missing_key] = 0.0 @@ -26,7 +27,7 @@ module OpenTox probabilities = probabilities.collect{|a,p| [a,weights.max*p]}.to_h p_max = probabilities.collect{|a,p| p}.max prediction = probabilities.key(p_max) - {:value => prediction,:probabilities => probabilities} + {:value => prediction,:probabilities => probabilities,:warnings => []} end end diff --git a/lib/compound.rb b/lib/compound.rb index d80f579..8dc53a1 100644 --- a/lib/compound.rb +++ b/lib/compound.rb @@ -319,7 +319,6 @@ module OpenTox obconversion.read_string obmol, identifier case output_format when /smi|can|inchi/ - #obconversion.write_string(obmol).gsub(/\s/,'').chomp obconversion.write_string(obmol).split(/\s/).first when /sdf/ # TODO: find disconnected structures diff --git a/lib/dataset.rb b/lib/dataset.rb index b7d9d4e..6ad3215 100644 --- a/lib/dataset.rb +++ b/lib/dataset.rb @@ -71,6 +71,8 @@ module OpenTox # Merge an array of datasets # @param [Array] OpenTox::Dataset Array to be merged + # @param [Hash] feature modifications + # @param [Hash] value modifications # @return [OpenTox::Dataset] merged dataset def self.merge datasets, feature_map=nil, value_map=nil dataset = self.new(:source => datasets.collect{|d| d.source}.join(", "), :name => datasets.collect{|d| d.name}.uniq.join(", ")) @@ -205,7 +207,7 @@ module OpenTox md5 = Digest::MD5.hexdigest(File.read(file)) # use hash to identify identical files dataset = self.find_by(:md5 => md5) if dataset - $logger.debug "Skipping import of #{file}, it is already in the database (id: #{dataset.id})." + $logger.debug "Found #{file} in the database (id: #{dataset.id}, md5: #{dataset.md5}), skipping import." else $logger.debug "Parsing #{file}." table = nil @@ -234,10 +236,10 @@ module OpenTox if read_result value = line.chomp if value.numeric? - feature = NumericFeature.find_or_create_by(:name => feature_name) + feature = NumericFeature.find_or_create_by(:name => feature_name, :measured => true) value = value.to_f else - feature = NominalFeature.find_or_create_by(:name => feature_name) + feature = NominalFeature.find_or_create_by(:name => feature_name, :measured => true) end features[feature] = value read_result = false @@ -259,7 +261,7 @@ module OpenTox md5 = Digest::MD5.hexdigest(File.read(file)) # use hash to identify identical files dataset = self.find_by(:md5 => md5) if dataset - $logger.debug "Skipping import of #{file}, it is already in the database (id: #{dataset.id})." + $logger.debug "Found #{file} in the database (id: #{dataset.id}, md5: #{dataset.md5}), skipping import." else $logger.debug "Parsing #{file}." table = nil @@ -301,7 +303,7 @@ module OpenTox # guess feature types feature_names.each_with_index do |f,i| - metadata = {:name => f} + metadata = {:name => f, :measured => true} original_id ? j = i+2 : j = i+1 values = table.collect{|row| val=row[j].to_s.strip; val.blank? ? nil : val }.uniq.compact types = values.collect{|v| v.numeric? ? true : false}.uniq @@ -424,7 +426,7 @@ module OpenTox name = File.basename(file,".*") batch = self.find_by(:source => source, :name => name) if batch - $logger.debug "Skipping import of #{file}, it is already in the database (id: #{batch.id})." + $logger.debug "Found #{file} in the database (id: #{dataset.id}, md5: #{dataset.md5}), skipping import." else $logger.debug "Parsing #{file}." # check delimiter diff --git a/lib/model.rb b/lib/model.rb index 0ed70f2..8901a2c 100644 --- a/lib/model.rb +++ b/lib/model.rb @@ -37,7 +37,7 @@ module OpenTox # @return [OpenTox::Model::Lazar] def self.create prediction_feature:nil, training_dataset:, algorithms:{} bad_request_error "Please provide a prediction_feature and/or a training_dataset." unless prediction_feature or training_dataset - prediction_feature = training_dataset.features.first unless prediction_feature + prediction_feature = training_dataset.features.select{|f| f.measured}.first unless prediction_feature # TODO: prediction_feature without training_dataset: use all available data # guess model type @@ -199,6 +199,8 @@ module OpenTox # @return [Hash] def predict_substance substance, threshold = self.algorithms[:similarity][:min] + p substance.smiles + t = Time.now @independent_variables = Marshal.load $gridfs.find_one(_id: self.independent_variables_id).data case algorithms[:similarity][:method] when /tanimoto/ # binary features @@ -284,6 +286,9 @@ module OpenTox else # try again with a lower threshold predict_substance substance, 0.2 end + p prediction + p Time.now - t + prediction end # Predict a substance (compound or nanoparticle), an array of substances or a dataset -- cgit v1.2.3