always render new correlation plot; keep same handling as for probability plot
[lazar] / lib / validation-statistics.rb
1 module OpenTox
2   module Validation
3     # Statistical evaluation of classification validations
4     module ClassificationStatistics
5
6       # Get statistics
7       # @return [Hash]
8       def statistics 
9         self.accept_values = model.prediction_feature.accept_values
10         self.confusion_matrix = {:all => Array.new(accept_values.size){Array.new(accept_values.size,0)}, :confidence_high => Array.new(accept_values.size){Array.new(accept_values.size,0)}, :confidence_low => Array.new(accept_values.size){Array.new(accept_values.size,0)}}
11         self.nr_predictions = {:all => 0,:confidence_high => 0,:confidence_low => 0}
12         predictions.each do |cid,pred|
13           # TODO: use measured majority class or all measurements??
14           if pred[:measurements].uniq.size == 1 and pred[:probabilities]
15             m = pred[:measurements].first
16             if pred[:value] == m
17               accept_values.each_with_index do |v,i|
18                 if pred[:value] == v
19                   confusion_matrix[:all][i][i] += 1
20                   self.nr_predictions[:all] += 1
21                   if pred[:confidence].match(/Similar/i)
22                     confusion_matrix[:confidence_high][i][i] += 1
23                     self.nr_predictions[:confidence_high] += 1
24                   elsif pred[:confidence].match(/Low/i)
25                     confusion_matrix[:confidence_low][i][i] += 1
26                     self.nr_predictions[:confidence_low] += 1
27                   end
28                 end
29               end
30             elsif pred[:value] != m
31               accept_values.each_with_index do |v,i|
32                 if pred[:value] == v
33                   confusion_matrix[:all][i][(i+1)%2] += 1
34                   self.nr_predictions[:all] += 1
35                   if pred[:confidence].match(/Similar/i)
36                     confusion_matrix[:confidence_high][i][(i+1)%2] += 1
37                     self.nr_predictions[:confidence_high] += 1
38                   elsif pred[:confidence].match(/Low/i)
39                     confusion_matrix[:confidence_low][i][(i+1)%2] += 1
40                     self.nr_predictions[:confidence_low] += 1
41                   end
42                 end
43               end
44             end
45           end
46         end
47
48         self.true_rate = {:all => {}, :confidence_high => {}, :confidence_low => {}}
49         self.predictivity = {:all => {}, :confidence_high => {}, :confidence_low => {}}
50         accept_values.each_with_index do |v,i|
51           [:all,:confidence_high,:confidence_low].each do |a|
52             self.true_rate[a][v] = confusion_matrix[a][i][i]/confusion_matrix[a][i].reduce(:+).to_f
53             self.predictivity[a][v] = confusion_matrix[a][i][i]/confusion_matrix[a].collect{|n| n[i]}.reduce(:+).to_f
54           end
55         end
56         self.accuracy = {}
57         [:all,:confidence_high,:confidence_low].each do |a|
58           self.accuracy[a] = (confusion_matrix[a][0][0]+confusion_matrix[a][1][1])/nr_predictions[a].to_f
59         end
60         $logger.debug "Accuracy #{accuracy}"
61         $logger.debug "Nr Predictions #{nr_predictions}"
62         save
63         {
64           :accept_values => accept_values,
65           :confusion_matrix => confusion_matrix,
66           :accuracy => accuracy,
67           :true_rate => self.true_rate,
68           :predictivity => self.predictivity,
69           :nr_predictions => nr_predictions,
70         }
71       end
72
73       # Plot accuracy vs prediction probability
74       # @param [String,nil] format
75       # @return [Blob]
76       def probability_plot format: "pdf"
77         #unless probability_plot_id
78
79           #tmpdir = File.join(ENV["HOME"], "tmp")
80           tmpdir = "/tmp"
81           #p tmpdir
82           FileUtils.mkdir_p tmpdir
83           tmpfile = File.join(tmpdir,"#{id.to_s}_probability.#{format}")
84           accuracies = []
85           probabilities = []
86           correct_predictions = 0
87           incorrect_predictions = 0
88           pp = []
89           predictions.values.select{|p| p["probabilities"]}.compact.each do |p|
90             p["measurements"].each do |m|
91               pp << [ p["probabilities"][p["value"]], p["value"] == m ]
92             end
93           end
94           pp.sort_by!{|p| 1-p.first}
95           pp.each do |p|
96             p[1] ? correct_predictions += 1 : incorrect_predictions += 1
97             accuracies << correct_predictions/(correct_predictions+incorrect_predictions).to_f
98             probabilities << p[0]
99           end
100           R.assign "accuracy", accuracies
101           R.assign "probability", probabilities
102           R.eval "image = qplot(probability,accuracy)+ylab('Accumulated accuracy')+xlab('Prediction probability')+ylim(c(0,1))+scale_x_reverse()+geom_line()"
103           R.eval "ggsave(file='#{tmpfile}', plot=image)"
104           file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{self.id.to_s}_probability_plot.#{format}")
105           plot_id = $gridfs.insert_one(file)
106           update(:probability_plot_id => plot_id)
107         #end
108         $gridfs.find_one(_id: probability_plot_id).data
109       end
110     end
111
112     # Statistical evaluation of regression validations
113     module RegressionStatistics
114
115       attr_accessor :x, :y
116
117       # Get statistics
118       # @return [Hash]
119       def statistics
120         self.warnings = []
121         self.rmse = {:all =>0,:confidence_high => 0,:confidence_low => 0}
122         self.r_squared  = {:all =>0,:confidence_high => 0,:confidence_low => 0}
123         self.mae = {:all =>0,:confidence_high => 0,:confidence_low => 0}
124         self.within_prediction_interval = {:all =>0,:confidence_high => 0,:confidence_low => 0}
125         self.out_of_prediction_interval = {:all =>0,:confidence_high => 0,:confidence_low => 0}
126         @x = {:all => [],:confidence_high => [],:confidence_low => []}
127         @y = {:all => [],:confidence_high => [],:confidence_low => []}
128         self.nr_predictions = {:all =>0,:confidence_high => 0,:confidence_low => 0}
129         predictions.each do |cid,pred|
130           !if pred[:value] and pred[:measurements] and !pred[:measurements].empty?
131             insert_prediction pred, :all
132             if pred[:confidence].match(/Similar/i)
133               insert_prediction pred, :confidence_high
134             elsif pred[:confidence].match(/Low/i)
135               insert_prediction pred, :confidence_low
136             end
137           else
138             trd_id = model.training_dataset_id
139             smiles = Compound.find(cid).smiles
140             self.warnings << "No training activities for #{smiles} in training dataset #{trd_id}."
141             $logger.debug "No training activities for #{smiles} in training dataset #{trd_id}."
142           end
143         end
144         [:all,:confidence_high,:confidence_low].each do |a|
145           if @x[a].size > 2
146             R.assign "measurement", @x[a]
147             R.assign "prediction", @y[a]
148             R.eval "r <- cor(measurement,prediction,use='pairwise')"
149             self.r_squared[a] = R.eval("r").to_ruby**2
150           else
151             self.r_squared[a] = 0
152           end
153           if self.nr_predictions[a] > 0
154             self.mae[a] = self.mae[a]/self.nr_predictions[a]
155             self.rmse[a] = Math.sqrt(self.rmse[a]/self.nr_predictions[a])
156           else
157             self.mae[a] = nil
158             self.rmse[a] = nil
159           end
160         end
161         $logger.debug "R^2 #{r_squared}"
162         $logger.debug "RMSE #{rmse}"
163         $logger.debug "MAE #{mae}"
164         $logger.debug "Nr predictions #{nr_predictions}"
165         $logger.debug "#{within_prediction_interval} measurements within prediction interval"
166         save
167         {
168           :mae => mae,
169           :rmse => rmse,
170           :r_squared => r_squared,
171           :within_prediction_interval => self.within_prediction_interval,
172           :out_of_prediction_interval => out_of_prediction_interval,
173           :nr_predictions => nr_predictions,
174         }
175       end
176
177       # Plot predicted vs measured values
178       # @param [String,nil] format
179       # @return [Blob]
180       def correlation_plot format: "png"
181         #unless correlation_plot_id
182           #tmpfile = "/tmp/#{id.to_s}_correlation.#{format}"
183           tmpdir = "/tmp"
184           #p tmpdir
185           FileUtils.mkdir_p tmpdir
186           tmpfile = File.join(tmpdir,"#{id.to_s}_correlation.#{format}")
187           x = []
188           y = []
189           feature = Feature.find(predictions.first.last["prediction_feature_id"])
190           predictions.each do |sid,p|
191             x << p["measurements"].median
192             y << p["value"]
193           end
194           R.assign "measurement", x
195           R.assign "prediction", y
196           R.eval "all = c(measurement,prediction)"
197           R.eval "range = c(min(all), max(all))"
198           if feature.name.match /Net cell association/ # ad hoc fix for awkward units
199             title = "log2(Net cell association [mL/ug(Mg)])"
200           else
201             title = feature.name
202             title += "-log10(#{feature.unit})" if feature.unit and !feature.unit.blank?
203           end
204           R.eval "image = qplot(prediction,measurement,main='#{title}',xlab='Prediction',ylab='Measurement',asp=1,xlim=range, ylim=range)"
205           R.eval "image = image + geom_abline(intercept=0, slope=1)"
206           R.eval "ggsave(file='#{tmpfile}', plot=image)"
207           file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{id.to_s}_correlation_plot.#{format}")
208           plot_id = $gridfs.insert_one(file)
209           update(:correlation_plot_id => plot_id)
210         #end
211         $gridfs.find_one(_id: correlation_plot_id).data
212       end
213
214       # Get predictions with measurements outside of the prediction interval
215       # @return [Hash]
216       def worst_predictions
217         worst_predictions = predictions.select do |sid,p|
218           p["prediction_interval"] and p["value"] and (p["measurements"].max < p["prediction_interval"][0] or p["measurements"].min > p["prediction_interval"][1])
219         end.compact.to_h
220         worst_predictions.each do |sid,p|
221           p["error"] = (p["value"] - p["measurements"].median).abs
222           if p["measurements"].max < p["prediction_interval"][0]
223             p["distance_prediction_interval"] = (p["measurements"].max - p["prediction_interval"][0]).abs
224           elsif p["measurements"].min > p["prediction_interval"][1]
225             p["distance_prediction_interval"] = (p["measurements"].min - p["prediction_interval"][1]).abs
226           end
227         end
228         worst_predictions.sort_by{|sid,p| p["distance_prediction_interval"] }.to_h
229       end
230
231       private
232
233       def insert_prediction prediction, type
234         self.nr_predictions[type] +=1
235         @x[type] << prediction[:measurements].median
236         @y[type] << prediction[:value]
237         error = prediction[:value]-prediction[:measurements].median
238         self.rmse[type] += error**2
239         self.mae[type] += error.abs
240         if prediction[:prediction_interval]
241           if prediction[:measurements].median >= prediction[:prediction_interval][0] and prediction[:measurements].median <= prediction[:prediction_interval][1]
242             self.within_prediction_interval[type] += 1
243           else
244             self.out_of_prediction_interval[type] += 1
245           end
246         end
247       end
248     end
249   end
250 end