From 6bde559981fa11ffd265af708956f9d4ee6c9a89 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Thu, 8 Oct 2015 10:32:31 +0200 Subject: crossvalidation plots, original classification confidence --- test/compound.rb | 13 +++++++++++++ test/setup.rb | 4 +++- test/validation.rb | 36 ++++++++++++++++++++++++++++-------- 3 files changed, 44 insertions(+), 9 deletions(-) (limited to 'test') diff --git a/test/compound.rb b/test/compound.rb index 036f384..24356d3 100644 --- a/test/compound.rb +++ b/test/compound.rb @@ -160,4 +160,17 @@ print c.sdf end end end + + def test_fingerprint_db_neighbors + training_dataset = Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.csv") + [ + "CC(=O)CC(C)C#N", + "CC(=O)CC(C)C", + "C(=O)CC(C)C#N", + ].each do |smi| + c = OpenTox::Compound.from_smiles smi + neighbors = c.db_neighbors(:training_dataset_id => training_dataset.id, :min_sim => 0.2) + p neighbors + end + end end diff --git a/test/setup.rb b/test/setup.rb index 3dad683..ba1b7af 100644 --- a/test/setup.rb +++ b/test/setup.rb @@ -3,5 +3,7 @@ require_relative '../lib/lazar.rb' include OpenTox TEST_DIR ||= File.expand_path(File.dirname(__FILE__)) DATA_DIR ||= File.join(TEST_DIR,"data") +Mongoid.configure.connect_to("test") +$mongo = Mongo::Client.new('mongodb://127.0.0.1:27017/test') #$mongo.database.drop -#$gridfs = $mongo.database.fs # recreate GridFS indexes +$gridfs = $mongo.database.fs diff --git a/test/validation.rb b/test/validation.rb index af5ea60..6764a32 100644 --- a/test/validation.rb +++ b/test/validation.rb @@ -16,11 +16,35 @@ class ValidationTest < MiniTest::Test model = Model::LazarClassification.create dataset#, features cv = ClassificationCrossValidation.create model assert cv.accuracy > 0.7 + File.open("tmp.svg","w+"){|f| f.puts cv.confidence_plot} + `inkview tmp.svg` p cv.nr_unpredicted p cv.accuracy #assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy should be larger than unweighted accuracy." end + def test_default_regression_crossvalidation + dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv" + model = Model::LazarRegression.create dataset + cv = RegressionCrossValidation.create model + #cv = RegressionCrossValidation.find '561503262b72ed54fd000001' + p cv.id + File.open("tmp.svg","w+"){|f| f.puts cv.correlation_plot} + `inkview tmp.svg` + File.open("tmp.svg","w+"){|f| f.puts cv.confidence_plot} + `inkview tmp.svg` + + #puts cv.misclassifications.to_yaml + p cv.rmse + p cv.weighted_rmse + assert cv.rmse < 1.5, "RMSE > 1.5" + #assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) " + p cv.mae + p cv.weighted_mae + assert cv.mae < 1 + #assert cv.weighted_mae < cv.mae + end + def test_regression_crossvalidation dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv" #dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv" @@ -41,13 +65,8 @@ class ValidationTest < MiniTest::Test refute_equal params[:neighbor_algorithm_parameters][:training_dataset_id], model[:neighbor_algorithm_parameters][:training_dataset_id] end - #`inkview #{cv.plot}` - #puts JSON.pretty_generate(cv.misclassifications)#.collect{|l| l.join ", "}.join "\n" - #`inkview #{cv.plot}` - assert cv.rmse < 30, "RMSE > 30" - #assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) " - assert cv.mae < 12 - #assert cv.weighted_mae < cv.mae + assert cv.rmse < 1.5, "RMSE > 30" + assert cv.mae < 1 end def test_repeated_crossvalidation @@ -55,7 +74,8 @@ class ValidationTest < MiniTest::Test model = Model::LazarClassification.create dataset repeated_cv = RepeatedCrossValidation.create model repeated_cv.crossvalidations.each do |cv| - assert cv.accuracy > 0.7 + assert_operator cv.accuracy, :>, 0.7, "model accuracy < 0.7, this may happen by chance due to an unfavorable training/test set split" + assert_operator cv.weighted_accuracy, :>, cv.accuracy end end -- cgit v1.2.3