Class: Eps::BaseEstimator

Inherits:
Object
  • Object
show all
Defined in:
lib/eps/base_estimator.rb

Direct Known Subclasses

LightGBM, LinearRegression, NaiveBayes

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(data = nil, y = nil, **options) ⇒ BaseEstimator


3
4
5
# File 'lib/eps/base_estimator.rb', line 3

def initialize(data = nil, y = nil, **options)
  train(data, y, **options) if data
end

Class Method Details

.extract_text_features(data, features) ⇒ Object

private


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# File 'lib/eps/base_estimator.rb', line 74

def self.extract_text_features(data, features)
  # updates features object
  vocabulary = {}
  function_mapping = {}
  derived_fields = {}
  data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
    name = n.attribute("name")&.value
    field = n.css("FieldRef").attribute("field").value
    value = n.css("Constant").text

    field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
    next if value.empty?

    (vocabulary[field] ||= []) << value

    function_mapping[field] = n.css("Apply").attribute("function").value

    derived_fields[name] = [field, value]
  end

  functions = {}
  data.css("TransformationDictionary DefineFunction").each do |n|
    name = n.attribute("name").value
    text_index = n.css("TextIndex")
    functions[name] = {
      tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
      case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
    }
  end

  text_features = {}
  function_mapping.each do |field, function|
    text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
    features[field] = "text"
  end

  [text_features, derived_fields]
end

.load_pmml(data) ⇒ Object


40
41
42
43
44
45
46
47
48
# File 'lib/eps/base_estimator.rb', line 40

def self.load_pmml(data)
  if data.is_a?(String)
    data = Nokogiri::XML(data) { |config| config.strict }
  end
  model = new
  model.instance_variable_set("@pmml", data) # cache data
  model.instance_variable_set("@evaluator", yield(data))
  model
end

Instance Method Details

#evaluate(data, y = nil, target: nil) ⇒ Object


31
32
33
34
# File 'lib/eps/base_estimator.rb', line 31

def evaluate(data, y = nil, target: nil)
  data, target = prep_data(data, y, target || @target)
  Eps.metrics(data.label, predict(data))
end

#predict(data) ⇒ Object


7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/eps/base_estimator.rb', line 7

def predict(data)
  singular = data.is_a?(Hash)
  data = [data] if singular

  data = Eps::DataFrame.new(data)

  @evaluator.features.each do |k, type|
    values = data.columns[k]
    raise ArgumentError, "Missing column: #{k}" if !values
    column_type = Utils.column_type(values.compact, k) if values

    if !column_type.nil?
      if (type == "numeric" && column_type != "numeric") || (type != "numeric" && column_type != "categorical")
        raise ArgumentError, "Bad type for column #{k}: Expected #{type} but got #{column_type}"
      end
    end
    # TODO check for unknown values for categorical features
  end

  predictions = @evaluator.predict(data)

  singular ? predictions.first : predictions
end

#summary(extended: false) ⇒ Object


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# File 'lib/eps/base_estimator.rb', line 50

def summary(extended: false)
  str = String.new("")

  if @validation_set
    y_true = @validation_set.label
    y_pred = predict(@validation_set)

    case @target_type
    when "numeric"
      metric_name = "RMSE"
      v = Metrics.rmse(y_true, y_pred)
      metric_value = v.round >= 1000 ? v.round.to_s : "%.3g" % v
    else
      metric_name = "accuracy"
      metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred)).round(1)
    end
    str << "Validation %s: %s\n\n"  % [metric_name, metric_value]
  end

  str << _summary(extended: extended)
  str
end

#to_pmmlObject


36
37
38
# File 'lib/eps/base_estimator.rb', line 36

def to_pmml
  (@pmml ||= generate_pmml).to_xml
end