Module: TensorFlow::Utils

Included in:
TensorFlow
Defined in:
lib/tensorflow/utils.rb

Class Method Summary collapse

Class Method Details

.check_status(status) ⇒ Object



4
5
6
7
8
# File 'lib/tensorflow/utils.rb', line 4

def check_status(status)
  if FFI.TF_GetCode(status) != 0
    raise Error, FFI.TF_Message(status)
  end
end

.default_contextObject



10
11
12
# File 'lib/tensorflow/utils.rb', line 10

def default_context
  @default_context ||= Context.new
end

.download_file(url, dest) ⇒ Object



108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# File 'lib/tensorflow/utils.rb', line 108

def download_file(url, dest)
  uri = URI(url)

  temp_dir ||= File.dirname(Tempfile.new("tensorflow"))
  temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name

  # Net::HTTP automatically adds Accept-Encoding for compression
  # of response bodies and automatically decompresses gzip
  # and deflateresponses unless a Range header was sent.
  # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
  Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
    request = Net::HTTP::Get.new(uri)

    print("Downloading dataset")
    i = 0
    File.open(temp_path, "wb") do |f|
      http.request(request) do |response|
        response.read_body do |chunk|
          f.write(chunk)

          # print progress
          putc "." if i % 50 == 0
          i += 1
        end
      end
      puts # newline
    end
  end

  FileUtils.mv(temp_path, dest)
end

.execute(op_name, inputs = [], **attrs) ⇒ Object



14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# File 'lib/tensorflow/utils.rb', line 14

def execute(op_name, inputs = [], **attrs)
  context = default_context
  status = FFI.TF_NewStatus # TODO reuse status between ops?
  op = FFI.TFE_NewOp(context, op_name, status)
  check_status status

  attrs.each do |attr_name, attr_value|
    next if attr_value.nil?

    attr_name = attr_name.to_s

    is_list = ::FFI::MemoryPointer.new(:int)
    type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
    check_status status

    case FFI::AttrType[type]
    when :string
      FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
    # when :int
    # when :float
    # when :bool
    when :type
      FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
    when :shape
      # TODO set value properly
      FFI.TFE_OpSetAttrShape(op, attr_name, nil, 0, status)
      check_status status
    # when :tensor
    # when :placeholder
    # when :func
    else
      raise "Unknown type: #{FFI::AttrType[type]}"
    end
  end

  inputs.each do |input|
    input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr)
    FFI.TFE_OpAddInput(op, input, status)
    check_status status
  end

  retvals = ::FFI::MemoryPointer.new(:pointer)
  num_retvals = ::FFI::MemoryPointer.new(:int)
  num_retvals.write_int(retvals.size)
  FFI.TFE_Execute(op, retvals, num_retvals, status)
  check_status status

  if num_retvals.read_int > 0
    handle = retvals.read_pointer
    type = FFI.TFE_TensorHandleDataType(handle)

    case FFI::DataType[type]
    when :resource
      handle
    else
      Tensor.new(pointer: handle)
    end
  end
ensure
  FFI.TF_DeleteStatus(status) if status
  FFI.TFE_DeleteOp(op) if op
end

.infer_type(value) ⇒ Object



77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# File 'lib/tensorflow/utils.rb', line 77

def infer_type(value)
  if value.all? { |v| v.is_a?(String) }
    :string
  elsif value.all? { |v| v == true || v == false }
    :bool
  elsif value.all? { |v| v.is_a?(Integer) }
    if value.all? { |v| v >= -2147483648 && v <= 2147483647 }
      :int32
    else
      :int64
    end
  elsif value.all? { |v| v.is_a?(Complex) }
    :complex128
  elsif value.all? { |v| v.is_a?(Numeric) }
    :float
  else
    raise Error, "Unable to infer data type"
  end
end

.load_dataset(path, url) ⇒ Object



97
98
99
100
101
102
103
104
105
106
# File 'lib/tensorflow/utils.rb', line 97

def load_dataset(path, url)
  # TODO handle this better
  raise "No HOME" unless ENV["HOME"]
  datasets_dir = "#{ENV["HOME"]}/.keras/datasets"
  FileUtils.mkdir_p(datasets_dir)

  path = "#{datasets_dir}/#{path}"
  Utils.download_file(url, path) unless File.exist?(path)
  Npy.load_npz(path)
end