summaryrefslogtreecommitdiff
path: root/test/validation.rb
blob: 485769c073f5a76a9a2cf26dc4e927b5d7df9d92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
require_relative "setup.rb"

class ValidationTest < MiniTest::Test

  def test_fminer_crossvalidation
    dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    model = Model::LazarFminerClassification.create dataset
    cv = ClassificationCrossValidation.create model
    p cv.accuracy
    p cv.weighted_accuracy
    refute_empty cv.validation_ids
    assert cv.accuracy > 0.8
    assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy (#{cv.weighted_accuracy}) larger than unweighted accuracy(#{cv.accuracy}) "
  end

  def test_classification_crossvalidation
    dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    model = Model::LazarClassification.create dataset#, features
    cv = ClassificationCrossValidation.create model
    p cv.accuracy
    p cv.weighted_accuracy
    assert cv.accuracy > 0.7
    assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy should be larger than unweighted accuracy."
  end

  def test_regression_crossvalidation
    dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv"
    #dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv"
    model = Model::LazarRegression.create dataset
    cv = RegressionCrossValidation.create model
    p cv.rmse 
    p cv.weighted_rmse
    p cv.mae
    p cv.weighted_mae
    #`inkview #{cv.plot}`
    assert cv.rmse < 30, "RMSE > 30"
    assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) "
    assert cv.mae < 12
    assert cv.weighted_mae < cv.mae
  end

end