Class: Torch::Distributions::Normal

Inherits:
ExponentialFamily show all
Defined in:
lib/torch/distributions/normal.rb

Instance Method Summary collapse

Constructor Details

#initialize(loc, scale, validate_args: nil) ⇒ Normal

Returns a new instance of Normal.



4
5
6
7
8
9
10
11
12
# File 'lib/torch/distributions/normal.rb', line 4

def initialize(loc, scale, validate_args: nil)
  @loc, @scale = Utils.broadcast_all(loc, scale)
  if loc.is_a?(Numeric) && scale.is_a?(Numeric)
    batch_shape = []
  else
    batch_shape = @loc.size
  end
  super(batch_shape:, validate_args:)
end

Instance Method Details

#sample(sample_shape: []) ⇒ Object



14
15
16
17
18
19
# File 'lib/torch/distributions/normal.rb', line 14

def sample(sample_shape: [])
  shape = _extended_shape(sample_shape:)
  Torch.no_grad do
    Torch.normal(@loc.expand(shape), @scale.expand(shape))
  end
end