summaryrefslogtreecommitdiff
path: root/lib/experiment.rb
blob: 191e76e29c5f7b99d4ab5beb2c69c49fe5a064de (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
module OpenTox

  class Experiment
    field :dataset_ids, type: Array
    field :model_algorithms, type: Array
    field :model_ids, type: Array, default: []
    field :crossvalidation_ids, type: Array, default: []
    field :prediction_algorithms, type: Array
    field :neighbor_algorithms, type: Array
    field :neighbor_algorithm_parameters, type: Array
  end

  # TODO more sophisticated experimental design
  def run 
    dataset_ids.each do |dataset_id|
      dataset = Dataset.find(dataset_id)
      model_algorithms.each do |model_algorithm|
        prediction_algorithms.each do |prediction_algorithm|
          neighbor_algorithms.each do |neighbor_algorithm|
            neighbor_algorithm_parameters.each do |neighbor_algorithm_parameter|
              $logger.debug "Creating #{model_algorithm} model for dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}."
              model = Object.const_get(model_algorithm).create dataset
              model.prediction_algorithm = prediction_algorithm
              model.neighbor_algorithm = neighbor_algorithm
              model.neighbor_algorithm_parameters = neighbor_algorithm_parameter
              model.save
              model_ids << model.id
              cv = nil
              if dataset.features.first.nominal
                cv = ClassificationCrossValidation
              elsif dataset.features.first.numeric
                cv = RegressionCrossValidation
              end
              if cv
                $logger.debug "Creating #{cv} for #{model_algorithm}, dataset #{dataset.name}, with prediction_algorithm #{prediction_algorithm}, neighbor_algorithm #{neighbor_algorithm}, neighbor_algorithm_parameters #{neighbor_algorithm_parameter}."
                crossvalidation = cv.create model
                crossvalidation_ids << crossvalidation.id
              else
                $logger.warn "#{dataset.features.first} is neither nominal nor numeric."
              end
            end
          end
        end
      end
    end
    save
  end

  def self.create params
    experiment = self.new
    $logge.debug "Experiment started ..."
    experiment.run params
    experiment
  end

  def report
    # TODO create ggplot2 report
    crossvalidation_ids.each do |id|
      cv = CrossValidation.find(id)
      file = "/tmp/#{id}.svg"
      File.open(file,"w+"){|f| f.puts cv.correlation_plot}
      `inkview '#{file}'`
    end
  end

end