Skip to content

Commit 1b11f62

Browse files
committed
add TlsAcceptor::with_acceptor method
1 parent 6e6df04 commit 1b11f62

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

src/acceptor.rs

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ pub use builder::AcceptorBuilder;
1717
use builder::WantsTlsConfig;
1818

1919
/// A TLS acceptor that can be used with hyper servers.
20-
pub struct TlsAcceptor {
20+
pub struct TlsAcceptor<A = AddrIncoming> {
2121
config: Arc<ServerConfig>,
22-
incoming: AddrIncoming,
22+
acceptor: A,
2323
}
2424

2525
/// An Acceptor for the `https` scheme.
@@ -31,20 +31,23 @@ impl TlsAcceptor {
3131

3232
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
3333
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
34-
Self { config, incoming }
34+
Self { config, acceptor: incoming }
3535
}
3636
}
3737

38-
impl Accept for TlsAcceptor {
39-
type Conn = TlsStream;
38+
impl<A> Accept for TlsAcceptor<A>
39+
where A: Accept<Error = io::Error> + Unpin,
40+
A::Conn: AsyncRead + AsyncWrite + Unpin,
41+
{
42+
type Conn = TlsStream<A::Conn>;
4043
type Error = io::Error;
4144

4245
fn poll_accept(
4346
self: Pin<&mut Self>,
4447
cx: &mut Context<'_>,
4548
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
4649
let pin = self.get_mut();
47-
Poll::Ready(match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
50+
Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) {
4851
Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
4952
Some(Err(e)) => Some(Err(e)),
5053
None => None,
@@ -66,22 +69,24 @@ where
6669
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
6770
// so we have to TlsAcceptor::accept and handshake to have access to it
6871
// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
69-
pub struct TlsStream {
70-
state: State,
72+
pub struct TlsStream<C = AddrStream> {
73+
state: State<C>,
7174
}
7275

73-
impl TlsStream {
74-
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> Self {
76+
77+
impl<C> TlsStream<C>
78+
where C: AsyncRead + AsyncWrite + Unpin
79+
{
80+
fn new(stream: C, config: Arc<ServerConfig>) -> Self {
7581
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
7682
Self {
7783
state: State::Handshaking(accept),
7884
}
7985
}
80-
8186
/// Returns a reference to the underlying IO stream.
8287
///
8388
/// This should always return `Some`, except if an error has already been yielded.
84-
pub fn io(&self) -> Option<&AddrStream> {
89+
pub fn io(&self) -> Option<&C> {
8590
match &self.state {
8691
State::Handshaking(accept) => accept.get_ref(),
8792
State::Streaming(stream) => Some(stream.get_ref().0),
@@ -99,7 +104,9 @@ impl TlsStream {
99104
}
100105
}
101106

102-
impl AsyncRead for TlsStream {
107+
impl<C> AsyncRead for TlsStream<C>
108+
where C: AsyncRead + AsyncWrite + Unpin
109+
{
103110
fn poll_read(
104111
self: Pin<&mut Self>,
105112
cx: &mut Context,
@@ -122,7 +129,9 @@ impl AsyncRead for TlsStream {
122129
}
123130
}
124131

125-
impl AsyncWrite for TlsStream {
132+
impl<C> AsyncWrite for TlsStream<C>
133+
where C: AsyncRead + AsyncWrite + Unpin
134+
{
126135
fn poll_write(
127136
self: Pin<&mut Self>,
128137
cx: &mut Context<'_>,
@@ -159,7 +168,7 @@ impl AsyncWrite for TlsStream {
159168
}
160169
}
161170

162-
enum State {
163-
Handshaking(tokio_rustls::Accept<AddrStream>),
164-
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
171+
enum State<C> {
172+
Handshaking(tokio_rustls::Accept<C>),
173+
Streaming(tokio_rustls::server::TlsStream<C>),
165174
}

src/acceptor/builder.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,17 @@ impl AcceptorBuilder<WantsIncoming> {
9090
/// Passes a [`AddrIncoming`] to configure the TLS connection and
9191
/// creates the [`TlsAcceptor`]
9292
pub fn with_incoming(self, incoming: impl Into<AddrIncoming>) -> TlsAcceptor {
93+
self.with_acceptor(incoming.into())
94+
}
95+
96+
/// Passes an acceptor implementing [Accept] to configure the TLS connection and
97+
/// creates the [`TlsAcceptor`]
98+
///
99+
/// [Accept]: hyper::server::accept::Accept
100+
pub fn with_acceptor<A>(self, acceptor: A) -> TlsAcceptor<A> {
93101
TlsAcceptor {
94102
config: Arc::new(self.0 .0),
95-
incoming: incoming.into(),
103+
acceptor,
96104
}
97105
}
98106
}

0 commit comments

Comments
 (0)