summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2015-08-26 14:20:23 +0200
committerChristoph Helma <helma@in-silico.ch>2015-08-26 14:20:23 +0200
commitd542e9fe92567c54423f39904111bd5293236416 (patch)
tree68d04fe73e7012a2732a15703b25f5934c7e7dad /lib
parentf8faf510b4574df1a00fa61a9f0a1681fc2f4857 (diff)
Parallel Crossvalidations
Diffstat (limited to 'lib')
-rw-r--r--lib/crossvalidation.rb81
-rw-r--r--lib/dataset.rb18
-rw-r--r--lib/experiment.rb4
-rw-r--r--lib/lazar.rb2
-rw-r--r--lib/opentox.rb1
-rw-r--r--lib/validation.rb5
6 files changed, 61 insertions, 50 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb
index 4407aeb..58a9664 100644
--- a/lib/crossvalidation.rb
+++ b/lib/crossvalidation.rb
@@ -6,13 +6,16 @@ module OpenTox
field :folds, type: Integer
field :nr_instances, type: Integer
field :nr_unpredicted, type: Integer
- field :predictions, type: Array
+ field :predictions, type: Array, default: []
field :finished_at, type: Time
- #belongs_to :prediction
def time
finished_at - created_at
end
+
+ def validations
+ validation_ids.collect{|vid| Validation.find vid}
+ end
end
class ClassificationCrossValidation < CrossValidation
@@ -45,7 +48,7 @@ module OpenTox
t = Time.now
$logger.debug "Fold #{fold_nr}"
validation = validation_class.create(model, fold[0], fold[1])
- validation_ids << validation.id
+ #validation_ids << validation.id
nr_instances += validation.nr_instances
nr_unpredicted += validation.nr_unpredicted
predictions += validation.predictions
@@ -74,7 +77,7 @@ module OpenTox
name: model.name,
model_id: model.id,
folds: n,
- validation_ids: validation_ids,
+ #validation_ids: validation_ids,
nr_instances: nr_instances,
nr_unpredicted: nr_unpredicted,
accept_values: accept_values,
@@ -103,29 +106,33 @@ module OpenTox
field :weighted_rmse, type: Float
field :weighted_mae, type: Float
field :weighted_mae, type: Float
+ field :r_squared, type: Float
field :correlation_plot_id, type: BSON::ObjectId
def self.create model, n=10
cv = self.new
cv.save # set created_at
- validation_ids = []
+ #validation_ids = []
nr_instances = 0
nr_unpredicted = 0
predictions = []
validation_class = Object.const_get(self.to_s.sub(/Cross/,''))
fold_nr = 1
training_dataset = Dataset.find model.training_dataset_id
- training_dataset.folds(n).each do |fold|
- t = Time.now
- $logger.debug "Predicting fold #{fold_nr}"
-
- validation = validation_class.create(model, fold[0], fold[1])
- validation_ids << validation.id
+ training_dataset.folds(n).each_with_index do |fold,fold_nr|
+ fork do # parallel execution of validations
+ $logger.debug "Dataset #{training_dataset.name}: Fold #{fold_nr} started"
+ t = Time.now
+ validation = validation_class.create(model, fold[0], fold[1],cv)
+ $logger.debug "Dataset #{training_dataset.name}, Fold #{fold_nr}: #{Time.now-t} seconds"
+ end
+ end
+ Process.waitall
+ cv.validation_ids = Validation.where(:crossvalidation_id => cv.id).distinct(:_id)
+ cv.validations.each do |validation|
nr_instances += validation.nr_instances
nr_unpredicted += validation.nr_unpredicted
predictions += validation.predictions
- $logger.debug "Fold #{fold_nr}: #{Time.now-t} seconds"
- fold_nr +=1
end
rmse = 0
weighted_rmse = 0
@@ -135,9 +142,8 @@ module OpenTox
weighted_mae = 0
rae = 0
weighted_rae = 0
- n = 0
confidence_sum = 0
- nil_activities = []
+ #nil_activities = []
predictions.each do |pred|
compound_id,activity,prediction,confidence = pred
if activity and prediction
@@ -146,34 +152,29 @@ module OpenTox
weighted_rmse += confidence*error**2
mae += error.abs
weighted_mae += confidence*error.abs
- n += 1
confidence_sum += confidence
+ cv.predictions << pred
else
# TODO: create warnings
+ cv.warnings << "No training activities for #{Compound.find(compound_id).smiles} in training dataset #{training_dataset.id}."
$logger.debug "No training activities for #{Compound.find(compound_id).smiles} in training dataset #{training_dataset.id}."
- nil_activities << pred
+ #nil_activities << pred
end
end
- predictions -= nil_activities
- x = predictions.collect{|p| p[1]}
- y = predictions.collect{|p| p[2]}
- R.assign "Measurement", x
- R.assign "Prediction", y
- R.eval "corr <- lm(-log(Measurement) ~ -log(Prediction))"
- s = R.eval "summary <- summary(corr)"
- p R.eval("summary$r.squared").to_ruby
- #p s.to_ruby
- #p s.to_ruby.first
- s.to_ruby.each_with_index do |l,i|
- #p i
- #p l
- end
- mae = mae/n
+ #predictions -= nil_activities
+ x = cv.predictions.collect{|p| p[1]}
+ y = cv.predictions.collect{|p| p[2]}
+ R.assign "measurement", x
+ R.assign "prediction", y
+ R.eval "r <- cor(-log(measurement),-log(prediction))"
+ r = R.eval("r").to_ruby
+
+ mae = mae/cv.predictions.size
weighted_mae = weighted_mae/confidence_sum
- rmse = Math.sqrt(rmse/n)
+ rmse = Math.sqrt(rmse/cv.predictions.size)
weighted_rmse = Math.sqrt(weighted_rmse/confidence_sum)
# TODO check!!
- predictions.sort! do |a,b|
+ cv.predictions.sort! do |a,b|
relative_error_a = (a[1]-a[2]).abs/a[1].to_f
relative_error_a = 1/relative_error_a if relative_error_a < 1
relative_error_b = (b[1]-b[2]).abs/b[1].to_f
@@ -184,14 +185,15 @@ module OpenTox
name: model.name,
model_id: model.id,
folds: n,
- validation_ids: validation_ids,
+ #validation_ids: validation_ids,
nr_instances: nr_instances,
nr_unpredicted: nr_unpredicted,
- predictions: predictions,#.sort{|a,b| [(b[1]-b[2]).abs/b[1].to_f,b[3]] <=> [(a[1]-a[2]).abs/a[1].to_f,a[3]]},
+ #predictions: predictions,#.sort{|a,b| [(b[1]-b[2]).abs/b[1].to_f,b[3]] <=> [(a[1]-a[2]).abs/a[1].to_f,a[3]]},
mae: mae,
rmse: rmse,
weighted_mae: weighted_mae,
- weighted_rmse: weighted_rmse
+ weighted_rmse: weighted_rmse,
+ r_squared: r**2
)
cv.save
cv
@@ -239,19 +241,20 @@ module OpenTox
#R.eval "error <- log(Measurement)-log(Prediction)"
#R.eval "rmse <- sqrt(mean(error^2, na.rm=T))"
#R.eval "mae <- mean(abs(error), na.rm=T)"
- R.eval "r <- cor(-log(prediction),-log(measurement))"
+ #R.eval "r <- cor(-log(prediction),-log(measurement))"
R.eval "svg(filename='#{tmpfile}')"
R.eval "all = c(-log(measurement),-log(prediction))"
R.eval "range = c(min(all), max(all))"
R.eval "image = qplot(-log(prediction),-log(measurement),main='#{self.name}',asp=1,xlim=range, ylim=range)"
R.eval "image = image + geom_abline(intercept=0, slope=1) + stat_smooth(method='lm', se=FALSE)"
- R.eval "text = textGrob(paste('RMSE: ', '#{rmse.round(2)},','MAE:','#{mae.round(2)},','r^2: ',round(r^2,2),'\n\n','#{attributes}'),just=c('left','top'),check.overlap = T)"
+ R.eval "text = textGrob(paste('RMSE: ', '#{rmse.round(2)},','MAE:','#{mae.round(2)},','r^2: ','#{r_squared.round(2)}','\n\n','#{attributes}'),just=c('left','top'),check.overlap = T)"
R.eval "grid.arrange(image, text, ncol=2)"
R.eval "dev.off()"
file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{self.id.to_s}_correlation_plot.svg")
plot_id = $gridfs.insert_one(file)
update(:correlation_plot_id => plot_id)
end
+ p correlation_plot_id
$gridfs.find_one(_id: correlation_plot_id).data
end
end
diff --git a/lib/dataset.rb b/lib/dataset.rb
index b3f5392..979753c 100644
--- a/lib/dataset.rb
+++ b/lib/dataset.rb
@@ -12,7 +12,6 @@ module OpenTox
field :compound_ids, type: Array, default: []
field :data_entries_id, type: BSON::ObjectId#, default: []
field :source, type: String
- field :warnings, type: Array, default: []
# Save all data including data_entries
# Should be used instead of save
@@ -21,7 +20,6 @@ module OpenTox
file = Mongo::Grid::File.new(dump, :filename => "#{self.id.to_s}.data_entries")
entries_id = $gridfs.insert_one(file)
update(:data_entries_id => entries_id)
- #save
end
# Readers
@@ -50,7 +48,7 @@ module OpenTox
bad_request_error "Data entries (#{data_entries_id}) are not a 2D-Array" unless @data_entries.is_a? Array and @data_entries.first.is_a? Array
bad_request_error "Data entries (#{data_entries_id}) have #{@data_entries.size} rows, but dataset (#{id}) has #{compound_ids.size} compounds" unless @data_entries.size == compound_ids.size
bad_request_error "Data entries (#{data_entries_id}) have #{@data_entries.first.size} columns, but dataset (#{id}) has #{feature_ids.size} features" unless @data_entries.first.size == feature_ids.size
- $logger.debug "Retrieving data: #{Time.now-t}"
+ #$logger.debug "Retrieving data: #{Time.now-t}"
end
end
@data_entries
@@ -149,11 +147,17 @@ module OpenTox
# Create a dataset from CSV file
# TODO: document structure
def self.from_csv_file file, source=nil, bioassay=true
- $logger.debug "Parsing #{file}."
source ||= file
- table = CSV.read file, :skip_blanks => true
- dataset = self.new(:source => source, :name => File.basename(file,".*"))
- dataset.parse_table table, bioassay
+ name = File.basename(file,".*")
+ dataset = self.find_by(:source => source, :name => name)
+ if dataset
+ $logger.debug "#{file} already in database."
+ else
+ $logger.debug "Parsing #{file}."
+ table = CSV.read file, :skip_blanks => true
+ dataset = self.new(:source => source, :name => name)
+ dataset.parse_table table, bioassay
+ end
dataset
end
diff --git a/lib/experiment.rb b/lib/experiment.rb
index b3ed174..191e76e 100644
--- a/lib/experiment.rb
+++ b/lib/experiment.rb
@@ -54,12 +54,12 @@ module OpenTox
end
def report
+ # TODO create ggplot2 report
crossvalidation_ids.each do |id|
cv = CrossValidation.find(id)
- file = "/tmp/#{cv.name}.svg"
+ file = "/tmp/#{id}.svg"
File.open(file,"w+"){|f| f.puts cv.correlation_plot}
`inkview '#{file}'`
- #p Crossvalidation.find(id).correlation_plot
end
end
diff --git a/lib/lazar.rb b/lib/lazar.rb
index 5903556..decbe69 100644
--- a/lib/lazar.rb
+++ b/lib/lazar.rb
@@ -9,12 +9,12 @@ require 'rserve'
require "nokogiri"
require "base64"
-
# Mongo setup
# TODO retrieve correct environment from Rack/Sinatra
ENV["MONGOID_ENV"] ||= "development"
# TODO remove config files, change default via ENV or directly in Mongoid class
Mongoid.load!("#{File.expand_path(File.join(File.dirname(__FILE__),'..','mongoid.yml'))}")
+Mongoid.raise_not_found_error = false # return nil if no document is found
$mongo = Mongoid.default_client
$gridfs = $mongo.database.fs
diff --git a/lib/opentox.rb b/lib/opentox.rb
index 53b34e9..875487c 100644
--- a/lib/opentox.rb
+++ b/lib/opentox.rb
@@ -13,6 +13,7 @@ module OpenTox
include Mongoid::Timestamps
store_in collection: klass.downcase.pluralize
field :name, type: String
+ field :warnings, type: Array, default: []
end
OpenTox.const_set klass,c
diff --git a/lib/validation.rb b/lib/validation.rb
index bcbe49a..445f897 100644
--- a/lib/validation.rb
+++ b/lib/validation.rb
@@ -1,8 +1,10 @@
module OpenTox
class Validation
+ #include Celluloid
field :prediction_dataset_id, type: BSON::ObjectId
+ field :crossvalidation_id, type: BSON::ObjectId
field :test_dataset_id, type: BSON::ObjectId
field :nr_instances, type: Integer
field :nr_unpredicted, type: Integer
@@ -81,7 +83,7 @@ module OpenTox
end
class RegressionValidation < Validation
- def self.create model, training_set, test_set
+ def self.create model, training_set, test_set, crossvalidation=nil
validation_model = Model::LazarRegression.create training_set
test_set_without_activities = Dataset.new(:compound_ids => test_set.compound_ids) # just to be sure that activities cannot be used
@@ -106,6 +108,7 @@ module OpenTox
:nr_unpredicted => nr_unpredicted,
:predictions => predictions.sort{|a,b| b[3] <=> a[3]} # sort according to confidence
)
+ validation.crossvalidation_id = crossvalidation.id if crossvalidation
validation.save
validation
end