Class: OnnxRuntime::InferenceSession

Inherits:
Object
  • Object
show all
Defined in:
lib/onnxruntime/inference_session.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) ⇒ InferenceSession

Returns a new instance of InferenceSession.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/onnxruntime/inference_session.rb', line 5

def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil)
  # session options
  session_options = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:CreateSessionOptions].call(session_options)
  check_status api[:EnableCpuMemArena].call(session_options.read_pointer) if enable_cpu_mem_arena
  check_status api[:EnableMemPattern].call(session_options.read_pointer) if enable_mem_pattern
  check_status api[:EnableProfiling].call(session_options.read_pointer, "onnxruntime_profile_") if enable_profiling
  if execution_mode
    mode =
      case execution_mode
      when :sequential
        0
      when :parallel
        1
      else
        raise ArgumentError, "Invalid execution mode"
      end
    check_status api[:SetSessionExecutionMode].call(session_options.read_pointer, mode)
  end
  check_status api[:SetSessionGraphOptimizationLevel].call(session_options.read_pointer, graph_optimization_level) if graph_optimization_level
  check_status api[:SetInterOpNumThreads].call(session_options.read_pointer, inter_op_num_threads) if inter_op_num_threads
  check_status api[:SetIntraOpNumThreads].call(session_options.read_pointer, intra_op_num_threads) if intra_op_num_threads
  check_status api[:SetSessionLogSeverityLevel].call(session_options.read_pointer, log_severity_level) if log_severity_level
  check_status api[:SetSessionLogVerbosityLevel].call(session_options.read_pointer, log_verbosity_level) if log_verbosity_level
  check_status api[:SetSessionLogId].call(session_options.read_pointer, logid) if logid
  check_status api[:SetOptimizedModelFilePath].call(session_options.read_pointer, optimized_model_filepath) if optimized_model_filepath

  # session
  @session = ::FFI::MemoryPointer.new(:pointer)
  path_or_bytes = path_or_bytes.to_str

  # fix for Windows "File doesn't exist"
  if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
    path_or_bytes = File.binread(path_or_bytes)
  end

  if path_or_bytes.encoding == Encoding::BINARY
    check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
  else
    check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
  end

  # input info
  allocator = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:GetAllocatorWithDefaultOptions].call(allocator)
  @allocator = allocator

  @inputs = []
  @outputs = []

  # input
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
  check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes)
  read_size_t(num_input_nodes).times do |i|
    name_ptr = ::FFI::MemoryPointer.new(:string)
    check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr)
    typeinfo = ::FFI::MemoryPointer.new(:pointer)
    check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
    @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
  end

  # output
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
  check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes)
  read_size_t(num_output_nodes).times do |i|
    name_ptr = ::FFI::MemoryPointer.new(:string)
    check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr)
    typeinfo = ::FFI::MemoryPointer.new(:pointer)
    check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
    @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
  end
end

Instance Attribute Details

#inputsObject (readonly)

Returns the value of attribute inputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def inputs
  @inputs
end

#outputsObject (readonly)

Returns the value of attribute outputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def outputs
  @outputs
end

Instance Method Details

#run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object

TODO support logid



79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# File 'lib/onnxruntime/inference_session.rb', line 79

def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil)
  input_tensor = create_input_tensor(input_feed)

  output_names ||= @outputs.map { |v| v[:name] }

  output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
  input_node_names = create_node_names(input_feed.keys.map(&:to_s))
  output_node_names = create_node_names(output_names.map(&:to_s))

  # run options
  run_options = ::FFI::MemoryPointer.new(:pointer)
  check_status api[:CreateRunOptions].call(run_options)
  check_status api[:RunOptionsSetRunLogSeverityLevel].call(run_options.read_pointer, log_severity_level) if log_severity_level
  check_status api[:RunOptionsSetRunLogVerbosityLevel].call(run_options.read_pointer, log_verbosity_level) if log_verbosity_level
  check_status api[:RunOptionsSetRunTag].call(run_options.read_pointer, logid) if logid
  check_status api[:RunOptionsSetTerminate].call(run_options.read_pointer) if terminate

  check_status api[:Run].call(read_pointer, run_options.read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)

  output_names.size.times.map do |i|
    create_from_onnx_value(output_tensor[i].read_pointer)
  end
end