Class: OnnxRuntime::InferenceSession

Inherits:
Object
  • Object
show all
Defined in:
lib/onnxruntime/inference_session.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) ⇒ InferenceSession



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/onnxruntime/inference_session.rb', line 5

def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
  # session options
  session_options = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:CreateSessionOptions].call(session_options)
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
  check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
  if execution_mode
    mode =
      case execution_mode
      when :sequential
        0
      when :parallel
        1
      else
        raise ArgumentError, "Invalid execution mode"
      end
    check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
  end
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
  check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
  check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath

  # session
  @session = ::FFI::MemoryPointer.new(:pointer)
  from_memory =
    if path_or_bytes.respond_to?(:read)
      path_or_bytes = path_or_bytes.read
      true
    else
      path_or_bytes = path_or_bytes.to_str
      path_or_bytes.encoding == Encoding::BINARY
    end

  # fix for Windows "File doesn't exist"
  if Gem.win_platform? && !from_memory
    path_or_bytes = File.binread(path_or_bytes)
    from_memory = true
  end

  if from_memory
    check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
  else
    check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
  end

  # input info
  allocator = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
  @allocator = allocator

  @inputs = []
  @outputs = []

  # input
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
  read_size_t(num_input_nodes).times do |i|
    name_ptr = ::FFI::MemoryPointer.new(:string)
    check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
    typeinfo = ::FFI::MemoryPointer.new(:pointer)
    check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
    @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
  end

  # output
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
  read_size_t(num_output_nodes).times do |i|
    name_ptr = ::FFI::MemoryPointer.new(:string)
    check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
    typeinfo = ::FFI::MemoryPointer.new(:pointer)
    check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
    @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
  end
end

Instance Attribute Details

#inputsObject (readonly)

Returns the value of attribute inputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def inputs
  @inputs
end

#outputsObject (readonly)

Returns the value of attribute outputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def outputs
  @outputs
end

Instance Method Details

#end_profilingObject



149
150
151
152
153
# File 'lib/onnxruntime/inference_session.rb', line 149

def end_profiling
  out = ::FFI::MemoryPointer.new(:string)
  check_status api[:SessionEndProfiling].call(read_pointer, @allocator.read_pointer, out)
  out.read_pointer.read_string
end

#modelmetaObject



111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# File 'lib/onnxruntime/inference_session.rb', line 111

def modelmeta
  keys = ::FFI::MemoryPointer.new(:pointer)
  num_keys = ::FFI::MemoryPointer.new(:int64_t)
  description = ::FFI::MemoryPointer.new(:string)
  domain = ::FFI::MemoryPointer.new(:string)
  graph_name = ::FFI::MemoryPointer.new(:string)
  producer_name = ::FFI::MemoryPointer.new(:string)
  version = ::FFI::MemoryPointer.new(:int64_t)

   = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:SessionGetModelMetadata].call(read_pointer, )

   = {}
  check_status = api[:ModelMetadataGetCustomMetadataMapKeys].call(.read_pointer, @allocator.read_pointer, keys, num_keys)
  num_keys.read(:int64_t).times do |i|
    key = keys.read_pointer[i * ::FFI::Pointer.size].read_pointer.read_string
    value = ::FFI::MemoryPointer.new(:string)
    check_status api[:ModelMetadataLookupCustomMetadataMap].call(.read_pointer, @allocator.read_pointer, key, value)
    [key] = value.read_pointer.read_string
  end

  check_status api[:ModelMetadataGetDescription].call(.read_pointer, @allocator.read_pointer, description)
  check_status api[:ModelMetadataGetDomain].call(.read_pointer, @allocator.read_pointer, domain)
  check_status api[:ModelMetadataGetGraphName].call(.read_pointer, @allocator.read_pointer, graph_name)
  check_status api[:ModelMetadataGetProducerName].call(.read_pointer, @allocator.read_pointer, producer_name)
  check_status api[:ModelMetadataGetVersion].call(.read_pointer, version)
  api[:ReleaseModelMetadata].call(.read_pointer)

  {
    custom_metadata_map: ,
    description: description.read_pointer.read_string,
    domain: domain.read_pointer.read_string,
    graph_name: graph_name.read_pointer.read_string,
    producer_name: producer_name.read_pointer.read_string,
    version: version.read(:int64_t)
  }
end

#run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object

TODO support logid



87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# File 'lib/onnxruntime/inference_session.rb', line 87

def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
  input_tensor = create_input_tensor(input_feed)

  output_names ||= @outputs.map { |v| v[:name] }

  output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
  input_node_names = create_node_names(input_feed.keys.map(&:to_s))
  output_node_names = create_node_names(output_names.map(&:to_s))

  # run options
  run_options = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:CreateRunOptions].call(run_options)
  check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
  check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
  check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
  check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate

  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)

  output_names.size.times.map do |i|
    create_from_onnx_value(output_tensor[i].read_pointer)
  end
end