Class: Transformers::XlmRoberta::XLMRobertaSdpaSelfAttention
- Inherits:
-
XLMRobertaSelfAttention
- Object
- Torch::NN::Module
- XLMRobertaSelfAttention
- Transformers::XlmRoberta::XLMRobertaSdpaSelfAttention
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Method Summary collapse
-
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
Adapted from XLMRobertaSelfAttention.
-
#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaSdpaSelfAttention
constructor
A new instance of XLMRobertaSdpaSelfAttention.
Methods inherited from XLMRobertaSelfAttention
Constructor Details
#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaSdpaSelfAttention
Returns a new instance of XLMRobertaSdpaSelfAttention.
240 241 242 243 244 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 240 def initialize(config, position_embedding_type: nil) super(config, position_embedding_type: ) @dropout_prob = config.attention_probs_dropout_prob @require_contiguous_qkv = Packaging::Version.parse(Utils.get_torch_version) < Packaging::Version.parse("2.2.0") end |
Instance Method Details
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
Adapted from XLMRobertaSelfAttention
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 247 def forward( hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false ) if @position_embedding_type != "absolute" || output_attentions || !head_mask.nil? # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. Transformers.logger.warn("XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support non-absolute `position_embedding_type` or `output_attentions: true` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation: \"eager\"` when loading the model.") return super(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions) end bsz, tgt_len, _ = hidden_states.size query_layer = transpose_for_scores(@query.(hidden_states)) # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = !encoder_hidden_states.nil? current_states = is_cross_attention ? encoder_hidden_states : hidden_states attention_mask = is_cross_attention ? encoder_attention_mask : attention_mask # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning if is_cross_attention && past_key_value && past_key_value[0].shape[2] == current_states.shape[1] key_layer, value_layer = past_key_value else key_layer = transpose_for_scores(@key.(current_states)) value_layer = transpose_for_scores(@value.(current_states)) if !past_key_value.nil? && !is_cross_attention key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2) value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2) end end if @is_decoder # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = [key_layer, value_layer] end # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. # Reference: https://github.com/pytorch/pytorch/issues/112577 if @require_contiguous_qkv && query_layer.device.type == "cuda" && !attention_mask.nil? query_layer = query_layer.contiguous key_layer = key_layer.contiguous value_layer = value_layer.contiguous end # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create # a causal mask in case tgt_len == 1. is_causal = @is_decoder && !is_cross_attention && attention_mask.nil? && tgt_len > 1 ? true : false attn_output = Torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask: attention_mask, dropout_p: @training ? @dropout_prob : 0.0, is_causal: is_causal) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, @all_head_size) outputs = [attn_output] if @is_decoder outputs = outputs + [past_key_value] end outputs end |