Class: Riktoken::Encoding
- Inherits:
-
Object
- Object
- Riktoken::Encoding
- Defined in:
- lib/riktoken/encoding.rb
Defined Under Namespace
Classes: DisallowedSpecialTokenError, InvalidTokenError
Instance Attribute Summary collapse
-
#name ⇒ Object
readonly
Returns the value of attribute name.
Instance Method Summary collapse
- #decode(tokens) ⇒ Object
- #encode(text, allowed_special: Set.new, disallowed_special: "all") ⇒ Object
-
#initialize(name:, ranks:, pattern:, special_tokens: {}) ⇒ Encoding
constructor
A new instance of Encoding.
Constructor Details
#initialize(name:, ranks:, pattern:, special_tokens: {}) ⇒ Encoding
21 22 23 24 25 |
# File 'lib/riktoken/encoding.rb', line 21 def initialize(name:, ranks:, pattern:, special_tokens: {}) @name = name @special_tokens = special_tokens @bpe = BPE.new(encoder: ranks, regex: pattern, special_tokens_encoder: special_tokens) end |
Instance Attribute Details
#name ⇒ Object (readonly)
Returns the value of attribute name.
11 12 13 |
# File 'lib/riktoken/encoding.rb', line 11 def name @name end |
Instance Method Details
#decode(tokens) ⇒ Object
48 49 50 |
# File 'lib/riktoken/encoding.rb', line 48 def decode(tokens) @bpe.decode(tokens) end |
#encode(text, allowed_special: Set.new, disallowed_special: "all") ⇒ Object
31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# File 'lib/riktoken/encoding.rb', line 31 def encode(text, allowed_special: Set.new, disallowed_special: "all") allowed_special = Set.new(@special_tokens.keys) if allowed_special == "all" disallowed_special = Set.new(@special_tokens.keys) - allowed_special if disallowed_special == "all" unless disallowed_special.empty? found = text.scan(Regexp.union(disallowed_special.to_a)).uniq found_disallowed = found & disallowed_special.to_a unless found_disallowed.empty? raise DisallowedSpecialTokenError, "Disallowed special token(s) found: #{found_disallowed.join(", ")}" end end @bpe.encode(text, allowed_special_tokens: allowed_special)[0] end |