Class: OnnxRuntime::InferenceSession

Inherits:
Object
  • Object
show all
Defined in:
lib/onnxruntime/inference_session.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil, providers: []) ⇒ InferenceSession

Returns a new instance of InferenceSession.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# File 'lib/onnxruntime/inference_session.rb', line 5

def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, free_dimension_overrides_by_denotation: nil, free_dimension_overrides_by_name: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil, profile_file_prefix: nil, session_config_entries: nil, providers: [])
  # session options
  session_options = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:CreateSessionOptions].call(session_options)
  session_options = ::FFI::AutoPointer.new(session_options.read_pointer, api[:ReleaseSessionOptions])

  if enable_cpu_mem_arena
    check_status api[:EnableCpuMemArena].call(session_options)
  else
    check_status api[:DisableCpuMemArena].call(session_options)
  end
  if enable_mem_pattern
    check_status api[:EnableMemPattern].call(session_options)
  else
    check_status api[:DisableMemPattern].call(session_options)
  end
  if enable_profiling
    check_status api[:EnableProfiling].call(session_options, ort_string(profile_file_prefix || "onnxruntime_profile_"))
  else
    check_status api[:DisableProfiling].call(session_options)
  end
  if execution_mode
    execution_modes = {sequential: 0, parallel: 1}
    mode = execution_modes[execution_mode]
    raise ArgumentError, "Invalid execution mode" unless mode
    check_status api[:SetSessionExecutionMode].call(session_options, mode)
  end
  if free_dimension_overrides_by_denotation
    free_dimension_overrides_by_denotation.each do |k, v|
      check_status api[:AddFreeDimensionOverride].call(session_options, k.to_s, v)
    end
  end
  if free_dimension_overrides_by_name
    free_dimension_overrides_by_name.each do |k, v|
      check_status api[:AddFreeDimensionOverrideByName].call(session_options, k.to_s, v)
    end
  end
  if graph_optimization_level
    optimization_levels = {none: 0, basic: 1, extended: 2, all: 99}
    level = optimization_levels[graph_optimization_level]
    raise ArgumentError, "Invalid graph optimization level" unless level
    check_status api[:SetSessionGraphOptimizationLevel].call(session_options, level)
  end
  check_status api[:SetInterOpNumThreads].call(session_options, inter_op_num_threads) if inter_op_num_threads
  check_status api[:SetIntraOpNumThreads].call(session_options, intra_op_num_threads) if intra_op_num_threads
  check_status api[:SetSessionLogSeverityLevel].call(session_options, log_severity_level) if log_severity_level
  check_status api[:SetSessionLogVerbosityLevel].call(session_options, log_verbosity_level) if log_verbosity_level
  check_status api[:SetSessionLogId].call(session_options, logid) if logid
  check_status api[:SetOptimizedModelFilePath].call(session_options, ort_string(optimized_model_filepath)) if optimized_model_filepath
  if session_config_entries
    session_config_entries.each do |k, v|
      check_status api[:AddSessionConfigEntry].call(session_options, k.to_s, v.to_s)
    end
  end
  providers.each do |provider|
    unless self.providers.include?(provider)
      warn "Provider not available: #{provider}"
      next
    end

    case provider
    when "CUDAExecutionProvider"
      cuda_options = ::FFI::MemoryPointer.new(:pointer)
      check_status api[:CreateCUDAProviderOptions].call(cuda_options)
      cuda_options = ::FFI::AutoPointer.new(cuda_options.read_pointer, api[:ReleaseCUDAProviderOptions])
      check_status api[:SessionOptionsAppendExecutionProvider_CUDA_V2].call(session_options, cuda_options)
    when "CoreMLExecutionProvider"
      unless FFI.respond_to?(:OrtSessionOptionsAppendExecutionProvider_CoreML)
        raise ArgumentError, "Provider not available: #{provider}"
      end

      coreml_flags = 0
      check_status FFI.OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags)
    when "CPUExecutionProvider"
      break
    else
      raise ArgumentError, "Provider not supported: #{provider}"
    end
  end

  @session = load_session(path_or_bytes, session_options)
  @allocator = Utils.allocator
  @inputs = load_inputs
  @outputs = load_outputs
end

Instance Attribute Details

#inputsObject (readonly)

Returns the value of attribute inputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def inputs
  @inputs
end

#outputsObject (readonly)

Returns the value of attribute outputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def outputs
  @outputs
end

Instance Method Details

#end_profilingObject

return value has double underscore like Python



191
192
193
194
195
196
197
198
199
# File 'lib/onnxruntime/inference_session.rb', line 191

def end_profiling
  out = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:SessionEndProfiling].call(@session, @allocator, out)
  begin
    out.read_pointer.read_string
  ensure
    allocator_free out.read_pointer
  end
end

#modelmetaObject



132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# File 'lib/onnxruntime/inference_session.rb', line 132

def modelmeta
   = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:SessionGetModelMetadata].call(@session, )
   = ::FFI::AutoPointer.new(.read_pointer, api[:ReleaseModelMetadata])

  keys = ::FFI::MemoryPointer.new(:pointer)
  num_keys = ::FFI::MemoryPointer.new(:int64_t)
  check_status api[:ModelMetadataGetCustomMetadataMapKeys].call(, @allocator, keys, num_keys)
  keys = keys.read_pointer

   = {}
  num_keys.read(:int64_t).times do |i|
    key_ptr = keys.get_pointer(i * ::FFI::Pointer.size)
    key = key_ptr.read_string
    value = ::FFI::MemoryPointer.new(:pointer)
    check_status api[:ModelMetadataLookupCustomMetadataMap].call(, @allocator, key, value)
    [key] = value.read_pointer.read_string

    allocator_free key_ptr
    allocator_free value.read_pointer
  end

  description = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:ModelMetadataGetDescription].call(, @allocator, description)

  domain = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:ModelMetadataGetDomain].call(, @allocator, domain)

  graph_name = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:ModelMetadataGetGraphName].call(, @allocator, graph_name)

  graph_description = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:ModelMetadataGetGraphDescription].call(, @allocator, graph_description)

  producer_name = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:ModelMetadataGetProducerName].call(, @allocator, producer_name)

  version = ::FFI::MemoryPointer.new(:int64_t)
  check_status api[:ModelMetadataGetVersion].call(, version)

  {
    custom_metadata_map: ,
    description: description.read_pointer.read_string,
    domain: domain.read_pointer.read_string,
    graph_name: graph_name.read_pointer.read_string,
    graph_description: graph_description.read_pointer.read_string,
    producer_name: producer_name.read_pointer.read_string,
    version: version.read(:int64_t)
  }
ensure
  allocator_free keys
  allocator_free description.read_pointer
  allocator_free domain.read_pointer
  allocator_free graph_name.read_pointer
  allocator_free graph_description.read_pointer
  allocator_free producer_name.read_pointer
end

#providersObject

no way to set providers with C API yet so we can return all available providers



203
204
205
206
207
208
209
210
211
# File 'lib/onnxruntime/inference_session.rb', line 203

def providers
  out_ptr = ::FFI::MemoryPointer.new(:pointer)
  length_ptr = ::FFI::MemoryPointer.new(:int)
  check_status api[:GetAvailableProviders].call(out_ptr, length_ptr)
  length = length_ptr.read_int
  providers = out_ptr.read_pointer.read_array_of_pointer(length).map(&:read_string)
  api[:ReleaseAvailableProviders].call(out_ptr.read_pointer, length)
  providers
end

#run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby) ⇒ Object



91
92
93
94
95
96
97
98
99
100
101
# File 'lib/onnxruntime/inference_session.rb', line 91

def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil, output_type: :ruby)
  if ![:ruby, :numo, :ort_value].include?(output_type)
    raise ArgumentError, "Invalid output type: #{output_type}"
  end

  ort_values = input_feed.keys.zip(create_input_tensor(input_feed)).to_h

  outputs = run_with_ort_values(output_names, ort_values, log_severity_level: log_severity_level, log_verbosity_level: log_verbosity_level, logid: logid, terminate: terminate)

  outputs.map { |v| output_type == :numo ? v.numo : (output_type == :ort_value ? v : v.to_ruby) }
end

#run_with_ort_values(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object

TODO support logid



104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# File 'lib/onnxruntime/inference_session.rb', line 104

def run_with_ort_values(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
  input_tensor = ::FFI::MemoryPointer.new(:pointer, input_feed.size)
  input_feed.each_with_index do |(_, input), i|
    input_tensor[i].write_pointer(input.to_ptr)
  end

  output_names ||= @outputs.map { |v| v[:name] }

  output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
  refs = []
  input_node_names = create_node_names(input_feed.keys.map(&:to_s), refs)
  output_node_names = create_node_names(output_names.map(&:to_s), refs)

  # run options
  run_options = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:CreateRunOptions].call(run_options)
  run_options = ::FFI::AutoPointer.new(run_options.read_pointer, api[:ReleaseRunOptions])

  check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options, log_severity_level) if log_severity_level
  check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options, log_verbosity_level) if log_verbosity_level
  check_status api[:RunOptionsSetRunTag].call(run_options, logid) if logid
  check_status api[:RunOptionsSetTerminate].call(run_options) if terminate

  check_status api[:Run].call(@session, run_options, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)

  output_names.size.times.map { |i| OrtValue.new(output_tensor[i].read_pointer) }
end