Class: TensorStream::Constant

Inherits:
Tensor
  • Object
show all
Defined in:
lib/tensor_stream/constant.rb

Overview

Class that defines a TensorStream variable

Instance Attribute Summary

Attributes inherited from Tensor

#data_type, #given_name, #graph, #internal, #is_const, #name, #native_buffer, #op, #outputs, #rank, #shape, #source, #value

Instance Method Summary collapse

Methods inherited from Tensor

#auto_math, #breakpoint!, cast_dtype, #collect, #consumers, detect_type, #device, #dtype, #eval, #first, #internal?, #print!, reset_counters, #to_a, #to_f, #to_h, #to_i, #to_math, #to_s

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Methods included from TensorMixins

#!=, #%, #*, #**, #+, #-, #-@, #/, #<, #<=, #==, #>, #>=, #[], #and, #cast, #ceil, #dot, #floor, #log, #matmul, #reduce, #reshape, #round, #var, #zero?

Constructor Details

#initialize(data_type, rank, shape, options = {}) ⇒ Constant

Returns a new instance of Constant.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/tensor_stream/constant.rb', line 4

def initialize(data_type, rank, shape, options = {})
  setup_initial_state(options)
  @data_type = data_type
  @rank = rank
  @breakpoint = false
  @shape = TensorShape.new(shape, rank)
  @value = nil
  @options = options
  @is_const = true
  @internal = options[:internal]
  @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join("/")
  @given_name = @name

  if options[:value]
    if options[:value].is_a?(Array)
      # check if single dimenstion array is passed
      options[:value] = _reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)

      @value = options[:value].map { |v| v.is_a?(Tensor) ? Tensor.cast_dtype(v, @data_type) : v }
    elsif !shape.empty?
      @value = _reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
    else
      @value = Tensor.cast_dtype(options[:value], @data_type)
    end
    @shape = TensorShape.new(shape_eval(@value))
  end

  @op = Graph.get_default_graph.add_op!(:const, value: @value, data_type: @data_type, internal_name: @name, shape: @shape)
  @name = @op.name
end

Instance Method Details

#inspectObject



35
36
37
# File 'lib/tensor_stream/constant.rb', line 35

def inspect
  "Constant(#{@value}, name: #{@name}, shape: #{@shape}, data_type: #{@data_type})"
end