Module: Rager::Utils::Replicate

Extended by:
T::Sig
Defined in:
lib/rager/utils/replicate.rb

Class Method Summary collapse

Class Method Details

.download_prediction(prediction_url, key = nil, http_adapter: nil) ⇒ Object



20
21
22
23
24
25
# File 'lib/rager/utils/replicate.rb', line 20

def self.download_prediction(prediction_url, key = nil, http_adapter: nil)
  download_url = get_download_url(prediction_url, key, http_adapter: http_adapter)
  return nil if download_url.nil?

  Rager::Utils::Http.download(download_url)
end

.download_prediction_to_file(prediction_url, key = nil, path:, http_adapter: nil) ⇒ Object



35
36
37
38
39
40
# File 'lib/rager/utils/replicate.rb', line 35

def self.download_prediction_to_file(prediction_url, key = nil, path:, http_adapter: nil)
  download_url = get_download_url(prediction_url, key, http_adapter: http_adapter)
  return nil if download_url.nil?

  Rager::Utils::Http.download_to_file(download_url, path)
end

.get_download_url(prediction_url, key = nil, http_adapter: nil) ⇒ Object



43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/rager/utils/replicate.rb', line 43

def self.get_download_url(prediction_url, key = nil, http_adapter: nil)
  api_key = ENV["REPLICATE_API_KEY"]
  raise Rager::Errors::CredentialsError.new("Replicate", env_var: ["REPLICATE_API_KEY"]) if api_key.nil?

  adapter = http_adapter || Rager.config.http_adapter
  response = adapter.make_request(
    Rager::Http::Request.new(
      url: prediction_url,
      headers: {"Authorization" => "Bearer #{api_key}", "Content-Type" => "application/json"}
    )
  )

  raise Rager::Errors::HttpError.new(adapter, prediction_url, response.status, body: response.body&.to_s) unless response.success?

  data = JSON.parse(T.cast(T.must(response.body), String))
  return nil if ["starting", "processing"].include?(data["status"])

  if data["status"] == "failed"
    error_msg = data["error"] || "Prediction failed"
    raise Rager::Errors::HttpError.new(adapter, prediction_url, 422, body: error_msg)
  end

  return nil unless data["status"] == "succeeded"

  output = data["output"]
  return nil if output.nil?

  download_url = if key && output.is_a?(Hash) && output[key]
    output[key]
  elsif output.is_a?(Hash)
    output.values.compact.first
  elsif output.is_a?(Array) && !output.empty?
    output.first
  else
    output.to_s
  end

  return nil if download_url.nil? || download_url.empty?
  download_url
rescue JSON::ParserError => e
  raise Rager::Errors::ParseError.new("Failed to parse prediction response", details: e.message)
end

.poll_prediction(prediction_url, key: nil, path: nil, max_attempts: 30, sleep_interval: 10, http_adapter: nil) ⇒ Object



96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# File 'lib/rager/utils/replicate.rb', line 96

def self.poll_prediction(prediction_url, key: nil, path: nil, max_attempts: 30, sleep_interval: 10, http_adapter: nil)
  max_attempts.times do
    result = if path
      download_prediction_to_file(prediction_url, key, path: path, http_adapter: http_adapter)
    else
      download_prediction(prediction_url, key, http_adapter: http_adapter)
    end
    return result unless result.nil?
    Rager::Utils::Runtime.sleep(sleep_interval)
  end

  raise Rager::Errors::TimeoutError.new(
    "Replicate prediction polling",
    timeout_seconds: max_attempts * sleep_interval,
    attempts: max_attempts,
    details: "Prediction at #{prediction_url} did not complete within #{max_attempts} attempts"
  )
end