summaryrefslogtreecommitdiff
path: root/bin/crossvalidation-folds.rb
blob: 0c765f7667a3d2f65c1771fab32da14c2a552392 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#!/usr/bin/env ruby
require_relative "../lib/lazar"
model = Model.new ARGV[0]
ARGV[1] ? folds = ARGV[1].to_i : folds = 10
nr_instances = model.train.size
indices = (0..nr_instances-1).to_a.shuffle
mid = (nr_instances/folds)
start = 0
0.upto(folds-1) do |i|
  fork do
    # split train data
    puts "Creating fold #{i}"
    last = start+mid
    last = last-1 unless nr_instances%folds > i
    test_idxs = indices[start..last] || []
    idxs = {
      :train => indices-test_idxs,
      :test => test_idxs
    }
    start = last+1
    # write training/test data
    cv_dir = File.join(File.dirname(ARGV[0]),"crossvalidation",i.to_s)
    idxs.each do |t,idx|
      file = File.join(cv_dir,t.to_s+".csv")
      `mkdir -p #{File.dirname file}`
      case t
      when :train
        File.open(file,"w+") do |f|
          f.puts (["Canonical SMILES",model.dependent_variable_name] + model.independent_variable_names).join(",")
          idx.collect{|i| model.train[i]}.each do |t|
            f.puts t.join(",")
          end
        end
      when :test
        File.open(file,"w+") do |f|
          f.puts (["Canonical SMILES"] + model.independent_variable_names).join(",")
          idx.collect{|i| model.train[i]}.each do |t|
            t.delete_at(1)
            f.puts t.join(",")
          end
        end
        file = File.join(cv_dir,t.to_s+"-experimental.csv")
        File.open(file,"w+") do |f|
          f.puts (["Canonical SMILES", model.dependent_variable_name]).join(",")
          idx.collect{|i| model.train[i]}.each do |t|
            # TODO fix
            f.puts t[0..1].join(",")
          end
        end
      end
    end
    Process.waitall
  end
end