summaryrefslogtreecommitdiff
path: root/lib/crossvalidation.rb
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2016-05-31 18:08:08 +0200
committerChristoph Helma <helma@in-silico.ch>2016-05-31 18:08:08 +0200
commitb515a0cfedb887a2af753db6e4a08ae1af430cad (patch)
tree5d69d89d0031d581e932272aeb741ee38a0106d6 /lib/crossvalidation.rb
parentf46ba3b7262f5b551c81fc9396c5b7f0cac7f030 (diff)
cleanup of validation modules/classes
Diffstat (limited to 'lib/crossvalidation.rb')
-rw-r--r--lib/crossvalidation.rb251
1 files changed, 77 insertions, 174 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb
index 420dd8c..22071d8 100644
--- a/lib/crossvalidation.rb
+++ b/lib/crossvalidation.rb
@@ -1,193 +1,96 @@
module OpenTox
- class CrossValidation
- field :validation_ids, type: Array, default: []
- field :model_id, type: BSON::ObjectId
- field :folds, type: Integer
- field :nr_instances, type: Integer
- field :nr_unpredicted, type: Integer
- field :predictions, type: Hash, default: {}
- field :finished_at, type: Time
-
- def time
- finished_at - created_at
- end
-
- def validations
- validation_ids.collect{|vid| Validation.find vid}
- end
-
- def model
- Model::Lazar.find model_id
- end
-
- def self.create model, n=10
- klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
- klass = RegressionCrossValidation if model.is_a? Model::LazarRegression
- bad_request_error "Unknown model class #{model.class}." unless klass
-
- cv = klass.new(
- name: model.name,
- model_id: model.id,
- folds: n
- )
- cv.save # set created_at
- nr_instances = 0
- nr_unpredicted = 0
- predictions = {}
- training_dataset = Dataset.find model.training_dataset_id
- training_dataset.folds(n).each_with_index do |fold,fold_nr|
- #fork do # parallel execution of validations can lead to Rserve and memory problems
- $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started"
- t = Time.now
- validation = Validation.create(model, fold[0], fold[1],cv)
- #p validation
- $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}: #{Time.now-t} seconds"
- #end
- end
- #Process.waitall
- cv.validation_ids = Validation.where(:crossvalidation_id => cv.id).distinct(:_id)
- cv.validations.each do |validation|
- nr_instances += validation.nr_instances
- nr_unpredicted += validation.nr_unpredicted
- predictions.merge! validation.predictions
+ module Validation
+ class CrossValidation < Validation
+ field :validation_ids, type: Array, default: []
+ field :model_id, type: BSON::ObjectId
+ field :folds, type: Integer, default: 10
+ field :nr_instances, type: Integer, default: 0
+ field :nr_unpredicted, type: Integer, default: 0
+ field :predictions, type: Hash, default: {}
+
+ def time
+ finished_at - created_at
end
- cv.update_attributes(
- nr_instances: nr_instances,
- nr_unpredicted: nr_unpredicted,
- predictions: predictions
- )
- $logger.debug "Nr unpredicted: #{nr_unpredicted}"
- cv.statistics
- cv
- end
- end
- class ClassificationCrossValidation < CrossValidation
-
- field :accept_values, type: Array
- field :confusion_matrix, type: Array
- field :weighted_confusion_matrix, type: Array
- field :accuracy, type: Float
- field :weighted_accuracy, type: Float
- field :true_rate, type: Hash
- field :predictivity, type: Hash
- field :confidence_plot_id, type: BSON::ObjectId
- # TODO auc, f-measure (usability??)
-
- def statistics
- stat = ValidationStatistics.classification(predictions, Feature.find(model.prediction_feature_id).accept_values)
- update_attributes(stat)
- stat
- end
+ def validations
+ validation_ids.collect{|vid| TrainTest.find vid}
+ end
- def confidence_plot
- unless confidence_plot_id
- tmpfile = "/tmp/#{id.to_s}_confidence.png"
- accuracies = []
- confidences = []
- correct_predictions = 0
- incorrect_predictions = 0
- predictions.each do |p|
- if p[1] and p[2]
- p[1] == p[2] ? correct_predictions += 1 : incorrect_predictions += 1
- accuracies << correct_predictions/(correct_predictions+incorrect_predictions).to_f
- confidences << p[3]
+ def model
+ Model::Lazar.find model_id
+ end
- end
+ def self.create model, n=10
+ klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
+ klass = RegressionCrossValidation if model.is_a? Model::LazarRegression
+ bad_request_error "Unknown model class #{model.class}." unless klass
+
+ cv = klass.new(
+ name: model.name,
+ model_id: model.id,
+ folds: n
+ )
+ cv.save # set created_at
+ nr_instances = 0
+ nr_unpredicted = 0
+ predictions = {}
+ training_dataset = Dataset.find model.training_dataset_id
+ training_dataset.folds(n).each_with_index do |fold,fold_nr|
+ #fork do # parallel execution of validations can lead to Rserve and memory problems
+ $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started"
+ t = Time.now
+ validation = TrainTest.create(model, fold[0], fold[1])
+ cv.validation_ids << validation.id
+ cv.nr_instances += validation.nr_instances
+ cv.nr_unpredicted += validation.nr_unpredicted
+ cv.predictions.merge! validation.predictions
+ $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}: #{Time.now-t} seconds"
+ #end
end
- R.assign "accuracy", accuracies
- R.assign "confidence", confidences
- R.eval "image = qplot(confidence,accuracy)+ylab('accumulated accuracy')+scale_x_reverse()"
- R.eval "ggsave(file='#{tmpfile}', plot=image)"
- file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{self.id.to_s}_confidence_plot.png")
- plot_id = $gridfs.insert_one(file)
- update(:confidence_plot_id => plot_id)
+ #Process.waitall
+ cv.save
+ $logger.debug "Nr unpredicted: #{nr_unpredicted}"
+ cv.statistics
+ cv.update_attributes(finished_at: Time.now)
+ cv
end
- $gridfs.find_one(_id: confidence_plot_id).data
- end
-
- #Average area under roc 0.646
- #Area under roc 0.646
- #F measure carcinogen: 0.769, noncarcinogen: 0.348
- end
-
- class RegressionCrossValidation < CrossValidation
-
- field :rmse, type: Float
- field :mae, type: Float
- field :r_squared, type: Float
- field :correlation_plot_id, type: BSON::ObjectId
-
- def statistics
- stat = ValidationStatistics.regression predictions
- update_attributes(stat)
- stat
end
- def misclassifications n=nil
- n ||= 10
- model = Model::Lazar.find(self.model_id)
- training_dataset = Dataset.find(model.training_dataset_id)
- prediction_feature = training_dataset.features.first
- predictions.collect do |p|
- unless p.include? nil
- compound = Compound.find(p[0])
- neighbors = compound.send(model.neighbor_algorithm,model.neighbor_algorithm_parameters)
- neighbors.collect! do |n|
- neighbor = Compound.find(n[0])
- { :smiles => neighbor.smiles, :similarity => n[1], :measurements => neighbor.toxicities[prediction_feature.id.to_s][training_dataset.id.to_s]}
- end
- {
- :smiles => compound.smiles,
- :measured => p[1],
- :predicted => p[2],
- :error => (p[1]-p[2]).abs,
- :relative_error => (p[1]-p[2]).abs/p[1],
- :confidence => p[3],
- :neighbors => neighbors
- }
- end
- end.compact.sort{|a,b| b[:relative_error] <=> a[:relative_error]}[0..n-1]
+ class ClassificationCrossValidation < CrossValidation
+ include ClassificationStatistics
+ field :accept_values, type: Array
+ field :confusion_matrix, type: Array
+ field :weighted_confusion_matrix, type: Array
+ field :accuracy, type: Float
+ field :weighted_accuracy, type: Float
+ field :true_rate, type: Hash
+ field :predictivity, type: Hash
+ field :confidence_plot_id, type: BSON::ObjectId
end
- def confidence_plot
- tmpfile = "/tmp/#{id.to_s}_confidence.png"
- sorted_predictions = predictions.collect{|p| [(p[1]-p[2]).abs,p[3]] if p[1] and p[2]}.compact
- R.assign "error", sorted_predictions.collect{|p| p[0]}
- R.assign "confidence", sorted_predictions.collect{|p| p[1]}
- # TODO fix axis names
- R.eval "image = qplot(confidence,error)"
- R.eval "image = image + stat_smooth(method='lm', se=FALSE)"
- R.eval "ggsave(file='#{tmpfile}', plot=image)"
- file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{self.id.to_s}_confidence_plot.png")
- plot_id = $gridfs.insert_one(file)
- update(:confidence_plot_id => plot_id)
- $gridfs.find_one(_id: confidence_plot_id).data
+ class RegressionCrossValidation < CrossValidation
+ include RegressionStatistics
+ field :rmse, type: Float
+ field :mae, type: Float
+ field :r_squared, type: Float
+ field :correlation_plot_id, type: BSON::ObjectId
end
- def correlation_plot
- unless correlation_plot_id
- plot_id = ValidationStatistics.correlation_plot id, predictions
- update(:correlation_plot_id => plot_id)
+ class RepeatedCrossValidation < Validation
+ field :crossvalidation_ids, type: Array, default: []
+ def self.create model, folds=10, repeats=3
+ repeated_cross_validation = self.new
+ repeats.times do |n|
+ $logger.debug "Crossvalidation #{n+1} for #{model.name}"
+ repeated_cross_validation.crossvalidation_ids << CrossValidation.create(model, folds).id
+ end
+ repeated_cross_validation.save
+ repeated_cross_validation
end
- $gridfs.find_one(_id: correlation_plot_id).data
- end
- end
-
- class RepeatedCrossValidation
- field :crossvalidation_ids, type: Array, default: []
- def self.create model, folds=10, repeats=3
- repeated_cross_validation = self.new
- repeats.times do |n|
- $logger.debug "Crossvalidation #{n+1} for #{model.name}"
- repeated_cross_validation.crossvalidation_ids << CrossValidation.create(model, folds).id
+ def crossvalidations
+ crossvalidation_ids.collect{|id| CrossValidation.find(id)}
end
- repeated_cross_validation.save
- repeated_cross_validation
- end
- def crossvalidations
- crossvalidation_ids.collect{|id| CrossValidation.find(id)}
end
end