summaryrefslogtreecommitdiff
path: root/lib/caret.rb
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2016-10-12 21:32:27 +0200
committerChristoph Helma <helma@in-silico.ch>2016-10-12 21:32:27 +0200
commitdc4ab1f4e64d738d6c0b70f0b690a2359685080f (patch)
tree054ae887bf978b519a95dce5dbead59bbc67a2bb /lib/caret.rb
parent1ec5ad2c67f270287499980a794e51bc9a6bbd84 (diff)
physchem regression, correlation_filter for fingerprints
Diffstat (limited to 'lib/caret.rb')
-rw-r--r--lib/caret.rb184
1 files changed, 57 insertions, 127 deletions
diff --git a/lib/caret.rb b/lib/caret.rb
index b999b06..59e02da 100644
--- a/lib/caret.rb
+++ b/lib/caret.rb
@@ -5,33 +5,56 @@ module OpenTox
# TODO classification
# model list: https://topepo.github.io/caret/modelList.html
- attr_accessor :descriptors, :neighbors, :method, :relevant_features, :data_frame, :feature_names, :weights, :query_features
-
- def initialize descriptors:, neighbors:, method:, relevant_features:
- @descriptors = descriptors
- @neighbors = neighbors
- @method = method
- @relevant_features = relevant_features
- end
-
- def self.regression descriptors:, neighbors:, method:, relevant_features:nil
-
- caret = new(descriptors:descriptors, neighbors:neighbors, method:method, relevant_features:relevant_features)
- # collect training data for R
- if descriptors.is_a? Array
- caret.fingerprint2R
- elsif descriptors.is_a? Hash
- caret.properties2R
- else
- bad_request_error "Descriptors should be a fingerprint (Array) or properties (Hash). Cannot handle '#{descriptors.class}'."
- end
- if caret.feature_names.empty? or caret.data_frame.flatten.uniq == ["NA"]
- prediction = Algorithm::Regression::weighted_average(descriptors: @descriptors, neighbors: neighbors)
+ def self.create_model_and_predict dependent_variables:, independent_variables:, weights:, method:, query_variables:
+ if independent_variables.flatten.uniq == ["NA"]
+ prediction = Algorithm::Regression::weighted_average dependent_variables:dependent_variables, weights:weights
prediction[:warning] = "No variables for regression model. Using weighted average of similar substances."
else
- prediction = caret.r_model_prediction
+ dependent_variables.each_with_index do |v,i|
+ dependent_variables[i] = to_r(v)
+ end
+ independent_variables.each_with_index do |c,i|
+ c.each_with_index do |v,j|
+ independent_variables[i][j] = to_r(v)
+ end
+ end
+ query_variables.each_with_index do |v,i|
+ query_variables[i] = to_r(v)
+ end
+ begin
+ R.assign "weights", weights
+ r_data_frame = "data.frame(#{([dependent_variables]+independent_variables).collect{|r| "c(#{r.join(',')})"}.join(', ')})"
+ R.eval "data <- #{r_data_frame}"
+ R.assign "features", (0..independent_variables.size-1).to_a
+ R.eval "names(data) <- append(c('activities'),features)" #
+ R.eval "model <- train(activities ~ ., data = data, method = '#{method}', na.action = na.pass, allowParallel=TRUE)"
+ rescue => e
+ $logger.debug "R caret model creation error for:"
+ $logger.debug JSON.pretty_generate(dependent_variables)
+ $logger.debug JSON.pretty_generate(independent_variables)
+ return {:value => nil, :warning => "R caret model cration error."}
+ end
+ begin
+ R.eval "query <- data.frame(rbind(c(#{query_variables.join ','})))"
+ R.eval "names(query) <- features"
+ R.eval "prediction <- predict(model,query)"
+ value = R.eval("prediction").to_f
+ rmse = R.eval("getTrainPerf(model)$TrainRMSE").to_f
+ r_squared = R.eval("getTrainPerf(model)$TrainRsquared").to_f
+ prediction_interval = value-1.96*rmse, value+1.96*rmse
+ prediction = {
+ :value => value,
+ :rmse => rmse,
+ :r_squared => r_squared,
+ :prediction_interval => prediction_interval
+ }
+ rescue => e
+ $logger.debug "R caret prediction error for:"
+ $logger.debug self.inspect
+ return nil
+ end
if prediction.nil? or prediction[:value].nil?
- prediction = Algorithm::Regression::weighted_average(descriptors: @descriptors, neighbors: neighbors)
+ prediction = Algorithm::Regression::weighted_average dependent_variables:dependent_variables, weights:weights
prediction[:warning] = "Could not create local caret model. Using weighted average of similar substances."
end
end
@@ -39,111 +62,18 @@ module OpenTox
end
- def fingerprint2R
-
- values = []
- features = {}
- @weights = []
- descriptor_ids = neighbors.collect{|n| n["descriptors"]}.flatten.uniq.sort
-
- neighbors.each do |n|
- activities = n["measurements"]
- activities.each do |act|
- values << act
- @weights << n["similarity"]
- descriptor_ids.each do |id|
- features[id] ||= []
- features[id] << n["descriptors"].include?(id)
- end
- end if activities
- end
-
- @feature_names = []
- @data_frame = [values]
-
- features.each do |k,v|
- unless v.uniq.size == 1
- @data_frame << v.collect{|m| m ? "T" : "F"}
- @feature_names << k
- end
- end
- @query_features = @feature_names.collect{|f| descriptors.include?(f) ? "T" : "F"}
-
+ # call caret methods dynamically, e.g. Caret.pls
+ def self.method_missing(sym, *args, &block)
+ args.first[:method] = sym.to_s
+ self.create_model_and_predict args.first
end
-
- def properties2R
-
- @weights = []
- @feature_names = []
- @query_features = []
-
- # keep only descriptors with values
- @relevant_features.keys.each_with_index do |f,i|
- if @descriptors[f]
- @feature_names << f
- @query_features << @descriptors[f].median
- else
- neighbors.each do |n|
- n["descriptors"].delete_at i
- end
- end
- end
-
- measurements = neighbors.collect{|n| n["measurements"]}.flatten
- # initialize data frame with 'NA' defaults
- @data_frame = Array.new(@feature_names.size+1){Array.new(measurements.size,"NA") }
-
- i = 0
- # parse neighbor activities and descriptors
- neighbors.each do |n|
- activities = n["measurements"]
- activities.each do |act| # multiple measurements are treated as separate instances
- unless n["descriptors"].include?(nil)
- data_frame[0][i] = act
- @weights << n["similarity"]
- n["descriptors"].each_with_index do |d,j|
- @data_frame[j+1][i] = d
- end
- i += 1
- end
- end if activities # ignore neighbors without measurements
- end
-
- end
-
- def r_model_prediction
- begin
- R.assign "weights", @weights
- r_data_frame = "data.frame(#{@data_frame.collect{|r| "c(#{r.join(',')})"}.join(', ')})"
- R.eval "data <- #{r_data_frame}"
- R.assign "features", @feature_names
- R.eval "names(data) <- append(c('activities'),features)" #
- R.eval "model <- train(activities ~ ., data = data, method = '#{method}', na.action = na.pass, allowParallel=TRUE)"
- rescue => e
- $logger.debug "R caret model creation error for:"
- $logger.debug JSON.pretty_generate(self.inspect)
- return nil
- end
- begin
- R.eval "query <- data.frame(rbind(c(#{@query_features.join ','})))"
- R.eval "names(query) <- features"
- R.eval "prediction <- predict(model,query)"
- value = R.eval("prediction").to_f
- rmse = R.eval("getTrainPerf(model)$TrainRMSE").to_f
- r_squared = R.eval("getTrainPerf(model)$TrainRsquared").to_f
- prediction_interval = value-1.96*rmse, value+1.96*rmse
- {
- :value => value,
- :rmse => rmse,
- :r_squared => r_squared,
- :prediction_interval => prediction_interval
- }
- rescue => e
- $logger.debug "R caret prediction error for:"
- $logger.debug self.inspect
- return nil
- end
+ def self.to_r v
+ return "F" if v == false
+ return "T" if v == true
+ return "NA" if v.nil?
+ return "NA" if v.is_a? Float and v.nan?
+ v
end
end