Class: OnnxRuntime::InferenceSession
- Inherits:
-
Object
- Object
- OnnxRuntime::InferenceSession
- Defined in:
- lib/onnxruntime/inference_session.rb
Instance Attribute Summary collapse
-
#inputs ⇒ Object
readonly
Returns the value of attribute inputs.
-
#outputs ⇒ Object
readonly
Returns the value of attribute outputs.
Instance Method Summary collapse
-
#end_profiling ⇒ Object
return value has double underscore like Python.
-
#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil, providers: []) ⇒ InferenceSession
constructor
A new instance of InferenceSession.
- #modelmeta ⇒ Object
-
#providers ⇒ Object
no way to set providers with C API yet so we can return all available providers.
- #run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby) ⇒ Object
-
#run_with_ort_values(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object
TODO support logid.
Constructor Details
#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil, providers: []) ⇒ InferenceSession
Returns a new instance of InferenceSession.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# File 'lib/onnxruntime/inference_session.rb', line 5 def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil, providers: []) # session options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateSessionOptions].call() = ::FFI::AutoPointer.new(.read_pointer, api[:ReleaseSessionOptions]) if enable_cpu_mem_arena check_status api[:EnableCpuMemArena].call() else check_status api[:DisableCpuMemArena].call() end if enable_mem_pattern check_status api[:EnableMemPattern].call() else check_status api[:DisableMemPattern].call() end if enable_profiling check_status api[:EnableProfiling].call(, ort_string(profile_file_prefix || "onnxruntime_profile_")) else check_status api[:DisableProfiling].call() end if execution_mode execution_modes = {sequential: 0, parallel: 1} mode = execution_modes[execution_mode] raise ArgumentError, "Invalid execution mode" unless mode check_status api[:SetSessionExecutionMode].call(, mode) end if free_dimension_overrides_by_denotation free_dimension_overrides_by_denotation.each do |k, v| check_status api[:AddFreeDimensionOverride].call(, k.to_s, v) end end if free_dimension_overrides_by_name free_dimension_overrides_by_name.each do |k, v| check_status api[:AddFreeDimensionOverrideByName].call(, k.to_s, v) end end if graph_optimization_level optimization_levels = {none: 0, basic: 1, extended: 2, all: 99} level = optimization_levels[graph_optimization_level] raise ArgumentError, "Invalid graph optimization level" unless level check_status api[:SetSessionGraphOptimizationLevel].call(, level) end check_status api[:SetInterOpNumThreads].call(, inter_op_num_threads) if inter_op_num_threads check_status api[:SetIntraOpNumThreads].call(, intra_op_num_threads) if intra_op_num_threads check_status api[:SetSessionLogSeverityLevel].call(, log_severity_level) if log_severity_level check_status api[:SetSessionLogVerbosityLevel].call(, log_verbosity_level) if log_verbosity_level check_status api[:SetSessionLogId].call(, logid) if logid check_status api[:SetOptimizedModelFilePath].call(, ort_string(optimized_model_filepath)) if optimized_model_filepath if session_config_entries session_config_entries.each do |k, v| check_status api[:AddSessionConfigEntry].call(, k.to_s, v.to_s) end end providers.each do |provider| unless self.providers.include?(provider) warn "Provider not available: #{provider}" next end case provider when "CUDAExecutionProvider" = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateCUDAProviderOptions].call() = ::FFI::AutoPointer.new(.read_pointer, api[:ReleaseCUDAProviderOptions]) check_status api[:SessionOptionsAppendExecutionProvider_CUDA_V2].call(, ) when "CoreMLExecutionProvider" unless FFI.respond_to?(:OrtSessionOptionsAppendExecutionProvider_CoreML) raise ArgumentError, "Provider not available: #{provider}" end coreml_flags = 0 check_status FFI.OrtSessionOptionsAppendExecutionProvider_CoreML(, coreml_flags) when "CPUExecutionProvider" break else raise ArgumentError, "Provider not supported: #{provider}" end end @session = load_session(path_or_bytes, ) @allocator = Utils.allocator @inputs = load_inputs @outputs = load_outputs end |
Instance Attribute Details
#inputs ⇒ Object (readonly)
Returns the value of attribute inputs.
3 4 5 |
# File 'lib/onnxruntime/inference_session.rb', line 3 def inputs @inputs end |
#outputs ⇒ Object (readonly)
Returns the value of attribute outputs.
3 4 5 |
# File 'lib/onnxruntime/inference_session.rb', line 3 def outputs @outputs end |
Instance Method Details
#end_profiling ⇒ Object
return value has double underscore like Python
191 192 193 194 195 196 197 198 199 |
# File 'lib/onnxruntime/inference_session.rb', line 191 def end_profiling out = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionEndProfiling].call(@session, @allocator, out) begin out.read_pointer.read_string ensure allocator_free out.read_pointer end end |
#modelmeta ⇒ Object
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# File 'lib/onnxruntime/inference_session.rb', line 132 def = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetModelMetadata].call(@session, ) = ::FFI::AutoPointer.new(.read_pointer, api[:ReleaseModelMetadata]) keys = ::FFI::MemoryPointer.new(:pointer) num_keys = ::FFI::MemoryPointer.new(:int64_t) check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(, @allocator, keys, num_keys) keys = keys.read_pointer = {} num_keys.read(:int64_t).times do |i| key_ptr = keys.get_pointer(i * ::FFI::Pointer.size) key = key_ptr.read_string value = ::FFI::MemoryPointer.new(:pointer) check_status api[:ModelMetadataLookupCustomMetadataMap].call(, @allocator, key, value) [key] = value.read_pointer.read_string allocator_free key_ptr allocator_free value.read_pointer end description = ::FFI::MemoryPointer.new(:pointer) check_status api[:ModelMetadataGetDescription].call(, @allocator, description) domain = ::FFI::MemoryPointer.new(:pointer) check_status api[:ModelMetadataGetDomain].call(, @allocator, domain) graph_name = ::FFI::MemoryPointer.new(:pointer) check_status api[:ModelMetadataGetGraphName].call(, @allocator, graph_name) graph_description = ::FFI::MemoryPointer.new(:pointer) check_status api[:ModelMetadataGetGraphDescription].call(, @allocator, graph_description) producer_name = ::FFI::MemoryPointer.new(:pointer) check_status api[:ModelMetadataGetProducerName].call(, @allocator, producer_name) version = ::FFI::MemoryPointer.new(:int64_t) check_status api[:ModelMetadataGetVersion].call(, version) { custom_metadata_map: , description: description.read_pointer.read_string, domain: domain.read_pointer.read_string, graph_name: graph_name.read_pointer.read_string, graph_description: graph_description.read_pointer.read_string, producer_name: producer_name.read_pointer.read_string, version: version.read(:int64_t) } ensure allocator_free keys allocator_free description.read_pointer allocator_free domain.read_pointer allocator_free graph_name.read_pointer allocator_free graph_description.read_pointer allocator_free producer_name.read_pointer end |
#providers ⇒ Object
no way to set providers with C API yet so we can return all available providers
203 204 205 206 207 208 209 210 211 |
# File 'lib/onnxruntime/inference_session.rb', line 203 def providers out_ptr = ::FFI::MemoryPointer.new(:pointer) length_ptr = ::FFI::MemoryPointer.new(:int) check_status api[:GetAvailableProviders].call(out_ptr, length_ptr) length = length_ptr.read_int providers = out_ptr.read_pointer.read_array_of_pointer(length).map(&:read_string) api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length) providers end |
#run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby) ⇒ Object
91 92 93 94 95 96 97 98 99 100 101 |
# File 'lib/onnxruntime/inference_session.rb', line 91 def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby) if ![:ruby, :numo, :ort_value].include?(output_type) raise ArgumentError, "Invalid output type: #{output_type}" end ort_values = input_feed.keys.zip(create_input_tensor(input_feed)).to_h outputs = run_with_ort_values(output_names, ort_values, log_severity_level: log_severity_level, log_verbosity_level: log_verbosity_level, logid: logid, terminate: terminate) outputs.map { |v| output_type == :numo ? v.numo : (output_type == :ort_value ? v : v.to_ruby) } end |
#run_with_ort_values(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object
TODO support logid
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# File 'lib/onnxruntime/inference_session.rb', line 104 def run_with_ort_values(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size) input_feed.each_with_index do |(_, input), i| input_tensor[i].write_pointer(input.to_ptr) end output_names ||= @outputs.map { |v| v[:name] } output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) refs = [] input_node_names = create_node_names(input_feed.keys.map(&:to_s), refs) output_node_names = create_node_names(output_names.map(&:to_s), refs) # run options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateRunOptions].call() = ::FFI::AutoPointer.new(.read_pointer, api[:ReleaseRunOptions]) check_status api[:RunOptionsSetRunLogSeverityLevel].call(, log_severity_level) if log_severity_level check_status api[:RunOptionsSetRunLogVerbosityLevel].call(, log_verbosity_level) if log_verbosity_level check_status api[:RunOptionsSetRunTag].call(, logid) if logid check_status api[:RunOptionsSetTerminate].call() if terminate check_status api[:Run].call(@session, , input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) output_names.size.times.map { |i| OrtValue.new(output_tensor[i].read_pointer) } end |