Skip to content

Commit 163b3f5

Browse files
MarinPostmadjc
authored andcommitted
add AcceptorBuilder::with_acceptor method
1 parent 6e6df04 commit 163b3f5

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

src/acceptor.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ pub use builder::AcceptorBuilder;
1717
use builder::WantsTlsConfig;
1818

1919
/// A TLS acceptor that can be used with hyper servers.
20-
pub struct TlsAcceptor {
20+
pub struct TlsAcceptor<A = AddrIncoming> {
2121
config: Arc<ServerConfig>,
22-
incoming: AddrIncoming,
22+
acceptor: A,
2323
}
2424

2525
/// An Acceptor for the `https` scheme.
@@ -31,20 +31,27 @@ impl TlsAcceptor {
3131

3232
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
3333
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
34-
Self { config, incoming }
34+
Self {
35+
config,
36+
acceptor: incoming,
37+
}
3538
}
3639
}
3740

38-
impl Accept for TlsAcceptor {
39-
type Conn = TlsStream;
41+
impl<A> Accept for TlsAcceptor<A>
42+
where
43+
A: Accept<Error = io::Error> + Unpin,
44+
A::Conn: AsyncRead + AsyncWrite + Unpin,
45+
{
46+
type Conn = TlsStream<A::Conn>;
4047
type Error = io::Error;
4148

4249
fn poll_accept(
4350
self: Pin<&mut Self>,
4451
cx: &mut Context<'_>,
4552
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
4653
let pin = self.get_mut();
47-
Poll::Ready(match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
54+
Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) {
4855
Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
4956
Some(Err(e)) => Some(Err(e)),
5057
None => None,
@@ -66,22 +73,21 @@ where
6673
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
6774
// so we have to TlsAcceptor::accept and handshake to have access to it
6875
// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
69-
pub struct TlsStream {
70-
state: State,
76+
pub struct TlsStream<C = AddrStream> {
77+
state: State<C>,
7178
}
7279

73-
impl TlsStream {
74-
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> Self {
80+
impl<C: AsyncRead + AsyncWrite + Unpin> TlsStream<C> {
81+
fn new(stream: C, config: Arc<ServerConfig>) -> Self {
7582
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
7683
Self {
7784
state: State::Handshaking(accept),
7885
}
7986
}
80-
8187
/// Returns a reference to the underlying IO stream.
8288
///
8389
/// This should always return `Some`, except if an error has already been yielded.
84-
pub fn io(&self) -> Option<&AddrStream> {
90+
pub fn io(&self) -> Option<&C> {
8591
match &self.state {
8692
State::Handshaking(accept) => accept.get_ref(),
8793
State::Streaming(stream) => Some(stream.get_ref().0),
@@ -99,7 +105,7 @@ impl TlsStream {
99105
}
100106
}
101107

102-
impl AsyncRead for TlsStream {
108+
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<C> {
103109
fn poll_read(
104110
self: Pin<&mut Self>,
105111
cx: &mut Context,
@@ -122,7 +128,7 @@ impl AsyncRead for TlsStream {
122128
}
123129
}
124130

125-
impl AsyncWrite for TlsStream {
131+
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<C> {
126132
fn poll_write(
127133
self: Pin<&mut Self>,
128134
cx: &mut Context<'_>,
@@ -159,7 +165,7 @@ impl AsyncWrite for TlsStream {
159165
}
160166
}
161167

162-
enum State {
163-
Handshaking(tokio_rustls::Accept<AddrStream>),
164-
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
168+
enum State<C> {
169+
Handshaking(tokio_rustls::Accept<C>),
170+
Streaming(tokio_rustls::server::TlsStream<C>),
165171
}

src/acceptor/builder.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,17 @@ impl AcceptorBuilder<WantsIncoming> {
9090
/// Passes a [`AddrIncoming`] to configure the TLS connection and
9191
/// creates the [`TlsAcceptor`]
9292
pub fn with_incoming(self, incoming: impl Into<AddrIncoming>) -> TlsAcceptor {
93+
self.with_acceptor(incoming.into())
94+
}
95+
96+
/// Passes an acceptor implementing [`Accept`] to configure the TLS connection and
97+
/// creates the [`TlsAcceptor`]
98+
///
99+
/// [`Accept`]: hyper::server::accept::Accept
100+
pub fn with_acceptor<A>(self, acceptor: A) -> TlsAcceptor<A> {
93101
TlsAcceptor {
94102
config: Arc::new(self.0 .0),
95-
incoming: incoming.into(),
103+
acceptor,
96104
}
97105
}
98106
}

0 commit comments

Comments
 (0)