diff options
Diffstat (limited to 'validation/validation_service.rb')
-rwxr-xr-x | validation/validation_service.rb | 65 |
1 files changed, 30 insertions, 35 deletions
diff --git a/validation/validation_service.rb b/validation/validation_service.rb index 889c652..dceead9 100755 --- a/validation/validation_service.rb +++ b/validation/validation_service.rb @@ -2,6 +2,7 @@ require "lib/validation_db.rb" require "lib/ot_predictions.rb" +require "lib/r-util.rb" require "validation/validation_format.rb" @@ -618,17 +619,17 @@ module Validation # splits a dataset into test and training dataset # returns map with training_dataset_uri and test_dataset_uri - def self.train_test_dataset_split( orig_dataset_uri, prediction_feature, subjectid, split_ratio=nil, random_seed=nil, task=nil ) + def self.train_test_dataset_split( orig_dataset_uri, prediction_feature, subjectid, stratified=false, split_ratio=nil, random_seed=nil, task=nil ) split_ratio=0.67 unless split_ratio split_ratio = split_ratio.to_f random_seed=1 unless random_seed random_seed = random_seed.to_i + raise OpenTox::NotFoundError.new "Split ratio invalid: "+split_ratio.to_s unless split_ratio and split_ratio=split_ratio.to_f + raise OpenTox::NotFoundError.new "Split ratio not >0 and <1 :"+split_ratio.to_s unless split_ratio>0 && split_ratio<1 orig_dataset = Lib::DatasetCache.find orig_dataset_uri, subjectid orig_dataset.load_all subjectid raise OpenTox::NotFoundError.new "Dataset not found: "+orig_dataset_uri.to_s unless orig_dataset - raise OpenTox::NotFoundError.new "Split ratio invalid: "+split_ratio.to_s unless split_ratio and split_ratio=split_ratio.to_f - raise OpenTox::NotFoundError.new "Split ratio not >0 and <1 :"+split_ratio.to_s unless split_ratio>0 && split_ratio<1 if prediction_feature raise OpenTox::NotFoundError.new "Prediction feature '"+prediction_feature.to_s+ "' not found in dataset, features are: \n"+ @@ -637,55 +638,49 @@ module Validation LOGGER.warn "no prediciton feature given, all features included in test dataset" end - compounds = orig_dataset.compounds - raise OpenTox::BadRequestError.new "Cannot split datset, num compounds in dataset < 2 ("+compounds.size.to_s+")" if compounds.size<2 - split = (compounds.size*split_ratio).to_i - split = [split,1].max - split = [split,compounds.size-2].min - - LOGGER.debug "splitting dataset "+orig_dataset_uri+ + if stratified + Lib::RUtil.init_r + df = Lib::RUtil.dataset_to_dataframe( orig_dataset ) + split = Lib::RUtil.stratified_split( df, split_ratio, random_seed ) + Lib::RUtil.quit_r + raise "internal error" unless split.size==orig_dataset.compounds.size + task.progress(33) if task + + training_compounds = [] + split.size.times{|i| training_compounds << orig_dataset.compounds[i] if split[i]==1} + test_compounds = orig_dataset.compounds - training_compounds + else + compounds = orig_dataset.compounds + raise OpenTox::BadRequestError.new "Cannot split datset, num compounds in dataset < 2 ("+compounds.size.to_s+")" if compounds.size<2 + split = (compounds.size*split_ratio).to_i + split = [split,1].max + split = [split,compounds.size-2].min + LOGGER.debug "splitting dataset "+orig_dataset_uri+ " into train:0-"+split.to_s+" and test:"+(split+1).to_s+"-"+(compounds.size-1).to_s+ " (shuffled with seed "+random_seed.to_s+")" - compounds.shuffle!( random_seed ) + compounds.shuffle!( random_seed ) + training_compounds = compounds[0..split] + test_compounds = compounds[(split+1)..-1] + end task.progress(33) if task - - result = {} -# result[:training_dataset_uri] = orig_dataset.create_new_dataset( compounds[0..split], -# orig_dataset.features, -# "Training dataset split of "+orig_dataset.title.to_s, -# $sinatra.url_for('/training_test_split',:full) ) -# orig_dataset.data_entries.each do |k,v| -# puts k.inspect+" =>"+v.inspect -# puts v.values[0].to_s+" "+v.values[0].class.to_s -# end + result = {} - result[:training_dataset_uri] = orig_dataset.split( compounds[0..split], + result[:training_dataset_uri] = orig_dataset.split( training_compounds, orig_dataset.features.keys, { DC.title => "Training dataset split of "+orig_dataset.title.to_s, DC.creator => $url_provider.url_for('/training_test_split',:full) }, subjectid ).uri task.progress(66) if task -# d = Lib::DatasetCache.find(result[:training_dataset_uri]) -# d.data_entries.values.each do |v| -# puts v.inspect -# puts v.values[0].to_s+" "+v.values[0].class.to_s -# end -# raise "stop here" - -# result[:test_dataset_uri] = orig_dataset.create_new_dataset( compounds[(split+1)..-1], -# orig_dataset.features.dclone - [prediction_feature], -# "Test dataset split of "+orig_dataset.title.to_s, -# $sinatra.url_for('/training_test_split',:full) ) - result[:test_dataset_uri] = orig_dataset.split( compounds[(split+1)..-1], + result[:test_dataset_uri] = orig_dataset.split( test_compounds, orig_dataset.features.keys.dclone - [prediction_feature], { DC.title => "Test dataset split of "+orig_dataset.title.to_s, DC.creator => $url_provider.url_for('/training_test_split',:full) }, subjectid ).uri task.progress(100) if task - if ENV['RACK_ENV'] =~ /test|debug/ + if !stratified and ENV['RACK_ENV'] =~ /test|debug/ raise OpenTox::NotFoundError.new "Training dataset not found: '"+result[:training_dataset_uri].to_s+"'" unless Lib::DatasetCache.find(result[:training_dataset_uri],subjectid) test_data = Lib::DatasetCache.find result[:test_dataset_uri],subjectid |