Skip to content

Commit 8ed43a6

Browse files
CNDB-13583: Add vector ann and brute force metrics
1 parent 18fe538 commit 8ed43a6

16 files changed

+402
-38
lines changed

src/java/org/apache/cassandra/index/sai/IndexContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ public IndexContext(@Nonnull String keyspace,
193193
this.hasEuclideanSimilarityFunc = vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
194194

195195
this.indexMetrics = new IndexMetrics(this);
196-
this.columnQueryMetrics = isLiteral() ? new ColumnQueryMetrics.TrieIndexMetrics(keyspace, table, getIndexName())
196+
this.columnQueryMetrics = isVector() ? new ColumnQueryMetrics.VectorIndexMetrics(keyspace, table, getIndexName()) :
197+
isLiteral() ? new ColumnQueryMetrics.TrieIndexMetrics(keyspace, table, getIndexName())
197198
: new ColumnQueryMetrics.BKDIndexMetrics(keyspace, table, getIndexName());
198199

199200
}

src/java/org/apache/cassandra/index/sai/QueryContext.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ public class QueryContext
6060

6161
private final LongAdder queryTimeouts = new LongAdder();
6262

63-
private final LongAdder annNodesVisited = new LongAdder();
63+
private final LongAdder annGraphSearchLatency = new LongAdder();
64+
6465
private float annRerankFloor = 0.0f; // only called from single-threaded setup code
6566

6667
private final LongAdder shadowedPrimaryKeyCount = new LongAdder();
@@ -138,9 +139,10 @@ public void addQueryTimeouts(long val)
138139
{
139140
queryTimeouts.add(val);
140141
}
141-
public void addAnnNodesVisited(long val)
142+
143+
public void addAnnGraphSearchLatency(long val)
142144
{
143-
annNodesVisited.add(val);
145+
annGraphSearchLatency.add(val);
144146
}
145147

146148
public void setFilterSortOrder(FilterSortOrder filterSortOrder)
@@ -201,9 +203,9 @@ public long queryTimeouts()
201203
{
202204
return queryTimeouts.longValue();
203205
}
204-
public long annNodesVisited()
206+
public long annGraphSearchLatency()
205207
{
206-
return annNodesVisited.longValue();
208+
return annGraphSearchLatency.longValue();
207209
}
208210

209211
public FilterSortOrder filterSortOrder()

src/java/org/apache/cassandra/index/sai/disk/v2/V2OnDiskOrdinalsMap.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ public Structure getStructure()
9898
return canFastMapOrdinalsView ? Structure.ONE_TO_ONE : Structure.ZERO_OR_ONE_TO_MANY;
9999
}
100100

101+
@Override
102+
public long cachedBytesUsed()
103+
{
104+
return 0;
105+
}
106+
101107
@Override
102108
public RowIdsView getRowIdsView()
103109
{

src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.apache.cassandra.index.sai.disk.vector.VectorCompression;
5959
import org.apache.cassandra.index.sai.disk.vector.VectorMemtableIndex;
6060
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
61+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
6162
import org.apache.cassandra.index.sai.plan.Expression;
6263
import org.apache.cassandra.index.sai.plan.Orderer;
6364
import org.apache.cassandra.index.sai.plan.Plan.CostCoefficients;
@@ -103,6 +104,7 @@ public class V2VectorIndexSearcher extends IndexSearcher
103104
private final PrimaryKey.Factory keyFactory;
104105
private final PairedSlidingWindowReservoir expectedActualNodesVisited = new PairedSlidingWindowReservoir(20);
105106
private final ThreadLocal<SparseBits> cachedBits;
107+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
106108

107109
protected V2VectorIndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
108110
PerIndexFiles perIndexFiles,
@@ -113,7 +115,8 @@ protected V2VectorIndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory,
113115
super(primaryKeyMapFactory, perIndexFiles, segmentMetadata, indexContext);
114116
this.graph = graph;
115117
this.keyFactory = PrimaryKey.factory(indexContext.comparator(), indexContext.indexFeatureSet());
116-
cachedBits = ThreadLocal.withInitial(SparseBits::new);
118+
this.cachedBits = ThreadLocal.withInitial(SparseBits::new);
119+
this.columnQueryMetrics = (ColumnQueryMetrics.VectorIndexMetrics) indexContext.getColumnQueryMetrics();
117120
}
118121

119122
@Override
@@ -194,10 +197,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
194197
if (RangeUtil.coversFullRing(keyRange))
195198
{
196199
var estimate = estimateCost(rerankK, graph.size());
197-
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, visited -> {
198-
estimate.updateStatistics(visited);
199-
context.addAnnNodesVisited(visited);
200-
});
200+
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, estimate::updateStatistics);
201201
}
202202

203203
PrimaryKey firstPrimaryKey = keyFactory.createTokenOnly(keyRange.left.getToken());
@@ -214,7 +214,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
214214

215215
// if the range covers the entire segment, skip directly to an index search
216216
if (minSSTableRowId <= metadata.minSSTableRowId && maxSSTableRowId >= metadata.maxSSTableRowId)
217-
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, context::addAnnNodesVisited);
217+
return graph.search(queryVector, limit, rerankK, threshold, Bits.ALL, context, visited -> {});
218218

219219
minSSTableRowId = Math.max(minSSTableRowId, metadata.minSSTableRowId);
220220
maxSSTableRowId = min(maxSSTableRowId, metadata.maxSSTableRowId);
@@ -263,10 +263,7 @@ private CloseableIterator<RowIdWithScore> searchInternal(AbstractBounds<Partitio
263263
// the trouble to add it.
264264
var betterCostEstimate = estimateCost(rerankK, cardinality);
265265

266-
return graph.search(queryVector, limit, rerankK, threshold, bits, context, visited -> {
267-
betterCostEstimate.updateStatistics(visited);
268-
context.addAnnNodesVisited(visited);
269-
});
266+
return graph.search(queryVector, limit, rerankK, threshold, bits, context, betterCostEstimate::updateStatistics);
270267
}
271268
}
272269

@@ -305,8 +302,9 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(CompressedVectors cv
305302
segmentOrdinalPairs.forEachIndexOrdinalPair((i, ordinal) -> {
306303
approximateScores.push(i, scoreFunction.similarityTo(ordinal));
307304
});
305+
columnQueryMetrics.onBruteForceNodesVisited(segmentOrdinalPairs.size());
308306
var reranker = new CloseableReranker(similarityFunction, queryVector, graph.getView());
309-
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK);
307+
return new BruteForceRowIdIterator(approximateScores, segmentOrdinalPairs, reranker, limit, rerankK, columnQueryMetrics);
310308
}
311309

312310
/**
@@ -325,6 +323,7 @@ private CloseableIterator<RowIdWithScore> orderByBruteForce(VectorFloat<?> query
325323
segmentOrdinalPairs.forEachSegmentRowIdOrdinalPair((segmentRowId, ordinal) -> {
326324
scoredRowIds.push(segmentRowId, esf.similarityTo(ordinal));
327325
});
326+
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
328327
return new NodeQueueRowIdIterator(scoredRowIds);
329328
}
330329
}
@@ -348,6 +347,7 @@ private CloseableIterator<RowIdWithScore> filterByBruteForce(VectorFloat<?> quer
348347
if (score >= threshold)
349348
results.add(new RowIdWithScore(segmentRowId, score));
350349
});
350+
columnQueryMetrics.onBruteForceNodesReranked(segmentOrdinalPairs.size());
351351
}
352352
return CloseableIterator.wrap(results.iterator());
353353
}

src/java/org/apache/cassandra/index/sai/disk/v5/V5OnDiskOrdinalsMap.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,26 @@ public void close() {
354354
// no-op
355355
}
356356
}
357+
358+
@Override
359+
public long cachedBytesUsed()
360+
{
361+
if (structure != V5VectorPostingsWriter.Structure.ONE_TO_MANY) {
362+
return 0;
363+
}
364+
365+
long bytes = 0;
366+
if (extraRowIds != null) {
367+
bytes += extraRowIds.length * 4L;
368+
}
369+
if (extraOrdinals != null) {
370+
bytes += extraOrdinals.length * 4L;
371+
}
372+
if (extraRowsByOrdinal != null) {
373+
for (int[] rowIds : extraRowsByOrdinal.values()) {
374+
bytes += rowIds.length * 4L;
375+
}
376+
}
377+
return bytes;
378+
}
357379
}

src/java/org/apache/cassandra/index/sai/disk/vector/AutoResumingNodeScoreIterator.java

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import io.github.jbellis.jvector.graph.GraphSearcher;
2626
import io.github.jbellis.jvector.graph.SearchResult;
27+
import org.apache.cassandra.index.sai.QueryContext;
28+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
2729
import org.apache.cassandra.tracing.Tracing;
2830
import org.apache.cassandra.utils.AbstractIterator;
2931

@@ -41,6 +43,8 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
4143
private final int rerankK;
4244
private final boolean inMemory;
4345
private final String source;
46+
private final QueryContext context;
47+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
4448
private final IntConsumer nodesVisitedConsumer;
4549
private Iterator<SearchResult.NodeScore> nodeScores;
4650
private int cumulativeNodesVisited;
@@ -51,6 +55,8 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
5155
* no more results.
5256
* @param searcher the {@link GraphSearcher} to use to resume search.
5357
* @param result the first {@link SearchResult} to iterate over
58+
* @param context the {@link QueryContext} to use to record metrics
59+
* @param columnQueryMetrics object to record metrics
5460
* @param nodesVisitedConsumer a consumer that accepts the total number of nodes visited
5561
* @param limit the limit to pass to the {@link GraphSearcher} when resuming search
5662
* @param rerankK the rerankK to pass to the {@link GraphSearcher} when resuming search
@@ -60,6 +66,8 @@ public class AutoResumingNodeScoreIterator extends AbstractIterator<SearchResult
6066
public AutoResumingNodeScoreIterator(GraphSearcher searcher,
6167
GraphSearcherAccessManager accessManager,
6268
SearchResult result,
69+
QueryContext context,
70+
ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics,
6371
IntConsumer nodesVisitedConsumer,
6472
int limit,
6573
int rerankK,
@@ -69,7 +77,9 @@ public AutoResumingNodeScoreIterator(GraphSearcher searcher,
6977
this.searcher = searcher;
7078
this.accessManager = accessManager;
7179
this.nodeScores = Arrays.stream(result.getNodes()).iterator();
72-
this.cumulativeNodesVisited = result.getVisitedCount();
80+
this.context = context;
81+
this.columnQueryMetrics = columnQueryMetrics;
82+
this.cumulativeNodesVisited = 0;
7383
this.nodesVisitedConsumer = nodesVisitedConsumer;
7484
this.limit = max(1, limit / 2); // we shouldn't need as many results on resume
7585
this.rerankK = rerankK;
@@ -83,21 +93,29 @@ protected SearchResult.NodeScore computeNext()
8393
if (nodeScores.hasNext())
8494
return nodeScores.next();
8595

96+
long start = System.nanoTime();
97+
98+
// Search deeper into the graph
8699
var nextResult = searcher.resume(limit, rerankK);
87-
maybeLogTrace(nextResult);
100+
101+
// Record metrics
102+
long elapsed = System.nanoTime() - start;
103+
columnQueryMetrics.onSearchResult(nextResult, elapsed, true);
104+
context.addAnnGraphSearchLatency(elapsed);
88105
cumulativeNodesVisited += nextResult.getVisitedCount();
106+
107+
if (Tracing.isTracing())
108+
{
109+
String msg = inMemory ? "Memory based ANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}"
110+
: "Disk based ANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}";
111+
Tracing.trace(msg, limit, rerankK, nextResult.getVisitedCount(), nextResult.getRerankedCount(), nextResult.getNodes().length, source);
112+
}
113+
89114
// If the next result is empty, we are done searching.
90115
nodeScores = Arrays.stream(nextResult.getNodes()).iterator();
91116
return nodeScores.hasNext() ? nodeScores.next() : endOfData();
92117
}
93118

94-
private void maybeLogTrace(SearchResult result)
95-
{
96-
String msg = inMemory ? "ANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}"
97-
: "DiskANN resume for {}/{} visited {} nodes, reranked {} to return {} results from {}";
98-
Tracing.trace(msg, limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source);
99-
}
100-
101119
@Override
102120
public void close()
103121
{

src/java/org/apache/cassandra/index/sai/disk/vector/BruteForceRowIdIterator.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import io.github.jbellis.jvector.graph.NodeQueue;
2222
import io.github.jbellis.jvector.util.BoundedLongHeap;
23+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
2324
import org.apache.cassandra.index.sai.utils.SegmentRowIdOrdinalPairs;
2425
import org.apache.cassandra.index.sai.utils.RowIdWithMeta;
2526
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
@@ -62,6 +63,7 @@ public class BruteForceRowIdIterator extends AbstractIterator<RowIdWithScore>
6263
private final CloseableReranker reranker;
6364
private final int topK;
6465
private final int limit;
66+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
6567
private int rerankedCount;
6668

6769
/**
@@ -70,12 +72,14 @@ public class BruteForceRowIdIterator extends AbstractIterator<RowIdWithScore>
7072
* @param reranker A function that takes a graph ordinal and returns the exact similarity score
7173
* @param limit The query limit
7274
* @param topK The number of vectors to resolve and score before returning results
75+
* @param columnQueryMetrics object to record metrics
7376
*/
7477
public BruteForceRowIdIterator(NodeQueue approximateScoreQueue,
7578
SegmentRowIdOrdinalPairs segmentOrdinalPairs,
7679
CloseableReranker reranker,
7780
int limit,
78-
int topK)
81+
int topK,
82+
ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics)
7983
{
8084
this.approximateScoreQueue = approximateScoreQueue;
8185
this.segmentOrdinalPairs = segmentOrdinalPairs;
@@ -84,21 +88,25 @@ public BruteForceRowIdIterator(NodeQueue approximateScoreQueue,
8488
assert topK >= limit : "topK must be greater than or equal to limit. Found: " + topK + " < " + limit;
8589
this.limit = limit;
8690
this.topK = topK;
91+
this.columnQueryMetrics = columnQueryMetrics;
8792
this.rerankedCount = topK; // placeholder to kick off computeNext
8893
}
8994

9095
@Override
9196
protected RowIdWithScore computeNext() {
9297
int consumed = rerankedCount - exactScoreQueue.size();
9398
if (consumed >= limit) {
99+
int exactComparisons = 0;
94100
// Refill the exactScoreQueue until it reaches topK exact scores, or the approximate score queue is empty
95101
while (approximateScoreQueue.size() > 0 && exactScoreQueue.size() < topK) {
96102
int segmentOrdinalIndex = approximateScoreQueue.pop();
97103
int rowId = segmentOrdinalPairs.getSegmentRowId(segmentOrdinalIndex);
98104
int ordinal = segmentOrdinalPairs.getOrdinal(segmentOrdinalIndex);
99105
float score = reranker.similarityTo(ordinal);
106+
exactComparisons++;
100107
exactScoreQueue.push(rowId, score);
101108
}
109+
columnQueryMetrics.onBruteForceNodesReranked(exactComparisons);
102110
rerankedCount = exactScoreQueue.size();
103111
}
104112
if (exactScoreQueue.size() == 0)

src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat;
5050
import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter.Structure;
5151
import org.apache.cassandra.index.sai.disk.vector.CassandraOnHeapGraph.PQVersion;
52+
import org.apache.cassandra.index.sai.metrics.ColumnQueryMetrics;
5253
import org.apache.cassandra.index.sai.utils.RowIdWithScore;
5354
import org.apache.cassandra.io.sstable.SSTableId;
5455
import org.apache.cassandra.io.util.FileHandle;
@@ -63,6 +64,7 @@ public class CassandraDiskAnn
6364

6465
public static final int PQ_MAGIC = 0xB011A61C; // PQ_MAGIC, with a lot of liberties taken
6566
protected final PerIndexFiles indexFiles;
67+
private final ColumnQueryMetrics.VectorIndexMetrics columnQueryMetrics;
6668
protected final SegmentMetadata.ComponentMetadataMap componentMetadatas;
6769

6870
private final SSTableId<?> source;
@@ -85,6 +87,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component
8587
this.source = sstableContext.sstable().getId();
8688
this.componentMetadatas = componentMetadatas;
8789
this.indexFiles = indexFiles;
90+
this.columnQueryMetrics = (ColumnQueryMetrics.VectorIndexMetrics) context.getColumnQueryMetrics();
8891

8992
similarityFunction = context.getIndexWriterConfig().getSimilarityFunction();
9093

@@ -152,8 +155,16 @@ else if (compressionType == VectorCompression.CompressionType.BINARY_QUANTIZATIO
152155

153156
SegmentMetadata.ComponentMetadata postingListsMetadata = this.componentMetadatas.get(IndexComponentType.POSTING_LISTS);
154157
ordinalsMap = omFactory.create(indexFiles.postingLists(), postingListsMetadata.offset, postingListsMetadata.length);
158+
if (ordinalsMap.getStructure() == Structure.ZERO_OR_ONE_TO_MANY)
159+
logger.warn("Index {} has structure ZERO_OR_ONE_TO_MANY, which requires on reading the on disk row id" +
160+
" to ordinal mapping for each search. This will be slower.", source);
155161

156162
searchers = ExplicitThreadLocal.withInitial(() -> new GraphSearcherAccessManager(new GraphSearcher(graph)));
163+
164+
// Record metrics for this graph
165+
columnQueryMetrics.onGraphLoaded(compressedVectors == null ? 0 : compressedVectors.ramBytesUsed(),
166+
ordinalsMap.cachedBytesUsed(),
167+
graph.size(0));
157168
}
158169

159170
public Structure getPostingsStructure()
@@ -231,11 +242,15 @@ else if (compressedVectors == null)
231242
var rr = view.rerankerFor(queryVector, similarityFunction);
232243
ssp = new SearchScoreProvider(asf, rr);
233244
}
245+
long start = System.nanoTime();
234246
var result = searcher.search(ssp, limit, rerankK, threshold, context.getAnnRerankFloor(), ordinalsMap.ignoringDeleted(acceptBits));
247+
long elapsed = System.nanoTime() - start;
235248
if (V3OnDiskFormat.ENABLE_RERANK_FLOOR)
236249
context.updateAnnRerankFloor(result.getWorstApproximateScoreInTopK());
237250
Tracing.trace("DiskANN search for {}/{} visited {} nodes, reranked {} to return {} results from {}",
238251
limit, rerankK, result.getVisitedCount(), result.getRerankedCount(), result.getNodes().length, source);
252+
columnQueryMetrics.onSearchResult(result, elapsed, false);
253+
context.addAnnGraphSearchLatency(elapsed);
239254
if (threshold > 0)
240255
{
241256
// Threshold based searches are comprehensive and do not need to resume the search.
@@ -246,7 +261,7 @@ else if (compressedVectors == null)
246261
}
247262
else
248263
{
249-
var nodeScores = new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, nodesVisitedConsumer, limit, rerankK, false, source.toString());
264+
var nodeScores = new AutoResumingNodeScoreIterator(searcher, graphAccessManager, result, context, columnQueryMetrics, nodesVisitedConsumer, limit, rerankK, false, source.toString());
250265
return new NodeScoreToRowIdWithScoreIterator(nodeScores, ordinalsMap.getRowIdsView());
251266
}
252267
}
@@ -271,6 +286,9 @@ public CompressedVectors getCompressedVectors()
271286
public void close() throws IOException
272287
{
273288
FileUtils.close(ordinalsMap, searchers, graph, graphHandle);
289+
columnQueryMetrics.onGraphClosed(compressedVectors == null ? 0 : compressedVectors.ramBytesUsed(),
290+
ordinalsMap.cachedBytesUsed(),
291+
graph.size(0));
274292
}
275293

276294
public OrdinalsView getOrdinalsView()

0 commit comments

Comments
 (0)