summaryrefslogtreecommitdiff
path: root/lib/experiment.rb
diff options
context:
space:
mode:
authorChristoph Helma <helma@in-silico.ch>2015-08-25 17:20:55 +0200
committerChristoph Helma <helma@in-silico.ch>2015-08-25 17:20:55 +0200
commitf8faf510b4574df1a00fa61a9f0a1681fc2f4857 (patch)
treeacdbe6666ca5f528be368c6f9fdf4d7fb51d031e /lib/experiment.rb
parent8c6c59980bc82dc2177147f2fe34adf8bfbc1539 (diff)
Experiments added
Diffstat (limited to 'lib/experiment.rb')
-rw-r--r--lib/experiment.rb66
1 files changed, 66 insertions, 0 deletions
diff --git a/lib/experiment.rb b/lib/experiment.rb
new file mode 100644
index 0000000..b3ed174
--- /dev/null
+++ b/lib/experiment.rb
@@ -0,0 +1,66 @@
+module OpenTox
+
+ class Experiment
+ field :dataset_ids, type: Array
+ field :model_algorithms, type: Array
+ field :model_ids, type: Array, default: []
+ field :crossvalidation_ids, type: Array, default: []
+ field :prediction_algorithms, type: Array
+ field :neighbor_algorithms, type: Array
+ field :neighbor_algorithm_parameters, type: Array
+ end
+
+ # TODO more sophisticated experimental design
+ def run
+ dataset_ids.each do |dataset_id|
+ dataset = Dataset.find(dataset_id)
+ model_algorithms.each do |model_algorithm|
+ prediction_algorithms.each do |prediction_algorithm|
+ neighbor_algorithms.each do |neighbor_algorithm|
+ neighbor_algorithm_parameters.each do |neighbor_algorithm_parameter|
+ $logger.debug "Creating #{model_algorithm} model for dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}."
+ model = Object.const_get(model_algorithm).create dataset
+ model.prediction_algorithm = prediction_algorithm
+ model.neighbor_algorithm = neighbor_algorithm
+ model.neighbor_algorithm_parameters = neighbor_algorithm_parameter
+ model.save
+ model_ids << model.id
+ cv = nil
+ if dataset.features.first.nominal
+ cv = ClassificationCrossValidation
+ elsif dataset.features.first.numeric
+ cv = RegressionCrossValidation
+ end
+ if cv
+ $logger.debug "Creating #{cv} for #{model_algorithm}, dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}."
+ crossvalidation = cv.create model
+ crossvalidation_ids << crossvalidation.id
+ else
+ $logger.warn "#{dataset.features.first} is neither nominal nor numeric."
+ end
+ end
+ end
+ end
+ end
+ end
+ save
+ end
+
+ def self.create params
+ experiment = self.new
+ $logge.debug "Experiment started ..."
+ experiment.run params
+ experiment
+ end
+
+ def report
+ crossvalidation_ids.each do |id|
+ cv = CrossValidation.find(id)
+ file = "/tmp/#{cv.name}.svg"
+ File.open(file,"w+"){|f| f.puts cv.correlation_plot}
+ `inkview '#{file}'`
+ #p Crossvalidation.find(id).correlation_plot
+ end
+ end
+
+end