Class: Torch::NN::Transformer

Inherits:
Module
  • Object
show all
Defined in:
lib/torch/nn/transformer.rb

Instance Attribute Summary collapse

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#_apply, #add_module, #apply, #buffers, #call, #children, #cpu, #cuda, #deep_dup, #double, #eval, #float, #half, #inspect, #load_state_dict, #method_missing, #modules, #named_buffers, #named_children, #named_modules, #named_parameters, #parameters, #register_buffer, #register_parameter, #requires_grad!, #respond_to?, #share_memory, #state_dict, #to, #train, #type, #zero_grad

Methods included from Utils

#_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initialize(d_model: 512, nhead: 8, num_encoder_layers: 6, num_decoder_layers: 6, dim_feedforward: 2048, dropout: 0.1, activation: :relu, custom_encoder: nil, custom_decoder: nil, layer_norm_eps: 1e-5, batch_first: false) ⇒ Transformer

Returns a new instance of Transformer.



9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/torch/nn/transformer.rb', line 9

def initialize(
  d_model: 512, nhead: 8,
  num_encoder_layers: 6, num_decoder_layers: 6,
  dim_feedforward: 2048, dropout: 0.1, activation: :relu,
  custom_encoder: nil, custom_decoder: nil,
  layer_norm_eps: 1e-5, batch_first: false
)

  super()

  @encoder =
    if custom_encoder
      custom_encoder
    else
      encoder_layer = TransformerEncoderLayer.new(
        d_model, nhead,
        dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
        layer_norm_eps: layer_norm_eps, batch_first: batch_first
      )
      encoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
      TransformerEncoder.new(encoder_layer, num_encoder_layers, norm: encoder_norm)
    end

  @decoder =
    if custom_decoder
      custom_decoder
    else
      decoder_layer = TransformerDecoderLayer.new(
        d_model, nhead,
        dim_feedforward: dim_feedforward, dropout: dropout, activation: activation,
        layer_norm_eps: layer_norm_eps, batch_first: batch_first
      )
      decoder_norm = LayerNorm.new(d_model, eps: layer_norm_eps)
      TransformerDecoder.new(decoder_layer, num_decoder_layers, norm: decoder_norm)
    end

  reset_parameters

  @d_model = d_model
  @nhead = nhead
  @batch_first = batch_first
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method in the class Torch::NN::Module

Instance Attribute Details

#d_modelObject (readonly)

Returns the value of attribute d_model.



52
53
54
# File 'lib/torch/nn/transformer.rb', line 52

def d_model
  @d_model
end

#decoderObject (readonly)

Returns the value of attribute decoder.



52
53
54
# File 'lib/torch/nn/transformer.rb', line 52

def decoder
  @decoder
end

#encoderObject (readonly)

Returns the value of attribute encoder.



52
53
54
# File 'lib/torch/nn/transformer.rb', line 52

def encoder
  @encoder
end

#nheadObject (readonly)

Returns the value of attribute nhead.



52
53
54
# File 'lib/torch/nn/transformer.rb', line 52

def nhead
  @nhead
end

Instance Method Details

#batch_first?Boolean

Returns:

  • (Boolean)


54
55
56
# File 'lib/torch/nn/transformer.rb', line 54

def batch_first?
  !!@batch_first
end

#forward(src, tgt, src_mask: nil, tgt_mask: nil, memory_mask: nil, src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) ⇒ Object



62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/torch/nn/transformer.rb', line 62

def forward(
  src, tgt,
  src_mask: nil, tgt_mask: nil, memory_mask: nil,
  src_key_padding_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil
)

  if (!batch_first? && src.size(1) != tgt.size(1)) ||
    (batch_first? && src.size(0) != tgt.size(0))

    raise ArgumentError, "The batch number of src and tgt must be equal"
  end

  if src.size(2) != d_model || tgt.size(2) != d_model
    raise ArgumentError, "The feature number of src and tgt must be equal to d_model"
  end

  memory = @encoder.(src, mask: src_mask, src_key_padding_mask: src_key_padding_mask)
  @decoder.(
    tgt, memory,
    tgt_mask: tgt_mask, memory_mask: memory_mask,
    tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask
  )
end

#generate_square_subsequent_mask(sz) ⇒ Object



86
87
88
89
# File 'lib/torch/nn/transformer.rb', line 86

def generate_square_subsequent_mask(sz)
  mask = Torch.triu(Torch.ones([sz, sz])).eq(1).transpose(0, 1)
  mask.float.masked_fill!(mask.eq(0), -Float::INFINITY).masked_fill!(mask.eq(1), 0.0)
end

#reset_parametersObject



58
59
60
# File 'lib/torch/nn/transformer.rb', line 58

def reset_parameters
  parameters.each { |p| Init.xavier_uniform!(p) if p.dim > 1 }
end