summaryrefslogtreecommitdiff
path: root/lib/caret.rb
blob: b999b06782021b6a6bc5bb6281cbb5220fca94ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
module OpenTox
  module Algorithm
    
    class Caret
      # TODO classification
      # model list: https://topepo.github.io/caret/modelList.html

      attr_accessor :descriptors, :neighbors, :method, :relevant_features, :data_frame, :feature_names, :weights, :query_features

      def initialize descriptors:, neighbors:, method:, relevant_features:
        @descriptors = descriptors
        @neighbors = neighbors
        @method = method
        @relevant_features = relevant_features
      end

      def self.regression descriptors:, neighbors:, method:, relevant_features:nil 

        caret = new(descriptors:descriptors, neighbors:neighbors, method:method, relevant_features:relevant_features)
        # collect training data for R
        if descriptors.is_a? Array
          caret.fingerprint2R 
        elsif descriptors.is_a? Hash
          caret.properties2R
        else
          bad_request_error "Descriptors should be a fingerprint (Array) or properties (Hash). Cannot handle '#{descriptors.class}'."
        end
        if caret.feature_names.empty? or caret.data_frame.flatten.uniq == ["NA"]
          prediction = Algorithm::Regression::weighted_average(descriptors: @descriptors, neighbors: neighbors)
          prediction[:warning] = "No variables for regression model. Using weighted average of similar substances."
        else
          prediction = caret.r_model_prediction 
          if prediction.nil? or prediction[:value].nil?
            prediction = Algorithm::Regression::weighted_average(descriptors: @descriptors, neighbors: neighbors)
            prediction[:warning] = "Could not create local caret model. Using weighted average of similar substances."
          end
        end
        prediction

      end

      def fingerprint2R

        values = []
        features = {}
        @weights = []
        descriptor_ids = neighbors.collect{|n| n["descriptors"]}.flatten.uniq.sort

        neighbors.each do |n|
          activities = n["measurements"]
          activities.each do |act|
            values << act
            @weights << n["similarity"]
            descriptor_ids.each do |id|
              features[id] ||= []
              features[id] << n["descriptors"].include?(id) 
            end
          end if activities
        end

        @feature_names = []
        @data_frame = [values]

        features.each do |k,v| 
          unless v.uniq.size == 1
            @data_frame << v.collect{|m| m ? "T" : "F"}
            @feature_names << k
          end
        end
        @query_features = @feature_names.collect{|f| descriptors.include?(f) ? "T" : "F"} 

      end


      def properties2R 

        @weights = []
        @feature_names = []
        @query_features = []

        # keep only descriptors with values
        @relevant_features.keys.each_with_index do |f,i|
          if @descriptors[f]
            @feature_names << f
            @query_features << @descriptors[f].median
          else
            neighbors.each do |n|
              n["descriptors"].delete_at i
            end
          end
        end
        
        measurements = neighbors.collect{|n| n["measurements"]}.flatten 
        # initialize data frame with 'NA' defaults
        @data_frame = Array.new(@feature_names.size+1){Array.new(measurements.size,"NA") }

        i = 0
        # parse neighbor activities and descriptors
        neighbors.each do |n|
          activities = n["measurements"]
          activities.each do |act| # multiple measurements are treated as separate instances
            unless n["descriptors"].include?(nil)
              data_frame[0][i] = act
              @weights << n["similarity"]
              n["descriptors"].each_with_index do |d,j| 
                @data_frame[j+1][i] = d
              end
              i += 1
            end
          end if activities # ignore neighbors without measurements
        end

      end

      def r_model_prediction 
        begin
          R.assign "weights", @weights
          r_data_frame = "data.frame(#{@data_frame.collect{|r| "c(#{r.join(',')})"}.join(', ')})"
          R.eval "data <- #{r_data_frame}"
          R.assign "features", @feature_names
          R.eval "names(data) <- append(c('activities'),features)" #
          R.eval "model <- train(activities ~ ., data = data, method = '#{method}', na.action = na.pass, allowParallel=TRUE)"
        rescue => e
          $logger.debug "R caret model creation error for:"
          $logger.debug JSON.pretty_generate(self.inspect)
          return nil
        end
        begin
          R.eval "query <- data.frame(rbind(c(#{@query_features.join ','})))"
          R.eval "names(query) <- features" 
          R.eval "prediction <- predict(model,query)"
          value = R.eval("prediction").to_f
          rmse = R.eval("getTrainPerf(model)$TrainRMSE").to_f
          r_squared = R.eval("getTrainPerf(model)$TrainRsquared").to_f
          prediction_interval = value-1.96*rmse, value+1.96*rmse
          {
            :value => value,
            :rmse => rmse,
            :r_squared => r_squared,
            :prediction_interval => prediction_interval
          }
        rescue => e
          $logger.debug "R caret prediction error for:"
          $logger.debug self.inspect
          return nil
        end
      end

    end
  end
end