From ab652ac85036c5b372e7f1a08cdb75a19db5b19a Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Sun, 8 May 2016 12:57:10 +0200 Subject: regression crossvalidation fixed --- lib/compound.rb | 5 ++++- lib/leave-one-out-validation.rb | 6 +++--- test/validation.rb | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/compound.rb b/lib/compound.rb index 3af6f6c..0a9111b 100644 --- a/lib/compound.rb +++ b/lib/compound.rb @@ -288,7 +288,10 @@ module OpenTox training_dataset.compounds.each do |compound| candidate_fingerprint = compound.fingerprint params[:type] sim = (query_fingerprint & candidate_fingerprint).size/(query_fingerprint | candidate_fingerprint).size.to_f - neighbors << {"_id" => compound.id, "toxicities" => {prediction_feature.id.to_s => {training_dataset_id.to_s => compound.toxicities[prediction_feature.id.to_s][training_dataset_id.to_s]}}, "tanimoto" => sim} if sim >= params[:min_sim] + fid = prediction_feature.id.to_s + did = params[:training_dataset_id].to_s + v = compound.toxicities[prediction_feature.id.to_s] + neighbors << {"_id" => compound.id, "toxicities" => {fid => {did => v[params[:training_dataset_id].to_s]}}, "tanimoto" => sim} if sim >= params[:min_sim] and v end neighbors.sort!{|a,b| b["tanimoto"] <=> a["tanimoto"]} end diff --git a/lib/leave-one-out-validation.rb b/lib/leave-one-out-validation.rb index 2306041..7189617 100644 --- a/lib/leave-one-out-validation.rb +++ b/lib/leave-one-out-validation.rb @@ -3,7 +3,6 @@ module OpenTox class LeaveOneOutValidation field :model_id, type: BSON::ObjectId - field :dataset_id, type: BSON::ObjectId field :nr_instances, type: Integer field :nr_unpredicted, type: Integer field :predictions, type: Hash @@ -13,13 +12,14 @@ module OpenTox $logger.debug "#{model.name}: LOO validation started" t = Time.now model.training_dataset.features.first.nominal? ? klass = ClassificationLeaveOneOutValidation : klass = RegressionLeaveOneOutValidation - loo = klass.new :model_id => model.id, :dataset_id => model.training_dataset_id + loo = klass.new :model_id => model.id predictions = model.predict model.training_dataset.compounds predictions.each{|cid,p| p.delete(:neighbors)} nr_unpredicted = 0 predictions.each do |cid,prediction| if prediction[:value] - prediction[:measured] = Substance.find(cid).toxicities[prediction[:prediction_feature_id].to_s][dataset_id.to_s] + tox = Substance.find(cid).toxicities[prediction[:prediction_feature_id].to_s] + prediction[:measured] = tox[model.training_dataset_id.to_s] if tox else nr_unpredicted += 1 end diff --git a/test/validation.rb b/test/validation.rb index 021fac5..8ebb52c 100644 --- a/test/validation.rb +++ b/test/validation.rb @@ -25,7 +25,6 @@ class ValidationTest < MiniTest::Test def test_classification_crossvalidation_parameters dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv" params = { - :training_dataset_id => dataset.id, :neighbor_algorithm_parameters => { :min_sim => 0.3, :type => "FP3" @@ -56,6 +55,7 @@ class ValidationTest < MiniTest::Test } } model = Model::LazarRegression.create dataset.features.first, dataset, params + p model cv = RegressionCrossValidation.create model cv.validation_ids.each do |vid| model = Model::Lazar.find(Validation.find(vid).model_id) -- cgit v1.2.3