Skip to content

Commit a3d2548

Browse files
authored
sync: implement Clone for watch::Sender (#6388)
1 parent b4ab647 commit a3d2548

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

tokio/src/sync/tests/loom_watch.rs

+20
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,23 @@ fn wait_for_returns_correct_value() {
8888
jh.join().unwrap();
8989
});
9090
}
91+
92+
#[test]
93+
fn multiple_sender_drop_concurrently() {
94+
loom::model(move || {
95+
let (tx1, rx) = watch::channel(0);
96+
let tx2 = tx1.clone();
97+
98+
let jh = thread::spawn(move || {
99+
drop(tx2);
100+
});
101+
assert!(rx.has_changed().is_ok());
102+
103+
drop(tx1);
104+
105+
jh.join().unwrap();
106+
107+
// Check if all sender are dropped and closed flag is set.
108+
assert!(rx.has_changed().is_err());
109+
});
110+
}

tokio/src/sync/watch.rs

+19-3
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
use crate::sync::notify::Notify;
115115

116116
use crate::loom::sync::atomic::AtomicUsize;
117-
use crate::loom::sync::atomic::Ordering::Relaxed;
117+
use crate::loom::sync::atomic::Ordering::{AcqRel, Relaxed};
118118
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
119119
use std::fmt;
120120
use std::mem;
@@ -146,6 +146,16 @@ pub struct Sender<T> {
146146
shared: Arc<Shared<T>>,
147147
}
148148

149+
impl<T> Clone for Sender<T> {
150+
fn clone(&self) -> Self {
151+
self.shared.ref_count_tx.fetch_add(1, Relaxed);
152+
153+
Self {
154+
shared: self.shared.clone(),
155+
}
156+
}
157+
}
158+
149159
/// Returns a reference to the inner value.
150160
///
151161
/// Outstanding borrows hold a read lock on the inner value. This means that
@@ -238,6 +248,9 @@ struct Shared<T> {
238248
/// Tracks the number of `Receiver` instances.
239249
ref_count_rx: AtomicUsize,
240250

251+
/// Tracks the number of `Sender` instances.
252+
ref_count_tx: AtomicUsize,
253+
241254
/// Notifies waiting receivers that the value changed.
242255
notify_rx: big_notify::BigNotify,
243256

@@ -485,6 +498,7 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
485498
value: RwLock::new(init),
486499
state: AtomicState::new(),
487500
ref_count_rx: AtomicUsize::new(1),
501+
ref_count_tx: AtomicUsize::new(1),
488502
notify_rx: big_notify::BigNotify::new(),
489503
notify_tx: Notify::new(),
490504
});
@@ -1302,8 +1316,10 @@ impl<T> Sender<T> {
13021316

13031317
impl<T> Drop for Sender<T> {
13041318
fn drop(&mut self) {
1305-
self.shared.state.set_closed();
1306-
self.shared.notify_rx.notify_waiters();
1319+
if self.shared.ref_count_tx.fetch_sub(1, AcqRel) == 1 {
1320+
self.shared.state.set_closed();
1321+
self.shared.notify_rx.notify_waiters();
1322+
}
13071323
}
13081324
}
13091325

tokio/tests/sync_watch.rs

+37-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ use wasm_bindgen_test::wasm_bindgen_test as test;
77

88
use tokio::sync::watch;
99
use tokio_test::task::spawn;
10-
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok};
10+
use tokio_test::{
11+
assert_pending, assert_ready, assert_ready_eq, assert_ready_err, assert_ready_ok,
12+
};
1113

1214
#[test]
1315
fn single_rx_recv() {
@@ -332,3 +334,37 @@ fn send_modify_panic() {
332334
assert_ready_ok!(task.poll());
333335
assert_eq!(*rx.borrow_and_update(), "three");
334336
}
337+
338+
#[tokio::test]
339+
async fn multiple_sender() {
340+
let (tx1, mut rx) = watch::channel(0);
341+
let tx2 = tx1.clone();
342+
343+
let mut t = spawn(async {
344+
rx.changed().await.unwrap();
345+
let v1 = *rx.borrow_and_update();
346+
rx.changed().await.unwrap();
347+
let v2 = *rx.borrow_and_update();
348+
(v1, v2)
349+
});
350+
351+
tx1.send(1).unwrap();
352+
assert_pending!(t.poll());
353+
tx2.send(2).unwrap();
354+
assert_ready_eq!(t.poll(), (1, 2));
355+
}
356+
357+
#[tokio::test]
358+
async fn reciever_is_notified_when_last_sender_is_dropped() {
359+
let (tx1, mut rx) = watch::channel(0);
360+
let tx2 = tx1.clone();
361+
362+
let mut t = spawn(rx.changed());
363+
assert_pending!(t.poll());
364+
365+
drop(tx1);
366+
assert!(!t.is_woken());
367+
drop(tx2);
368+
369+
assert!(t.is_woken());
370+
}

0 commit comments

Comments
 (0)