Class: Transformers::Mpnet::MPNetEmbeddings

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ MPNetEmbeddings

Returns a new instance of MPNetEmbeddings.



43
44
45
46
47
48
49
50
51
52
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 43

def initialize(config)
  super()
  @padding_idx = 1
  @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: @padding_idx)
  @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size, padding_idx: @padding_idx)

  @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
  @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
  register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
end

Instance Method Details

#create_position_ids_from_input_ids(input_ids, padding_idx) ⇒ Object



94
95
96
97
98
99
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 94

def create_position_ids_from_input_ids(input_ids, padding_idx)
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  mask = input_ids.ne(padding_idx).int
  incremental_indices = Torch.cumsum(mask, dim: 1).type_as(mask) * mask
  incremental_indices.long + padding_idx
end

#create_position_ids_from_inputs_embeds(inputs_embeds) ⇒ Object



86
87
88
89
90
91
92
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 86

def create_position_ids_from_inputs_embeds(inputs_embeds)
  input_shape = inputs_embeds.size[...-1]
  sequence_length = input_shape[1]

  position_ids = Torch.arange(@padding_idx + 1, sequence_length + @padding_idx + 1, dtype: Torch.long, device: inputs_embeds.device)
  position_ids.unsqueeze(0).expand(input_shape)
end

#forward(input_ids: nil, position_ids: nil, inputs_embeds: nil, **kwargs) ⇒ Object



54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 54

def forward(input_ids: nil, position_ids: nil, inputs_embeds: nil, **kwargs)
  if position_ids.nil?
    if !input_ids.nil?
      position_ids = create_position_ids_from_input_ids(input_ids, @padding_idx)
    else
      position_ids = create_position_ids_from_inputs_embeds(inputs_embeds)
    end
  end

  if !input_ids.nil?
    input_shape = input_ids.size
  else
    input_shape = inputs_embeds.size[...-1]
  end

  seq_length = input_shape[1]

  if position_ids.nil?
    position_ids = @position_ids[0.., ...seq_length]
  end

  if inputs_embeds.nil?
    inputs_embeds = @word_embeddings.(input_ids)
  end
  position_embeddings = @position_embeddings.(position_ids)

  embeddings = inputs_embeds + position_embeddings
  embeddings = @LayerNorm.(embeddings)
  embeddings = @dropout.(embeddings)
  embeddings
end