summaryrefslogtreecommitdiff
path: root/test/classification-validation.rb
blob: b913e1e766a0016d18b39c114434143bea0d8a5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
require_relative "setup.rb"

class ClassificationValidationTest < MiniTest::Test
  include OpenTox::Validation

  # defaults

  def test_default_classification_crossvalidation
    dataset = Dataset.from_csv_file File.join(Download::DATA,"Carcinogenicity-Rodents.csv")
    model = Model::Lazar.create training_dataset: dataset
    cv = ClassificationCrossValidation.create model
    assert cv.accuracy[:all] > 0.65, "Accuracy (#{cv.accuracy[:all]}) should be larger than 0.65, this may occur due to an unfavorable training/test set split"
    File.open("/tmp/tmp.pdf","w+"){|f| f.puts cv.probability_plot(format:"pdf")}
    assert_match "PDF", `file -b /tmp/tmp.pdf`
    File.open("/tmp/tmp.png","w+"){|f| f.puts cv.probability_plot(format:"png")}
    assert_match "PNG", `file -b /tmp/tmp.png`
  end

  # parameters

  def test_classification_crossvalidation_parameters
    dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    algorithms = {
      :similarity => { :min => [0.9,0.8] },
      :descriptors => { :type => "FP3" }
    }
    model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
    cv = ClassificationCrossValidation.create model
    params = model.algorithms
    params = JSON.parse(params.to_json) # convert symbols to string
    p cv
    
    cv.validations.each do |validation|
      validation_params = validation.model.algorithms
      refute_nil model.training_dataset_id
      refute_nil validation.model.training_dataset_id
      refute_equal model.training_dataset_id, validation.model.training_dataset_id
      assert_equal params, validation_params
      keys = cv.accuracy.keys
      av = cv.accept_values
      types = ["nr_predictions", \
               "predictivity", \
               "true_rate", \
               "confusion_matrix"
      ]
      types.each do |type|
        keys.each do |key|
          case type
          when "confusion_matrix"
            cv[type][key].each do |arr|
              arr.each do |a|
                refute_nil a
                assert a > 0, "#{cv[type][key]} values should be greater than 0."
              end
            end
          when "predictivity", "true_rate"
            av.each do |v|
              refute_nil cv[type][key][v]
              assert cv[type][key][v] > 0, "#{cv[type][key]} values should be greater than 0."
            end
          else
            refute_nil cv[type][key]
            assert cv[type][key] > 0, "#{cv[type][key]} value should be greater than 0."
          end
        end
      end
    end
  end
  
  # LOO

  def test_classification_loo_validation
    dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    model = Model::Lazar.create training_dataset: dataset
    loo = ClassificationLeaveOneOut.create model
    refute_empty loo.confusion_matrix
    assert loo.accuracy[:all] > 0.650
  end

  # repeated CV

  def test_repeated_crossvalidation
    dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    model = Model::Lazar.create training_dataset: dataset
    repeated_cv = RepeatedCrossValidation.create model
    repeated_cv.crossvalidations.each do |cv|
      assert_operator cv.accuracy[:all], :>, 0.65, "model accuracy < 0.65, this may happen by chance due to an unfavorable training/test set split"
    end
  end
  
  def test_validation_model
    m = Model::Validation.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    [:endpoint,:species,:source].each do |p|
      refute_empty m[p]
    end
    puts m.to_json
    assert m.classification?
    refute m.regression?
    m.crossvalidations.each do |cv|
      assert cv.accuracy[:all] > 0.65, "Crossvalidation accuracy (#{cv.accuracy[:all]}) should be larger than 0.65. This may happen due to an unfavorable training/test set split."
    end
    prediction = m.predict Compound.from_smiles("OCC(CN(CC(O)C)N=O)O")
    assert_equal "false", prediction[:value]
    m.delete
  end

  def test_carcinogenicity_rf_classification
    skip "Caret rf classification may run into a (endless?) loop for some compounds."
    dataset = Dataset.from_csv_file File.join(Download::DATA,"Carcinogenicity-Rodents.csv")
    algorithms = {
      :prediction => {
        :method => "Algorithm::Caret.rf",
      },
    }
    model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
    cv = ClassificationCrossValidation.create model
#    cv = ClassificationCrossValidation.find "5bbc822dca626919731e2822"
    puts cv.statistics
    puts cv.id
    
  end

  def test_mutagenicity_classification_algorithms
    skip "Caret rf classification may run into a (endless?) loop for some compounds."
    source_feature = Feature.where(:name => "Ames test categorisation").first
    target_feature = Feature.where(:name => "Mutagenicity").first
    kazius = Dataset.from_sdf_file "#{Download::DATA}/parts/cas_4337.sdf"
    hansen = Dataset.from_csv_file "#{Download::DATA}/parts/hansen.csv"
    efsa = Dataset.from_csv_file "#{Download::DATA}/parts/efsa.csv"
    dataset = Dataset.merge [kazius,hansen,efsa], {source_feature => target_feature}, {1 => "mutagen", 0 => "nonmutagen"}
    model = Model::Lazar.create training_dataset: dataset
    repeated_cv = RepeatedCrossValidation.create model
    puts repeated_cv.id
    repeated_cv.crossvalidations.each do |cv|
      puts cv.accuracy
      puts cv.confusion_matrix
    end
    algorithms = {
      :prediction => {
        :method => "Algorithm::Caret.rf",
      },
    }
    model = Model::Lazar.create training_dataset: dataset, algorithms: algorithms
    repeated_cv = RepeatedCrossValidation.create model
    puts repeated_cv.id
    repeated_cv.crossvalidations.each do |cv|
      puts cv.accuracy
      puts cv.confusion_matrix
    end
    
  end

end