Class: Rumale::NaiveBayes::BaseNaiveBayes

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator, Base::Classifier
Defined in:
lib/rumale/naive_bayes/naive_bayes.rb

Overview

BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier.

Direct Known Subclasses

BernoulliNB, GaussianNB, MultinomialNB

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Methods included from Base::Classifier

#fit, #score

Instance Method Details

#predict(x) ⇒ Numo::Int32

Predict class labels for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the labels.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted class label per sample.



18
19
20
21
22
23
# File 'lib/rumale/naive_bayes/naive_bayes.rb', line 18

def predict(x)
  check_sample_array(x)
  n_samples = x.shape.first
  decision_values = decision_function(x)
  Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
end

#predict_log_proba(x) ⇒ Numo::DFloat

Predict log-probability for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the log-probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.



29
30
31
32
33
34
# File 'lib/rumale/naive_bayes/naive_bayes.rb', line 29

def predict_log_proba(x)
  check_sample_array(x)
  n_samples, = x.shape
  log_likelihoods = decision_function(x)
  log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(1)).reshape(n_samples, 1)
end

#predict_proba(x) ⇒ Numo::DFloat

Predict probability for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted probability of each class per sample.



40
41
42
43
# File 'lib/rumale/naive_bayes/naive_bayes.rb', line 40

def predict_proba(x)
  check_sample_array(x)
  Numo::NMath.exp(predict_log_proba(x)).abs
end