1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net/http"
13 "testing"
14 )
15
16 func TestRecorder(t *testing.T) {
17 type checkFunc func(*ResponseRecorder) error
18 check := func(fns ...checkFunc) []checkFunc { return fns }
19
20 hasStatus := func(wantCode int) checkFunc {
21 return func(rec *ResponseRecorder) error {
22 if rec.Code != wantCode {
23 return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
24 }
25 return nil
26 }
27 }
28 hasResultStatus := func(want string) checkFunc {
29 return func(rec *ResponseRecorder) error {
30 if rec.Result().Status != want {
31 return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
32 }
33 return nil
34 }
35 }
36 hasResultStatusCode := func(wantCode int) checkFunc {
37 return func(rec *ResponseRecorder) error {
38 if rec.Result().StatusCode != wantCode {
39 return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
40 }
41 return nil
42 }
43 }
44 hasResultContents := func(want string) checkFunc {
45 return func(rec *ResponseRecorder) error {
46 contentBytes, err := io.ReadAll(rec.Result().Body)
47 if err != nil {
48 return err
49 }
50 contents := string(contentBytes)
51 if contents != want {
52 return fmt.Errorf("Result().Body = %s; want %s", contents, want)
53 }
54 return nil
55 }
56 }
57 hasContents := func(want string) checkFunc {
58 return func(rec *ResponseRecorder) error {
59 if rec.Body.String() != want {
60 return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
61 }
62 return nil
63 }
64 }
65 hasFlush := func(want bool) checkFunc {
66 return func(rec *ResponseRecorder) error {
67 if rec.Flushed != want {
68 return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
69 }
70 return nil
71 }
72 }
73 hasOldHeader := func(key, want string) checkFunc {
74 return func(rec *ResponseRecorder) error {
75 if got := rec.HeaderMap.Get(key); got != want {
76 return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
77 }
78 return nil
79 }
80 }
81 hasHeader := func(key, want string) checkFunc {
82 return func(rec *ResponseRecorder) error {
83 if got := rec.Result().Header.Get(key); got != want {
84 return fmt.Errorf("final header %s = %q; want %q", key, got, want)
85 }
86 return nil
87 }
88 }
89 hasNotHeaders := func(keys ...string) checkFunc {
90 return func(rec *ResponseRecorder) error {
91 for _, k := range keys {
92 v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
93 if ok {
94 return fmt.Errorf("unexpected header %s with value %q", k, v)
95 }
96 }
97 return nil
98 }
99 }
100 hasTrailer := func(key, want string) checkFunc {
101 return func(rec *ResponseRecorder) error {
102 if got := rec.Result().Trailer.Get(key); got != want {
103 return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
104 }
105 return nil
106 }
107 }
108 hasNotTrailers := func(keys ...string) checkFunc {
109 return func(rec *ResponseRecorder) error {
110 trailers := rec.Result().Trailer
111 for _, k := range keys {
112 _, ok := trailers[http.CanonicalHeaderKey(k)]
113 if ok {
114 return fmt.Errorf("unexpected trailer %s", k)
115 }
116 }
117 return nil
118 }
119 }
120 hasContentLength := func(length int64) checkFunc {
121 return func(rec *ResponseRecorder) error {
122 if got := rec.Result().ContentLength; got != length {
123 return fmt.Errorf("ContentLength = %d; want %d", got, length)
124 }
125 return nil
126 }
127 }
128
129 for _, tt := range [...]struct {
130 name string
131 h func(w http.ResponseWriter, r *http.Request)
132 checks []checkFunc
133 }{
134 {
135 "200 default",
136 func(w http.ResponseWriter, r *http.Request) {},
137 check(hasStatus(200), hasContents("")),
138 },
139 {
140 "first code only",
141 func(w http.ResponseWriter, r *http.Request) {
142 w.WriteHeader(201)
143 w.WriteHeader(202)
144 w.Write([]byte("hi"))
145 },
146 check(hasStatus(201), hasContents("hi")),
147 },
148 {
149 "write sends 200",
150 func(w http.ResponseWriter, r *http.Request) {
151 w.Write([]byte("hi first"))
152 w.WriteHeader(201)
153 w.WriteHeader(202)
154 },
155 check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
156 },
157 {
158 "write string",
159 func(w http.ResponseWriter, r *http.Request) {
160 io.WriteString(w, "hi first")
161 },
162 check(
163 hasStatus(200),
164 hasContents("hi first"),
165 hasFlush(false),
166 hasHeader("Content-Type", "text/plain; charset=utf-8"),
167 ),
168 },
169 {
170 "flush",
171 func(w http.ResponseWriter, r *http.Request) {
172 w.(http.Flusher).Flush()
173 w.WriteHeader(201)
174 },
175 check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
176 },
177 {
178 "Content-Type detection",
179 func(w http.ResponseWriter, r *http.Request) {
180 io.WriteString(w, "<html>")
181 },
182 check(hasHeader("Content-Type", "text/html; charset=utf-8")),
183 },
184 {
185 "no Content-Type detection with Transfer-Encoding",
186 func(w http.ResponseWriter, r *http.Request) {
187 w.Header().Set("Transfer-Encoding", "some encoding")
188 io.WriteString(w, "<html>")
189 },
190 check(hasHeader("Content-Type", "")),
191 },
192 {
193 "no Content-Type detection if set explicitly",
194 func(w http.ResponseWriter, r *http.Request) {
195 w.Header().Set("Content-Type", "some/type")
196 io.WriteString(w, "<html>")
197 },
198 check(hasHeader("Content-Type", "some/type")),
199 },
200 {
201 "Content-Type detection doesn't crash if HeaderMap is nil",
202 func(w http.ResponseWriter, r *http.Request) {
203
204
205
206 w.(*ResponseRecorder).HeaderMap = nil
207 io.WriteString(w, "<html>")
208 },
209 check(hasHeader("Content-Type", "text/html; charset=utf-8")),
210 },
211 {
212 "Header is not changed after write",
213 func(w http.ResponseWriter, r *http.Request) {
214 hdr := w.Header()
215 hdr.Set("Key", "correct")
216 w.WriteHeader(200)
217 hdr.Set("Key", "incorrect")
218 },
219 check(hasHeader("Key", "correct")),
220 },
221 {
222 "Trailer headers are correctly recorded",
223 func(w http.ResponseWriter, r *http.Request) {
224 w.Header().Set("Non-Trailer", "correct")
225 w.Header().Set("Trailer", "Trailer-A, Trailer-B")
226 w.Header().Add("Trailer", "Trailer-C")
227 io.WriteString(w, "<html>")
228 w.Header().Set("Non-Trailer", "incorrect")
229 w.Header().Set("Trailer-A", "valuea")
230 w.Header().Set("Trailer-C", "valuec")
231 w.Header().Set("Trailer-NotDeclared", "should be omitted")
232 w.Header().Set("Trailer:Trailer-D", "with prefix")
233 },
234 check(
235 hasStatus(200),
236 hasHeader("Content-Type", "text/html; charset=utf-8"),
237 hasHeader("Non-Trailer", "correct"),
238 hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
239 hasTrailer("Trailer-A", "valuea"),
240 hasTrailer("Trailer-C", "valuec"),
241 hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
242 hasTrailer("Trailer-D", "with prefix"),
243 ),
244 },
245 {
246 "Header set without any write",
247 func(w http.ResponseWriter, r *http.Request) {
248 w.Header().Set("X-Foo", "1")
249
250
251
252
253
254 w.(*ResponseRecorder).Code = 0
255 },
256 check(
257 hasOldHeader("X-Foo", "1"),
258 hasStatus(0),
259 hasHeader("X-Foo", "1"),
260 hasResultStatus("200 OK"),
261 hasResultStatusCode(200),
262 ),
263 },
264 {
265 "HeaderMap vs FinalHeaders",
266 func(w http.ResponseWriter, r *http.Request) {
267 h := w.Header()
268 h.Set("X-Foo", "1")
269 w.Write([]byte("hi"))
270 h.Set("X-Foo", "2")
271 h.Set("X-Bar", "2")
272 },
273 check(
274 hasOldHeader("X-Foo", "2"),
275 hasOldHeader("X-Bar", "2"),
276 hasHeader("X-Foo", "1"),
277 hasNotHeaders("X-Bar"),
278 ),
279 },
280 {
281 "setting Content-Length header",
282 func(w http.ResponseWriter, r *http.Request) {
283 body := "Some body"
284 contentLength := fmt.Sprintf("%d", len(body))
285 w.Header().Set("Content-Length", contentLength)
286 io.WriteString(w, body)
287 },
288 check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
289 },
290 {
291 "nil ResponseRecorder.Body",
292 func(w http.ResponseWriter, r *http.Request) {
293 w.(*ResponseRecorder).Body = nil
294 io.WriteString(w, "hi")
295 },
296 check(hasResultContents("")),
297
298 },
299 } {
300 t.Run(tt.name, func(t *testing.T) {
301 r, _ := http.NewRequest("GET", "http://foo.com/", nil)
302 h := http.HandlerFunc(tt.h)
303 rec := NewRecorder()
304 h.ServeHTTP(rec, r)
305 for _, check := range tt.checks {
306 if err := check(rec); err != nil {
307 t.Error(err)
308 }
309 }
310 })
311 }
312 }
313
314 func TestBodyNotAllowed(t *testing.T) {
315 rw := NewRecorder()
316 rw.Body = new(bytes.Buffer)
317 rw.WriteHeader(204)
318
319 _, err := rw.Write([]byte("hello "))
320 if !errors.Is(err, http.ErrBodyNotAllowed) {
321 t.Errorf("expected BodyNotAllowed for Write after 204, got: %v", err)
322 }
323
324 _, err = rw.WriteString("world")
325 if !errors.Is(err, http.ErrBodyNotAllowed) {
326 t.Errorf("expected BodyNotAllowed for WriteString after 204, got: %v", err)
327 }
328
329 if got, want := rw.Body.String(), "hello world"; got != want {
330 t.Errorf("got Body=%q, want %q", got, want)
331 }
332 }
333
334
335 func TestParseContentLength(t *testing.T) {
336 tests := []struct {
337 cl string
338 want int64
339 }{
340 {
341 cl: "3",
342 want: 3,
343 },
344 {
345 cl: "+3",
346 want: -1,
347 },
348 {
349 cl: "-3",
350 want: -1,
351 },
352 {
353
354 cl: "9223372036854775807",
355 want: 9223372036854775807,
356 },
357 {
358 cl: "9223372036854775808",
359 want: -1,
360 },
361 }
362
363 for _, tt := range tests {
364 if got := parseContentLength(tt.cl); got != tt.want {
365 t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
366 }
367 }
368 }
369
370
371
372 func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
373 badCodes := []int{
374 -100, 0, 99, 1000, 20000,
375 }
376 for _, badCode := range badCodes {
377 t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
378 defer func() {
379 if r := recover(); r == nil {
380 t.Fatal("Expected a panic")
381 }
382 }()
383
384 handler := func(rw http.ResponseWriter, _ *http.Request) {
385 rw.WriteHeader(badCode)
386 }
387 r, _ := http.NewRequest("GET", "http://example.org/", nil)
388 rw := NewRecorder()
389 handler(rw, r)
390 })
391 }
392 }
393
View as plain text