Class: Transformers::BatchEncoding

Inherits:
Object
  • Object
show all
Defined in:
lib/transformers/tokenization_utils_base.rb

Instance Method Summary collapse

Constructor Details

#initialize(data: nil, encoding: nil, tensor_type: nil, prepend_batch_axis: false, n_sequences: nil) ⇒ BatchEncoding

Returns a new instance of BatchEncoding.



35
36
37
38
39
40
41
42
43
44
45
46
47
# File 'lib/transformers/tokenization_utils_base.rb', line 35

def initialize(
  data: nil,
  encoding: nil,
  tensor_type: nil,
  prepend_batch_axis: false,
  n_sequences: nil
)
  @data = data

  @encodings = encoding

  convert_to_tensors(tensor_type: tensor_type, prepend_batch_axis: prepend_batch_axis)
end

Instance Method Details

#[](item) ⇒ Object



79
80
81
82
83
84
85
86
87
88
89
90
91
# File 'lib/transformers/tokenization_utils_base.rb', line 79

def [](item)
  if item.is_a?(String)
    @data[item]
  elsif item.is_a?(Symbol)
    @data[item.to_s]
  elsif !@encodings.nil?
    @encodings[item]
  elsif item.is_a?(Range)
    @data.keys.to_h { |key| [key, @data[key][item]] }
  else
    raise KeyError, "Invalid key. Only three types of key are available: (1) string, (2) integers for backend Encoding, and (3) ranges for data subsetting."
  end
end

#convert_to_tensors(tensor_type: nil, prepend_batch_axis: false) ⇒ Object



49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/transformers/tokenization_utils_base.rb', line 49

def convert_to_tensors(tensor_type: nil, prepend_batch_axis: false)
  if tensor_type.nil?
    return self
  end

  if !tensor_type.is_a?(TensorType)
    tensor_type = TensorType.new(tensor_type)
  end

  is_tensor = Torch.method(:tensor?)

  as_tensor = lambda do |value, dtype: nil|
    if value.is_a?(Array) && value[0].is_a?(Numo::NArray)
      return Torch.tensor(Numo::NArray.cast(value))
    end
    Torch.tensor(value)
  end

  items.each do |key, value|
    if prepend_batch_axis
      value = [value]
    end

    if !is_tensor.(value)
      tensor = as_tensor.(value)
      @data[key] = tensor
    end
  end
end

#delete(item) ⇒ Object



97
98
99
# File 'lib/transformers/tokenization_utils_base.rb', line 97

def delete(item)
  @data.delete(item.to_s)
end

#encodingsObject



105
106
107
# File 'lib/transformers/tokenization_utils_base.rb', line 105

def encodings
  @encodings
end

#include?(item) ⇒ Boolean

Returns:

  • (Boolean)


93
94
95
# File 'lib/transformers/tokenization_utils_base.rb', line 93

def include?(item)
  @data.include?(item.to_s)
end

#itemsObject



101
102
103
# File 'lib/transformers/tokenization_utils_base.rb', line 101

def items
  @data
end

#sequence_ids(batch_index = 0) ⇒ Object



109
110
111
112
113
114
115
116
# File 'lib/transformers/tokenization_utils_base.rb', line 109

def sequence_ids(batch_index = 0)
  if !@encodings
    raise ArgumentError,
      "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" +
      " class)."
  end
  @encodings[batch_index].sequence_ids
end

#to_hObject Also known as: to_hash



118
119
120
# File 'lib/transformers/tokenization_utils_base.rb', line 118

def to_h
  @data.transform_keys(&:to_sym)
end