summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2017-01-11 09:20:40 +0100
committerChristoph Helma <helma@in-silico.ch>2017-01-11 09:20:40 +0100
commit04ebe0640ab6e566dfc316f80a020d1e78b10924 (patch)
tree945f89f33fed9f868318e678af029f1c491eb2e2
parented0d7edee4ac9831b58a01555de8bdba3534495e (diff)
validation documentation
-rw-r--r--lib/crossvalidation.rb22
-rw-r--r--lib/model.rb12
-rw-r--r--lib/validation-statistics.rb19
-rw-r--r--lib/validation.rb3
4 files changed, 54 insertions, 2 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb
index bcb3ccf..75c5db5 100644
--- a/lib/crossvalidation.rb
+++ b/lib/crossvalidation.rb
@@ -1,10 +1,16 @@
module OpenTox
module Validation
+
+ # Crossvalidation
class CrossValidation < Validation
field :validation_ids, type: Array, default: []
field :folds, type: Integer, default: 10
+ # Create a crossvalidation
+ # @param [OpenTox::Model::Lazar]
+ # @param [Fixnum] number of folds
+ # @return [OpenTox::Validation::CrossValidation]
def self.create model, n=10
$logger.debug model.algorithms
klass = ClassificationCrossValidation if model.is_a? Model::LazarClassification
@@ -41,14 +47,20 @@ module OpenTox
cv
end
+ # Get execution time
+ # @return [Fixnum]
def time
finished_at - created_at
end
+ # Get individual validations
+ # @return [Array<OpenTox::Validation>]
def validations
validation_ids.collect{|vid| TrainTest.find vid}
end
+ # Get predictions for all compounds
+ # @return [Array<Hash>]
def predictions
predictions = {}
validations.each{|v| predictions.merge!(v.predictions)}
@@ -56,6 +68,7 @@ module OpenTox
end
end
+ # Crossvalidation of classification models
class ClassificationCrossValidation < CrossValidation
include ClassificationStatistics
field :accept_values, type: Array
@@ -68,6 +81,7 @@ module OpenTox
field :probability_plot_id, type: BSON::ObjectId
end
+ # Crossvalidation of regression models
class RegressionCrossValidation < CrossValidation
include RegressionStatistics
field :rmse, type: Float, default:0
@@ -78,10 +92,16 @@ module OpenTox
field :correlation_plot_id, type: BSON::ObjectId
end
+ # Independent repeated crossvalidations
class RepeatedCrossValidation < Validation
field :crossvalidation_ids, type: Array, default: []
field :correlation_plot_id, type: BSON::ObjectId
+ # Create repeated crossvalidations
+ # @param [OpenTox::Model::Lazar]
+ # @param [Fixnum] number of folds
+ # @param [Fixnum] number of repeats
+ # @return [OpenTox::Validation::RepeatedCrossValidation]
def self.create model, folds=10, repeats=3
repeated_cross_validation = self.new
repeats.times do |n|
@@ -92,6 +112,8 @@ module OpenTox
repeated_cross_validation
end
+ # Get crossvalidations
+ # @return [OpenTox::Validation::CrossValidation]
def crossvalidations
crossvalidation_ids.collect{|id| CrossValidation.find(id)}
end
diff --git a/lib/model.rb b/lib/model.rb
index 64edb76..b18610d 100644
--- a/lib/model.rb
+++ b/lib/model.rb
@@ -320,7 +320,9 @@ module OpenTox
end
- def save # store independent_variables in GridFS to avoid Mongo database size limit problems
+ # Save the model
+ # Stores independent_variables in GridFS to avoid Mongo database size limit problems
+ def save
file = Mongo::Grid::File.new(Marshal.dump(@independent_variables), :filename => "#{id}.independent_variables")
self.independent_variables_id = $gridfs.insert_one(file)
super
@@ -357,6 +359,8 @@ module OpenTox
substance_ids.collect{|id| Substance.find(id)}
end
+ # Are fingerprints used as descriptors
+ # @return [TrueClass, FalseClass]
def fingerprints?
algorithms[:descriptors][:method] == "fingerprint" ? true : false
end
@@ -428,10 +432,14 @@ module OpenTox
repeated_crossvalidation.crossvalidations
end
+ # Is it a regression model
+ # @return [TrueClass, FalseClass]
def regression?
model.is_a? LazarRegression
end
+ # Is it a classification model
+ # @return [TrueClass, FalseClass]
def classification?
model.is_a? LazarClassification
end
@@ -452,7 +460,7 @@ module OpenTox
end
# Create and validate a nano-lazar model, import data from eNanoMapper if necessary
- # nano-lazar methods are described in detail in https://github.com/enanomapper/nano-lazar-paper/blob/master/nano-lazar.pdf
+ # nano-lazar methods are described in detail in https://github.com/enanomapper/nano-lazar-paper/blob/master/nano-lazar.pdf
# @param [OpenTox::Dataset, nil] training_dataset
# @param [OpenTox::Feature, nil] prediction_feature
# @param [Hash, nil] algorithms
diff --git a/lib/validation-statistics.rb b/lib/validation-statistics.rb
index 2202b79..553e6ac 100644
--- a/lib/validation-statistics.rb
+++ b/lib/validation-statistics.rb
@@ -1,7 +1,10 @@
module OpenTox
module Validation
+ # Statistical evaluation of classification validations
module ClassificationStatistics
+ # Get statistics
+ # @return [Hash]
def statistics
self.accept_values = model.prediction_feature.accept_values
self.confusion_matrix = Array.new(accept_values.size){Array.new(accept_values.size,0)}
@@ -63,6 +66,9 @@ module OpenTox
}
end
+ # Plot accuracy vs prediction probability
+ # @param [String,nil] format
+ # @return [Blob]
def probability_plot format: "pdf"
#unless probability_plot_id
@@ -99,8 +105,11 @@ module OpenTox
end
end
+ # Statistical evaluation of regression validations
module RegressionStatistics
+ # Get statistics
+ # @return [Hash]
def statistics
self.rmse = 0
self.mae = 0
@@ -147,10 +156,15 @@ module OpenTox
}
end
+ # Get percentage of measurements within the prediction interval
+ # @return [Float]
def percent_within_prediction_interval
100*within_prediction_interval.to_f/(within_prediction_interval+out_of_prediction_interval)
end
+ # Plot predicted vs measured values
+ # @param [String,nil] format
+ # @return [Blob]
def correlation_plot format: "png"
unless correlation_plot_id
tmpfile = "/tmp/#{id.to_s}_correlation.#{format}"
@@ -177,6 +191,11 @@ module OpenTox
$gridfs.find_one(_id: correlation_plot_id).data
end
+ # Get predictions with the largest difference between predicted and measured values
+ # @params [Fixnum] number of predictions
+ # @params [TrueClass,FalseClass,nil] include neighbors
+ # @params [TrueClass,FalseClass,nil] show common descriptors
+ # @return [Hash]
def worst_predictions n: 5, show_neigbors: true, show_common_descriptors: false
worst_predictions = predictions.sort_by{|sid,p| -(p["value"] - p["measurements"].median).abs}[0,n]
worst_predictions.collect do |p|
diff --git a/lib/validation.rb b/lib/validation.rb
index ced9596..c9954b6 100644
--- a/lib/validation.rb
+++ b/lib/validation.rb
@@ -2,6 +2,7 @@ module OpenTox
module Validation
+ # Base validation class
class Validation
include OpenTox
include Mongoid::Document
@@ -14,6 +15,8 @@ module OpenTox
field :predictions, type: Hash, default: {}
field :finished_at, type: Time
+ # Get model
+ # @return [OpenTox::Model::Lazar]
def model
Model::Lazar.find model_id
end