Class: Torch::NN::Functional

Inherits:
Object
  • Object
show all
Defined in:
lib/torch/nn/functional.rb

Class Method Summary collapse

Class Method Details

.avg_pool2d(input, kernel_size) ⇒ Object



27
28
29
30
# File 'lib/torch/nn/functional.rb', line 27

def avg_pool2d(input, kernel_size)
  kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
  Torch.avg_pool2d(input, kernel_size)
end

.conv2d(input, weight, bias, stride: 1, padding: 0) ⇒ Object



9
10
11
12
# File 'lib/torch/nn/functional.rb', line 9

def conv2d(input, weight, bias, stride: 1, padding: 0)
  # TODO pair stride and padding when needed
  Torch.conv2d(input, weight, bias, stride, padding)
end

.cross_entropy(input, target) ⇒ Object



40
41
42
# File 'lib/torch/nn/functional.rb', line 40

def cross_entropy(input, target)
  nll_loss(log_softmax(input, 1), target)
end

.leaky_relu(input, negative_slope = 0.01) ⇒ Object



18
19
20
# File 'lib/torch/nn/functional.rb', line 18

def leaky_relu(input, negative_slope = 0.01)
  Torch.leaky_relu(input, negative_slope)
end

.linear(input, weight, bias) ⇒ Object



32
33
34
# File 'lib/torch/nn/functional.rb', line 32

def linear(input, weight, bias)
  Torch.linear(input, weight, bias)
end

.log_softmax(input, dim) ⇒ Object



49
50
51
# File 'lib/torch/nn/functional.rb', line 49

def log_softmax(input, dim)
  input.log_softmax(dim)
end

.max_pool2d(input, kernel_size) ⇒ Object



22
23
24
25
# File 'lib/torch/nn/functional.rb', line 22

def max_pool2d(input, kernel_size)
  kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
  Torch.max_pool2d(input, kernel_size)
end

.mse_loss(input, target, reduction: "mean") ⇒ Object



36
37
38
# File 'lib/torch/nn/functional.rb', line 36

def mse_loss(input, target, reduction: "mean")
  Torch.mse_loss(input, target, reduction)
end

.nll_loss(input, target) ⇒ Object



44
45
46
47
# File 'lib/torch/nn/functional.rb', line 44

def nll_loss(input, target)
  # TODO fix for non-1d
  Torch.nll_loss(input, target)
end

.prelu(input, weight) ⇒ Object



14
15
16
# File 'lib/torch/nn/functional.rb', line 14

def prelu(input, weight)
  Torch.prelu(input, weight)
end

.relu(input) ⇒ Object



5
6
7
# File 'lib/torch/nn/functional.rb', line 5

def relu(input)
  Torch.relu(input)
end