Skip to content

Commit 08358e5

Browse files
authored
Merge pull request torvalds#750 from wedsonaf/async-net
rust: add async tcp connection support
2 parents 8ddb2ec + 3abfbec commit 08358e5

File tree

7 files changed

+345
-5
lines changed

7 files changed

+345
-5
lines changed

rust/helpers.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,13 @@ void rust_helper_init_wait(struct wait_queue_entry *wq_entry)
272272
}
273273
EXPORT_SYMBOL_GPL(rust_helper_init_wait);
274274

275+
void rust_helper_init_waitqueue_func_entry(struct wait_queue_entry *wq_entry,
276+
wait_queue_func_t func)
277+
{
278+
init_waitqueue_func_entry(wq_entry, func);
279+
}
280+
EXPORT_SYMBOL_GPL(rust_helper_init_waitqueue_func_entry);
281+
275282
int rust_helper_signal_pending(struct task_struct *t)
276283
{
277284
return signal_pending(t);

rust/kernel/bindings_helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@
3939
/* `bindgen` gets confused at certain things. */
4040
const gfp_t BINDINGS_GFP_KERNEL = GFP_KERNEL;
4141
const gfp_t BINDINGS___GFP_ZERO = __GFP_ZERO;
42+
const __poll_t BINDINGS_EPOLLIN = EPOLLIN;
43+
const __poll_t BINDINGS_EPOLLOUT = EPOLLOUT;
44+
const __poll_t BINDINGS_EPOLLERR = EPOLLERR;
45+
const __poll_t BINDINGS_EPOLLHUP = EPOLLHUP;

rust/kernel/kasync.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
3+
//! Kernel async functionality.
4+
5+
#[cfg(CONFIG_NET)]
6+
pub mod net;

rust/kernel/kasync/net.rs

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
3+
//! Async networking.
4+
5+
use crate::{bindings, c_types, error::code::*, net, sync::NoWaitLock, types::Opaque, Result};
6+
use core::{
7+
future::Future,
8+
marker::{PhantomData, PhantomPinned},
9+
ops::Deref,
10+
pin::Pin,
11+
task::{Context, Poll, Waker},
12+
};
13+
14+
/// A socket listening on a TCP port.
15+
///
16+
/// The [`TcpListener::accept`] method is meant to be used in async contexts.
17+
pub struct TcpListener {
18+
listener: net::TcpListener,
19+
}
20+
21+
impl TcpListener {
22+
/// Creates a new TCP listener.
23+
///
24+
/// It is configured to listen on the given socket address for the given namespace.
25+
pub fn try_new(ns: &net::Namespace, addr: &net::SocketAddr) -> Result<Self> {
26+
Ok(Self {
27+
listener: net::TcpListener::try_new(ns, addr)?,
28+
})
29+
}
30+
31+
/// Accepts a new connection.
32+
///
33+
/// Returns a future that when ready indicates the result of the accept operation; on success,
34+
/// it contains the newly-accepted tcp stream.
35+
pub fn accept(&self) -> impl Future<Output = Result<TcpStream>> + '_ {
36+
SocketFuture::from_listener(
37+
self,
38+
bindings::BINDINGS_EPOLLIN | bindings::BINDINGS_EPOLLERR,
39+
|| {
40+
Ok(TcpStream {
41+
stream: self.listener.accept(false)?,
42+
})
43+
},
44+
)
45+
}
46+
}
47+
48+
impl Deref for TcpListener {
49+
type Target = net::TcpListener;
50+
51+
fn deref(&self) -> &Self::Target {
52+
&self.listener
53+
}
54+
}
55+
56+
/// A connected TCP socket.
57+
///
58+
/// The potentially blocking methods (e.g., [`TcpStream::read`], [`TcpStream::write`]) are meant
59+
/// to be used in async contexts.
60+
///
61+
/// # Examples
62+
///
63+
/// ```
64+
/// # use kernel::prelude::*;
65+
/// # use kernel::kasync::net::TcpStream;
66+
/// async fn echo_server(stream: TcpStream) -> Result {
67+
/// let mut buf = [0u8; 1024];
68+
/// loop {
69+
/// let n = stream.read(&mut buf).await?;
70+
/// if n == 0 {
71+
/// return Ok(());
72+
/// }
73+
/// stream.write_all(&buf[..n]).await?;
74+
/// }
75+
/// }
76+
/// ```
77+
pub struct TcpStream {
78+
stream: net::TcpStream,
79+
}
80+
81+
impl TcpStream {
82+
/// Reads data from a connected socket.
83+
///
84+
/// Returns a future that when ready indicates the result of the read operation; on success, it
85+
/// contains the number of bytes read, which will be zero if the connection is closed.
86+
pub fn read<'a>(&'a self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize>> + 'a {
87+
SocketFuture::from_stream(
88+
self,
89+
bindings::BINDINGS_EPOLLIN | bindings::BINDINGS_EPOLLHUP | bindings::BINDINGS_EPOLLERR,
90+
|| self.stream.read(buf, false),
91+
)
92+
}
93+
94+
/// Writes data to the connected socket.
95+
///
96+
/// Returns a future that when ready indicates the result of the write operation; on success, it
97+
/// contains the number of bytes written.
98+
pub fn write<'a>(&'a self, buf: &'a [u8]) -> impl Future<Output = Result<usize>> + 'a {
99+
SocketFuture::from_stream(
100+
self,
101+
bindings::BINDINGS_EPOLLOUT | bindings::BINDINGS_EPOLLHUP | bindings::BINDINGS_EPOLLERR,
102+
|| self.stream.write(buf, false),
103+
)
104+
}
105+
106+
/// Writes all the data to the connected socket.
107+
///
108+
/// Returns a future that when ready indicates the result of the write operation; on success, it
109+
/// has written all the data.
110+
pub async fn write_all<'a>(&'a self, buf: &'a [u8]) -> Result {
111+
let mut rem = buf;
112+
113+
while !rem.is_empty() {
114+
let n = self.write(rem).await?;
115+
rem = &rem[n..];
116+
}
117+
118+
Ok(())
119+
}
120+
}
121+
122+
impl Deref for TcpStream {
123+
type Target = net::TcpStream;
124+
125+
fn deref(&self) -> &Self::Target {
126+
&self.stream
127+
}
128+
}
129+
130+
/// A future for a socket operation.
131+
///
132+
/// # Invariants
133+
///
134+
/// `sock` is always non-null and valid for the duration of the lifetime of the instance.
135+
struct SocketFuture<'a, Out, F: FnMut() -> Result<Out> + Send + 'a> {
136+
sock: *mut bindings::socket,
137+
mask: u32,
138+
is_queued: bool,
139+
wq_entry: Opaque<bindings::wait_queue_entry>,
140+
waker: NoWaitLock<Option<Waker>>,
141+
_p: PhantomData<&'a ()>,
142+
_pin: PhantomPinned,
143+
operation: F,
144+
}
145+
146+
// SAFETY: A kernel socket can be used from any thread, `wq_entry` is only used on drop and when
147+
// `is_queued` is initially `false`.
148+
unsafe impl<Out, F: FnMut() -> Result<Out> + Send> Send for SocketFuture<'_, Out, F> {}
149+
150+
impl<'a, Out, F: FnMut() -> Result<Out> + Send + 'a> SocketFuture<'a, Out, F> {
151+
/// Creates a new socket future.
152+
///
153+
/// # Safety
154+
///
155+
/// Callers must ensure that `sock` is non-null, valid, and remains valid for the lifetime
156+
/// (`'a`) of the returned instance.
157+
unsafe fn new(sock: *mut bindings::socket, mask: u32, operation: F) -> Self {
158+
Self {
159+
sock,
160+
mask,
161+
is_queued: false,
162+
wq_entry: Opaque::uninit(),
163+
waker: NoWaitLock::new(None),
164+
operation,
165+
_p: PhantomData,
166+
_pin: PhantomPinned,
167+
}
168+
}
169+
170+
/// Creates a new socket future for a tcp listener.
171+
fn from_listener(listener: &'a TcpListener, mask: u32, operation: F) -> Self {
172+
// SAFETY: The socket is guaranteed to remain valid because it is bound to the reference to
173+
// the listener (whose existence guarantees the socket remains valid).
174+
unsafe { Self::new(listener.listener.sock, mask, operation) }
175+
}
176+
177+
/// Creates a new socket future for a tcp stream.
178+
fn from_stream(stream: &'a TcpStream, mask: u32, operation: F) -> Self {
179+
// SAFETY: The socket is guaranteed to remain valid because it is bound to the reference to
180+
// the stream (whose existence guarantees the socket remains valid).
181+
unsafe { Self::new(stream.stream.sock, mask, operation) }
182+
}
183+
184+
/// Callback called when the socket changes state.
185+
///
186+
/// If the state matches the one we're waiting on, we wake up the tak so that the future can be
187+
/// polled again.
188+
unsafe extern "C" fn wake_callback(
189+
wq_entry: *mut bindings::wait_queue_entry,
190+
_mode: c_types::c_uint,
191+
_flags: c_types::c_int,
192+
key: *mut c_types::c_void,
193+
) -> c_types::c_int {
194+
let mask = key as u32;
195+
196+
// SAFETY: The future is valid while this callback is called because we remove from the
197+
// queue on drop.
198+
//
199+
// There is a potential soundness issue here because we're generating a shared reference to
200+
// `Self` while `Self::poll` has a mutable (unique) reference. However, for `!Unpin` types
201+
// (like `Self`), `&mut T` is treated as `*mut T` per
202+
// https://github.com/rust-lang/rust/issues/63818 -- so we avoid the unsoundness. Once a
203+
// more definitive solution is available, we can change this to use it.
204+
let s = unsafe { &*crate::container_of!(wq_entry, Self, wq_entry) };
205+
if mask & s.mask == 0 {
206+
// Nothing to do as this notification doesn't interest us.
207+
return 0;
208+
}
209+
210+
// If we can't acquire the waker lock, the waker is in the process of being modified. Our
211+
// attempt to acquire the lock will be reported to the lock owner, so it will trigger the
212+
// wake up.
213+
if let Some(guard) = s.waker.try_lock() {
214+
if let Some(ref w) = *guard {
215+
let cloned = w.clone();
216+
drop(guard);
217+
cloned.wake();
218+
return 1;
219+
}
220+
}
221+
0
222+
}
223+
224+
/// Poll the future once.
225+
///
226+
/// It calls the operation and converts `EAGAIN` errors into a pending state.
227+
fn poll_once(self: Pin<&mut Self>) -> Poll<Result<Out>> {
228+
// SAFETY: We never move out of `this`.
229+
let this = unsafe { self.get_unchecked_mut() };
230+
match (this.operation)() {
231+
Ok(s) => Poll::Ready(Ok(s)),
232+
Err(e) => {
233+
if e == EAGAIN {
234+
Poll::Pending
235+
} else {
236+
Poll::Ready(Err(e))
237+
}
238+
}
239+
}
240+
}
241+
242+
/// Updates the waker stored in the future.
243+
///
244+
/// It automatically triggers a wake up on races with the reactor.
245+
fn set_waker(&self, waker: &Waker) {
246+
if let Some(mut guard) = self.waker.try_lock() {
247+
let old = core::mem::replace(&mut *guard, Some(waker.clone()));
248+
let contention = guard.unlock();
249+
drop(old);
250+
if !contention {
251+
return;
252+
}
253+
}
254+
255+
// We either couldn't store the waker because the existing one is being awakened, or the
256+
// reactor tried to acquire the lock while we held it (contention). In either case, we just
257+
// wake it up to ensure we don't miss any notification.
258+
waker.wake_by_ref();
259+
}
260+
}
261+
262+
impl<Out, F: FnMut() -> Result<Out> + Send> Future for SocketFuture<'_, Out, F> {
263+
type Output = Result<Out>;
264+
265+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266+
match self.as_mut().poll_once() {
267+
Poll::Ready(r) => Poll::Ready(r),
268+
Poll::Pending => {
269+
// Store away the latest waker every time we may `Pending`.
270+
self.set_waker(cx.waker());
271+
if self.is_queued {
272+
// Nothing else to do was the waiter is already queued.
273+
return Poll::Pending;
274+
}
275+
276+
// SAFETY: We never move out of `this`.
277+
let this = unsafe { self.as_mut().get_unchecked_mut() };
278+
279+
this.is_queued = true;
280+
281+
// SAFETY: `wq_entry` is valid for write.
282+
unsafe {
283+
bindings::init_waitqueue_func_entry(
284+
this.wq_entry.get(),
285+
Some(Self::wake_callback),
286+
)
287+
};
288+
289+
// SAFETY: `wq_entry` was just initialised above and is valid for read/write.
290+
// By the type invariants, the socket is always valid.
291+
unsafe {
292+
bindings::add_wait_queue(
293+
core::ptr::addr_of_mut!((*this.sock).wq.wait),
294+
this.wq_entry.get(),
295+
)
296+
};
297+
298+
// If the future wasn't queued yet, we need to poll again in case it reached
299+
// the desired state between the last poll and being queued (in which case we
300+
// would have missed the notification).
301+
self.poll_once()
302+
}
303+
}
304+
}
305+
}
306+
307+
impl<Out, F: FnMut() -> Result<Out> + Send> Drop for SocketFuture<'_, Out, F> {
308+
fn drop(&mut self) {
309+
if !self.is_queued {
310+
return;
311+
}
312+
313+
// SAFETY: `wq_entry` is initialised because `is_queued` is set to `true`, so it is valid
314+
// for read/write. By the type invariants, the socket is always valid.
315+
unsafe {
316+
bindings::remove_wait_queue(
317+
core::ptr::addr_of_mut!((*self.sock).wq.wait),
318+
self.wq_entry.get(),
319+
)
320+
};
321+
}
322+
}

rust/kernel/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ pub mod file;
5454
pub mod gpio;
5555
pub mod hwrng;
5656
pub mod irq;
57+
pub mod kasync;
5758
pub mod miscdev;
5859
pub mod mm;
5960
#[cfg(CONFIG_NET)]

rust/kernel/net.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ impl SocketAddrV6 {
231231
///
232232
/// The socket pointer is always non-null and valid.
233233
pub struct TcpListener {
234-
sock: *mut bindings::socket,
234+
pub(crate) sock: *mut bindings::socket,
235235
}
236236

237237
// SAFETY: `TcpListener` is just a wrapper for a kernel socket, which can be used from any thread.
@@ -313,7 +313,7 @@ impl Drop for TcpListener {
313313
///
314314
/// The socket pointer is always non-null and valid.
315315
pub struct TcpStream {
316-
sock: *mut bindings::socket,
316+
pub(crate) sock: *mut bindings::socket,
317317
}
318318

319319
// SAFETY: `TcpStream` is just a wrapper for a kernel socket, which can be used from any thread.
@@ -332,7 +332,7 @@ impl TcpStream {
332332
/// - If `block` is `false`, returns [`crate::error::code::EAGAIN`];
333333
/// - If `block` is `true`, blocks until an error occurs, the connection is closed, or some
334334
/// becomes readable.
335-
pub fn read(&mut self, buf: &mut [u8], block: bool) -> Result<usize> {
335+
pub fn read(&self, buf: &mut [u8], block: bool) -> Result<usize> {
336336
let mut msg = bindings::msghdr::default();
337337
let mut vec = bindings::kvec {
338338
iov_base: buf.as_mut_ptr().cast(),
@@ -364,7 +364,7 @@ impl TcpStream {
364364
/// If the send buffer of the socket is full, one of two behaviours will occur:
365365
/// - If `block` is `false`, returns [`crate::error::code::EAGAIN`];
366366
/// - If `block` is `true`, blocks until an error occurs or some data is written.
367-
pub fn write(&mut self, buf: &[u8], block: bool) -> Result<usize> {
367+
pub fn write(&self, buf: &[u8], block: bool) -> Result<usize> {
368368
let mut msg = bindings::msghdr {
369369
msg_flags: if block { 0 } else { bindings::MSG_DONTWAIT },
370370
..bindings::msghdr::default()

0 commit comments

Comments
 (0)