summaryrefslogtreecommitdiff
path: root/test/prediction_models.rb
blob: b4ad415e71d0bc193af76832339968b7a4201e93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
require_relative "setup.rb"

class PredictionModelTest < MiniTest::Test

  def test_prediction_model
    dataset = Dataset.from_csv_file "#{DATA_DIR}/hamster_carcinogenicity.csv"
    model = Model::LazarFminerClassification.create dataset
    cv = ClassificationCrossValidation.create model
    metadata = JSON.parse(File.read("#{DATA_DIR}/hamster_carcinogenicity.json"))

    metadata[:model_id] = model.id
    metadata[:crossvalidation_id] = cv.id
    pm = Model::Prediction.new(metadata)
    pm.save
    [:endpoint,:species,:source].each do |p|
      refute_empty pm[p]
    end
    assert pm.crossvalidation.accuracy > 0.8
    prediction = pm.predict Compound.from_smiles("CCCC(NN)C")
    assert_equal "true", prediction[:value]
    pm.delete
  end
end