Source file src/cmd/go/internal/vcweb/vcstest/vcstest.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package vcstest serves the repository scripts in cmd/go/testdata/vcstest
     6  // using the [vcweb] script engine.
     7  package vcstest
     8  
     9  import (
    10  	"bytes"
    11  	"cmd/go/internal/vcs"
    12  	"cmd/go/internal/vcweb"
    13  	"cmd/go/internal/web/intercept"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"encoding/pem"
    17  	"fmt"
    18  	"internal/testenv"
    19  	"io"
    20  	"log"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"net/url"
    24  	"os"
    25  	"path/filepath"
    26  	"testing"
    27  )
    28  
    29  var Hosts = []string{
    30  	"vcs-test.golang.org",
    31  }
    32  
    33  type Server struct {
    34  	vcweb   *vcweb.Server
    35  	workDir string
    36  	HTTP    *httptest.Server
    37  	HTTPS   *httptest.Server
    38  }
    39  
    40  // NewServer returns a new test-local vcweb server that serves VCS requests
    41  // for modules with paths that begin with "vcs-test.golang.org" using the
    42  // scripts in cmd/go/testdata/vcstest.
    43  func NewServer() (srv *Server, err error) {
    44  	if vcs.VCSTestRepoURL != "" {
    45  		panic("vcs URL hooks already set")
    46  	}
    47  
    48  	scriptDir := filepath.Join(testenv.GOROOT(nil), "src/cmd/go/testdata/vcstest")
    49  
    50  	workDir, err := os.MkdirTemp("", "vcstest")
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	defer func() {
    55  		if err != nil {
    56  			os.RemoveAll(workDir)
    57  		}
    58  	}()
    59  
    60  	logger := log.Default()
    61  	if !testing.Verbose() {
    62  		logger = log.New(io.Discard, "", log.LstdFlags)
    63  	}
    64  	handler, err := vcweb.NewServer(scriptDir, workDir, logger)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	defer func() {
    69  		if err != nil {
    70  			handler.Close()
    71  		}
    72  	}()
    73  
    74  	srvHTTP := httptest.NewUnstartedServer(handler)
    75  	srvHTTP.Config.ErrorLog = testLogger()
    76  	srvHTTP.Start()
    77  	httpURL, err := url.Parse(srvHTTP.URL)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	defer func() {
    82  		if err != nil {
    83  			srvHTTP.Close()
    84  		}
    85  	}()
    86  
    87  	srvHTTPS := httptest.NewUnstartedServer(handler)
    88  	srvHTTPS.Config.ErrorLog = testLogger()
    89  	srvHTTPS.StartTLS()
    90  	httpsURL, err := url.Parse(srvHTTPS.URL)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	defer func() {
    95  		if err != nil {
    96  			srvHTTPS.Close()
    97  		}
    98  	}()
    99  
   100  	srv = &Server{
   101  		vcweb:   handler,
   102  		workDir: workDir,
   103  		HTTP:    srvHTTP,
   104  		HTTPS:   srvHTTPS,
   105  	}
   106  	vcs.VCSTestRepoURL = srv.HTTP.URL
   107  	vcs.VCSTestHosts = Hosts
   108  
   109  	interceptors := make([]intercept.Interceptor, 0, 2*len(Hosts))
   110  	for _, host := range Hosts {
   111  		interceptors = append(interceptors,
   112  			intercept.Interceptor{Scheme: "http", FromHost: host, ToHost: httpURL.Host, Client: srv.HTTP.Client()},
   113  			intercept.Interceptor{Scheme: "https", FromHost: host, ToHost: httpsURL.Host, Client: srv.HTTPS.Client()})
   114  	}
   115  	intercept.EnableTestHooks(interceptors)
   116  
   117  	fmt.Fprintln(os.Stderr, "vcs-test.golang.org rerouted to "+srv.HTTP.URL)
   118  	fmt.Fprintln(os.Stderr, "https://vcs-test.golang.org rerouted to "+srv.HTTPS.URL)
   119  
   120  	return srv, nil
   121  }
   122  
   123  func testLogger() *log.Logger {
   124  	return log.New(httpLogger{}, "vcweb: ", 0)
   125  }
   126  
   127  type httpLogger struct{}
   128  
   129  func (httpLogger) Write(b []byte) (int, error) {
   130  	if bytes.Contains(b, []byte("TLS handshake error")) {
   131  		return len(b), nil
   132  	}
   133  	return os.Stdout.Write(b)
   134  }
   135  
   136  func (srv *Server) Close() error {
   137  	if vcs.VCSTestRepoURL != srv.HTTP.URL {
   138  		panic("vcs URL hooks modified before Close")
   139  	}
   140  	vcs.VCSTestRepoURL = ""
   141  	vcs.VCSTestHosts = nil
   142  	intercept.DisableTestHooks()
   143  
   144  	srv.HTTP.Close()
   145  	srv.HTTPS.Close()
   146  	err := srv.vcweb.Close()
   147  	if rmErr := os.RemoveAll(srv.workDir); err == nil {
   148  		err = rmErr
   149  	}
   150  	return err
   151  }
   152  
   153  func (srv *Server) WriteCertificateFile() (string, error) {
   154  	b := pem.EncodeToMemory(&pem.Block{
   155  		Type:  "CERTIFICATE",
   156  		Bytes: srv.HTTPS.Certificate().Raw,
   157  	})
   158  
   159  	filename := filepath.Join(srv.workDir, "cert.pem")
   160  	if err := os.WriteFile(filename, b, 0644); err != nil {
   161  		return "", err
   162  	}
   163  	return filename, nil
   164  }
   165  
   166  // TLSClient returns an http.Client that can talk to the httptest.Server
   167  // whose certificate is written to the given file path.
   168  func TLSClient(certFile string) (*http.Client, error) {
   169  	client := &http.Client{
   170  		Transport: http.DefaultTransport.(*http.Transport).Clone(),
   171  	}
   172  
   173  	pemBytes, err := os.ReadFile(certFile)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	certpool := x509.NewCertPool()
   179  	if !certpool.AppendCertsFromPEM(pemBytes) {
   180  		return nil, fmt.Errorf("no certificates found in %s", certFile)
   181  	}
   182  	client.Transport.(*http.Transport).TLSClientConfig = &tls.Config{
   183  		RootCAs: certpool,
   184  	}
   185  
   186  	return client, nil
   187  }
   188  

View as plain text