From bdc6b5b40437896384561d74a510560e9e592364 Mon Sep 17 00:00:00 2001 From: "helma@in-silico.ch" Date: Tue, 9 Oct 2018 18:20:27 +0200 Subject: tentative random forest classification: hangs unpredictably during caret model generation/optimization for some (inorganic?) compounds. --- lib/caret-classification.rb | 107 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 lib/caret-classification.rb (limited to 'lib/caret-classification.rb') diff --git a/lib/caret-classification.rb b/lib/caret-classification.rb new file mode 100644 index 0000000..fefe6b6 --- /dev/null +++ b/lib/caret-classification.rb @@ -0,0 +1,107 @@ +module OpenTox + module Algorithm + + # Ruby interface for the R caret package + # Caret model list: https://topepo.github.io/caret/modelList.html + class Caret + + # Create a local R caret model and make a prediction + # @param [Array] dependent_variables + # @param [Array>] independent_variables + # @param [Array] weights + # @param [String] Caret method + # @param [Array] query_variables + # @return [Hash] + def self.create_model_and_predict dependent_variables:, independent_variables:, weights:, method:, query_variables: + remove = [] + # remove independent_variables with single values + independent_variables.each_with_index { |values,i| remove << i if values.uniq.size == 1} + remove.sort.reverse.each do |i| + independent_variables.delete_at i + query_variables.delete_at i + end + if independent_variables.flatten.uniq == ["NA"] or independent_variables.flatten.uniq == [] + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "No variables for classification model. Using weighted average of similar substances." + elsif dependent_variables.uniq.size == 1 + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "All neighbors have the same measured activity. Cannot create random forest model, using weighted average of similar substances." + elsif dependent_variables.size < 3 + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "Insufficient number of neighbors (#{dependent_variables.size}) for classification model. Using weighted average of similar substances." + else + dependent_variables.collect!{|v| to_r(v)} + independent_variables.each_with_index do |c,i| + c.each_with_index do |v,j| + independent_variables[i][j] = to_r(v) + end + end +# query_variables.collect!{|v| to_r(v)} + begin + R.assign "weights", weights + #r_data_frame = "data.frame(#{([dependent_variables.collect{|v| to_r(v)}]+independent_variables).collect{|r| "c(#{r.collect{|v| to_r(v)}.join(',')})"}.join(', ')})" + r_data_frame = "data.frame(#{([dependent_variables]+independent_variables).collect{|r| "c(#{r.join(',')})"}.join(', ')})" + #p r_data_frame + R.eval "data <- #{r_data_frame}" + R.assign "features", (0..independent_variables.size-1).to_a + R.eval "names(data) <- append(c('activities'),features)" # + p "train" + R.eval "model <- train(activities ~ ., data = data, method = '#{method}', na.action = na.pass, allowParallel=TRUE)" + p "done" + rescue => e + $logger.debug "R caret model creation error for: #{e.message}" + $logger.debug dependent_variables + $logger.debug independent_variables + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "R caret model creation error. Using weighted average of similar substances." + return prediction + end + begin + R.eval "query <- data.frame(rbind(c(#{query_variables.collect{|v| to_r(v)}.join ','})))" + R.eval "names(query) <- features" + R.eval "prediction <- predict(model,query, type=\"prob\")" + names = R.eval("names(prediction)").to_ruby + probs = R.eval("prediction").to_ruby + probabilities = {} + names.each_with_index { |n,i| probabilities[n] = probs[i] } + value = probabilities.sort_by{|n,p| -p }[0][0] + prediction = { + :value => value, + :probabilities => probabilities, + :warnings => [], + } + rescue => e + $logger.debug "R caret prediction error for: #{e.inspect}" + $logger.debug self.inspect + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "R caret prediction error. Using weighted average of similar substances" + return prediction + end + if prediction.nil? or prediction[:value].nil? + prediction = Algorithm::Classification::weighted_majority_vote dependent_variables:dependent_variables, weights:weights + prediction[:warnings] << "Empty R caret prediction. Using weighted average of similar substances." + end + end + prediction + + end + + # Call caret methods dynamically, e.g. Caret.pls + def self.method_missing(sym, *args, &block) + args.first[:method] = sym.to_s + self.create_model_and_predict args.first + end + + # Convert Ruby values to R values + def self.to_r v + return "F" if v == false + return "T" if v == true + return nil if v.is_a? Float and v.nan? + return "\"#{v}\"" if v.is_a? String + v + end + + end + end +end + -- cgit v1.2.3