summaryrefslogtreecommitdiff
path: root/lib/crossvalidation.rb
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2015-08-25 17:20:55 +0200
committerChristoph Helma <helma@in-silico.ch>2015-08-25 17:20:55 +0200
commitf8faf510b4574df1a00fa61a9f0a1681fc2f4857 (patch)
treeacdbe6666ca5f528be368c6f9fdf4d7fb51d031e /lib/crossvalidation.rb
parent8c6c59980bc82dc2177147f2fe34adf8bfbc1539 (diff)
Experiments added
Diffstat (limited to 'lib/crossvalidation.rb')
-rw-r--r--lib/crossvalidation.rb109
1 files changed, 85 insertions, 24 deletions
diff --git a/lib/crossvalidation.rb b/lib/crossvalidation.rb
index 5af75bf..4407aeb 100644
--- a/lib/crossvalidation.rb
+++ b/lib/crossvalidation.rb
@@ -102,6 +102,8 @@ module OpenTox
field :mae, type: Float
field :weighted_rmse, type: Float
field :weighted_mae, type: Float
+ field :weighted_mae, type: Float
+ field :correlation_plot_id, type: BSON::ObjectId
def self.create model, n=10
cv = self.new
@@ -135,10 +137,11 @@ module OpenTox
weighted_rae = 0
n = 0
confidence_sum = 0
+ nil_activities = []
predictions.each do |pred|
compound_id,activity,prediction,confidence = pred
if activity and prediction
- error = prediction-activity
+ error = Math.log(prediction)-Math.log(activity)
rmse += error**2
weighted_rmse += confidence*error**2
mae += error.abs
@@ -147,13 +150,36 @@ module OpenTox
confidence_sum += confidence
else
# TODO: create warnings
- p pred
+ $logger.debug "No training activities for #{Compound.find(compound_id).smiles} in training dataset #{training_dataset.id}."
+ nil_activities << pred
end
end
+ predictions -= nil_activities
+ x = predictions.collect{|p| p[1]}
+ y = predictions.collect{|p| p[2]}
+ R.assign "Measurement", x
+ R.assign "Prediction", y
+ R.eval "corr <- lm(-log(Measurement) ~ -log(Prediction))"
+ s = R.eval "summary <- summary(corr)"
+ p R.eval("summary$r.squared").to_ruby
+ #p s.to_ruby
+ #p s.to_ruby.first
+ s.to_ruby.each_with_index do |l,i|
+ #p i
+ #p l
+ end
mae = mae/n
weighted_mae = weighted_mae/confidence_sum
rmse = Math.sqrt(rmse/n)
weighted_rmse = Math.sqrt(weighted_rmse/confidence_sum)
+ # TODO check!!
+ predictions.sort! do |a,b|
+ relative_error_a = (a[1]-a[2]).abs/a[1].to_f
+ relative_error_a = 1/relative_error_a if relative_error_a < 1
+ relative_error_b = (b[1]-b[2]).abs/b[1].to_f
+ relative_error_b = 1/relative_error_b if relative_error_b < 1
+ [relative_error_b,b[3]] <=> [relative_error_a,a[3]]
+ end
cv.update_attributes(
name: model.name,
model_id: model.id,
@@ -161,7 +187,7 @@ module OpenTox
validation_ids: validation_ids,
nr_instances: nr_instances,
nr_unpredicted: nr_unpredicted,
- predictions: predictions.sort{|a,b| b[3] <=> a[3]},
+ predictions: predictions,#.sort{|a,b| [(b[1]-b[2]).abs/b[1].to_f,b[3]] <=> [(a[1]-a[2]).abs/a[1].to_f,a[3]]},
mae: mae,
rmse: rmse,
weighted_mae: weighted_mae,
@@ -171,27 +197,62 @@ module OpenTox
cv
end
- def plot
- # RMSE
- x = predictions.collect{|p| p[1]}
- y = predictions.collect{|p| p[2]}
- R.assign "Measurement", x
- R.assign "Prediction", y
- R.eval "par(pty='s')" # sets the plot type to be square
- #R.eval "fitline <- lm(log(Prediction) ~ log(Measurement))"
- #R.eval "error <- log(Measurement)-log(Prediction)"
- R.eval "error <- Measurement-Prediction"
- R.eval "rmse <- sqrt(mean(error^2,na.rm=T))"
- R.eval "mae <- mean( abs(error), na.rm = TRUE)"
- R.eval "r <- cor(log(Prediction),log(Measurement))"
- R.eval "svg(filename='/tmp/#{id.to_s}.svg')"
- R.eval "plot(log(Prediction),log(Measurement),main='#{self.name}', sub=paste('RMSE: ',rmse, 'MAE :',mae, 'r^2: ',r^2),asp=1)"
- #R.eval "plot(log(Prediction),log(Measurement),main='#{self.name}', sub=paste('RMSE: ',rmse, 'MAE :',mae, 'r^2: '),asp=1)"
- #R.eval "plot(log(Prediction),log(Measurement),main='#{self.name}', ,asp=1)"
- R.eval "abline(0,1,col='blue')"
- #R.eval "abline(fitline,col='red')"
- R.eval "dev.off()"
- "/tmp/#{id.to_s}.svg"
+ def misclassifications n=nil
+ #n = predictions.size unless n
+ n = 20 unless n
+ model = Model::Lazar.find(self.model_id)
+ training_dataset = Dataset.find(model.training_dataset_id)
+ prediction_feature = training_dataset.features.first
+ predictions[0..n-1].collect do |p|
+ compound = Compound.find(p[0])
+ neighbors = compound.neighbors.collect do |n|
+ neighbor = Compound.find(n[0])
+ values = training_dataset.values(neighbor,prediction_feature)
+ { :smiles => neighbor.smiles, :fingerprint => neighbor.fp4.collect{|id| Smarts.find(id).name},:similarity => n[1], :measurements => values}
+ end
+ {
+ :smiles => compound.smiles,
+ :fingerprint => compound.fp4.collect{|id| Smarts.find(id).name},
+ :measured => p[1],
+ :predicted => p[2],
+ :relative_error => (p[1]-p[2]).abs/p[1].to_f,
+ :confidence => p[3],
+ :neighbors => neighbors
+ }
+ end
+ end
+
+ def correlation_plot
+ unless correlation_plot_id
+ tmpfile = "/tmp/#{id.to_s}.svg"
+ x = predictions.collect{|p| p[1]}
+ y = predictions.collect{|p| p[2]}
+ attributes = Model::Lazar.find(self.model_id).attributes
+ attributes.delete_if{|key,_| key.match(/_id|_at/) or ["_id","creator","name"].include? key}
+ attributes = attributes.values.collect{|v| v.is_a?(String) ? v.sub(/OpenTox::/,'') : v}.join("\n")
+ p "'"+attributes
+ R.eval "library(ggplot2)"
+ R.eval "library(grid)"
+ R.eval "library(gridExtra)"
+ R.assign "measurement", x
+ R.assign "prediction", y
+ #R.eval "error <- log(Measurement)-log(Prediction)"
+ #R.eval "rmse <- sqrt(mean(error^2, na.rm=T))"
+ #R.eval "mae <- mean(abs(error), na.rm=T)"
+ R.eval "r <- cor(-log(prediction),-log(measurement))"
+ R.eval "svg(filename='#{tmpfile}')"
+ R.eval "all = c(-log(measurement),-log(prediction))"
+ R.eval "range = c(min(all), max(all))"
+ R.eval "image = qplot(-log(prediction),-log(measurement),main='#{self.name}',asp=1,xlim=range, ylim=range)"
+ R.eval "image = image + geom_abline(intercept=0, slope=1) + stat_smooth(method='lm', se=FALSE)"
+ R.eval "text = textGrob(paste('RMSE: ', '#{rmse.round(2)},','MAE:','#{mae.round(2)},','r^2: ',round(r^2,2),'\n\n','#{attributes}'),just=c('left','top'),check.overlap = T)"
+ R.eval "grid.arrange(image, text, ncol=2)"
+ R.eval "dev.off()"
+ file = Mongo::Grid::File.new(File.read(tmpfile), :filename => "#{self.id.to_s}_correlation_plot.svg")
+ plot_id = $gridfs.insert_one(file)
+ update(:correlation_plot_id => plot_id)
+ end
+ $gridfs.find_one(_id: correlation_plot_id).data
end
end