summaryrefslogtreecommitdiff
path: root/lib/caret.rb
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2016-10-07 10:25:58 +0200
committerChristoph Helma <helma@in-silico.ch>2016-10-07 10:25:58 +0200
commit91787edb3682900bc5a2feeca66e5142f387fcc6 (patch)
treefcb189bb8e84c9b3727b29f0846a5f820b596aee /lib/caret.rb
parent4348eec89033e6677c9f628646fc67bd03c73fe6 (diff)
unified interface for prediction algorithms
Diffstat (limited to 'lib/caret.rb')
-rw-r--r--lib/caret.rb152
1 files changed, 152 insertions, 0 deletions
diff --git a/lib/caret.rb b/lib/caret.rb
new file mode 100644
index 0000000..b999b06
--- /dev/null
+++ b/lib/caret.rb
@@ -0,0 +1,152 @@
+module OpenTox
+ module Algorithm
+
+ class Caret
+ # TODO classification
+ # model list: https://topepo.github.io/caret/modelList.html
+
+ attr_accessor :descriptors, :neighbors, :method, :relevant_features, :data_frame, :feature_names, :weights, :query_features
+
+ def initialize descriptors:, neighbors:, method:, relevant_features:
+ @descriptors = descriptors
+ @neighbors = neighbors
+ @method = method
+ @relevant_features = relevant_features
+ end
+
+ def self.regression descriptors:, neighbors:, method:, relevant_features:nil
+
+ caret = new(descriptors:descriptors, neighbors:neighbors, method:method, relevant_features:relevant_features)
+ # collect training data for R
+ if descriptors.is_a? Array
+ caret.fingerprint2R
+ elsif descriptors.is_a? Hash
+ caret.properties2R
+ else
+ bad_request_error "Descriptors should be a fingerprint (Array) or properties (Hash). Cannot handle '#{descriptors.class}'."
+ end
+ if caret.feature_names.empty? or caret.data_frame.flatten.uniq == ["NA"]
+ prediction = Algorithm::Regression::weighted_average(descriptors: @descriptors, neighbors: neighbors)
+ prediction[:warning] = "No variables for regression model. Using weighted average of similar substances."
+ else
+ prediction = caret.r_model_prediction
+ if prediction.nil? or prediction[:value].nil?
+ prediction = Algorithm::Regression::weighted_average(descriptors: @descriptors, neighbors: neighbors)
+ prediction[:warning] = "Could not create local caret model. Using weighted average of similar substances."
+ end
+ end
+ prediction
+
+ end
+
+ def fingerprint2R
+
+ values = []
+ features = {}
+ @weights = []
+ descriptor_ids = neighbors.collect{|n| n["descriptors"]}.flatten.uniq.sort
+
+ neighbors.each do |n|
+ activities = n["measurements"]
+ activities.each do |act|
+ values << act
+ @weights << n["similarity"]
+ descriptor_ids.each do |id|
+ features[id] ||= []
+ features[id] << n["descriptors"].include?(id)
+ end
+ end if activities
+ end
+
+ @feature_names = []
+ @data_frame = [values]
+
+ features.each do |k,v|
+ unless v.uniq.size == 1
+ @data_frame << v.collect{|m| m ? "T" : "F"}
+ @feature_names << k
+ end
+ end
+ @query_features = @feature_names.collect{|f| descriptors.include?(f) ? "T" : "F"}
+
+ end
+
+
+ def properties2R
+
+ @weights = []
+ @feature_names = []
+ @query_features = []
+
+ # keep only descriptors with values
+ @relevant_features.keys.each_with_index do |f,i|
+ if @descriptors[f]
+ @feature_names << f
+ @query_features << @descriptors[f].median
+ else
+ neighbors.each do |n|
+ n["descriptors"].delete_at i
+ end
+ end
+ end
+
+ measurements = neighbors.collect{|n| n["measurements"]}.flatten
+ # initialize data frame with 'NA' defaults
+ @data_frame = Array.new(@feature_names.size+1){Array.new(measurements.size,"NA") }
+
+ i = 0
+ # parse neighbor activities and descriptors
+ neighbors.each do |n|
+ activities = n["measurements"]
+ activities.each do |act| # multiple measurements are treated as separate instances
+ unless n["descriptors"].include?(nil)
+ data_frame[0][i] = act
+ @weights << n["similarity"]
+ n["descriptors"].each_with_index do |d,j|
+ @data_frame[j+1][i] = d
+ end
+ i += 1
+ end
+ end if activities # ignore neighbors without measurements
+ end
+
+ end
+
+ def r_model_prediction
+ begin
+ R.assign "weights", @weights
+ r_data_frame = "data.frame(#{@data_frame.collect{|r| "c(#{r.join(',')})"}.join(', ')})"
+ R.eval "data <- #{r_data_frame}"
+ R.assign "features", @feature_names
+ R.eval "names(data) <- append(c('activities'),features)" #
+ R.eval "model <- train(activities ~ ., data = data, method = '#{method}', na.action = na.pass, allowParallel=TRUE)"
+ rescue => e
+ $logger.debug "R caret model creation error for:"
+ $logger.debug JSON.pretty_generate(self.inspect)
+ return nil
+ end
+ begin
+ R.eval "query <- data.frame(rbind(c(#{@query_features.join ','})))"
+ R.eval "names(query) <- features"
+ R.eval "prediction <- predict(model,query)"
+ value = R.eval("prediction").to_f
+ rmse = R.eval("getTrainPerf(model)$TrainRMSE").to_f
+ r_squared = R.eval("getTrainPerf(model)$TrainRsquared").to_f
+ prediction_interval = value-1.96*rmse, value+1.96*rmse
+ {
+ :value => value,
+ :rmse => rmse,
+ :r_squared => r_squared,
+ :prediction_interval => prediction_interval
+ }
+ rescue => e
+ $logger.debug "R caret prediction error for:"
+ $logger.debug self.inspect
+ return nil
+ end
+ end
+
+ end
+ end
+end
+