Skip to content

Commit 9e0498d

Browse files
committed
http2: use synthetic timers for ping timeouts in tests
Change-Id: I642890519b066937ade3c13e8387c31d29e912f4 Reviewed-on: https://go-review.googlesource.com/c/net/+/572377 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]>
1 parent 31d9683 commit 9e0498d

File tree

4 files changed

+240
-70
lines changed

4 files changed

+240
-70
lines changed

http2/clientconn_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
123123
tc.fr.SetMaxReadFrameSize(10 << 20)
124124

125125
t.Cleanup(func() {
126+
tc.sync()
126127
if tc.rerr == nil {
127128
tc.rerr = io.EOF
128129
}
@@ -459,6 +460,14 @@ func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, he
459460
tc.sync()
460461
}
461462

463+
func (tc *testClientConn) writePing(ack bool, data [8]byte) {
464+
tc.t.Helper()
465+
if err := tc.fr.WritePing(ack, data); err != nil {
466+
tc.t.Fatal(err)
467+
}
468+
tc.sync()
469+
}
470+
462471
func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
463472
tc.t.Helper()
464473
if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {

http2/testsync.go

+93-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package http2
55

66
import (
7+
"context"
78
"sync"
89
"time"
910
)
@@ -173,25 +174,64 @@ func (h *testSyncHooks) condWait(cond *sync.Cond) {
173174
h.unlock()
174175
}
175176

176-
// newTimer creates a new timer: A time.Timer if h is nil, or a synthetic timer in tests.
177+
// newTimer creates a new fake timer.
177178
func (h *testSyncHooks) newTimer(d time.Duration) timer {
178179
h.lock()
179180
defer h.unlock()
180181
t := &fakeTimer{
181-
when: h.now.Add(d),
182-
c: make(chan time.Time),
182+
hooks: h,
183+
when: h.now.Add(d),
184+
c: make(chan time.Time),
183185
}
184186
h.timers = append(h.timers, t)
185187
return t
186188
}
187189

190+
// afterFunc creates a new fake AfterFunc timer.
191+
func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer {
192+
h.lock()
193+
defer h.unlock()
194+
t := &fakeTimer{
195+
hooks: h,
196+
when: h.now.Add(d),
197+
f: f,
198+
}
199+
h.timers = append(h.timers, t)
200+
return t
201+
}
202+
203+
func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
204+
ctx, cancel := context.WithCancel(ctx)
205+
t := h.afterFunc(d, cancel)
206+
return ctx, func() {
207+
t.Stop()
208+
cancel()
209+
}
210+
}
211+
212+
func (h *testSyncHooks) timeUntilEvent() time.Duration {
213+
h.lock()
214+
defer h.unlock()
215+
var next time.Time
216+
for _, t := range h.timers {
217+
if next.IsZero() || t.when.Before(next) {
218+
next = t.when
219+
}
220+
}
221+
if d := next.Sub(h.now); d > 0 {
222+
return d
223+
}
224+
return 0
225+
}
226+
188227
// advance advances time and causes synthetic timers to fire.
189228
func (h *testSyncHooks) advance(d time.Duration) {
190229
h.lock()
191230
defer h.unlock()
192231
h.now = h.now.Add(d)
193232
timers := h.timers[:0]
194233
for _, t := range h.timers {
234+
t := t // remove after go.mod depends on go1.22
195235
t.mu.Lock()
196236
switch {
197237
case t.when.After(h.now):
@@ -200,7 +240,20 @@ func (h *testSyncHooks) advance(d time.Duration) {
200240
// stopped timer
201241
default:
202242
t.when = time.Time{}
203-
close(t.c)
243+
if t.c != nil {
244+
close(t.c)
245+
}
246+
if t.f != nil {
247+
h.total++
248+
go func() {
249+
defer func() {
250+
h.lock()
251+
h.total--
252+
h.unlock()
253+
}()
254+
t.f()
255+
}()
256+
}
204257
}
205258
t.mu.Unlock()
206259
}
@@ -212,13 +265,16 @@ func (h *testSyncHooks) advance(d time.Duration) {
212265
type timer interface {
213266
C() <-chan time.Time
214267
Stop() bool
268+
Reset(d time.Duration) bool
215269
}
216270

271+
// timeTimer implements timer using real time.
217272
type timeTimer struct {
218273
t *time.Timer
219274
c chan time.Time
220275
}
221276

277+
// newTimeTimer creates a new timer using real time.
222278
func newTimeTimer(d time.Duration) timer {
223279
ch := make(chan time.Time)
224280
t := time.AfterFunc(d, func() {
@@ -227,20 +283,49 @@ func newTimeTimer(d time.Duration) timer {
227283
return &timeTimer{t, ch}
228284
}
229285

230-
func (t timeTimer) C() <-chan time.Time { return t.c }
231-
func (t timeTimer) Stop() bool { return t.t.Stop() }
286+
// newTimeAfterFunc creates an AfterFunc timer using real time.
287+
func newTimeAfterFunc(d time.Duration, f func()) timer {
288+
return &timeTimer{
289+
t: time.AfterFunc(d, f),
290+
}
291+
}
232292

293+
func (t timeTimer) C() <-chan time.Time { return t.c }
294+
func (t timeTimer) Stop() bool { return t.t.Stop() }
295+
func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
296+
297+
// fakeTimer implements timer using fake time.
233298
type fakeTimer struct {
299+
hooks *testSyncHooks
300+
234301
mu sync.Mutex
235-
when time.Time
236-
c chan time.Time
302+
when time.Time // when the timer will fire
303+
c chan time.Time // closed when the timer fires; mutually exclusive with f
304+
f func() // called when the timer fires; mutually exclusive with c
237305
}
238306

239307
func (t *fakeTimer) C() <-chan time.Time { return t.c }
308+
240309
func (t *fakeTimer) Stop() bool {
241310
t.mu.Lock()
242311
defer t.mu.Unlock()
243312
stopped := t.when.IsZero()
244313
t.when = time.Time{}
245314
return stopped
246315
}
316+
317+
func (t *fakeTimer) Reset(d time.Duration) bool {
318+
if t.c != nil || t.f == nil {
319+
panic("fakeTimer only supports Reset on AfterFunc timers")
320+
}
321+
t.mu.Lock()
322+
defer t.mu.Unlock()
323+
t.hooks.lock()
324+
defer t.hooks.unlock()
325+
active := !t.when.IsZero()
326+
t.when = t.hooks.now.Add(d)
327+
if !active {
328+
t.hooks.timers = append(t.hooks.timers, t)
329+
}
330+
return active
331+
}

http2/transport.go

+56-13
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,21 @@ func (cc *ClientConn) newTimer(d time.Duration) timer {
391391
return newTimeTimer(d)
392392
}
393393

394+
// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
395+
func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer {
396+
if cc.syncHooks != nil {
397+
return cc.syncHooks.afterFunc(d, f)
398+
}
399+
return newTimeAfterFunc(d, f)
400+
}
401+
402+
func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
403+
if cc.syncHooks != nil {
404+
return cc.syncHooks.contextWithTimeout(ctx, d)
405+
}
406+
return context.WithTimeout(ctx, d)
407+
}
408+
394409
// clientStream is the state for a single HTTP/2 stream. One of these
395410
// is created for each Transport.RoundTrip call.
396411
type clientStream struct {
@@ -875,7 +890,7 @@ func (cc *ClientConn) healthCheck() {
875890
pingTimeout := cc.t.pingTimeout()
876891
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
877892
// trigger the healthCheck again if there is no frame received.
878-
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
893+
ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
879894
defer cancel()
880895
cc.vlogf("http2: Transport sending health check")
881896
err := cc.Ping(ctx)
@@ -1432,6 +1447,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
14321447
if cc.reqHeaderMu == nil {
14331448
panic("RoundTrip on uninitialized ClientConn") // for tests
14341449
}
1450+
var newStreamHook func(*clientStream)
1451+
if cc.syncHooks != nil {
1452+
newStreamHook = cc.syncHooks.newstream
1453+
cc.syncHooks.blockUntil(func() bool {
1454+
select {
1455+
case cc.reqHeaderMu <- struct{}{}:
1456+
<-cc.reqHeaderMu
1457+
case <-cs.reqCancel:
1458+
case <-ctx.Done():
1459+
default:
1460+
return false
1461+
}
1462+
return true
1463+
})
1464+
}
14351465
select {
14361466
case cc.reqHeaderMu <- struct{}{}:
14371467
case <-cs.reqCancel:
@@ -1456,8 +1486,8 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) {
14561486
}
14571487
cc.mu.Unlock()
14581488

1459-
if cc.syncHooks != nil {
1460-
cc.syncHooks.newstream(cs)
1489+
if newStreamHook != nil {
1490+
newStreamHook(cs)
14611491
}
14621492

14631493
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
@@ -2369,10 +2399,9 @@ func (rl *clientConnReadLoop) run() error {
23692399
cc := rl.cc
23702400
gotSettings := false
23712401
readIdleTimeout := cc.t.ReadIdleTimeout
2372-
var t *time.Timer
2402+
var t timer
23732403
if readIdleTimeout != 0 {
2374-
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
2375-
defer t.Stop()
2404+
t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
23762405
}
23772406
for {
23782407
f, err := cc.fr.ReadFrame()
@@ -3067,24 +3096,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error {
30673096
}
30683097
cc.mu.Unlock()
30693098
}
3070-
errc := make(chan error, 1)
3099+
var pingError error
3100+
errc := make(chan struct{})
30713101
cc.goRun(func() {
30723102
cc.wmu.Lock()
30733103
defer cc.wmu.Unlock()
3074-
if err := cc.fr.WritePing(false, p); err != nil {
3075-
errc <- err
3104+
if pingError = cc.fr.WritePing(false, p); pingError != nil {
3105+
close(errc)
30763106
return
30773107
}
3078-
if err := cc.bw.Flush(); err != nil {
3079-
errc <- err
3108+
if pingError = cc.bw.Flush(); pingError != nil {
3109+
close(errc)
30803110
return
30813111
}
30823112
})
3113+
if cc.syncHooks != nil {
3114+
cc.syncHooks.blockUntil(func() bool {
3115+
select {
3116+
case <-c:
3117+
case <-errc:
3118+
case <-ctx.Done():
3119+
case <-cc.readerDone:
3120+
default:
3121+
return false
3122+
}
3123+
return true
3124+
})
3125+
}
30833126
select {
30843127
case <-c:
30853128
return nil
3086-
case err := <-errc:
3087-
return err
3129+
case <-errc:
3130+
return pingError
30883131
case <-ctx.Done():
30893132
return ctx.Err()
30903133
case <-cc.readerDone:

0 commit comments

Comments
 (0)