Class: NanoGPT::Layers::CausalSelfAttention
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- NanoGPT::Layers::CausalSelfAttention
- Defined in:
- lib/nano_gpt/layers/causal_self_attention.rb
Overview
Multi-head causal self-attention
Instance Method Summary collapse
- #forward(x) ⇒ Object
-
#initialize(config) ⇒ CausalSelfAttention
constructor
A new instance of CausalSelfAttention.
Constructor Details
#initialize(config) ⇒ CausalSelfAttention
Returns a new instance of CausalSelfAttention.
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# File 'lib/nano_gpt/layers/causal_self_attention.rb', line 7 def initialize(config) super() raise ArgumentError, "n_embd must be divisible by n_head" unless (config.n_embd % config.n_head).zero? @n_head = config.n_head @n_embd = config.n_embd @head_size = config.n_embd / config.n_head @dropout_p = config.dropout # Key, query, value projections for all heads, combined @c_attn = Torch::NN::Linear.new(config.n_embd, 3 * config.n_embd, bias: config.bias) # Output projection @c_proj = Torch::NN::Linear.new(config.n_embd, config.n_embd, bias: config.bias) # Regularization @attn_dropout = Torch::NN::Dropout.new(p: config.dropout) @resid_dropout = Torch::NN::Dropout.new(p: config.dropout) # Use native scaled_dot_product_attention with is_causal=true when dropout is 0 # Native SDPA is ~5x faster but doesn't support dropout with is_causal mode @flash = config.dropout == 0.0 # Causal mask for manual attention (only used when @flash is false) unless @flash mask = Torch.tril(Torch.ones(config.block_size, config.block_size)) register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) end end |
Instance Method Details
#forward(x) ⇒ Object
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# File 'lib/nano_gpt/layers/causal_self_attention.rb', line 35 def forward(x) b, t, c = x.shape # Calculate Q, K, V qkv = @c_attn.call(x) q, k, v = qkv.split(@n_embd, 2) # Reshape: (B, T, C) -> (B, nh, T, hs) q = q.view(b, t, @n_head, @head_size).transpose(1, 2) k = k.view(b, t, @n_head, @head_size).transpose(1, 2) v = v.view(b, t, @n_head, @head_size).transpose(1, 2) y = if @flash # Native scaled_dot_product_attention with is_causal=true # Uses Flash Attention on CUDA, optimized kernel on MPS Torch::NN.scaled_dot_product_attention(q, k, v, nil, 0.0, true) else # Manual attention implementation with causal mask scale = 1.0 / Math.sqrt(@head_size) att = q.matmul(k.transpose(-2, -1)) att.mul!(scale) # Apply causal mask - slice mask to current sequence length mask_slice = @mask.narrow(2, 0, t).narrow(3, 0, t) att.masked_fill!(mask_slice.eq(0), -Float::INFINITY) att = Torch::NN::Functional.softmax(att, dim: -1) att = @attn_dropout.call(att) att.matmul(v) end # Reassemble heads: (B, nh, T, hs) -> (B, T, C) y = y.transpose(1, 2).contiguous.view(b, t, c) # Output projection @resid_dropout.call(@c_proj.call(y)) end |