Skip to content

Commit e804c9e

Browse files
authored
Add client disconnection handling in ASGI and RSGI (#524)
* Handle client disconnection in ASGI * Add `RSGIHTTPProtocol.client_disconnect`
1 parent 74020ed commit e804c9e

File tree

15 files changed

+231
-58
lines changed

15 files changed

+231
-58
lines changed

docs/spec/RSGI.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# RSGI Specification
22

3-
**Version:** 1.4
3+
**Version:** 1.5
44

55
## Abstract
66

@@ -165,10 +165,11 @@ And here are descriptions for the upper attributes:
165165

166166
#### HTTP protocol interface
167167

168-
HTTP protocol object implements two awaitable methods to receive the request body, and five different methods to send data, in particular:
168+
HTTP protocol object implements two awaitable methods to receive the request body, five different methods to send data, and one awaitable method to wait for client disconnection, in particular:
169169

170170
- `__call__` to receive the entire body in `bytes` format
171171
- `__aiter__` to receive the body in `bytes` chunks
172+
- `client_disconnect` to watch for client disconnection
172173
- `response_empty` to send back an empty response
173174
- `response_str` to send back a response with a `str` body
174175
- `response_bytes` to send back a response with `bytes` body
@@ -180,6 +181,7 @@ All the upper-mentioned response methods accepts an integer `status` parameter,
180181
```
181182
coroutine __call__() -> body
182183
asynciterator __aiter__() -> body chunks
184+
coroutine client_disconnect()
183185
function response_empty(status, headers)
184186
function response_str(status, headers, body)
185187
function response_bytes(status, headers, body)
@@ -197,6 +199,10 @@ coroutine send_bytes(bytes)
197199
coroutine send_str(str)
198200
```
199201

202+
The `client_disconnect` method will return a future that resolve ones the client has disconnected.
203+
204+
> **Note:** as HTTP supports keep-alived connections, the lifecycle of the client connection might not be the same of the single request. This is why the RSGI specification doesn't imply `client_disconnect` should resolve in case a client sends multiple requests within the same connection, and thus the protocol delegates to the application the responsibility to cancel the disconnection watcher once the response is sent.
205+
200206
### Websocket protocol
201207

202208
WebSockets share some HTTP details - they have a path and headers - but also have more state. Again, most of that state is in the scope, which will live as long as the socket does.

granian/_granian.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class RSGIHTTPStreamTransport:
2121
class RSGIHTTPProtocol:
2222
async def __call__(self) -> bytes: ...
2323
def __aiter__(self) -> Any: ...
24+
async def client_disconnect(self) -> None: ...
2425
def response_empty(self, status: int, headers: List[Tuple[str, str]]): ...
2526
def response_str(self, status: int, headers: List[Tuple[str, str]], body: str): ...
2627
def response_bytes(self, status: int, headers: List[Tuple[str, str]], body: bytes): ...

src/asgi/callbacks.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use pyo3::prelude::*;
22
use pyo3::types::PyDict;
3-
use std::{net::SocketAddr, sync::OnceLock};
4-
use tokio::sync::oneshot;
3+
use std::{
4+
net::SocketAddr,
5+
sync::{Arc, OnceLock},
6+
};
7+
use tokio::sync::{oneshot, Notify};
58

69
use super::{
710
io::{ASGIHTTPProtocol as HTTPProtocol, ASGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport},
@@ -145,14 +148,15 @@ impl CallbackWatcherWebsocket {
145148
pub(crate) fn call_http(
146149
cb: ArcCBScheduler,
147150
rt: RuntimeRef,
151+
disconnect_guard: Arc<Notify>,
148152
server_addr: SocketAddr,
149153
client_addr: SocketAddr,
150154
scheme: &str,
151155
req: hyper::http::request::Parts,
152156
body: hyper::body::Incoming,
153157
) -> oneshot::Receiver<HTTPResponse> {
154158
let (tx, rx) = oneshot::channel();
155-
let protocol = HTTPProtocol::new(rt.clone(), body, tx);
159+
let protocol = HTTPProtocol::new(rt.clone(), body, tx, disconnect_guard);
156160
let scheme: Box<str> = scheme.into();
157161

158162
rt.spawn_blocking(move |py| {

src/asgi/http.rs

+19-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use http_body_util::BodyExt;
22
use hyper::{header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, StatusCode};
3-
use std::net::SocketAddr;
4-
use tokio::sync::mpsc;
3+
use std::{net::SocketAddr, sync::Arc};
4+
use tokio::sync::{mpsc, Notify};
55

66
use super::callbacks::{call_http, call_ws};
77
use crate::{
@@ -16,8 +16,19 @@ const SCHEME_WS: &str = "ws";
1616
const SCHEME_WSS: &str = "wss";
1717

1818
macro_rules! handle_http_response {
19-
($handler:expr, $rt:expr, $callback:expr, $server_addr:expr, $client_addr:expr, $scheme:expr, $req:expr, $body:expr) => {
20-
match $handler($callback, $rt, $server_addr, $client_addr, $req, $scheme, $body).await {
19+
($handler:expr, $rt:expr, $disconnect_guard:expr, $callback:expr, $server_addr:expr, $client_addr:expr, $scheme:expr, $req:expr, $body:expr) => {
20+
match $handler(
21+
$callback,
22+
$rt,
23+
$disconnect_guard,
24+
$server_addr,
25+
$client_addr,
26+
$req,
27+
$scheme,
28+
$body,
29+
)
30+
.await
31+
{
2132
Ok(res) => res,
2233
_ => {
2334
log::error!("ASGI protocol failure");
@@ -32,6 +43,7 @@ macro_rules! handle_request {
3243
#[inline]
3344
pub(crate) async fn $func_name(
3445
rt: RuntimeRef,
46+
disconnect_guard: Arc<Notify>,
3547
callback: ArcCBScheduler,
3648
server_addr: SocketAddr,
3749
client_addr: SocketAddr,
@@ -42,6 +54,7 @@ macro_rules! handle_request {
4254
handle_http_response!(
4355
$handler,
4456
rt,
57+
disconnect_guard,
4558
callback,
4659
server_addr,
4760
client_addr,
@@ -58,6 +71,7 @@ macro_rules! handle_request_with_ws {
5871
#[inline]
5972
pub(crate) async fn $func_name(
6073
rt: RuntimeRef,
74+
disconnect_guard: Arc<Notify>,
6175
callback: ArcCBScheduler,
6276
server_addr: SocketAddr,
6377
client_addr: SocketAddr,
@@ -142,6 +156,7 @@ macro_rules! handle_request_with_ws {
142156
handle_http_response!(
143157
$handler_req,
144158
rt,
159+
disconnect_guard,
145160
callback,
146161
server_addr,
147162
client_addr,

src/asgi/io.rs

+43-11
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::{
1313
};
1414
use tokio::{
1515
fs::File,
16-
sync::{mpsc, oneshot, Mutex as AsyncMutex},
16+
sync::{mpsc, oneshot, Mutex as AsyncMutex, Notify},
1717
};
1818
use tokio_tungstenite::tungstenite::Message;
1919
use tokio_util::io::ReaderStream;
@@ -25,7 +25,9 @@ use super::{
2525
use crate::{
2626
conversion::FutureResultToPy,
2727
http::{response_404, HTTPResponse, HTTPResponseBody, HV_SERVER},
28-
runtime::{empty_future_into_py, err_future_into_py, future_into_py_futlike, Runtime, RuntimeRef},
28+
runtime::{
29+
done_future_into_py, empty_future_into_py, err_future_into_py, future_into_py_futlike, Runtime, RuntimeRef,
30+
},
2931
ws::{HyperWebsocket, UpgradeData, WSRxStream, WSTxStream},
3032
};
3133

@@ -37,27 +39,36 @@ static WS_SUBPROTO_HNAME: &str = "Sec-WebSocket-Protocol";
3739
pub(crate) struct ASGIHTTPProtocol {
3840
rt: RuntimeRef,
3941
tx: Mutex<Option<oneshot::Sender<HTTPResponse>>>,
42+
disconnect_guard: Arc<Notify>,
4043
request_body: Arc<AsyncMutex<http_body_util::BodyStream<body::Incoming>>>,
4144
response_started: atomic::AtomicBool,
4245
response_chunked: atomic::AtomicBool,
4346
response_intent: Mutex<Option<(u16, HeaderMap)>>,
4447
body_tx: Mutex<Option<mpsc::Sender<Result<body::Bytes, anyhow::Error>>>>,
4548
flow_rx_exhausted: Arc<atomic::AtomicBool>,
49+
flow_rx_closed: Arc<atomic::AtomicBool>,
4650
flow_tx_waiter: Arc<tokio::sync::Notify>,
4751
sent_response_code: Arc<atomic::AtomicU16>,
4852
}
4953

5054
impl ASGIHTTPProtocol {
51-
pub fn new(rt: RuntimeRef, body: hyper::body::Incoming, tx: oneshot::Sender<HTTPResponse>) -> Self {
55+
pub fn new(
56+
rt: RuntimeRef,
57+
body: hyper::body::Incoming,
58+
tx: oneshot::Sender<HTTPResponse>,
59+
disconnect_guard: Arc<Notify>,
60+
) -> Self {
5261
Self {
5362
rt,
5463
tx: Mutex::new(Some(tx)),
64+
disconnect_guard,
5565
request_body: Arc::new(AsyncMutex::new(http_body_util::BodyStream::new(body))),
5666
response_started: false.into(),
5767
response_chunked: false.into(),
5868
response_intent: Mutex::new(None),
5969
body_tx: Mutex::new(None),
6070
flow_rx_exhausted: Arc::new(atomic::AtomicBool::new(false)),
71+
flow_rx_closed: Arc::new(atomic::AtomicBool::new(false)),
6172
flow_tx_waiter: Arc::new(tokio::sync::Notify::new()),
6273
sent_response_code: Arc::new(atomic::AtomicU16::new(500)),
6374
}
@@ -108,35 +119,56 @@ impl ASGIHTTPProtocol {
108119
#[pymethods]
109120
impl ASGIHTTPProtocol {
110121
fn receive<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
122+
if self.flow_rx_closed.load(atomic::Ordering::Acquire) {
123+
return done_future_into_py(
124+
py,
125+
super::conversion::message_into_py(py, ASGIMessageType::HTTPDisconnect).map(Bound::unbind),
126+
);
127+
}
128+
111129
if self.flow_rx_exhausted.load(atomic::Ordering::Acquire) {
112130
let flow_hld = self.flow_tx_waiter.clone();
131+
let flow_dgr = self.disconnect_guard.clone();
132+
let flow_dsr = self.flow_rx_closed.clone();
113133
return future_into_py_futlike(self.rt.clone(), py, async move {
114-
let () = flow_hld.notified().await;
134+
tokio::select! {
135+
() = flow_hld.notified() => {},
136+
() = flow_dgr.notified() => flow_dsr.store(true, atomic::Ordering::Release),
137+
}
115138
FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPDisconnect)
116139
});
117140
}
118141

119142
let body_ref = self.request_body.clone();
120143
let flow_ref = self.flow_rx_exhausted.clone();
121144
let flow_hld = self.flow_tx_waiter.clone();
145+
let flow_dgr = self.disconnect_guard.clone();
146+
let flow_dsr = self.flow_rx_closed.clone();
122147
future_into_py_futlike(self.rt.clone(), py, async move {
123148
let mut bodym = body_ref.lock().await;
124149
let body = &mut *bodym;
125150
let mut more_body = false;
126-
let chunk = match body.next().await {
127-
Some(Ok(buf)) => {
128-
more_body = true;
129-
Ok(buf.into_data().unwrap_or_default())
151+
152+
let chunk = tokio::select! {
153+
frame = body.next() => match frame {
154+
Some(Ok(buf)) => {
155+
more_body = true;
156+
Some(buf.into_data().unwrap_or_default())
157+
}
158+
Some(Err(_)) => None,
159+
_ => Some(body::Bytes::new()),
160+
},
161+
() = flow_dgr.notified() => {
162+
flow_dsr.store(true, atomic::Ordering::Release);
163+
None
130164
}
131-
Some(Err(err)) => Err(err),
132-
_ => Ok(body::Bytes::new()),
133165
};
134166
if !more_body {
135167
flow_ref.store(true, atomic::Ordering::Release);
136168
}
137169

138170
match chunk {
139-
Ok(data) => FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPRequestBody((data, more_body))),
171+
Some(data) => FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPRequestBody((data, more_body))),
140172
_ => {
141173
flow_hld.notify_one();
142174
FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPDisconnect)

src/callbacks.rs

+29
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,35 @@ impl PyEmptyAwaitable {
333333
}
334334
}
335335

336+
#[pyclass(frozen, module = "granian._granian")]
337+
pub(crate) struct PyDoneAwaitable {
338+
result: PyResult<PyObject>,
339+
}
340+
341+
impl PyDoneAwaitable {
342+
pub(crate) fn new(result: PyResult<PyObject>) -> Self {
343+
Self { result }
344+
}
345+
}
346+
347+
#[pymethods]
348+
impl PyDoneAwaitable {
349+
fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
350+
pyself
351+
}
352+
353+
fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
354+
pyself
355+
}
356+
357+
fn __next__(&self, py: Python) -> PyResult<PyObject> {
358+
self.result
359+
.as_ref()
360+
.map(|v| v.clone_ref(py))
361+
.map_err(|v| v.clone_ref(py))
362+
}
363+
}
364+
336365
#[pyclass(frozen, module = "granian._granian")]
337366
pub(crate) struct PyErrAwaitable {
338367
result: PyResult<()>,

src/rsgi/callbacks.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use pyo3::prelude::*;
2-
use std::sync::OnceLock;
3-
use tokio::sync::oneshot;
2+
use std::sync::{Arc, OnceLock};
3+
use tokio::sync::{oneshot, Notify};
44

55
use super::{
66
io::{RSGIHTTPProtocol as HTTPProtocol, RSGIWebsocketProtocol as WebsocketProtocol, WebsocketDetachedTransport},
@@ -118,11 +118,12 @@ impl CallbackWatcherWebsocket {
118118
pub(crate) fn call_http(
119119
cb: ArcCBScheduler,
120120
rt: RuntimeRef,
121+
disconnect_guard: Arc<Notify>,
121122
body: hyper::body::Incoming,
122123
scope: HTTPScope,
123124
) -> oneshot::Receiver<PyResponse> {
124125
let (tx, rx) = oneshot::channel();
125-
let protocol = HTTPProtocol::new(rt.clone(), tx, body);
126+
let protocol = HTTPProtocol::new(rt.clone(), tx, body, disconnect_guard);
126127

127128
rt.spawn_blocking(move |py| {
128129
if let Ok(watcher) = CallbackWatcherHTTP::new(py, protocol, scope) {

src/rsgi/http.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use futures::sink::SinkExt;
22
use http_body_util::BodyExt;
33
use hyper::{header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, StatusCode};
4-
use std::net::SocketAddr;
5-
use tokio::sync::mpsc;
4+
use std::{net::SocketAddr, sync::Arc};
5+
use tokio::sync::{mpsc, Notify};
66

77
use super::{
88
callbacks::{call_http, call_ws},
@@ -30,8 +30,8 @@ macro_rules! build_scope {
3030
}
3131

3232
macro_rules! handle_http_response {
33-
($handler:expr, $rt:expr, $callback:expr, $body:expr, $scope:expr) => {
34-
match $handler($callback, $rt, $body, $scope).await {
33+
($handler:expr, $rt:expr, $disconnect_guard:expr, $callback:expr, $body:expr, $scope:expr) => {
34+
match $handler($callback, $rt, $disconnect_guard, $body, $scope).await {
3535
Ok(PyResponse::Body(pyres)) => pyres.to_response(),
3636
Ok(PyResponse::File(pyres)) => pyres.to_response().await,
3737
_ => {
@@ -47,6 +47,7 @@ macro_rules! handle_request {
4747
#[inline]
4848
pub(crate) async fn $func_name(
4949
rt: RuntimeRef,
50+
disconnect_guard: Arc<Notify>,
5051
callback: ArcCBScheduler,
5152
server_addr: SocketAddr,
5253
client_addr: SocketAddr,
@@ -55,7 +56,7 @@ macro_rules! handle_request {
5556
) -> HTTPResponse {
5657
let (parts, body) = req.into_parts();
5758
let scope = build_scope!(HTTPScope, server_addr, client_addr, parts, scheme);
58-
handle_http_response!($handler, rt, callback, body, scope)
59+
handle_http_response!($handler, rt, disconnect_guard, callback, body, scope)
5960
}
6061
};
6162
}
@@ -65,6 +66,7 @@ macro_rules! handle_request_with_ws {
6566
#[inline]
6667
pub(crate) async fn $func_name(
6768
rt: RuntimeRef,
69+
disconnect_guard: Arc<Notify>,
6870
callback: ArcCBScheduler,
6971
server_addr: SocketAddr,
7072
client_addr: SocketAddr,
@@ -131,7 +133,7 @@ macro_rules! handle_request_with_ws {
131133

132134
let (parts, body) = req.into_parts();
133135
let scope = build_scope!(HTTPScope, server_addr, client_addr, parts, scheme);
134-
handle_http_response!($handler_req, rt, callback, body, scope)
136+
handle_http_response!($handler_req, rt, disconnect_guard, callback, body, scope)
135137
}
136138
};
137139
}

0 commit comments

Comments
 (0)