1
2
3
4
5
6
7 package httptest
8
9 import (
10 "context"
11 "crypto/tls"
12 "crypto/x509"
13 "flag"
14 "fmt"
15 "log"
16 "net"
17 "net/http"
18 "net/http/internal/testcert"
19 "os"
20 "strings"
21 "sync"
22 "time"
23 )
24
25
26
27 type Server struct {
28 URL string
29 Listener net.Listener
30
31
32
33
34 EnableHTTP2 bool
35
36
37
38
39 TLS *tls.Config
40
41
42
43 Config *http.Server
44
45
46 certificate *x509.Certificate
47
48
49
50 wg sync.WaitGroup
51
52 mu sync.Mutex
53 closed bool
54 conns map[net.Conn]http.ConnState
55
56
57
58 client *http.Client
59 }
60
61 func newLocalListener() net.Listener {
62 if serveFlag != "" {
63 l, err := net.Listen("tcp", serveFlag)
64 if err != nil {
65 panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
66 }
67 return l
68 }
69 l, err := net.Listen("tcp", "127.0.0.1:0")
70 if err != nil {
71 if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
72 panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
73 }
74 }
75 return l
76 }
77
78
79
80
81
82
83
84
85
86
87 var serveFlag string
88
89 func init() {
90 if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
91 flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
92 }
93 }
94
95 func strSliceContainsPrefix(v []string, pre string) bool {
96 for _, s := range v {
97 if strings.HasPrefix(s, pre) {
98 return true
99 }
100 }
101 return false
102 }
103
104
105
106 func NewServer(handler http.Handler) *Server {
107 ts := NewUnstartedServer(handler)
108 ts.Start()
109 return ts
110 }
111
112
113
114
115
116
117
118 func NewUnstartedServer(handler http.Handler) *Server {
119 return &Server{
120 Listener: newLocalListener(),
121 Config: &http.Server{Handler: handler},
122 }
123 }
124
125
126 func (s *Server) Start() {
127 if s.URL != "" {
128 panic("Server already started")
129 }
130
131 if s.client == nil {
132 tr := &http.Transport{}
133 dialer := net.Dialer{}
134
135
136
137 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
138 if tr.Dial != nil {
139 return tr.Dial(network, addr)
140 }
141 if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") {
142 addr = s.Listener.Addr().String()
143 }
144 return dialer.DialContext(ctx, network, addr)
145 }
146 s.client = &http.Client{Transport: tr}
147
148 }
149 s.URL = "http://" + s.Listener.Addr().String()
150 s.wrap()
151 s.goServe()
152 if serveFlag != "" {
153 fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
154 select {}
155 }
156 }
157
158
159 func (s *Server) StartTLS() {
160 if s.URL != "" {
161 panic("Server already started")
162 }
163 if s.client == nil {
164 s.client = &http.Client{}
165 }
166 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
167 if err != nil {
168 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
169 }
170
171 existingConfig := s.TLS
172 if existingConfig != nil {
173 s.TLS = existingConfig.Clone()
174 } else {
175 s.TLS = new(tls.Config)
176 }
177 if s.TLS.NextProtos == nil {
178 nextProtos := []string{"http/1.1"}
179 if s.EnableHTTP2 {
180 nextProtos = []string{"h2"}
181 }
182 s.TLS.NextProtos = nextProtos
183 }
184 if len(s.TLS.Certificates) == 0 {
185 s.TLS.Certificates = []tls.Certificate{cert}
186 }
187 s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
188 if err != nil {
189 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
190 }
191 certpool := x509.NewCertPool()
192 certpool.AddCert(s.certificate)
193 tr := &http.Transport{
194 TLSClientConfig: &tls.Config{
195 RootCAs: certpool,
196 },
197 ForceAttemptHTTP2: s.EnableHTTP2,
198 }
199 dialer := net.Dialer{}
200 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
201 if tr.Dial != nil {
202 return tr.Dial(network, addr)
203 }
204 if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") {
205 addr = s.Listener.Addr().String()
206 }
207 return dialer.DialContext(ctx, network, addr)
208 }
209 s.client.Transport = tr
210 s.Listener = tls.NewListener(s.Listener, s.TLS)
211 s.URL = "https://" + s.Listener.Addr().String()
212 s.wrap()
213 s.goServe()
214 }
215
216
217
218 func NewTLSServer(handler http.Handler) *Server {
219 ts := NewUnstartedServer(handler)
220 ts.StartTLS()
221 return ts
222 }
223
224 type closeIdleTransport interface {
225 CloseIdleConnections()
226 }
227
228
229
230 func (s *Server) Close() {
231 s.mu.Lock()
232 if !s.closed {
233 s.closed = true
234 s.Listener.Close()
235 s.Config.SetKeepAlivesEnabled(false)
236 for c, st := range s.conns {
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255 if st == http.StateIdle || st == http.StateNew {
256 s.closeConn(c)
257 }
258 }
259
260 t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
261 defer t.Stop()
262 }
263 s.mu.Unlock()
264
265
266
267
268 if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
269 t.CloseIdleConnections()
270 }
271
272
273 if s.client != nil {
274 if t, ok := s.client.Transport.(closeIdleTransport); ok {
275 t.CloseIdleConnections()
276 }
277 }
278
279 s.wg.Wait()
280 }
281
282 func (s *Server) logCloseHangDebugInfo() {
283 s.mu.Lock()
284 defer s.mu.Unlock()
285 var buf strings.Builder
286 buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
287 for c, st := range s.conns {
288 fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
289 }
290 log.Print(buf.String())
291 }
292
293
294 func (s *Server) CloseClientConnections() {
295 s.mu.Lock()
296 nconn := len(s.conns)
297 ch := make(chan struct{}, nconn)
298 for c := range s.conns {
299 go s.closeConnChan(c, ch)
300 }
301 s.mu.Unlock()
302
303
304
305
306
307
308
309 timer := time.NewTimer(5 * time.Second)
310 defer timer.Stop()
311 for i := 0; i < nconn; i++ {
312 select {
313 case <-ch:
314 case <-timer.C:
315
316 return
317 }
318 }
319 }
320
321
322
323 func (s *Server) Certificate() *x509.Certificate {
324 return s.certificate
325 }
326
327
328
329
330
331
332
333 func (s *Server) Client() *http.Client {
334 return s.client
335 }
336
337 func (s *Server) goServe() {
338 s.wg.Add(1)
339 go func() {
340 defer s.wg.Done()
341 s.Config.Serve(s.Listener)
342 }()
343 }
344
345
346
347 func (s *Server) wrap() {
348 oldHook := s.Config.ConnState
349 s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
350 s.mu.Lock()
351 defer s.mu.Unlock()
352
353 switch cs {
354 case http.StateNew:
355 if _, exists := s.conns[c]; exists {
356 panic("invalid state transition")
357 }
358 if s.conns == nil {
359 s.conns = make(map[net.Conn]http.ConnState)
360 }
361
362
363 s.wg.Add(1)
364 s.conns[c] = cs
365 if s.closed {
366
367
368
369
370 s.closeConn(c)
371 }
372 case http.StateActive:
373 if oldState, ok := s.conns[c]; ok {
374 if oldState != http.StateNew && oldState != http.StateIdle {
375 panic("invalid state transition")
376 }
377 s.conns[c] = cs
378 }
379 case http.StateIdle:
380 if oldState, ok := s.conns[c]; ok {
381 if oldState != http.StateActive {
382 panic("invalid state transition")
383 }
384 s.conns[c] = cs
385 }
386 if s.closed {
387 s.closeConn(c)
388 }
389 case http.StateHijacked, http.StateClosed:
390
391
392 if _, ok := s.conns[c]; ok {
393 delete(s.conns, c)
394
395
396 defer s.wg.Done()
397 }
398 }
399 if oldHook != nil {
400 oldHook(c, cs)
401 }
402 }
403 }
404
405
406
407 func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
408
409
410
411 func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
412 c.Close()
413 if done != nil {
414 done <- struct{}{}
415 }
416 }
417
View as plain text