3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
# File 'lib/neighbor/model.rb', line 3
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
if attribute_names.empty?
warn "[neighbor] has_neighbors without an attribute name is deprecated"
attribute_names << :neighbor_vector
else
attribute_names.map!(&:to_sym)
end
class_eval do
@neighbor_attributes ||= {}
if @neighbor_attributes.empty?
def self.neighbor_attributes
parent_attributes =
if superclass.respond_to?(:neighbor_attributes)
superclass.neighbor_attributes
else
{}
end
parent_attributes.merge(@neighbor_attributes || {})
end
end
attribute_names.each do |attribute_name|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
attribute attribute_name, Neighbor::Vector.new(dimensions: dimensions, normalize: normalize, model: self, attribute_name: attribute_name)
end
return if @neighbor_attributes.size != attribute_names.size
scope :nearest_neighbors, ->(attribute_name, vector = nil, options = nil) {
if options.nil? && vector.is_a?(Hash)
options = vector
vector = nil
end
raise ArgumentError, "missing keyword: :distance" unless options.is_a?(Hash) && options.key?(:distance)
distance = options.delete(:distance)
raise ArgumentError, "unknown keywords: #{options.keys.map(&:inspect).join(", ")}" if options.any?
if vector.nil? && !attribute_name.nil? && attribute_name.respond_to?(:to_a)
warn "[neighbor] nearest_neighbors without an attribute name is deprecated"
vector = attribute_name
attribute_name = :neighbor_vector
end
attribute_name = attribute_name.to_sym
options = neighbor_attributes[attribute_name]
raise ArgumentError, "Invalid attribute" unless options
normalize = options[:normalize]
dimensions = options[:dimensions]
return none if vector.nil?
distance = distance.to_s
quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
column_info = klass.type_for_attribute(attribute_name).column_info
operator =
if column_info[:type] == :vector
case distance
when "inner_product"
"<#>"
when "cosine"
"<=>"
when "euclidean"
"<->"
end
else
case distance
when "taxicab"
"<#>"
when "chebyshev"
"<=>"
when "euclidean", "cosine"
"<->"
end
end
raise ArgumentError, "Invalid distance: #{distance}" unless operator
if distance == "cosine" && column_info[:type] == :cube && normalize.nil?
raise Neighbor::Error, "Set normalize for cosine distance with cube"
end
vector = Neighbor::Vector.cast(vector, dimensions: dimensions, normalize: normalize, column_info: column_info)
query =
if column_info[:type] == :vector
connection.quote("[#{vector.map(&:to_f).join(", ")}]")
else
"cube(array[#{vector.map(&:to_f).join(", ")}])"
end
order = "#{quoted_attribute} #{operator} #{query}"
neighbor_distance =
if column_info[:type] != :vector && distance == "cosine"
"POWER(#{order}, 2) / 2.0"
elsif column_info[:type] == :vector && distance == "inner_product"
"(#{order}) * -1"
else
order
end
select(*column_names, "#{neighbor_distance} AS neighbor_distance")
.where.not(attribute_name => nil)
.order(Arel.sql(order))
}
def nearest_neighbors(attribute_name = nil, **options)
if attribute_name.nil?
warn "[neighbor] nearest_neighbors without an attribute name is deprecated"
attribute_name = :neighbor_vector
end
attribute_name = attribute_name.to_sym
raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
self.class
.where.not(self.class.primary_key => self[self.class.primary_key])
.nearest_neighbors(attribute_name, self[attribute_name], **options)
end
end
end
|