Class: Transformers::XlmRoberta::XLMRobertaSdpaSelfAttention

Inherits:
XLMRobertaSelfAttention show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Method Summary collapse

Methods inherited from XLMRobertaSelfAttention

#transpose_for_scores

Constructor Details

#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaSdpaSelfAttention

Returns a new instance of XLMRobertaSdpaSelfAttention.



240
241
242
243
244
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 240

def initialize(config, position_embedding_type: nil)
  super(config, position_embedding_type: position_embedding_type)
  @dropout_prob = config.attention_probs_dropout_prob
  @require_contiguous_qkv = Packaging::Version.parse(Utils.get_torch_version) < Packaging::Version.parse("2.2.0")
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object

Adapted from XLMRobertaSelfAttention



247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 247

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  if @position_embedding_type != "absolute" || output_attentions || !head_mask.nil?
    # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
    Transformers.logger.warn("XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support non-absolute `position_embedding_type` or `output_attentions: true` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation: \"eager\"` when loading the model.")
    return super(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
  end

  bsz, tgt_len, _ = hidden_states.size

  query_layer = transpose_for_scores(@query.(hidden_states))

  # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
  # mask needs to be such that the encoder's padding tokens are not attended to.
  is_cross_attention = !encoder_hidden_states.nil?

  current_states = is_cross_attention ? encoder_hidden_states : hidden_states
  attention_mask = is_cross_attention ? encoder_attention_mask : attention_mask

  # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
  if is_cross_attention && past_key_value && past_key_value[0].shape[2] == current_states.shape[1]
    key_layer, value_layer = past_key_value
  else
    key_layer = transpose_for_scores(@key.(current_states))
    value_layer = transpose_for_scores(@value.(current_states))
    if !past_key_value.nil? && !is_cross_attention
      key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
      value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
    end
  end

  if @is_decoder
    # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
    # Further calls to cross_attention layer can then reuse all cross-attention
    # key/value_states (first "if" case)
    # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
    # all previous decoder key/value_states. Further calls to uni-directional self-attention
    # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
    # if encoder bi-directional self-attention `past_key_value` is always `None`
    past_key_value = [key_layer, value_layer]
  end

  # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
  # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
  # Reference: https://github.com/pytorch/pytorch/issues/112577
  if @require_contiguous_qkv && query_layer.device.type == "cuda" && !attention_mask.nil?
    query_layer = query_layer.contiguous
    key_layer = key_layer.contiguous
    value_layer = value_layer.contiguous
  end

  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
  # a causal mask in case tgt_len == 1.
  is_causal = @is_decoder && !is_cross_attention && attention_mask.nil? && tgt_len > 1 ? true : false

  attn_output = Torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask: attention_mask, dropout_p: @training ? @dropout_prob : 0.0, is_causal: is_causal)

  attn_output = attn_output.transpose(1, 2)
  attn_output = attn_output.reshape(bsz, tgt_len, @all_head_size)

  outputs = [attn_output]
  if @is_decoder
    outputs = outputs + [past_key_value]
  end
  outputs
end