5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
# File 'lib/torch/nn/init.rb', line 5
def calculate_fan_in_and_fan_out(tensor)
dimensions = tensor.dim
if dimensions < 2
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
end
if dimensions == 2
fan_in = tensor.size(1)
fan_out = tensor.size(0)
else
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim > 2
receptive_field_size = tensor[0][0].numel
end
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
end
[fan_in, fan_out]
end
|