summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2016-06-01 10:37:00 +0200
committerChristoph Helma <helma@in-silico.ch>2016-06-01 10:37:00 +0200
commit65b69d4c35890a7a2d2992108f0cf4eb5202dd1b (patch)
treed8583d3e745cc21a6897081d28e157f5c0c16024
parentb515a0cfedb887a2af753db6e4a08ae1af430cad (diff)
validation tests fixed
-rw-r--r--lib/crossvalidation.rb24
-rw-r--r--lib/leave-one-out-validation.rb1
-rw-r--r--lib/model.rb3
-rw-r--r--lib/validation-statistics.rb19
-rw-r--r--lib/validation.rb6
-rw-r--r--test/all.rb2
-rw-r--r--test/validation.rb2
7 files changed, 24 insertions, 33 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb
index 22071d8..15e25a5 100644
--- a/lib/crossvalidation.rb
+++ b/lib/crossvalidation.rb
@@ -3,23 +3,7 @@ module OpenTox
module Validation
class CrossValidation < Validation
field :validation_ids, type: Array, default: []
- field :model_id, type: BSON::ObjectId
field :folds, type: Integer, default: 10
- field :nr_instances, type: Integer, default: 0
- field :nr_unpredicted, type: Integer, default: 0
- field :predictions, type: Hash, default: {}
-
- def time
- finished_at - created_at
- end
-
- def validations
- validation_ids.collect{|vid| TrainTest.find vid}
- end
-
- def model
- Model::Lazar.find model_id
- end
def self.create model, n=10
klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
@@ -55,6 +39,14 @@ module OpenTox
cv.update_attributes(finished_at: Time.now)
cv
end
+
+ def time
+ finished_at - created_at
+ end
+
+ def validations
+ validation_ids.collect{|vid| TrainTest.find vid}
+ end
end
class ClassificationCrossValidation < CrossValidation
diff --git a/lib/leave-one-out-validation.rb b/lib/leave-one-out-validation.rb
index 7ff65ff..59f43c5 100644
--- a/lib/leave-one-out-validation.rb
+++ b/lib/leave-one-out-validation.rb
@@ -49,7 +49,6 @@ module OpenTox
field :mae, type: Float, default: 0
field :r_squared, type: Float
field :correlation_plot_id, type: BSON::ObjectId
- field :confidence_plot_id, type: BSON::ObjectId
end
end
diff --git a/lib/model.rb b/lib/model.rb
index 988cac9..81f9629 100644
--- a/lib/model.rb
+++ b/lib/model.rb
@@ -33,7 +33,6 @@ module OpenTox
#send(feature_selection_algorithm.to_sym) if feature_selection_algorithm
save
- self
end
def correlation_filter
@@ -203,7 +202,7 @@ module OpenTox
}.each do |key,value|
model.neighbor_algorithm_parameters[key] ||= value
end
- model.neighbor_algorithm_parameters[:type] = "MP2D" if training_dataset.substances.first.is_a? Compound
+ model.neighbor_algorithm_parameters[:type] ||= "MP2D" if training_dataset.substances.first.is_a? Compound
model.save
model
end
diff --git a/lib/validation-statistics.rb b/lib/validation-statistics.rb
index 816824b..e42d298 100644
--- a/lib/validation-statistics.rb
+++ b/lib/validation-statistics.rb
@@ -98,8 +98,8 @@ module OpenTox
def statistics
# TODO: predictions within prediction_interval
- rmse = 0
- mae = 0
+ self.rmse = 0
+ self.mae = 0
x = []
y = []
predictions.each do |cid,pred|
@@ -107,8 +107,8 @@ module OpenTox
x << pred[:measurements].median
y << pred[:value]
error = pred[:value]-pred[:measurements].median
- rmse += error**2
- mae += error.abs
+ self.rmse += error**2
+ self.mae += error.abs
else
warnings << "No training activities for #{Compound.find(compound_id).smiles} in training dataset #{model.training_dataset_id}."
$logger.debug "No training activities for #{Compound.find(compound_id).smiles} in training dataset #{model.training_dataset_id}."
@@ -117,17 +117,18 @@ module OpenTox
R.assign "measurement", x
R.assign "prediction", y
R.eval "r <- cor(measurement,prediction,use='pairwise')"
- r = R.eval("r").to_ruby
+ self.r_squared = R.eval("r").to_ruby**2
- mae = mae/predictions.size
- rmse = Math.sqrt(rmse/predictions.size)
- $logger.debug "R^2 #{r**2}"
+ self.mae = self.mae/predictions.size
+ self.rmse = Math.sqrt(self.rmse/predictions.size)
+ $logger.debug "R^2 #{r_squared}"
$logger.debug "RMSE #{rmse}"
$logger.debug "MAE #{mae}"
+ save
{
:mae => mae,
:rmse => rmse,
- :r_squared => r**2,
+ :r_squared => r_squared,
}
end
diff --git a/lib/validation.rb b/lib/validation.rb
index ff9a971..ced9596 100644
--- a/lib/validation.rb
+++ b/lib/validation.rb
@@ -9,9 +9,9 @@ module OpenTox
store_in collection: "validations"
field :name, type: String
field :model_id, type: BSON::ObjectId
- field :nr_instances, type: Integer
- field :nr_unpredicted, type: Integer
- field :predictions, type: Hash
+ field :nr_instances, type: Integer, default: 0
+ field :nr_unpredicted, type: Integer, default: 0
+ field :predictions, type: Hash, default: {}
field :finished_at, type: Time
def model
diff --git a/test/all.rb b/test/all.rb
index a10bcaa..8e137b4 100644
--- a/test/all.rb
+++ b/test/all.rb
@@ -1,5 +1,5 @@
# "./default_environment.rb" has to be executed separately
-exclude = ["./setup.rb","./all.rb", "./default_environment.rb","./nanoparticles.rb"]
+exclude = ["./setup.rb","./all.rb", "./default_environment.rb",]
(Dir[File.join(File.dirname(__FILE__),"*.rb")]-exclude).each do |test|
require_relative test
end
diff --git a/test/validation.rb b/test/validation.rb
index a259472..4d0c372 100644
--- a/test/validation.rb
+++ b/test/validation.rb
@@ -59,6 +59,7 @@ class ValidationTest < MiniTest::Test
}
}
model = Model::LazarRegression.create dataset.features.first, dataset, params
+ assert_equal params[:neighbor_algorithm_parameters][:type], model[:neighbor_algorithm_parameters][:type]
cv = RegressionCrossValidation.create model
cv.validation_ids.each do |vid|
model = Model::Lazar.find(Validation.find(vid).model_id)
@@ -74,7 +75,6 @@ class ValidationTest < MiniTest::Test
end
def test_physchem_regression_crossvalidation
-
training_dataset = OpenTox::Dataset.from_csv_file File.join(DATA_DIR,"EPAFHM.medi_log10.csv")
model = Model::LazarRegression.create(training_dataset.features.first, training_dataset, :prediction_algorithm => "OpenTox::Algorithm::Regression.local_physchem_regression")
cv = RegressionCrossValidation.create model