Skip to content

Commit e471548

Browse files
michaeljmarshalldjatnieks
authored andcommitted
CNDB-13583: Add vector ann and brute force metrics (#1683)
### What is the issue riptano/cndb#13583 ### What does this PR fix and why was it fixed This PR adds comprehensive metrics for Storage Attached Indexes (SAI) vector search operations, providing crucial insights into both ANN (Approximate Nearest Neighbor) graph searches and brute force operations. New Vector Search Metrics: Search Operation Counters: - `ANNNodesVisited`: Total number of nodes visited during ANN searches (this is equivalent to approximate similarity score computations) - `ANNNodesReranked`: Number of nodes that underwent exact distance computation for reranking (this is equivalent to exact similarity score computations) - `ANNNodesExpanded`: Total number of nodes whose edges were explored - `ANNNodesExpandedBaseLayer`: Number of nodes expanded in the base layer of the graph - `ANNGraphSearches`: Count of new graph searches initiated - `ANNGraphResumes`: Count of resumed graph searches (when a search continues from previous results) - `ANNGraphSearchLatency`: Timer measuring individual graph search latency (Note: This measures per-graph search time, not total query time which may involve multiple graphs) Brute Force Operation Counters: - `BruteForceNodesVisited`: Number of nodes visited during brute force searches (approximate similarity comparisons) - `BruteForceNodesReranked`: Number of nodes that underwent exact similarity computation during brute force searches Memory Usage Tracking: - `quantizationMemoryBytes`: Current memory usage by the quantization (PQ or BQ) data structures - `ordinalsMapMemoryBytes`: Current memory usage by ordinals mapping structures (only matters in some cases) - `onDiskGraphsCount`: Number of currently loaded graph segments - `onDiskGraphVectorsCount`: Total number of vectors in currently loaded graphs These metrics will help us: 1. Understand if we are performing more comparisons than expected 2. Get insight into number of graphs queried 3. Get insight into the brute force vs graph query path 4. Understand current memory utilization The counters provide operations/second metrics, allowing calculation of per-query averages by dividing by the number of queries. The memory tracking metrics help monitor resource usage across graph segments as they are loaded and unloaded.
1 parent f889e22 commit e471548

16 files changed

+405
-38
lines changed

src/java/org/apache/cassandra/index/sai/IndexContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ public IndexContext(@Nonnull String keyspace,
198198
this.hasEuclideanSimilarityFunc = vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
199199

200200
this.indexMetrics = new IndexMetrics(this);
201-
this.columnQueryMetrics = isLiteral() ? new ColumnQueryMetrics.TrieIndexMetrics(keyspace, table, getIndexName())
201+
this.columnQueryMetrics = isVector() ? new ColumnQueryMetrics.VectorIndexMetrics(keyspace, table, getIndexName()) :
202+
isLiteral() ? new ColumnQueryMetrics.TrieIndexMetrics(keyspace, table, getIndexName())
202203
: new ColumnQueryMetrics.BKDIndexMetrics(keyspace, table, getIndexName());
203204

204205
}

src/java/org/apache/cassandra/index/sai/QueryContext.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ public class QueryContext
6161

6262
private final LongAdder queryTimeouts = new LongAdder();
6363

64-
private final LongAdder annNodesVisited = new LongAdder();
64+
private final LongAdder annGraphSearchLatency = new LongAdder();
65+
6566
private float annRerankFloor = 0.0f; // only called from single-threaded setup code
6667

6768
private final LongAdder shadowedPrimaryKeyCount = new LongAdder();
@@ -139,9 +140,10 @@ public void addQueryTimeouts(long val)
139140
{
140141
queryTimeouts.add(val);
141142
}
142-
public void addAnnNodesVisited(long val)
143+
144+
public void addAnnGraphSearchLatency(long val)
143145
{
144-
annNodesVisited.add(val);
146+
annGraphSearchLatency.add(val);
145147
}
146148

147149
public void setFilterSortOrder(FilterSortOrder filterSortOrder)
@@ -202,9 +204,9 @@ public long queryTimeouts()
202204
{
203205
return queryTimeouts.longValue();
204206
}
205-
public long annNodesVisited()
207+
public long annGraphSearchLatency()
206208
{
207-
return annNodesVisited.longValue();
209+
return annGraphSearchLatency.longValue();
208210
}
209211

210212
public FilterSortOrder filterSortOrder()

src/java/org/apache/cassandra/index/sai/disk/v2/V2OnDiskOrdinalsMap.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ public Structure getStructure()
9898
return canFastMapOrdinalsView ? Structure.ONE_TO_ONE : Structure.ZERO_OR_ONE_TO_MANY;
9999
}
100100

101+
@Override
102+
public long cachedBytesUsed()
103+
{
104+
return 0;
105+
}
106+
101107
@Override
102108
public RowIdsView getRowIdsView()
103109
{

src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.apache.cassandra.index.sai.disk.vector.VectorCompression;
5959
import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex;
6060
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
61+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
6162
import org.apache.cassandra.index.sai.plan.Expression;
6263
import org.apache.cassandra.index.sai.plan.Orderer;
6364
import org.apache.cassandra.index.sai.plan.Plan.CostCoefficients;
@@ -103,6 +104,7 @@ public class V2VectorIndexSearcher extends IndexSearcher
103104
private final PrimaryKey.Factory keyFactory;
104105
private final PairedSlidingWindowReservoir expectedActualNodesVisited = new PairedSlidingWindowReservoir(20);
105106
private final ThreadLocal<SparseBits> cachedBits;
107+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
106108

107109
protected V2VectorIndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
108110
PerIndexFiles perIndexFiles,
@@ -113,7 +115,8 @@ protected V2VectorIndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
113115
super(primaryKeyMapFactory, perIndexFiles, segmentMetadata, indexContext);
114116
this.graph = graph;
115117
this.keyFactory = PrimaryKey.factory(indexContext.comparator(), indexContext.indexFeatureSet());
116-
cachedBits = ThreadLocal.withInitial(SparseBits::new);
118+
this.cachedBits = ThreadLocal.withInitial(SparseBits::new);
119+
this.columnQueryMetrics = (ColumnQueryMetrics.VectorIndexMetrics) indexContext.getColumnQueryMetrics();
117120
}
118121

119122
@Override
@@ -194,10 +197,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
194197
if (RangeUtil.coversFullRing(keyRange))
195198
{
196199
var estimate = estimateCost(rerankK, graph.size());
197-
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, visited -> {
198-
estimate.updateStatistics(visited);
199-
context.addAnnNodesVisited(visited);
200-
});
200+
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, estimate::updateStatistics);
201201
}
202202

203203
PrimaryKey firstPrimaryKey = keyFactory.createTokenOnly(keyRange.left.getToken());
@@ -214,7 +214,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
214214

215215
// if the range covers the entire segment, skip directly to an index search
216216
if (minSSTableRowId <= metadata.minSSTableRowId && maxSSTableRowId >= metadata.maxSSTableRowId)
217-
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, context::addAnnNodesVisited);
217+
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, visited -> {});
218218

219219
minSSTableRowId = Math.max(minSSTableRowId, metadata.minSSTableRowId);
220220
maxSSTableRowId = min(maxSSTableRowId, metadata.maxSSTableRowId);
@@ -263,10 +263,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
263263
// the trouble to add it.
264264
var betterCostEstimate = estimateCost(rerankK, cardinality);
265265

266-
return graph.search(queryVector, limit, rerankK, threshold, bits, context, visited -> {
267-
betterCostEstimate.updateStatistics(visited);
268-
context.addAnnNodesVisited(visited);
269-
});
266+
return graph.search(queryVector, limit, rerankK, threshold, bits, context, betterCostEstimate::updateStatistics);
270267
}
271268
}
272269

@@ -305,8 +302,9 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
305302
segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> {
306303
approximateScores.push(i, scoreFunction.similarityTo(ordinal));
307304
});
305+
columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size());
308306
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
309-
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK);
307+
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics);
310308
}
311309

312310
/**
@@ -325,6 +323,7 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
325323
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
326324
scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal));
327325
});
326+
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
328327
return new NodeQueueRowIdIterator(scoredRowIds);
329328
}
330329
}
@@ -348,6 +347,7 @@ private CloseableIterator<RowIdWithScore> filterByBruteForce(VectorFloat<?> quer
348347
if (score >= threshold)
349348
results.add(new RowIdWithScore(segmentRowId, score));
350349
});
350+
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
351351
}
352352
return CloseableIterator.wrap(results.iterator());
353353
}

src/java/org/apache/cassandra/index/sai/disk/v5/V5OnDiskOrdinalsMap.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,26 @@ public void close() {
354354
// no-op
355355
}
356356
}
357+
358+
@Override
359+
public long cachedBytesUsed()
360+
{
361+
if (structure != V5VectorPostingsWriter.Structure.ONE_TO_MANY) {
362+
return 0;
363+
}
364+
365+
long bytes = 0;
366+
if (extraRowIds != null) {
367+
bytes += extraRowIds.length * 4L;
368+
}
369+
if (extraOrdinals != null) {
370+
bytes += extraOrdinals.length * 4L;
371+
}
372+
if (extraRowsByOrdinal != null) {
373+
for (int[] rowIds : extraRowsByOrdinal.values()) {
374+
bytes += rowIds.length * 4L;
375+
}
376+
}
377+
return bytes;
378+
}
357379
}

src/java/org/apache/cassandra/index/sai/disk/vector/AutoResumingNodeScoreIterator.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424

2525
import io.github.jbellis.jvector.graph.GraphSearcher;
2626
import io.github.jbellis.jvector.graph.SearchResult;
27+
import org.apache.cassandra.index.sai.QueryContext;
28+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
2729
import org.apache.cassandra.tracing.Tracing;
2830
import org.apache.cassandra.utils.AbstractIterator;
2931

3032
import static java.lang.Math.max;
33+
import static org.apache.cassandra.utils.Clock.Global.nanoTime;
3134

3235
/**
3336
* An iterator over {@link SearchResult.NodeScore} backed by a {@link SearchResult} that resumes search
@@ -41,6 +44,8 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
4144
private final int rerankK;
4245
private final boolean inMemory;
4346
private final String source;
47+
private final QueryContext context;
48+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
4449
private final IntConsumer nodesVisitedConsumer;
4550
private Iterator<SearchResult.NodeScore> nodeScores;
4651
private int cumulativeNodesVisited;
@@ -51,6 +56,8 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
5156
* no more results.
5257
* @param searcher the {@link GraphSearcher} to use to resume search.
5358
* @param result the first {@link SearchResult} to iterate over
59+
* @param context the {@link QueryContext} to use to record metrics
60+
* @param columnQueryMetrics object to record metrics
5461
* @param nodesVisitedConsumer a consumer that accepts the total number of nodes visited
5562
* @param limit the limit to pass to the {@link GraphSearcher} when resuming search
5663
* @param rerankK the rerankK to pass to the {@link GraphSearcher} when resuming search
@@ -60,6 +67,8 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
6067
public AutoResumingNodeScoreIterator(GraphSearcher searcher,
6168
GraphSearcherAccessManager accessManager,
6269
SearchResult result,
70+
QueryContext context,
71+
ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics,
6372
IntConsumer nodesVisitedConsumer,
6473
int limit,
6574
int rerankK,
@@ -69,7 +78,9 @@ public AutoResumingNodeScoreIterator(GraphSearcher searcher,
6978
this.searcher = searcher;
7079
this.accessManager = accessManager;
7180
this.nodeScores = Arrays.stream(result.getNodes()).iterator();
72-
this.cumulativeNodesVisited = result.getVisitedCount();
81+
this.context = context;
82+
this.columnQueryMetrics = columnQueryMetrics;
83+
this.cumulativeNodesVisited = 0;
7384
this.nodesVisitedConsumer = nodesVisitedConsumer;
7485
this.limit = max(1, limit / 2); // we shouldn't need as many results on resume
7586
this.rerankK = rerankK;
@@ -83,21 +94,29 @@ protected SearchResult.NodeScore computeNext()
8394
if (nodeScores.hasNext())
8495
return nodeScores.next();
8596

97+
long start = nanoTime();
98+
99+
// Search deeper into the graph
86100
var nextResult = searcher.resume(limit, rerankK);
87-
maybeLogTrace(nextResult);
101+
102+
// Record metrics
103+
long elapsed = nanoTime() - start;
104+
columnQueryMetrics.onSearchResult(nextResult, elapsed, true);
105+
context.addAnnGraphSearchLatency(elapsed);
88106
cumulativeNodesVisited += nextResult.getVisitedCount();
107+
108+
if (Tracing.isTracing())
109+
{
110+
String msg = inMemory ? "Memory based ANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}"
111+
: "Disk based ANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}";
112+
Tracing.trace(msg, limit, rerankK, nextResult.getVisitedCount(), nextResult.getRerankedCount(), nextResult.getNodes().length, source);
113+
}
114+
89115
// If the next result is empty, we are done searching.
90116
nodeScores = Arrays.stream(nextResult.getNodes()).iterator();
91117
return nodeScores.hasNext() ? nodeScores.next() : endOfData();
92118
}
93119

94-
private void maybeLogTrace(SearchResult result)
95-
{
96-
String msg = inMemory ? "ANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}"
97-
: "DiskANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}";
98-
Tracing.trace(msg, limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source);
99-
}
100-
101120
@Override
102121
public void close()
103122
{

src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import io.github.jbellis.jvector.graph.NodeQueue;
2222
import io.github.jbellis.jvector.util.BoundedLongHeap;
23+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
2324
import org.apache.cassandra.index.sai.utils.SegmentRowIdOrdinalPairs;
2425
import org.apache.cassandra.index.sai.utils.RowIdWithMeta;
2526
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
@@ -62,6 +63,7 @@ public class BruteForceRowIdIterator extends AbstractIterator<RowIdWithScore>
6263
private final CloseableReranker reranker;
6364
private final int topK;
6465
private final int limit;
66+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
6567
private int rerankedCount;
6668

6769
/**
@@ -70,12 +72,14 @@ public class BruteForceRowIdIterator extends AbstractIterator<RowIdWithScore>
7072
* @param reranker A function that takes a graph ordinal and returns the exact similarity score
7173
* @param limit The query limit
7274
* @param topK The number of vectors to resolve and score before returning results
75+
* @param columnQueryMetrics object to record metrics
7376
*/
7477
public BruteForceRowIdIterator(NodeQueue approximateScoreQueue,
7578
SegmentRowIdOrdinalPairs segmentOrdinalPairs,
7679
CloseableReranker reranker,
7780
int limit,
78-
int topK)
81+
int topK,
82+
ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics)
7983
{
8084
this.approximateScoreQueue = approximateScoreQueue;
8185
this.segmentOrdinalPairs = segmentOrdinalPairs;
@@ -84,21 +88,25 @@ public BruteForceRowIdIterator(NodeQueue approximateScoreQueue,
8488
assert topK >= limit : "topK must be greater than or equal to limit. Found: " + topK + " < " + limit;
8589
this.limit = limit;
8690
this.topK = topK;
91+
this.columnQueryMetrics = columnQueryMetrics;
8792
this.rerankedCount = topK; // placeholder to kick off computeNext
8893
}
8994

9095
@Override
9196
protected RowIdWithScore computeNext() {
9297
int consumed = rerankedCount - exactScoreQueue.size();
9398
if (consumed >= limit) {
99+
int exactComparisons = 0;
94100
// Refill the exactScoreQueue until it reaches topK exact scores, or the approximate score queue is empty
95101
while (approximateScoreQueue.size() > 0 && exactScoreQueue.size() < topK) {
96102
int segmentOrdinalIndex = approximateScoreQueue.pop();
97103
int rowId = segmentOrdinalPairs.getSegmentRowId(segmentOrdinalIndex);
98104
int ordinal = segmentOrdinalPairs.getOrdinal(segmentOrdinalIndex);
99105
float score = reranker.similarityTo(ordinal);
106+
exactComparisons++;
100107
exactScoreQueue.push(rowId, score);
101108
}
109+
columnQueryMetrics.onBruteForceNodesReranked(exactComparisons);
102110
rerankedCount = exactScoreQueue.size();
103111
}
104112
if (exactScoreQueue.size() == 0)

0 commit comments

Comments
 (0)