summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2015-10-08 10:32:31 +0200
committerChristoph Helma <helma@in-silico.ch>2015-10-08 10:32:31 +0200
commit6bde559981fa11ffd265af708956f9d4ee6c9a89 (patch)
tree0fdeff56c476bb2eb0e6a2af895a1e9306645904 /test
parentc974ddec27b8e505a8dc22a7c99f2e4b8682aa48 (diff)
crossvalidation plots, original classification confidence
Diffstat (limited to 'test')
-rw-r--r--test/compound.rb13
-rw-r--r--test/setup.rb4
-rw-r--r--test/validation.rb36
3 files changed, 44 insertions, 9 deletions
diff --git a/test/compound.rb b/test/compound.rb
index 036f384..24356d3 100644
--- a/test/compound.rb
+++ b/test/compound.rb
@@ -160,4 +160,17 @@ print c.sdf
end
end
end
+
+ def test_fingerprint_db_neighbors
+ training_dataset = Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.csv")
+ [
+ "CC(=O)CC(C)C#N",
+ "CC(=O)CC(C)C",
+ "C(=O)CC(C)C#N",
+ ].each do |smi|
+ c = OpenTox::Compound.from_smiles smi
+ neighbors = c.db_neighbors(:training_dataset_id => training_dataset.id, :min_sim => 0.2)
+ p neighbors
+ end
+ end
end
diff --git a/test/setup.rb b/test/setup.rb
index 3dad683..ba1b7af 100644
--- a/test/setup.rb
+++ b/test/setup.rb
@@ -3,5 +3,7 @@ require_relative '../lib/lazar.rb'
include OpenTox
TEST_DIR ||= File.expand_path(File.dirname(__FILE__))
DATA_DIR ||= File.join(TEST_DIR,"data")
+Mongoid.configure.connect_to("test")
+$mongo = Mongo::Client.new('mongodb://127.0.0.1:27017/test')
#$mongo.database.drop
-#$gridfs = $mongo.database.fs # recreate GridFS indexes
+$gridfs = $mongo.database.fs
diff --git a/test/validation.rb b/test/validation.rb
index af5ea60..6764a32 100644
--- a/test/validation.rb
+++ b/test/validation.rb
@@ -16,11 +16,35 @@ class ValidationTest < MiniTest::Test
model = Model::LazarClassification.create dataset#, features
cv = ClassificationCrossValidation.create model
assert cv.accuracy > 0.7
+ File.open("tmp.svg","w+"){|f| f.puts cv.confidence_plot}
+ `inkview tmp.svg`
p cv.nr_unpredicted
p cv.accuracy
#assert cv.weighted_accuracy > cv.accuracy, "Weighted accuracy should be larger than unweighted accuracy."
end
+ def test_default_regression_crossvalidation
+ dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv"
+ model = Model::LazarRegression.create dataset
+ cv = RegressionCrossValidation.create model
+ #cv = RegressionCrossValidation.find '561503262b72ed54fd000001'
+ p cv.id
+ File.open("tmp.svg","w+"){|f| f.puts cv.correlation_plot}
+ `inkview tmp.svg`
+ File.open("tmp.svg","w+"){|f| f.puts cv.confidence_plot}
+ `inkview tmp.svg`
+
+ #puts cv.misclassifications.to_yaml
+ p cv.rmse
+ p cv.weighted_rmse
+ assert cv.rmse < 1.5, "RMSE > 1.5"
+ #assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) "
+ p cv.mae
+ p cv.weighted_mae
+ assert cv.mae < 1
+ #assert cv.weighted_mae < cv.mae
+ end
+
def test_regression_crossvalidation
dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.medi.csv"
#dataset = Dataset.from_csv_file "#{DATA_DIR}/EPAFHM.csv"
@@ -41,13 +65,8 @@ class ValidationTest < MiniTest::Test
refute_equal params[:neighbor_algorithm_parameters][:training_dataset_id], model[:neighbor_algorithm_parameters][:training_dataset_id]
end
- #`inkview #{cv.plot}`
- #puts JSON.pretty_generate(cv.misclassifications)#.collect{|l| l.join ", "}.join "\n"
- #`inkview #{cv.plot}`
- assert cv.rmse < 30, "RMSE > 30"
- #assert cv.weighted_rmse < cv.rmse, "Weighted RMSE (#{cv.weighted_rmse}) larger than unweighted RMSE(#{cv.rmse}) "
- assert cv.mae < 12
- #assert cv.weighted_mae < cv.mae
+ assert cv.rmse < 1.5, "RMSE > 30"
+ assert cv.mae < 1
end
def test_repeated_crossvalidation
@@ -55,7 +74,8 @@ class ValidationTest < MiniTest::Test
model = Model::LazarClassification.create dataset
repeated_cv = RepeatedCrossValidation.create model
repeated_cv.crossvalidations.each do |cv|
- assert cv.accuracy > 0.7
+ assert_operator cv.accuracy, :>, 0.7, "model accuracy < 0.7, this may happen by chance due to an unfavorable training/test set split"
+ assert_operator cv.weighted_accuracy, :>, cv.accuracy
end
end