@@ -366,6 +366,75 @@ async fn stream_write_zero() -> io::Result<()> {
366
366
Ok ( ( ) ) as io:: Result < ( ) >
367
367
}
368
368
369
+ #[ tokio:: test]
370
+ async fn stream_shutdown ( ) -> io:: Result < ( ) > {
371
+ struct TestIo < ' a > {
372
+ conn : & ' a mut Connection ,
373
+ is_shutdown : bool ,
374
+ }
375
+
376
+ impl AsyncRead for TestIo < ' _ > {
377
+ fn poll_read (
378
+ self : Pin < & mut Self > ,
379
+ _cx : & mut Context < ' _ > ,
380
+ _: & mut ReadBuf < ' _ > ,
381
+ ) -> Poll < io:: Result < ( ) > > {
382
+ Poll :: Pending
383
+ }
384
+ }
385
+
386
+ impl AsyncWrite for TestIo < ' _ > {
387
+ fn poll_write (
388
+ mut self : Pin < & mut Self > ,
389
+ _cx : & mut Context < ' _ > ,
390
+ mut buf : & [ u8 ] ,
391
+ ) -> Poll < io:: Result < usize > > {
392
+ let len = self . conn . read_tls ( buf. by_ref ( ) ) ?;
393
+ self . conn
394
+ . process_new_packets ( )
395
+ . map_err ( |err| io:: Error :: new ( io:: ErrorKind :: InvalidData , err) ) ?;
396
+ Poll :: Ready ( Ok ( len) )
397
+ }
398
+
399
+ fn poll_flush ( mut self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
400
+ self . conn
401
+ . process_new_packets ( )
402
+ . map_err ( |err| io:: Error :: new ( io:: ErrorKind :: InvalidData , err) ) ?;
403
+ Poll :: Ready ( Ok ( ( ) ) )
404
+ }
405
+
406
+ fn poll_shutdown ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
407
+ self . get_mut ( ) . is_shutdown = true ;
408
+ Poll :: Ready ( Ok ( ( ) ) )
409
+ }
410
+ }
411
+
412
+ let ( server, mut client) = make_pair ( ) ;
413
+ let mut server = Connection :: from ( server) ;
414
+ poll_fn ( |cx| do_handshake ( & mut client, & mut server, cx) ) . await ?;
415
+
416
+ {
417
+ let mut io = TestIo {
418
+ conn : & mut server,
419
+ is_shutdown : false ,
420
+ } ;
421
+ let mut stream = Stream :: new ( & mut io, & mut client) ;
422
+
423
+ stream. session . send_close_notify ( ) ;
424
+ stream. shutdown ( ) . await ?;
425
+
426
+ assert ! ( !io. is_shutdown) ;
427
+ }
428
+
429
+ assert ! ( !server. is_handshaking( ) ) ;
430
+ assert ! ( server
431
+ . process_new_packets( )
432
+ . map_err( io:: Error :: other) ?
433
+ . peer_has_closed( ) ) ;
434
+
435
+ Ok ( ( ) ) as io:: Result < ( ) >
436
+ }
437
+
369
438
fn make_pair ( ) -> ( ServerConnection , ClientConnection ) {
370
439
let ( sconfig, cconfig) = utils:: make_configs ( ) ;
371
440
let server = ServerConnection :: new ( Arc :: new ( sconfig) ) . unwrap ( ) ;
0 commit comments