Skip to content

Add task cancellation check in aggregation code paths #18426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix 'system call filter not installed' caused when network.host: 0.0.0.0 ([#18309](https://github.com/opensearch-project/OpenSearch/pull/18309))
- Fix MatrixStatsAggregator reuse when mode parameter changes ([#18242](https://github.com/opensearch-project/OpenSearch/issues/18242))
- Replace the deprecated construction method of TopScoreDocCollectorManager with the new method ([#18395](https://github.com/opensearch-project/OpenSearch/pull/18395))
- Fixed Approximate Framework regression with Lucene 10.2.1 by updating `intersectRight` BKD walk and `IntRef` visit method ([#18358](https://github.com/opensearch-project/OpenSearch/issues/18358
- Fixed Approximate Framework regression with Lucene 10.2.1 by updating `intersectRight` BKD walk and `IntRef` visit method ([#18358](https://github.com/opensearch-project/OpenSearch/issues/18358))
- Add task cancellation checks in aggregators ([#18426](https://github.com/opensearch-project/OpenSearch/pull/18426))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -328,4 +329,10 @@
public String toString() {
return name;
}

protected void checkCancelled() {
if (context.isCancelled()) {
throw new TaskCancelledException("The query has been cancelled");

Check warning on line 335 in server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java#L335

Added line #L335 was not covered by tests
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,11 @@ protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {}
* array of ordinals
*/
protected final InternalAggregations[] buildSubAggsForBuckets(long[] bucketOrdsToCollect) throws IOException {
checkCancelled();
beforeBuildingBuckets(bucketOrdsToCollect);
InternalAggregation[][] aggregations = new InternalAggregation[subAggregators.length][];
for (int i = 0; i < subAggregators.length; i++) {
checkCancelled();
aggregations[i] = subAggregators[i].buildAggregations(bucketOrdsToCollect);
}
InternalAggregations[] result = new InternalAggregations[bucketOrdsToCollect.length];
Expand Down Expand Up @@ -323,6 +325,7 @@ protected final <B> InternalAggregation[] buildAggregationsForFixedBucketCount(
BucketBuilderForFixedCount<B> bucketBuilder,
Function<List<B>, InternalAggregation> resultBuilder
) throws IOException {
checkCancelled();
int totalBuckets = owningBucketOrds.length * bucketsPerOwningBucketOrd;
long[] bucketOrdsToCollect = new long[totalBuckets];
int bucketOrdIdx = 0;
Expand Down Expand Up @@ -373,6 +376,7 @@ protected final InternalAggregation[] buildAggregationsForSingleBucket(long[] ow
* `consumeBucketsAndMaybeBreak(owningBucketOrds.length)`
* here but we don't because single bucket aggs never have.
*/
checkCancelled();
InternalAggregations[] subAggregationResults = buildSubAggsForBuckets(owningBucketOrds);
InternalAggregation[] results = new InternalAggregation[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
Expand Down Expand Up @@ -403,6 +407,7 @@ protected final <B> InternalAggregation[] buildAggregationsForVariableBuckets(
BucketBuilderForVariable<B> bucketBuilder,
ResultBuilderForVariable<B> resultBuilder
) throws IOException {
checkCancelled();
long totalOrdsToCollect = 0;
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
totalOrdsToCollect += bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@

@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
checkCancelled();

Check warning on line 211 in server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java#L211

Added line #L211 was not covered by tests
// Buckets are ordered into groups - [keyed filters] [key1&key2 intersects]
int maxOrd = owningBucketOrds.length * totalNumKeys;
int totalBucketsToBuild = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ protected void doPostCollection() throws IOException {

@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
checkCancelled();
// Composite aggregator must be at the top of the aggregation tree
assert owningBucketOrds.length == 1 && owningBucketOrds[0] == 0L;
if (deferredCollectors != NO_OP_COLLECTOR) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
owningBucketOrds,
keys.length + (showOtherBucket ? 1 : 0),
(offsetInOwningOrd, docCount, subAggregationResults) -> {
checkCancelled();
if (offsetInOwningOrd < keys.length) {
return new InternalFilters.InternalBucket(keys[offsetInOwningOrd], docCount, subAggregationResults, keyed);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ public AbstractHistogramAggregator(
@Override
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
return buildAggregationsForVariableBuckets(owningBucketOrds, bucketOrds, (bucketValue, docCount, subAggregationResults) -> {
checkCancelled();
double roundKey = Double.longBitsToDouble(bucketValue);
double key = roundKey * interval + offset;
return new InternalHistogram.Bucket(key, docCount, keyed, formatter, subAggregationResults);
}, (owningBucketOrd, buckets) -> {
checkCancelled();
// the contract of the histogram aggregation is that shards must return buckets ordered by key in ascending order
CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ protected final InternalAggregation[] buildAggregations(
subAggregationResults
),
(owningBucketOrd, buckets) -> {
checkCancelled();
// the contract of the histogram aggregation is that shards must return
// buckets ordered by key in ascending order
CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator());
Expand Down Expand Up @@ -733,6 +734,7 @@ private int increaseRoundingIfNeeded(long owningBucketOrd, int oldEstimatedBucke
private void rebucket() {
rebucketCount++;
try (LongKeyedBucketOrds oldOrds = bucketOrds) {
checkCancelled();
long[] mergeMap = new long[Math.toIntExact(oldOrds.size())];
bucketOrds = new LongKeyedBucketOrds.FromMany(context.bigArrays());
for (long owningBucketOrd = 0; owningBucketOrd <= oldOrds.maxOwningBucketOrd(); owningBucketOrd++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
subAggregationResults
),
(owningBucketOrd, buckets) -> {
checkCancelled();
// the contract of the histogram aggregation is that shards must return buckets ordered by key in ascending order
CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I

List<InternalVariableWidthHistogram.Bucket> buckets = new ArrayList<>(numClusters);
for (int bucketOrd = 0; bucketOrd < numClusters; bucketOrd++) {
checkCancelled();
buckets.add(collector.buildBucket(bucketOrd, subAggregationResults[bucketOrd]));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
owningBucketOrds,
ranges.length,
(offsetInOwningOrd, docCount, subAggregationResults) -> {
checkCancelled();
Range range = ranges[offsetInOwningOrd];
return rangeFactory.createBucket(range.key, range.from, range.to, docCount, subAggregationResults, keyed, format);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
B[][] topBucketsPreOrd = buildTopBucketsPerOrd(owningBucketOrds.length);
long[] otherDocCount = new long[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
checkCancelled();
final int size;
if (localBucketCountThresholds.getMinDocCount() == 0) {
// if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
long offset = 0;
for (int owningOrdIdx = 0; owningOrdIdx < owningBucketOrds.length; owningOrdIdx++) {
try (LongHash bucketsInThisOwningBucketToCollect = new LongHash(1, context.bigArrays())) {
checkCancelled();
filters[owningOrdIdx] = newFilter();
List<LongRareTerms.Bucket> builtBuckets = new ArrayList<>();
LongKeyedBucketOrds.BucketOrdsEnum collectedBuckets = bucketOrds.ordsEnum(owningBucketOrds[owningOrdIdx]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
B[][] topBucketsPerOrd = buildTopBucketsPerOrd(owningBucketOrds.length);
long[] otherDocCounts = new long[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
checkCancelled();
collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]);
int size = (int) Math.min(bucketOrds.size(), localBucketCountThresholds.getRequiredSize());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
InternalMultiTerms.Bucket[][] topBucketsPerOrd = new InternalMultiTerms.Bucket[owningBucketOrds.length][];
long[] otherDocCounts = new long[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
checkCancelled();
collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]);
long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
B[][] topBucketsPerOrd = buildTopBucketsPerOrd(owningBucketOrds.length);
long[] otherDocCounts = new long[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
checkCancelled();
collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]);
long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
long offset = 0;
for (int owningOrdIdx = 0; owningOrdIdx < owningBucketOrds.length; owningOrdIdx++) {
try (BytesRefHash bucketsInThisOwningBucketToCollect = new BytesRefHash(context.bigArrays())) {
checkCancelled();
filters[owningOrdIdx] = newFilter();
List<StringRareTerms.Bucket> builtBuckets = new ArrayList<>();
BytesKeyedBucketOrds.BucketOrdsEnum collectedBuckets = bucketOrds.ordsEnum(owningBucketOrds[owningOrdIdx]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
*/
@PublicApi(since = "1.0.0")
public abstract class ValuesSource {
private Runnable cancellationCheck;

/**
* Get the current {@link BytesValues}.
Expand All @@ -101,6 +102,10 @@
*/
public abstract Function<Rounding, Rounding.Prepared> roundingPreparer(IndexReader reader) throws IOException;

protected void setCancellationCheck(Runnable cancellationCheck) {
this.cancellationCheck = cancellationCheck;
}

Check warning on line 107 in server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java#L106-L107

Added lines #L106 - L107 were not covered by tests

/**
* Check if this values source supports using global ordinals
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.opensearch.search.aggregations;

import org.opensearch.action.search.SearchShardTask;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexService;
Expand Down Expand Up @@ -38,6 +39,8 @@ public void setUp() throws Exception {
client().admin().indices().prepareRefresh("idx").get();
context = createSearchContext(index);
((TestSearchContext) context).setConcurrentSegmentSearchEnabled(true);
SearchShardTask task = new SearchShardTask(0, "n/a", "n/a", "test-kind", null, null);
context.setTask(task);
}

protected AggregatorFactories getAggregationFactories(String agg) throws IOException {
Expand Down
Loading
Loading