Class: OnnxRuntime::InferenceSession

Inherits:
Object
  • Object
show all
Defined in:
lib/onnxruntime/inference_session.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(path_or_bytes) ⇒ InferenceSession

Returns a new instance of InferenceSession.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# File 'lib/onnxruntime/inference_session.rb', line 5

def initialize(path_or_bytes)
  # session options
  session_options = ::FFI::MemoryPointer.new(:pointer)
  check_status FFI.OrtCreateSessionOptions(session_options)

  # session
  @session = ::FFI::MemoryPointer.new(:pointer)
  path_or_bytes = path_or_bytes.to_str

  # fix for Windows "File doesn't exist"
  if Gem.win_platform? && path_or_bytes.encoding != Encoding::BINARY
    path_or_bytes = File.binread(path_or_bytes)
  end

  if path_or_bytes.encoding == Encoding::BINARY
    check_status FFI.OrtCreateSessionFromArray(env.read_pointer, path_or_bytes, path_or_bytes.bytesize, session_options.read_pointer, @session)
  else
    check_status FFI.OrtCreateSession(env.read_pointer, path_or_bytes, session_options.read_pointer, @session)
  end

  # input info
  allocator = ::FFI::MemoryPointer.new(:pointer)
  check_status FFI.OrtCreateDefaultAllocator(allocator)
  @allocator = allocator

  @inputs = []
  @outputs = []

  # input
  num_input_nodes = ::FFI::MemoryPointer.new(:size_t)
  check_status FFI.OrtSessionGetInputCount(read_pointer, num_input_nodes)
  read_size_t(num_input_nodes).times do |i|
    name_ptr = ::FFI::MemoryPointer.new(:string)
    check_status FFI.OrtSessionGetInputName(read_pointer, i, @allocator.read_pointer, name_ptr)
    typeinfo = ::FFI::MemoryPointer.new(:pointer)
    check_status FFI.OrtSessionGetInputTypeInfo(read_pointer, i, typeinfo)
    @inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
  end

  # output
  num_output_nodes = ::FFI::MemoryPointer.new(:size_t)
  check_status FFI.OrtSessionGetOutputCount(read_pointer, num_output_nodes)
  read_size_t(num_output_nodes).times do |i|
    name_ptr = ::FFI::MemoryPointer.new(:string)
    check_status FFI.OrtSessionGetOutputName(read_pointer, i, allocator.read_pointer, name_ptr)
    typeinfo = ::FFI::MemoryPointer.new(:pointer)
    check_status FFI.OrtSessionGetOutputTypeInfo(read_pointer, i, typeinfo)
    @outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
  end
end

Instance Attribute Details

#inputsObject (readonly)

Returns the value of attribute inputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def inputs
  @inputs
end

#outputsObject (readonly)

Returns the value of attribute outputs.



3
4
5
# File 'lib/onnxruntime/inference_session.rb', line 3

def outputs
  @outputs
end

Instance Method Details

#run(output_names, input_feed) ⇒ Object



56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# File 'lib/onnxruntime/inference_session.rb', line 56

def run(output_names, input_feed)
  input_tensor = create_input_tensor(input_feed)

  output_names ||= @outputs.map { |v| v[:name] }

  output_tensor = ::FFI::MemoryPointer.new(:pointer, outputs.size)
  input_node_names = create_node_names(input_feed.keys.map(&:to_s))
  output_node_names = create_node_names(output_names.map(&:to_s))
  # TODO support run options
  check_status FFI.OrtRun(read_pointer, nil, input_node_names, input_tensor, input_feed.size, output_node_names, output_names.size, output_tensor)

  output_names.size.times.map do |i|
    create_from_onnx_value(output_tensor[i].read_pointer)
  end
end