Skip to content

Commit f92f9c1

Browse files
committed
Fix handling of timeout in SseEmitter
Closes gh-34426
1 parent 2b38c00 commit f92f9c1

File tree

1 file changed

+70
-22
lines changed

1 file changed

+70
-22
lines changed

Diff for: spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java

+70-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
2121
import java.util.LinkedHashSet;
2222
import java.util.List;
2323
import java.util.Set;
24-
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicReference;
2525
import java.util.function.Consumer;
2626

2727
import org.springframework.http.MediaType;
@@ -73,21 +73,20 @@ public class ResponseBodyEmitter {
7373
@Nullable
7474
private Handler handler;
7575

76+
private final AtomicReference<State> state = new AtomicReference<>(State.START);
77+
7678
/** Store send data before handler is initialized. */
7779
private final Set<DataWithMediaType> earlySendAttempts = new LinkedHashSet<>(8);
7880

79-
/** Store successful completion before the handler is initialized. */
80-
private final AtomicBoolean complete = new AtomicBoolean();
81-
8281
/** Store an error before the handler is initialized. */
8382
@Nullable
8483
private Throwable failure;
8584

86-
private final DefaultCallback timeoutCallback = new DefaultCallback();
85+
private final TimeoutCallback timeoutCallback = new TimeoutCallback();
8786

8887
private final ErrorCallback errorCallback = new ErrorCallback();
8988

90-
private final DefaultCallback completionCallback = new DefaultCallback();
89+
private final CompletionCallback completionCallback = new CompletionCallback();
9190

9291

9392
/**
@@ -128,7 +127,7 @@ synchronized void initialize(Handler handler) throws IOException {
128127
this.earlySendAttempts.clear();
129128
}
130129

131-
if (this.complete.get()) {
130+
if (this.state.get() == State.COMPLETE) {
132131
if (this.failure != null) {
133132
this.handler.completeWithError(this.failure);
134133
}
@@ -144,7 +143,7 @@ synchronized void initialize(Handler handler) throws IOException {
144143
}
145144

146145
void initializeWithError(Throwable ex) {
147-
if (this.complete.compareAndSet(false, true)) {
146+
if (this.state.compareAndSet(State.START, State.COMPLETE)) {
148147
this.failure = ex;
149148
this.earlySendAttempts.clear();
150149
this.errorCallback.accept(ex);
@@ -186,8 +185,7 @@ public void send(Object object) throws IOException {
186185
* @throws java.lang.IllegalStateException wraps any other errors
187186
*/
188187
public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException {
189-
Assert.state(!this.complete.get(), () -> "ResponseBodyEmitter has already completed" +
190-
(this.failure != null ? " with error: " + this.failure : ""));
188+
assertNotComplete();
191189
if (this.handler != null) {
192190
try {
193191
this.handler.send(object, mediaType);
@@ -214,11 +212,15 @@ public synchronized void send(Object object, @Nullable MediaType mediaType) thro
214212
* @since 6.0.12
215213
*/
216214
public synchronized void send(Set<DataWithMediaType> items) throws IOException {
217-
Assert.state(!this.complete.get(), () -> "ResponseBodyEmitter has already completed" +
218-
(this.failure != null ? " with error: " + this.failure : ""));
215+
assertNotComplete();
219216
sendInternal(items);
220217
}
221218

219+
private void assertNotComplete() {
220+
Assert.state(this.state.get() == State.START, () -> "ResponseBodyEmitter has already completed" +
221+
(this.failure != null ? " with error: " + this.failure : ""));
222+
}
223+
222224
private void sendInternal(Set<DataWithMediaType> items) throws IOException {
223225
if (items.isEmpty()) {
224226
return;
@@ -248,7 +250,7 @@ private void sendInternal(Set<DataWithMediaType> items) throws IOException {
248250
* related events such as an error while {@link #send(Object) sending}.
249251
*/
250252
public void complete() {
251-
if (this.complete.compareAndSet(false, true) && this.handler != null) {
253+
if (trySetComplete() && this.handler != null) {
252254
this.handler.complete();
253255
}
254256
}
@@ -265,14 +267,19 @@ public void complete() {
265267
* {@link #send(Object) sending}.
266268
*/
267269
public void completeWithError(Throwable ex) {
268-
if (this.complete.compareAndSet(false, true)) {
270+
if (trySetComplete()) {
269271
this.failure = ex;
270272
if (this.handler != null) {
271273
this.handler.completeWithError(ex);
272274
}
273275
}
274276
}
275277

278+
private boolean trySetComplete() {
279+
return (this.state.compareAndSet(State.START, State.COMPLETE) ||
280+
(this.state.compareAndSet(State.TIMEOUT, State.COMPLETE)));
281+
}
282+
276283
/**
277284
* Register code to invoke when the async request times out. This method is
278285
* called from a container thread when an async request times out.
@@ -369,7 +376,7 @@ public MediaType getMediaType() {
369376
}
370377

371378

372-
private class DefaultCallback implements Runnable {
379+
private class TimeoutCallback implements Runnable {
373380

374381
private final List<Runnable> delegates = new ArrayList<>(1);
375382

@@ -379,9 +386,10 @@ public synchronized void addDelegate(Runnable delegate) {
379386

380387
@Override
381388
public void run() {
382-
ResponseBodyEmitter.this.complete.compareAndSet(false, true);
383-
for (Runnable delegate : this.delegates) {
384-
delegate.run();
389+
if (ResponseBodyEmitter.this.state.compareAndSet(State.START, State.TIMEOUT)) {
390+
for (Runnable delegate : this.delegates) {
391+
delegate.run();
392+
}
385393
}
386394
}
387395
}
@@ -397,11 +405,51 @@ public synchronized void addDelegate(Consumer<Throwable> callback) {
397405

398406
@Override
399407
public void accept(Throwable t) {
400-
ResponseBodyEmitter.this.complete.compareAndSet(false, true);
401-
for(Consumer<Throwable> delegate : this.delegates) {
402-
delegate.accept(t);
408+
if (ResponseBodyEmitter.this.state.compareAndSet(State.START, State.COMPLETE)) {
409+
for (Consumer<Throwable> delegate : this.delegates) {
410+
delegate.accept(t);
411+
}
412+
}
413+
}
414+
}
415+
416+
417+
private class CompletionCallback implements Runnable {
418+
419+
private final List<Runnable> delegates = new ArrayList<>(1);
420+
421+
public synchronized void addDelegate(Runnable delegate) {
422+
this.delegates.add(delegate);
423+
}
424+
425+
@Override
426+
public void run() {
427+
if (ResponseBodyEmitter.this.state.compareAndSet(State.START, State.COMPLETE)) {
428+
for (Runnable delegate : this.delegates) {
429+
delegate.run();
430+
}
403431
}
404432
}
405433
}
406434

435+
436+
/**
437+
* Represents a state for {@link ResponseBodyEmitter}.
438+
* <p><pre>
439+
* START ----+
440+
* | |
441+
* v |
442+
* TIMEOUT |
443+
* | |
444+
* v |
445+
* COMPLETE <--+
446+
* </pre>
447+
* @since 6.2.4
448+
*/
449+
private enum State {
450+
START,
451+
TIMEOUT, // handling a timeout
452+
COMPLETE
453+
}
454+
407455
}

0 commit comments

Comments
 (0)