Class: Transformers::Bert::BertAttention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config, position_embedding_type: nil) ⇒ BertAttention

Returns a new instance of BertAttention.



222
223
224
225
226
227
228
229
# File 'lib/transformers/models/bert/modeling_bert.rb', line 222

def initialize(config, position_embedding_type: nil)
  super()
  @self = BERT_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(
    config, position_embedding_type: position_embedding_type
  )
  @output = BertSelfOutput.new(config)
  @pruned_heads = Set.new
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object



231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# File 'lib/transformers/models/bert/modeling_bert.rb', line 231

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  self_outputs = @self.(
    hidden_states,
    attention_mask: attention_mask,
    head_mask: head_mask,
    encoder_hidden_states: encoder_hidden_states,
    encoder_attention_mask: encoder_attention_mask,
    past_key_value: past_key_value,
    output_attentions: output_attentions
  )
  attention_output = @output.(self_outputs[0], hidden_states)
  outputs = [attention_output] + self_outputs[1..]  # add attentions if we output them
  outputs
end