Class: Transformers::BatchEncoding
- Inherits:
-
Object
- Object
- Transformers::BatchEncoding
- Defined in:
- lib/transformers/tokenization_utils_base.rb
Instance Method Summary collapse
- #[](item) ⇒ Object
- #convert_to_tensors(tensor_type: nil, prepend_batch_axis: false) ⇒ Object
- #delete(item) ⇒ Object
- #encodings ⇒ Object
- #include?(item) ⇒ Boolean
-
#initialize(data: nil, encoding: nil, tensor_type: nil, prepend_batch_axis: false, n_sequences: nil) ⇒ BatchEncoding
constructor
A new instance of BatchEncoding.
- #items ⇒ Object
- #sequence_ids(batch_index = 0) ⇒ Object
- #to_h ⇒ Object (also: #to_hash)
Constructor Details
#initialize(data: nil, encoding: nil, tensor_type: nil, prepend_batch_axis: false, n_sequences: nil) ⇒ BatchEncoding
Returns a new instance of BatchEncoding.
35 36 37 38 39 40 41 42 43 44 45 46 47 |
# File 'lib/transformers/tokenization_utils_base.rb', line 35 def initialize( data: nil, encoding: nil, tensor_type: nil, prepend_batch_axis: false, n_sequences: nil ) @data = data @encodings = encoding convert_to_tensors(tensor_type: tensor_type, prepend_batch_axis: prepend_batch_axis) end |
Instance Method Details
#[](item) ⇒ Object
79 80 81 82 83 84 85 86 87 88 89 90 91 |
# File 'lib/transformers/tokenization_utils_base.rb', line 79 def [](item) if item.is_a?(String) @data[item] elsif item.is_a?(Symbol) @data[item.to_s] elsif !@encodings.nil? @encodings[item] elsif item.is_a?(Range) @data.keys.to_h { |key| [key, @data[key][item]] } else raise KeyError, "Invalid key. Only three types of key are available: (1) string, (2) integers for backend Encoding, and (3) ranges for data subsetting." end end |
#convert_to_tensors(tensor_type: nil, prepend_batch_axis: false) ⇒ Object
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# File 'lib/transformers/tokenization_utils_base.rb', line 49 def convert_to_tensors(tensor_type: nil, prepend_batch_axis: false) if tensor_type.nil? return self end if !tensor_type.is_a?(TensorType) tensor_type = TensorType.new(tensor_type) end is_tensor = Torch.method(:tensor?) as_tensor = lambda do |value, dtype: nil| if value.is_a?(Array) && value[0].is_a?(Numo::NArray) return Torch.tensor(Numo::NArray.cast(value)) end Torch.tensor(value) end items.each do |key, value| if prepend_batch_axis value = [value] end if !is_tensor.(value) tensor = as_tensor.(value) @data[key] = tensor end end end |
#delete(item) ⇒ Object
97 98 99 |
# File 'lib/transformers/tokenization_utils_base.rb', line 97 def delete(item) @data.delete(item.to_s) end |
#encodings ⇒ Object
105 106 107 |
# File 'lib/transformers/tokenization_utils_base.rb', line 105 def encodings @encodings end |
#include?(item) ⇒ Boolean
93 94 95 |
# File 'lib/transformers/tokenization_utils_base.rb', line 93 def include?(item) @data.include?(item.to_s) end |
#items ⇒ Object
101 102 103 |
# File 'lib/transformers/tokenization_utils_base.rb', line 101 def items @data end |
#sequence_ids(batch_index = 0) ⇒ Object
109 110 111 112 113 114 115 116 |
# File 'lib/transformers/tokenization_utils_base.rb', line 109 def sequence_ids(batch_index = 0) if !@encodings raise ArgumentError, "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" + " class)." end @encodings[batch_index].sequence_ids end |
#to_h ⇒ Object Also known as: to_hash
118 119 120 |
# File 'lib/transformers/tokenization_utils_base.rb', line 118 def to_h @data.transform_keys(&:to_sym) end |