Class: ReplicateClient::Prediction

Inherits:
Object
  • Object
show all
Defined in:
lib/replicate-client/prediction.rb

Defined Under Namespace

Modules: Status

Constant Summary collapse

INDEX_PATH =
"/predictions"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(attributes) ⇒ Prediction

Returns a new instance of Prediction.



209
210
211
# File 'lib/replicate-client/prediction.rb', line 209

def initialize(attributes)
  reset_attributes(attributes)
end

Instance Attribute Details

#completed_atTime

The date the prediction was completed.

Returns:

  • (Time)


192
193
194
# File 'lib/replicate-client/prediction.rb', line 192

def completed_at
  @completed_at
end

#created_atTime

The date the prediction was created.

Returns:

  • (Time)


177
178
179
# File 'lib/replicate-client/prediction.rb', line 177

def created_at
  @created_at
end

#data_removedTime

The date the prediction was removed.

Returns:

  • (Time)


182
183
184
# File 'lib/replicate-client/prediction.rb', line 182

def data_removed
  @data_removed
end

#errorString

The error message for the prediction.

Returns:

  • (String)


167
168
169
# File 'lib/replicate-client/prediction.rb', line 167

def error
  @error
end

#idString

The ID of the prediction.

Returns:

  • (String)


142
143
144
# File 'lib/replicate-client/prediction.rb', line 142

def id
  @id
end

#inputHash

The input data for the prediction.

Returns:

  • (Hash)


157
158
159
# File 'lib/replicate-client/prediction.rb', line 157

def input
  @input
end

#logsString

The logs for the prediction.

Returns:

  • (String)


207
208
209
# File 'lib/replicate-client/prediction.rb', line 207

def logs
  @logs
end

#metricsHash

The metrics for the prediction.

Returns:

  • (Hash)


197
198
199
# File 'lib/replicate-client/prediction.rb', line 197

def metrics
  @metrics
end

#model_nameString

The model used for the prediction.

Returns:

  • (String)


152
153
154
# File 'lib/replicate-client/prediction.rb', line 152

def model_name
  @model_name
end

#outputHash

The output data for the prediction.

Returns:

  • (Hash)


162
163
164
# File 'lib/replicate-client/prediction.rb', line 162

def output
  @output
end

#started_atTime

The date the prediction was started.

Returns:

  • (Time)


187
188
189
# File 'lib/replicate-client/prediction.rb', line 187

def started_at
  @started_at
end

#statusString

The status of the prediction.

Returns:

  • (String)


172
173
174
# File 'lib/replicate-client/prediction.rb', line 172

def status
  @status
end

#urlsHash

The URLs for the prediction.

Returns:

  • (Hash)


202
203
204
# File 'lib/replicate-client/prediction.rb', line 202

def urls
  @urls
end

#version_idString

The version of the model used for the prediction.

Returns:

  • (String)


147
148
149
# File 'lib/replicate-client/prediction.rb', line 147

def version_id
  @version_id
end

Class Method Details

.build_path(id) ⇒ String

Build the path for the prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:

  • (String)


125
126
127
# File 'lib/replicate-client/prediction.rb', line 125

def build_path(id)
  "#{INDEX_PATH}/#{id}"
end

.cancel!(id) ⇒ void

This method returns an undefined value.

Cancel a prediction.

Parameters:

  • id (String)

    The ID of the prediction.



134
135
136
# File 'lib/replicate-client/prediction.rb', line 134

def cancel!(id)
  ReplicateClient.client.post("#{build_path(id)}/cancel")
end

.create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false) ⇒ ReplicateClient::Prediction

Create a new prediction for a version.

Parameters:

  • version (String, ReplicateClient::Version)

    The version of the model to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

  • sync (Boolean) (defaults to: false)

    Whether to wait for the prediction to complete.

Returns:



25
26
27
28
29
30
31
32
33
34
35
36
37
38
# File 'lib/replicate-client/prediction.rb', line 25

def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
  args = {
    version: version.is_a?(Model::Version) ? version.id : version,
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  headers = sync ? { "Prefer" => "wait" } : {}

  prediction = ReplicateClient.client.post(INDEX_PATH, args, headers:)

  new(prediction)
end

.create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false) ⇒ ReplicateClient::Prediction

Create a new prediction for a deployment.

Parameters:

  • deployment (String, ReplicateClient::Deployment)

    The deployment to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

  • sync (Boolean) (defaults to: false)

    Whether to wait for the prediction to complete.

Returns:



49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# File 'lib/replicate-client/prediction.rb', line 49

def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
  args = {
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  headers = sync ? { "Prefer" => "wait" } : {}

  deployment_path = deployment.is_a?(Deployment) ? deployment.path : "#{Deployment::INDEX_PATH}/#{deployment}"

  prediction = ReplicateClient.client.post("#{deployment_path}#{INDEX_PATH}", args, headers:)

  new(prediction)
end

.create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false) ⇒ ReplicateClient::Prediction

Create a new prediction for a model.

Parameters:

  • model (String, ReplicateClient::Model)

    The model to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

  • sync (Boolean) (defaults to: false)

    Whether to wait for the prediction to complete.

Returns:



74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# File 'lib/replicate-client/prediction.rb', line 74

def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
  model_path = model.is_a?(Model) ? model.path : Model.build_path(**Model.parse_model_name(model))

  args = {
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  headers = sync ? { "Prefer" => "wait" } : {}

  prediction = ReplicateClient.client.post("#{model_path}#{INDEX_PATH}", args, headers:)

  new(prediction)
end

.find(id) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:



95
96
97
98
# File 'lib/replicate-client/prediction.rb', line 95

def find(id)
  attributes = ReplicateClient.client.get(build_path(id))
  new(attributes)
end

.find_by(id:) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:



114
115
116
117
118
# File 'lib/replicate-client/prediction.rb', line 114

def find_by(id:)
  find_by!(id: id)
rescue ReplicateClient::NotFoundError
  nil
end

.find_by!(id:) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:



105
106
107
# File 'lib/replicate-client/prediction.rb', line 105

def find_by!(id:)
  find(id)
end

Instance Method Details

#cancel!void

This method returns an undefined value.

Cancel the prediction.



238
239
240
# File 'lib/replicate-client/prediction.rb', line 238

def cancel!
  Prediction.cancel!(id)
end

#canceled?Boolean

Check if the prediction is canceled.

Returns:

  • (Boolean)


259
260
261
# File 'lib/replicate-client/prediction.rb', line 259

def canceled?
  status == Status::CANCELED
end

#failed?Boolean

Check if the prediction is failed.

Returns:

  • (Boolean)


252
253
254
# File 'lib/replicate-client/prediction.rb', line 252

def failed?
  status == Status::FAILED
end

#modelReplicateClient::Model

The model used for the prediction.



224
225
226
# File 'lib/replicate-client/prediction.rb', line 224

def model
  @model ||= Model.find(@model_name, version_id: @version_id)
end

#processing?Boolean

Check if the prediction is processing.

Returns:

  • (Boolean)


273
274
275
# File 'lib/replicate-client/prediction.rb', line 273

def processing?
  status == Status::PROCESSING
end

#reload!ReplicateClient::Prediction

Reload the prediction.



216
217
218
219
# File 'lib/replicate-client/prediction.rb', line 216

def reload!
  attributes = ReplicateClient.client.get(Prediction.build_path(@id))
  reset_attributes(attributes)
end

#starting?Boolean

Check if the prediction is starting.

Returns:

  • (Boolean)


266
267
268
# File 'lib/replicate-client/prediction.rb', line 266

def starting?
  status == Status::STARTING
end

#succeeded?Boolean

Check if the prediction is succeeded.

Returns:

  • (Boolean)


245
246
247
# File 'lib/replicate-client/prediction.rb', line 245

def succeeded?
  status == Status::SUCCEEDED
end

#versionReplicateClient::Model::Version

The version of the model used for the prediction.



231
232
233
# File 'lib/replicate-client/prediction.rb', line 231

def version
  @version ||= model.version
end