Class: Wandb::XGBoostCallback
- Inherits:
-
XGBoost::TrainingCallback
- Object
- XGBoost::TrainingCallback
- Wandb::XGBoostCallback
- Defined in:
- lib/wandb/xgboost_callback.rb
Defined Under Namespace
Classes: Opts
Constant Summary collapse
- MINIMIZE_METRICS =
Add other metrics as needed
%w[rmse logloss error]- MAXIMIZE_METRICS =
Add other metrics as needed
%w[auc accuracy]
Instance Attribute Summary collapse
-
#api_key ⇒ Object
Returns the value of attribute api_key.
-
#custom_loggers ⇒ Object
Returns the value of attribute custom_loggers.
-
#define_metric ⇒ Object
writeonly
Sets the attribute define_metric.
-
#history ⇒ Object
Returns the value of attribute history.
-
#importance_type ⇒ Object
Returns the value of attribute importance_type.
-
#log_feature_importance ⇒ Object
writeonly
Sets the attribute log_feature_importance.
-
#log_model ⇒ Object
Returns the value of attribute log_model.
-
#normalize_feature_importance ⇒ Object
Returns the value of attribute normalize_feature_importance.
-
#project_name ⇒ Object
Returns the value of attribute project_name.
-
#sample ⇒ Object
Returns the value of attribute sample.
Instance Method Summary collapse
- #after_iteration(model, epoch, history) ⇒ Object
- #after_training(model) ⇒ Object
- #as_json ⇒ Object
- #before_iteration(_model, _epoch, _history) ⇒ Object
- #before_training(model) ⇒ Object
- #finish ⇒ Object
-
#initialize(options = {}) ⇒ XGBoostCallback
constructor
A new instance of XGBoostCallback.
Constructor Details
#initialize(options = {}) ⇒ XGBoostCallback
Returns a new instance of XGBoostCallback.
27 28 29 30 31 32 33 34 35 36 37 38 |
# File 'lib/wandb/xgboost_callback.rb', line 27 def initialize( = {}) = Opts.new() @log_model = .default(:log_model, false) @log_feature_importance = .default(:log_feature_importance, true) @importance_type = .default(:importance_type, "gain") @normalize_feature_importance = .default(:normalize_feature_importance, true) @define_metric = .default(:define_metric, true) @api_key = .default(:api_key, ENV["WANDB_API_KEY"]) @project_name = .default(:project_name, nil) @sample = .default(:sample, 1.0) @custom_loggers = .default(:custom_loggers, []) end |
Instance Attribute Details
#api_key ⇒ Object
Returns the value of attribute api_key.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def api_key @api_key end |
#custom_loggers ⇒ Object
Returns the value of attribute custom_loggers.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def custom_loggers @custom_loggers end |
#define_metric=(value) ⇒ Object
Sets the attribute define_metric
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def define_metric=(value) @define_metric = value end |
#history ⇒ Object
Returns the value of attribute history.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def history @history end |
#importance_type ⇒ Object
Returns the value of attribute importance_type.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def importance_type @importance_type end |
#log_feature_importance=(value) ⇒ Object
Sets the attribute log_feature_importance
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def log_feature_importance=(value) @log_feature_importance = value end |
#log_model ⇒ Object
Returns the value of attribute log_model.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def log_model @log_model end |
#normalize_feature_importance ⇒ Object
Returns the value of attribute normalize_feature_importance.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def normalize_feature_importance @normalize_feature_importance end |
#project_name ⇒ Object
Returns the value of attribute project_name.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def project_name @project_name end |
#sample ⇒ Object
Returns the value of attribute sample.
23 24 25 |
# File 'lib/wandb/xgboost_callback.rb', line 23 def sample @sample end |
Instance Method Details
#after_iteration(model, epoch, history) ⇒ Object
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# File 'lib/wandb/xgboost_callback.rb', line 97 def after_iteration(model, epoch, history) log_frequency = (1.0 / @sample).round if epoch % log_frequency == 0 history.to_h.each do |split, metric_scores| metric = metric_scores.keys.first values = metric_scores.values.last epoch_value = values[epoch] define_metric(split, metric) if @define_metric && epoch == 0 full_metric_name = "#{split}-#{metric}" Wandb.log({ full_metric_name => epoch_value }) end Wandb.log("epoch" => epoch) end false end |
#after_training(model) ⇒ Object
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# File 'lib/wandb/xgboost_callback.rb', line 66 def after_training(model) # Log the model as an artifact log_model_as_artifact(model) if @log_model # Log feature importance log_feature_importance(model) if @log_feature_importance # Log best score and best iteration unless model.best_score finish return model end Wandb.log( "best_score" => model.best_score.to_f, "best_iteration" => model.best_iteration.to_i, ) finish model end |
#as_json ⇒ Object
40 41 42 43 44 45 46 47 48 49 50 51 |
# File 'lib/wandb/xgboost_callback.rb', line 40 def as_json { log_model: @log_model, log_feature_importance: @log_feature_importance, importance_type: @importance_type, define_metric: @define_metric, normalize_feature_importance: @normalize_feature_importance, sample: @sample, project_name: @project_name, callback_type: :wandb, } end |
#before_iteration(_model, _epoch, _history) ⇒ Object
93 94 95 |
# File 'lib/wandb/xgboost_callback.rb', line 93 def before_iteration(_model, _epoch, _history) false end |
#before_training(model) ⇒ Object
53 54 55 56 57 58 59 60 61 62 63 64 |
# File 'lib/wandb/xgboost_callback.rb', line 53 def before_training(model) Wandb.login(api_key: api_key) Wandb.init(project: project_name) config = JSON.parse(model.save_config) log_conf = { learning_rate: config.dig("learner", "gradient_booster", "tree_train_param", "learning_rate").to_f, max_depth: config.dig("learner", "gradient_booster", "tree_train_param", "max_depth").to_f, n_estimators: model.num_boosted_rounds, } Wandb.log(log_conf) model end |
#finish ⇒ Object
88 89 90 91 |
# File 'lib/wandb/xgboost_callback.rb', line 88 def finish Wandb.finish FileUtils.rm_rf(File.join(Dir.pwd, "wandb")) end |