Skip to content

Commit 0b35430

Browse files
committed
await tasks to cancel
1 parent e0c3c6d commit 0b35430

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

crates/bevy_tasks/src/task_pool.rs

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ use std::{
22
future::Future,
33
marker::PhantomData,
44
mem,
5-
pin::Pin,
65
sync::Arc,
76
thread::{self, JoinHandle},
87
};
98

9+
use async_task::FallibleTask;
1010
use concurrent_queue::ConcurrentQueue;
1111
use futures_lite::{future, pin, FutureExt};
1212

@@ -248,8 +248,8 @@ impl TaskPool {
248248
let task_scope_executor = &async_executor::Executor::default();
249249
let task_scope_executor: &'env async_executor::Executor =
250250
unsafe { mem::transmute(task_scope_executor) };
251-
let spawned: ConcurrentQueue<async_executor::Task<T>> = ConcurrentQueue::unbounded();
252-
let spawned_ref: &'env ConcurrentQueue<async_executor::Task<T>> =
251+
let spawned: ConcurrentQueue<FallibleTask<T>> = ConcurrentQueue::unbounded();
252+
let spawned_ref: &'env ConcurrentQueue<FallibleTask<T>> =
253253
unsafe { mem::transmute(&spawned) };
254254

255255
let scope = Scope {
@@ -270,7 +270,7 @@ impl TaskPool {
270270
let get_results = async move {
271271
let mut results = Vec::with_capacity(spawned.len());
272272
while let Ok(task) = spawned.pop() {
273-
results.push(task.await);
273+
results.push(task.await.unwrap());
274274
}
275275

276276
results
@@ -279,23 +279,8 @@ impl TaskPool {
279279
// Pin the futures on the stack.
280280
pin!(get_results);
281281

282-
// SAFETY: This function blocks until all futures complete, so we do not read/write
283-
// the data from futures outside of the 'scope lifetime. However,
284-
// rust has no way of knowing this so we must convert to 'static
285-
// here to appease the compiler as it is unable to validate safety.
286-
let get_results: Pin<&mut (dyn Future<Output = Vec<T>> + 'static + Send)> = get_results;
287-
let get_results: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static + Send)> =
288-
unsafe { mem::transmute(get_results) };
289-
290-
// The thread that calls scope() will participate in driving tasks in the pool
291-
// forward until the tasks that are spawned by this scope() call
292-
// complete. (If the caller of scope() happens to be a thread in
293-
// this thread pool, and we only have one thread in the pool, then
294-
// simply calling future::block_on(spawned) would deadlock.)
295-
let mut spawned = task_scope_executor.spawn(get_results);
296-
297282
loop {
298-
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
283+
if let Some(result) = future::block_on(future::poll_once(&mut get_results)) {
299284
break result;
300285
};
301286

@@ -375,7 +360,7 @@ impl Drop for TaskPool {
375360
pub struct Scope<'scope, 'env: 'scope, T> {
376361
executor: &'scope async_executor::Executor<'scope>,
377362
task_scope_executor: &'scope async_executor::Executor<'scope>,
378-
spawned: &'scope ConcurrentQueue<async_executor::Task<T>>,
363+
spawned: &'scope ConcurrentQueue<FallibleTask<T>>,
379364
// make `Scope` invariant over 'scope and 'env
380365
scope: PhantomData<&'scope mut &'scope ()>,
381366
env: PhantomData<&'env mut &'env ()>,
@@ -391,7 +376,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
391376
///
392377
/// For more information, see [`TaskPool::scope`].
393378
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
394-
let task = self.executor.spawn(f);
379+
let task = self.executor.spawn(f).fallible();
395380
// ConcurrentQueue only errors when closed or full, but we never
396381
// close and use an unbouded queue, so it is safe to unwrap
397382
self.spawned.push(task).unwrap();
@@ -404,13 +389,26 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
404389
///
405390
/// For more information, see [`TaskPool::scope`].
406391
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
407-
let task = self.task_scope_executor.spawn(f);
392+
let task = self.task_scope_executor.spawn(f).fallible();
408393
// ConcurrentQueue only errors when closed or full, but we never
409394
// close and use an unbouded queue, so it is safe to unwrap
410395
self.spawned.push(task).unwrap();
411396
}
412397
}
413398

399+
impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
400+
where
401+
T: 'scope,
402+
{
403+
fn drop(&mut self) {
404+
future::block_on(async {
405+
while let Ok(task) = self.spawned.pop() {
406+
task.cancel().await;
407+
}
408+
});
409+
}
410+
}
411+
414412
#[cfg(test)]
415413
#[allow(clippy::disallowed_types)]
416414
mod tests {

0 commit comments

Comments
 (0)