Class: FastText::Classifier

Inherits:
Model
  • Object
show all
Defined in:
lib/fasttext/classifier.rb

Constant Summary collapse

DEFAULT_OPTIONS =
{
  lr: 0.1,
  lr_update_rate: 100,
  dim: 100,
  ws: 5,
  epoch: 5,
  min_count: 1,
  min_count_label: 0,
  neg: 5,
  word_ngrams: 1,
  loss: "softmax",
  model: "supervised",
  bucket: 2000000,
  minn: 0,
  maxn: 0,
  thread: 3,
  t: 0.0001,
  label_prefix: "__label__",
  verbose: 2,
  pretrained_vectors: "",
  save_output: false,
  seed: 0,
  autotune_validation_file: "",
  autotune_metric: "f1",
  autotune_predictions: 1,
  autotune_duration: 60 * 5,
  autotune_model_size: ""
}

Instance Method Summary collapse

Methods inherited from Model

#dimension, #initialize, #quantized?, #save_model, #sentence_vector, #subword_id, #subwords, #word_id, #word_vector, #words

Constructor Details

This class inherits a constructor from FastText::Model

Instance Method Details

#fit(x, y = nil, autotune_set: nil) ⇒ Object



32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/fasttext/classifier.rb', line 32

def fit(x, y = nil, autotune_set: nil)
  input, _ref = input_path(x, y)
  @m ||= Ext::Model.new
  a = build_args(DEFAULT_OPTIONS)
  a.input = input
  a.model = "supervised"
  if autotune_set
    x, y = autotune_set
    a.autotune_validation_file, _autotune_ref = input_path(x, y)
  end
  m.train(a)
end

#labels(include_freq: false) ⇒ Object



76
77
78
79
80
81
82
83
84
# File 'lib/fasttext/classifier.rb', line 76

def labels(include_freq: false)
  labels, freqs = m.labels
  labels.map! { |v| remove_prefix(v) }
  if include_freq
    labels.zip(freqs).to_h
  else
    labels
  end
end

#predict(text, k: 1, threshold: 0.0) ⇒ Object



45
46
47
48
49
50
51
52
53
54
55
56
57
58
# File 'lib/fasttext/classifier.rb', line 45

def predict(text, k: 1, threshold: 0.0)
  multiple = text.is_a?(Array)
  text = [text] unless multiple

  # TODO predict multiple in C++ for performance
  result =
    text.map do |t|
      m.predict(prep_text(t), k, threshold).to_h do |v|
        [remove_prefix(v[1]), v[0]]
      end
    end

  multiple ? result : result.first
end

#quantizeObject

TODO support options



71
72
73
74
# File 'lib/fasttext/classifier.rb', line 71

def quantize
  a = Ext::Args.new
  m.quantize(a)
end

#test(x, y = nil, k: 1) ⇒ Object



60
61
62
63
64
65
66
67
68
# File 'lib/fasttext/classifier.rb', line 60

def test(x, y = nil, k: 1)
  input, _ref = input_path(x, y)
  res = m.test(input, k)
  {
    examples: res[0],
    precision: res[1],
    recall: res[2]
  }
end