Class: OnnxRuntime::InferenceSession
- Inherits:
-
Object
- Object
- OnnxRuntime::InferenceSession
- Defined in:
- lib/onnxruntime/inference_session.rb
Instance Attribute Summary collapse
-
#inputs ⇒ Object
readonly
Returns the value of attribute inputs.
-
#outputs ⇒ Object
readonly
Returns the value of attribute outputs.
Instance Method Summary collapse
-
#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) ⇒ InferenceSession
constructor
A new instance of InferenceSession.
-
#run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object
TODO support logid.
Constructor Details
#initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) ⇒ InferenceSession
Returns a new instance of InferenceSession.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# File 'lib/onnxruntime/inference_session.rb', line 5 def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: true, enable_profiling: false, execution_mode: nil, graph_optimization_level: nil, inter_op_num_threads: nil, intra_op_num_threads: nil, log_severity_level: nil, log_verbosity_level: nil, logid: nil, optimized_model_filepath: nil) # session options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateSessionOptions].call() check_status api[:EnableCpuMemArena].call(.read_pointer) if enable_cpu_mem_arena check_status api[:EnableMemPattern].call(.read_pointer) if enable_mem_pattern check_status api[:EnableProfiling].call(.read_pointer, "onnxruntime_profile_") if enable_profiling if execution_mode mode = case execution_mode when :sequential 0 when :parallel 1 else raise ArgumentError, "Invalid execution mode" end check_status api[:SetSessionExecutionMode].call(.read_pointer, mode) end check_status api[:SetSessionGraphOptimizationLevel].call(.read_pointer, graph_optimization_level) if graph_optimization_level check_status api[:SetInterOpNumThreads].call(.read_pointer, inter_op_num_threads) if inter_op_num_threads check_status api[:SetIntraOpNumThreads].call(.read_pointer, intra_op_num_threads) if intra_op_num_threads check_status api[:SetSessionLogSeverityLevel].call(.read_pointer, log_severity_level) if log_severity_level check_status api[:SetSessionLogVerbosityLevel].call(.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:SetSessionLogId].call(.read_pointer, logid) if logid check_status api[:SetOptimizedModelFilePath].call(.read_pointer, optimized_model_filepath) if optimized_model_filepath # session @session = ::FFI::MemoryPointer.new(:pointer) path_or_bytes = path_or_bytes.to_str # fix for Windows "File doesn't exist" if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY path_or_bytes = File.binread(path_or_bytes) end if path_or_bytes.encoding == Encoding::BINARY check_status api[:CreateSessionFromArray].call(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, .read_pointer, @session) else check_status api[:CreateSession].call(env.read_pointer, path_or_bytes, .read_pointer, @session) end # input info allocator = ::FFI::MemoryPointer.new(:pointer) check_status api[:GetAllocatorWithDefaultOptions].call(allocator) @allocator = allocator @inputs = [] @outputs = [] # input num_input_nodes = ::FFI::MemoryPointer.new(:size_t) check_status api[:SessionGetInputCount].call(read_pointer, num_input_nodes) read_size_t(num_input_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status api[:SessionGetInputName].call(read_pointer, i, @allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo) @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end # output num_output_nodes = ::FFI::MemoryPointer.new(:size_t) check_status api[:SessionGetOutputCount].call(read_pointer, num_output_nodes) read_size_t(num_output_nodes).times do |i| name_ptr = ::FFI::MemoryPointer.new(:string) check_status api[:SessionGetOutputName].call(read_pointer, i, allocator.read_pointer, name_ptr) typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo) @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) end end |
Instance Attribute Details
#inputs ⇒ Object (readonly)
Returns the value of attribute inputs.
3 4 5 |
# File 'lib/onnxruntime/inference_session.rb', line 3 def inputs @inputs end |
#outputs ⇒ Object (readonly)
Returns the value of attribute outputs.
3 4 5 |
# File 'lib/onnxruntime/inference_session.rb', line 3 def outputs @outputs end |
Instance Method Details
#run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) ⇒ Object
TODO support logid
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# File 'lib/onnxruntime/inference_session.rb', line 79 def run(output_names, input_feed, log_severity_level: nil, log_verbosity_level: nil, logid: nil, terminate: nil) input_tensor = create_input_tensor(input_feed) output_names ||= @outputs.map { |v| v[:name] } output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size) input_node_names = create_node_names(input_feed.keys.map(&:to_s)) output_node_names = create_node_names(output_names.map(&:to_s)) # run options = ::FFI::MemoryPointer.new(:pointer) check_status api[:CreateRunOptions].call() check_status api[:RunOptionsSetRunLogSeverityLevel].call(.read_pointer, log_severity_level) if log_severity_level check_status api[:RunOptionsSetRunLogVerbosityLevel].call(.read_pointer, log_verbosity_level) if log_verbosity_level check_status api[:RunOptionsSetRunTag].call(.read_pointer, logid) if logid check_status api[:RunOptionsSetTerminate].call(.read_pointer) if terminate check_status api[:Run].call(read_pointer, .read_pointer, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor) output_names.size.times.map do |i| create_from_onnx_value(output_tensor[i].read_pointer) end end |