Class: Ai4r::Search::MCTS

Inherits:
Object
  • Object
show all
Includes:
Data::Parameterizable
Defined in:
lib/ai4r/search/mcts.rb

Overview

Basic UCT-style Monte Carlo Tree Search.

This generic implementation expects four callbacks:

  • actions_fn.call(state) returns available actions for a state.

  • transition_fn.call(state, action) computes the next state.

  • terminal_fn.call(state) returns true if the state has no children.

  • reward_fn.call(state) yields a numeric payoff for terminal states.

Example:

env = {
  actions_fn: ->(s) { s == :root ? i[a b] : [] },
  transition_fn: ->(s, a) { a == :a ? :win : :lose },
  terminal_fn: ->(s) { i[win lose].include?(s) },
  reward_fn: ->(s) { s == :win ? 1.0 : 0.0 }
}
mcts = Ai4r::Search::MCTS.new(**env)
best = mcts.search(:root, 50)
# => :a

Defined Under Namespace

Classes: Node

Instance Method Summary collapse

Methods included from Data::Parameterizable

#get_parameters, included, #set_parameters

Constructor Details

#initialize(actions_fn:, transition_fn:, terminal_fn:, reward_fn:, exploration: Math.sqrt(2)) ⇒ MCTS

Create a new search object.

actions_fn

returns available actions for a state

transition_fn

computes the next state given a state and action

terminal_fn

predicate to detect terminal states

reward_fn

numeric payoff for terminal states



48
49
50
51
52
53
54
55
# File 'lib/ai4r/search/mcts.rb', line 48

def initialize(actions_fn:, transition_fn:, terminal_fn:, reward_fn:,
               exploration: Math.sqrt(2))
  @actions_fn = actions_fn
  @transition_fn = transition_fn
  @terminal_fn = terminal_fn
  @reward_fn = reward_fn
  @exploration = exploration
end

Instance Method Details

#search(root_state, iterations) ⇒ Object

Run MCTS starting from root_state for a number of iterations. Returns the action considered best from the root.



59
60
61
62
63
64
65
66
67
# File 'lib/ai4r/search/mcts.rb', line 59

def search(root_state, iterations)
  root = Node.new(root_state)
  iterations.times do
    node = tree_policy(root)
    reward = default_policy(node.state)
    backup(node, reward)
  end
  best_child(root, 0)&.action
end