Class: Eps::NaiveBayes

Inherits:
BaseEstimator show all
Defined in:
lib/eps/naive_bayes.rb

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from BaseEstimator

#evaluate, extract_text_features, #initialize, #predict, #summary, #to_pmml

Constructor Details

This class inherits a constructor from Eps::BaseEstimator

Instance Attribute Details

#probabilitiesObject (readonly)

Returns the value of attribute probabilities.



3
4
5
# File 'lib/eps/naive_bayes.rb', line 3

def probabilities
  @probabilities
end

Class Method Details

.load_pmml(data) ⇒ Object

pmml



11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# File 'lib/eps/naive_bayes.rb', line 11

def self.load_pmml(data)
  super do |data|
    # TODO more validation
    node = data.css("NaiveBayesModel")

    prior = {}
    node.css("BayesOutput TargetValueCount").each do |n|
      prior[n.attribute("value").value] = n.attribute("count").value.to_f
    end

    legacy = false

    conditional = {}
    features = {}
    node.css("BayesInput").each do |n|
      prob = {}

      # numeric
      n.css("TargetValueStat").each do |n2|
        n3 = n2.css("GaussianDistribution")
        prob[n2.attribute("value").value] = {
          mean: n3.attribute("mean").value.to_f,
          stdev: Math.sqrt(n3.attribute("variance").value.to_f)
        }
      end

      # detect bad form in Eps < 0.3
      bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys

      # categorical
      n.css("PairCounts").each do |n2|
        if bad_format
          n2.css("TargetValueCount").each do |n3|
            prob[n3.attribute("value").value] ||= {}
            prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value)
          end
        else
          boom = {}
          n2.css("TargetValueCount").each do |n3|
            boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value)
          end
          prob[n2.attribute("value").value] = boom
        end
      end

      if bad_format
        legacy = true
        prob.each do |k, v|
          prior.keys.each do |k|
            v[k] ||= 0.0
          end
        end
      end

      name = n.attribute("fieldName").value
      conditional[name] = prob
      features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
    end

    target = node.css("BayesOutput").attribute("fieldName").value

    probabilities = {
      prior: prior,
      conditional: conditional
    }

    # get derived fields
    derived = {}
    data.css("DerivedField").each do |n|
      name = n.attribute("name").value
      field = n.css("NormDiscrete").attribute("field").value
      value = n.css("NormDiscrete").attribute("value").value
      features.delete(name)
      features[field] = "derived"
      derived[field] ||= {}
      derived[field][name] = value
    end

    Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
  end
end

Instance Method Details

#accuracyObject



5
6
7
# File 'lib/eps/naive_bayes.rb', line 5

def accuracy
  Eps::Metrics.accuracy(@train_set.label, predict(@train_set))
end