Class: ReplicateClient::Prediction

Inherits:
Object
  • Object
show all
Defined in:
lib/replicate-client/prediction.rb

Defined Under Namespace

Modules: Status

Constant Summary collapse

INDEX_PATH =
"/predictions"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(attributes, client: ReplicateClient.client) ⇒ Prediction

Returns a new instance of Prediction.



224
225
226
227
# File 'lib/replicate-client/prediction.rb', line 224

def initialize(attributes, client: ReplicateClient.client)
  @client = client
  reset_attributes(attributes)
end

Instance Attribute Details

#clientReplicateClient::Client

The client for the prediction.



222
223
224
# File 'lib/replicate-client/prediction.rb', line 222

def client
  @client
end

#completed_atTime

The date the prediction was completed.

Returns:

  • (Time)


202
203
204
# File 'lib/replicate-client/prediction.rb', line 202

def completed_at
  @completed_at
end

#created_atTime

The date the prediction was created.

Returns:

  • (Time)


187
188
189
# File 'lib/replicate-client/prediction.rb', line 187

def created_at
  @created_at
end

#data_removedTime

The date the prediction was removed.

Returns:

  • (Time)


192
193
194
# File 'lib/replicate-client/prediction.rb', line 192

def data_removed
  @data_removed
end

#errorString

The error message for the prediction.

Returns:

  • (String)


177
178
179
# File 'lib/replicate-client/prediction.rb', line 177

def error
  @error
end

#idString

The ID of the prediction.

Returns:

  • (String)


152
153
154
# File 'lib/replicate-client/prediction.rb', line 152

def id
  @id
end

#inputHash

The input data for the prediction.

Returns:

  • (Hash)


167
168
169
# File 'lib/replicate-client/prediction.rb', line 167

def input
  @input
end

#logsString

The logs for the prediction.

Returns:

  • (String)


217
218
219
# File 'lib/replicate-client/prediction.rb', line 217

def logs
  @logs
end

#metricsHash

The metrics for the prediction.

Returns:

  • (Hash)


207
208
209
# File 'lib/replicate-client/prediction.rb', line 207

def metrics
  @metrics
end

#model_nameString

The model used for the prediction.

Returns:

  • (String)


162
163
164
# File 'lib/replicate-client/prediction.rb', line 162

def model_name
  @model_name
end

#outputHash

The output data for the prediction.

Returns:

  • (Hash)


172
173
174
# File 'lib/replicate-client/prediction.rb', line 172

def output
  @output
end

#started_atTime

The date the prediction was started.

Returns:

  • (Time)


197
198
199
# File 'lib/replicate-client/prediction.rb', line 197

def started_at
  @started_at
end

#statusString

The status of the prediction.

Returns:

  • (String)


182
183
184
# File 'lib/replicate-client/prediction.rb', line 182

def status
  @status
end

#urlsHash

The URLs for the prediction.

Returns:

  • (Hash)


212
213
214
# File 'lib/replicate-client/prediction.rb', line 212

def urls
  @urls
end

#version_idString

The version of the model used for the prediction.

Returns:

  • (String)


157
158
159
# File 'lib/replicate-client/prediction.rb', line 157

def version_id
  @version_id
end

Class Method Details

.build_path(id) ⇒ String

Build the path for the prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:

  • (String)


134
135
136
# File 'lib/replicate-client/prediction.rb', line 134

def build_path(id)
  "#{INDEX_PATH}/#{id}"
end

.cancel!(id, client: ReplicateClient.client) ⇒ void

This method returns an undefined value.

Cancel a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.



144
145
146
# File 'lib/replicate-client/prediction.rb', line 144

def cancel!(id, client: ReplicateClient.client)
  client.post("#{build_path(id)}/cancel")
end

.create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false, client: ReplicateClient.client) ⇒ ReplicateClient::Prediction

Create a new prediction for a version.

Parameters:

  • version (String, ReplicateClient::Version)

    The version of the model to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

  • sync (Boolean) (defaults to: false)

    Whether to wait for the prediction to complete.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.

Returns:



26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# File 'lib/replicate-client/prediction.rb', line 26

def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false,
            client: ReplicateClient.client)
  args = {
    version: version.is_a?(Model::Version) ? version.id : version,
    input: input,
    webhook: webhook_url || client.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  headers = sync ? { "Prefer" => "wait" } : {}

  prediction = client.post(INDEX_PATH, args, headers:)

  new(prediction, client: client)
end

.create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false, client: ReplicateClient.client) ⇒ ReplicateClient::Prediction

Create a new prediction for a deployment.

Parameters:

  • deployment (String, ReplicateClient::Deployment)

    The deployment to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

  • sync (Boolean) (defaults to: false)

    Whether to wait for the prediction to complete.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.

Returns:



52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# File 'lib/replicate-client/prediction.rb', line 52

def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false,
                           client: ReplicateClient.client)
  args = {
    input: input,
    webhook: webhook_url || client.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  headers = sync ? { "Prefer" => "wait" } : {}

  deployment_path = deployment.is_a?(Deployment) ? deployment.path : "#{Deployment::INDEX_PATH}/#{deployment}"

  prediction = client.post("#{deployment_path}#{INDEX_PATH}", args, headers:)

  new(prediction, client: client)
end

.create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false, client: ReplicateClient.client) ⇒ ReplicateClient::Prediction

Create a new prediction for a model.

Parameters:

  • model (String, ReplicateClient::Model)

    The model to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

  • sync (Boolean) (defaults to: false)

    Whether to wait for the prediction to complete.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.

Returns:



79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# File 'lib/replicate-client/prediction.rb', line 79

def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false,
                               client: ReplicateClient.client)
  model_path = model.is_a?(Model) ? model.path : Model.build_path(**Model.parse_model_name(model))

  args = {
    input: input,
    webhook: webhook_url || client.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  headers = sync ? { "Prefer" => "wait" } : {}

  prediction = client.post("#{model_path}#{INDEX_PATH}", args, headers:)

  new(prediction, client: client)
end

.find(id, client: ReplicateClient.client) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.

Returns:



102
103
104
105
# File 'lib/replicate-client/prediction.rb', line 102

def find(id, client: ReplicateClient.client)
  attributes = client.get(build_path(id))
  new(attributes, client: client)
end

.find_by(id:, client: ReplicateClient.client) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.

Returns:



123
124
125
126
127
# File 'lib/replicate-client/prediction.rb', line 123

def find_by(id:, client: ReplicateClient.client)
  find_by!(id: id, client: client)
rescue ReplicateClient::NotFoundError
  nil
end

.find_by!(id:, client: ReplicateClient.client) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for the prediction.

Returns:



113
114
115
# File 'lib/replicate-client/prediction.rb', line 113

def find_by!(id:, client: ReplicateClient.client)
  find(id, client: client)
end

Instance Method Details

#cancel!void

This method returns an undefined value.

Cancel the prediction.



254
255
256
# File 'lib/replicate-client/prediction.rb', line 254

def cancel!
  Prediction.cancel!(id, client: @client)
end

#canceled?Boolean

Check if the prediction is canceled.

Returns:

  • (Boolean)


275
276
277
# File 'lib/replicate-client/prediction.rb', line 275

def canceled?
  status == Status::CANCELED
end

#failed?Boolean

Check if the prediction is failed.

Returns:

  • (Boolean)


268
269
270
# File 'lib/replicate-client/prediction.rb', line 268

def failed?
  status == Status::FAILED
end

#modelReplicateClient::Model

The model used for the prediction.



240
241
242
# File 'lib/replicate-client/prediction.rb', line 240

def model
  @model ||= Model.find(@model_name, version_id: @version_id)
end

#processing?Boolean

Check if the prediction is processing.

Returns:

  • (Boolean)


289
290
291
# File 'lib/replicate-client/prediction.rb', line 289

def processing?
  status == Status::PROCESSING
end

#reload!ReplicateClient::Prediction

Reload the prediction.



232
233
234
235
# File 'lib/replicate-client/prediction.rb', line 232

def reload!
  attributes = @client.get(Prediction.build_path(@id))
  reset_attributes(attributes)
end

#starting?Boolean

Check if the prediction is starting.

Returns:

  • (Boolean)


282
283
284
# File 'lib/replicate-client/prediction.rb', line 282

def starting?
  status == Status::STARTING
end

#succeeded?Boolean

Check if the prediction is succeeded.

Returns:

  • (Boolean)


261
262
263
# File 'lib/replicate-client/prediction.rb', line 261

def succeeded?
  status == Status::SUCCEEDED
end

#versionReplicateClient::Model::Version

The version of the model used for the prediction.



247
248
249
# File 'lib/replicate-client/prediction.rb', line 247

def version
  @version ||= model.version
end