Class: Transformers::DebertaV2::ConvLayer

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ ConvLayer

Returns a new instance of ConvLayer.



263
264
265
266
267
268
269
270
271
272
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 263

def initialize(config)
  super()
  kernel_size = config.getattr("conv_kernel_size", 3)
  groups = config.getattr("conv_groups", 1)
  @conv_act = config.getattr("conv_act", "tanh")
  @conv = Torch::NN::Conv1d.new(config.hidden_size, config.hidden_size, kernel_size, padding: (kernel_size - 1) / 2, groups: groups)
  @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
  @dropout = StableDropout.new(config.hidden_dropout_prob)
  @config = config
end

Instance Method Details

#forward(hidden_states, residual_states, input_mask) ⇒ Object



274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 274

def forward(hidden_states, residual_states, input_mask)
  out = @conv.(hidden_states.permute(0, 2, 1).contiguous).permute(0, 2, 1).contiguous
  rmask = (1 - input_mask).bool
  out.masked_fill!(rmask.unsqueeze(-1).expand(out.size), 0)
  out = ACT2FN[@conv_act].(@dropout.(out))

  layer_norm_input = residual_states + out
  output = @LayerNorm.(layer_norm_input).to(layer_norm_input)

  if input_mask.nil?
    output_states = output
  elsif input_mask.dim != layer_norm_input.dim
    if input_mask.dim == 4
      input_mask = input_mask.squeeze(1).squeeze(1)
    end
    input_mask = input_mask.unsqueeze(2)
  end

  output_states
end