Class: Riktoken::BPE

Inherits:
Object
  • Object
show all
Defined in:
lib/riktoken/bpe.rb

Defined Under Namespace

Classes: TextEncodingError

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(encoder:, regex:, special_tokens_encoder:) ⇒ BPE

Returns a new instance of BPE.



18
19
20
21
22
23
24
25
# File 'lib/riktoken/bpe.rb', line 18

def initialize(encoder:, regex:, special_tokens_encoder:)
  @encoder = encoder
  @regex = regex
  @special_tokens_encoder = special_tokens_encoder
  @special_regex = Regexp.union(special_tokens_encoder.keys)
  @decoder = encoder.map { |k, v| [v, k] }.to_h
  @special_tokens_decoder = special_tokens_encoder.map { |k, v| [v, k] }.to_h
end

Instance Attribute Details

#decoderObject (readonly)

: Hash[rank, String]



8
9
10
# File 'lib/riktoken/bpe.rb', line 8

def decoder
  @decoder
end

#encoderObject (readonly)

: Hash[String, rank] – parameter like parsed *.tiktoken file



7
8
9
# File 'lib/riktoken/bpe.rb', line 7

def encoder
  @encoder
end

#regexObject (readonly)

: Regexp



11
12
13
# File 'lib/riktoken/bpe.rb', line 11

def regex
  @regex
end

#special_regexObject (readonly)

: Regexp



12
13
14
# File 'lib/riktoken/bpe.rb', line 12

def special_regex
  @special_regex
end

#special_tokens_decoderObject (readonly)

: Hash[rank, String]



10
11
12
# File 'lib/riktoken/bpe.rb', line 10

def special_tokens_decoder
  @special_tokens_decoder
end

#special_tokens_encoderObject (readonly)

: Hash[String, rank]



9
10
11
# File 'lib/riktoken/bpe.rb', line 9

def special_tokens_encoder
  @special_tokens_encoder
end

Class Method Details

.byte_pair_encode(piece, ranks) ⇒ Object



113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# File 'lib/riktoken/bpe.rb', line 113

def self.byte_pair_encode(piece, ranks)
  return [ranks[piece]] if ranks[piece]

  chars = piece.bytes.map(&:chr)

  loop do
    # Find the pair with the smallest rank among all adjacent pairs in ranks
    min_rank = nil
    min_pair_pos = nil
    (0...chars.size - 1).each do |i|
      pair = chars[i] + chars[i + 1]
      if ranks.key?(pair) && (min_rank.nil? || ranks[pair] < min_rank)
        min_rank = ranks[pair]
        min_pair_pos = i
      end
    end
    break unless min_pair_pos

    # merge: `min_pair_pos` and `min_pair_pos+1`
    chars = chars[0...min_pair_pos] + [chars[min_pair_pos] + chars[min_pair_pos + 1]] + chars[(min_pair_pos + 2)..]
    # after merging, it attempts re-searching from the start to maximize the merging unit
  end

  chars.map { |c| ranks[c] }
end

Instance Method Details

#decode(tokens) ⇒ Object

Decode given tokens back into text encoded as UTF-8.



100
101
102
103
104
105
106
107
108
# File 'lib/riktoken/bpe.rb', line 100

def decode(tokens)
  return "" if tokens.empty?
  encoded = tokens.map { |t| @decoder[t] || @special_tokens_decoder[t] }.join.force_encoding("UTF-8")
  if encoded.valid_encoding?
    encoded
  else
    raise TextEncodingError, "failed to apply the text encoding to decoded tokens as valid UTF-8"
  end
end

#encode(text, allowed_special_tokens: Set.new) ⇒ Object

Encode given text into tokens using the BPE encoding, allowing for given special tokens.



36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# File 'lib/riktoken/bpe.rb', line 36

def encode(text, allowed_special_tokens: Set.new)
  tokens = []
  start = 0
  last_piece_token_len = 0

  loop do
    next_special = nil
    start_find = start
    while start_find < text.length
      m = @special_regex.match(text, start_find)
      if m.nil?
        break
      elsif allowed_special_tokens.include?(m[0])
        next_special = m
        break
      else
        start_find = m.begin(0) + 1
      end
    end

    end_pos = next_special ? next_special.begin(0) : text.length

    segment = text[start...end_pos]
    segment.scan(@regex) do |m|
      piece = m.is_a?(Array) ? m[0] : m
      if @encoder.key?(piece)
        last_piece_token_len = 1
        tokens << @encoder[piece]
      else
        bpe_tokens = self.class.byte_pair_encode(piece, @encoder)
        last_piece_token_len = bpe_tokens.size
        tokens.concat(bpe_tokens)
      end
    end

    break unless next_special

    piece = next_special[0]
    token = @special_tokens_encoder[piece]
    tokens << token
    start = next_special.end(0)
    last_piece_token_len = 0
  end

  [tokens, last_piece_token_len]
end

#encode_ordinary(text) ⇒ Object

Encode given text into tokens using the BPE encoding without considering special tokens.



86
87
88
# File 'lib/riktoken/bpe.rb', line 86

def encode_ordinary(text)
  encode(text)[0]
end

#encode_with_special_tokens(text) ⇒ Object

Encode given text into tokens using the BPE encoding, allowing for all special tokens.



93
94
95
# File 'lib/riktoken/bpe.rb', line 93

def encode_with_special_tokens(text)
  encode(text, allowed_special_tokens: special_tokens)
end

#special_tokensObject



28
29
30
# File 'lib/riktoken/bpe.rb', line 28

def special_tokens
  Set.new(@special_tokens_encoder.keys)
end