410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
|
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 410
def forward(
input_ids: nil,
attention_mask: nil,
position_ids: nil,
head_mask: nil,
inputs_embeds: nil,
output_attentions: nil,
output_hidden_states: nil,
return_dict: nil,
**kwargs
)
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
if !input_ids.nil? && !inputs_embeds.nil?
raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
elsif !input_ids.nil?
warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size
elsif !inputs_embeds.nil?
input_shape = inputs_embeds.size[...-1]
else
raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
end
device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
if attention_mask.nil?
attention_mask = Torch.ones(input_shape, device: device)
end
extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, inputs_embeds: inputs_embeds)
encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
sequence_output = encoder_outputs[0]
pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
if !return_dict
return [sequence_output, pooled_output] + encoder_outputs[1..]
end
BaseModelOutputWithPooling.new(last_hidden_state: sequence_output, pooler_output: pooled_output, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions)
end
|