Method: Transformers::BatchFeature#_get_is_as_tensor_fns

Defined in:
lib/transformers/feature_extraction_utils.rb

#_get_is_as_tensor_fns(tensor_type: nil) ⇒ Object



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/transformers/feature_extraction_utils.rb', line 42

def _get_is_as_tensor_fns(tensor_type: nil)
  if tensor_type.nil?
    return [nil, nil]
  end

  as_tensor = lambda do |value|
    if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
      value = Numo::NArray.cast(value)
    end
    Torch.tensor(value)
  end

  is_tensor = Torch.method(:tensor?)

  [is_tensor, as_tensor]
end