Class: Informers::Utils::Sampler

Inherits:
Object
  • Object
show all
Defined in:
lib/informers/utils/generation.rb

Direct Known Subclasses

BeamSearchSampler, GreedySampler

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(generation_config) ⇒ Sampler

Returns a new instance of Sampler.



76
77
78
79
# File 'lib/informers/utils/generation.rb', line 76

def initialize(generation_config)
  super()
  @generation_config = generation_config
end

Class Method Details

.get_sampler(generation_config) ⇒ Object



105
106
107
108
109
110
111
112
113
114
115
116
# File 'lib/informers/utils/generation.rb', line 105

def self.get_sampler(generation_config)
  if generation_config[:do_sample]
    MultinomialSampler.new(generation_config)
  elsif generation_config[:num_beams] > 1
    BeamSearchSampler.new(generation_config)
  else
    if generation_config[:num_return_sequences] > 1
      raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
    end
    GreedySampler.new(generation_config)
  end
end

Instance Method Details

#call(logits, index = -1)) ⇒ Object



81
82
83
84
85
# File 'lib/informers/utils/generation.rb', line 81

def call(logits, index = -1)
  # Sample from logits, of dims [batch, sequence_length, vocab_size].
  # If index is specified, sample from [batch, index, vocab_size].
  sample(logits, index)
end

#get_logits(logits, index) ⇒ Object



87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# File 'lib/informers/utils/generation.rb', line 87

def get_logits(logits, index)
  vocab_size = Utils.dims(logits)[-1]

  logs = logits.flatten

  if index == -1
    logs = logs.last(vocab_size)
  else
    raise Todo
  end

  # add temperature
  if @generation_config["temperature"] > 0
    logs = logs.map { |x| x / @generation_config["temperature"] }
  end
  logs
end