summaryrefslogtreecommitdiff
path: root/lib/crossvalidation.rb
diff options
context:
space:
mode:
Diffstat (limited to 'lib/crossvalidation.rb')
-rw-r--r--lib/crossvalidation.rb22
1 files changed, 22 insertions, 0 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb
index bcb3ccf..75c5db5 100644
--- a/lib/crossvalidation.rb
+++ b/lib/crossvalidation.rb
@@ -1,10 +1,16 @@
module OpenTox
module Validation
+
+ # Crossvalidation
class CrossValidation < Validation
field :validation_ids, type: Array, default: []
field :folds, type: Integer, default: 10
+ # Create a crossvalidation
+ # @param [OpenTox::Model::Lazar]
+ # @param [Fixnum] number of folds
+ # @return [OpenTox::Validation::CrossValidation]
def self.create model, n=10
$logger.debug model.algorithms
klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
@@ -41,14 +47,20 @@ module OpenTox
cv
end
+ # Get execution time
+ # @return [Fixnum]
def time
finished_at - created_at
end
+ # Get individual validations
+ # @return [Array<OpenTox::Validation>]
def validations
validation_ids.collect{|vid| TrainTest.find vid}
end
+ # Get predictions for all compounds
+ # @return [Array<Hash>]
def predictions
predictions = {}
validations.each{|v| predictions.merge!(v.predictions)}
@@ -56,6 +68,7 @@ module OpenTox
end
end
+ # Crossvalidation of classification models
class ClassificationCrossValidation < CrossValidation
include ClassificationStatistics
field :accept_values, type: Array
@@ -68,6 +81,7 @@ module OpenTox
field :probability_plot_id, type: BSON::ObjectId
end
+ # Crossvalidation of regression models
class RegressionCrossValidation < CrossValidation
include RegressionStatistics
field :rmse, type: Float, default:0
@@ -78,10 +92,16 @@ module OpenTox
field :correlation_plot_id, type: BSON::ObjectId
end
+ # Independent repeated crossvalidations
class RepeatedCrossValidation < Validation
field :crossvalidation_ids, type: Array, default: []
field :correlation_plot_id, type: BSON::ObjectId
+ # Create repeated crossvalidations
+ # @param [OpenTox::Model::Lazar]
+ # @param [Fixnum] number of folds
+ # @param [Fixnum] number of repeats
+ # @return [OpenTox::Validation::RepeatedCrossValidation]
def self.create model, folds=10, repeats=3
repeated_cross_validation = self.new
repeats.times do |n|
@@ -92,6 +112,8 @@ module OpenTox
repeated_cross_validation
end
+ # Get crossvalidations
+ # @return [OpenTox::Validation::CrossValidation]
def crossvalidations
crossvalidation_ids.collect{|id| CrossValidation.find(id)}
end