Class: Torch::NN::MultiheadAttention

Inherits:
Module
  • Object
show all
Defined in:
lib/torch/nn/multihead_attention.rb

Instance Attribute Summary

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#_apply, #add_module, #apply, #buffers, #call, #children, #cpu, #cuda, #deep_dup, #double, #eval, #float, #half, #inspect, #load_state_dict, #method_missing, #modules, #named_buffers, #named_children, #named_modules, #named_parameters, #parameters, #register_buffer, #register_parameter, #requires_grad!, #respond_to?, #share_memory, #state_dict, #to, #train, #type, #zero_grad

Methods included from Utils

#_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initialize(embed_dim, num_heads, dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false, kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil) ⇒ MultiheadAttention

Returns a new instance of MultiheadAttention.

Raises:

  • (ArgumentError)


4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/torch/nn/multihead_attention.rb', line 4

def initialize(
  embed_dim, num_heads,
  dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
  kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
)

  super()

  @embed_dim = embed_dim
  @kdim = kdim || @embed_dim
  @vdim = vdim || @embed_dim

  @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim

  @num_heads = num_heads
  @dropout = dropout
  @batch_first = batch_first

  @head_dim = @embed_dim.div @num_heads

  raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim

  if @qkv_same_embed_dim
    @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
    %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
  else
    @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
    @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
    @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))

    register_parameter('in_proj_weight', nil)
  end

  if bias
    @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
  else
    register_parameter('in_proj_bias', nil)
  end

  @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)

  if add_bias_kv
    @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
    @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
  else
    @bias_k = @bias_v = nil
  end

  @add_zero_attn = add_zero_attn

  reset_parameters
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method in the class Torch::NN::Module

Instance Method Details

#batch_first?Boolean

Returns:

  • (Boolean)


57
58
59
# File 'lib/torch/nn/multihead_attention.rb', line 57

def batch_first?
  !!@batch_first
end

#forward(query, key, value, key_padding_mask: nil, need_weights: true, attn_mask: nil) ⇒ Object



79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# File 'lib/torch/nn/multihead_attention.rb', line 79

def forward(
  query, key, value,
  key_padding_mask: nil, need_weights: true, attn_mask: nil
)

  if batch_first?
    query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
  end

  attn_output, attn_output_weights =
    if @qkv_same_embed_dim
      F.multi_head_attention_forward(
        query, key, value,
        @embed_dim, @num_heads,
        @in_proj_weight, @in_proj_bias,
        @bias_k, @bias_v, @add_zero_attn,
        @dropout, @out_proj.weight, @out_proj.bias,
        training: @training,
        key_padding_mask: key_padding_mask,
        need_weights: need_weights,
        attn_mask: attn_mask
      )
    else
      F.multi_head_attention_forward(
        query, key, value,
        @embed_dim, @num_heads,
        @in_proj_weight, @in_proj_bias,
        @bias_k, @bias_v, @add_zero_attn,
        @dropout, @out_proj.weight, @out_proj.bias,
        training: @training,
        key_padding_mask: key_padding_mask,
        need_weights: need_weights,
        attn_mask: attn_mask,
        use_separate_proj_weight: true,
        q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
      )
    end

  attn_output = attn_output.transpose(1, 0) if batch_first?

  [attn_output, attn_output_weights]
end

#reset_parametersObject



61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/torch/nn/multihead_attention.rb', line 61

def reset_parameters
  if @qkv_same_embed_dim
    Init.xavier_uniform!(@in_proj_weight)
  else
    Init.xavier_uniform!(@q_proj_weight)
    Init.xavier_uniform!(@k_proj_weight)
    Init.xavier_uniform!(@v_proj_weight)
  end

  if @in_proj_bias
    Init.constant!(@in_proj_bias, 0.0)
    Init.constant!(@out_proj.bias, 0.0)
  end

  Init.xavier_uniform!(@bias_k) if @bias_k
  Init.xavier_uniform!(@bias_v) if @bias_v
end