diff options
author | Martin Gütlein <martin.guetlein@gmail.com> | 2009-11-26 18:45:06 +0100 |
---|---|---|
committer | Martin Gütlein <martin.guetlein@gmail.com> | 2009-11-26 18:45:06 +0100 |
commit | 3b57006e49a3f77025c80bd73072f1f4f3d9e261 (patch) | |
tree | a5ca6e1598783a4a79a4c722699dbaf9fe806fa2 /report | |
parent | 4e76efe724aa723e1f79fb09d4e908323720f1de (diff) |
fixed variance calculation, added ranking plots for compare algs on multiple datasets, ..
Diffstat (limited to 'report')
-rw-r--r-- | report/environment.rb | 1 | ||||
-rw-r--r-- | report/plot_factory.rb | 64 | ||||
-rw-r--r-- | report/r_plot_factory.rb | 2 | ||||
-rw-r--r-- | report/report_factory.rb | 68 | ||||
-rw-r--r-- | report/util.rb | 9 | ||||
-rw-r--r-- | report/validation_data.rb | 97 |
6 files changed, 216 insertions, 25 deletions
diff --git a/report/environment.rb b/report/environment.rb index 71a32fe..47d70b2 100644 --- a/report/environment.rb +++ b/report/environment.rb @@ -14,6 +14,7 @@ module Reports end load "report/r_plot_factory.rb" +load "report/plot_factory.rb" load "report/xml_report.rb" load "report/xml_report_util.rb" load "report/report_persistance.rb" diff --git a/report/plot_factory.rb b/report/plot_factory.rb new file mode 100644 index 0000000..1c229a9 --- /dev/null +++ b/report/plot_factory.rb @@ -0,0 +1,64 @@ +ENV['JAVA_HOME'] = "/usr/bin" unless ENV['JAVA_HOME'] +ENV['PATH'] = ENV['JAVA_HOME']+":"+ENV['PATH'] unless ENV['PATH'].split(":").index(ENV['JAVA_HOME']) +ENV['RANK_PLOTTER_JAR'] = "RankPlotter/RankPlotter.jar" unless ENV['RANK_PLOTTER_JAR'] + +module Reports + + module PlotFactory + + def self.create_ranking_plot( svg_out_file, validation_set, compare_attribute, equal_attribute, rank_attribute ) + + #compute ranks + rank_set = validation_set.compute_ranking([equal_attribute],rank_attribute) + #puts rank_set.to_array([:algorithm_uri, :dataset_uri, :acc, :acc_ranking]).collect{|a| a.inspect}.join("\n") + + #compute avg ranks + merge_set = rank_set.merge([compare_attribute]) + #puts merge_set.to_array([:algorithm_uri, :dataset_uri, :acc, :acc_ranking]).collect{|a| a.inspect}.join("\n") + + comparables = merge_set.get_values(compare_attribute) + ranks = merge_set.get_values((rank_attribute.to_s+"_ranking").to_sym) + + plot_ranking( rank_attribute.to_s+" ranking", + comparables, + ranks, + 0.1, + validation_set.num_different_values(equal_attribute), + svg_out_file) + end + + protected + def self.plot_ranking( title, comparables_array, ranks_array, confidence = nil, numdatasets = nil, svg_out_file = nil ) + + (confidence and numdatasets) ? conf = "-q "+confidence.to_s+" -k "+numdatasets.to_s : conf = "" + svg_out_file ? show = "-o" : show = "" + (title and title.length > 0) ? tit = '-t "'+title+'"' : tit = "" + #title = "-t \""+ranking_value_prop+"-Ranking ("+comparables.size.to_s+" "+comparable_prop+"s, "+num_groups.to_s+" "+ranking_group_prop+"s, p < "+p.to_s+")\" " + + cmd = "java -jar "+ENV['RANK_PLOTTER_JAR']+" "+tit+" -c '"+ + comparables_array.join(",")+"' -r '"+ranks_array.join(",")+"' "+conf+" "+show #+" > /home/martin/tmp/test.svg" + #puts "\nplotting: "+cmd + + res = "" + IO.popen(cmd) do |f| + while line = f.gets do + res += line + end + end + if svg_out_file + f = File.new(svg_out_file, "w") + f.puts res + end + + svg_out_file ? svg_out_file : res + end + + def self.demo_ranking_plot + puts plot_ranking( nil, ["naive bayes", "svm", "decision tree"], [1.9, 3, 1.5], 0.1, 50) #, "/home/martin/tmp/test.svg") + end + + + end +end + +#Reports::PlotFactory::demo_ranking_plot diff --git a/report/r_plot_factory.rb b/report/r_plot_factory.rb index 3fe7bcf..134385e 100644 --- a/report/r_plot_factory.rb +++ b/report/r_plot_factory.rb @@ -24,7 +24,7 @@ module Reports::RPlotFactory validation_set.validations.each do |v| b.add_data(v.send(title_attribute), value_attributes.collect{|a| v.send(a)}) end - b.build_plot(value_attributes) + b.build_plot(value_attributes.collect{|a| a.to_s}) end # creates a roc plot (result is plotted into out_file) diff --git a/report/report_factory.rb b/report/report_factory.rb index 837df1d..16a4e3d 100644 --- a/report/report_factory.rb +++ b/report/report_factory.rb @@ -1,10 +1,10 @@ # selected attributes of interest when generating the report for a train-/test-evaluation -VAL_ATTR_TRAIN_TEST = [ "model_uri", "training_dataset_uri", "test_dataset_uri" ] +VAL_ATTR_TRAIN_TEST = [ :model_uri, :training_dataset_uri, :test_dataset_uri ] # selected attributes of interest when generating the crossvalidation report -VAL_ATTR_CV = [ "algorithm_uri", "dataset_uri", "num_folds", "crossvalidation_fold" ] +VAL_ATTR_CV = [ :algorithm_uri, :dataset_uri, :num_folds, :crossvalidation_fold ] # selected attributes of interest when performing classification -VAL_ATTR_CLASS = [ "auc", "acc", "sens", "spec" ] +VAL_ATTR_CLASS = [ :auc, :acc, :sens, :spec ] # = Reports::ReportFactory @@ -63,14 +63,16 @@ module Reports::ReportFactory raise Reports::BadRequest.new("num validations ("+validation_set.size.to_s+") is not equal to num folds ("+validation_set.first.num_folds.to_s+")") unless validation_set.first.num_folds==validation_set.size raise Reports::BadRequest.new("num different folds is not equal to num validations") unless validation_set.num_different_values(:crossvalidation_fold)==validation_set.size merged = validation_set.merge([:crossvalidation_id]) + + puts merged.get_values(:acc_variance, false).inspect report = Reports::ReportContent.new("Crossvalidation report") report.add_section_result(merged, VAL_ATTR_CV+VAL_ATTR_CLASS-["crossvalidation_fold"],"Mean Results","Mean Results") report.add_section_roc_plot(validation_set) report.add_section_confusion_matrix(merged.first) - report.add_section_result(validation_set, VAL_ATTR_CV+VAL_ATTR_CLASS-["CV_num_folds"], "Results","Results") + report.add_section_result(validation_set, VAL_ATTR_CV+VAL_ATTR_CLASS-[:num_folds], "Results","Results") report.add_section_result(validation_set, OpenTox::Validation::ALL_PROPS, "All Results", "All Results") - report.add_section_predictions( validation_set, ["crossvalidation_fold"] ) + report.add_section_predictions( validation_set, [:crossvalidation_fold] ) return report end @@ -87,19 +89,33 @@ module Reports::ReportFactory raise Reports::BadRequest.new("number of different algorithms <2") if validation_set.num_different_values(:algorithm_uri)<2 if validation_set.num_different_values(:dataset_uri)>1 - raise Reports::BadRequest.new("so far, algorithm comparison is only supported for 1 dataset") + # groups results into sets with equal dataset + dataset_grouping = Reports::Util.group(validation_set.validations, [:dataset_uri]) + # check if equal values in each group exist + Reports::Util.check_group_matching(dataset_grouping, [:algorithm_uri, :crossvalidation_fold, :num_folds, :stratified, :random_seed]) + # we only checked that equal validations exist in each dataset group, now check for each algorithm + dataset_grouping.each do |validations| + algorithm_grouping = Reports::Util.group(validations, [:algorithm_uri]) + Reports::Util.check_group_matching(algorithm_grouping, [:crossvalidation_fold, :num_folds, :stratified, :random_seed]) + end + + merged = validation_set.merge([:algorithm_uri, :dataset_uri]) + report = Reports::ReportContent.new("Algorithm comparison report - Many datasets") + report.add_section_result(merged,VAL_ATTR_CV+VAL_ATTR_CLASS-[:crossvalidation_fold],"Mean Results","Mean Results") + report.add_section_ranking_plots(merged, :algorithm_uri, :dataset_uri, [:acc, :auc, :sens, :spec]) + return report else # this groups all validations in x different groups (arrays) according to there algorithm-uri - algorithm_grouping = Reports::Util.group(validation_set.validations, ["algorithm_uri"]) + algorithm_grouping = Reports::Util.group(validation_set.validations, [:algorithm_uri]) # we check if there are corresponding validations in each group that have equal attributes (folds, num-folds,..) Reports::Util.check_group_matching(algorithm_grouping, [:crossvalidation_fold, :num_folds, :dataset_uri, :stratified, :random_seed]) merged = validation_set.merge([:algorithm_uri]) report = Reports::ReportContent.new("Algorithm comparison report") - report.add_section_bar_plot(merged,"algorithm_uri",VAL_ATTR_CLASS) - report.add_section_roc_plot(validation_set, "algorithm_uri") - report.add_section_result(merged,VAL_ATTR_CV+VAL_ATTR_CLASS-["crossvalidation_fold"],"Mean Results","Mean Results") - report.add_section_result(validation_set,VAL_ATTR_CV+VAL_ATTR_CLASS-["num_folds"],"Results","Results") + report.add_section_bar_plot(merged,:algorithm_uri,VAL_ATTR_CLASS) + report.add_section_roc_plot(validation_set, :algorithm_uri) + report.add_section_result(merged,VAL_ATTR_CV+VAL_ATTR_CLASS-[:crossvalidation_fold],"Mean Results","Mean Results") + report.add_section_result(validation_set,VAL_ATTR_CV+VAL_ATTR_CLASS-[:num_folds],"Results","Results") return report end @@ -147,7 +163,7 @@ class Reports::ReportContent #PENDING rexml strings in tables not working when >66 vals = vals.collect{|a| a.collect{|v| v.to_s[0,66] }} #transpose values if there more than 8 columns - transpose = vals[0].size>8 + transpose = vals[0].size>8 && vals[0].size>vals.size @xml_report.add_table(section_table, table_title, vals, !transpose, transpose) end @@ -184,6 +200,34 @@ class Reports::ReportContent end + def add_section_ranking_plots( validation_set, + compare_attribute, + equal_attribute, + rank_attributes, + section_title="Ranking Plots", + section_text="This section contains the ranking plots.") + + section_rank = @xml_report.add_section(@xml_report.get_root_element, section_title) + @xml_report.add_paragraph(section_rank, section_text) if section_text + + rank_attributes.each{|a| add_ranking_plot(section_rank, validation_set, compare_attribute, equal_attribute, a, a.to_s+"-ranking.svg")} + end + + def add_ranking_plot( report_section, + validation_set, + compare_attribute, + equal_attribute, + rank_attribute, + plot_file_name="ranking.svg", + image_title="Ranking Plot", + image_caption=nil) + + plot_file_path = add_tmp_file(plot_file_name) + Reports::PlotFactory::create_ranking_plot(plot_file_path, validation_set, compare_attribute, equal_attribute, rank_attribute) + @xml_report.add_imagefigure(report_section, image_title, plot_file_name, "SVG", image_caption) + + end + def add_section_bar_plot(validation_set, title_attribute, value_attributes, diff --git a/report/util.rb b/report/util.rb index b460d48..5934064 100644 --- a/report/util.rb +++ b/report/util.rb @@ -78,7 +78,7 @@ module Reports::Util break end end - raise Reports::BadRequest.new("no match found for "+o.to_s) unless match + raise Reports::BadRequest.new("no match found for "+inspect_attributes(o, match_attributes)) unless match end end end @@ -98,5 +98,12 @@ module Reports::Util end return tmp_file_path end + + protected + def self.inspect_attributes(object, attributes) + res = object.class.to_s+" (" + res += attributes.collect{ |a| a.to_s+"->"+object.send(a).inspect }.join(", ") + res += ")" + end end
\ No newline at end of file diff --git a/report/validation_data.rb b/report/validation_data.rb index 9cb7cd1..eb092a1 100644 --- a/report/validation_data.rb +++ b/report/validation_data.rb @@ -1,6 +1,7 @@ # the variance is computed when merging results for these attributes VAL_ATTR_VARIANCE = [ :auc, :acc ] +VAL_ATTR_RANKING = [ :auc, :acc, :spec, :sens ] class Object @@ -54,9 +55,11 @@ module Reports # class Validation - #VAL_ATTR.each{ |a| attr_accessor a } - OpenTox::Validation::ALL_PROPS.each{ |a| attr_accessor a } - VAL_ATTR_VARIANCE.each{ |a| attr_accessor (a.to_s+"_variance").to_sym } + @@validation_attributes = OpenTox::Validation::ALL_PROPS + + VAL_ATTR_VARIANCE.collect{ |a| (a.to_s+"_variance").to_sym } + + VAL_ATTR_RANKING.collect{ |a| (a.to_s+"_ranking").to_sym } + + @@validation_attributes.each{ |a| attr_accessor a } attr_reader :predictions, :merge_count @@ -85,6 +88,13 @@ module Reports Reports.validation_access.init_cv(self) end + def clone_validation + new_val = clone + VAL_ATTR_VARIANCE.each { |a| new_val.send((a.to_s+"_variance=").to_sym,nil) } + new_val.set_merge_count(1) + return new_val + end + # merges this validation and another validation object to a new validation object # * v1.att = "a", v2.att = "a" => r.att = "a" # * v1.att = "a", v2.att = "b" => r.att = "a / b" @@ -98,10 +108,10 @@ module Reports new_validation = Reports::Validation.new raise "not working" if validation.merge_count > 1 + + @@validation_attributes.each do |a| + next if a.to_s =~ /_variance$/ - OpenTox::Validation::ALL_PROPS.each do |a| - next if a =~ /_variance$/ - if (equal_attributes.index(a) != nil) new_validation.send("#{a.to_s}=".to_sym, send(a)) else @@ -109,12 +119,16 @@ module Reports variance = nil if (send(a).is_a?(Float) || send(a).is_a?(Integer)) + old_value = send(a) value = (send(a) * @merge_count + validation.send(a)) / (@merge_count + 1).to_f; if (VAL_ATTR_VARIANCE.index(a) != nil) - old_std_dev = 0; - old_std_dev = send((a.to_s+"_variance").to_sym) ** 2 if send((a.to_s+"_variance").to_sym) - std_dev = (old_std_dev * (@merge_count / (@merge_count + 1.0))) + (((validation.send(a) - value) ** 2) * (1 / @merge_count)) - variance = Math.sqrt(std_dev); + # use revursiv formular for computing the variance + # ( see Tysiak, Folgen: explizit und rekursiv, ISSN: 0025-5866 + # http://www.frl.de/tysiakpapers/07_TY_Papers.pdf ) + old_variance = 0 unless (old_variance = send((a.to_s+"_variance").to_sym)) + variance = old_variance*(@merge_count-1)/@merge_count + + (value-old_value)**2 + + (validation.send(a)-value)**2/@merge_count end else if send(a).to_s != validation.send(a).to_s @@ -245,7 +259,7 @@ module Reports #merge grouping.each do |g| - new_set.validations.push(g[0].clone) + new_set.validations.push(g[0].clone_validation) g[1..-1].each do |v| new_set.validations[-1] = new_set.validations[-1].merge(v, equal_attributes) end @@ -253,7 +267,68 @@ module Reports return new_set end + + # creates a new validaiton set, that contains a ranking for __ranking_attribute__ + # (i.e. for ranking attribute :acc, :acc_ranking is calculated) + # all validation with equal values for __equal_attributes__ are compared + # (the one with highest value of __ranking_attribute__ has rank 1, and so on) + # + # call-seq: + # compute_ranking(equal_attributes, ranking_attribute) => array + # + def compute_ranking(equal_attributes, ranking_attribute) + + new_set = Reports::ValidationSet.new + (0..@validations.size-1).each do |i| + new_set.validations.push(@validations[i].clone_validation) + end + + grouping = Reports::Util.group(new_set.validations, equal_attributes) + grouping.each do |group| + # put indices and ranking values for current group into hash + rank_hash = {} + (0..group.size-1).each do |i| + rank_hash[i] = group[i].send(ranking_attribute) + end + + # sort group accrording to second value (= ranking value) + rank_array = rank_hash.sort { |a, b| b[1] <=> a[1] } + + # create ranks array + ranks = Array.new + (0..rank_array.size-1).each do |j| + + val = rank_array.at(j)[1] + rank = j+1 + ranks.push(rank.to_f) + + # check if previous ranks have equal value + equal_count = 1; + equal_rank_sum = rank; + + while ( j - equal_count >= 0 && (val - rank_array.at(j - equal_count)[1]).abs < 0.0001 ) + equal_rank_sum += ranks.at(j - equal_count); + equal_count += 1; + end + + # if previous ranks have equal values -> replace with avg rank + if (equal_count > 1) + (0..equal_count-1).each do |k| + ranks[j-k] = equal_rank_sum / equal_count.to_f; + end + end + end + + # set rank as validation value + (0..rank_array.size-1).each do |j| + index = rank_array.at(j)[0] + group[index].send( (ranking_attribute.to_s+"_ranking=").to_sym, ranks[j]) + end + end + + return new_set + end def size return @validations.size |