diff options
author | Martin Gütlein <martin.guetlein@gmail.com> | 2010-03-23 14:07:14 +0100 |
---|---|---|
committer | Martin Gütlein <martin.guetlein@gmail.com> | 2010-03-23 14:07:14 +0100 |
commit | 14d2a68564061d63166cd409bf4fd30dc841d2b8 (patch) | |
tree | e07cbc10883ee8c116caa443048df0471efe02b2 /report/report_factory.rb | |
parent | 6a5ebb67493ab2c30121ae26fb75d6a24c36eafc (diff) |
added predictedValues feature, some report changes (true feature prediction class), some hacks to validate IDEA/AMBIT
Diffstat (limited to 'report/report_factory.rb')
-rw-r--r-- | report/report_factory.rb | 57 |
1 files changed, 34 insertions, 23 deletions
diff --git a/report/report_factory.rb b/report/report_factory.rb index fbcd2eb..52c0642 100644 --- a/report/report_factory.rb +++ b/report/report_factory.rb @@ -4,8 +4,8 @@ VAL_ATTR_TRAIN_TEST = [ :model_uri, :training_dataset_uri, :test_dataset_uri, :p # selected attributes of interest when generating the crossvalidation report VAL_ATTR_CV = [ :algorithm_uri, :dataset_uri, :num_folds, :crossvalidation_fold ] # selected attributes of interest when performing classification -VAL_ATTR_CLASS = [ :area_under_roc, :percent_correct, :true_positive_rate, :true_negative_rate ] -VAL_ATTR_BAR_PLOT_CLASS = [ :area_under_roc, :accuracy, :true_positive_rate, :true_negative_rate ] +VAL_ATTR_CLASS = [ :percent_correct, :weighted_area_under_roc, :area_under_roc, :f_measure, :true_positive_rate, :true_negative_rate ] +VAL_ATTR_BAR_PLOT_CLASS = [ :accuracy, :weighted_area_under_roc, :area_under_roc, :f_measure, :true_positive_rate, :true_negative_rate ] VAL_ATTR_REGR = [ :root_mean_squared_error, :mean_absolute_error, :r_square ] @@ -56,7 +56,7 @@ module Reports::ReportFactory #val.get_prediction_feature_values.each do |class_value| #report.add_section_roc_plot(validation_set, class_value, nil, "roc-plot-"+class_value+".svg") #end - report.add_section_confusion_matrix(validation_set.first) + report.add_section_confusion_matrix(val) else #regression report.add_section_result(validation_set, VAL_ATTR_TRAIN_TEST + VAL_ATTR_REGR, "Results", "Results") end @@ -69,28 +69,29 @@ module Reports::ReportFactory def self.create_report_crossvalidation(validation_set) raise Reports::BadRequest.new("num validations is not >1") unless validation_set.size>1 - raise Reports::BadRequest.new("crossvalidation-id not set in all validations") if validation_set.has_nil_values?(:crossvalidation_id) - raise Reports::BadRequest.new("num different cross-validation-id's must be equal to 1") unless validation_set.num_different_values(:crossvalidation_id)==1 + raise Reports::BadRequest.new("crossvalidation-id not unique and != nil") if validation_set.unique_value(:crossvalidation_id)==nil validation_set.load_cv_attributes - raise Reports::BadRequest.new("num validations ("+validation_set.size.to_s+") is not equal to num folds ("+validation_set.first.num_folds.to_s+")") unless validation_set.first.num_folds==validation_set.size + raise Reports::BadRequest.new("num validations ("+validation_set.size.to_s+") is not equal to num folds ("+ + validation_set.unique_value(:num_folds).to_s+")") unless validation_set.unique_value(:num_folds)==validation_set.size raise Reports::BadRequest.new("num different folds is not equal to num validations") unless validation_set.num_different_values(:crossvalidation_fold)==validation_set.size raise Reports::BadRequest.new("validations must be either all regression, "+ +"or all classification validations") unless validation_set.all_classification? or validation_set.all_regression? merged = validation_set.merge([:crossvalidation_id]) + raise unless merged.size==1 #puts merged.get_values(:percent_correct_variance, false).inspect report = Reports::ReportContent.new("Crossvalidation report") - if (validation_set.first.classification?) + if (validation_set.all_classification?) report.add_section_result(merged, VAL_ATTR_CV+VAL_ATTR_CLASS-[:crossvalidation_fold],"Mean Results","Mean Results") report.add_section_roc_plot(validation_set, nil, nil, "roc-plot.svg", nil, nil, "Roc plot") report.add_section_roc_plot(validation_set, nil, :crossvalidation_fold, "roc-plot-folds.svg", nil, nil, "Roc plots for folds") - #validation_set.validations[0].get_prediction_feature_values.each do |class_value| + #validation_set.first.get_prediction_feature_values.each do |class_value| #report.add_section_roc_plot(validation_set, class_value, nil, "roc-plot-"+class_value+".svg") #end - report.add_section_confusion_matrix(merged.first) + report.add_section_confusion_matrix(merged.validations[0]) report.add_section_result(validation_set, VAL_ATTR_CV+VAL_ATTR_CLASS-[:num_folds], "Results","Results") else #regression report.add_section_result(merged, VAL_ATTR_CV+VAL_ATTR_REGR-[:crossvalidation_fold],"Mean Results","Mean Results") @@ -121,10 +122,10 @@ module Reports::ReportFactory #merged = validation_set.merge([:algorithm_uri, :dataset_uri]) report = Reports::ReportContent.new("Algorithm comparison report - Many datasets") - if (validation_set.first.classification?) + if (validation_set.all_classification?) report.add_section_result(validation_set,[:algorithm_uri, :test_dataset_uri]+VAL_ATTR_CLASS,"Mean Results","Mean Results") report.add_section_ranking_plots(validation_set, :algorithm_uri, :test_dataset_uri, - [:accuracy, :true_positive_rate, :true_negative_rate], "true") + [:percent_correct, :true_positive_rate, :true_negative_rate], "true") else # regression raise Reports::BadRequest.new("not implemented yet for regression") end @@ -137,10 +138,10 @@ module Reports::ReportFactory report = Reports::ReportContent.new("Algorithm comparison report") - if (validation_set.first.classification?) + if (validation_set.all_classification?) report.add_section_bar_plot(validation_set,nil,:algorithm_uri,VAL_ATTR_BAR_PLOT_CLASS, "bar-plot.svg") report.add_section_roc_plot(validation_set,nil, :algorithm_uri, "roc-plot.svg") - #validation_set.validations[0].get_prediction_feature_values.each do |class_value| + #validation_set.first.get_prediction_feature_values.each do |class_value| #report.add_section_bar_plot(validation_set,class_value,:algorithm_uri,VAL_ATTR_CLASS, "bar-plot-"+class_value+".svg") #report.add_section_roc_plot(validation_set, class_value, :algorithm_uri, "roc-plot-"+class_value+".svg") #end @@ -169,7 +170,7 @@ module Reports::ReportFactory merged = validation_set.merge([:algorithm_uri, :dataset_uri]) report = Reports::ReportContent.new("Algorithm comparison report - Many datasets") - if (validation_set.first.classification?) + if (validation_set.all_classification?) report.add_section_result(merged,VAL_ATTR_CV+VAL_ATTR_CLASS-[:crossvalidation_fold],"Mean Results","Mean Results") report.add_section_ranking_plots(merged, :algorithm_uri, :dataset_uri, [:acc, :auc, :sens, :spec], "true") else # regression @@ -186,12 +187,21 @@ module Reports::ReportFactory report = Reports::ReportContent.new("Algorithm comparison report") - if (validation_set.first.classification?) - validation_set.validations[0].get_prediction_feature_values.each do |class_value| - report.add_section_bar_plot(merged,class_value,:algorithm_uri,VAL_ATTR_CLASS, "bar-plot-"+class_value+".svg") - report.add_section_roc_plot(validation_set, class_value, :algorithm_uri, "roc-plot-"+class_value+".svg") - end + if (validation_set.all_classification?) + report.add_section_result(merged,VAL_ATTR_CV+VAL_ATTR_CLASS-[:crossvalidation_fold],"Mean Results","Mean Results") + + true_class = validation_set.get_true_prediction_feature_value + if true_class!=nil + report.add_section_bar_plot(merged,true_class,:algorithm_uri,VAL_ATTR_BAR_PLOT_CLASS, "bar-plot.svg") + report.add_section_roc_plot(validation_set, nil, :algorithm_uri, "roc-plot.svg") + else + validation_set.get_prediction_feature_values.each do |class_value| + report.add_section_bar_plot(merged,class_value,:algorithm_uri,VAL_ATTR_BAR_PLOT_CLASS, "bar-plot-"+class_value+".svg") + report.add_section_roc_plot(validation_set, class_value, :algorithm_uri, "roc-plot-"+class_value+".svg") + end + end + report.add_section_result(validation_set,VAL_ATTR_CV+VAL_ATTR_CLASS-[:num_folds],"Results","Results") else #regression report.add_section_result(merged, VAL_ATTR_CV+VAL_ATTR_REGR-[:crossvalidation_fold],"Mean Results","Mean Results") @@ -224,7 +234,7 @@ class Reports::ReportContent table_title="Predictions") section_table = @xml_report.add_section(@xml_report.get_root_element, section_title) - if validation_set.first.get_predictions + if validation_set.validations[0].get_predictions @xml_report.add_paragraph(section_table, section_text) if section_text @xml_report.add_table(section_table, table_title, Reports::PredictionUtil.predictions_to_array(validation_set, validation_attributes)) else @@ -240,7 +250,7 @@ class Reports::ReportContent section_table = @xml_report.add_section(xml_report.get_root_element, section_title) @xml_report.add_paragraph(section_table, section_text) if section_text - vals = validation_set.to_array(validation_attributes) + vals = validation_set.to_array(validation_attributes,false,validation_set.get_true_prediction_feature_value) #PENDING rexml strings in tables not working when >66 vals = vals.collect{|a| a.collect{|v| v.to_s[0,66] }} #PENDING transpose values if there more than 4 columns, and there are more than columns than rows @@ -318,9 +328,10 @@ class Reports::ReportContent rank_attribute, class_value=nil, plot_file_name="ranking.svg", - image_title="Ranking Plot", + image_title=nil, image_caption=nil) - + + image_title = "Ranking Plot for class value: '"+class_value.to_s+"'" if image_title==nil plot_file_path = add_tmp_file(plot_file_name) Reports::PlotFactory::create_ranking_plot(plot_file_path, validation_set, compare_attribute, equal_attribute, rank_attribute, class_value) @xml_report.add_imagefigure(report_section, image_title, plot_file_name, "SVG", image_caption) |