Class: ReplicateClient::Training

Inherits:
Object
  • Object
show all
Defined in:
lib/replicate-client/training.rb

Defined Under Namespace

Modules: Status

Constant Summary collapse

INDEX_PATH =
"/trainings"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(attributes, client: ReplicateClient.client) ⇒ ReplicateClient::Training

Initialize a new training instance.

Parameters:

  • attributes (Hash)

    The attributes of the training.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for requests.



211
212
213
214
# File 'lib/replicate-client/training.rb', line 211

def initialize(attributes, client: ReplicateClient.client)
  @client = client
  reset_attributes(attributes)
end

Instance Attribute Details

#clientReplicateClient::Client

The client used to make API requests for this training.



203
204
205
# File 'lib/replicate-client/training.rb', line 203

def client
  @client
end

#completed_atTime?

The timestamp when the training was completed.

Returns:

  • (Time, nil)


168
169
170
# File 'lib/replicate-client/training.rb', line 168

def completed_at
  @completed_at
end

#created_atString

The timestamp when the training was created.

Returns:

  • (String)


163
164
165
# File 'lib/replicate-client/training.rb', line 163

def created_at
  @created_at
end

#errorString?

The error message, if any, encountered during the training process.

Returns:

  • (String, nil)


183
184
185
# File 'lib/replicate-client/training.rb', line 183

def error
  @error
end

#idString

The unique identifier of the training.

Returns:

  • (String)


137
138
139
# File 'lib/replicate-client/training.rb', line 137

def id
  @id
end

#inputHash

The input data provided for the training.

Returns:

  • (Hash)


152
153
154
# File 'lib/replicate-client/training.rb', line 152

def input
  @input
end

#logsString

The logs generated during the training process.

Returns:

  • (String)


178
179
180
# File 'lib/replicate-client/training.rb', line 178

def logs
  @logs
end

#metricsHash?

The metrics generated during the training process.

Returns:

  • (Hash, nil)


198
199
200
# File 'lib/replicate-client/training.rb', line 198

def metrics
  @metrics
end

#model_full_nameString

The full model name in the format “owner/name”.

Returns:

  • (String)


142
143
144
# File 'lib/replicate-client/training.rb', line 142

def model_full_name
  @model_full_name
end

#outputHash?

The output data generated during the training process.

Returns:

  • (Hash, nil)


193
194
195
# File 'lib/replicate-client/training.rb', line 193

def output
  @output
end

#started_atTime?

The timestamp when the training was started.

Returns:

  • (Time, nil)


173
174
175
# File 'lib/replicate-client/training.rb', line 173

def started_at
  @started_at
end

#statusString

The current status of the training. Possible values: “starting”, “processing”, “succeeded”, “failed”, “canceled”.

Returns:

  • (String)


158
159
160
# File 'lib/replicate-client/training.rb', line 158

def status
  @status
end

#urlsHash

URLs related to the training, such as those for retrieving or canceling it.

Returns:

  • (Hash)


188
189
190
# File 'lib/replicate-client/training.rb', line 188

def urls
  @urls
end

#version_idString

The version ID of the model being trained.

Returns:

  • (String)


147
148
149
# File 'lib/replicate-client/training.rb', line 147

def version_id
  @version_id
end

Class Method Details

.auto_paging_each(client: ReplicateClient.client) {|ReplicateClient::Training| ... } ⇒ void

This method returns an undefined value.

List all trainings.

Parameters:

Yields:



23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/replicate-client/training.rb', line 23

def auto_paging_each(client: ReplicateClient.client, &block)
  cursor = nil

  loop do
    url_params = cursor ? "?cursor=#{cursor}" : ""
    attributes = client.get("#{INDEX_PATH}#{url_params}")

    trainings = attributes["results"].map { |training| new(training, client: client) }

    trainings.each(&block)

    cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
    break if cursor.nil?
  end
end

.build_path(id:) ⇒ String

Build the path for a specific training.

Parameters:

  • id (String)

    The id of the training.

Returns:

  • (String)


129
130
131
# File 'lib/replicate-client/training.rb', line 129

def build_path(id:)
  "#{INDEX_PATH}/#{id}"
end

.cancel!(id, client: ReplicateClient.client) ⇒ void

This method returns an undefined value.

Cancel a training.

Parameters:

  • id (String)

    The id of the training.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for requests.



119
120
121
122
# File 'lib/replicate-client/training.rb', line 119

def cancel!(id, client: ReplicateClient.client)
  path = "#{build_path(id: id)}/cancel"
  client.post(path, {})
end

.create!(owner:, name:, version:, destination:, input:, webhook_url: nil, webhook_events_filter: nil, client: ReplicateClient.client) ⇒ ReplicateClient::Training

Create a new training.

format.

Parameters:

  • owner (String)

    The owner of the model.

  • name (String)

    The name of the model.

  • version (ReplicateClient::Version, String)

    The version of the model to train.

  • destination (ReplicateClient::Model, String)

    The destination model instance or string in “owner/name”

  • input (Hash)

    The input data for the training.

  • webhook_url (String, nil) (defaults to: nil)

    A URL to receive webhook notifications.

  • webhook_events_filter (Array, nil) (defaults to: nil)

    The events to trigger webhook requests.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for requests.

Returns:



52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# File 'lib/replicate-client/training.rb', line 52

def create!(owner:, name:, version:, destination:, input:, webhook_url: nil, webhook_events_filter: nil,
            client: ReplicateClient.client)
  destination_str = destination.is_a?(ReplicateClient::Model) ? destination.full_name : destination
  version_id = version.is_a?(ReplicateClient::Model::Version) ? version.id : version

  path = "/models/#{owner}/#{name}/versions/#{version_id}/trainings"
  body = {
    destination: destination_str,
    input: input,
    webhook: webhook_url || client.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter
  }

  attributes = client.post(path, body)
  new(attributes, client: client)
end

.create_for_model!(model:, destination:, input:, webhook_url: nil, webhook_events_filter: nil, client: ReplicateClient.client) ⇒ ReplicateClient::Training

Create a new training for a specific model.

Parameters:

  • model (ReplicateClient::Model, String)

    The model instance or a string representing the model ID.

  • destination (ReplicateClient::Model, String)

    The destination model or full name in “owner/name” format.

  • input (Hash)

    The input data for the training.

  • webhook_url (String, nil) (defaults to: nil)

    A URL to receive webhook notifications.

  • webhook_events_filter (Array, nil) (defaults to: nil)

    The events to trigger webhook requests.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for requests.

Returns:

Raises:

  • (ArgumentError)


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# File 'lib/replicate-client/training.rb', line 79

def create_for_model!(model:, destination:, input:, webhook_url: nil, webhook_events_filter: nil,
                      client: ReplicateClient.client)
  model_instance = if model.is_a?(ReplicateClient::Model)
                     model
                   else
                     ReplicateClient::Model.find(model,
                                                 client: client)
                   end
  raise ArgumentError, "Invalid model" unless model_instance

  create!(
    owner: model_instance.owner,
    name: model_instance.name,
    version: model_instance.version_id,
    destination: destination,
    input: input,
    webhook_url: webhook_url || client.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter,
    client: client
  )
end

.find(id, client: ReplicateClient.client) ⇒ ReplicateClient::Training

Find a training by id.

Parameters:

  • id (String)

    The id of the training.

  • client (ReplicateClient::Client) (defaults to: ReplicateClient.client)

    The client to use for requests.

Returns:



107
108
109
110
111
# File 'lib/replicate-client/training.rb', line 107

def find(id, client: ReplicateClient.client)
  path = build_path(id: id)
  attributes = client.get(path)
  new(attributes, client: client)
end

Instance Method Details

#cancel!void

This method returns an undefined value.

Cancel the training.



254
255
256
# File 'lib/replicate-client/training.rb', line 254

def cancel!
  ReplicateClient::Training.cancel!(id, client: @client)
end

#canceled?Boolean

Check if the training was canceled.

Returns:

  • (Boolean)


247
248
249
# File 'lib/replicate-client/training.rb', line 247

def canceled?
  status == Status::CANCELED
end

#failed?Boolean

Check if the training has failed.

Returns:

  • (Boolean)


240
241
242
# File 'lib/replicate-client/training.rb', line 240

def failed?
  status == Status::FAILED
end

#modelReplicateClient::Model

The model instance of the training.



269
270
271
# File 'lib/replicate-client/training.rb', line 269

def model
  @model ||= ReplicateClient::Model.find(model_full_name, version_id: version_id, client: @client)
end

#processing?Boolean

Check if the training is processing.

Returns:

  • (Boolean)


226
227
228
# File 'lib/replicate-client/training.rb', line 226

def processing?
  status == Status::PROCESSING
end

#reload!void

This method returns an undefined value.

Reload the training.



261
262
263
264
# File 'lib/replicate-client/training.rb', line 261

def reload!
  attributes = @client.get(Training.build_path(id: id))
  reset_attributes(attributes)
end

#starting?Boolean

Check if the training is starting.

Returns:

  • (Boolean)


219
220
221
# File 'lib/replicate-client/training.rb', line 219

def starting?
  status == Status::STARTING
end

#succeeded?Boolean

Check if the training has succeeded.

Returns:

  • (Boolean)


233
234
235
# File 'lib/replicate-client/training.rb', line 233

def succeeded?
  status == Status::SUCCEEDED
end

#versionReplicateClient::Model::Version

The version instance of the training.



276
277
278
# File 'lib/replicate-client/training.rb', line 276

def version
  @version ||= model.version
end