Class: Riktoken::Encoding

Inherits:
Object
  • Object
show all
Defined in:
lib/riktoken/encoding.rb

Defined Under Namespace

Classes: DisallowedSpecialTokenError, InvalidTokenError

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(name:, ranks:, pattern:, special_tokens: {}) ⇒ Encoding



21
22
23
24
25
# File 'lib/riktoken/encoding.rb', line 21

def initialize(name:, ranks:, pattern:, special_tokens: {})
  @name = name
  @special_tokens = special_tokens
  @bpe = BPE.new(encoder: ranks, regex: pattern, special_tokens_encoder: special_tokens)
end

Instance Attribute Details

#nameObject (readonly)

Returns the value of attribute name.



11
12
13
# File 'lib/riktoken/encoding.rb', line 11

def name
  @name
end

Instance Method Details

#decode(tokens) ⇒ Object



48
49
50
# File 'lib/riktoken/encoding.rb', line 48

def decode(tokens)
  @bpe.decode(tokens)
end

#encode(text, allowed_special: Set.new, disallowed_special: "all") ⇒ Object



31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/riktoken/encoding.rb', line 31

def encode(text, allowed_special: Set.new, disallowed_special: "all")
  allowed_special = Set.new(@special_tokens.keys) if allowed_special == "all"
  disallowed_special = Set.new(@special_tokens.keys) - allowed_special if disallowed_special == "all"

  unless disallowed_special.empty?
    found = text.scan(Regexp.union(disallowed_special.to_a)).uniq
    found_disallowed = found & disallowed_special.to_a
    unless found_disallowed.empty?
      raise DisallowedSpecialTokenError, "Disallowed special token(s) found: #{found_disallowed.join(", ")}"
    end
  end

  @bpe.encode(text, allowed_special_tokens: allowed_special)[0]
end