f3e3af81151ed6cfb32fd68ea52d2d118c7658b0
[lazar] / lib / validation-statistics.rb
1 module OpenTox
2   module Validation
3     # Statistical evaluation of classification validations
4     module ClassificationStatistics
5
6       # Get statistics
7       # @return [Hash]
8       def statistics 
9         self.accept_values = model.prediction_feature.accept_values
10         self.confusion_matrix = {:all => Array.new(accept_values.size){Array.new(accept_values.size,0)}, :without_warnings => Array.new(accept_values.size){Array.new(accept_values.size,0)}}
11         self.weighted_confusion_matrix = {:all => Array.new(accept_values.size){Array.new(accept_values.size,0)}, :without_warnings => Array.new(accept_values.size){Array.new(accept_values.size,0)}}
12         self.nr_predictions = {:all => 0,:without_warnings => 0}
13         predictions.each do |cid,pred|
14           # TODO
15           # use predictions without probabilities (single neighbor)??
16           # use measured majority class??
17           if pred[:measurements].uniq.size == 1 and pred[:probabilities]
18             m = pred[:measurements].first
19             if pred[:value] == m
20               if pred[:value] == accept_values[0]
21                 confusion_matrix[:all][0][0] += 1
22                 weighted_confusion_matrix[:all][0][0] += pred[:probabilities][pred[:value]]
23                 self.nr_predictions[:all] += 1
24                 if pred[:warnings].empty?
25                   confusion_matrix[:without_warnings][0][0] += 1
26                   weighted_confusion_matrix[:without_warnings][0][0] += pred[:probabilities][pred[:value]]
27                   self.nr_predictions[:without_warnings] += 1
28                 end
29               elsif pred[:value] == accept_values[1]
30                 confusion_matrix[:all][1][1] += 1
31                 weighted_confusion_matrix[:all][1][1] += pred[:probabilities][pred[:value]]
32                 self.nr_predictions[:all] += 1
33                 if pred[:warnings].empty?
34                   confusion_matrix[:without_warnings][1][1] += 1
35                   weighted_confusion_matrix[:without_warnings][1][1] += pred[:probabilities][pred[:value]]
36                   self.nr_predictions[:without_warnings] += 1
37                 end
38               end
39             elsif pred[:value] != m
40               if pred[:value] == accept_values[0]
41                 confusion_matrix[:all][0][1] += 1
42                 weighted_confusion_matrix[:all][0][1] += pred[:probabilities][pred[:value]]
43                 self.nr_predictions[:all] += 1
44                 if pred[:warnings].empty?
45                   confusion_matrix[:without_warnings][0][1] += 1
46                   weighted_confusion_matrix[:without_warnings][0][1] += pred[:probabilities][pred[:value]]
47                   self.nr_predictions[:without_warnings] += 1
48                 end
49               elsif pred[:value] == accept_values[1]
50                 confusion_matrix[:all][1][0] += 1
51                 weighted_confusion_matrix[:all][1][0] += pred[:probabilities][pred[:value]]
52                 self.nr_predictions[:all] += 1
53                 if pred[:warnings].empty?
54                   confusion_matrix[:without_warnings][1][0] += 1
55                   weighted_confusion_matrix[:without_warnings][1][0] += pred[:probabilities][pred[:value]]
56                   self.nr_predictions[:without_warnings] += 1
57                 end
58               end
59             end
60           end
61         end
62         self.true_rate = {:all => {}, :without_warnings => {}}
63         self.predictivity = {:all => {}, :without_warnings => {}}
64         accept_values.each_with_index do |v,i|
65           [:all,:without_warnings].each do |a|
66             self.true_rate[a][v] = confusion_matrix[a][i][i]/confusion_matrix[a][i].reduce(:+).to_f
67             self.predictivity[a][v] = confusion_matrix[a][i][i]/confusion_matrix[a].collect{|n| n[i]}.reduce(:+).to_f
68           end
69         end
70         confidence_sum = {:all => 0, :without_warnings => 0}
71         [:all,:without_warnings].each do |a|
72           weighted_confusion_matrix[a].each do |r|
73             r.each do |c|
74               confidence_sum[a] += c
75             end
76           end
77         end
78         self.accuracy = {}
79         self.weighted_accuracy = {}
80         [:all,:without_warnings].each do |a|
81           self.accuracy[a] = (confusion_matrix[a][0][0]+confusion_matrix[a][1][1])/nr_predictions[a].to_f
82           self.weighted_accuracy[a] = (weighted_confusion_matrix[a][0][0]+weighted_confusion_matrix[a][1][1])/confidence_sum[a].to_f
83         end
84         $logger.debug "Accuracy #{accuracy}"
85         $logger.debug "Nr Predictions #{nr_predictions}"
86         save
87         {
88           :accept_values => accept_values,
89           :confusion_matrix => confusion_matrix,
90           :weighted_confusion_matrix => weighted_confusion_matrix,
91           :accuracy => accuracy,
92           :weighted_accuracy => weighted_accuracy,
93           :true_rate => self.true_rate,
94           :predictivity => self.predictivity,
95           :nr_predictions => nr_predictions,
96         }
97       end
98
99       # Plot accuracy vs prediction probability
100       # @param [String,nil] format
101       # @return [Blob]
102       def probability_plot format: "pdf"
103         #unless probability_plot_id
104
105           #tmpdir = File.join(ENV["HOME"], "tmp")
106           tmpdir = "/tmp"
107           #p tmpdir
108           FileUtils.mkdir_p tmpdir
109           tmpfile = File.join(tmpdir,"#{id.to_s}_probability.#{format}")
110           accuracies = []
111           probabilities = []
112           correct_predictions = 0
113           incorrect_predictions = 0
114           pp = []
115           predictions.values.select{|p| p["probabilities"]}.compact.each do |p|
116             p["measurements"].each do |m|
117               pp << [ p["probabilities"][p["value"]], p["value"] == m ]
118             end
119           end
120           pp.sort_by!{|p| 1-p.first}
121           pp.each do |p|
122             p[1] ? correct_predictions += 1 : incorrect_predictions += 1
123             accuracies << correct_predictions/(correct_predictions+incorrect_predictions).to_f
124             probabilities << p[0]
125           end
126           R.assign "accuracy", accuracies
127           R.assign "probability", probabilities
128           R.eval "image = qplot(probability,accuracy)+ylab('Accumulated accuracy')+xlab('Prediction probability')+ylim(c(0,1))+scale_x_reverse()+geom_line()"
129           R.eval "ggsave(file='#{tmpfile}', plot=image)"
130           file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{self.id.to_s}_probability_plot.svg")
131           plot_id = $gridfs.insert_one(file)
132           update(:probability_plot_id => plot_id)
133         #end
134         $gridfs.find_one(_id: probability_plot_id).data
135       end
136     end
137
138     # Statistical evaluation of regression validations
139     module RegressionStatistics
140
141       # Get statistics
142       # @return [Hash]
143       def statistics
144         self.warnings = []
145         self.rmse = {:all =>0,:without_warnings => 0}
146         self.r_squared  = {:all =>0,:without_warnings => 0}
147         self.mae = {:all =>0,:without_warnings => 0}
148         self.within_prediction_interval = {:all =>0,:without_warnings => 0}
149         self.out_of_prediction_interval = {:all =>0,:without_warnings => 0}
150         x = {:all => [],:without_warnings => []}
151         y = {:all => [],:without_warnings => []}
152         self.nr_predictions = {:all =>0,:without_warnings => 0}
153         predictions.each do |cid,pred|
154           !if pred[:value] and pred[:measurements] and !pred[:measurements].empty?
155             self.nr_predictions[:all] +=1
156             x[:all] << pred[:measurements].median
157             y[:all] << pred[:value]
158             error = pred[:value]-pred[:measurements].median
159             self.rmse[:all] += error**2
160             self.mae[:all] += error.abs
161             if pred[:prediction_interval]
162               if pred[:measurements].median >= pred[:prediction_interval][0] and pred[:measurements].median <= pred[:prediction_interval][1]
163                 self.within_prediction_interval[:all] += 1
164               else
165                 self.out_of_prediction_interval[:all] += 1
166               end
167             end
168             if pred[:warnings].empty?
169               self.nr_predictions[:without_warnings] +=1
170               x[:without_warnings] << pred[:measurements].median
171               y[:without_warnings] << pred[:value]
172               error = pred[:value]-pred[:measurements].median
173               self.rmse[:without_warnings] += error**2
174               self.mae[:without_warnings] += error.abs
175               if pred[:prediction_interval]
176                 if pred[:measurements].median >= pred[:prediction_interval][0] and pred[:measurements].median <= pred[:prediction_interval][1]
177                   self.within_prediction_interval[:without_warnings] += 1
178                 else
179                   self.out_of_prediction_interval[:without_warnings] += 1
180                 end
181               end
182             end
183           else
184             trd_id = model.training_dataset_id
185             smiles = Compound.find(cid).smiles
186             self.warnings << "No training activities for #{smiles} in training dataset #{trd_id}."
187             $logger.debug "No training activities for #{smiles} in training dataset #{trd_id}."
188           end
189         end
190         [:all,:without_warnings].each do |a|
191           if x[a].size > 2
192             R.assign "measurement", x[a]
193             R.assign "prediction", y[a]
194             R.eval "r <- cor(measurement,prediction,use='pairwise')"
195             self.r_squared[a] = R.eval("r").to_ruby**2
196           else
197             self.r_squared[a] = 0
198           end
199           if self.nr_predictions[a] > 0
200             self.mae[a] = self.mae[a]/self.nr_predictions[a]
201             self.rmse[a] = Math.sqrt(self.rmse[a]/self.nr_predictions[a])
202           else
203             self.mae[a] = nil
204             self.rmse[a] = nil
205           end
206         end
207         $logger.debug "R^2 #{r_squared}"
208         $logger.debug "RMSE #{rmse}"
209         $logger.debug "MAE #{mae}"
210         $logger.debug "Nr predictions #{nr_predictions}"
211         $logger.debug "#{within_prediction_interval} measurements within prediction interval"
212         $logger.debug "#{warnings}"
213         save
214         {
215           :mae => mae,
216           :rmse => rmse,
217           :r_squared => r_squared,
218           :within_prediction_interval => self.within_prediction_interval,
219           :out_of_prediction_interval => out_of_prediction_interval,
220           :nr_predictions => nr_predictions,
221         }
222       end
223
224       # Plot predicted vs measured values
225       # @param [String,nil] format
226       # @return [Blob]
227       def correlation_plot format: "png"
228         unless correlation_plot_id
229           tmpfile = "/tmp/#{id.to_s}_correlation.#{format}"
230           x = []
231           y = []
232           feature = Feature.find(predictions.first.last["prediction_feature_id"])
233           predictions.each do |sid,p|
234             x << p["measurements"].median
235             y << p["value"]
236           end
237           R.assign "measurement", x
238           R.assign "prediction", y
239           R.eval "all = c(measurement,prediction)"
240           R.eval "range = c(min(all), max(all))"
241           if feature.name.match /Net cell association/ # ad hoc fix for awkward units
242             title = "log2(Net cell association [mL/ug(Mg)])"
243           else
244             title = feature.name
245             title += "-log10(#{feature.unit})" if feature.unit and !feature.unit.blank?
246           end
247           R.eval "image = qplot(prediction,measurement,main='#{title}',xlab='Prediction',ylab='Measurement',asp=1,xlim=range, ylim=range)"
248           R.eval "image = image + geom_abline(intercept=0, slope=1)"
249           R.eval "ggsave(file='#{tmpfile}', plot=image)"
250           file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{id.to_s}_correlation_plot.#{format}")
251           plot_id = $gridfs.insert_one(file)
252           update(:correlation_plot_id => plot_id)
253         end
254         $gridfs.find_one(_id: correlation_plot_id).data
255       end
256
257       # Get predictions with measurements outside of the prediction interval
258       # @return [Hash]
259       def worst_predictions
260         worst_predictions = predictions.select do |sid,p|
261           p["prediction_interval"] and p["value"] and (p["measurements"].max < p["prediction_interval"][0] or p["measurements"].min > p["prediction_interval"][1])
262         end.compact.to_h
263         worst_predictions.each do |sid,p|
264           p["error"] = (p["value"] - p["measurements"].median).abs
265           if p["measurements"].max < p["prediction_interval"][0]
266             p["distance_prediction_interval"] = (p["measurements"].max - p["prediction_interval"][0]).abs
267           elsif p["measurements"].min > p["prediction_interval"][1]
268             p["distance_prediction_interval"] = (p["measurements"].min - p["prediction_interval"][1]).abs
269           end
270         end
271         worst_predictions.sort_by{|sid,p| p["distance_prediction_interval"] }.to_h
272       end
273     end
274   end
275 end