Class: Transformers::Vit::ViTLayer
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Vit::ViTLayer
- Defined in:
- lib/transformers/models/vit/modeling_vit.rb
Instance Method Summary collapse
- #forward(hidden_states, head_mask: nil, output_attentions: false) ⇒ Object
-
#initialize(config) ⇒ ViTLayer
constructor
A new instance of ViTLayer.
Constructor Details
#initialize(config) ⇒ ViTLayer
Returns a new instance of ViTLayer.
247 248 249 250 251 252 253 254 255 256 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 247 def initialize(config) super() @chunk_size_feed_forward = config.chunk_size_feed_forward @seq_len_dim = 1 @attention = VIT_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config) @intermediate = ViTIntermediate.new(config) @output = ViTOutput.new(config) @layernorm_before = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps) @layernorm_after = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps) end |
Instance Method Details
#forward(hidden_states, head_mask: nil, output_attentions: false) ⇒ Object
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 258 def forward( hidden_states, head_mask: nil, output_attentions: false ) self_attention_outputs = @attention.( @layernorm_before.(hidden_states), # in ViT, layernorm is applied before self-attention head_mask: head_mask, output_attentions: output_attentions ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1..] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + hidden_states # in ViT, layernorm is also applied after self-attention layer_output = @layernorm_after.(hidden_states) layer_output = @intermediate.(layer_output) # second residual connection is done here layer_output = @output.(layer_output, hidden_states) outputs = [layer_output] + outputs outputs end |