Class: Transformers::XlmRoberta::XLMRobertaForSequenceClassification
- Inherits:
-
XLMRobertaPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- XLMRobertaPreTrainedModel
- Transformers::XlmRoberta::XLMRobertaForSequenceClassification
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ XLMRobertaForSequenceClassification
constructor
A new instance of XLMRobertaForSequenceClassification.
Methods inherited from XLMRobertaPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ XLMRobertaForSequenceClassification
Returns a new instance of XLMRobertaForSequenceClassification.
953 954 955 956 957 958 959 960 961 962 963 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 953 def initialize(config) super(config) @num_labels = config.num_labels @config = config @roberta = XLMRobertaModel.new(config, add_pooling_layer: false) @classifier = XLMRobertaClassificationHead.new(config) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 965 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] logits = @classifier.(sequence_output) loss = nil if !labels.nil? # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if @config.problem_type.nil? if @num_labels == 1 @problem_type = "regression" elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int @problem_type = "single_label_classification" else @problem_type = "multi_label_classification" end end if @config.problem_type == "regression" loss_fct = Torch::NN::MSELoss.new if @num_labels == 1 loss = loss_fct.(logits.squeeze, labels.squeeze) else loss = loss_fct.(logits, labels) end elsif @config.problem_type == "single_label_classification" loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) elsif @config.problem_type == "multi_label_classification" loss_fct = Torch::NN::BCEWithLogitsLoss.new loss = loss_fct.(logits, labels) end end if !return_dict output = [logits] + outputs[2..] return !loss.nil? ? [loss] + output : output end SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |