Skip to content

Commit a730a19

Browse files
olegnntaiki-e
authored andcommitted
FlattenUnordered: always replace inner wakers (#2726)
1 parent 890f893 commit a730a19

File tree

3 files changed

+55
-12
lines changed

3 files changed

+55
-12
lines changed

futures-util/src/stream/stream/flatten_unordered.rs

+11-12
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,8 @@ impl WrappedWaker {
209209
///
210210
/// This function will modify waker's `inner_waker` via `UnsafeCell`, so
211211
/// it should be used only during `POLLING` phase by one thread at the time.
212-
unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) -> Waker {
212+
unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
213213
*self_arc.inner_waker.get() = cx.waker().clone().into();
214-
waker(self_arc.clone())
215214
}
216215

217216
/// Attempts to start the waking process for the waker with the given value.
@@ -414,6 +413,12 @@ where
414413
}
415414
};
416415

416+
// Safety: now state is `POLLING`.
417+
unsafe {
418+
WrappedWaker::replace_waker(this.stream_waker, cx);
419+
WrappedWaker::replace_waker(this.inner_streams_waker, cx)
420+
};
421+
417422
if poll_state_value & NEED_TO_POLL_STREAM != NONE {
418423
let mut stream_waker = None;
419424

@@ -431,13 +436,9 @@ where
431436

432437
break;
433438
} else {
434-
// Initialize base stream waker if it's not yet initialized
435-
if stream_waker.is_none() {
436-
// Safety: now state is `POLLING`.
437-
stream_waker
438-
.replace(unsafe { WrappedWaker::replace_waker(this.stream_waker, cx) });
439-
}
440-
let mut cx = Context::from_waker(stream_waker.as_ref().unwrap());
439+
let mut cx = Context::from_waker(
440+
stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
441+
);
441442

442443
match this.stream.as_mut().poll_next(&mut cx) {
443444
Poll::Ready(Some(item)) => {
@@ -475,9 +476,7 @@ where
475476
}
476477

477478
if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
478-
// Safety: now state is `POLLING`.
479-
let inner_streams_waker =
480-
unsafe { WrappedWaker::replace_waker(this.inner_streams_waker, cx) };
479+
let inner_streams_waker = waker(this.inner_streams_waker.clone());
481480
let mut cx = Context::from_waker(&inner_streams_waker);
482481

483482
match this.inner_streams.as_mut().poll_next(&mut cx) {

futures/tests/no-std/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![cfg(nightly)]
22
#![no_std]
3+
#![allow(useless_anonymous_reexport)]
34

45
#[cfg(feature = "futures-core-alloc")]
56
#[cfg(target_has_atomic = "ptr")]

futures/tests/stream.rs

+43
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use futures::stream::{self, StreamExt};
1414
use futures::task::Poll;
1515
use futures::{ready, FutureExt};
1616
use futures_core::Stream;
17+
use futures_executor::ThreadPool;
1718
use futures_test::task::noop_context;
1819

1920
#[test]
@@ -65,6 +66,7 @@ fn flatten_unordered() {
6566
use futures::task::*;
6667
use std::convert::identity;
6768
use std::pin::Pin;
69+
use std::sync::atomic::{AtomicBool, Ordering};
6870
use std::thread;
6971
use std::time::Duration;
7072

@@ -322,6 +324,47 @@ fn flatten_unordered() {
322324
assert_eq!(values, (0..60).collect::<Vec<u8>>());
323325
});
324326
}
327+
328+
// nested `flatten_unordered`
329+
let te = ThreadPool::new().unwrap();
330+
let handle = te
331+
.spawn_with_handle(async move {
332+
let inner = stream::iter(0..10)
333+
.then(|_| {
334+
let task = Arc::new(AtomicBool::new(false));
335+
let mut spawned = false;
336+
337+
future::poll_fn(move |cx| {
338+
if !spawned {
339+
let waker = cx.waker().clone();
340+
let task = task.clone();
341+
342+
std::thread::spawn(move || {
343+
std::thread::sleep(Duration::from_millis(500));
344+
task.store(true, Ordering::Release);
345+
346+
waker.wake_by_ref()
347+
});
348+
spawned = true;
349+
}
350+
351+
if task.load(Ordering::Acquire) {
352+
Poll::Ready(Some(()))
353+
} else {
354+
Poll::Pending
355+
}
356+
})
357+
})
358+
.map(|_| stream::once(future::ready(())))
359+
.flatten_unordered(None);
360+
361+
let stream = stream::once(future::ready(inner)).flatten_unordered(None);
362+
363+
assert_eq!(stream.count().await, 10);
364+
})
365+
.unwrap();
366+
367+
block_on(handle);
325368
}
326369

327370
#[test]

0 commit comments

Comments
 (0)