302b2c8c524e9a0c22efdbddd897998116ae675a
[lazar] / test / classification-validation.rb
1 require_relative "setup.rb"
2
3 class ClassificationValidationTest < MiniTest::Test
4   include OpenTox::Validation
5
6   # defaults
7
8   def test_default_classification_crossvalidation
9     #dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
10     dataset = Dataset.from_csv_file "#{DATA_DIR}/multi_cell_call.csv"
11     model = Model::Lazar.create training_dataset: dataset
12     cv = ClassificationCrossValidation.create model
13     assert cv.accuracy[:without_warnings] > 0.65, "Accuracy (#{cv.accuracy[:without_warnings]}) should be larger than 0.65, this may occur due to an unfavorable training/test set split"
14     assert cv.weighted_accuracy[:all] > cv.accuracy[:all], "Weighted accuracy (#{cv.weighted_accuracy[:all]}) should be larger than accuracy (#{cv.accuracy[:all]})."
15     File.open("/tmp/tmp.pdf","w+"){|f| f.puts cv.probability_plot(format:"pdf")}
16     assert_match "PDF", `file -b /tmp/tmp.pdf`
17     File.open("/tmp/tmp.png","w+"){|f| f.puts cv.probability_plot(format:"png")}
18     assert_match "PNG", `file -b /tmp/tmp.png`
19   end
20
21   # parameters
22
23   def test_classification_crossvalidation_parameters
24     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
25     algorithms = {
26       :similarity => { :min => 0.3, },
27       :descriptors => { :type => "FP3" }
28     }
29     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
30     cv = ClassificationCrossValidation.create model
31     params = model.algorithms
32     params = JSON.parse(params.to_json) # convert symbols to string
33     
34     cv.validations.each do |validation|
35       validation_params = validation.model.algorithms
36       refute_nil model.training_dataset_id
37       refute_nil validation.model.training_dataset_id
38       refute_equal model.training_dataset_id, validation.model.training_dataset_id
39       assert_equal params, validation_params
40     end
41   end
42   
43   # LOO
44
45   def test_classification_loo_validation
46     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
47     model = Model::Lazar.create training_dataset: dataset
48     loo = ClassificationLeaveOneOut.create model
49     refute_empty loo.confusion_matrix
50     assert loo.accuracy[:without_warnings] > 0.650
51     assert loo.weighted_accuracy[:all] > loo.accuracy[:all], "Weighted accuracy (#{loo.weighted_accuracy[:all]}) should be larger than accuracy (#{loo.accuracy[:all]})."
52   end
53
54   # repeated CV
55
56   def test_repeated_crossvalidation
57     dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
58     model = Model::Lazar.create training_dataset: dataset
59     repeated_cv = RepeatedCrossValidation.create model
60     repeated_cv.crossvalidations.each do |cv|
61       assert_operator cv.accuracy[:without_warnings], :>, 0.65, "model accuracy < 0.65, this may happen by chance due to an unfavorable training/test set split"
62     end
63   end
64   
65   def test_validation_model
66     m = Model::Validation.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
67     [:endpoint,:species,:source].each do |p|
68       refute_empty m[p]
69     end
70     puts m.to_json
71     assert m.classification?
72     refute m.regression?
73     m.crossvalidations.each do |cv|
74       assert cv.accuracy[:without_warnings] > 0.65, "Crossvalidation accuracy (#{cv.accuracy[:without_warnings]}) should be larger than 0.65. This may happen due to an unfavorable training/test set split."
75     end
76     prediction = m.predict Compound.from_smiles("OCC(CN(CC(O)C)N=O)O")
77     assert_equal "false", prediction[:value]
78     m.delete
79   end
80
81   def test_carcinogenicity_rf_classification
82     skip "Caret rf classification may run into a (endless?) loop for some compounds."
83     dataset = Dataset.from_csv_file "#{DATA_DIR}/multi_cell_call.csv"
84     algorithms = {
85       :prediction => {
86         :method => "Algorithm::Caret.rf",
87       },
88     }
89     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
90     cv = ClassificationCrossValidation.create model
91 #    cv = ClassificationCrossValidation.find "5bbc822dca626919731e2822"
92     puts cv.statistics
93     puts cv.id
94     
95   end
96
97   def test_mutagenicity_classification_algorithms
98     skip "Caret rf classification may run into a (endless?) loop for some compounds."
99     source_feature = Feature.where(:name => "Ames test categorisation").first
100     target_feature = Feature.where(:name => "Mutagenicity").first
101     kazius = Dataset.from_sdf_file "#{DATA_DIR}/cas_4337.sdf"
102     hansen = Dataset.from_csv_file "#{DATA_DIR}/hansen.csv"
103     efsa = Dataset.from_csv_file "#{DATA_DIR}/efsa.csv"
104     dataset = Dataset.merge [kazius,hansen,efsa], {source_feature => target_feature}, {1 => "mutagen", 0 => "nonmutagen"}
105     model = Model::Lazar.create training_dataset: dataset
106     repeated_cv = RepeatedCrossValidation.create model
107     puts repeated_cv.id
108     repeated_cv.crossvalidations.each do |cv|
109       puts cv.accuracy
110       puts cv.confusion_matrix
111     end
112     algorithms = {
113       :prediction => {
114         :method => "Algorithm::Caret.rf",
115       },
116     }
117     model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
118     repeated_cv = RepeatedCrossValidation.create model
119     puts repeated_cv.id
120     repeated_cv.crossvalidations.each do |cv|
121       puts cv.accuracy
122       puts cv.confusion_matrix
123     end
124     
125   end
126
127 end