33f0353706914ffed76c959f96554c43187354c0
[lazar] / test / classification-validation.rb
1 require_relative "setup.rb"
2
3 class ClassificationValidationTest < MiniTest::Test
4   include OpenTox::Validation
5
6   # defaults
7
8   def test_default_classification_crossvalidation
9     dataset = Dataset.from_csv_file File.join(Download::DATA,"Carcinogenicity-Rodents.csv")
10     model = Model::Lazar.create training_dataset: dataset
11     cv = ClassificationCrossValidation.create model
12     assert cv.accuracy[:all] > 0.65, "Accuracy (#{cv.accuracy[:all]}) should be larger than 0.65, this may occur due to an unfavorable training/test set split"
13     File.open("/tmp/tmp.pdf","w+"){|f| f.puts cv.probability_plot(format:"pdf")}
14     assert_match "PDF", `file -b /tmp/tmp.pdf`
15     File.open("/tmp/tmp.png","w+"){|f| f.puts cv.probability_plot(format:"png")}
16     assert_match "PNG", `file -b /tmp/tmp.png`
17   end
18
19   # parameters
20
21   def test_classification_crossvalidation_parameters
22     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
23     algorithms = {
24       :similarity => { :min => [0.9,0.8] },
25       :descriptors => { :type => "FP3" }
26     }
27     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
28     cv = ClassificationCrossValidation.create model
29     params = model.algorithms
30     params = JSON.parse(params.to_json) # convert symbols to string
31     p cv
32     
33     cv.validations.each do |validation|
34       validation_params = validation.model.algorithms
35       refute_nil model.training_dataset_id
36       refute_nil validation.model.training_dataset_id
37       refute_equal model.training_dataset_id, validation.model.training_dataset_id
38       assert_equal params, validation_params
39     end
40   end
41   
42   # LOO
43
44   def test_classification_loo_validation
45     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
46     model = Model::Lazar.create training_dataset: dataset
47     loo = ClassificationLeaveOneOut.create model
48     refute_empty loo.confusion_matrix
49     assert loo.accuracy[:all] > 0.650
50   end
51
52   # repeated CV
53
54   def test_repeated_crossvalidation
55     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
56     model = Model::Lazar.create training_dataset: dataset
57     repeated_cv = RepeatedCrossValidation.create model
58     repeated_cv.crossvalidations.each do |cv|
59       assert_operator cv.accuracy[:all], :>, 0.65, "model accuracy < 0.65, this may happen by chance due to an unfavorable training/test set split"
60     end
61   end
62   
63   def test_validation_model
64     m = Model::Validation.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
65     [:endpoint,:species,:source].each do |p|
66       refute_empty m[p]
67     end
68     puts m.to_json
69     assert m.classification?
70     refute m.regression?
71     m.crossvalidations.each do |cv|
72       assert cv.accuracy[:all] > 0.65, "Crossvalidation accuracy (#{cv.accuracy[:all]}) should be larger than 0.65. This may happen due to an unfavorable training/test set split."
73     end
74     prediction = m.predict Compound.from_smiles("OCC(CN(CC(O)C)N=O)O")
75     assert_equal "false", prediction[:value]
76     m.delete
77   end
78
79   def test_carcinogenicity_rf_classification
80     skip "Caret rf classification may run into a (endless?) loop for some compounds."
81     dataset = Dataset.from_csv_file File.join(Download::DATA,"Carcinogenicity-Rodents.csv")
82     algorithms = {
83       :prediction => {
84         :method => "Algorithm::Caret.rf",
85       },
86     }
87     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
88     cv = ClassificationCrossValidation.create model
89 #    cv = ClassificationCrossValidation.find "5bbc822dca626919731e2822"
90     puts cv.statistics
91     puts cv.id
92     
93   end
94
95   def test_mutagenicity_classification_algorithms
96     skip "Caret rf classification may run into a (endless?) loop for some compounds."
97     source_feature = Feature.where(:name => "Ames test categorisation").first
98     target_feature = Feature.where(:name => "Mutagenicity").first
99     kazius = Dataset.from_sdf_file "#{Download::DATA}/parts/cas_4337.sdf"
100     hansen = Dataset.from_csv_file "#{Download::DATA}/parts/hansen.csv"
101     efsa = Dataset.from_csv_file "#{Download::DATA}/parts/efsa.csv"
102     dataset = Dataset.merge [kazius,hansen,efsa], {source_feature => target_feature}, {1 => "mutagen", 0 => "nonmutagen"}
103     model = Model::Lazar.create training_dataset: dataset
104     repeated_cv = RepeatedCrossValidation.create model
105     puts repeated_cv.id
106     repeated_cv.crossvalidations.each do |cv|
107       puts cv.accuracy
108       puts cv.confusion_matrix
109     end
110     algorithms = {
111       :prediction => {
112         :method => "Algorithm::Caret.rf",
113       },
114     }
115     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
116     repeated_cv = RepeatedCrossValidation.create model
117     puts repeated_cv.id
118     repeated_cv.crossvalidations.each do |cv|
119       puts cv.accuracy
120       puts cv.confusion_matrix
121     end
122     
123   end
124
125 end