Source file src/net/http/httptest/server.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Implementation of Server
     6  
     7  package httptest
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"flag"
    14  	"fmt"
    15  	"log"
    16  	"net"
    17  	"net/http"
    18  	"net/http/internal/testcert"
    19  	"os"
    20  	"strings"
    21  	"sync"
    22  	"time"
    23  )
    24  
    25  // A Server is an HTTP server listening on a system-chosen port on the
    26  // local loopback interface, for use in end-to-end HTTP tests.
    27  type Server struct {
    28  	URL      string // base URL of form http://ipaddr:port with no trailing slash
    29  	Listener net.Listener
    30  
    31  	// EnableHTTP2 controls whether HTTP/2 is enabled
    32  	// on the server. It must be set between calling
    33  	// NewUnstartedServer and calling Server.StartTLS.
    34  	EnableHTTP2 bool
    35  
    36  	// TLS is the optional TLS configuration, populated with a new config
    37  	// after TLS is started. If set on an unstarted server before StartTLS
    38  	// is called, existing fields are copied into the new config.
    39  	TLS *tls.Config
    40  
    41  	// Config may be changed after calling NewUnstartedServer and
    42  	// before Start or StartTLS.
    43  	Config *http.Server
    44  
    45  	// certificate is a parsed version of the TLS config certificate, if present.
    46  	certificate *x509.Certificate
    47  
    48  	// wg counts the number of outstanding HTTP requests on this server.
    49  	// Close blocks until all requests are finished.
    50  	wg sync.WaitGroup
    51  
    52  	mu     sync.Mutex // guards closed and conns
    53  	closed bool
    54  	conns  map[net.Conn]http.ConnState // except terminal states
    55  
    56  	// client is configured for use with the server.
    57  	// Its transport is automatically closed when Close is called.
    58  	client *http.Client
    59  }
    60  
    61  func newLocalListener() net.Listener {
    62  	if serveFlag != "" {
    63  		l, err := net.Listen("tcp", serveFlag)
    64  		if err != nil {
    65  			panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
    66  		}
    67  		return l
    68  	}
    69  	l, err := net.Listen("tcp", "127.0.0.1:0")
    70  	if err != nil {
    71  		if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
    72  			panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
    73  		}
    74  	}
    75  	return l
    76  }
    77  
    78  // When debugging a particular http server-based test,
    79  // this flag lets you run
    80  //
    81  //	go test -run='^BrokenTest$' -httptest.serve=127.0.0.1:8000
    82  //
    83  // to start the broken server so you can interact with it manually.
    84  // We only register this flag if it looks like the caller knows about it
    85  // and is trying to use it as we don't want to pollute flags and this
    86  // isn't really part of our API. Don't depend on this.
    87  var serveFlag string
    88  
    89  func init() {
    90  	if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
    91  		flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
    92  	}
    93  }
    94  
    95  func strSliceContainsPrefix(v []string, pre string) bool {
    96  	for _, s := range v {
    97  		if strings.HasPrefix(s, pre) {
    98  			return true
    99  		}
   100  	}
   101  	return false
   102  }
   103  
   104  // NewServer starts and returns a new [Server].
   105  // The caller should call Close when finished, to shut it down.
   106  func NewServer(handler http.Handler) *Server {
   107  	ts := NewUnstartedServer(handler)
   108  	ts.Start()
   109  	return ts
   110  }
   111  
   112  // NewUnstartedServer returns a new [Server] but doesn't start it.
   113  //
   114  // After changing its configuration, the caller should call Start or
   115  // StartTLS.
   116  //
   117  // The caller should call Close when finished, to shut it down.
   118  func NewUnstartedServer(handler http.Handler) *Server {
   119  	return &Server{
   120  		Listener: newLocalListener(),
   121  		Config:   &http.Server{Handler: handler},
   122  	}
   123  }
   124  
   125  // Start starts a server from NewUnstartedServer.
   126  func (s *Server) Start() {
   127  	if s.URL != "" {
   128  		panic("Server already started")
   129  	}
   130  
   131  	if s.client == nil {
   132  		tr := &http.Transport{}
   133  		dialer := net.Dialer{}
   134  		// User code may set either of Dial or DialContext, with DialContext taking precedence.
   135  		// We set DialContext here to preserve any context values that are passed in,
   136  		// but fall back to Dial if the user has set it.
   137  		tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
   138  			if tr.Dial != nil {
   139  				return tr.Dial(network, addr)
   140  			}
   141  			if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") {
   142  				addr = s.Listener.Addr().String()
   143  			}
   144  			return dialer.DialContext(ctx, network, addr)
   145  		}
   146  		s.client = &http.Client{Transport: tr}
   147  
   148  	}
   149  	s.URL = "http://" + s.Listener.Addr().String()
   150  	s.wrap()
   151  	s.goServe()
   152  	if serveFlag != "" {
   153  		fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
   154  		select {}
   155  	}
   156  }
   157  
   158  // StartTLS starts TLS on a server from NewUnstartedServer.
   159  func (s *Server) StartTLS() {
   160  	if s.URL != "" {
   161  		panic("Server already started")
   162  	}
   163  	if s.client == nil {
   164  		s.client = &http.Client{}
   165  	}
   166  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
   167  	if err != nil {
   168  		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
   169  	}
   170  
   171  	existingConfig := s.TLS
   172  	if existingConfig != nil {
   173  		s.TLS = existingConfig.Clone()
   174  	} else {
   175  		s.TLS = new(tls.Config)
   176  	}
   177  	if s.TLS.NextProtos == nil {
   178  		nextProtos := []string{"http/1.1"}
   179  		if s.EnableHTTP2 {
   180  			nextProtos = []string{"h2"}
   181  		}
   182  		s.TLS.NextProtos = nextProtos
   183  	}
   184  	if len(s.TLS.Certificates) == 0 {
   185  		s.TLS.Certificates = []tls.Certificate{cert}
   186  	}
   187  	s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
   188  	if err != nil {
   189  		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
   190  	}
   191  	certpool := x509.NewCertPool()
   192  	certpool.AddCert(s.certificate)
   193  	tr := &http.Transport{
   194  		TLSClientConfig: &tls.Config{
   195  			RootCAs: certpool,
   196  		},
   197  		ForceAttemptHTTP2: s.EnableHTTP2,
   198  	}
   199  	dialer := net.Dialer{}
   200  	tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
   201  		if tr.Dial != nil {
   202  			return tr.Dial(network, addr)
   203  		}
   204  		if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") {
   205  			addr = s.Listener.Addr().String()
   206  		}
   207  		return dialer.DialContext(ctx, network, addr)
   208  	}
   209  	s.client.Transport = tr
   210  	s.Listener = tls.NewListener(s.Listener, s.TLS)
   211  	s.URL = "https://" + s.Listener.Addr().String()
   212  	s.wrap()
   213  	s.goServe()
   214  }
   215  
   216  // NewTLSServer starts and returns a new [Server] using TLS.
   217  // The caller should call Close when finished, to shut it down.
   218  func NewTLSServer(handler http.Handler) *Server {
   219  	ts := NewUnstartedServer(handler)
   220  	ts.StartTLS()
   221  	return ts
   222  }
   223  
   224  type closeIdleTransport interface {
   225  	CloseIdleConnections()
   226  }
   227  
   228  // Close shuts down the server and blocks until all outstanding
   229  // requests on this server have completed.
   230  func (s *Server) Close() {
   231  	s.mu.Lock()
   232  	if !s.closed {
   233  		s.closed = true
   234  		s.Listener.Close()
   235  		s.Config.SetKeepAlivesEnabled(false)
   236  		for c, st := range s.conns {
   237  			// Force-close any idle connections (those between
   238  			// requests) and new connections (those which connected
   239  			// but never sent a request). StateNew connections are
   240  			// super rare and have only been seen (in
   241  			// previously-flaky tests) in the case of
   242  			// socket-late-binding races from the http Client
   243  			// dialing this server and then getting an idle
   244  			// connection before the dial completed. There is thus
   245  			// a connected connection in StateNew with no
   246  			// associated Request. We only close StateIdle and
   247  			// StateNew because they're not doing anything. It's
   248  			// possible StateNew is about to do something in a few
   249  			// milliseconds, but a previous CL to check again in a
   250  			// few milliseconds wasn't liked (early versions of
   251  			// https://golang.org/cl/15151) so now we just
   252  			// forcefully close StateNew. The docs for Server.Close say
   253  			// we wait for "outstanding requests", so we don't close things
   254  			// in StateActive.
   255  			if st == http.StateIdle || st == http.StateNew {
   256  				s.closeConn(c)
   257  			}
   258  		}
   259  		// If this server doesn't shut down in 5 seconds, tell the user why.
   260  		t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
   261  		defer t.Stop()
   262  	}
   263  	s.mu.Unlock()
   264  
   265  	// Not part of httptest.Server's correctness, but assume most
   266  	// users of httptest.Server will be using the standard
   267  	// transport, so help them out and close any idle connections for them.
   268  	if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
   269  		t.CloseIdleConnections()
   270  	}
   271  
   272  	// Also close the client idle connections.
   273  	if s.client != nil {
   274  		if t, ok := s.client.Transport.(closeIdleTransport); ok {
   275  			t.CloseIdleConnections()
   276  		}
   277  	}
   278  
   279  	s.wg.Wait()
   280  }
   281  
   282  func (s *Server) logCloseHangDebugInfo() {
   283  	s.mu.Lock()
   284  	defer s.mu.Unlock()
   285  	var buf strings.Builder
   286  	buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
   287  	for c, st := range s.conns {
   288  		fmt.Fprintf(&buf, "  %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
   289  	}
   290  	log.Print(buf.String())
   291  }
   292  
   293  // CloseClientConnections closes any open HTTP connections to the test Server.
   294  func (s *Server) CloseClientConnections() {
   295  	s.mu.Lock()
   296  	nconn := len(s.conns)
   297  	ch := make(chan struct{}, nconn)
   298  	for c := range s.conns {
   299  		go s.closeConnChan(c, ch)
   300  	}
   301  	s.mu.Unlock()
   302  
   303  	// Wait for outstanding closes to finish.
   304  	//
   305  	// Out of paranoia for making a late change in Go 1.6, we
   306  	// bound how long this can wait, since golang.org/issue/14291
   307  	// isn't fully understood yet. At least this should only be used
   308  	// in tests.
   309  	timer := time.NewTimer(5 * time.Second)
   310  	defer timer.Stop()
   311  	for i := 0; i < nconn; i++ {
   312  		select {
   313  		case <-ch:
   314  		case <-timer.C:
   315  			// Too slow. Give up.
   316  			return
   317  		}
   318  	}
   319  }
   320  
   321  // Certificate returns the certificate used by the server, or nil if
   322  // the server doesn't use TLS.
   323  func (s *Server) Certificate() *x509.Certificate {
   324  	return s.certificate
   325  }
   326  
   327  // Client returns an HTTP client configured for making requests to the server.
   328  // It is configured to trust the server's TLS test certificate and will
   329  // close its idle connections on [Server.Close].
   330  // Use Server.URL as the base URL to send requests to the server.
   331  // The returned client will also redirect any requests to "example.com"
   332  // or its subdomains to the server.
   333  func (s *Server) Client() *http.Client {
   334  	return s.client
   335  }
   336  
   337  func (s *Server) goServe() {
   338  	s.wg.Add(1)
   339  	go func() {
   340  		defer s.wg.Done()
   341  		s.Config.Serve(s.Listener)
   342  	}()
   343  }
   344  
   345  // wrap installs the connection state-tracking hook to know which
   346  // connections are idle.
   347  func (s *Server) wrap() {
   348  	oldHook := s.Config.ConnState
   349  	s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
   350  		s.mu.Lock()
   351  		defer s.mu.Unlock()
   352  
   353  		switch cs {
   354  		case http.StateNew:
   355  			if _, exists := s.conns[c]; exists {
   356  				panic("invalid state transition")
   357  			}
   358  			if s.conns == nil {
   359  				s.conns = make(map[net.Conn]http.ConnState)
   360  			}
   361  			// Add c to the set of tracked conns and increment it to the
   362  			// waitgroup.
   363  			s.wg.Add(1)
   364  			s.conns[c] = cs
   365  			if s.closed {
   366  				// Probably just a socket-late-binding dial from
   367  				// the default transport that lost the race (and
   368  				// thus this connection is now idle and will
   369  				// never be used).
   370  				s.closeConn(c)
   371  			}
   372  		case http.StateActive:
   373  			if oldState, ok := s.conns[c]; ok {
   374  				if oldState != http.StateNew && oldState != http.StateIdle {
   375  					panic("invalid state transition")
   376  				}
   377  				s.conns[c] = cs
   378  			}
   379  		case http.StateIdle:
   380  			if oldState, ok := s.conns[c]; ok {
   381  				if oldState != http.StateActive {
   382  					panic("invalid state transition")
   383  				}
   384  				s.conns[c] = cs
   385  			}
   386  			if s.closed {
   387  				s.closeConn(c)
   388  			}
   389  		case http.StateHijacked, http.StateClosed:
   390  			// Remove c from the set of tracked conns and decrement it from the
   391  			// waitgroup, unless it was previously removed.
   392  			if _, ok := s.conns[c]; ok {
   393  				delete(s.conns, c)
   394  				// Keep Close from returning until the user's ConnState hook
   395  				// (if any) finishes.
   396  				defer s.wg.Done()
   397  			}
   398  		}
   399  		if oldHook != nil {
   400  			oldHook(c, cs)
   401  		}
   402  	}
   403  }
   404  
   405  // closeConn closes c.
   406  // s.mu must be held.
   407  func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
   408  
   409  // closeConnChan is like closeConn, but takes an optional channel to receive a value
   410  // when the goroutine closing c is done.
   411  func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
   412  	c.Close()
   413  	if done != nil {
   414  		done <- struct{}{}
   415  	}
   416  }
   417  

View as plain text