From d3a4c309d48b794f2f60f44bb9a3d94f402cc82f Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Wed, 16 Sep 2015 13:11:45 +0200 Subject: repeated crossvalidations, improved experiment reports --- lib/experiment.rb | 81 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 41 insertions(+), 40 deletions(-) (limited to 'lib/experiment.rb') diff --git a/lib/experiment.rb b/lib/experiment.rb index 2f51756..7849337 100644 --- a/lib/experiment.rb +++ b/lib/experiment.rb @@ -2,45 +2,22 @@ module OpenTox class Experiment field :dataset_ids, type: Array - field :model_algorithms, type: Array - field :model_ids, type: Array, default: [] - field :crossvalidation_ids, type: Array, default: [] - field :prediction_algorithms, type: Array - field :neighbor_algorithms, type: Array - field :neighbor_algorithm_parameters, type: Array + field :model_settings, type: Array + field :results, type: Hash, default: {} end - # TODO more sophisticated experimental design def run dataset_ids.each do |dataset_id| dataset = Dataset.find(dataset_id) - model_algorithms.each do |model_algorithm| - prediction_algorithms.each do |prediction_algorithm| - neighbor_algorithms.each do |neighbor_algorithm| - neighbor_algorithm_parameters.each do |neighbor_algorithm_parameter| - $logger.debug "Creating #{model_algorithm} model for dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}." - model = Object.const_get(model_algorithm).create dataset - model.prediction_algorithm = prediction_algorithm - model.neighbor_algorithm = neighbor_algorithm - model.neighbor_algorithm_parameters = neighbor_algorithm_parameter - model.save - model_ids << model.id - cv = nil - if dataset.features.first.nominal - cv = ClassificationCrossValidation - elsif dataset.features.first.numeric - cv = RegressionCrossValidation - end - if cv - $logger.debug "Creating #{cv} for #{model_algorithm}, dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}." - crossvalidation = cv.create model - self.crossvalidation_ids << crossvalidation.id - else - $logger.warn "#{dataset.features.first} is neither nominal nor numeric." - end - end - end - end + results[dataset_id.to_s] = [] + model_settings.each do |setting| + model = Object.const_get(setting[:algorithm]).create dataset + model.prediction_algorithm = setting[:prediction_algorithm] if setting[:prediction_algorithm] + model.neighbor_algorithm = setting[:neighbor_algorithm] if setting[:neighbor_algorithm] + model.neighbor_algorithm_parameters = setting[:neighbor_algorithm_parameter] if setting[:neighbor_algorithm_parameter] + model.save + repeated_crossvalidation = RepeatedCrossValidation.create model + results[dataset_id.to_s] << {:model_id => model.id, :repeated_crossvalidation_id => repeated_crossvalidation.id} end end save @@ -54,13 +31,37 @@ module OpenTox end def report - # TODO create ggplot2 report - self.crossvalidation_ids.each do |id| - cv = CrossValidation.find(id) - file = "/tmp/#{id}.svg" - File.open(file,"w+"){|f| f.puts cv.correlation_plot} - `inkview '#{file}'` + # TODO significances + report = {} + report[:name] = name + report[:experiment_id] = self.id.to_s + dataset_ids.each do |dataset_id| + dataset_name = Dataset.find(dataset_id).name + report[dataset_name] = [] + results[dataset_id.to_s].each do |result| + model = Model::Lazar.find(result[:model_id]) + repeated_cv = RepeatedCrossValidation.find(result[:repeated_crossvalidation_id]) + crossvalidations = repeated_cv.crossvalidations + summary = {} + [:neighbor_algorithm, :neighbor_algorithm_parameters, :prediction_algorithm].each do |key| + summary[key] = model[key] + end + summary[:nr_instances] = crossvalidations.first.nr_instances + summary[:nr_unpredicted] = crossvalidations.collect{|cv| cv.nr_unpredicted} + summary[:time] = crossvalidations.collect{|cv| cv.time} + if crossvalidations.first.is_a? ClassificationCrossValidation + summary[:accuracies] = crossvalidations.collect{|cv| cv.accuracy} + elsif crossvalidations.first.is_a? RegressionCrossValidation + summary[:r_squared] = crossvalidations.collect{|cv| cv.r_squared} + end + report[dataset_name] << summary + #p repeated_cv.crossvalidations.collect{|cv| cv.accuracy} + #file = "/tmp/#{id}.svg" + #File.open(file,"w+"){|f| f.puts cv.correlation_plot} + #`inkview '#{file}'` + end end + report end end -- cgit v1.2.3