blob: 009c337b19d71bdd135d801581686140f5083592 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
|
require_relative "setup.rb"
class ValidationTest < MiniTest::Test
def test_fminer_crossvalidation
dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
model = Model::LazarFminerClassification.create dataset
cv = ClassificationCrossValidation.create model
p cv.accuracy
p cv.weighted_accuracy
refute_empty cv.validation_ids
assert cv.accuracy > 0.8
assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy (#{cv.weighted_accuracy}) larger than unweighted accuracy(#{cv.accuracy}) "
end
def test_classification_crossvalidation
dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
model = Model::LazarClassification.create dataset#, features
cv = ClassificationCrossValidation.create model
p cv.accuracy
p cv.weighted_accuracy
assert cv.accuracy > 0.7
assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy should be larger than unweighted accuracy."
end
def test_regression_crossvalidation
#dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv"
dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv"
model = Model::LazarRegression.create dataset
cv = RegressionCrossValidation.create model
p cv.rmse
p cv.weighted_rmse
p cv.mae
p cv.weighted_mae
#`inkview #{cv.plot}`
#puts JSON.pretty_generate(cv.misclassifications)#.collect{|l| l.join ", "}.join "\n"
p cv.misclassifications.collect{|l| l[:neighbors].size}
`inkview #{cv.plot}`
assert cv.rmse < 30, "RMSE > 30"
assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) "
assert cv.mae < 12
assert cv.weighted_mae < cv.mae
end
end
|