Class: Informers::Utils::Sampler
- Inherits:
-
Object
- Object
- Informers::Utils::Sampler
show all
- Defined in:
- lib/informers/utils/generation.rb
Class Method Summary
collapse
Instance Method Summary
collapse
Constructor Details
#initialize(generation_config) ⇒ Sampler
Returns a new instance of Sampler.
76
77
78
79
|
# File 'lib/informers/utils/generation.rb', line 76
def initialize(generation_config)
super()
@generation_config = generation_config
end
|
Class Method Details
.get_sampler(generation_config) ⇒ Object
105
106
107
108
109
110
111
112
113
114
115
116
|
# File 'lib/informers/utils/generation.rb', line 105
def self.get_sampler(generation_config)
if generation_config[:do_sample]
MultinomialSampler.new(generation_config)
elsif generation_config[:num_beams] > 1
BeamSearchSampler.new(generation_config)
else
if generation_config[:num_return_sequences] > 1
raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
end
GreedySampler.new(generation_config)
end
end
|
Instance Method Details
#call(logits, index = -1)) ⇒ Object
81
82
83
84
85
|
# File 'lib/informers/utils/generation.rb', line 81
def call(logits, index = -1)
sample(logits, index)
end
|
#get_logits(logits, index) ⇒ Object
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
# File 'lib/informers/utils/generation.rb', line 87
def get_logits(logits, index)
vocab_size = Utils.dims(logits)[-1]
logs = logits.flatten
if index == -1
logs = logs.last(vocab_size)
else
raise Todo
end
if @generation_config["temperature"] > 0
logs = logs.map { |x| x / @generation_config["temperature"] }
end
logs
end
|