From db21f62399669be741e4e692278344e16cb39580 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Sun, 9 Aug 2015 13:43:07 +0200 Subject: customized prediction algorithms implemented --- test/compound.rb | 1 + test/validation.rb | 36 ++++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/test/compound.rb b/test/compound.rb index 2bf0204..d1b1520 100644 --- a/test/compound.rb +++ b/test/compound.rb @@ -85,6 +85,7 @@ class CompoundTest < MiniTest::Test d.compounds.each do |c| refute_nil c.fp4 end + # TODO assert neighbor size c = d.compounds[371] #p c p c.neighbors diff --git a/test/validation.rb b/test/validation.rb index 80055f2..9348c6f 100644 --- a/test/validation.rb +++ b/test/validation.rb @@ -2,12 +2,40 @@ require_relative "setup.rb" class ValidationTest < MiniTest::Test + def test_fminer_crossvalidation + dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv" + model = Model::LazarFminerClassification.create dataset#, features + cv = ClassificationCrossValidation.create model + p cv.accuracy + p cv.weighted_accuracy + assert cv.accuracy > 0.8 + assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy (#{cv.weighted_accuracy}) larger than unweighted accuracy(#{cv.accuracy}) " + end + def test_classification_crossvalidation dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv" - features = Algorithm::Fminer.bbrc dataset - model = Model::Lazar.create dataset, features - cv = CrossValidation.create model - p cv + model = Model::LazarClassification.create dataset#, features + cv = ClassificationCrossValidation.create model + p cv.accuracy + p cv.weighted_accuracy + assert cv.accuracy > 0.7 + assert cv.weighted_accuracy > cv.accuracy + end + + def test_regression_crossvalidation + dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv" + #dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv" + model = Model::LazarRegression.create dataset + cv = RegressionCrossValidation.create model + p cv.rmse + p cv.weighted_rmse + p cv.mae + p cv.weighted_mae + `inkview #{cv.plot}` + assert cv.rmse < 30, "RMSE > 30" + assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) " + assert cv.mae < 12 + assert cv.weighted_mae < cv.mae end end -- cgit v1.2.3