Black Lives Matter. Support the Equal Justice Initiative.

Source file src/net/http/httptest/server_test.go

Documentation: net/http/httptest

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bufio"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"testing"
    13  )
    14  
    15  type newServerFunc func(http.Handler) *Server
    16  
    17  var newServers = map[string]newServerFunc{
    18  	"NewServer":    NewServer,
    19  	"NewTLSServer": NewTLSServer,
    20  
    21  	// The manual variants of newServer create a Server manually by only filling
    22  	// in the exported fields of Server.
    23  	"NewServerManual": func(h http.Handler) *Server {
    24  		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
    25  		ts.Start()
    26  		return ts
    27  	},
    28  	"NewTLSServerManual": func(h http.Handler) *Server {
    29  		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
    30  		ts.StartTLS()
    31  		return ts
    32  	},
    33  }
    34  
    35  func TestServer(t *testing.T) {
    36  	for _, name := range []string{"NewServer", "NewServerManual"} {
    37  		t.Run(name, func(t *testing.T) {
    38  			newServer := newServers[name]
    39  			t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
    40  			t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
    41  			t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
    42  			t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
    43  			t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
    44  		})
    45  	}
    46  	for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
    47  		t.Run(name, func(t *testing.T) {
    48  			newServer := newServers[name]
    49  			t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
    50  			t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
    51  		})
    52  	}
    53  }
    54  
    55  func testServer(t *testing.T, newServer newServerFunc) {
    56  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		w.Write([]byte("hello"))
    58  	}))
    59  	defer ts.Close()
    60  	res, err := http.Get(ts.URL)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  	got, err := io.ReadAll(res.Body)
    65  	res.Body.Close()
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	if string(got) != "hello" {
    70  		t.Errorf("got %q, want hello", string(got))
    71  	}
    72  }
    73  
    74  // Issue 12781
    75  func testGetAfterClose(t *testing.T, newServer newServerFunc) {
    76  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    77  		w.Write([]byte("hello"))
    78  	}))
    79  
    80  	res, err := http.Get(ts.URL)
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	got, err := io.ReadAll(res.Body)
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  	if string(got) != "hello" {
    89  		t.Fatalf("got %q, want hello", string(got))
    90  	}
    91  
    92  	ts.Close()
    93  
    94  	res, err = http.Get(ts.URL)
    95  	if err == nil {
    96  		body, _ := io.ReadAll(res.Body)
    97  		t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
    98  	}
    99  }
   100  
   101  func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
   102  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   103  		w.Write([]byte("hello"))
   104  	}))
   105  	dial := func() net.Conn {
   106  		c, err := net.Dial("tcp", ts.Listener.Addr().String())
   107  		if err != nil {
   108  			t.Fatal(err)
   109  		}
   110  		return c
   111  	}
   112  
   113  	// Keep one connection in StateNew (connected, but not sending anything)
   114  	cnew := dial()
   115  	defer cnew.Close()
   116  
   117  	// Keep one connection in StateIdle (idle after a request)
   118  	cidle := dial()
   119  	defer cidle.Close()
   120  	cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
   121  	_, err := http.ReadResponse(bufio.NewReader(cidle), nil)
   122  	if err != nil {
   123  		t.Fatal(err)
   124  	}
   125  
   126  	ts.Close() // test we don't hang here forever.
   127  }
   128  
   129  // Issue 14290
   130  func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
   131  	var s *Server
   132  	s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   133  		s.CloseClientConnections()
   134  	}))
   135  	defer s.Close()
   136  	res, err := http.Get(s.URL)
   137  	if err == nil {
   138  		res.Body.Close()
   139  		t.Fatalf("Unexpected response: %#v", res)
   140  	}
   141  }
   142  
   143  // Tests that the Server.Client method works and returns an http.Client that can hit
   144  // NewTLSServer without cert warnings.
   145  func testServerClient(t *testing.T, newTLSServer newServerFunc) {
   146  	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   147  		w.Write([]byte("hello"))
   148  	}))
   149  	defer ts.Close()
   150  	client := ts.Client()
   151  	res, err := client.Get(ts.URL)
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  	got, err := io.ReadAll(res.Body)
   156  	res.Body.Close()
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	if string(got) != "hello" {
   161  		t.Errorf("got %q, want hello", string(got))
   162  	}
   163  }
   164  
   165  // Tests that the Server.Client.Transport interface is implemented
   166  // by a *http.Transport.
   167  func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
   168  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   169  	}))
   170  	defer ts.Close()
   171  	client := ts.Client()
   172  	if _, ok := client.Transport.(*http.Transport); !ok {
   173  		t.Errorf("got %T, want *http.Transport", client.Transport)
   174  	}
   175  }
   176  
   177  // Tests that the TLS Server.Client.Transport interface is implemented
   178  // by a *http.Transport.
   179  func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
   180  	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   181  	}))
   182  	defer ts.Close()
   183  	client := ts.Client()
   184  	if _, ok := client.Transport.(*http.Transport); !ok {
   185  		t.Errorf("got %T, want *http.Transport", client.Transport)
   186  	}
   187  }
   188  
   189  type onlyCloseListener struct {
   190  	net.Listener
   191  }
   192  
   193  func (onlyCloseListener) Close() error { return nil }
   194  
   195  // Issue 19729: panic in Server.Close for values created directly
   196  // without a constructor (so the unexported client field is nil).
   197  func TestServerZeroValueClose(t *testing.T) {
   198  	ts := &Server{
   199  		Listener: onlyCloseListener{},
   200  		Config:   &http.Server{},
   201  	}
   202  
   203  	ts.Close() // tests that it doesn't panic
   204  }
   205  
   206  func TestTLSServerWithHTTP2(t *testing.T) {
   207  	modes := []struct {
   208  		name      string
   209  		wantProto string
   210  	}{
   211  		{"http1", "HTTP/1.1"},
   212  		{"http2", "HTTP/2.0"},
   213  	}
   214  
   215  	for _, tt := range modes {
   216  		t.Run(tt.name, func(t *testing.T) {
   217  			cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   218  				w.Header().Set("X-Proto", r.Proto)
   219  			}))
   220  
   221  			switch tt.name {
   222  			case "http2":
   223  				cst.EnableHTTP2 = true
   224  				cst.StartTLS()
   225  			default:
   226  				cst.Start()
   227  			}
   228  
   229  			defer cst.Close()
   230  
   231  			res, err := cst.Client().Get(cst.URL)
   232  			if err != nil {
   233  				t.Fatalf("Failed to make request: %v", err)
   234  			}
   235  			if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
   236  				t.Fatalf("X-Proto header mismatch:\n\tgot:  %q\n\twant: %q", g, w)
   237  			}
   238  		})
   239  	}
   240  }
   241  

View as plain text