65bec635bc55751b61bb874b41e268c5097cdcd8
[lazar] / test / regression-validation.rb
1 require_relative "setup.rb"
2
3 class RegressionValidationTest < MiniTest::Test
4   include OpenTox::Validation
5
6   # defaults
7   
8   def test_default_regression_crossvalidation
9     training_dataset = Dataset.from_csv_file File.join(Download::DATA, "Acute_toxicity-Fathead_minnow.csv")
10     dataset = Dataset.from_csv_file File.join(Download::DATA, "Acute_toxicity-Fathead_minnow.csv")
11     model = Model::Lazar.create training_dataset: dataset
12     cv = RegressionCrossValidation.create model
13     assert cv.rmse[:all] < 1.5, "RMSE #{cv.rmse[:all]} should be smaller than 1.5, this may occur due to unfavorable training/test set splits"
14     assert cv.mae[:all] < 1.1, "MAE #{cv.mae[:all]} should be smaller than 1.1, this may occur due to unfavorable training/test set splits"
15     assert cv.within_prediction_interval[:all]/cv.nr_predictions[:all].to_f > 0.8, "Only #{(100.0*cv.within_prediction_interval[:all]/cv.nr_predictions[:all]).round(2)}% of measurement within prediction interval. This may occur due to unfavorable training/test set splits"
16   end
17
18   # parameters
19   
20   def test_regression_crossvalidation_params
21     dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi_log10.csv"
22     algorithms = {
23       :prediction => { :method => "OpenTox::Algorithm::Regression.weighted_average" },
24       :descriptors => { :type => "MACCS", },
25       :similarity => {:min => [0.9,0.1]}
26     }
27     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
28     assert_equal algorithms[:descriptors][:type], model.algorithms[:descriptors][:type]
29     cv = RegressionCrossValidation.create model
30     cv.validation_ids.each do |vid|
31       model = Model::Lazar.find(Validation.find(vid).model_id)
32       assert_equal algorithms[:descriptors][:type], model.algorithms[:descriptors][:type]
33       assert_equal algorithms[:similarity][:min], model.algorithms[:similarity][:min]
34       refute_nil model.training_dataset_id
35       refute_equal dataset.id, model.training_dataset_id
36     end
37
38     refute_nil cv.rmse[:all]
39     refute_nil cv.mae[:all]
40   end
41
42   def test_physchem_regression_crossvalidation
43     training_dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
44     model = Model::Lazar.create training_dataset:training_dataset
45     cv = RegressionCrossValidation.create model
46     refute_nil cv.rmse[:all]
47     refute_nil cv.mae[:all]
48   end
49
50   # LOO
51
52   def test_regression_loo_validation
53     dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
54     model = Model::Lazar.create training_dataset: dataset
55     loo = RegressionLeaveOneOut.create model
56     assert loo.r_squared[:all] > 0.34, "R^2 (#{loo.r_squared[:all]}) should be larger than 0.034"
57   end
58
59   def test_regression_loo_validation_with_feature_selection
60     dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
61     algorithms = {
62       :descriptors => {
63         :method => "calculate_properties",
64         :features => PhysChem.openbabel_descriptors,
65       },
66       :similarity => {
67         :method => "Algorithm::Similarity.weighted_cosine",
68         :min => [0.5,0.1]
69       },
70       :feature_selection => {
71         :method => "Algorithm::FeatureSelection.correlation_filter",
72       },
73     }
74     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
75     assert_raises ArgumentError do
76       loo = RegressionLeaveOneOut.create model
77     end
78   end
79
80   # repeated CV
81
82   def test_repeated_crossvalidation
83     dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
84     model = Model::Lazar.create training_dataset: dataset
85     repeated_cv = RepeatedCrossValidation.create model
86     repeated_cv.crossvalidations.each do |cv|
87       assert cv.r_squared[:all] > 0.34, "R^2 (#{cv.r_squared[:all]}) should be larger than 0.34"
88       assert cv.rmse[:all] < 1.5, "RMSE (#{cv.rmse[:all]}) should be smaller than 0.5"
89     end
90   end
91
92 end