diff options
Diffstat (limited to 'lib/crossvalidation.rb')
-rw-r--r-- | lib/crossvalidation.rb | 42 |
1 files changed, 17 insertions, 25 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb index 06a1e2a..e1761bc 100644 --- a/lib/crossvalidation.rb +++ b/lib/crossvalidation.rb @@ -15,7 +15,7 @@ module OpenTox $logger.debug model.algorithms klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification klass = RegressionCrossValidation if model.is_a? Model::LazarRegression - bad_request_error "Unknown model class #{model.class}." unless klass + raise ArgumentError, "Unknown model class #{model.class}." unless klass cv = klass.new( name: model.name, @@ -24,24 +24,16 @@ module OpenTox ) cv.save # set created_at - nr_instances = 0 - nr_unpredicted = 0 training_dataset = model.training_dataset training_dataset.folds(n).each_with_index do |fold,fold_nr| #fork do # parallel execution of validations can lead to Rserve and memory problems - $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started" - t = Time.now - validation = TrainTest.create(model, fold[0], fold[1]) - cv.validation_ids << validation.id - cv.nr_instances += validation.nr_instances - cv.nr_unpredicted += validation.nr_unpredicted - #cv.predictions.merge! validation.predictions - $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}: #{Time.now-t} seconds" - #end + $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started" + t = Time.now + validation = TrainTest.create(model, fold[0], fold[1]) + cv.validation_ids << validation.id + $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}: #{Time.now-t} seconds" end - #Process.waitall cv.save - $logger.debug "Nr unpredicted: #{nr_unpredicted}" cv.statistics cv.update_attributes(finished_at: Time.now) cv @@ -72,25 +64,25 @@ module OpenTox class ClassificationCrossValidation < CrossValidation include ClassificationStatistics field :accept_values, type: Array - field :confusion_matrix, type: Array - field :weighted_confusion_matrix, type: Array - field :accuracy, type: Float - field :weighted_accuracy, type: Float + field :confusion_matrix, type: Hash + field :accuracy, type: Hash field :true_rate, type: Hash field :predictivity, type: Hash + field :nr_predictions, type: Hash field :probability_plot_id, type: BSON::ObjectId end # Crossvalidation of regression models class RegressionCrossValidation < CrossValidation include RegressionStatistics - field :rmse, type: Float, default:0 - field :mae, type: Float, default:0 - field :r_squared, type: Float - field :within_prediction_interval, type: Integer, default:0 - field :out_of_prediction_interval, type: Integer, default:0 - field :correlation_plot_id, type: BSON::ObjectId + field :rmse, type: Hash + field :mae, type: Hash + field :r_squared, type: Hash + field :within_prediction_interval, type: Hash + field :out_of_prediction_interval, type: Hash + field :nr_predictions, type: Hash field :warnings, type: Array + field :correlation_plot_id, type: BSON::ObjectId end # Independent repeated crossvalidations @@ -103,7 +95,7 @@ module OpenTox # @param [Fixnum] number of folds # @param [Fixnum] number of repeats # @return [OpenTox::Validation::RepeatedCrossValidation] - def self.create model, folds=10, repeats=3 + def self.create model, folds=10, repeats=5 repeated_cross_validation = self.new repeats.times do |n| $logger.debug "Crossvalidation #{n+1} for #{model.name}" |