Module: TensorFlow::Utils
- Included in:
- TensorFlow
- Defined in:
- lib/tensorflow/utils.rb
Class Method Summary collapse
- .check_status(status) ⇒ Object
- .default_context ⇒ Object
- .download_file(url, dest) ⇒ Object
- .execute(op_name, inputs = [], **attrs) ⇒ Object
- .infer_type(value) ⇒ Object
- .load_dataset(path, url) ⇒ Object
Class Method Details
.check_status(status) ⇒ Object
4 5 6 7 8 |
# File 'lib/tensorflow/utils.rb', line 4 def check_status(status) if FFI.TF_GetCode(status) != 0 raise Error, FFI.TF_Message(status) end end |
.default_context ⇒ Object
10 11 12 |
# File 'lib/tensorflow/utils.rb', line 10 def default_context @default_context ||= Context.new end |
.download_file(url, dest) ⇒ Object
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# File 'lib/tensorflow/utils.rb', line 108 def download_file(url, dest) uri = URI(url) temp_dir ||= File.dirname(Tempfile.new("tensorflow")) temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name # Net::HTTP automatically adds Accept-Encoding for compression # of response bodies and automatically decompresses gzip # and deflateresponses unless a Range header was sent. # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http| request = Net::HTTP::Get.new(uri) print("Downloading dataset") i = 0 File.open(temp_path, "wb") do |f| http.request(request) do |response| response.read_body do |chunk| f.write(chunk) # print progress putc "." if i % 50 == 0 i += 1 end end puts # newline end end FileUtils.mv(temp_path, dest) end |
.execute(op_name, inputs = [], **attrs) ⇒ Object
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
# File 'lib/tensorflow/utils.rb', line 14 def execute(op_name, inputs = [], **attrs) context = default_context status = FFI.TF_NewStatus # TODO reuse status between ops? op = FFI.TFE_NewOp(context, op_name, status) check_status status attrs.each do |attr_name, attr_value| next if attr_value.nil? attr_name = attr_name.to_s is_list = ::FFI::MemoryPointer.new(:int) type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status) check_status status case FFI::AttrType[type] when :string FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize) # when :int # when :float # when :bool when :type FFI.TFE_OpSetAttrType(op, attr_name, attr_value) when :shape # TODO set value properly FFI.TFE_OpSetAttrShape(op, attr_name, nil, 0, status) check_status status # when :tensor # when :placeholder # when :func else raise "Unknown type: #{FFI::AttrType[type]}" end end inputs.each do |input| input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr) FFI.TFE_OpAddInput(op, input, status) check_status status end retvals = ::FFI::MemoryPointer.new(:pointer) num_retvals = ::FFI::MemoryPointer.new(:int) num_retvals.write_int(retvals.size) FFI.TFE_Execute(op, retvals, num_retvals, status) check_status status if num_retvals.read_int > 0 handle = retvals.read_pointer type = FFI.TFE_TensorHandleDataType(handle) case FFI::DataType[type] when :resource handle else Tensor.new(pointer: handle) end end ensure FFI.TF_DeleteStatus(status) if status FFI.TFE_DeleteOp(op) if op end |
.infer_type(value) ⇒ Object
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# File 'lib/tensorflow/utils.rb', line 77 def infer_type(value) if value.all? { |v| v.is_a?(String) } :string elsif value.all? { |v| v == true || v == false } :bool elsif value.all? { |v| v.is_a?(Integer) } if value.all? { |v| v >= -2147483648 && v <= 2147483647 } :int32 else :int64 end elsif value.all? { |v| v.is_a?(Complex) } :complex128 elsif value.all? { |v| v.is_a?(Numeric) } :float else raise Error, "Unable to infer data type" end end |
.load_dataset(path, url) ⇒ Object
97 98 99 100 101 102 103 104 105 106 |
# File 'lib/tensorflow/utils.rb', line 97 def load_dataset(path, url) # TODO handle this better raise "No HOME" unless ENV["HOME"] datasets_dir = "#{ENV["HOME"]}/.keras/datasets" FileUtils.mkdir_p(datasets_dir) path = "#{datasets_dir}/#{path}" Utils.download_file(url, path) unless File.exist?(path) Npy.load_npz(path) end |