Module: Torch::Hub

Defined in:
lib/torch/hub.rb

Class Method Summary collapse

Class Method Details

.download_url_to_file(url, dst) ⇒ Object



8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# File 'lib/torch/hub.rb', line 8

def download_url_to_file(url, dst)
  uri = URI(url)
  tmp = nil
  location = nil

  puts "Downloading #{url}..."
  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
    request = Net::HTTP::Get.new(uri)

    http.request(request) do |response|
      case response
      when Net::HTTPRedirection
        location = response["location"]
      when Net::HTTPSuccess
        tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
        File.open(tmp, "wb") do |f|
          response.read_body do |chunk|
            f.write(chunk)
          end
        end
      else
        raise Error, "Bad response"
      end
    end
  end

  if location
    download_url_to_file(location, dst)
  else
    FileUtils.mv(tmp, dst)
    nil
  end
end

.list(github, force_reload: false) ⇒ Object

Raises:



4
5
6
# File 'lib/torch/hub.rb', line 4

def list(github, force_reload: false)
  raise NotImplementedYet
end

.load_state_dict_from_url(url, model_dir: nil) ⇒ Object



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# File 'lib/torch/hub.rb', line 42

def load_state_dict_from_url(url, model_dir: nil)
  unless model_dir
    torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
    model_dir = File.join(torch_home, "checkpoints")
  end

  FileUtils.mkdir_p(model_dir)

  parts = URI(url)
  filename = File.basename(parts.path)
  cached_file = File.join(model_dir, filename)
  unless File.exist?(cached_file)
    # TODO support hash_prefix
    download_url_to_file(url, cached_file)
  end

  Torch.load(cached_file)
end