Method: Transformers::BatchFeature#to

Defined in:
lib/transformers/feature_extraction_utils.rb

#to(*args, **kwargs) ⇒ Object



87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# File 'lib/transformers/feature_extraction_utils.rb', line 87

def to(*args, **kwargs)
  new_data = {}
  device = kwargs[:device]
  # Check if the args are a device or a dtype
  if device.nil? && args.length > 0
    raise Todo
  end
  # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
  items.each do |k, v|
    # check if v is a floating point
    if Torch.floating_point?(v)
      # cast and send to device
      new_data[k] = v.to(*args, **kwargs)
    elsif !device.nil?
      new_data[k] = v.to(device)
    else
      new_data[k] = v
    end
  end
  @data = new_data
  self
end