OnnxRuntime::Torch::Tensor
Torch::Tensor support for ONNX Runtime Ruby.
This gem provides conversion between OnnxRuntime::OrtValue and Torch::Tensor, so that you can pass Torch::Tensor inputs to OnnxRuntime and work with outputs as Torch::Tensor.
It works with zero-copy in most cases. Zero-copy is available when receiver is row-major and contiguous tensor on CPU.
SYNOPSIS
require "torchaudio"
require "onnxruntime"
require "onnxruntime/torch/tensor"
session = OnnxRuntime::InferenceSession.new("path/to/model")
waveform, sample_rate = TorchAudio.load("path/to/file") # => [Torch::Tensor, Integer]
input_torch_tensor = pre_process(waveform) # => Torch::Tensor
input = input_torch_tensor.to_ort_value # => OnnxRuntime::OrtValue, zero-copy
outputs = session.run(
[:output_name],
{input_name: input},
output_type: :ort_value # required to get Torch::Tensor at the next step
) # => Array[OnnxRuntime::OrtValue]
output = outputs[0] # => OnnxRuntime::OrtValue
output_torch_tensor = output.to_torch_tensor # => Torch::Tensor, zero-copy
output_waveform = post_process(output_torch_tensor)
TorchAudio.save("path/to/output", output_waveform, sample_rate)
INSTALLATION
% gem install onnxruntime-torch-tensor
or, add to Gemfile:
gem "onnxruntime-torch-tensor"
API
OnnxRuntime::OrtValue#to_torch_tensor # => Torch::Tensor
Torch::Tensor#to_ort_value # => OnnxRuntime::OrtValue
OnnxRuntime::OrtValue.from_torch_tensor(torch_tensor) # => OnnxRuntime::OrtValue
# Not Torch::Tensor.fromm_ort_value, inspired from Torch.from_numo
Torch.from_ort_value(ort_value) # => Torch::Tensor
This gem uses NDAV internally, so provides OrtValue() and TorchTensor() methods based on it.
NDAV::Converter::OrtValue(torch_tensor) # => OnnxRuntime::OrtValue
NDAV::Converter::TorchTensor(ort_value) # => Torch::Tensor
module YourApp
include NDAV::Converter
def your_method
OrtValue(torch_tensor) # => OnnxRuntime::OrtValue
TorchTensor(ort_value) # => Torch::Tensor
end
end
LICENSE
BSD-3-Clause. See LICENSE.txt file.