real datasets for testing, test data cleanup, Daphnia import, upper and lower similar...
[lazar] / lib / crossvalidation.rb
1 module OpenTox
2
3   module Validation
4
5     # Crossvalidation
6     class CrossValidation < Validation
7       field :validation_ids, type: Array, default: []
8       field :folds, type: Integer, default: 10
9
10       # Create a crossvalidation
11       # @param [OpenTox::Model::Lazar]
12       # @param [Fixnum] number of folds
13       # @return [OpenTox::Validation::CrossValidation]
14       def self.create model, n=10
15         $logger.debug model.algorithms
16         klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
17         klass = RegressionCrossValidation if model.is_a? Model::LazarRegression
18         raise ArgumentError, "Unknown model class #{model.class}." unless klass
19
20         cv = klass.new(
21           name: model.name,
22           model_id: model.id,
23           folds: n
24         )
25         cv.save # set created_at
26
27         training_dataset = model.training_dataset
28         training_dataset.folds(n).each_with_index do |fold,fold_nr|
29           #fork do # parallel execution of validations can lead to Rserve and memory problems
30           $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started"
31           t = Time.now
32           validation = TrainTest.create(model, fold[0], fold[1])
33           cv.validation_ids << validation.id
34           $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}:  #{Time.now-t} seconds"
35         end
36         cv.save
37         cv.statistics
38         cv.update_attributes(finished_at: Time.now)
39         cv
40       end
41
42       # Get execution time
43       # @return [Fixnum]
44       def time
45         finished_at - created_at
46       end
47
48       # Get individual validations
49       # @return [Array<OpenTox::Validation>]
50       def validations
51         validation_ids.collect{|vid| TrainTest.find vid}
52       end
53
54       # Get predictions for all compounds
55       # @return [Array<Hash>]
56       def predictions
57         predictions = {}
58         validations.each{|v| predictions.merge!(v.predictions)}
59         predictions
60       end
61     end
62
63     # Crossvalidation of classification models
64     class ClassificationCrossValidation < CrossValidation
65       include ClassificationStatistics
66       field :accept_values, type: Array
67       field :confusion_matrix, type: Hash
68       field :accuracy, type: Hash
69       field :true_rate, type: Hash
70       field :predictivity, type: Hash
71       field :nr_predictions, type: Hash
72       field :probability_plot_id, type: BSON::ObjectId
73     end
74
75     # Crossvalidation of regression models
76     class RegressionCrossValidation < CrossValidation
77       include RegressionStatistics
78       field :rmse, type: Hash
79       field :mae, type: Hash
80       field :r_squared, type: Hash
81       field :within_prediction_interval, type: Hash
82       field :out_of_prediction_interval, type: Hash
83       field :nr_predictions, type: Hash
84       field :warnings, type: Array
85       field :correlation_plot_id, type: BSON::ObjectId
86     end
87
88     # Independent repeated crossvalidations
89     class RepeatedCrossValidation < Validation
90       field :crossvalidation_ids, type: Array, default: []
91       field :correlation_plot_id, type: BSON::ObjectId
92
93       # Create repeated crossvalidations
94       # @param [OpenTox::Model::Lazar]
95       # @param [Fixnum] number of folds
96       # @param [Fixnum] number of repeats
97       # @return [OpenTox::Validation::RepeatedCrossValidation]
98       def self.create model, folds=10, repeats=5
99         repeated_cross_validation = self.new
100         repeats.times do |n|
101           $logger.debug "Crossvalidation #{n+1} for #{model.name}"
102           repeated_cross_validation.crossvalidation_ids << CrossValidation.create(model, folds).id
103         end
104         repeated_cross_validation.save
105         repeated_cross_validation
106       end
107
108       # Get crossvalidations
109       # @return [OpenTox::Validation::CrossValidation]
110       def crossvalidations
111         crossvalidation_ids.collect{|id| CrossValidation.find(id)}
112       end
113
114     end
115   end
116
117 end