5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
# File 'lib/torchvision/transforms/functional.rb', line 5
def normalize(tensor, mean, std, inplace: false)
unless Torch.tensor?(tensor)
raise ArgumentError, "tensor should be a torch tensor. Got #{tensor.class.name}"
end
if tensor.ndimension != 3
raise ArgumentError, "Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = #{tensor.size}"
end
tensor = tensor.clone unless inplace
dtype = tensor.dtype
mean = Torch.tensor(mean, dtype: dtype, device: tensor.device)
std = Torch.tensor(std, dtype: dtype, device: tensor.device)
if std.to_a.any? { |v| v == 0 }
raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
end
tensor.sub!(mean).div!(std)
tensor
end
|