8719dcaf0161d5f3c1fc188f40d5aaa456d47f46
[lazar] / lib / crossvalidation.rb
1 module OpenTox
2
3   module Validation
4
5     # Crossvalidation
6     class CrossValidation < Validation
7       field :validation_ids, type: Array, default: []
8       field :folds, type: Integer, default: 10
9
10       # Create a crossvalidation
11       # @param [OpenTox::Model::Lazar]
12       # @param [Fixnum] number of folds
13       # @return [OpenTox::Validation::CrossValidation]
14       def self.create model, n=10
15         $logger.debug model.algorithms
16         klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
17         klass = RegressionCrossValidation if model.is_a? Model::LazarRegression
18         raise ArgumentError, "Unknown model class #{model.class}." unless klass
19
20         cv = klass.new(
21           name: model.name,
22           model_id: model.id,
23           folds: n
24         )
25         cv.save # set created_at
26
27         training_dataset = model.training_dataset
28         training_dataset.folds(n).each_with_index do |fold,fold_nr|
29           #fork do # parallel execution of validations can lead to Rserve and memory problems
30           $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started"
31           t = Time.now
32           validation = TrainTest.create(model, fold[0], fold[1])
33           cv.validation_ids << validation.id
34           $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}:  #{Time.now-t} seconds"
35         end
36         cv.save
37         cv.statistics
38         cv.update_attributes(finished_at: Time.now)
39         cv
40       end
41
42       # Get execution time
43       # @return [Fixnum]
44       def time
45         finished_at - created_at
46       end
47
48       # Get individual validations
49       # @return [Array<OpenTox::Validation>]
50       def validations
51         validation_ids.collect{|vid| TrainTest.find vid}
52       end
53
54       # Get predictions for all compounds
55       # @return [Array<Hash>]
56       def predictions
57         predictions = {}
58         validations.each{|v| predictions.merge!(v.predictions)}
59         predictions
60       end
61     end
62
63     # Crossvalidation of classification models
64     class ClassificationCrossValidation < CrossValidation
65       include ClassificationStatistics
66       field :accept_values, type: Array
67       field :confusion_matrix, type: Hash
68       field :weighted_confusion_matrix, type: Hash
69       field :accuracy, type: Hash
70       field :weighted_accuracy, type: Hash
71       field :true_rate, type: Hash
72       field :predictivity, type: Hash
73       field :nr_predictions, type: Hash
74       field :probability_plot_id, type: BSON::ObjectId
75     end
76
77     # Crossvalidation of regression models
78     class RegressionCrossValidation < CrossValidation
79       include RegressionStatistics
80       field :rmse, type: Hash
81       field :mae, type: Hash
82       field :r_squared, type: Hash
83       field :within_prediction_interval, type: Hash
84       field :out_of_prediction_interval, type: Hash
85       field :nr_predictions, type: Hash
86       field :warnings, type: Array
87       field :correlation_plot_id, type: BSON::ObjectId
88     end
89
90     # Independent repeated crossvalidations
91     class RepeatedCrossValidation < Validation
92       field :crossvalidation_ids, type: Array, default: []
93       field :correlation_plot_id, type: BSON::ObjectId
94
95       # Create repeated crossvalidations
96       # @param [OpenTox::Model::Lazar]
97       # @param [Fixnum] number of folds
98       # @param [Fixnum] number of repeats
99       # @return [OpenTox::Validation::RepeatedCrossValidation]
100       def self.create model, folds=10, repeats=5
101         repeated_cross_validation = self.new
102         repeats.times do |n|
103           $logger.debug "Crossvalidation #{n+1} for #{model.name}"
104           repeated_cross_validation.crossvalidation_ids << CrossValidation.create(model, folds).id
105         end
106         repeated_cross_validation.save
107         repeated_cross_validation
108       end
109
110       # Get crossvalidations
111       # @return [OpenTox::Validation::CrossValidation]
112       def crossvalidations
113         crossvalidation_ids.collect{|id| CrossValidation.find(id)}
114       end
115
116     end
117   end
118
119 end