Skip to content

Commit 9bbdd3c

Browse files
Use shard pointer tracked by writer for recovery (#17868)
Signed-off-by: Varun Bharadwaj <[email protected]>
1 parent 032f409 commit 9bbdd3c

File tree

5 files changed

+130
-49
lines changed

5 files changed

+130
-49
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
7777
### Fixed
7878
- Fix bytes parameter on `_cat/recovery` ([#17598](https://github.com/opensearch-project/OpenSearch/pull/17598))
7979
- Fix slow performance of FeatureFlag checks ([#17611](https://github.com/opensearch-project/OpenSearch/pull/17611))
80+
- Fix shard recovery in pull-based ingestion to avoid skipping messages ([#17868](https://github.com/opensearch-project/OpenSearch/pull/17868)))
8081

8182
### Security
8283

server/src/main/java/org/opensearch/indices/pollingingest/DefaultStreamPoller.java

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public class DefaultStreamPoller implements StreamPoller {
5050
private ExecutorService processorThread;
5151

5252
// start of the batch, inclusive
53-
private IngestionShardPointer batchStartPointer;
53+
private IngestionShardPointer initialBatchStartPointer;
5454
private boolean includeBatchStartPointer = false;
5555

5656
private ResetState resetState;
@@ -105,7 +105,7 @@ public DefaultStreamPoller(
105105
this.consumer = Objects.requireNonNull(consumer);
106106
this.resetState = resetState;
107107
this.resetValue = resetValue;
108-
this.batchStartPointer = startPointer;
108+
this.initialBatchStartPointer = startPointer;
109109
this.state = initialState;
110110
this.persistedPointers = persistedPointers;
111111
if (!this.persistedPointers.isEmpty()) {
@@ -170,23 +170,23 @@ protected void startPoll() {
170170
if (resetState != ResetState.NONE) {
171171
switch (resetState) {
172172
case EARLIEST:
173-
batchStartPointer = consumer.earliestPointer();
174-
logger.info("Resetting offset by seeking to earliest offset {}", batchStartPointer.asString());
173+
initialBatchStartPointer = consumer.earliestPointer();
174+
logger.info("Resetting offset by seeking to earliest offset {}", initialBatchStartPointer.asString());
175175
break;
176176
case LATEST:
177-
batchStartPointer = consumer.latestPointer();
178-
logger.info("Resetting offset by seeking to latest offset {}", batchStartPointer.asString());
177+
initialBatchStartPointer = consumer.latestPointer();
178+
logger.info("Resetting offset by seeking to latest offset {}", initialBatchStartPointer.asString());
179179
break;
180180
case REWIND_BY_OFFSET:
181-
batchStartPointer = consumer.pointerFromOffset(resetValue);
182-
logger.info("Resetting offset by seeking to offset {}", batchStartPointer.asString());
181+
initialBatchStartPointer = consumer.pointerFromOffset(resetValue);
182+
logger.info("Resetting offset by seeking to offset {}", initialBatchStartPointer.asString());
183183
break;
184184
case REWIND_BY_TIMESTAMP:
185-
batchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue));
185+
initialBatchStartPointer = consumer.pointerFromTimestampMillis(Long.parseLong(resetValue));
186186
logger.info(
187187
"Resetting offset by seeking to timestamp {}, corresponding offset {}",
188188
resetValue,
189-
batchStartPointer.asString()
189+
initialBatchStartPointer.asString()
190190
);
191191
break;
192192
}
@@ -209,7 +209,8 @@ protected void startPoll() {
209209
List<IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message>> results;
210210

211211
if (includeBatchStartPointer) {
212-
results = consumer.readNext(batchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT);
212+
results = consumer.readNext(initialBatchStartPointer, true, MAX_POLL_SIZE, POLL_TIMEOUT);
213+
includeBatchStartPointer = false;
213214
} else {
214215
results = consumer.readNext(MAX_POLL_SIZE, POLL_TIMEOUT);
215216
}
@@ -220,38 +221,47 @@ protected void startPoll() {
220221
}
221222

222223
state = State.PROCESSING;
223-
// process the records
224-
boolean firstInBatch = true;
225-
for (IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message> result : results) {
226-
if (firstInBatch) {
227-
// update the batch start pointer to the next batch
228-
batchStartPointer = result.getPointer();
229-
firstInBatch = false;
230-
}
224+
processRecords(results);
225+
} catch (Exception e) {
226+
// Pause ingestion when an error is encountered while polling the streaming source.
227+
// Currently we do not have a good way to skip past the failing messages.
228+
// The user will have the option to manually update the offset and resume ingestion.
229+
// todo: support retry?
230+
logger.error("Pausing ingestion. Fatal error occurred in polling the shard {}: {}", consumer.getShardId(), e);
231+
pause();
232+
}
233+
}
234+
}
231235

232-
// check if the message is already processed
233-
if (isProcessed(result.getPointer())) {
234-
logger.info("Skipping message with pointer {} as it is already processed", result.getPointer().asString());
235-
continue;
236-
}
237-
totalPolledCount.inc();
238-
blockingQueue.put(result);
239-
240-
logger.debug(
241-
"Put message {} with pointer {} to the blocking queue",
242-
String.valueOf(result.getMessage().getPayload()),
243-
result.getPointer().asString()
244-
);
236+
private void processRecords(List<IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message>> results) {
237+
for (IngestionShardConsumer.ReadResult<? extends IngestionShardPointer, ? extends Message> result : results) {
238+
try {
239+
// check if the message is already processed
240+
if (isProcessed(result.getPointer())) {
241+
logger.debug("Skipping message with pointer {} as it is already processed", () -> result.getPointer().asString());
242+
continue;
245243
}
246-
// for future reads, we do not need to include the batch start pointer, and read from the last successful pointer.
247-
includeBatchStartPointer = false;
248-
} catch (Throwable e) {
249-
logger.error("Error in polling the shard {}: {}", consumer.getShardId(), e);
244+
totalPolledCount.inc();
245+
blockingQueue.put(result);
246+
247+
logger.debug(
248+
"Put message {} with pointer {} to the blocking queue",
249+
String.valueOf(result.getMessage().getPayload()),
250+
result.getPointer().asString()
251+
);
252+
} catch (Exception e) {
253+
logger.error(
254+
"Error in processing a record. Shard {}, pointer {}: {}",
255+
consumer.getShardId(),
256+
result.getPointer().asString(),
257+
e
258+
);
250259
errorStrategy.handleError(e, IngestionErrorStrategy.ErrorStage.POLLING);
251260

252261
if (!errorStrategy.shouldIgnoreError(e, IngestionErrorStrategy.ErrorStage.POLLING)) {
253262
// Blocking error encountered. Pause poller to stop processing remaining updates.
254263
pause();
264+
break;
255265
}
256266
}
257267
}
@@ -329,9 +339,16 @@ public boolean isClosed() {
329339
return closed;
330340
}
331341

342+
/**
343+
* Returns the batch start pointer from where the poller can resume in case of shard recovery. The poller and
344+
* processor are decoupled in this implementation, and hence the latest pointer tracked by the processor acts as the
345+
* recovery/start point. In case the processor has not started tracking, then the initial batchStartPointer used by
346+
* the poller acts as the start point.
347+
*/
332348
@Override
333349
public IngestionShardPointer getBatchStartPointer() {
334-
return batchStartPointer;
350+
IngestionShardPointer currentShardPointer = processorRunnable.getCurrentShardPointer();
351+
return currentShardPointer == null ? initialBatchStartPointer : currentShardPointer;
335352
}
336353

337354
@Override

server/src/main/java/org/opensearch/indices/pollingingest/MessageProcessorRunnable.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.document.StoredField;
1414
import org.apache.lucene.index.Term;
1515
import org.opensearch.action.DocWriteRequest;
16+
import org.opensearch.common.Nullable;
1617
import org.opensearch.common.lucene.uid.Versions;
1718
import org.opensearch.common.metrics.CounterMetric;
1819
import org.opensearch.common.util.RequestUtils;
@@ -59,6 +60,10 @@ public class MessageProcessorRunnable implements Runnable {
5960
private final MessageProcessor messageProcessor;
6061
private final CounterMetric stats = new CounterMetric();
6162

63+
// tracks the most recent pointer that is being processed
64+
@Nullable
65+
private volatile IngestionShardPointer currentShardPointer;
66+
6267
/**
6368
* Constructor.
6469
*
@@ -274,6 +279,7 @@ public void run() {
274279
if (readResult != null) {
275280
try {
276281
stats.inc();
282+
currentShardPointer = readResult.getPointer();
277283
messageProcessor.process(readResult.getMessage(), readResult.getPointer());
278284
readResult = null;
279285
} catch (Exception e) {
@@ -308,4 +314,9 @@ public IngestionErrorStrategy getErrorStrategy() {
308314
public void setErrorStrategy(IngestionErrorStrategy errorStrategy) {
309315
this.errorStrategy = errorStrategy;
310316
}
317+
318+
@Nullable
319+
public IngestionShardPointer getCurrentShardPointer() {
320+
return currentShardPointer;
321+
}
311322
}

server/src/test/java/org/opensearch/index/engine/IngestionEngineTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public void testCreateEngine() throws IOException {
102102
// verify the commit data
103103
Assert.assertEquals(7, commitData.size());
104104
// the commiit data is the start of the current batch
105-
Assert.assertEquals("0", commitData.get(StreamPoller.BATCH_START));
105+
Assert.assertEquals("1", commitData.get(StreamPoller.BATCH_START));
106106

107107
// verify the stored offsets
108108
var offset = new FakeIngestionSource.FakeIngestionShardPointer(0);

server/src/test/java/org/opensearch/indices/pollingingest/DefaultStreamPollerTests.java

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,14 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte
267267
);
268268
IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class);
269269
when(mockConsumer.getShardId()).thenReturn(0);
270-
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed"))
271-
.thenReturn(readResultsBatch1)
272-
.thenThrow(new RuntimeException("message3 poll failed"))
273-
.thenReturn(readResultsBatch2)
274-
.thenReturn(Collections.emptyList());
270+
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1);
271+
when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList());
275272

276273
IngestionErrorStrategy errorStrategy = spy(new DropIngestionErrorStrategy("ingestion_source"));
274+
ArrayBlockingQueue mockQueue = mock(ArrayBlockingQueue.class);
275+
doThrow(new RuntimeException()).doNothing().when(mockQueue).put(any());
276+
processorRunnable = new MessageProcessorRunnable(mockQueue, processor, errorStrategy);
277+
277278
poller = new DefaultStreamPoller(
278279
new FakeIngestionSource.FakeIngestionShardPointer(0),
279280
persistedPointers,
@@ -288,7 +289,7 @@ public void testDropErrorIngestionStrategy() throws TimeoutException, Interrupte
288289
Thread.sleep(sleepTime);
289290

290291
verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING));
291-
verify(processor, times(2)).process(any(), any());
292+
verify(mockQueue, times(4)).put(any());
292293
}
293294

294295
public void testBlockErrorIngestionStrategy() throws TimeoutException, InterruptedException {
@@ -314,12 +315,14 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt
314315
);
315316
IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class);
316317
when(mockConsumer.getShardId()).thenReturn(0);
317-
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenThrow(new RuntimeException("message1 poll failed"))
318-
.thenReturn(readResultsBatch1)
319-
.thenReturn(readResultsBatch2)
320-
.thenReturn(Collections.emptyList());
318+
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1);
319+
when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList());
321320

322321
IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source"));
322+
ArrayBlockingQueue mockQueue = mock(ArrayBlockingQueue.class);
323+
doThrow(new RuntimeException()).doNothing().when(mockQueue).put(any());
324+
processorRunnable = new MessageProcessorRunnable(mockQueue, processor, errorStrategy);
325+
323326
poller = new DefaultStreamPoller(
324327
new FakeIngestionSource.FakeIngestionShardPointer(0),
325328
persistedPointers,
@@ -334,7 +337,6 @@ public void testBlockErrorIngestionStrategy() throws TimeoutException, Interrupt
334337
Thread.sleep(sleepTime);
335338

336339
verify(errorStrategy, times(1)).handleError(any(), eq(IngestionErrorStrategy.ErrorStage.POLLING));
337-
verify(processor, never()).process(any(), any());
338340
assertEquals(DefaultStreamPoller.State.PAUSED, poller.getState());
339341
assertTrue(poller.isPaused());
340342
}
@@ -374,4 +376,54 @@ public void testUpdateErrorStrategy() {
374376
assertTrue(poller.getErrorStrategy() instanceof BlockIngestionErrorStrategy);
375377
assertTrue(processorRunnable.getErrorStrategy() instanceof BlockIngestionErrorStrategy);
376378
}
379+
380+
public void testPersistedBatchStartPointer() throws TimeoutException, InterruptedException {
381+
messages.add("{\"_id\":\"3\",\"_source\":{\"name\":\"bob\", \"age\": 24}}".getBytes(StandardCharsets.UTF_8));
382+
messages.add("{\"_id\":\"4\",\"_source\":{\"name\":\"alice\", \"age\": 21}}".getBytes(StandardCharsets.UTF_8));
383+
List<
384+
IngestionShardConsumer.ReadResult<
385+
FakeIngestionSource.FakeIngestionShardPointer,
386+
FakeIngestionSource.FakeIngestionMessage>> readResultsBatch1 = fakeConsumer.readNext(
387+
fakeConsumer.earliestPointer(),
388+
true,
389+
2,
390+
100
391+
);
392+
List<
393+
IngestionShardConsumer.ReadResult<
394+
FakeIngestionSource.FakeIngestionShardPointer,
395+
FakeIngestionSource.FakeIngestionMessage>> readResultsBatch2 = fakeConsumer.readNext(
396+
new FakeIngestionSource.FakeIngestionShardPointer(2),
397+
true,
398+
2,
399+
100
400+
);
401+
402+
// This test publishes 4 messages, so use blocking queue of size 3. This ensures the poller is blocked when adding the 4th message
403+
// for validation.
404+
IngestionErrorStrategy errorStrategy = spy(new BlockIngestionErrorStrategy("ingestion_source"));
405+
doThrow(new RuntimeException()).when(processor).process(any(), any());
406+
processorRunnable = new MessageProcessorRunnable(new ArrayBlockingQueue<>(3), processor, errorStrategy);
407+
408+
IngestionShardConsumer mockConsumer = mock(IngestionShardConsumer.class);
409+
when(mockConsumer.getShardId()).thenReturn(0);
410+
when(mockConsumer.readNext(any(), anyBoolean(), anyLong(), anyInt())).thenReturn(readResultsBatch1);
411+
412+
when(mockConsumer.readNext(anyLong(), anyInt())).thenReturn(readResultsBatch2).thenReturn(Collections.emptyList());
413+
414+
poller = new DefaultStreamPoller(
415+
new FakeIngestionSource.FakeIngestionShardPointer(0),
416+
persistedPointers,
417+
mockConsumer,
418+
processorRunnable,
419+
StreamPoller.ResetState.NONE,
420+
"",
421+
errorStrategy,
422+
StreamPoller.State.NONE
423+
);
424+
poller.start();
425+
Thread.sleep(sleepTime);
426+
427+
assertEquals(new FakeIngestionSource.FakeIngestionShardPointer(0), poller.getBatchStartPointer());
428+
}
377429
}

0 commit comments

Comments
 (0)