From 6bde559981fa11ffd265af708956f9d4ee6c9a89 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Thu, 8 Oct 2015 10:32:31 +0200 Subject: crossvalidation plots, original classification confidence --- test/validation.rb | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) (limited to 'test/validation.rb') diff --git a/test/validation.rb b/test/validation.rb index af5ea60..6764a32 100644 --- a/test/validation.rb +++ b/test/validation.rb @@ -16,11 +16,35 @@ class ValidationTest < MiniTest::Test model = Model::LazarClassification.create dataset#, features cv = ClassificationCrossValidation.create model assert cv.accuracy > 0.7 + File.open("tmp.svg","w+"){|f| f.puts cv.confidence_plot} + `inkview tmp.svg` p cv.nr_unpredicted p cv.accuracy #assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy should be larger than unweighted accuracy." end + def test_default_regression_crossvalidation + dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv" + model = Model::LazarRegression.create dataset + cv = RegressionCrossValidation.create model + #cv = RegressionCrossValidation.find '561503262b72ed54fd000001' + p cv.id + File.open("tmp.svg","w+"){|f| f.puts cv.correlation_plot} + `inkview tmp.svg` + File.open("tmp.svg","w+"){|f| f.puts cv.confidence_plot} + `inkview tmp.svg` + + #puts cv.misclassifications.to_yaml + p cv.rmse + p cv.weighted_rmse + assert cv.rmse < 1.5, "RMSE > 1.5" + #assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) " + p cv.mae + p cv.weighted_mae + assert cv.mae < 1 + #assert cv.weighted_mae < cv.mae + end + def test_regression_crossvalidation dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv" #dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv" @@ -41,13 +65,8 @@ class ValidationTest < MiniTest::Test refute_equal params[:neighbor_algorithm_parameters][:training_dataset_id], model[:neighbor_algorithm_parameters][:training_dataset_id] end - #`inkview #{cv.plot}` - #puts JSON.pretty_generate(cv.misclassifications)#.collect{|l| l.join ", "}.join "\n" - #`inkview #{cv.plot}` - assert cv.rmse < 30, "RMSE > 30" - #assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) " - assert cv.mae < 12 - #assert cv.weighted_mae < cv.mae + assert cv.rmse < 1.5, "RMSE > 30" + assert cv.mae < 1 end def test_repeated_crossvalidation @@ -55,7 +74,8 @@ class ValidationTest < MiniTest::Test model = Model::LazarClassification.create dataset repeated_cv = RepeatedCrossValidation.create model repeated_cv.crossvalidations.each do |cv| - assert cv.accuracy > 0.7 + assert_operator cv.accuracy, :>, 0.7, "model accuracy < 0.7, this may happen by chance due to an unfavorable training/test set split" + assert_operator cv.weighted_accuracy, :>, cv.accuracy end end -- cgit v1.2.3