From 4348eec89033e6677c9f628646fc67bd03c73fe6 Mon Sep 17 00:00:00 2001 From: Christoph Helma Date: Thu, 6 Oct 2016 19:14:10 +0200 Subject: nano caret regression fixed --- lib/model.rb | 64 +++++++++++++++++++++++++++--------------------------------- 1 file changed, 29 insertions(+), 35 deletions(-) (limited to 'lib/model.rb') diff --git a/lib/model.rb b/lib/model.rb index a272580..290309a 100644 --- a/lib/model.rb +++ b/lib/model.rb @@ -23,10 +23,12 @@ module OpenTox # explicit prediction algorithm if algorithms[:prediction] and algorithms[:prediction][:method] case algorithms[:prediction][:method] - when /Classifiction/ + when /Classification/i model = LazarClassification.new - when /Regression/ + when /Regression/i model = LazarRegression.new + else + bad_request_error "Prediction method '#{algorithms[:prediction][:method]}' not implemented." end # guess model type @@ -36,6 +38,10 @@ module OpenTox model = LazarClassification.new end + model.prediction_feature_id = prediction_feature.id + model.training_dataset_id = training_dataset.id + model.name = "#{training_dataset.name} #{prediction_feature.name}" + # set defaults substance_classes = training_dataset.substances.collect{|s| s.class.to_s}.uniq bad_request_error "Cannot create models for mixed substance classes '#{substance_classes.join ', '}'." unless substance_classes.size == 1 @@ -60,7 +66,7 @@ module OpenTox } elsif model.class == LazarRegression model.algorithms[:prediction] = { - :method => "Algorithm::Regression.caret", + :method => "Algorithm::Caret.regression", :parameters => "pls", } end @@ -77,7 +83,7 @@ module OpenTox :min => 0.5 }, :prediction => { - :method => "Algorithm::Regression.caret", + :method => "Algorithm::Caret.regression", :parameters => "rf", }, :feature_selection => { @@ -100,10 +106,6 @@ module OpenTox end end - model.prediction_feature_id = prediction_feature.id - model.training_dataset_id = training_dataset.id - model.name = "#{training_dataset.name} #{prediction_feature.name}" - if model.algorithms[:feature_selection] and model.algorithms[:feature_selection][:method] model.relevant_features = Algorithm.run model.algorithms[:feature_selection][:method], dataset: training_dataset, prediction_feature: prediction_feature, types: model.algorithms[:descriptors][:types] end @@ -151,8 +153,12 @@ module OpenTox else bad_request_error "Descriptor method '#{algorithms[:descriptors][:method]}' not available." end - params = algorithms[:prediction].merge({:descriptors => descriptors, :neighbors => neighbors}) - params.delete :method + params = { + :method => algorithms[:prediction][:parameters], + :descriptors => descriptors, + :neighbors => neighbors, + :relevant_features => relevant_features + } result = Algorithm.run algorithms[:prediction][:method], params prediction.merge! result prediction[:neighbors] = neighbors @@ -218,11 +224,9 @@ module OpenTox end class LazarClassification < Lazar - end class LazarRegression < Lazar - end class Prediction @@ -240,7 +244,7 @@ module OpenTox field :leave_one_out_validation_id, type: BSON::ObjectId def predict object - Lazar.find(model_id).predict object + model.predict object end def training_dataset @@ -251,6 +255,10 @@ module OpenTox Lazar.find model_id end + def prediction_feature + model.prediction_feature + end + def repeated_crossvalidation Validation::RepeatedCrossValidation.find repeated_crossvalidation_id end @@ -276,15 +284,8 @@ module OpenTox bad_request_error "No metadata file #{metadata_file}" unless File.exist? metadata_file prediction_model = self.new JSON.parse(File.read(metadata_file)) training_dataset = Dataset.from_csv_file file - prediction_feature = training_dataset.features.first - model = nil - if prediction_feature.nominal? - model = LazarClassification.create prediction_feature, training_dataset - elsif prediction_feature.numeric? - model = LazarRegression.create prediction_feature, training_dataset - end + model = Lazar.create training_dataset: training_dataset prediction_model[:model_id] = model.id - prediction_model[:prediction_feature_id] = prediction_feature.id prediction_model[:repeated_crossvalidation_id] = Validation::RepeatedCrossValidation.create(model).id #prediction_model[:leave_one_out_validation_id] = Validation::LeaveOneOut.create(model).id prediction_model.save @@ -297,26 +298,19 @@ module OpenTox def self.from_json_dump dir, category Import::Enanomapper.import dir - + training_dataset = Dataset.where(:name => "Protein Corona Fingerprinting Predicts the Cellular Interaction of Gold and Silver Nanoparticles").first + unless training_dataset + Import::Enanomapper.import File.join(File.dirname(__FILE__),"data","enm") + training_dataset = Dataset.where(name: "Protein Corona Fingerprinting Predicts the Cellular Interaction of Gold and Silver Nanoparticles").first + end prediction_model = self.new( :endpoint => "log2(Net cell association)", :source => "https://data.enanomapper.net/", :species => "A549 human lung epithelial carcinoma cells", :unit => "log2(ug/Mg)" ) - params = { - :feature_selection_algorithm => :correlation_filter, - :feature_selection_algorithm_parameters => {:category => category}, - :neighbor_algorithm => "physchem_neighbors", - :neighbor_algorithm_parameters => {:min_sim => 0.5}, - :prediction_algorithm => "OpenTox::Algorithm::Regression.physchem_regression", - :prediction_algorithm_parameters => {:method => 'rf'}, # random forests - } - training_dataset = Dataset.find_or_create_by(:name => "Protein Corona Fingerprinting Predicts the Cellular Interaction of Gold and Silver Nanoparticles") - prediction_feature = Feature.find_or_create_by(name: "log2(Net cell association)", category: "TOX") - #prediction_feature = Feature.find("579621b84de73e267b414e55") - prediction_model[:prediction_feature_id] = prediction_feature.id - model = Model::LazarRegression.create(prediction_feature, training_dataset, params) + prediction_feature = Feature.where(name: "log2(Net cell association)", category: "TOX").first + model = Model::LazarRegression.create(prediction_feature: prediction_feature, training_dataset: training_dataset) prediction_model[:model_id] = model.id repeated_cv = Validation::RepeatedCrossValidation.create model prediction_model[:repeated_crossvalidation_id] = Validation::RepeatedCrossValidation.create(model).id -- cgit v1.2.3