Class: Ai4r::NeuralNetwork::Transformer
- Inherits:
-
Object
- Object
- Ai4r::NeuralNetwork::Transformer
- Includes:
- Data::Parameterizable
- Defined in:
- lib/ai4r/neural_network/transformer.rb
Overview
A tiny Transformer with embeddings, positional encoding, multi-head attention and a feed-forward layer. Depending on the architecture configuration it can operate as an encoder, decoder or encoder-decoder model. Weights are initialized randomly and the model is not trainable.
Instance Method Summary collapse
-
#eval(*args) ⇒ Object
Evaluate a sequence of integer token ids.
-
#initialize(vocab_size:, max_len:, embed_dim: 8, num_heads: 2, ff_dim: 32, architecture: :encoder, seed: nil) ⇒ Transformer
constructor
Initialize the Transformer with given hyperparameters.
Methods included from Data::Parameterizable
#get_parameters, included, #set_parameters
Constructor Details
#initialize(vocab_size:, max_len:, embed_dim: 8, num_heads: 2, ff_dim: 32, architecture: :encoder, seed: nil) ⇒ Transformer
Initialize the Transformer with given hyperparameters.
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# File 'lib/ai4r/neural_network/transformer.rb', line 31 def initialize(vocab_size:, max_len:, embed_dim: 8, num_heads: 2, ff_dim: 32, architecture: :encoder, seed: nil) @seed = seed @rng = seed ? Random.new(seed) : Random.new @vocab_size = vocab_size @max_len = max_len @embed_dim = @num_heads = num_heads @ff_dim = ff_dim @architecture = architecture if % num_heads != 0 raise ArgumentError, 'embed_dim must be divisible by num_heads' end raise ArgumentError, 'invalid architecture' unless %i[encoder decoder seq2seq].include?(@architecture) init_weights build_positional_encoding end |
Instance Method Details
#eval(*args) ⇒ Object
Evaluate a sequence of integer token ids. Returns an array of length seq_len with embed_dim sized vectors.
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# File 'lib/ai4r/neural_network/transformer.rb', line 53 def eval(*args) case @architecture when :encoder tokens = args.first raise ArgumentError, 'sequence too long' if tokens.length > @max_len encode(tokens) when :decoder tokens = args.first raise ArgumentError, 'sequence too long' if tokens.length > @max_len decode(tokens) when :seq2seq src, tgt = args raise ArgumentError, 'sequence too long' if src.length > @max_len || tgt.length > @max_len memory = encode(src) decode(tgt, memory) else raise ArgumentError, 'invalid architecture' end end |