Method: TensorStream::MathGradients._extract_input_shapes
- Defined in:
- lib/tensor_stream/math_gradients.rb
._extract_input_shapes(inputs) ⇒ Object
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
# File 'lib/tensor_stream/math_gradients.rb', line 248 def self._extract_input_shapes(inputs) sizes = [] fully_known = true inputs.each do |x| input_shape = ts.shape(x) unless input_shape.is_const fully_known = false break end sizes << input_shape.value end if fully_known sizes else ts.shape_n(inputs) end end |