Skip to content

Commit bd605ac

Browse files
committed
feat(transport): make conn I/O handler behavior configurable
Signed-off-by: Roman Volosatovs <[email protected]>
1 parent 481e59d commit bd605ac

File tree

9 files changed

+243
-152
lines changed

9 files changed

+243
-152
lines changed

crates/transport-quic/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ repository.workspace = true
1313
anyhow = { workspace = true, features = ["std"] }
1414
bytes = { workspace = true }
1515
quinn = { workspace = true, features = ["runtime-tokio"] }
16+
tracing = { workspace = true }
1617
wrpc-transport = { workspace = true }
1718

1819
[dev-dependencies]
1920
futures = { workspace = true }
2021
test-log = { workspace = true, features = ["color", "log", "trace"] }
2122
tokio = { workspace = true, features = ["rt-multi-thread"] }
22-
tracing = { workspace = true }
2323
wrpc-test = { workspace = true, features = ["quic"] }

crates/transport-quic/src/lib.rs

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,62 @@
22
33
use anyhow::Context as _;
44
use bytes::Bytes;
5-
use quinn::{Connection, RecvStream, SendStream};
6-
use wrpc_transport::frame::{invoke, Accept, Incoming, Outgoing};
5+
use quinn::{Connection, RecvStream, SendStream, VarInt};
6+
use tracing::{debug, error, trace, warn};
7+
use wrpc_transport::frame::{Accept, Incoming, InvokeBuilder, Outgoing};
78
use wrpc_transport::Invoke;
89

9-
/// QUIC transport client
10+
/// QUIC server with graceful stream shutdown handling
11+
pub type Server = wrpc_transport::Server<(), RecvStream, SendStream, ConnHandler>;
12+
13+
/// QUIC wRPC client
1014
#[derive(Clone, Debug)]
1115
pub struct Client(Connection);
1216

17+
/// Graceful stream shutdown handler
18+
pub struct ConnHandler;
19+
20+
const DONE: VarInt = VarInt::from_u32(1);
21+
22+
impl wrpc_transport::frame::ConnHandler<RecvStream, SendStream> for ConnHandler {
23+
async fn on_ingress(mut rx: RecvStream, res: std::io::Result<()>) {
24+
if let Err(err) = res {
25+
error!(?err, "ingress failed");
26+
} else {
27+
debug!("ingress successfully complete");
28+
}
29+
if let Err(err) = rx.stop(DONE) {
30+
error!(?err, "failed to close stream");
31+
}
32+
}
33+
34+
async fn on_egress(mut tx: SendStream, res: std::io::Result<()>) {
35+
if let Err(err) = res {
36+
error!(?err, "egress failed");
37+
} else {
38+
debug!("egress successfully complete");
39+
}
40+
if let Err(err) = tx.finish() {
41+
error!(?err, "failed to close stream");
42+
}
43+
match tx.stopped().await {
44+
Ok(None) => {
45+
trace!("stream successfully closed")
46+
}
47+
Ok(Some(code)) => {
48+
if code == DONE {
49+
trace!("stream successfully closed")
50+
} else {
51+
warn!(?code, "stream closed with code")
52+
}
53+
}
54+
Err(err) => {
55+
error!(?err, "failed to await stream close");
56+
}
57+
}
58+
}
59+
}
60+
1361
impl From<Connection> for Client {
1462
fn from(conn: Connection) -> Self {
1563
Self(conn)
@@ -37,7 +85,9 @@ impl Invoke for &Client {
3785
.open_bi()
3886
.await
3987
.context("failed to open parameter stream")?;
40-
invoke(tx, rx, instance, func, params, paths).await
88+
InvokeBuilder::<ConnHandler>::default()
89+
.invoke(tx, rx, instance, func, params, paths)
90+
.await
4191
}
4292
}
4393

crates/transport-quic/tests/loopback.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ async fn loopback() -> anyhow::Result<()> {
1515
wrpc_test::with_quic(|clt, srv| async {
1616
let clt = Client::from(clt);
1717
let srv_conn = Client::from(srv);
18-
let srv = Arc::new(wrpc_transport::frame::Server::default());
18+
let srv = Arc::new(wrpc_transport_quic::Server::new());
1919
let invocations = srv
2020
.serve("foo", "bar", [Box::from([Some(42), Some(0)])])
2121
.await

crates/transport/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ anyhow = { workspace = true, features = ["std"] }
2020
bytes = { workspace = true }
2121
futures = { workspace = true, features = ["std"] }
2222
pin-project-lite = { workspace = true }
23+
send-future = { workspace = true }
2324
tokio = { workspace = true, features = ["macros", "rt", "time"] }
2425
tokio-stream = { workspace = true }
2526
tokio-util = { workspace = true, features = ["codec", "io"] }
2627
tracing = { workspace = true, features = ["attributes"] }
27-
send-future = { workspace = true }
2828
wasm-tokio = { workspace = true, features = ["tracing"] }
2929

3030
[target.'cfg(target_family = "wasm")'.dependencies]
Lines changed: 58 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,70 @@
1-
use std::sync::Arc;
1+
use core::marker::PhantomData;
22

33
use anyhow::Context as _;
44
use bytes::{BufMut as _, Bytes, BytesMut};
55
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
6-
use tokio::sync::mpsc;
7-
use tokio::task::JoinSet;
8-
use tokio_stream::wrappers::ReceiverStream;
96
use tokio_util::codec::Encoder;
10-
use tokio_util::io::StreamReader;
11-
use tokio_util::sync::PollSender;
12-
use tracing::{debug, error, instrument, trace, Instrument as _};
7+
use tracing::{instrument, trace};
138
use wasm_tokio::{CoreNameEncoder, CoreVecEncoderBytes};
149

15-
use crate::frame::conn::{egress, ingress, Incoming, Outgoing};
16-
use crate::frame::PROTOCOL;
10+
use crate::frame::conn::{Incoming, Outgoing};
11+
use crate::frame::{Conn, ConnHandler, PROTOCOL};
12+
13+
/// Defines invocation behavior
14+
#[derive(Clone)]
15+
pub struct InvokeBuilder<H = ()>(PhantomData<H>)
16+
where
17+
H: ?Sized;
18+
19+
impl<H> InvokeBuilder<H> {
20+
/// Invoke function `func` on instance `instance`
21+
#[instrument(level = "trace", skip_all)]
22+
pub async fn invoke<P, I, O>(
23+
self,
24+
mut tx: O,
25+
rx: I,
26+
instance: &str,
27+
func: &str,
28+
params: Bytes,
29+
paths: impl AsRef<[P]> + Send,
30+
) -> anyhow::Result<(Outgoing, Incoming)>
31+
where
32+
P: AsRef<[Option<usize>]> + Send + Sync,
33+
I: AsyncRead + Unpin + Send + 'static,
34+
O: AsyncWrite + Unpin + Send + 'static,
35+
H: ConnHandler<I, O>,
36+
{
37+
let mut buf = BytesMut::with_capacity(
38+
17_usize // len(PROTOCOL) + len(instance) + len(func) + len([]) + len(params)
39+
.saturating_add(instance.len())
40+
.saturating_add(func.len())
41+
.saturating_add(params.len()),
42+
);
43+
buf.put_u8(PROTOCOL);
44+
CoreNameEncoder.encode(instance, &mut buf)?;
45+
CoreNameEncoder.encode(func, &mut buf)?;
46+
buf.put_u8(0);
47+
CoreVecEncoderBytes.encode(params, &mut buf)?;
48+
trace!(?buf, "writing invocation");
49+
tx.write_all(&buf)
50+
.await
51+
.context("failed to initialize connection")?;
52+
53+
let Conn { tx, rx } = Conn::new::<H, _, _, _>(rx, tx, paths.as_ref());
54+
Ok((tx, rx))
55+
}
56+
}
57+
58+
impl<H> Default for InvokeBuilder<H> {
59+
fn default() -> Self {
60+
Self(PhantomData)
61+
}
62+
}
1763

1864
/// Invoke function `func` on instance `instance`
1965
#[instrument(level = "trace", skip_all)]
2066
pub async fn invoke<P, I, O>(
21-
mut tx: O,
67+
tx: O,
2268
rx: I,
2369
instance: &str,
2470
func: &str,
@@ -30,65 +76,7 @@ where
3076
I: AsyncRead + Unpin + Send + 'static,
3177
O: AsyncWrite + Unpin + Send + 'static,
3278
{
33-
let mut buf = BytesMut::with_capacity(
34-
17_usize // len(PROTOCOL) + len(instance) + len(func) + len([]) + len(params)
35-
.saturating_add(instance.len())
36-
.saturating_add(func.len())
37-
.saturating_add(params.len()),
38-
);
39-
buf.put_u8(PROTOCOL);
40-
CoreNameEncoder.encode(instance, &mut buf)?;
41-
CoreNameEncoder.encode(func, &mut buf)?;
42-
buf.put_u8(0);
43-
CoreVecEncoderBytes.encode(params, &mut buf)?;
44-
trace!(?buf, "writing invocation");
45-
tx.write_all(&buf)
79+
InvokeBuilder::<()>::default()
80+
.invoke(tx, rx, instance, func, params, paths)
4681
.await
47-
.context("failed to initialize connection")?;
48-
49-
let index = Arc::new(std::sync::Mutex::new(paths.as_ref().iter().collect()));
50-
let (results_tx, results_rx) = mpsc::channel(128);
51-
let mut results_io = JoinSet::new();
52-
results_io.spawn({
53-
let index = Arc::clone(&index);
54-
async move {
55-
if let Err(err) = ingress(rx, &index, results_tx).await {
56-
error!(?err, "result ingress failed");
57-
} else {
58-
debug!("result ingress successfully complete");
59-
}
60-
let Ok(mut index) = index.lock() else {
61-
error!("failed to lock index trie");
62-
return;
63-
};
64-
trace!("shutting down index trie");
65-
index.close_tx();
66-
}
67-
.in_current_span()
68-
});
69-
70-
let (params_tx, params_rx) = mpsc::channel(128);
71-
tokio::spawn(
72-
async {
73-
if let Err(err) = egress(params_rx, tx).await {
74-
error!(?err, "parameter egress failed");
75-
} else {
76-
debug!("parameter egress successfully complete");
77-
}
78-
}
79-
.in_current_span(),
80-
);
81-
Ok((
82-
Outgoing {
83-
tx: PollSender::new(params_tx),
84-
path: Arc::from([]),
85-
path_buf: Bytes::from_static(&[0]),
86-
},
87-
Incoming {
88-
rx: Some(StreamReader::new(ReceiverStream::new(results_rx))),
89-
path: Arc::from([]),
90-
index,
91-
io: Arc::new(results_io),
92-
},
93-
))
9482
}

crates/transport/src/frame/conn/mod.rs

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::future::Future;
12
use core::mem;
23
use core::pin::Pin;
34
use core::task::{ready, Context, Poll};
@@ -15,7 +16,7 @@ use tokio_stream::wrappers::ReceiverStream;
1516
use tokio_util::codec::Encoder;
1617
use tokio_util::io::StreamReader;
1718
use tokio_util::sync::PollSender;
18-
use tracing::{instrument, trace};
19+
use tracing::{debug, error, instrument, trace, Instrument as _, Span};
1920
use wasm_tokio::{AsyncReadLeb128 as _, Leb128Encoder};
2021

2122
use crate::Index;
@@ -484,10 +485,10 @@ async fn ingress(
484485
}
485486
}
486487

487-
#[instrument(level = "trace", skip_all, ret(level = "trace"))]
488+
#[instrument(level = "trace", skip_all)]
488489
async fn egress(
489-
mut rx: mpsc::Receiver<(Bytes, Bytes)>,
490490
mut tx: impl AsyncWrite + Unpin,
491+
mut rx: mpsc::Receiver<(Bytes, Bytes)>,
491492
) -> std::io::Result<()> {
492493
let mut buf = BytesMut::with_capacity(5);
493494
trace!("waiting for next frame");
@@ -503,3 +504,90 @@ async fn egress(
503504
trace!("shutting down outgoing stream");
504505
tx.shutdown().await
505506
}
507+
508+
/// Connection handler defines the connection I/O behavior.
509+
/// It is mostly useful for transports that may require additional clean up not already covered
510+
/// by [AsyncWrite::shutdown], for example.
511+
/// This API is experimental and may change in backwards-incompatible ways in the future.
512+
pub trait ConnHandler<Rx, Tx> {
513+
/// Handle ingress completion
514+
fn on_ingress(rx: Rx, res: std::io::Result<()>) -> impl Future<Output = ()> + Send {
515+
_ = rx;
516+
if let Err(err) = res {
517+
error!(?err, "ingress failed");
518+
} else {
519+
debug!("ingress successfully complete");
520+
}
521+
async {}
522+
}
523+
524+
/// Handle egress completion
525+
fn on_egress(tx: Tx, res: std::io::Result<()>) -> impl Future<Output = ()> + Send {
526+
_ = tx;
527+
if let Err(err) = res {
528+
error!(?err, "egress failed");
529+
} else {
530+
debug!("egress successfully complete");
531+
}
532+
async {}
533+
}
534+
}
535+
536+
impl<Rx, Tx> ConnHandler<Rx, Tx> for () {}
537+
538+
/// Peer connection
539+
pub(crate) struct Conn {
540+
rx: Incoming,
541+
tx: Outgoing,
542+
}
543+
544+
impl Conn {
545+
/// Creates a new [Conn] given an [AsyncRead], [ConnHandler] and a set of async paths
546+
fn new<H, Rx, Tx, P>(mut rx: Rx, mut tx: Tx, paths: impl IntoIterator<Item = P>) -> Self
547+
where
548+
Rx: AsyncRead + Unpin + Send + 'static,
549+
Tx: AsyncWrite + Unpin + Send + 'static,
550+
H: ConnHandler<Rx, Tx>,
551+
P: AsRef<[Option<usize>]>,
552+
{
553+
let index = Arc::new(std::sync::Mutex::new(paths.into_iter().collect()));
554+
let (rx_tx, rx_rx) = mpsc::channel(128);
555+
let mut rx_io = JoinSet::new();
556+
let span = Span::current();
557+
rx_io.spawn({
558+
let index = Arc::clone(&index);
559+
async move {
560+
let res = ingress(&mut rx, &index, rx_tx).await;
561+
H::on_ingress(rx, res).await;
562+
let Ok(mut index) = index.lock() else {
563+
error!("failed to lock index trie");
564+
return;
565+
};
566+
trace!("shutting down index trie");
567+
index.close_tx();
568+
}
569+
.instrument(span.clone())
570+
});
571+
let (tx_tx, tx_rx) = mpsc::channel(128);
572+
tokio::spawn(
573+
async {
574+
let res = egress(&mut tx, tx_rx).await;
575+
H::on_egress(tx, res).await;
576+
}
577+
.instrument(span.clone()),
578+
);
579+
Conn {
580+
tx: Outgoing {
581+
tx: PollSender::new(tx_tx),
582+
path: Arc::from([]),
583+
path_buf: Bytes::from_static(&[0]),
584+
},
585+
rx: Incoming {
586+
rx: Some(StreamReader::new(ReceiverStream::new(rx_rx))),
587+
path: Arc::from([]),
588+
index: Arc::clone(&index),
589+
io: Arc::new(rx_io),
590+
},
591+
}
592+
}
593+
}

0 commit comments

Comments
 (0)