Class: SparseTensor

Inherits:
Hash
  • Object
show all
Defined in:
lib/graphkit.rb

Overview

A simple sparse tensor

Defined Under Namespace

Classes: RankError

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Hash

#modify

Constructor Details

#initialize(rank = 2) ⇒ SparseTensor

Create a new tensor.



53
54
55
56
57
# File 'lib/graphkit.rb', line 53

def initialize(rank = 2)
	@rank = rank
	@shape = [0]*rank
	super()
end

Instance Attribute Details

#rankObject (readonly)

Returns the value of attribute rank.



49
50
51
# File 'lib/graphkit.rb', line 49

def rank
  @rank
end

#shapeObject (readonly)

Returns the value of attribute shape.



49
50
51
# File 'lib/graphkit.rb', line 49

def shape
  @shape
end

Class Method Details

.diagonal(rank, array) ⇒ Object

Create a new diagonal tensor from an array. E.g. if rank was 2, then tensor = array tensor = array Etc.



64
65
66
67
68
69
70
# File 'lib/graphkit.rb', line 64

def self.diagonal(rank, array)
	tensor = new(rank)
	for index in 0...array.size
		tensor[[index] * rank] = array[index]
	end
	tensor
end

Instance Method Details

#+(other) ⇒ Object



124
125
126
# File 'lib/graphkit.rb', line 124

def +(other)
	scalar_binary(other){|a, b| a + b}
end

#-(other) ⇒ Object



127
128
129
# File 'lib/graphkit.rb', line 127

def -(other)
	scalar_binary(other){|a, b| a - b}
end

#[](*args) ⇒ Object

Access an element of the tensor. E.g. for a rank 2 tensor

a = tensor

Raises:



76
77
78
79
80
81
82
# File 'lib/graphkit.rb', line 76

def [](*args)
	args = args[0] if args.size == 1 and not args.size == @rank and args[0].size == @rank
# 		p args

	raise RankError.new("Rank is #@rank, not #{args.size}") unless args.size == @rank
	return super(args)
end

#[]=(*args) ⇒ Object

Set an element of the tensor. E.g. for a rank 2 tensor

tensor = a_variable

Raises:



88
89
90
91
92
93
94
95
96
# File 'lib/graphkit.rb', line 88

def []=(*args)
	value = args.pop
	args = args[0] if args.size == 1 and args[0].size == @rank
	raise RankError.new("Rank is #@rank, not #{args.size}") unless args.size == @rank
	args.each_with_index do |arg, index|
		@shape[index] = [@shape[index], arg + 1].max
	end
	super(args, value)
end

#max(&block) ⇒ Object

Find the maximum element of the tensor. See Enumerable#max.



133
134
135
# File 'lib/graphkit.rb', line 133

def max(&block)
	return self.values.max(&block)
end

#min(&block) ⇒ Object

Find the minimum element of the tensor. See Enumerable#max.



139
140
141
# File 'lib/graphkit.rb', line 139

def min(&block)
	return self.values.min(&block)
end

#scalar_binary(other, &block) ⇒ Object

Raises:

  • (ArgumentError)


109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# File 'lib/graphkit.rb', line 109

def scalar_binary(other, &block)
	raise ArgumentError unless other.class == self.class
	raise RankError.new("Different ranks: #@rank, #{other.rank}") unless other.rank == @rank

	new = self.class.new(@rank)
	self.keys.each do |key|
		if other[key]
			new[key] = yield(self[key], other[key])
		else
			new[key] = self[key]
		end
	end
	(other.keys - self.keys).each{|key| new[key] = other[key]}
	new
end