Class: Transformers::XlmRoberta::XLMRobertaForSequenceClassification

Inherits:
XLMRobertaPreTrainedModel show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from XLMRobertaPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ XLMRobertaForSequenceClassification

Returns a new instance of XLMRobertaForSequenceClassification.



953
954
955
956
957
958
959
960
961
962
963
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 953

def initialize(config)
  super(config)
  @num_labels = config.num_labels
  @config = config

  @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
  @classifier = XLMRobertaClassificationHead.new(config)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 965

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
  sequence_output = outputs[0]
  logits = @classifier.(sequence_output)

  loss = nil
  if !labels.nil?
    # move labels to correct device to enable model parallelism
    labels = labels.to(logits.device)
    if @config.problem_type.nil?
      if @num_labels == 1
        @problem_type = "regression"
      elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
        @problem_type = "single_label_classification"
      else
        @problem_type = "multi_label_classification"
      end
    end

    if @config.problem_type == "regression"
      loss_fct = Torch::NN::MSELoss.new
      if @num_labels == 1
        loss = loss_fct.(logits.squeeze, labels.squeeze)
      else
        loss = loss_fct.(logits, labels)
      end
    elsif @config.problem_type == "single_label_classification"
      loss_fct = Torch::NN::CrossEntropyLoss.new
      loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
    elsif @config.problem_type == "multi_label_classification"
      loss_fct = Torch::NN::BCEWithLogitsLoss.new
      loss = loss_fct.(logits, labels)
    end
  end

  if !return_dict
    output = [logits] + outputs[2..]
    return !loss.nil? ? [loss] + output : output
  end

  SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end