Class: Einsum

Inherits:
Object
  • Object
show all
Defined in:
lib/einsum.rb,
lib/einsum/version.rb

Overview

typed: strong frozen_string_literal: true

Defined Under Namespace

Classes: Label

Constant Summary collapse

FormatError =
Class.new(StandardError)
VERSION =
'0.1.4'

Class Method Summary collapse

Class Method Details

.einsum(format, *operands) ⇒ Object

Evaluates the (extended) Einstein summation convention on the operands.

Operands must be Array like. Array elements must respond to * and +.

Examples:

Implicit mode:

Einsum.einsum('ij,jk', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => dot product: [[7, 10], [15, 22]]
Einsum.einsum('ij,kj', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => inner product: [[ 5, 11], [11, 25]]

Explicit mode:

Einsum.einsum('ij,jk->ik', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => dot product: [[7, 10], [15, 22]]
Einsum.einsum('ij,kj->ik', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => inner product: [[ 5, 11], [11, 25]]
Einsum.einsum('ij,jk->', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => 54
Einsum.einsum('ij,kj->', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => 52


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# File 'lib/einsum.rb', line 44

def einsum(format, *operands)
  labels = {}

  # check syntax of format string
  unless format.match?(/\A([a-z]+(,[a-z]+)*)(->[a-z]*)?\z/)
    raise FormatError, "invalid format: #{format}"
  end

  # chop up format string
  inputs, explicit, output = format.partition('->')
  inputs = inputs.split(',')
  if operands.length != inputs.length
    raise FormatError, "provides #{operands.length} operands for #{inputs.length} input labels strings"
  end

  # check labels and operands
  inputs.zip(operands).each.with_index do |(input, operand), pos|
    input.split('').each.with_index do |label, axis|
      unless (dim = dim(operand, axis))
        raise FormatError, "no axis in operand #{pos} corresponds to label #{label}"
      end
      if labels[label] && labels[label].dimension != dim
        raise FormatError, "inconsistent dimension for label #{label}: #{labels[label].dimension} and #{dim}"
      end

      labels[label] ||= Label.new(dim)
      labels[label].increment
    end
  end

  # if implicit mode, generate output labels string from all
  # labels mentioned only once in the input labels strings
  if explicit.empty? && (groups = labels.group_by { |_, l| l.count }[1])
    output = groups.map(&:first).sort.join
  end

  # compute shape of the result
  shape = []
  output.split('').each do |label|
    unless labels[label]
      raise FormatError, "output label #{label} not present in input labels"
    end

    shape << labels[label].dimension
  end

  # generate template for result
  result = 0
  unless shape.empty?
    result = empty(shape, result)
  end

  # generate code for the specified operations. first, loop over
  # each output axis in the order specified by the output labels.
  # then, loop over the remaining input axes and compute the
  # result for each cell in the output matrix.

  code = []
  internal = inputs.join.split('').sort.uniq - output.split('')
  external = output.split('')

  external.each do |label|
    code.push("#{labels[label].dimension}.times do |#{label}|")
  end

  internal.each do |label|
    code.push("#{labels[label].dimension}.times do |#{label}|")
  end

  external_labels = external.map { |l| "[#{l}]" }.join
  code.push("result#{external_labels} +=")

  inputs.each.with_index do |input, i|
    input_labels = input.split('').map { |l| "[#{l}]" }.join
    suffix = i < inputs.length - 1 ? ' *' : ''
    code.push("operands[#{i}]#{input_labels}#{suffix}")
  end

  internal.each do
    code.push('end')
  end

  external.each do
    code.push('end')
  end

  # evaluate the generated code in the current context. this would
  # be considered dangerous, except we are in control of generated
  # code except loop variable names, which are derived from input
  # and output labels, which are constrained to be individual,
  # lowercase characters, which are bound in their respective
  # loops.

  # rubocop:disable Security/Eval
  binding.eval(code.join("\n"))
  # rubocop:enable Security/Eval

  result
end