Skip to content

Commit 30a4ae9

Browse files
committed
http2: support trailer headers
1 parent 9820975 commit 30a4ae9

File tree

4 files changed

+418
-18
lines changed

4 files changed

+418
-18
lines changed

httpcore/_async/http2.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def __init__(
7171
h2.events.ResponseReceived
7272
| h2.events.DataReceived
7373
| h2.events.StreamEnded
74-
| h2.events.StreamReset,
74+
| h2.events.StreamReset
75+
| h2.events.TrailersReceived,
7576
],
7677
] = {}
7778

79+
# Mapping from stream ID to trailing headers
80+
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}
81+
7882
# Connection terminated events are stored as state since
7983
# we need to handle them for all streams.
8084
self._connection_terminated: h2.events.ConnectionTerminated | None = None
@@ -152,16 +156,24 @@ async def handle_async_request(self, request: Request) -> Response:
152156
)
153157
trace.return_value = (status, headers)
154158

155-
return Response(
159+
extensions = {
160+
"http_version": b"HTTP/2",
161+
"network_stream": self._network_stream,
162+
"stream_id": stream_id,
163+
}
164+
165+
http2_stream = HTTP2ConnectionByteStream(self, request, stream_id=stream_id)
166+
167+
response = Response(
156168
status=status,
157169
headers=headers,
158-
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159-
extensions={
160-
"http_version": b"HTTP/2",
161-
"network_stream": self._network_stream,
162-
"stream_id": stream_id,
163-
},
170+
content=http2_stream,
171+
extensions=extensions,
164172
)
173+
174+
http2_stream.set_response(response)
175+
176+
return response
165177
except BaseException as exc: # noqa: PIE786
166178
with AsyncShieldCancellation():
167179
kwargs = {"stream_id": stream_id}
@@ -321,12 +333,21 @@ async def _receive_response_body(
321333
self._h2_state.acknowledge_received_data(amount, stream_id)
322334
await self._write_outgoing_data(request)
323335
yield event.data
336+
elif isinstance(event, h2.events.TrailersReceived):
337+
# Process trailing headers but continue receiving events
338+
# The trailing headers are already stored in self._trailing_headers
339+
continue
324340
elif isinstance(event, h2.events.StreamEnded):
325341
break
326342

327343
async def _receive_stream_event(
328344
self, request: Request, stream_id: int
329-
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
345+
) -> (
346+
h2.events.ResponseReceived
347+
| h2.events.DataReceived
348+
| h2.events.StreamEnded
349+
| h2.events.TrailersReceived
350+
):
330351
"""
331352
Return the next available event for a given stream ID.
332353
@@ -377,10 +398,19 @@ async def _receive_events(
377398
h2.events.DataReceived,
378399
h2.events.StreamEnded,
379400
h2.events.StreamReset,
401+
h2.events.TrailersReceived,
380402
),
381403
):
382404
if event.stream_id in self._events:
383405
self._events[event.stream_id].append(event)
406+
if isinstance(event, h2.events.TrailersReceived):
407+
self._trailing_headers[event.stream_id] = []
408+
if event.headers is not None:
409+
for k, v in event.headers:
410+
if not k.startswith(b":"):
411+
self._trailing_headers[
412+
event.stream_id
413+
].append((k, v))
384414

385415
elif isinstance(event, h2.events.ConnectionTerminated):
386416
self._connection_terminated = event
@@ -409,6 +439,8 @@ async def _receive_remote_settings_change(
409439
async def _response_closed(self, stream_id: int) -> None:
410440
await self._max_streams_semaphore.release()
411441
del self._events[stream_id]
442+
if stream_id in self._trailing_headers:
443+
del self._trailing_headers[stream_id]
412444
async with self._state_lock:
413445
if self._connection_terminated and not self._events:
414446
await self.aclose()
@@ -567,6 +599,10 @@ def __init__(
567599
self._request = request
568600
self._stream_id = stream_id
569601
self._closed = False
602+
self._response: Response | None = None
603+
604+
def set_response(self, response: Response) -> None:
605+
self._response = response
570606

571607
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
572608
kwargs = {"request": self._request, "stream_id": self._stream_id}
@@ -576,6 +612,14 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576612
request=self._request, stream_id=self._stream_id
577613
):
578614
yield chunk
615+
616+
if (
617+
self._response is not None
618+
and self._stream_id in self._connection._trailing_headers
619+
):
620+
self._response.extensions["trailing_headers"] = (
621+
self._connection._trailing_headers[self._stream_id]
622+
)
579623
except BaseException as exc:
580624
# If we get an exception while streaming the response,
581625
# we want to close the response (and possibly the connection)

httpcore/_sync/http2.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def __init__(
7171
h2.events.ResponseReceived
7272
| h2.events.DataReceived
7373
| h2.events.StreamEnded
74-
| h2.events.StreamReset,
74+
| h2.events.StreamReset
75+
| h2.events.TrailersReceived,
7576
],
7677
] = {}
7778

79+
# Mapping from stream ID to trailing headers
80+
self._trailing_headers: dict[int, list[tuple[bytes, bytes]]] = {}
81+
7882
# Connection terminated events are stored as state since
7983
# we need to handle them for all streams.
8084
self._connection_terminated: h2.events.ConnectionTerminated | None = None
@@ -152,16 +156,24 @@ def handle_request(self, request: Request) -> Response:
152156
)
153157
trace.return_value = (status, headers)
154158

155-
return Response(
159+
extensions = {
160+
"http_version": b"HTTP/2",
161+
"network_stream": self._network_stream,
162+
"stream_id": stream_id,
163+
}
164+
165+
http2_stream = HTTP2ConnectionByteStream(self, request, stream_id=stream_id)
166+
167+
response = Response(
156168
status=status,
157169
headers=headers,
158-
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159-
extensions={
160-
"http_version": b"HTTP/2",
161-
"network_stream": self._network_stream,
162-
"stream_id": stream_id,
163-
},
170+
content=http2_stream,
171+
extensions=extensions,
164172
)
173+
174+
http2_stream.set_response(response)
175+
176+
return response
165177
except BaseException as exc: # noqa: PIE786
166178
with ShieldCancellation():
167179
kwargs = {"stream_id": stream_id}
@@ -321,12 +333,21 @@ def _receive_response_body(
321333
self._h2_state.acknowledge_received_data(amount, stream_id)
322334
self._write_outgoing_data(request)
323335
yield event.data
336+
elif isinstance(event, h2.events.TrailersReceived):
337+
# Process trailing headers but continue receiving events
338+
# The trailing headers are already stored in self._trailing_headers
339+
continue
324340
elif isinstance(event, h2.events.StreamEnded):
325341
break
326342

327343
def _receive_stream_event(
328344
self, request: Request, stream_id: int
329-
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
345+
) -> (
346+
h2.events.ResponseReceived
347+
| h2.events.DataReceived
348+
| h2.events.StreamEnded
349+
| h2.events.TrailersReceived
350+
):
330351
"""
331352
Return the next available event for a given stream ID.
332353
@@ -377,10 +398,19 @@ def _receive_events(
377398
h2.events.DataReceived,
378399
h2.events.StreamEnded,
379400
h2.events.StreamReset,
401+
h2.events.TrailersReceived,
380402
),
381403
):
382404
if event.stream_id in self._events:
383405
self._events[event.stream_id].append(event)
406+
if isinstance(event, h2.events.TrailersReceived):
407+
self._trailing_headers[event.stream_id] = []
408+
if event.headers is not None:
409+
for k, v in event.headers:
410+
if not k.startswith(b":"):
411+
self._trailing_headers[
412+
event.stream_id
413+
].append((k, v))
384414

385415
elif isinstance(event, h2.events.ConnectionTerminated):
386416
self._connection_terminated = event
@@ -409,6 +439,8 @@ def _receive_remote_settings_change(
409439
def _response_closed(self, stream_id: int) -> None:
410440
self._max_streams_semaphore.release()
411441
del self._events[stream_id]
442+
if stream_id in self._trailing_headers:
443+
del self._trailing_headers[stream_id]
412444
with self._state_lock:
413445
if self._connection_terminated and not self._events:
414446
self.close()
@@ -567,6 +599,10 @@ def __init__(
567599
self._request = request
568600
self._stream_id = stream_id
569601
self._closed = False
602+
self._response: Response | None = None
603+
604+
def set_response(self, response: Response) -> None:
605+
self._response = response
570606

571607
def __iter__(self) -> typing.Iterator[bytes]:
572608
kwargs = {"request": self._request, "stream_id": self._stream_id}
@@ -576,6 +612,14 @@ def __iter__(self) -> typing.Iterator[bytes]:
576612
request=self._request, stream_id=self._stream_id
577613
):
578614
yield chunk
615+
616+
if (
617+
self._response is not None
618+
and self._stream_id in self._connection._trailing_headers
619+
):
620+
self._response.extensions["trailing_headers"] = (
621+
self._connection._trailing_headers[self._stream_id]
622+
)
579623
except BaseException as exc:
580624
# If we get an exception while streaming the response,
581625
# we want to close the response (and possibly the connection)

0 commit comments

Comments
 (0)