Skip to content

Commit 4b174ce

Browse files
committed
sync: fix cloning value when receiving from broadcast channel
The broadcast channel does not require values to implement `Sync` yet it calls the `.clone()` method without synchronizing. This is unsound logic. This patch adds per-value synchronization on receive to handle this case. It is unlikely any usage of the broadcast channel is currently at risk of the unsoundeness issue as it requires accessing a `!Sync` type during `.clone()`, which would be very unusual when using the broadcast channel.
1 parent 9681ce2 commit 4b174ce

File tree

1 file changed

+27
-28
lines changed

1 file changed

+27
-28
lines changed

Diff for: tokio/src/sync/broadcast.rs

+27-28
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
119119
use crate::loom::cell::UnsafeCell;
120120
use crate::loom::sync::atomic::{AtomicBool, AtomicUsize};
121-
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
121+
use crate::loom::sync::{Arc, Mutex, MutexGuard};
122122
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
123123
use crate::util::WakeList;
124124

@@ -303,7 +303,7 @@ use self::error::{RecvError, SendError, TryRecvError};
303303
/// Data shared between senders and receivers.
304304
struct Shared<T> {
305305
/// slots in the channel.
306-
buffer: Box<[RwLock<Slot<T>>]>,
306+
buffer: Box<[Mutex<Slot<T>>]>,
307307

308308
/// Mask a position -> index.
309309
mask: usize,
@@ -347,7 +347,7 @@ struct Slot<T> {
347347
///
348348
/// The value is set by `send` when the write lock is held. When a reader
349349
/// drops, `rem` is decremented. When it hits zero, the value is dropped.
350-
val: UnsafeCell<Option<T>>,
350+
val: Option<T>,
351351
}
352352

353353
/// An entry in the wait queue.
@@ -385,7 +385,7 @@ generate_addr_of_methods! {
385385
}
386386

387387
struct RecvGuard<'a, T> {
388-
slot: RwLockReadGuard<'a, Slot<T>>,
388+
slot: MutexGuard<'a, Slot<T>>,
389389
}
390390

391391
/// Receive a value future.
@@ -394,11 +394,15 @@ struct Recv<'a, T> {
394394
receiver: &'a mut Receiver<T>,
395395

396396
/// Entry in the waiter `LinkedList`.
397-
waiter: UnsafeCell<Waiter>,
397+
waiter: WaiterCell,
398398
}
399399

400-
unsafe impl<'a, T: Send> Send for Recv<'a, T> {}
401-
unsafe impl<'a, T: Send> Sync for Recv<'a, T> {}
400+
// The wrapper around `UnsafeCell` isolates the unsafe impl `Send` and `Sync`
401+
// from `Recv`.
402+
struct WaiterCell(UnsafeCell<Waiter>);
403+
404+
unsafe impl Send for WaiterCell {}
405+
unsafe impl Sync for WaiterCell {}
402406

403407
/// Max number of receivers. Reserve space to lock.
404408
const MAX_RECEIVERS: usize = usize::MAX >> 2;
@@ -466,12 +470,6 @@ pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
466470
(tx, rx)
467471
}
468472

469-
unsafe impl<T: Send> Send for Sender<T> {}
470-
unsafe impl<T: Send> Sync for Sender<T> {}
471-
472-
unsafe impl<T: Send> Send for Receiver<T> {}
473-
unsafe impl<T: Send> Sync for Receiver<T> {}
474-
475473
impl<T> Sender<T> {
476474
/// Creates the sending-half of the [`broadcast`] channel.
477475
///
@@ -510,10 +508,10 @@ impl<T> Sender<T> {
510508
let mut buffer = Vec::with_capacity(capacity);
511509

512510
for i in 0..capacity {
513-
buffer.push(RwLock::new(Slot {
511+
buffer.push(Mutex::new(Slot {
514512
rem: AtomicUsize::new(0),
515513
pos: (i as u64).wrapping_sub(capacity as u64),
516-
val: UnsafeCell::new(None),
514+
val: None,
517515
}));
518516
}
519517

@@ -599,7 +597,7 @@ impl<T> Sender<T> {
599597
tail.pos = tail.pos.wrapping_add(1);
600598

601599
// Get the slot
602-
let mut slot = self.shared.buffer[idx].write().unwrap();
600+
let mut slot = self.shared.buffer[idx].lock();
603601

604602
// Track the position
605603
slot.pos = pos;
@@ -608,7 +606,7 @@ impl<T> Sender<T> {
608606
slot.rem.with_mut(|v| *v = rem);
609607

610608
// Write the value
611-
slot.val = UnsafeCell::new(Some(value));
609+
slot.val = Some(value);
612610

613611
// Release the slot lock before notifying the receivers.
614612
drop(slot);
@@ -695,7 +693,7 @@ impl<T> Sender<T> {
695693
while low < high {
696694
let mid = low + (high - low) / 2;
697695
let idx = base_idx.wrapping_add(mid) & self.shared.mask;
698-
if self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 {
696+
if self.shared.buffer[idx].lock().rem.load(SeqCst) == 0 {
699697
low = mid + 1;
700698
} else {
701699
high = mid;
@@ -737,7 +735,7 @@ impl<T> Sender<T> {
737735
let tail = self.shared.tail.lock();
738736

739737
let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize;
740-
self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0
738+
self.shared.buffer[idx].lock().rem.load(SeqCst) == 0
741739
}
742740

743741
/// Returns the number of active receivers.
@@ -1057,7 +1055,7 @@ impl<T> Receiver<T> {
10571055
let idx = (self.next & self.shared.mask as u64) as usize;
10581056

10591057
// The slot holding the next value to read
1060-
let mut slot = self.shared.buffer[idx].read().unwrap();
1058+
let mut slot = self.shared.buffer[idx].lock();
10611059

10621060
if slot.pos != self.next {
10631061
// Release the `slot` lock before attempting to acquire the `tail`
@@ -1074,7 +1072,7 @@ impl<T> Receiver<T> {
10741072
let mut tail = self.shared.tail.lock();
10751073

10761074
// Acquire slot lock again
1077-
slot = self.shared.buffer[idx].read().unwrap();
1075+
slot = self.shared.buffer[idx].lock();
10781076

10791077
// Make sure the position did not change. This could happen in the
10801078
// unlikely event that the buffer is wrapped between dropping the
@@ -1367,12 +1365,12 @@ impl<'a, T> Recv<'a, T> {
13671365
fn new(receiver: &'a mut Receiver<T>) -> Recv<'a, T> {
13681366
Recv {
13691367
receiver,
1370-
waiter: UnsafeCell::new(Waiter {
1368+
waiter: WaiterCell(UnsafeCell::new(Waiter {
13711369
queued: AtomicBool::new(false),
13721370
waker: None,
13731371
pointers: linked_list::Pointers::new(),
13741372
_p: PhantomPinned,
1375-
}),
1373+
})),
13761374
}
13771375
}
13781376

@@ -1384,7 +1382,7 @@ impl<'a, T> Recv<'a, T> {
13841382
is_unpin::<&mut Receiver<T>>();
13851383

13861384
let me = self.get_unchecked_mut();
1387-
(me.receiver, &me.waiter)
1385+
(me.receiver, &me.waiter.0)
13881386
}
13891387
}
13901388
}
@@ -1418,6 +1416,7 @@ impl<'a, T> Drop for Recv<'a, T> {
14181416
// `Shared::notify_rx` before we drop the object.
14191417
let queued = self
14201418
.waiter
1419+
.0
14211420
.with(|ptr| unsafe { (*ptr).queued.load(Acquire) });
14221421

14231422
// If the waiter is queued, we need to unlink it from the waiters list.
@@ -1432,6 +1431,7 @@ impl<'a, T> Drop for Recv<'a, T> {
14321431
// `Relaxed` order suffices because we hold the tail lock.
14331432
let queued = self
14341433
.waiter
1434+
.0
14351435
.with_mut(|ptr| unsafe { (*ptr).queued.load(Relaxed) });
14361436

14371437
if queued {
@@ -1440,7 +1440,7 @@ impl<'a, T> Drop for Recv<'a, T> {
14401440
// safety: tail lock is held and the wait node is verified to be in
14411441
// the list.
14421442
unsafe {
1443-
self.waiter.with_mut(|ptr| {
1443+
self.waiter.0.with_mut(|ptr| {
14441444
tail.waiters.remove((&mut *ptr).into());
14451445
});
14461446
}
@@ -1486,16 +1486,15 @@ impl<'a, T> RecvGuard<'a, T> {
14861486
where
14871487
T: Clone,
14881488
{
1489-
self.slot.val.with(|ptr| unsafe { (*ptr).clone() })
1489+
self.slot.val.clone()
14901490
}
14911491
}
14921492

14931493
impl<'a, T> Drop for RecvGuard<'a, T> {
14941494
fn drop(&mut self) {
14951495
// Decrement the remaining counter
14961496
if 1 == self.slot.rem.fetch_sub(1, SeqCst) {
1497-
// Safety: Last receiver, drop the value
1498-
self.slot.val.with_mut(|ptr| unsafe { *ptr = None });
1497+
self.slot.val = None;
14991498
}
15001499
}
15011500
}

0 commit comments

Comments
 (0)