fixed confidence value for cv stats; added tests
[lazar] / test / classification-validation.rb
1 require_relative "setup.rb"
2
3 class ClassificationValidationTest < MiniTest::Test
4   include OpenTox::Validation
5
6   # defaults
7
8   def test_default_classification_crossvalidation
9     dataset = Dataset.from_csv_file File.join(Download::DATA,"Carcinogenicity-Rodents.csv")
10     model = Model::Lazar.create training_dataset: dataset
11     cv = ClassificationCrossValidation.create model
12     assert cv.accuracy[:all] > 0.65, "Accuracy (#{cv.accuracy[:all]}) should be larger than 0.65, this may occur due to an unfavorable training/test set split"
13     File.open("/tmp/tmp.pdf","w+"){|f| f.puts cv.probability_plot(format:"pdf")}
14     assert_match "PDF", `file -b /tmp/tmp.pdf`
15     File.open("/tmp/tmp.png","w+"){|f| f.puts cv.probability_plot(format:"png")}
16     assert_match "PNG", `file -b /tmp/tmp.png`
17   end
18
19   # parameters
20
21   def test_classification_crossvalidation_parameters
22     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
23     algorithms = {
24       :similarity => { :min => [0.9,0.8] },
25       :descriptors => { :type => "FP3" }
26     }
27     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
28     cv = ClassificationCrossValidation.create model
29     params = model.algorithms
30     params = JSON.parse(params.to_json) # convert symbols to string
31     p cv
32     
33     cv.validations.each do |validation|
34       validation_params = validation.model.algorithms
35       refute_nil model.training_dataset_id
36       refute_nil validation.model.training_dataset_id
37       refute_equal model.training_dataset_id, validation.model.training_dataset_id
38       assert_equal params, validation_params
39       keys = cv.accuracy.keys
40       av = cv.accept_values
41       types = ["nr_predictions", \
42                "predictivity", \
43                "true_rate", \
44                "confusion_matrix"
45       ]
46       types.each do |type|
47         keys.each do |key|
48           case type
49           when "confusion_matrix"
50             cv[type][key].each do |arr|
51               arr.each do |a|
52                 refute_nil a
53                 assert a > 0
54               end
55             end
56           when "predictivity", "true_rate"
57             av.each do |v|
58               refute_nil cv[type][key][v]
59               assert cv[type][key][v] > 0
60             end
61           else
62             refute_nil cv[type][key]
63             assert cv[type][key] > 0
64           end
65         end
66       end
67     end
68   end
69   
70   # LOO
71
72   def test_classification_loo_validation
73     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
74     model = Model::Lazar.create training_dataset: dataset
75     loo = ClassificationLeaveOneOut.create model
76     refute_empty loo.confusion_matrix
77     assert loo.accuracy[:all] > 0.650
78   end
79
80   # repeated CV
81
82   def test_repeated_crossvalidation
83     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
84     model = Model::Lazar.create training_dataset: dataset
85     repeated_cv = RepeatedCrossValidation.create model
86     repeated_cv.crossvalidations.each do |cv|
87       assert_operator cv.accuracy[:all], :>, 0.65, "model accuracy < 0.65, this may happen by chance due to an unfavorable training/test set split"
88     end
89   end
90   
91   def test_validation_model
92     m = Model::Validation.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
93     [:endpoint,:species,:source].each do |p|
94       refute_empty m[p]
95     end
96     puts m.to_json
97     assert m.classification?
98     refute m.regression?
99     m.crossvalidations.each do |cv|
100       assert cv.accuracy[:all] > 0.65, "Crossvalidation accuracy (#{cv.accuracy[:all]}) should be larger than 0.65. This may happen due to an unfavorable training/test set split."
101     end
102     prediction = m.predict Compound.from_smiles("OCC(CN(CC(O)C)N=O)O")
103     assert_equal "false", prediction[:value]
104     m.delete
105   end
106
107   def test_carcinogenicity_rf_classification
108     skip "Caret rf classification may run into a (endless?) loop for some compounds."
109     dataset = Dataset.from_csv_file File.join(Download::DATA,"Carcinogenicity-Rodents.csv")
110     algorithms = {
111       :prediction => {
112         :method => "Algorithm::Caret.rf",
113       },
114     }
115     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
116     cv = ClassificationCrossValidation.create model
117 #    cv = ClassificationCrossValidation.find "5bbc822dca626919731e2822"
118     puts cv.statistics
119     puts cv.id
120     
121   end
122
123   def test_mutagenicity_classification_algorithms
124     skip "Caret rf classification may run into a (endless?) loop for some compounds."
125     source_feature = Feature.where(:name => "Ames test categorisation").first
126     target_feature = Feature.where(:name => "Mutagenicity").first
127     kazius = Dataset.from_sdf_file "#{Download::DATA}/parts/cas_4337.sdf"
128     hansen = Dataset.from_csv_file "#{Download::DATA}/parts/hansen.csv"
129     efsa = Dataset.from_csv_file "#{Download::DATA}/parts/efsa.csv"
130     dataset = Dataset.merge [kazius,hansen,efsa], {source_feature => target_feature}, {1 => "mutagen", 0 => "nonmutagen"}
131     model = Model::Lazar.create training_dataset: dataset
132     repeated_cv = RepeatedCrossValidation.create model
133     puts repeated_cv.id
134     repeated_cv.crossvalidations.each do |cv|
135       puts cv.accuracy
136       puts cv.confusion_matrix
137     end
138     algorithms = {
139       :prediction => {
140         :method => "Algorithm::Caret.rf",
141       },
142     }
143     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
144     repeated_cv = RepeatedCrossValidation.create model
145     puts repeated_cv.id
146     repeated_cv.crossvalidations.each do |cv|
147       puts cv.accuracy
148       puts cv.confusion_matrix
149     end
150     
151   end
152
153 end