Method: TensorStream::MathGradients._extract_input_shapes

Defined in:
lib/tensor_stream/math_gradients.rb

._extract_input_shapes(inputs) ⇒ Object



248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# File 'lib/tensor_stream/math_gradients.rb', line 248

def self._extract_input_shapes(inputs)
  sizes = []
  fully_known = true
  inputs.each do |x|
    input_shape = ts.shape(x)
    unless input_shape.is_const
      fully_known = false
      break
    end
    sizes << input_shape.value
  end

  if fully_known
    sizes
  else
    ts.shape_n(inputs)
  end
end