From 46c628f1757ce8274a0b277b3ec3306609b38c14 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Mon, 25 Jul 2016 15:53:22 +0200 Subject: local_weighted_average fallback fixed, cv predictions pulled from validations to avoid mongo document size errors --- lib/crossvalidation.rb | 10 ++++++++-- lib/regression.rb | 4 ++-- test/nanoparticles.rb | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb index 7aae3d2..d7a1f08 100644 --- a/lib/crossvalidation.rb +++ b/lib/crossvalidation.rb @@ -18,7 +18,7 @@ module OpenTox cv.save # set created_at nr_instances = 0 nr_unpredicted = 0 - predictions = {} + #predictions = {} training_dataset = Dataset.find model.training_dataset_id training_dataset.folds(n).each_with_index do |fold,fold_nr| #fork do # parallel execution of validations can lead to Rserve and memory problems @@ -28,7 +28,7 @@ module OpenTox cv.validation_ids << validation.id cv.nr_instances += validation.nr_instances cv.nr_unpredicted += validation.nr_unpredicted - cv.predictions.merge! validation.predictions + #cv.predictions.merge! validation.predictions $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}: #{Time.now-t} seconds" #end end @@ -47,6 +47,12 @@ module OpenTox def validations validation_ids.collect{|vid| TrainTest.find vid} end + + def predictions + predictions = {} + validations.each{|v| predictions.merge!(v.predictions)} + predictions + end end class ClassificationCrossValidation < CrossValidation diff --git a/lib/regression.rb b/lib/regression.rb index d034d0b..269a743 100644 --- a/lib/regression.rb +++ b/lib/regression.rb @@ -48,7 +48,7 @@ module OpenTox end if variables.empty? - prediction = local_weighted_average substance, neighbors + prediction = local_weighted_average(substance: substance, neighbors: neighbors) prediction[:warning] = "No variables for regression model. Using weighted average of similar substances." prediction else @@ -104,7 +104,7 @@ module OpenTox pc_ids.compact! if pc_ids.empty? - prediction = local_weighted_average substance, neighbors + prediction = local_weighted_average(substance: substance, neighbors: neighbors) prediction[:warning] = "No relevant variables for regression model. Using weighted average of similar substances." prediction else diff --git a/test/nanoparticles.rb b/test/nanoparticles.rb index 0446086..b427eb0 100644 --- a/test/nanoparticles.rb +++ b/test/nanoparticles.rb @@ -6,8 +6,8 @@ class NanoparticleTest < MiniTest::Test def setup # TODO: multiple runs create duplicates - #$mongo.database.drop - #Import::Enanomapper.import File.join(File.dirname(__FILE__),"data","enm") + $mongo.database.drop + Import::Enanomapper.import File.join(File.dirname(__FILE__),"data","enm") end def test_create_model -- cgit v1.2.3