4
4
"context"
5
5
"embed"
6
6
"math"
7
+ "math/rand"
7
8
"os"
8
9
"path/filepath"
9
10
@@ -22,6 +23,19 @@ import (
22
23
//go:embed backend-assets/*
23
24
var backendAssets embed.FS
24
25
26
+ func normalize (vecs [][]float32 ) {
27
+ for i , k := range vecs {
28
+ norm := float64 (0 )
29
+ for _ , x := range k {
30
+ norm += float64 (x * x )
31
+ }
32
+ norm = math .Sqrt (norm )
33
+ for j , x := range k {
34
+ vecs [i ][j ] = x / float32 (norm )
35
+ }
36
+ }
37
+ }
38
+
25
39
var _ = Describe ("Integration tests for the stores backend(s) and internal APIs" , Label ("stores" ), func () {
26
40
Context ("Embedded Store get,set and delete" , func () {
27
41
var sl * model.ModelLoader
@@ -192,17 +206,8 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
192
206
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
193
207
keys := [][]float32 {{0.1 , 0.3 , 0.5 }, {0.5 , 0.5 , 0.5 }, {0.6 , 0.6 , - 0.6 }, {0.7 , - 0.7 , - 0.7 }}
194
208
vals := [][]byte {[]byte ("test0" ), []byte ("test1" ), []byte ("test2" ), []byte ("test3" )}
195
- // normalize the keys
196
- for i , k := range keys {
197
- norm := float64 (0 )
198
- for _ , x := range k {
199
- norm += float64 (x * x )
200
- }
201
- norm = math .Sqrt (norm )
202
- for j , x := range k {
203
- keys [i ][j ] = x / float32 (norm )
204
- }
205
- }
209
+
210
+ normalize (keys )
206
211
207
212
err := store .SetCols (context .Background (), sc , keys , vals )
208
213
Expect (err ).ToNot (HaveOccurred ())
@@ -225,5 +230,121 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
225
230
Expect (ks [1 ]).To (Equal (keys [1 ]))
226
231
Expect (vals [1 ]).To (Equal (vals [1 ]))
227
232
})
233
+
234
+ It ("It produces the correct cosine similarities for orthogonal and opposite unit vectors" , func () {
235
+ keys := [][]float32 {{1.0 , 0.0 , 0.0 }, {0.0 , 1.0 , 0.0 }, {0.0 , 0.0 , 1.0 }, {- 1.0 , 0.0 , 0.0 }}
236
+ vals := [][]byte {[]byte ("x" ), []byte ("y" ), []byte ("z" ), []byte ("-z" )}
237
+
238
+ err := store .SetCols (context .Background (), sc , keys , vals );
239
+ Expect (err ).ToNot (HaveOccurred ())
240
+
241
+ _ , _ , sims , err := store .Find (context .Background (), sc , keys [0 ], 4 )
242
+ Expect (err ).ToNot (HaveOccurred ())
243
+ Expect (sims ).To (Equal ([]float32 {1.0 , 0.0 , 0.0 , - 1.0 }))
244
+ })
245
+
246
+ It ("It produces the correct cosine similarities for orthogonal and opposite vectors" , func () {
247
+ keys := [][]float32 {{1.0 , 0.0 , 1.0 }, {0.0 , 2.0 , 0.0 }, {0.0 , 0.0 , - 1.0 }, {- 1.0 , 0.0 , - 1.0 }}
248
+ vals := [][]byte {[]byte ("x" ), []byte ("y" ), []byte ("z" ), []byte ("-z" )}
249
+
250
+ err := store .SetCols (context .Background (), sc , keys , vals );
251
+ Expect (err ).ToNot (HaveOccurred ())
252
+
253
+ _ , _ , sims , err := store .Find (context .Background (), sc , keys [0 ], 4 )
254
+ Expect (err ).ToNot (HaveOccurred ())
255
+ Expect (sims [0 ]).To (BeNumerically ("~" , 1 , 0.1 ))
256
+ Expect (sims [1 ]).To (BeNumerically ("~" , 0 , 0.1 ))
257
+ Expect (sims [2 ]).To (BeNumerically ("~" , - 0.7 , 0.1 ))
258
+ Expect (sims [3 ]).To (BeNumerically ("~" , - 1 , 0.1 ))
259
+ })
260
+
261
+ expectTriangleEq := func (keys [][]float32 , vals [][]byte ) {
262
+ sims := map [string ]map [string ]float32 {}
263
+
264
+ // compare every key vector pair and store the similarities in a lookup table
265
+ // that uses the values as keys
266
+ for i , k := range keys {
267
+ _ , valsk , simsk , err := store .Find (context .Background (), sc , k , 9 )
268
+ Expect (err ).ToNot (HaveOccurred ())
269
+
270
+ for j , v := range valsk {
271
+ p := string (vals [i ])
272
+ q := string (v )
273
+
274
+ if sims [p ] == nil {
275
+ sims [p ] = map [string ]float32 {}
276
+ }
277
+
278
+ //log.Debug().Strs("vals", []string{p, q}).Float32("similarity", simsk[j]).Send()
279
+
280
+ sims [p ][q ] = simsk [j ]
281
+ }
282
+ }
283
+
284
+ // Check that the triangle inequality holds for every combination of the triplet
285
+ // u, v and w
286
+ for _ , simsu := range sims {
287
+ for w , simw := range simsu {
288
+ // acos(u,w) <= ...
289
+ uws := math .Acos (float64 (simw ))
290
+
291
+ // ... acos(u,v) + acos(v,w)
292
+ for v , _ := range simsu {
293
+ uvws := math .Acos (float64 (simsu [v ])) + math .Acos (float64 (sims [v ][w ]))
294
+
295
+ //log.Debug().Str("u", u).Str("v", v).Str("w", w).Send()
296
+ //log.Debug().Float32("uw", simw).Float32("uv", simsu[v]).Float32("vw", sims[v][w]).Send()
297
+ Expect (uws ).To (BeNumerically ("<=" , uvws ))
298
+ }
299
+ }
300
+ }
301
+ }
302
+
303
+ It ("It obeys the triangle inequality for normalized values" , func () {
304
+ keys := [][]float32 {
305
+ {1.0 , 0.0 , 0.0 }, {0.0 , 1.0 , 0.0 }, {0.0 , 0.0 , 1.0 },
306
+ {- 1.0 , 0.0 , 0.0 }, {0.0 , - 1.0 , 0.0 }, {0.0 , 0.0 , - 1.0 },
307
+ {2.0 , 3.0 , 4.0 }, {9.0 , 7.0 , 1.0 }, {0.0 , - 1.2 , 2.3 },
308
+ }
309
+ vals := [][]byte {
310
+ []byte ("x" ), []byte ("y" ), []byte ("z" ),
311
+ []byte ("-x" ), []byte ("-y" ), []byte ("-z" ),
312
+ []byte ("u" ), []byte ("v" ), []byte ("w" ),
313
+ }
314
+
315
+ normalize (keys [6 :])
316
+
317
+ err := store .SetCols (context .Background (), sc , keys , vals );
318
+ Expect (err ).ToNot (HaveOccurred ())
319
+
320
+ expectTriangleEq (keys , vals )
321
+ })
322
+
323
+ It ("It obeys the triangle inequality" , func () {
324
+ rnd := rand .New (rand .NewSource (151 ))
325
+ keys := make ([][]float32 , 20 )
326
+ vals := make ([][]byte , 20 )
327
+
328
+ for i := range keys {
329
+ k := make ([]float32 , 768 )
330
+
331
+ for j := range k {
332
+ k [j ] = rnd .Float32 ()
333
+ }
334
+
335
+ keys [i ] = k
336
+ }
337
+
338
+ c := byte ('a' )
339
+ for i := range vals {
340
+ vals [i ] = []byte {c }
341
+ c += 1
342
+ }
343
+
344
+ err := store .SetCols (context .Background (), sc , keys , vals );
345
+ Expect (err ).ToNot (HaveOccurred ())
346
+
347
+ expectTriangleEq (keys , vals )
348
+ })
228
349
})
229
350
})
0 commit comments