Class: Ai4r::NeuralNetwork::Transformer

Inherits:
Object
  • Object
show all
Includes:
Data::Parameterizable
Defined in:
lib/ai4r/neural_network/transformer.rb

Overview

A tiny Transformer with embeddings, positional encoding, multi-head attention and a feed-forward layer. Depending on the architecture configuration it can operate as an encoder, decoder or encoder-decoder model. Weights are initialized randomly and the model is not trainable.

Instance Method Summary collapse

Methods included from Data::Parameterizable

#get_parameters, included, #set_parameters

Constructor Details

#initialize(vocab_size:, max_len:, embed_dim: 8, num_heads: 2, ff_dim: 32, architecture: :encoder, seed: nil) ⇒ Transformer

Initialize the Transformer with given hyperparameters.

Raises:

  • (ArgumentError)


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/ai4r/neural_network/transformer.rb', line 31

def initialize(vocab_size:, max_len:, embed_dim: 8, num_heads: 2, ff_dim: 32,
               architecture: :encoder, seed: nil)
  @seed = seed
  @rng = seed ? Random.new(seed) : Random.new
  @vocab_size = vocab_size
  @max_len = max_len
  @embed_dim = embed_dim
  @num_heads = num_heads
  @ff_dim = ff_dim
  @architecture = architecture
  if embed_dim % num_heads != 0
    raise ArgumentError,
          'embed_dim must be divisible by num_heads'
  end
  raise ArgumentError, 'invalid architecture' unless %i[encoder decoder seq2seq].include?(@architecture)

  init_weights
  build_positional_encoding
end

Instance Method Details

#eval(*args) ⇒ Object

Evaluate a sequence of integer token ids. Returns an array of length seq_len with embed_dim sized vectors.



53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# File 'lib/ai4r/neural_network/transformer.rb', line 53

def eval(*args)
  case @architecture
  when :encoder
    tokens = args.first
    raise ArgumentError, 'sequence too long' if tokens.length > @max_len

    encode(tokens)
  when :decoder
    tokens = args.first
    raise ArgumentError, 'sequence too long' if tokens.length > @max_len

    decode(tokens)
  when :seq2seq
    src, tgt = args
    raise ArgumentError, 'sequence too long' if src.length > @max_len || tgt.length > @max_len

    memory = encode(src)
    decode(tgt, memory)
  else
    raise ArgumentError, 'invalid architecture'
  end
end