Class: ClassicBandit::ThompsonSampling

Inherits:
Object
  • Object
show all
Includes:
ArmUpdatable
Defined in:
lib/classic_bandit/thompson_sampling.rb

Overview

Implements Thompson Sampling algorithm for multi-armed bandit problems. Uses Beta-Bernoulli conjugate model, sampling from Beta distribution using Gamma random variables.

Examples:

Create and use Thompson Sampling

arms = [
  ClassicBandit::Arm.new(id: 1, name: "banner_a"),
  ClassicBandit::Arm.new(id: 2, name: "banner_b")
]
bandit = ClassicBandit::ThompsonSampling.new(arms: arms)

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods included from ArmUpdatable

#update

Constructor Details

#initialize(arms:, alpha_prior: 1.0, beta_prior: 1.0) ⇒ ThompsonSampling

Returns a new instance of ThompsonSampling.

Parameters:

  • arms (Array<Arm>)

    Array of arms to choose from

  • alpha_prior (Float) (defaults to: 1.0)

    Prior parameter for successes (default: 1.0)

  • beta_prior (Float) (defaults to: 1.0)

    Prior parameter for failures (default: 1.0)

Raises:

  • (ArgumentError)


22
23
24
25
26
27
28
29
# File 'lib/classic_bandit/thompson_sampling.rb', line 22

def initialize(arms:, alpha_prior: 1.0, beta_prior: 1.0)
  raise ArgumentError, "alpha_prior must be positive" unless alpha_prior.positive?
  raise ArgumentError, "beta_prior must be positive" unless beta_prior.positive?

  @arms = arms
  @alpha_prior = alpha_prior
  @beta_prior = beta_prior
end

Instance Attribute Details

#alpha_priorObject (readonly)

Returns the value of attribute alpha_prior.



17
18
19
# File 'lib/classic_bandit/thompson_sampling.rb', line 17

def alpha_prior
  @alpha_prior
end

#armsObject (readonly)

Returns the value of attribute arms.



17
18
19
# File 'lib/classic_bandit/thompson_sampling.rb', line 17

def arms
  @arms
end

#beta_priorObject (readonly)

Returns the value of attribute beta_prior.



17
18
19
# File 'lib/classic_bandit/thompson_sampling.rb', line 17

def beta_prior
  @beta_prior
end

Instance Method Details

#select_armObject



31
32
33
# File 'lib/classic_bandit/thompson_sampling.rb', line 31

def select_arm
  @arms.max_by { |arm| ts_score(arm) }
end