274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
|
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 274
def forward(hidden_states, residual_states, input_mask)
out = @conv.(hidden_states.permute(0, 2, 1).contiguous).permute(0, 2, 1).contiguous
rmask = (1 - input_mask).bool
out.masked_fill!(rmask.unsqueeze(-1).expand(out.size), 0)
out = ACT2FN[@conv_act].(@dropout.(out))
layer_norm_input = residual_states + out
output = @LayerNorm.(layer_norm_input).to(layer_norm_input)
if input_mask.nil?
output_states = output
elsif input_mask.dim != layer_norm_input.dim
if input_mask.dim == 4
input_mask = input_mask.squeeze(1).squeeze(1)
end
input_mask = input_mask.unsqueeze(2)
end
output_states
end
|