Class: SparseTensor

Inherits:
Hash
  • Object
show all
Defined in:
lib/graphkit.rb

Overview

A simple sparse tensor

Defined Under Namespace

Classes: RankError

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Hash

#modify

Constructor Details

#initialize(rank = 2) ⇒ SparseTensor

Create a new tensor.



55
56
57
58
59
# File 'lib/graphkit.rb', line 55

def initialize(rank = 2)
	@rank = rank
	@shape = [0]*rank
	super()
end

Instance Attribute Details

#rankObject (readonly)

Returns the value of attribute rank.



50
51
52
# File 'lib/graphkit.rb', line 50

def rank
  @rank
end

#shapeObject (readonly)

Returns the value of attribute shape.



50
51
52
# File 'lib/graphkit.rb', line 50

def shape
  @shape
end

Class Method Details

.diagonal(rank, array) ⇒ Object

Create a new diagonal tensor from an array. E.g. if rank was 2, then tensor = array tensor = array Etc.



66
67
68
69
70
71
72
# File 'lib/graphkit.rb', line 66

def self.diagonal(rank, array)
	tensor = new(rank)
	for index in 0...array.size
		tensor[[index] * rank] = array[index]
	end
	tensor
end

.from_hash(hash) ⇒ Object



161
162
163
164
165
# File 'lib/graphkit.rb', line 161

def self.from_hash(hash)
	st = new(hash.keys[0].size)
	hash.each{|k,v| st[k] = v}
	st
end

Instance Method Details

#+(other) ⇒ Object



135
136
137
# File 'lib/graphkit.rb', line 135

def +(other)
	scalar_binary(other){|a, b| a + b}
end

#-(other) ⇒ Object



138
139
140
# File 'lib/graphkit.rb', line 138

def -(other)
	scalar_binary(other){|a, b| a - b}
end

#[](*args) ⇒ Object

Access an element of the tensor. E.g. for a rank 2 tensor

a = tensor

Raises:



78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# File 'lib/graphkit.rb', line 78

def [](*args)
	args = args[0] if args.size == 1 and not args.size == @rank and args[0].size == @rank
# 		p args

	raise RankError.new("Rank is #@rank, not #{args.size}") unless args.size == @rank
	return nil unless keys.include? args
	#if self.keys.include?(args) or @default_val == nil
		#eputs args.pretty_inspect
		#eputs self.pretty_inspect
		#eputs self.class, self.class.ancestors
		super(args)
	#else
		#return @default_val
	#end

end

#[]=(*args) ⇒ Object

Set an element of the tensor. E.g. for a rank 2 tensor

tensor = a_variable

Raises:



99
100
101
102
103
104
105
106
107
# File 'lib/graphkit.rb', line 99

def []=(*args)
	value = args.pop
	args = args[0] if args.size == 1 and args[0].size == @rank
	raise RankError.new("Rank is #@rank, not #{args.size}") unless args.size == @rank
	args.each_with_index do |arg, index|
		@shape[index] = [@shape[index], arg + 1].max
	end
	super(args, value)
end

#alter!(&block) ⇒ Object



156
157
158
159
160
# File 'lib/graphkit.rb', line 156

def alter!(&block)
	self.keys.each do |k|
		self[k] = yield(self[k])
	end
end

#inspectObject



166
167
168
# File 'lib/graphkit.rb', line 166

def inspect
	"SparseTensor.from_hash(#{super})"
end

#max(&block) ⇒ Object

Find the maximum element of the tensor. See Enumerable#max.



144
145
146
# File 'lib/graphkit.rb', line 144

def max(&block)
	return self.values.max(&block)
end

#min(&block) ⇒ Object

Find the minimum element of the tensor. See Enumerable#max.



150
151
152
# File 'lib/graphkit.rb', line 150

def min(&block)
	return self.values.min(&block)
end

#scalar_binary(other, &block) ⇒ Object

Raises:

  • (ArgumentError)


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# File 'lib/graphkit.rb', line 120

def scalar_binary(other, &block)
	raise ArgumentError unless other.class == self.class
	raise RankError.new("Different ranks: #@rank, #{other.rank}") unless other.rank == @rank

	new = self.class.new(@rank)
	self.keys.each do |key|
		if other[key]
			new[key] = yield(self[key], other[key])
		else
			new[key] = self[key]
		end
	end
	(other.keys - self.keys).each{|key| new[key] = other[key]}
	new
end