Module: TensorStream::EmbeddingLookup

Includes:
PyPorts
Included in:
NN
Defined in:
lib/tensor_stream/nn/embedding_lookup.rb

Instance Method Summary collapse

Methods included from PyPorts

#floor_div

Instance Method Details

#_clip(params, ids, max_norm) ⇒ Object



96
97
98
99
100
101
102
103
# File 'lib/tensor_stream/nn/embedding_lookup.rb', line 96

def _clip(params, ids, max_norm)
  return params if max_norm.nil?

  ids_rank, ids_static = _rank(ids)
  params_rank, params_static = _rank(params)

  TensorStream.clip_by_norm(params, max_norm, axes: ids_static && params_static ? (ids_rank...params_rank).to_a : TensorStream.range(ids_rank, params_rank))
end

#_embedding_lookup_and_transform(params, ids, partition_strategy: "mod", name: nil, max_norm: nil, transform_fn: nil) ⇒ Object

Helper function for embedding_lookup and _compute_sampled_logits.



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# File 'lib/tensor_stream/nn/embedding_lookup.rb', line 17

def _embedding_lookup_and_transform(params, ids, partition_strategy: "mod", name: nil, max_norm: nil, transform_fn: nil)
  raise TensorStream::ValueError, "Need at least one param" if params.nil?

  params = [params] unless params.is_a?(Array)

  TensorStream.name_scope(name, "embedding_lookup", values: params + [ids]) do |name|
    np = params.size
    ids = TensorStream.convert_to_tensor(ids, name: "ids")
    if (np == 1) && (transform_fn.nil? || (ids.shape.size == 1))
      result = nil
      TensorStream.colocate_with(params[0]) do
        result = _clip(TensorStream.gather(params[0], ids, name: name), ids, max_norm)
        result = transform_fn.call(result) if transform_fn
      end

      return TensorStream.identity(result)
    else
      flat_ids = TensorStream.reshape(ids, [-1])
      original_indices = TensorStream.range(TensorStream.size(flat_ids))

      p_assignments = nil
      new_ids = nil

      if partition_strategy == "mod"
        p_assignments = flat_ids % np
        new_ids = floor_div(flat_ids, np)
      elsif partition_strategy == "div"
        raise "not yet supported!"
      else
        raise TensorStream::ValueError, "Unrecognized partition strategy: " + partition_strategy
      end

      p_assignments = TensorStream.cast(p_assignments, :int32)
      gather_ids = TensorStream.dynamic_partition(new_ids, p_assignments, np)
      pindices = TensorStream.dynamic_partition(original_indices, p_assignments, np)
      partitioned_result = []
      (0...np).each do |p|
        pids = gather_ids[p]
        result = nil
        TensorStream.colocate_with(params[p]) do
          result = TensorStream.gather(params[p], pids)
          if transform_fn
            # If transform_fn is provided, the clip_by_norm precedes
            # the transform and hence must be co-located. See below
            # for the counterpart if transform_fn is not proveded.
            result = transform_fn.call(_clip(result, pids, max_norm))
          end
        end
        partitioned_result << result
      end
      ret = TensorStream.dynamic_stitch(pindices, partitioned_result, name: name)

      if transform_fn.nil?
        element_shape_s = params[0].shape[1..-1]
        params[1..-1].each { |p| element_shape_s = element_shape_s.merge_with(p.shape[1..-1]) }
      else
        element_shape_s = ret.shape[1..-1]
      end

       # Compute the dynamic element shape.
      element_shape_d = if element_shape_s.fully_defined?
                           element_shape_s
                        elsif transform_fn.nil?
                          # It's important that we compute params[0].shape on the right device
                          # to avoid data motion.
                          TensorStream.colocate_with(params[0]) do
                            params_shape = TensorStream.shape(params[0])
                            params_shape[1..-1]
                          end
                        else
                          TensorStream.shape(ret)[1..-1]
                        end
      ret = TensorStream.reshape(ret, TensorStream.concat([TensorStream.shape(ids), element_shape_d], 0))
      ret = _clip(ret, ids, max_norm) unless transform_fn
      ret
    end
  end
end

#_rank(x) ⇒ Object



105
106
107
108
109
110
111
112
# File 'lib/tensor_stream/nn/embedding_lookup.rb', line 105

def _rank(x)
  rank = TensorStream.convert_to_tensor(x).shape.ndims
  if rank
    [rank, false]
  else
    [TensorStream.rank(x), false]
  end
end

#embedding_lookup(params, ids, partition_strategy: "mod", name: nil, validate_indices: true, max_norm: nil) ⇒ Object

Looks up ‘ids` in a list of embedding tensors.



11
12
13
# File 'lib/tensor_stream/nn/embedding_lookup.rb', line 11

def embedding_lookup(params, ids, partition_strategy: "mod", name: nil, validate_indices: true, max_norm: nil)
  _embedding_lookup_and_transform(params, ids, partition_strategy: partition_strategy, name: name, max_norm: max_norm, transform_fn: nil)
end