summaryrefslogtreecommitdiff
path: root/bin/crossvalidation-folds.rb
blob: 16a410387112ebede630c571f9e585b862e8e281 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env ruby
require_relative "../lib/lazar"
model = Model.new ARGV[0]
ARGV[1] ? folds = ARGV[1].to_i : folds = 10
nr_instances = model.train.size
indices = (0..nr_instances-1).to_a.shuffle
mid = (nr_instances/folds)
start = 0
0.upto(folds-1) do |i|
  fork do
    # split train data
    puts "Creating fold #{i}"
    last = start+mid
    last = last-1 unless nr_instances%folds > i
    test_idxs = indices[start..last] || []
    idxs = {
      :train => indices-test_idxs,
      :test => test_idxs
    }
    p idxs
    start = last+1
    # write training/test data
    cv_dir = File.join(File.dirname(ARGV[0]),"crossvalidation",i.to_s)
    idxs.each do |t,idx|
      file = File.join(cv_dir,t.to_s+".csv")
      `mkdir -p #{File.dirname file}`
      case t
      when :train
        File.open(file,"w+") do |f|
          f.puts (["Canonical SMILES",model.dependent_variable_name] + model.independent_variable_names).join(",")
          idx.collect{|i| model.train[i]}.each do |t|
            f.puts t.join(",")
          end
        end
      when :test
        File.open(file,"w+") do |f|
          f.puts (["Canonical SMILES"] + model.independent_variable_names).join(",")
          idx.collect{|i| model.train[i]}.each do |t|
            t.delete_at(1)
            f.puts t.join(",")
          end
        end
      end
    end
    Process.waitall
  end
end