Class: Transformers::XlmRoberta::XLMRobertaForTokenClassification
- Inherits:
-
XLMRobertaPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- XLMRobertaPreTrainedModel
- Transformers::XlmRoberta::XLMRobertaForTokenClassification
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ XLMRobertaForTokenClassification
constructor
A new instance of XLMRobertaForTokenClassification.
Methods inherited from XLMRobertaPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ XLMRobertaForTokenClassification
Returns a new instance of XLMRobertaForTokenClassification.
1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 1080 def initialize(config) super(config) @num_labels = config.num_labels @roberta = XLMRobertaModel.new(config, add_pooling_layer: false) classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob @dropout = Torch::NN::Dropout.new(p: classifier_dropout) @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 1093 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] sequence_output = @dropout.(sequence_output) logits = @classifier.(sequence_output) loss = nil if !labels.nil? # move labels to correct device to enable model parallelism labels = labels.to(logits.device) loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) end if !return_dict output = [logits] + outputs[2..] return !loss.nil? ? [loss] + output : output end TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |