Class: Einsum
- Inherits:
-
Object
- Object
- Einsum
- Defined in:
- lib/einsum.rb,
lib/einsum/version.rb
Overview
typed: strong frozen_string_literal: true
Defined Under Namespace
Classes: Label
Constant Summary collapse
- FormatError =
Class.new(StandardError)
- VERSION =
'0.1.4'
Class Method Summary collapse
-
.einsum(format, *operands) ⇒ Object
Evaluates the (extended) Einstein summation convention on the operands.
Class Method Details
.einsum(format, *operands) ⇒ Object
Evaluates the (extended) Einstein summation convention on the operands.
Operands must be Array like. Array elements must respond to * and +.
Examples:
Implicit mode:
Einsum.einsum('ij,jk', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => dot product: [[7, 10], [15, 22]]
Einsum.einsum('ij,kj', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => inner product: [[ 5, 11], [11, 25]]
Explicit mode:
Einsum.einsum('ij,jk->ik', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => dot product: [[7, 10], [15, 22]]
Einsum.einsum('ij,kj->ik', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => inner product: [[ 5, 11], [11, 25]]
Einsum.einsum('ij,jk->', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => 54
Einsum.einsum('ij,kj->', [[1, 2], [3, 4]], [[1, 2], [3, 4]]) # => 52
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# File 'lib/einsum.rb', line 44 def einsum(format, *operands) labels = {} # check syntax of format string unless format.match?(/\A([a-z]+(,[a-z]+)*)(->[a-z]*)?\z/) raise FormatError, "invalid format: #{format}" end # chop up format string inputs, explicit, output = format.partition('->') inputs = inputs.split(',') if operands.length != inputs.length raise FormatError, "provides #{operands.length} operands for #{inputs.length} input labels strings" end # check labels and operands inputs.zip(operands).each.with_index do |(input, operand), pos| input.split('').each.with_index do |label, axis| unless (dim = dim(operand, axis)) raise FormatError, "no axis in operand #{pos} corresponds to label #{label}" end if labels[label] && labels[label].dimension != dim raise FormatError, "inconsistent dimension for label #{label}: #{labels[label].dimension} and #{dim}" end labels[label] ||= Label.new(dim) labels[label].increment end end # if implicit mode, generate output labels string from all # labels mentioned only once in the input labels strings if explicit.empty? && (groups = labels.group_by { |_, l| l.count }[1]) output = groups.map(&:first).sort.join end # compute shape of the result shape = [] output.split('').each do |label| unless labels[label] raise FormatError, "output label #{label} not present in input labels" end shape << labels[label].dimension end # generate template for result result = 0 unless shape.empty? result = empty(shape, result) end # generate code for the specified operations. first, loop over # each output axis in the order specified by the output labels. # then, loop over the remaining input axes and compute the # result for each cell in the output matrix. code = [] internal = inputs.join.split('').sort.uniq - output.split('') external = output.split('') external.each do |label| code.push("#{labels[label].dimension}.times do |#{label}|") end internal.each do |label| code.push("#{labels[label].dimension}.times do |#{label}|") end external_labels = external.map { |l| "[#{l}]" }.join code.push("result#{external_labels} +=") inputs.each.with_index do |input, i| input_labels = input.split('').map { |l| "[#{l}]" }.join suffix = i < inputs.length - 1 ? ' *' : '' code.push("operands[#{i}]#{input_labels}#{suffix}") end internal.each do code.push('end') end external.each do code.push('end') end # evaluate the generated code in the current context. this would # be considered dangerous, except we are in control of generated # code except loop variable names, which are derived from input # and output labels, which are constrained to be individual, # lowercase characters, which are bound in their respective # loops. # rubocop:disable Security/Eval binding.eval(code.join("\n")) # rubocop:enable Security/Eval result end |