Class: Transformers::Bert::BertAttention
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Bert::BertAttention
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Method Summary collapse
- #forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
-
#initialize(config, position_embedding_type: nil) ⇒ BertAttention
constructor
A new instance of BertAttention.
Constructor Details
#initialize(config, position_embedding_type: nil) ⇒ BertAttention
Returns a new instance of BertAttention.
222 223 224 225 226 227 228 229 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 222 def initialize(config, position_embedding_type: nil) super() @self = BERT_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new( config, position_embedding_type: ) @output = BertSelfOutput.new(config) @pruned_heads = Set.new end |
Instance Method Details
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 231 def forward( hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false ) self_outputs = @self.( hidden_states, attention_mask: attention_mask, head_mask: head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, past_key_value: past_key_value, output_attentions: output_attentions ) attention_output = @output.(self_outputs[0], hidden_states) outputs = [attention_output] + self_outputs[1..] # add attentions if we output them outputs end |