Module: Tensorflow::Keras::Utils
- Defined in:
- lib/tensorflow/keras/utils.rb
Class Method Summary collapse
- .add_weight(name: nil, shape: [], initializer: nil, dtype: :float) ⇒ Object
- .get_file(fname, origin, file_hash: nil, cache_subdir: "datasets") ⇒ Object
Class Method Details
.add_weight(name: nil, shape: [], initializer: nil, dtype: :float) ⇒ Object
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
# File 'lib/tensorflow/keras/utils.rb', line 5 def add_weight(name: nil, shape: [], initializer: nil, dtype: :float) variable = Variable.new(shape: shape, name: name, dtype: dtype) initial_value = case initializer when "zeros" Tensorflow.fill(shape, 0.0) when "glorot_uniform" # TODO compute fans fan_in = shape[0] fan_out = shape[1] scale = 1.0 scale /= [1.0, (fan_in + fan_out) / 2.0].max limit = ::Math.sqrt(3.0 * scale) minval = -limit maxval = limit rnd = RawOps.random_uniform(shape: shape, dtype: :float) Math.add(rnd * (maxval - minval), minval) else raise Error, "Unknown initializer: #{initializer}" end variable.value = initial_value variable end |
.get_file(fname, origin, file_hash: nil, cache_subdir: "datasets") ⇒ Object
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
# File 'lib/tensorflow/keras/utils.rb', line 32 def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets") # destination # TODO handle this better raise "No HOME" unless ENV["HOME"] dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}" FileUtils.mkdir_p(File.dirname(dest)) return dest if File.exist?(dest) temp_dir ||= File.dirname(Tempfile.new("tensorflow")) temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name digest = file_hash&.size == 32 ? Digest::MD5.new : Digest::SHA2.new uri = URI(origin) # Net::HTTP automatically adds Accept-Encoding for compression # of response bodies and automatically decompresses gzip # and deflateresponses unless a Range header was sent. # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http| request = Net::HTTP::Get.new(uri) puts "Downloading data from #{origin}" i = 0 File.open(temp_path, "wb") do |f| http.request(request) do |response| response.read_body do |chunk| f.write(chunk) digest.update(chunk) # print progress putc "." if i % 50 == 0 i += 1 end end puts # newline end end if file_hash && digest.hexdigest != file_hash raise Error, "Bad hash" end FileUtils.mv(temp_path, dest) dest end |