Class: Wandb::XGBoostCallback

Inherits:
XGBoost::TrainingCallback
  • Object
show all
Defined in:
lib/wandb/xgboost_callback.rb

Defined Under Namespace

Classes: Opts

Constant Summary collapse

MINIMIZE_METRICS =

Add other metrics as needed

%w[rmse logloss error]
MAXIMIZE_METRICS =

Add other metrics as needed

%w[auc accuracy]

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(options = {}) ⇒ XGBoostCallback

Returns a new instance of XGBoostCallback.



27
28
29
30
31
32
33
34
35
36
37
38
# File 'lib/wandb/xgboost_callback.rb', line 27

def initialize(options = {})
  options = Opts.new(options)
  @log_model = options.default(:log_model, false)
  @log_feature_importance = options.default(:log_feature_importance, true)
  @importance_type = options.default(:importance_type, "gain")
  @normalize_feature_importance = options.default(:normalize_feature_importance, true)
  @define_metric = options.default(:define_metric, true)
  @api_key = options.default(:api_key, ENV["WANDB_API_KEY"])
  @project_name = options.default(:project_name, nil)
  @sample = options.default(:sample, 1.0)
  @custom_loggers = options.default(:custom_loggers, [])
end

Instance Attribute Details

#api_keyObject

Returns the value of attribute api_key.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def api_key
  @api_key
end

#custom_loggersObject

Returns the value of attribute custom_loggers.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def custom_loggers
  @custom_loggers
end

#define_metric=(value) ⇒ Object

Sets the attribute define_metric

Parameters:

  • value

    the value to set the attribute define_metric to.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def define_metric=(value)
  @define_metric = value
end

#historyObject

Returns the value of attribute history.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def history
  @history
end

#importance_typeObject

Returns the value of attribute importance_type.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def importance_type
  @importance_type
end

#log_feature_importance=(value) ⇒ Object

Sets the attribute log_feature_importance

Parameters:

  • value

    the value to set the attribute log_feature_importance to.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def log_feature_importance=(value)
  @log_feature_importance = value
end

#log_modelObject

Returns the value of attribute log_model.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def log_model
  @log_model
end

#normalize_feature_importanceObject

Returns the value of attribute normalize_feature_importance.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def normalize_feature_importance
  @normalize_feature_importance
end

#project_nameObject

Returns the value of attribute project_name.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def project_name
  @project_name
end

#sampleObject

Returns the value of attribute sample.



23
24
25
# File 'lib/wandb/xgboost_callback.rb', line 23

def sample
  @sample
end

Instance Method Details

#after_iteration(model, epoch, history) ⇒ Object



97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# File 'lib/wandb/xgboost_callback.rb', line 97

def after_iteration(model, epoch, history)
  log_frequency = (1.0 / @sample).round
  if epoch % log_frequency == 0
    history.to_h.each do |split, metric_scores|
      metric = metric_scores.keys.first
      values = metric_scores.values.last
      epoch_value = values[epoch]

      define_metric(split, metric) if @define_metric && epoch == 0
      full_metric_name = "#{split}-#{metric}"
      Wandb.log({ full_metric_name => epoch_value })
    end
    Wandb.log("epoch" => epoch)
  end
  false
end

#after_training(model) ⇒ Object



66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# File 'lib/wandb/xgboost_callback.rb', line 66

def after_training(model)
  # Log the model as an artifact
  log_model_as_artifact(model) if @log_model

  # Log feature importance
  log_feature_importance(model) if @log_feature_importance

  # Log best score and best iteration
  unless model.best_score
    finish
    return model
  end

  Wandb.log(
    "best_score" => model.best_score.to_f,
    "best_iteration" => model.best_iteration.to_i,
  )
  finish

  model
end

#as_jsonObject



40
41
42
43
44
45
46
47
48
49
50
51
# File 'lib/wandb/xgboost_callback.rb', line 40

def as_json
  {
    log_model: @log_model,
    log_feature_importance: @log_feature_importance,
    importance_type: @importance_type,
    define_metric: @define_metric,
    normalize_feature_importance: @normalize_feature_importance,
    sample: @sample,
    project_name: @project_name,
    callback_type: :wandb,
  }
end

#before_iteration(_model, _epoch, _history) ⇒ Object



93
94
95
# File 'lib/wandb/xgboost_callback.rb', line 93

def before_iteration(_model, _epoch, _history)
  false
end

#before_training(model) ⇒ Object



53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/wandb/xgboost_callback.rb', line 53

def before_training(model)
  Wandb.(api_key: api_key)
  Wandb.init(project: project_name)
  config = JSON.parse(model.save_config)
  log_conf = {
    learning_rate: config.dig("learner", "gradient_booster", "tree_train_param", "learning_rate").to_f,
    max_depth: config.dig("learner", "gradient_booster", "tree_train_param", "max_depth").to_f,
    n_estimators: model.num_boosted_rounds,
  }
  Wandb.log(log_conf)
  model
end

#finishObject



88
89
90
91
# File 'lib/wandb/xgboost_callback.rb', line 88

def finish
  Wandb.finish
  FileUtils.rm_rf(File.join(Dir.pwd, "wandb"))
end