Source file src/net/http/httptest/recorder_test.go

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  	"testing"
    14  )
    15  
    16  func TestRecorder(t *testing.T) {
    17  	type checkFunc func(*ResponseRecorder) error
    18  	check := func(fns ...checkFunc) []checkFunc { return fns }
    19  
    20  	hasStatus := func(wantCode int) checkFunc {
    21  		return func(rec *ResponseRecorder) error {
    22  			if rec.Code != wantCode {
    23  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
    24  			}
    25  			return nil
    26  		}
    27  	}
    28  	hasResultStatus := func(want string) checkFunc {
    29  		return func(rec *ResponseRecorder) error {
    30  			if rec.Result().Status != want {
    31  				return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
    32  			}
    33  			return nil
    34  		}
    35  	}
    36  	hasResultStatusCode := func(wantCode int) checkFunc {
    37  		return func(rec *ResponseRecorder) error {
    38  			if rec.Result().StatusCode != wantCode {
    39  				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
    40  			}
    41  			return nil
    42  		}
    43  	}
    44  	hasResultContents := func(want string) checkFunc {
    45  		return func(rec *ResponseRecorder) error {
    46  			contentBytes, err := io.ReadAll(rec.Result().Body)
    47  			if err != nil {
    48  				return err
    49  			}
    50  			contents := string(contentBytes)
    51  			if contents != want {
    52  				return fmt.Errorf("Result().Body = %s; want %s", contents, want)
    53  			}
    54  			return nil
    55  		}
    56  	}
    57  	hasContents := func(want string) checkFunc {
    58  		return func(rec *ResponseRecorder) error {
    59  			if rec.Body.String() != want {
    60  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
    61  			}
    62  			return nil
    63  		}
    64  	}
    65  	hasFlush := func(want bool) checkFunc {
    66  		return func(rec *ResponseRecorder) error {
    67  			if rec.Flushed != want {
    68  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
    69  			}
    70  			return nil
    71  		}
    72  	}
    73  	hasOldHeader := func(key, want string) checkFunc {
    74  		return func(rec *ResponseRecorder) error {
    75  			if got := rec.HeaderMap.Get(key); got != want {
    76  				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
    77  			}
    78  			return nil
    79  		}
    80  	}
    81  	hasHeader := func(key, want string) checkFunc {
    82  		return func(rec *ResponseRecorder) error {
    83  			if got := rec.Result().Header.Get(key); got != want {
    84  				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
    85  			}
    86  			return nil
    87  		}
    88  	}
    89  	hasNotHeaders := func(keys ...string) checkFunc {
    90  		return func(rec *ResponseRecorder) error {
    91  			for _, k := range keys {
    92  				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
    93  				if ok {
    94  					return fmt.Errorf("unexpected header %s with value %q", k, v)
    95  				}
    96  			}
    97  			return nil
    98  		}
    99  	}
   100  	hasTrailer := func(key, want string) checkFunc {
   101  		return func(rec *ResponseRecorder) error {
   102  			if got := rec.Result().Trailer.Get(key); got != want {
   103  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
   104  			}
   105  			return nil
   106  		}
   107  	}
   108  	hasNotTrailers := func(keys ...string) checkFunc {
   109  		return func(rec *ResponseRecorder) error {
   110  			trailers := rec.Result().Trailer
   111  			for _, k := range keys {
   112  				_, ok := trailers[http.CanonicalHeaderKey(k)]
   113  				if ok {
   114  					return fmt.Errorf("unexpected trailer %s", k)
   115  				}
   116  			}
   117  			return nil
   118  		}
   119  	}
   120  	hasContentLength := func(length int64) checkFunc {
   121  		return func(rec *ResponseRecorder) error {
   122  			if got := rec.Result().ContentLength; got != length {
   123  				return fmt.Errorf("ContentLength = %d; want %d", got, length)
   124  			}
   125  			return nil
   126  		}
   127  	}
   128  
   129  	for _, tt := range [...]struct {
   130  		name   string
   131  		h      func(w http.ResponseWriter, r *http.Request)
   132  		checks []checkFunc
   133  	}{
   134  		{
   135  			"200 default",
   136  			func(w http.ResponseWriter, r *http.Request) {},
   137  			check(hasStatus(200), hasContents("")),
   138  		},
   139  		{
   140  			"first code only",
   141  			func(w http.ResponseWriter, r *http.Request) {
   142  				w.WriteHeader(201)
   143  				w.WriteHeader(202)
   144  				w.Write([]byte("hi"))
   145  			},
   146  			check(hasStatus(201), hasContents("hi")),
   147  		},
   148  		{
   149  			"write sends 200",
   150  			func(w http.ResponseWriter, r *http.Request) {
   151  				w.Write([]byte("hi first"))
   152  				w.WriteHeader(201)
   153  				w.WriteHeader(202)
   154  			},
   155  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
   156  		},
   157  		{
   158  			"write string",
   159  			func(w http.ResponseWriter, r *http.Request) {
   160  				io.WriteString(w, "hi first")
   161  			},
   162  			check(
   163  				hasStatus(200),
   164  				hasContents("hi first"),
   165  				hasFlush(false),
   166  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
   167  			),
   168  		},
   169  		{
   170  			"flush",
   171  			func(w http.ResponseWriter, r *http.Request) {
   172  				w.(http.Flusher).Flush() // also sends a 200
   173  				w.WriteHeader(201)
   174  			},
   175  			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
   176  		},
   177  		{
   178  			"Content-Type detection",
   179  			func(w http.ResponseWriter, r *http.Request) {
   180  				io.WriteString(w, "<html>")
   181  			},
   182  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   183  		},
   184  		{
   185  			"no Content-Type detection with Transfer-Encoding",
   186  			func(w http.ResponseWriter, r *http.Request) {
   187  				w.Header().Set("Transfer-Encoding", "some encoding")
   188  				io.WriteString(w, "<html>")
   189  			},
   190  			check(hasHeader("Content-Type", "")), // no header
   191  		},
   192  		{
   193  			"no Content-Type detection if set explicitly",
   194  			func(w http.ResponseWriter, r *http.Request) {
   195  				w.Header().Set("Content-Type", "some/type")
   196  				io.WriteString(w, "<html>")
   197  			},
   198  			check(hasHeader("Content-Type", "some/type")),
   199  		},
   200  		{
   201  			"Content-Type detection doesn't crash if HeaderMap is nil",
   202  			func(w http.ResponseWriter, r *http.Request) {
   203  				// Act as if the user wrote new(httptest.ResponseRecorder)
   204  				// rather than using NewRecorder (which initializes
   205  				// HeaderMap)
   206  				w.(*ResponseRecorder).HeaderMap = nil
   207  				io.WriteString(w, "<html>")
   208  			},
   209  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
   210  		},
   211  		{
   212  			"Header is not changed after write",
   213  			func(w http.ResponseWriter, r *http.Request) {
   214  				hdr := w.Header()
   215  				hdr.Set("Key", "correct")
   216  				w.WriteHeader(200)
   217  				hdr.Set("Key", "incorrect")
   218  			},
   219  			check(hasHeader("Key", "correct")),
   220  		},
   221  		{
   222  			"Trailer headers are correctly recorded",
   223  			func(w http.ResponseWriter, r *http.Request) {
   224  				w.Header().Set("Non-Trailer", "correct")
   225  				w.Header().Set("Trailer", "Trailer-A, Trailer-B")
   226  				w.Header().Add("Trailer", "Trailer-C")
   227  				io.WriteString(w, "<html>")
   228  				w.Header().Set("Non-Trailer", "incorrect")
   229  				w.Header().Set("Trailer-A", "valuea")
   230  				w.Header().Set("Trailer-C", "valuec")
   231  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
   232  				w.Header().Set("Trailer:Trailer-D", "with prefix")
   233  			},
   234  			check(
   235  				hasStatus(200),
   236  				hasHeader("Content-Type", "text/html; charset=utf-8"),
   237  				hasHeader("Non-Trailer", "correct"),
   238  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
   239  				hasTrailer("Trailer-A", "valuea"),
   240  				hasTrailer("Trailer-C", "valuec"),
   241  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
   242  				hasTrailer("Trailer-D", "with prefix"),
   243  			),
   244  		},
   245  		{
   246  			"Header set without any write", // Issue 15560
   247  			func(w http.ResponseWriter, r *http.Request) {
   248  				w.Header().Set("X-Foo", "1")
   249  
   250  				// Simulate somebody using
   251  				// new(ResponseRecorder) instead of
   252  				// using the constructor which sets
   253  				// this to 200
   254  				w.(*ResponseRecorder).Code = 0
   255  			},
   256  			check(
   257  				hasOldHeader("X-Foo", "1"),
   258  				hasStatus(0),
   259  				hasHeader("X-Foo", "1"),
   260  				hasResultStatus("200 OK"),
   261  				hasResultStatusCode(200),
   262  			),
   263  		},
   264  		{
   265  			"HeaderMap vs FinalHeaders", // more for Issue 15560
   266  			func(w http.ResponseWriter, r *http.Request) {
   267  				h := w.Header()
   268  				h.Set("X-Foo", "1")
   269  				w.Write([]byte("hi"))
   270  				h.Set("X-Foo", "2")
   271  				h.Set("X-Bar", "2")
   272  			},
   273  			check(
   274  				hasOldHeader("X-Foo", "2"),
   275  				hasOldHeader("X-Bar", "2"),
   276  				hasHeader("X-Foo", "1"),
   277  				hasNotHeaders("X-Bar"),
   278  			),
   279  		},
   280  		{
   281  			"setting Content-Length header",
   282  			func(w http.ResponseWriter, r *http.Request) {
   283  				body := "Some body"
   284  				contentLength := fmt.Sprintf("%d", len(body))
   285  				w.Header().Set("Content-Length", contentLength)
   286  				io.WriteString(w, body)
   287  			},
   288  			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
   289  		},
   290  		{
   291  			"nil ResponseRecorder.Body", // Issue 26642
   292  			func(w http.ResponseWriter, r *http.Request) {
   293  				w.(*ResponseRecorder).Body = nil
   294  				io.WriteString(w, "hi")
   295  			},
   296  			check(hasResultContents("")), // check we don't crash reading the body
   297  
   298  		},
   299  	} {
   300  		t.Run(tt.name, func(t *testing.T) {
   301  			r, _ := http.NewRequest("GET", "http://foo.com/", nil)
   302  			h := http.HandlerFunc(tt.h)
   303  			rec := NewRecorder()
   304  			h.ServeHTTP(rec, r)
   305  			for _, check := range tt.checks {
   306  				if err := check(rec); err != nil {
   307  					t.Error(err)
   308  				}
   309  			}
   310  		})
   311  	}
   312  }
   313  
   314  func TestBodyNotAllowed(t *testing.T) {
   315  	rw := NewRecorder()
   316  	rw.Body = new(bytes.Buffer)
   317  	rw.WriteHeader(204)
   318  
   319  	_, err := rw.Write([]byte("hello "))
   320  	if !errors.Is(err, http.ErrBodyNotAllowed) {
   321  		t.Errorf("expected BodyNotAllowed for Write after 204, got: %v", err)
   322  	}
   323  
   324  	_, err = rw.WriteString("world")
   325  	if !errors.Is(err, http.ErrBodyNotAllowed) {
   326  		t.Errorf("expected BodyNotAllowed for WriteString after 204, got: %v", err)
   327  	}
   328  
   329  	if got, want := rw.Body.String(), "hello world"; got != want {
   330  		t.Errorf("got Body=%q, want %q", got, want)
   331  	}
   332  }
   333  
   334  // issue 39017 - disallow Content-Length values such as "+3"
   335  func TestParseContentLength(t *testing.T) {
   336  	tests := []struct {
   337  		cl   string
   338  		want int64
   339  	}{
   340  		{
   341  			cl:   "3",
   342  			want: 3,
   343  		},
   344  		{
   345  			cl:   "+3",
   346  			want: -1,
   347  		},
   348  		{
   349  			cl:   "-3",
   350  			want: -1,
   351  		},
   352  		{
   353  			// max int64, for safe conversion before returning
   354  			cl:   "9223372036854775807",
   355  			want: 9223372036854775807,
   356  		},
   357  		{
   358  			cl:   "9223372036854775808",
   359  			want: -1,
   360  		},
   361  	}
   362  
   363  	for _, tt := range tests {
   364  		if got := parseContentLength(tt.cl); got != tt.want {
   365  			t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
   366  		}
   367  	}
   368  }
   369  
   370  // Ensure that httptest.Recorder panics when given a non-3 digit (XXX)
   371  // status HTTP code. See https://golang.org/issues/45353
   372  func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
   373  	badCodes := []int{
   374  		-100, 0, 99, 1000, 20000,
   375  	}
   376  	for _, badCode := range badCodes {
   377  		t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
   378  			defer func() {
   379  				if r := recover(); r == nil {
   380  					t.Fatal("Expected a panic")
   381  				}
   382  			}()
   383  
   384  			handler := func(rw http.ResponseWriter, _ *http.Request) {
   385  				rw.WriteHeader(badCode)
   386  			}
   387  			r, _ := http.NewRequest("GET", "http://example.org/", nil)
   388  			rw := NewRecorder()
   389  			handler(rw, r)
   390  		})
   391  	}
   392  }
   393  

View as plain text