summaryrefslogtreecommitdiff
path: root/lib/experiment.rb
diff options
context:
space:
mode:
Diffstat (limited to 'lib/experiment.rb')
-rw-r--r--lib/experiment.rb81
1 files changed, 41 insertions, 40 deletions
diff --git a/lib/experiment.rb b/lib/experiment.rb
index 2f51756..7849337 100644
--- a/lib/experiment.rb
+++ b/lib/experiment.rb
@@ -2,45 +2,22 @@ module OpenTox
class Experiment
field :dataset_ids, type: Array
- field :model_algorithms, type: Array
- field :model_ids, type: Array, default: []
- field :crossvalidation_ids, type: Array, default: []
- field :prediction_algorithms, type: Array
- field :neighbor_algorithms, type: Array
- field :neighbor_algorithm_parameters, type: Array
+ field :model_settings, type: Array
+ field :results, type: Hash, default: {}
end
- # TODO more sophisticated experimental design
def run
dataset_ids.each do |dataset_id|
dataset = Dataset.find(dataset_id)
- model_algorithms.each do |model_algorithm|
- prediction_algorithms.each do |prediction_algorithm|
- neighbor_algorithms.each do |neighbor_algorithm|
- neighbor_algorithm_parameters.each do |neighbor_algorithm_parameter|
- $logger.debug "Creating #{model_algorithm} model for dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}."
- model = Object.const_get(model_algorithm).create dataset
- model.prediction_algorithm = prediction_algorithm
- model.neighbor_algorithm = neighbor_algorithm
- model.neighbor_algorithm_parameters = neighbor_algorithm_parameter
- model.save
- model_ids << model.id
- cv = nil
- if dataset.features.first.nominal
- cv = ClassificationCrossValidation
- elsif dataset.features.first.numeric
- cv = RegressionCrossValidation
- end
- if cv
- $logger.debug "Creating #{cv} for #{model_algorithm}, dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}."
- crossvalidation = cv.create model
- self.crossvalidation_ids << crossvalidation.id
- else
- $logger.warn "#{dataset.features.first} is neither nominal nor numeric."
- end
- end
- end
- end
+ results[dataset_id.to_s] = []
+ model_settings.each do |setting|
+ model = Object.const_get(setting[:algorithm]).create dataset
+ model.prediction_algorithm = setting[:prediction_algorithm] if setting[:prediction_algorithm]
+ model.neighbor_algorithm = setting[:neighbor_algorithm] if setting[:neighbor_algorithm]
+ model.neighbor_algorithm_parameters = setting[:neighbor_algorithm_parameter] if setting[:neighbor_algorithm_parameter]
+ model.save
+ repeated_crossvalidation = RepeatedCrossValidation.create model
+ results[dataset_id.to_s] << {:model_id => model.id, :repeated_crossvalidation_id => repeated_crossvalidation.id}
end
end
save
@@ -54,13 +31,37 @@ module OpenTox
end
def report
- # TODO create ggplot2 report
- self.crossvalidation_ids.each do |id|
- cv = CrossValidation.find(id)
- file = "/tmp/#{id}.svg"
- File.open(file,"w+"){|f| f.puts cv.correlation_plot}
- `inkview '#{file}'`
+ # TODO significances
+ report = {}
+ report[:name] = name
+ report[:experiment_id] = self.id.to_s
+ dataset_ids.each do |dataset_id|
+ dataset_name = Dataset.find(dataset_id).name
+ report[dataset_name] = []
+ results[dataset_id.to_s].each do |result|
+ model = Model::Lazar.find(result[:model_id])
+ repeated_cv = RepeatedCrossValidation.find(result[:repeated_crossvalidation_id])
+ crossvalidations = repeated_cv.crossvalidations
+ summary = {}
+ [:neighbor_algorithm, :neighbor_algorithm_parameters, :prediction_algorithm].each do |key|
+ summary[key] = model[key]
+ end
+ summary[:nr_instances] = crossvalidations.first.nr_instances
+ summary[:nr_unpredicted] = crossvalidations.collect{|cv| cv.nr_unpredicted}
+ summary[:time] = crossvalidations.collect{|cv| cv.time}
+ if crossvalidations.first.is_a? ClassificationCrossValidation
+ summary[:accuracies] = crossvalidations.collect{|cv| cv.accuracy}
+ elsif crossvalidations.first.is_a? RegressionCrossValidation
+ summary[:r_squared] = crossvalidations.collect{|cv| cv.r_squared}
+ end
+ report[dataset_name] << summary
+ #p repeated_cv.crossvalidations.collect{|cv| cv.accuracy}
+ #file = "/tmp/#{id}.svg"
+ #File.open(file,"w+"){|f| f.puts cv.correlation_plot}
+ #`inkview '#{file}'`
+ end
end
+ report
end
end