Class: SparseTensor

Inherits:
Hash
  • Object
show all
Defined in:
lib/graphkit.rb

Overview

A simple sparse tensor

Defined Under Namespace

Classes: RankError

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Hash

#modify

Constructor Details

#initialize(rank = 2) ⇒ SparseTensor

Create a new tensor.



55
56
57
58
59
# File 'lib/graphkit.rb', line 55

def initialize(rank = 2)
  @rank = rank
  @shape = [0]*rank
  super()
end

Instance Attribute Details

#rankObject (readonly)

Returns the value of attribute rank.



50
51
52
# File 'lib/graphkit.rb', line 50

def rank
  @rank
end

#shapeObject (readonly)

Returns the value of attribute shape.



50
51
52
# File 'lib/graphkit.rb', line 50

def shape
  @shape
end

Class Method Details

.diagonal(rank, array) ⇒ Object

Create a new diagonal tensor from an array. E.g. if rank was 2, then tensor = array tensor = array Etc.



66
67
68
69
70
71
72
# File 'lib/graphkit.rb', line 66

def self.diagonal(rank, array)
  tensor = new(rank)
  for index in 0...array.size
    tensor[[index] * rank] = array[index]
  end
  tensor
end

.from_hash(hash) ⇒ Object



161
162
163
164
165
# File 'lib/graphkit.rb', line 161

def self.from_hash(hash)
  st = new(hash.keys[0].size)
  hash.each{|k,v| st[k] = v}
  st
end

Instance Method Details

#+(other) ⇒ Object



135
136
137
# File 'lib/graphkit.rb', line 135

def +(other)
  scalar_binary(other){|a, b| a + b}
end

#-(other) ⇒ Object



138
139
140
# File 'lib/graphkit.rb', line 138

def -(other)
  scalar_binary(other){|a, b| a - b}
end

#[](*args) ⇒ Object

Access an element of the tensor. E.g. for a rank 2 tensor

a = tensor

Raises:



78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# File 'lib/graphkit.rb', line 78

def [](*args)
  args = args[0] if args.size == 1 and not args.size == @rank and args[0].size == @rank
#     p args

  raise RankError.new("Rank is #@rank, not #{args.size}") unless args.size == @rank
  return nil unless keys.include? args
  #if self.keys.include?(args) or @default_val == nil
    #eputs args.pretty_inspect
    #eputs self.pretty_inspect
    #eputs self.class, self.class.ancestors
    super(args)
  #else
    #return @default_val
  #end

end

#[]=(*args) ⇒ Object

Set an element of the tensor. E.g. for a rank 2 tensor

tensor = a_variable

Raises:



99
100
101
102
103
104
105
106
107
# File 'lib/graphkit.rb', line 99

def []=(*args)
  value = args.pop
  args = args[0] if args.size == 1 and args[0].size == @rank
  raise RankError.new("Rank is #@rank, not #{args.size}") unless args.size == @rank
  args.each_with_index do |arg, index|
    @shape[index] = [@shape[index], arg + 1].max
  end
  super(args, value)
end

#alter!(&block) ⇒ Object



156
157
158
159
160
# File 'lib/graphkit.rb', line 156

def alter!(&block)
  self.keys.each do |k|
    self[k] = yield(self[k])
  end
end

#inspectObject



166
167
168
# File 'lib/graphkit.rb', line 166

def inspect
  "SparseTensor.from_hash(#{super})"
end

#max(&block) ⇒ Object

Find the maximum element of the tensor. See Enumerable#max.



144
145
146
# File 'lib/graphkit.rb', line 144

def max(&block)
  return self.values.max(&block)
end

#min(&block) ⇒ Object

Find the minimum element of the tensor. See Enumerable#max.



150
151
152
# File 'lib/graphkit.rb', line 150

def min(&block)
  return self.values.min(&block)
end

#scalar_binary(other, &block) ⇒ Object

Raises:

  • (ArgumentError)


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# File 'lib/graphkit.rb', line 120

def scalar_binary(other, &block)
  raise ArgumentError unless other.class == self.class
  raise RankError.new("Different ranks: #@rank, #{other.rank}") unless other.rank == @rank

  new = self.class.new(@rank)
  self.keys.each do |key|
    if other[key]
      new[key] = yield(self[key], other[key])
    else
      new[key] = self[key]
    end
  end
  (other.keys - self.keys).each{|key| new[key] = other[key]}
  new
end