summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2015-08-09 13:43:07 +0200
committerChristoph Helma <helma@in-silico.ch>2015-08-09 13:43:07 +0200
commitdb21f62399669be741e4e692278344e16cb39580 (patch)
tree7c882382c70fbfc9dfeb563a59c01def2da9a935
parent2dc460ee78325c731ab12b375c2bd12e4e3393e8 (diff)
customized prediction algorithms implementedmongodb
-rw-r--r--test/compound.rb1
-rw-r--r--test/validation.rb36
2 files changed, 33 insertions, 4 deletions
diff --git a/test/compound.rb b/test/compound.rb
index 2bf0204..d1b1520 100644
--- a/test/compound.rb
+++ b/test/compound.rb
@@ -85,6 +85,7 @@ class CompoundTest < MiniTest::Test
d.compounds.each do |c|
refute_nil c.fp4
end
+ # TODO assert neighbor size
c = d.compounds[371]
#p c
p c.neighbors
diff --git a/test/validation.rb b/test/validation.rb
index 80055f2..9348c6f 100644
--- a/test/validation.rb
+++ b/test/validation.rb
@@ -2,12 +2,40 @@ require_relative "setup.rb"
class ValidationTest < MiniTest::Test
+ def test_fminer_crossvalidation
+ dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
+ model = Model::LazarFminerClassification.create dataset#, features
+ cv = ClassificationCrossValidation.create model
+ p cv.accuracy
+ p cv.weighted_accuracy
+ assert cv.accuracy > 0.8
+ assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy (#{cv.weighted_accuracy}) larger than unweighted accuracy(#{cv.accuracy}) "
+ end
+
def test_classification_crossvalidation
dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
- features = Algorithm::Fminer.bbrc dataset
- model = Model::Lazar.create dataset, features
- cv = CrossValidation.create model
- p cv
+ model = Model::LazarClassification.create dataset#, features
+ cv = ClassificationCrossValidation.create model
+ p cv.accuracy
+ p cv.weighted_accuracy
+ assert cv.accuracy > 0.7
+ assert cv.weighted_accuracy > cv.accuracy
+ end
+
+ def test_regression_crossvalidation
+ dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv"
+ #dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv"
+ model = Model::LazarRegression.create dataset
+ cv = RegressionCrossValidation.create model
+ p cv.rmse
+ p cv.weighted_rmse
+ p cv.mae
+ p cv.weighted_mae
+ `inkview #{cv.plot}`
+ assert cv.rmse < 30, "RMSE > 30"
+ assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) "
+ assert cv.mae < 12
+ assert cv.weighted_mae < cv.mae
end
end