Class: TensorFlow::Keras::Layers::Flatten

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/keras/layers/flatten.rb

Instance Method Summary collapse

Constructor Details

#initialize(input_shape: nil) ⇒ Flatten

Returns a new instance of Flatten.



5
6
7
# File 'lib/tensorflow/keras/layers/flatten.rb', line 5

def initialize(input_shape: nil)
  @input_shape = input_shape
end

Instance Method Details

#call(inputs) ⇒ Object



18
19
20
21
# File 'lib/tensorflow/keras/layers/flatten.rb', line 18

def call(inputs)
  flattened_dim = inputs.shape[1..-1].inject(&:*)
  TensorFlow.reshape(inputs, [-1, flattened_dim])
end

#count_paramsObject



14
15
16
# File 'lib/tensorflow/keras/layers/flatten.rb', line 14

def count_params
  0
end

#output_shapeObject



9
10
11
12
# File 'lib/tensorflow/keras/layers/flatten.rb', line 9

def output_shape
  flattened_dim = @input_shape.inject(&:*)
  [-1, flattened_dim]
end