Method: Transformers::Mpnet::MPNetForTokenClassification#forward

Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 672

def forward(
  input_ids: nil,
  attention_mask: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)

  sequence_output = outputs[0]

  sequence_output = @dropout.(sequence_output)
  logits = @classifier.(sequence_output)

  loss = nil
  if !labels.nil?
    loss_fct = Torch::NN::CrossEntropyLoss.new
    loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
  end

  if !return_dict
    output = [logits] + outputs[2..]
    return !loss.nil? ? [loss] + output : output
  end

  TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end