@@ -71,10 +71,14 @@ def __init__(
71
71
h2 .events .ResponseReceived
72
72
| h2 .events .DataReceived
73
73
| h2 .events .StreamEnded
74
- | h2 .events .StreamReset ,
74
+ | h2 .events .StreamReset
75
+ | h2 .events .TrailersReceived ,
75
76
],
76
77
] = {}
77
78
79
+ # Mapping from stream ID to trailing headers
80
+ self ._trailing_headers : dict [int , list [tuple [bytes , bytes ]]] = {}
81
+
78
82
# Connection terminated events are stored as state since
79
83
# we need to handle them for all streams.
80
84
self ._connection_terminated : h2 .events .ConnectionTerminated | None = None
@@ -152,16 +156,24 @@ async def handle_async_request(self, request: Request) -> Response:
152
156
)
153
157
trace .return_value = (status , headers )
154
158
155
- return Response (
159
+ extensions = {
160
+ "http_version" : b"HTTP/2" ,
161
+ "network_stream" : self ._network_stream ,
162
+ "stream_id" : stream_id ,
163
+ }
164
+
165
+ http2_stream = HTTP2ConnectionByteStream (self , request , stream_id = stream_id )
166
+
167
+ response = Response (
156
168
status = status ,
157
169
headers = headers ,
158
- content = HTTP2ConnectionByteStream (self , request , stream_id = stream_id ),
159
- extensions = {
160
- "http_version" : b"HTTP/2" ,
161
- "network_stream" : self ._network_stream ,
162
- "stream_id" : stream_id ,
163
- },
170
+ content = http2_stream ,
171
+ extensions = extensions ,
164
172
)
173
+
174
+ http2_stream .set_response (response )
175
+
176
+ return response
165
177
except BaseException as exc : # noqa: PIE786
166
178
with AsyncShieldCancellation ():
167
179
kwargs = {"stream_id" : stream_id }
@@ -321,12 +333,21 @@ async def _receive_response_body(
321
333
self ._h2_state .acknowledge_received_data (amount , stream_id )
322
334
await self ._write_outgoing_data (request )
323
335
yield event .data
336
+ elif isinstance (event , h2 .events .TrailersReceived ):
337
+ # Process trailing headers but continue receiving events
338
+ # The trailing headers are already stored in self._trailing_headers
339
+ continue
324
340
elif isinstance (event , h2 .events .StreamEnded ):
325
341
break
326
342
327
343
async def _receive_stream_event (
328
344
self , request : Request , stream_id : int
329
- ) -> h2 .events .ResponseReceived | h2 .events .DataReceived | h2 .events .StreamEnded :
345
+ ) -> (
346
+ h2 .events .ResponseReceived
347
+ | h2 .events .DataReceived
348
+ | h2 .events .StreamEnded
349
+ | h2 .events .TrailersReceived
350
+ ):
330
351
"""
331
352
Return the next available event for a given stream ID.
332
353
@@ -377,10 +398,19 @@ async def _receive_events(
377
398
h2 .events .DataReceived ,
378
399
h2 .events .StreamEnded ,
379
400
h2 .events .StreamReset ,
401
+ h2 .events .TrailersReceived ,
380
402
),
381
403
):
382
404
if event .stream_id in self ._events :
383
405
self ._events [event .stream_id ].append (event )
406
+ if isinstance (event , h2 .events .TrailersReceived ):
407
+ self ._trailing_headers [event .stream_id ] = []
408
+ if event .headers is not None :
409
+ for k , v in event .headers :
410
+ if not k .startswith (b":" ):
411
+ self ._trailing_headers [
412
+ event .stream_id
413
+ ].append ((k , v ))
384
414
385
415
elif isinstance (event , h2 .events .ConnectionTerminated ):
386
416
self ._connection_terminated = event
@@ -409,6 +439,8 @@ async def _receive_remote_settings_change(
409
439
async def _response_closed (self , stream_id : int ) -> None :
410
440
await self ._max_streams_semaphore .release ()
411
441
del self ._events [stream_id ]
442
+ if stream_id in self ._trailing_headers :
443
+ del self ._trailing_headers [stream_id ]
412
444
async with self ._state_lock :
413
445
if self ._connection_terminated and not self ._events :
414
446
await self .aclose ()
@@ -567,6 +599,10 @@ def __init__(
567
599
self ._request = request
568
600
self ._stream_id = stream_id
569
601
self ._closed = False
602
+ self ._response : Response | None = None
603
+
604
+ def set_response (self , response : Response ) -> None :
605
+ self ._response = response
570
606
571
607
async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
572
608
kwargs = {"request" : self ._request , "stream_id" : self ._stream_id }
@@ -576,6 +612,14 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576
612
request = self ._request , stream_id = self ._stream_id
577
613
):
578
614
yield chunk
615
+
616
+ if (
617
+ self ._response is not None
618
+ and self ._stream_id in self ._connection ._trailing_headers
619
+ ):
620
+ self ._response .extensions ["trailing_headers" ] = (
621
+ self ._connection ._trailing_headers [self ._stream_id ]
622
+ )
579
623
except BaseException as exc :
580
624
# If we get an exception while streaming the response,
581
625
# we want to close the response (and possibly the connection)
0 commit comments