summaryrefslogtreecommitdiff
path: root/test/regression-validation.rb
diff options
context:
space:
mode:
authorhelma@in-silico.ch <helma@in-silico.ch>2018-10-12 21:58:36 +0200
committerhelma@in-silico.ch <helma@in-silico.ch>2018-10-12 21:58:36 +0200
commit9d17895ab9e8cd31e0f32e8e622e13612ea5ff77 (patch)
treed6984f0bd81679228d0dfd903aad09c7005f1c4c /test/regression-validation.rb
parentde763211bd2b6451e3a8dc20eb95a3ecf72bef17 (diff)
validation statistic fixes
Diffstat (limited to 'test/regression-validation.rb')
-rw-r--r--test/regression-validation.rb91
1 files changed, 91 insertions, 0 deletions
diff --git a/test/regression-validation.rb b/test/regression-validation.rb
new file mode 100644
index 0000000..44162c0
--- /dev/null
+++ b/test/regression-validation.rb
@@ -0,0 +1,91 @@
+require_relative "setup.rb"
+
+class ValidationRegressionTest < MiniTest::Test
+ include OpenTox::Validation
+
+ # defaults
+
+ def test_default_regression_crossvalidation
+ dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM_log10.csv"
+ model = Model::Lazar.create training_dataset: dataset
+ cv = RegressionCrossValidation.create model
+ assert cv.rmse[:all] < 1.5, "RMSE #{cv.rmse[:all]} should be smaller than 1.5, this may occur due to unfavorable training/test set splits"
+ assert cv.mae[:all] < 1.1, "MAE #{cv.mae[:all]} should be smaller than 1.1, this may occur due to unfavorable training/test set splits"
+ assert cv.within_prediction_interval[:all]/cv.nr_predictions[:all] > 0.8, "Only #{(100*cv.within_prediction_interval[:all]/cv.nr_predictions[:all]).round(2)}% of measurement within prediction interval. This may occur due to unfavorable training/test set splits"
+ end
+
+ # parameters
+
+ def test_regression_crossvalidation_params
+ dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi_log10.csv"
+ algorithms = {
+ :prediction => { :method => "OpenTox::Algorithm::Regression.weighted_average" },
+ :descriptors => { :type => "MACCS", },
+ :similarity => {:min => 0.7}
+ }
+ model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
+ assert_equal algorithms[:descriptors][:type], model.algorithms[:descriptors][:type]
+ cv = RegressionCrossValidation.create model
+ cv.validation_ids.each do |vid|
+ model = Model::Lazar.find(Validation.find(vid).model_id)
+ assert_equal algorithms[:descriptors][:type], model.algorithms[:descriptors][:type]
+ assert_equal algorithms[:similarity][:min], model.algorithms[:similarity][:min]
+ refute_nil model.training_dataset_id
+ refute_equal dataset.id, model.training_dataset_id
+ end
+
+ refute_nil cv.rmse[:all]
+ refute_nil cv.mae[:all]
+ end
+
+ def test_physchem_regression_crossvalidation
+ training_dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
+ model = Model::Lazar.create training_dataset:training_dataset
+ cv = RegressionCrossValidation.create model
+ refute_nil cv.rmse[:all]
+ refute_nil cv.mae[:all]
+ end
+
+ # LOO
+
+ def test_regression_loo_validation
+ dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
+ model = Model::Lazar.create training_dataset: dataset
+ loo = RegressionLeaveOneOut.create model
+ assert loo.r_squared[:all] > 0.34, "R^2 (#{loo.r_squared[:all]}) should be larger than 0.034"
+ end
+
+ def test_regression_loo_validation_with_feature_selection
+ dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
+ algorithms = {
+ :descriptors => {
+ :method => "calculate_properties",
+ :features => PhysChem.openbabel_descriptors,
+ },
+ :similarity => {
+ :method => "Algorithm::Similarity.weighted_cosine",
+ :min => 0.5
+ },
+ :feature_selection => {
+ :method => "Algorithm::FeatureSelection.correlation_filter",
+ },
+ }
+ model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
+ assert_raises OpenTox::BadRequestError do
+ loo = RegressionLeaveOneOut.create model
+ end
+ end
+
+ # repeated CV
+
+ def test_repeated_crossvalidation
+ dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
+ model = Model::Lazar.create training_dataset: dataset
+ repeated_cv = RepeatedCrossValidation.create model
+ repeated_cv.crossvalidations.each do |cv|
+ assert cv.r_squared[:all] > 0.34, "R^2 (#{cv.r_squared[:all]}) should be larger than 0.034"
+ assert cv.rmse[:all] < 1.5, "RMSE (#{cv.rmse[:all]}) should be smaller than 0.5"
+ end
+ end
+
+end