Method: TensorStream::Evaluator::BaseEvaluator.query_device

Defined in:
lib/tensor_stream/evaluator/base_evaluator.rb

.query_device(query) ⇒ Object

Select device using uri



61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# File 'lib/tensor_stream/evaluator/base_evaluator.rb', line 61

def self.query_device(query)
  return default_device if query.nil? || query == :default

  all_devices = query_supported_devices
  substrs = query.split("/")
  substrs.each do |q|
    components = q.split(":")
    next if components.size.zero?

    if components[0] == "device" # use tensorflow convention
      device_type = components[1]
      select_index = components[2].to_i

      devices = all_devices.select { |d| d.type == device_type.downcase.to_sym }
      return nil if devices.empty?

      select_index = [devices.size - 1, select_index].min
      return devices[select_index]
    elsif %w[cpu gpu].include?(components[0])
      device_type = components[0].to_sym
      select_index = components[1].to_i

      devices = all_devices.select { |d| d.type == device_type.downcase.to_sym }
      return nil if devices.empty?

      select_index = [devices.size - 1, select_index].min
      return devices[select_index]
    elsif components[0] == "ts" # tensorstream specific
      evaluator_class = TensorStream::Evaluator.evaluators[components[1]][:class]
      return nil unless self == evaluator_class
      return evaluator_class.fetch_device(components[2..components.size]) if evaluator_class.respond_to?(:fetch_device)

      return nil
    end
  end
end