Module: Neighbor::Model

Defined in:
lib/neighbor/model.rb

Instance Method Summary collapse

Instance Method Details

#has_neighbors(*attribute_names, dimensions: nil, normalize: nil) ⇒ Object



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# File 'lib/neighbor/model.rb', line 3

def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
  if attribute_names.empty?
    warn "[neighbor] has_neighbors without an attribute name is deprecated"
    attribute_names << :neighbor_vector
  else
    attribute_names.map!(&:to_sym)
  end

  class_eval do
    @neighbor_attributes ||= {}

    if @neighbor_attributes.empty?
      def self.neighbor_attributes
        parent_attributes =
          if superclass.respond_to?(:neighbor_attributes)
            superclass.neighbor_attributes
          else
            {}
          end

        parent_attributes.merge(@neighbor_attributes || {})
      end
    end

    attribute_names.each do |attribute_name|
      raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
      @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}

      attribute attribute_name, Neighbor::Vector.new(dimensions: dimensions, normalize: normalize, model: self, attribute_name: attribute_name)
    end

    return if @neighbor_attributes.size != attribute_names.size

    scope :nearest_neighbors, ->(attribute_name, vector = nil, options = nil) {
      # cannot use keyword arguments with scope with Ruby 3.2 and Active Record 6.1
      # https://github.com/rails/rails/issues/46934
      if options.nil? && vector.is_a?(Hash)
        options = vector
        vector = nil
      end
      raise ArgumentError, "missing keyword: :distance" unless options.is_a?(Hash) && options.key?(:distance)
      distance = options.delete(:distance)
      raise ArgumentError, "unknown keywords: #{options.keys.map(&:inspect).join(", ")}" if options.any?

      if vector.nil? && !attribute_name.nil? && attribute_name.respond_to?(:to_a)
        warn "[neighbor] nearest_neighbors without an attribute name is deprecated"
        vector = attribute_name
        attribute_name = :neighbor_vector
      end
      attribute_name = attribute_name.to_sym

      options = neighbor_attributes[attribute_name]
      raise ArgumentError, "Invalid attribute" unless options
      normalize = options[:normalize]
      dimensions = options[:dimensions]

      return none if vector.nil?

      distance = distance.to_s

      quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"

      column_info = klass.type_for_attribute(attribute_name).column_info

      operator =
        if column_info[:type] == :vector
          case distance
          when "inner_product"
            "<#>"
          when "cosine"
            "<=>"
          when "euclidean"
            "<->"
          end
        else
          case distance
          when "taxicab"
            "<#>"
          when "chebyshev"
            "<=>"
          when "euclidean", "cosine"
            "<->"
          end
        end

      raise ArgumentError, "Invalid distance: #{distance}" unless operator

      # ensure normalize set (can be true or false)
      if distance == "cosine" && column_info[:type] == :cube && normalize.nil?
        raise Neighbor::Error, "Set normalize for cosine distance with cube"
      end

      vector = Neighbor::Vector.cast(vector, dimensions: dimensions, normalize: normalize, column_info: column_info)

      # important! neighbor_vector should already be typecast
      # but use to_f as extra safeguard against SQL injection
      query =
        if column_info[:type] == :vector
          connection.quote("[#{vector.map(&:to_f).join(", ")}]")
        else
          "cube(array[#{vector.map(&:to_f).join(", ")}])"
        end

      order = "#{quoted_attribute} #{operator} #{query}"

      # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
      # with normalized vectors:
      # cosine similarity = 1 - (euclidean distance)**2 / 2
      # cosine distance = 1 - cosine similarity
      # this transformation doesn't change the order, so only needed for select
      neighbor_distance =
        if column_info[:type] != :vector && distance == "cosine"
          "POWER(#{order}, 2) / 2.0"
        elsif column_info[:type] == :vector && distance == "inner_product"
          "(#{order}) * -1"
        else
          order
        end

      # for select, use column_names instead of * to account for ignored columns
      select(*column_names, "#{neighbor_distance} AS neighbor_distance")
        .where.not(attribute_name => nil)
        .order(Arel.sql(order))
    }

    def nearest_neighbors(attribute_name = nil, **options)
      if attribute_name.nil?
        warn "[neighbor] nearest_neighbors without an attribute name is deprecated"
        attribute_name = :neighbor_vector
      end
      attribute_name = attribute_name.to_sym
      # important! check if neighbor attribute before calling send
      raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]

      self.class
        .where.not(self.class.primary_key => self[self.class.primary_key])
        .nearest_neighbors(attribute_name, self[attribute_name], **options)
    end
  end
end