Source file
src/net/http/transport_dial_test.go
1
2
3
4
5 package http_test
6
7 import (
8 "context"
9 "crypto/tls"
10 "errors"
11 "io"
12 "net"
13 "net/http"
14 "net/http/httptrace"
15 "strings"
16 "sync"
17 "testing"
18 "testing/synctest"
19 )
20
21
22 func TestTransportPoolConnReusePriorConnection(t *testing.T) {
23 synctest.Test(t, func(t *testing.T) {
24 dt := newTransportDialTester(t, http1Mode)
25
26
27 rt1 := dt.roundTrip()
28 c1 := dt.wantDial()
29 c1.finish(nil)
30 rt1.wantDone(c1, "HTTP/1.1")
31 rt1.finish()
32
33
34 rt2 := dt.roundTrip()
35 rt2.wantDone(c1, "HTTP/1.1")
36 rt2.finish()
37 })
38 }
39
40
41 func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
42 synctest.Test(t, func(t *testing.T) {
43 dt := newTransportDialTester(t, http1Mode)
44
45
46 rt1 := dt.roundTrip()
47 c1 := dt.wantDial()
48 c1.finish(nil)
49 rt1.wantDone(c1, "HTTP/1.1")
50
51
52
53 rt2 := dt.roundTrip()
54 c2 := dt.wantDial()
55 c2.finish(nil)
56 rt2.wantDone(c2, "HTTP/1.1")
57 })
58 }
59
60
61
62 func testTransportPoolConnHTTP2NoStrictMaxConcurrentRequests(t *testing.T) {
63 synctest.Test(t, func(t *testing.T) {
64 dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {
65 srv.HTTP2 = &http.HTTP2Config{
66 MaxConcurrentStreams: 2,
67 }
68 })
69
70
71 rt1 := dt.roundTrip()
72 c1 := dt.wantDial()
73 c1.finish(nil)
74 rt1.wantDone(c1, "HTTP/2.0")
75
76
77 rt2 := dt.roundTrip()
78 rt2.wantDone(c1, "HTTP/2.0")
79
80
81 rt3 := dt.roundTrip()
82 c2 := dt.wantDial()
83 c2.finish(nil)
84 rt3.wantDone(c2, "HTTP/2.0")
85
86 rt1.finish()
87 rt2.finish()
88 rt3.finish()
89
90
91 rt4 := dt.roundTrip()
92 rt4.wantDone(c1, "HTTP/2.0")
93 rt5 := dt.roundTrip()
94 rt5.wantDone(c1, "HTTP/2.0")
95 rt6 := dt.roundTrip()
96 rt6.wantDone(c2, "HTTP/2.0")
97 rt4.finish()
98 rt5.finish()
99 rt6.finish()
100 })
101 }
102
103
104
105
106 func TestTransportPoolConnHTTP2StrictMaxConcurrentRequests(t *testing.T) {
107 t.Skip("skipped until h2_bundle.go includes support for StrictMaxConcurrentRequests")
108
109 synctest.Test(t, func(t *testing.T) {
110 dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {
111 srv.HTTP2.MaxConcurrentStreams = 2
112 }, func(tr *http.Transport) {
113 tr.HTTP2 = &http.HTTP2Config{
114 StrictMaxConcurrentRequests: true,
115 }
116 })
117
118
119 rt1 := dt.roundTrip()
120 c1 := dt.wantDial()
121 c1.finish(nil)
122 rt1.wantDone(c1, "HTTP/2.0")
123
124
125 rt2 := dt.roundTrip()
126 rt2.wantDone(c1, "HTTP/2.0")
127
128
129 rt3 := dt.roundTrip()
130
131
132 rt1.finish()
133 rt3.wantDone(c1, "HTTP/2.0")
134
135 rt2.finish()
136 rt3.finish()
137 })
138 }
139
140
141 func TestTransportPoolConnHTTP2Startup(t *testing.T) {
142 synctest.Test(t, func(t *testing.T) {
143 dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {})
144
145
146
147 rt1 := dt.roundTrip()
148 rt2 := dt.roundTrip()
149 c1 := dt.wantDial()
150 c2 := dt.wantDial()
151
152
153 c1.finish(nil)
154 rt1.wantDone(c1, "HTTP/2.0")
155 rt2.wantDone(c1, "HTTP/2.0")
156
157 rt1.finish()
158 rt2.finish()
159 c2.finish(nil)
160 })
161 }
162
163
164
165 func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
166 synctest.Test(t, func(t *testing.T) {
167 dt := newTransportDialTester(t, http1Mode)
168
169
170 rt1 := dt.roundTrip()
171 c1 := dt.wantDial()
172 c1.finish(nil)
173 rt1.wantDone(c1, "HTTP/1.1")
174
175
176
177
178 rt2 := dt.roundTrip()
179 c2 := dt.wantDial()
180 rt1.finish()
181 rt2.wantDone(c1, "HTTP/1.1")
182
183
184
185
186
187 rt3 := dt.roundTrip()
188 c3 := dt.wantDial()
189 c2.finish(nil)
190 rt3.wantDone(c2, "HTTP/1.1")
191
192 c3.finish(nil)
193 })
194 }
195
196
197 func TestTransportPoolDisableKeepAlives(t *testing.T) {
198 synctest.Test(t, func(t *testing.T) {
199 dt := newTransportDialTester(t, http1Mode, func(tr *http.Transport) {
200 tr.DisableKeepAlives = true
201 })
202
203
204 for range 2 {
205 rt := dt.roundTrip()
206 c := dt.wantDial()
207 c.finish(nil)
208 rt.wantDone(c, "HTTP/1.1")
209 rt.finish()
210 }
211 })
212 }
213
214
215 func TestTransportPoolCancelRequestReusesConn(t *testing.T) {
216 synctest.Test(t, func(t *testing.T) {
217 dt := newTransportDialTester(t, http1Mode)
218
219
220 rt1 := dt.roundTrip()
221 c1 := dt.wantDial()
222 rt1.cancel()
223 rt1.wantError()
224
225
226 rt2 := dt.roundTrip()
227 c2 := dt.wantDial()
228 c1.finish(nil)
229 rt2.wantDone(c1, "HTTP/1.1")
230 rt2.finish()
231
232 c2.finish(nil)
233 })
234 }
235
236
237 func TestTransportPoolCancelRequestWithDisableKeepAlives(t *testing.T) {
238 synctest.Test(t, func(t *testing.T) {
239 dt := newTransportDialTester(t, http1Mode, func(tr *http.Transport) {
240 tr.DisableKeepAlives = true
241 })
242
243
244 rt1 := dt.roundTrip()
245 c1 := dt.wantDial()
246 rt1.cancel()
247 rt1.wantError()
248
249
250 c1.finish(nil)
251
252
253 rt2 := dt.roundTrip()
254 c2 := dt.wantDial()
255 c2.finish(nil)
256 rt2.wantDone(c2, "HTTP/1.1")
257 rt2.finish()
258 })
259 }
260
261
262 func TestTransportPoolConnectionBroken(t *testing.T) {
263 synctest.Test(t, func(t *testing.T) {
264 dt := newTransportDialTester(t, http1Mode)
265
266
267
268 rt1 := dt.roundTrip()
269 c1 := dt.wantDial()
270 c1.finish(nil)
271 rt1.wantDone(c1, "HTTP/1.1")
272 c1.fakeNetConn.Close()
273 rt1.finish()
274
275
276 rt2 := dt.roundTrip()
277 c2 := dt.wantDial()
278 c2.finish(nil)
279 rt2.wantDone(c2, "HTTP/1.1")
280 c2.fakeNetConn.Close()
281 rt2.finish()
282 })
283 }
284
285
286 func TestTransportPoolClosesConnsPastMaxIdleConnsPerHost(t *testing.T) {
287 synctest.Test(t, func(t *testing.T) {
288 dt := newTransportDialTester(t, http1Mode, func(tr *http.Transport) {
289 tr.MaxIdleConnsPerHost = 1
290 })
291
292
293 rt1 := dt.roundTrip("host1.fake.tld")
294 c1 := dt.wantDial()
295 c1.finish(nil)
296 rt1.wantDone(c1, "HTTP/1.1")
297
298
299 rt2 := dt.roundTrip("host1.fake.tld")
300 c2 := dt.wantDial()
301 c2.finish(nil)
302 rt2.wantDone(c2, "HTTP/1.1")
303
304
305 rt3 := dt.roundTrip("host2.fake.tld")
306 c3 := dt.wantDial()
307 c3.finish(nil)
308 rt3.wantDone(c3, "HTTP/1.1")
309
310
311 rt3.finish()
312 rt2.finish()
313 rt1.finish()
314 c1.wantClosed()
315
316
317 rt4 := dt.roundTrip("host1.fake.tld")
318 rt4.wantDone(c2, "HTTP/1.1")
319 rt4.finish()
320 rt5 := dt.roundTrip("host2.fake.tld")
321 rt5.wantDone(c3, "HTTP/1.1")
322 rt5.finish()
323 })
324 }
325
326
327
328 func TestTransportPoolMaxIdleConnsPerHostHTTP2(t *testing.T) {
329 synctest.Test(t, func(t *testing.T) {
330 t.Skip("skipped until h2_bundle.go includes support for MaxConcurrentStreams")
331
332 dt := newTransportDialTester(t, http2Mode, func(srv *http.Server) {
333 srv.HTTP2 = &http.HTTP2Config{
334 MaxConcurrentStreams: 1,
335 }
336 }, func(tr *http.Transport) {
337 tr.MaxIdleConnsPerHost = 1
338 })
339
340
341 rt1 := dt.roundTrip()
342 c1 := dt.wantDial()
343 c1.finish(nil)
344 rt1.wantDone(c1, "HTTP/2.0")
345
346
347 rt2 := dt.roundTrip()
348 c2 := dt.wantDial()
349 c2.finish(nil)
350 rt2.wantDone(c2, "HTTP/2.0")
351
352
353
354 rt1.finish()
355 rt2.finish()
356
357
358 rt3 := dt.roundTrip()
359 rt3.wantDone(c1, "HTTP/2.0")
360 rt4 := dt.roundTrip()
361 rt4.wantDone(c2, "HTTP/2.0")
362 })
363 }
364
365
366 type transportDialTester struct {
367 t *testing.T
368 cst *clientServerTest
369
370 dialsMu sync.Mutex
371 dials []*transportDialTesterConn
372
373 roundTripCount int
374 dialCount int
375 }
376
377
378 type transportDialTesterRoundTrip struct {
379 t *testing.T
380
381 roundTripID int
382 cancel context.CancelFunc
383 reqBody io.WriteCloser
384 respBodyClosed bool
385 returned bool
386
387 res *http.Response
388 err error
389 conn *transportDialTesterConn
390 }
391
392
393
394 type transportDialTesterConn struct {
395 t *testing.T
396
397 connID int
398 ready chan error
399 protos []string
400 closed chan struct{}
401
402 *fakeNetConn
403 }
404
405 func newTransportDialTester(t *testing.T, mode testMode, opts ...any) *transportDialTester {
406 t.Helper()
407 dt := &transportDialTester{
408 t: t,
409 }
410 dialContext := func(_ context.Context, network, address string) (*transportDialTesterConn, error) {
411 c := &transportDialTesterConn{
412 t: t,
413 ready: make(chan error),
414 closed: make(chan struct{}),
415 }
416
417
418 dt.dialsMu.Lock()
419 dt.dials = append(dt.dials, c)
420 dt.dialsMu.Unlock()
421
422 select {
423 case err := <-c.ready:
424 if err != nil {
425 return nil, err
426 }
427 case <-t.Context().Done():
428 t.Errorf("test finished with dial in progress")
429 return nil, errors.New("test finished")
430 }
431
432 c.fakeNetConn = dt.cst.li.connect()
433 t.Cleanup(func() {
434 c.fakeNetConn.Close()
435 })
436
437
438 return c, nil
439 }
440 dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
441
442 http.NewResponseController(w).EnableFullDuplex()
443 w.WriteHeader(200)
444 http.NewResponseController(w).Flush()
445
446
447 io.ReadAll(r.Body)
448 }), append([]any{optFakeNet, func(tr *http.Transport) {
449 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
450 return dialContext(ctx, network, dt.cst.ts.Listener.Addr().String())
451 }
452 tr.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
453 conn, err := dialContext(ctx, network, dt.cst.ts.Listener.Addr().String())
454 if err != nil {
455 return nil, err
456 }
457 config := &tls.Config{
458 InsecureSkipVerify: true,
459 NextProtos: []string{"h2", "http/1.1"},
460 }
461 if conn.protos != nil {
462 config.NextProtos = conn.protos
463 }
464 tc := tls.Client(conn, config)
465 if err := tc.Handshake(); err != nil {
466 return nil, err
467 }
468 return tc, nil
469 }
470 }}, opts...)...)
471 return dt
472 }
473
474
475
476 func (dt *transportDialTester) roundTrip(opts ...any) *transportDialTesterRoundTrip {
477 dt.t.Helper()
478 host := "fake.tld"
479 for _, o := range opts {
480 switch o := o.(type) {
481 case string:
482 host = o
483 default:
484 dt.t.Fatalf("unknown option type %T", o)
485 }
486 }
487 ctx, cancel := context.WithCancel(context.Background())
488 pr, pw := io.Pipe()
489 dt.roundTripCount++
490 rt := &transportDialTesterRoundTrip{
491 t: dt.t,
492 roundTripID: dt.roundTripCount,
493 reqBody: pw,
494 cancel: cancel,
495 }
496 dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
497 dt.t.Cleanup(func() {
498 rt.cancel()
499 rt.finish()
500 })
501 go func() {
502 ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
503 GotConn: func(info httptrace.GotConnInfo) {
504 c := info.Conn
505 if tlsConn, ok := c.(*tls.Conn); ok {
506 c = tlsConn.NetConn()
507 }
508 rt.conn = c.(*transportDialTesterConn)
509 },
510 })
511 proto, _, _ := strings.Cut(dt.cst.ts.URL, ":")
512 req, _ := http.NewRequestWithContext(ctx, "POST", proto+"://"+host, pr)
513 req.Header.Set("Content-Type", "text/plain")
514 rt.res, rt.err = dt.cst.tr.RoundTrip(req)
515 dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
516 rt.returned = true
517 }()
518 return rt
519 }
520
521
522 func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn, wantProto string) {
523 rt.t.Helper()
524 synctest.Wait()
525 if !rt.returned {
526 rt.t.Fatalf("RoundTrip %v: still running, want to have returned", rt.roundTripID)
527 }
528 if rt.err != nil {
529 rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
530 }
531 if rt.conn != c {
532 rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
533 }
534 if got, want := rt.conn, c; got != want {
535 rt.t.Fatalf("RoundTrip %v: sent on conn %v, want conn %v", rt.roundTripID, got.connID, want.connID)
536 }
537 if got, want := rt.res.Proto, wantProto; got != want {
538 rt.t.Fatalf("RoundTrip %v: got protocol %q, want %q", rt.roundTripID, got, want)
539 }
540 }
541
542
543 func (rt *transportDialTesterRoundTrip) wantError() {
544 rt.t.Helper()
545 synctest.Wait()
546 if !rt.returned {
547 rt.t.Fatalf("RoundTrip %v: still running, want to have returned", rt.roundTripID)
548 }
549 if rt.err == nil {
550 rt.t.Fatalf("RoundTrip %v: success, want error", rt.roundTripID)
551 }
552 }
553
554
555
556 func (rt *transportDialTesterRoundTrip) finish() {
557 rt.t.Helper()
558
559 synctest.Wait()
560 if !rt.returned {
561 rt.t.Fatalf("RoundTrip %v: still running, want to have returned", rt.roundTripID)
562 }
563 if rt.err != nil {
564 return
565 }
566
567 if rt.respBodyClosed {
568 return
569 }
570 rt.respBodyClosed = true
571 rt.reqBody.Close()
572 io.ReadAll(rt.res.Body)
573 rt.res.Body.Close()
574 rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
575 }
576
577
578 func (dt *transportDialTester) wantDial() *transportDialTesterConn {
579 dt.t.Helper()
580 synctest.Wait()
581 dt.dialsMu.Lock()
582 defer dt.dialsMu.Unlock()
583 if len(dt.dials) == 0 {
584 dt.t.Fatalf("no dial started, want one")
585 }
586 c := dt.dials[0]
587 dt.dials = dt.dials[1:]
588 dt.dialCount++
589 c.connID = dt.dialCount
590 dt.t.Logf("Dial %v: started", c.connID)
591 return c
592 }
593
594
595 func (c *transportDialTesterConn) finish(err error) {
596 c.t.Helper()
597 c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
598 c.ready <- err
599 close(c.ready)
600 }
601
602 func (c *transportDialTesterConn) wantClosed() {
603 c.t.Helper()
604 <-c.closed
605 }
606
607 func (c *transportDialTesterConn) Close() error {
608 select {
609 case <-c.closed:
610 default:
611 c.t.Logf("Conn %v: closed", c.connID)
612 close(c.closed)
613 }
614 return nil
615 }
616
View as plain text