Skip to content

Commit cfbb56d

Browse files
committed
fix(stores): Add tests for known results and triangle inequality
This adds some more tests to check the cosine similarity function has some expected mathematical properties.
1 parent 6913b29 commit cfbb56d

File tree

1 file changed

+132
-11
lines changed

1 file changed

+132
-11
lines changed

tests/integration/stores_test.go

Lines changed: 132 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"embed"
66
"math"
7+
"math/rand"
78
"os"
89
"path/filepath"
910

@@ -22,6 +23,19 @@ import (
2223
//go:embed backend-assets/*
2324
var backendAssets embed.FS
2425

26+
func normalize(vecs [][]float32) {
27+
for i, k := range vecs {
28+
norm := float64(0)
29+
for _, x := range k {
30+
norm += float64(x * x)
31+
}
32+
norm = math.Sqrt(norm)
33+
for j, x := range k {
34+
vecs[i][j] = x / float32(norm)
35+
}
36+
}
37+
}
38+
2539
var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() {
2640
Context("Embedded Store get,set and delete", func() {
2741
var sl *model.ModelLoader
@@ -192,17 +206,8 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
192206
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
193207
keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}
194208
vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")}
195-
// normalize the keys
196-
for i, k := range keys {
197-
norm := float64(0)
198-
for _, x := range k {
199-
norm += float64(x * x)
200-
}
201-
norm = math.Sqrt(norm)
202-
for j, x := range k {
203-
keys[i][j] = x / float32(norm)
204-
}
205-
}
209+
210+
normalize(keys)
206211

207212
err := store.SetCols(context.Background(), sc, keys, vals)
208213
Expect(err).ToNot(HaveOccurred())
@@ -225,5 +230,121 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
225230
Expect(ks[1]).To(Equal(keys[1]))
226231
Expect(vals[1]).To(Equal(vals[1]))
227232
})
233+
234+
It("It produces the correct cosine similarities for orthogonal and opposite unit vectors", func() {
235+
keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}}
236+
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}
237+
238+
err := store.SetCols(context.Background(), sc, keys, vals);
239+
Expect(err).ToNot(HaveOccurred())
240+
241+
_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
242+
Expect(err).ToNot(HaveOccurred())
243+
Expect(sims).To(Equal([]float32{1.0, 0.0, 0.0, -1.0}))
244+
})
245+
246+
It("It produces the correct cosine similarities for orthogonal and opposite vectors", func() {
247+
keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}}
248+
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}
249+
250+
err := store.SetCols(context.Background(), sc, keys, vals);
251+
Expect(err).ToNot(HaveOccurred())
252+
253+
_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
254+
Expect(err).ToNot(HaveOccurred())
255+
Expect(sims[0]).To(BeNumerically("~", 1, 0.1))
256+
Expect(sims[1]).To(BeNumerically("~", 0, 0.1))
257+
Expect(sims[2]).To(BeNumerically("~", -0.7, 0.1))
258+
Expect(sims[3]).To(BeNumerically("~", -1, 0.1))
259+
})
260+
261+
expectTriangleEq := func(keys [][]float32, vals [][]byte) {
262+
sims := map[string]map[string]float32{}
263+
264+
// compare every key vector pair and store the similarities in a lookup table
265+
// that uses the values as keys
266+
for i, k := range keys {
267+
_, valsk, simsk, err := store.Find(context.Background(), sc, k, 9)
268+
Expect(err).ToNot(HaveOccurred())
269+
270+
for j, v := range valsk {
271+
p := string(vals[i])
272+
q := string(v)
273+
274+
if sims[p] == nil {
275+
sims[p] = map[string]float32{}
276+
}
277+
278+
//log.Debug().Strs("vals", []string{p, q}).Float32("similarity", simsk[j]).Send()
279+
280+
sims[p][q] = simsk[j]
281+
}
282+
}
283+
284+
// Check that the triangle inequality holds for every combination of the triplet
285+
// u, v and w
286+
for _, simsu := range sims {
287+
for w, simw := range simsu {
288+
// acos(u,w) <= ...
289+
uws := math.Acos(float64(simw))
290+
291+
// ... acos(u,v) + acos(v,w)
292+
for v, _ := range simsu {
293+
uvws := math.Acos(float64(simsu[v])) + math.Acos(float64(sims[v][w]))
294+
295+
//log.Debug().Str("u", u).Str("v", v).Str("w", w).Send()
296+
//log.Debug().Float32("uw", simw).Float32("uv", simsu[v]).Float32("vw", sims[v][w]).Send()
297+
Expect(uws).To(BeNumerically("<=", uvws))
298+
}
299+
}
300+
}
301+
}
302+
303+
It("It obeys the triangle inequality for normalized values", func() {
304+
keys := [][]float32{
305+
{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0},
306+
{-1.0, 0.0, 0.0}, {0.0, -1.0, 0.0}, {0.0, 0.0, -1.0},
307+
{2.0, 3.0, 4.0}, {9.0, 7.0, 1.0}, {0.0, -1.2, 2.3},
308+
}
309+
vals := [][]byte{
310+
[]byte("x"), []byte("y"), []byte("z"),
311+
[]byte("-x"), []byte("-y"), []byte("-z"),
312+
[]byte("u"), []byte("v"), []byte("w"),
313+
}
314+
315+
normalize(keys[6:])
316+
317+
err := store.SetCols(context.Background(), sc, keys, vals);
318+
Expect(err).ToNot(HaveOccurred())
319+
320+
expectTriangleEq(keys, vals)
321+
})
322+
323+
It("It obeys the triangle inequality", func() {
324+
rnd := rand.New(rand.NewSource(151))
325+
keys := make([][]float32, 20)
326+
vals := make([][]byte, 20)
327+
328+
for i := range keys {
329+
k := make([]float32, 768)
330+
331+
for j := range k {
332+
k[j] = rnd.Float32()
333+
}
334+
335+
keys[i] = k
336+
}
337+
338+
c := byte('a')
339+
for i := range vals {
340+
vals[i] = []byte{c}
341+
c += 1
342+
}
343+
344+
err := store.SetCols(context.Background(), sc, keys, vals);
345+
Expect(err).ToNot(HaveOccurred())
346+
347+
expectTriangleEq(keys, vals)
348+
})
228349
})
229350
})

0 commit comments

Comments
 (0)