Class: Kumi::Core::Analyzer::Passes::SNASTPass
- Defined in:
- lib/kumi/core/analyzer/passes/snast_pass.rb
Overview
Semantic NAST Pass (SNAST)
-
Rewrites intrinsic control and reductions into first-class nodes.
-
Attaches semantic stamps to every node: meta = { axes:, dtype: }.
-
Uses side tables for types/scopes; no meta.
Reduction rule (default sugar):
If not explicitly annotated, a reduction over arguments reduces the LAST axis:
a = lub_by_prefix(arg_axes_list)
over = [a.last]; out_axes = a[0...-1]
Inputs (state):
:nast_module => Kumi::Core::NAST::Module (topologically ordered)
:metadata_table => Hash[node_key => { result_scope:, result_type:, arg_scopes?: ... }]
:declaration_table => Hash[name => { result_scope:, result_type: }]
:input_table => [{path_fqn:, axes:, dtype:}] or Hash[path_fqn] => { axes:, dtype: }
Output (state):
:snast_module => Kumi::Core::NAST::Module (with NAST::Select / NAST::Reduce nodes)
TODO: If downstream never keys by node ids, consider removing dependence on node.id. TODO: Use Error helpers with provenance
Instance Method Summary collapse
- #axes_of(n) ⇒ Object
- #dtype_of(n) ⇒ Object
- #lookup_input(fqn) ⇒ Object
-
#lub_by_prefix(list) ⇒ Object
Least upper bound by prefix.
- #meta_for(node) ⇒ Object
- #node_key(n) ⇒ Object
- #prefix?(pre, full) ⇒ Boolean
-
#reduce_last_axis(args_axes_list) ⇒ Object
Default reduce sugar: over last axis of the LUB of argument axes.
- #run(errors) ⇒ Object
-
#stamp!(node, axes, dtype) ⇒ Object
———- Helpers ———-.
- #visit_call(n) ⇒ Object
-
#visit_const(n) ⇒ Object
———- Leaves ———-.
- #visit_declaration(d) ⇒ Object
- #visit_hash(n) ⇒ Object
- #visit_import_call(n) ⇒ Object
- #visit_index_ref(n) ⇒ Object
- #visit_input_ref(n) ⇒ Object
-
#visit_module(mod) ⇒ Object
———- Visitor entry points ———-.
- #visit_pair(n) ⇒ Object
- #visit_ref(n) ⇒ Object
- #visit_tuple(n) ⇒ Object
Methods inherited from PassBase
#debug, #debug_enabled?, #initialize
Methods included from ErrorReporting
#inferred_location, #raise_localized_error, #raise_syntax_error, #raise_type_error, #report_enhanced_error, #report_error, #report_semantic_error, #report_syntax_error, #report_type_error
Constructor Details
This class inherits a constructor from Kumi::Core::Analyzer::Passes::PassBase
Instance Method Details
#axes_of(n) ⇒ Object
206 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 206 def axes_of(n) = Array(n.[:stamp]&.dig(:axes)) |
#dtype_of(n) ⇒ Object
207 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 207 def dtype_of(n) = n.[:stamp]&.dig(:dtype) |
#lookup_input(fqn) ⇒ Object
233 234 235 236 237 238 239 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 233 def lookup_input(fqn) if @input_table.respond_to?(:find) @input_table.find { |x| x[:path_fqn] == fqn } || raise("Input not found for #{fqn}") else @input_table.fetch(fqn) { raise("Input not found for #{fqn}") } end end |
#lub_by_prefix(list) ⇒ Object
Least upper bound by prefix. All entries must be a prefix of the longest.
210 211 212 213 214 215 216 217 218 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 210 def lub_by_prefix(list) return [] if list.empty? cand = list.max_by(&:length) || [] list.each do |ax| raise Kumi::Core::Errors::SemanticError, "prefix mismatch: #{ax.inspect} vs #{cand.inspect}" unless prefix?(ax, cand) end cand end |
#meta_for(node) ⇒ Object
205 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 205 def (node) = .fetch(node_key(node)) |
#node_key(n) ⇒ Object
241 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 241 def node_key(n) = "#{n.class}_#{n.id}" |
#prefix?(pre, full) ⇒ Boolean
220 221 222 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 220 def prefix?(pre, full) pre.each_with_index.all? { |tok, i| full[i] == tok } end |
#reduce_last_axis(args_axes_list) ⇒ Object
Default reduce sugar: over last axis of the LUB of argument axes. Returns { over:, out_axes: }.
226 227 228 229 230 231 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 226 def reduce_last_axis(args_axes_list) a = lub_by_prefix(args_axes_list) raise Kumi::Core::Errors::SemanticError, "cannot reduce scalar" if a.empty? { over: [a.last], out_axes: a[0...-1] } end |
#run(errors) ⇒ Object
29 30 31 32 33 34 35 36 37 38 39 40 41 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 29 def run(errors) @nast_module = get_state(:nast_module, required: true) = get_state(:metadata_table, required: true) @declaration_table = get_state(:declaration_table, required: true) @input_table = get_state(:input_table, required: true) @index_table = get_state(:index_table, required: true) @registry = get_state(:registry, required: true) @errors = errors debug "Building SNAST from #{@nast_module.decls.size} declarations" snast_module = @nast_module.accept(self) state.with(:snast_module, snast_module.freeze) end |
#stamp!(node, axes, dtype) ⇒ Object
———- Helpers ———-
200 201 202 203 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 200 def stamp!(node, axes, dtype) node.[:stamp] = { axes: Array(axes), dtype: dtype }.freeze node end |
#visit_call(n) ⇒ Object
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 104 def visit_call(n) if @registry.function_select?(n.fn) c = n.args[0].accept(self) t = n.args[1].accept(self) f = n.args[2].accept(self) target_axes = lub_by_prefix([axes_of(t), axes_of(f)]) target_axes = axes_of(c) if target_axes.empty? unless prefix?(axes_of(c), target_axes) raise Kumi::Core::Errors::SemanticError, "select mask axes #{axes_of(c).inspect} must prefix #{target_axes.inspect} at: #{n.loc}" end out = NAST::Select.new(id: n.id, cond: c, on_true: t, on_false: f, loc: n.loc, meta: n..dup) return stamp!(out, target_axes, dtype_of(t)) end if @registry.function_reduce?(n.fn) raise "Reducers should only have one arg" if n.args.size != 1 # TODO: -> sugar to collapse variadics? arg_node = n.args.first visited_arg = arg_node.accept(self) = visited_arg[:meta] arg_type = [:stamp][:dtype] if Kumi::Core::Types.collection?(arg_type) # --- Path for FOLD (Scalar or Vectorized) ---w # The argument is semantically a tuple. Create a Fold node. # We still need to visit the child node to build the SNAST tree fold_node = NAST::Fold.new( id: n.id, fn: @registry.resolve_function(n.fn), arg: visited_arg, # The arg is the tuple/reference to the tuple loc: n.loc, meta: n..dup ) # The output type is the reduced scalar type (e.g., :integer for max). # The axes are PRESERVED because a fold is an element-wise operation # on the container of tuples. = (n) return stamp!(fold_node, [:result_scope], [:result_type]) else # --- Path for REDUCE (Vectorized Arrays) --- in_axes = axes_of(visited_arg) if in_axes.empty? raise Kumi::Core::Errors::SemanticError, "reduce function called on a non-collection scalar: #{arg_type}" end = (n) out_axes = Array([:result_scope]) raise Kumi::Core::Errors::SemanticError, "reduce: out axes must prefix arg axes" unless prefix?(out_axes, in_axes) over_axes = in_axes.drop(out_axes.length) reduce_node = NAST::Reduce.new( id: n.id, fn: @registry.resolve_function(n.fn), over: over_axes, arg: visited_arg, loc: n.loc, meta: n..dup ) return stamp!(reduce_node, out_axes, [:result_type]) end end # regular elementwise args = n.args.map { _1.accept(self) } m = (n) # Use the function ID from metadata (already resolved with type awareness in NASTDimensionalAnalyzerPass) fn_id = m[:function] || @registry.resolve_function(n.fn) out = n.class.new(id: n.id, fn: fn_id.to_sym, args:, opts: n.opts, loc: n.loc) stamp!(out, m[:result_scope], m[:result_type]) end |
#visit_const(n) ⇒ Object
———- Leaves ———-
59 60 61 62 63 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 59 def visit_const(n) = (n) out = n.class.new(id: n.id, value: n.value, loc: n.loc) stamp!(out, [], [:type]) end |
#visit_declaration(d) ⇒ Object
50 51 52 53 54 55 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 50 def visit_declaration(d) = @declaration_table.fetch(d.name) body = d.body.accept(self) out = d.class.new(id: d.id, name: d.name, body:, loc: d.loc, meta: { kind: d.kind }) stamp!(out, [:result_scope], [:result_type]) end |
#visit_hash(n) ⇒ Object
90 91 92 93 94 95 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 90 def visit_hash(n) pairs = n.pairs.map { _1.accept(self) } m = (n) out = n.class.new(id: n.id, pairs:, loc: n.loc) stamp!(out, m[:scope], m[:type]) end |
#visit_import_call(n) ⇒ Object
183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 183 def visit_import_call(n) args = n.args.map { _1.accept(self) } m = (n) out = n.class.new( id: n.id, fn_name: n.fn_name, args: args, input_mapping_keys: n.input_mapping_keys, source_module: n.source_module, loc: n.loc, meta: n..dup ) stamp!(out, m[:result_scope], m[:result_type]) end |
#visit_index_ref(n) ⇒ Object
71 72 73 74 75 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 71 def visit_index_ref(n) m = (n) out = n.class.new(id: n.id, name: n.name, input_fqn: n.input_fqn, loc: n.loc) stamp!(out, m[:scope], m[:type]) end |
#visit_input_ref(n) ⇒ Object
65 66 67 68 69 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 65 def visit_input_ref(n) ent = lookup_input(n.path_fqn) out = n.class.new(id: n.id, path: n.path, loc: n.loc) stamp!(out, ent[:axes], ent[:dtype]) end |
#visit_module(mod) ⇒ Object
———- Visitor entry points ———-
45 46 47 48 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 45 def visit_module(mod) # decls is expected to be a Hash[name => Declaration] mod.class.new(decls: mod.decls.transform_values { |d| d.accept(self) }) end |
#visit_pair(n) ⇒ Object
97 98 99 100 101 102 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 97 def visit_pair(n) value = n.value.accept(self) m = (n) out = n.class.new(id: n.id, key: n.key, value:) stamp!(out, m[:scope], m[:type]) end |
#visit_ref(n) ⇒ Object
77 78 79 80 81 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 77 def visit_ref(n) m = (n) out = n.class.new(id: n.id, name: n.name, loc: n.loc) stamp!(out, m[:result_scope], m[:result_type]) end |
#visit_tuple(n) ⇒ Object
83 84 85 86 87 88 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 83 def visit_tuple(n) args = n.args.map { _1.accept(self) } m = (n) out = n.class.new(id: n.id, args:, loc: n.loc) stamp!(out, m[:result_scope], m[:result_type]) end |