Class: Transformers::DebertaV2::DebertaV2ForTokenClassification
- Inherits:
-
DebertaV2PreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- DebertaV2PreTrainedModel
- Transformers::DebertaV2::DebertaV2ForTokenClassification
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ DebertaV2ForTokenClassification
constructor
A new instance of DebertaV2ForTokenClassification.
Methods inherited from DebertaV2PreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ DebertaV2ForTokenClassification
Returns a new instance of DebertaV2ForTokenClassification.
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1025 def initialize(config) super(config) @num_labels = config.num_labels @deberta = DebertaV2Model.new(config) @dropout = Torch::NN::Dropout.new(config.hidden_dropout_prob) @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1037 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] sequence_output = @dropout.(sequence_output) logits = @classifier.(sequence_output) loss = nil if !labels.nil? loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) end if !return_dict output = [logits] + outputs[1..] return !loss.nil? ? [loss] + output : output end TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |