Class: TorchVision::Transforms::Functional

Inherits:
Object
  • Object
show all
Defined in:
lib/torchvision/transforms/functional.rb

Class Method Summary collapse

Class Method Details

.hflip(img) ⇒ Object



91
92
93
94
95
96
97
# File 'lib/torchvision/transforms/functional.rb', line 91

def hflip(img)
  if img.is_a?(Torch::Tensor)
    img.flip(-1)
  else
    img.flip(:horizontal)
  end
end

.normalize(tensor, mean, std, inplace: false) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/torchvision/transforms/functional.rb', line 5

def normalize(tensor, mean, std, inplace: false)
  unless Torch.tensor?(tensor)
    raise ArgumentError, "tensor should be a torch tensor. Got #{tensor.class.name}"
  end

  if tensor.ndimension != 3
    raise ArgumentError, "Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = #{tensor.size}"
  end

  tensor = tensor.clone unless inplace

  dtype = tensor.dtype
  # TODO Torch.as_tensor
  mean = Torch.tensor(mean, dtype: dtype, device: tensor.device)
  std = Torch.tensor(std, dtype: dtype, device: tensor.device)

  # TODO
  if std.to_a.any? { |v| v == 0 }
    raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
  end
  if mean.ndim == 1
    mean = mean[0...mean.size(0), nil, nil]
  end
  if std.ndim == 1
    std = std[0...std.size(0), nil, nil]
  end
  tensor.sub!(mean).div!(std)
  tensor
end

.resize(img, size) ⇒ Object



35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/torchvision/transforms/functional.rb', line 35

def resize(img, size)
  raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
  # TODO support array size
  raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)

  w, h = img.size
  if (w <= h && w == size) || (h <= w && h == size)
    return img
  end
  if w < h
    ow = size
    oh = (size * h / w).to_i
    img.thumbnail_image(ow, height: oh)
  else
    oh = size
    ow = (size * w / h).to_i
    img.thumbnail_image(ow, height: oh)
  end
end

.to_tensor(pic) ⇒ Object

TODO improve



56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# File 'lib/torchvision/transforms/functional.rb', line 56

def to_tensor(pic)
  if !pic.is_a?(Numo::NArray) && !pic.is_a?(Vips::Image)
    raise ArgumentError, "pic should be Vips::Image or Numo::NArray. Got #{pic.class.name}"
  end

  if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
    raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
  end

  if pic.is_a?(Numo::NArray)
    if pic.ndim == 2
      pic = pic.reshape(*pic.shape, 1)
    end

    img = Torch.from_numo(pic.transpose(2, 0, 1))
    if img.dtype == :uint8
      return img.float.div(255)
    else
      return img
    end
  end

  case pic.format
  when :uchar
    img = Torch::ByteTensor.new(Torch::ByteStorage.from_buffer(pic.write_to_memory))
  else
    raise Error, "Format not supported yet: #{pic.format}"
  end

  img = img.view(pic.height, pic.width, pic.bands)
  # put it from HWC to CHW format
  img = img.permute([2, 0, 1]).contiguous
  img.float.div(255)
end

.vflip(img) ⇒ Object



99
100
101
102
103
104
105
# File 'lib/torchvision/transforms/functional.rb', line 99

def vflip(img)
  if img.is_a?(Torch::Tensor)
    img.flip(-2)
  else
    img.flip(:vertical)
  end
end