3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
|
# File 'lib/neighbor/model.rb', line 3
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
if attribute_names.empty?
raise ArgumentError, "has_neighbors requires an attribute name"
end
attribute_names.map!(&:to_sym)
class_eval do
@neighbor_attributes ||= {}
if @neighbor_attributes.empty?
def self.neighbor_attributes
parent_attributes =
if superclass.respond_to?(:neighbor_attributes)
superclass.neighbor_attributes
else
{}
end
parent_attributes.merge(@neighbor_attributes || {})
end
end
attribute_names.each do |attribute_name|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
end
return if @neighbor_attributes.size != attribute_names.size
validate do
self.class.neighbor_attributes.each do |k, v|
value = read_attribute(k)
next if value.nil?
column_info = self.class.columns_hash[k.to_s]
dimensions = v[:dimensions] || column_info&.limit
if !Neighbor::Utils.validate_dimensions(value, column_info&.type, dimensions).nil?
errors.add(k, "must have #{dimensions} dimensions")
end
if !Neighbor::Utils.validate_finite(value, column_info&.type)
errors.add(k, "must have finite values")
end
end
end
before_save do
self.class.neighbor_attributes.each do |k, v|
next unless v[:normalize] && attribute_changed?(k)
value = read_attribute(k)
next if value.nil?
self[k] = Neighbor::Utils.normalize(value, column_info: self.class.columns_hash[k.to_s])
end
end
scope :nearest_neighbors, ->(attribute_name, vector, options = nil) {
raise ArgumentError, "missing keyword: :distance" unless options.is_a?(Hash) && options.key?(:distance)
distance = options.delete(:distance)
precision = options.delete(:precision)
raise ArgumentError, "unknown keywords: #{options.keys.map(&:inspect).join(", ")}" if options.any?
attribute_name = attribute_name.to_sym
options = neighbor_attributes[attribute_name]
raise ArgumentError, "Invalid attribute" unless options
normalize = options[:normalize]
dimensions = options[:dimensions]
return none if vector.nil?
distance = distance.to_s
quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
column_info = columns_hash[attribute_name.to_s]
column_type = column_info&.type
operator =
case column_type
when :bit
case distance
when "hamming"
"<~>"
when "jaccard"
"<%>"
when "hamming2"
"#"
end
when :vector, :halfvec, :sparsevec
case distance
when "inner_product"
"<#>"
when "cosine"
"<=>"
when "euclidean"
"<->"
when "taxicab"
"<+>"
end
when :cube
case distance
when "taxicab"
"<#>"
when "chebyshev"
"<=>"
when "euclidean", "cosine"
"<->"
end
else
raise ArgumentError, "Unsupported type: #{column_type}"
end
raise ArgumentError, "Invalid distance: #{distance}" unless operator
if distance == "cosine" && column_type == :cube && normalize.nil?
raise Neighbor::Error, "Set normalize for cosine distance with cube"
end
column_attribute = klass.type_for_attribute(attribute_name)
vector = column_attribute.cast(vector)
Neighbor::Utils.validate(vector, dimensions: dimensions, column_info: column_info)
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
query = connection.quote(column_attribute.serialize(vector))
if !precision.nil?
case precision.to_s
when "half"
cast_dimensions = dimensions || column_info&.limit
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
quoted_attribute += "::halfvec(#{connection.quote(cast_dimensions.to_i)})"
else
raise ArgumentError, "Invalid precision"
end
end
order = "#{quoted_attribute} #{operator} #{query}"
if operator == "#"
order = "bit_count(#{order})"
end
neighbor_distance =
if column_type == :cube && distance == "cosine"
"POWER(#{order}, 2) / 2.0"
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
"(#{order}) * -1"
else
order
end
select_columns = select_values.any? ? [] : column_names
select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
.where.not(attribute_name => nil)
.reorder(Arel.sql(order))
}
def nearest_neighbors(attribute_name, **options)
attribute_name = attribute_name.to_sym
raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
self.class
.where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
.nearest_neighbors(attribute_name, self[attribute_name], **options)
end
end
end
|