// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package httptest import ( "bufio" "io" "net" "net/http" "testing" ) type newServerFunc func(http.Handler) *Server var newServers = map[string]newServerFunc{ "NewServer": NewServer, "NewTLSServer": NewTLSServer, // The manual variants of newServer create a Server manually by only filling // in the exported fields of Server. "NewServerManual": func(h http.Handler) *Server { ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} ts.Start() return ts }, "NewTLSServerManual": func(h http.Handler) *Server { ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} ts.StartTLS() return ts }, } func TestServer(t *testing.T) { for _, name := range []string{"NewServer", "NewServerManual"} { t.Run(name, func(t *testing.T) { newServer := newServers[name] t.Run("Server", func(t *testing.T) { testServer(t, newServer) }) t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) }) t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) }) t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) }) t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) }) }) } for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} { t.Run(name, func(t *testing.T) { newServer := newServers[name] t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) }) t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) }) }) } } func testServer(t *testing.T, newServer newServerFunc) { ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) defer ts.Close() res, err := http.Get(ts.URL) if err != nil { t.Fatal(err) } got, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if string(got) != "hello" { t.Errorf("got %q, want hello", string(got)) } } // Issue 12781 func testGetAfterClose(t *testing.T, newServer newServerFunc) { ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) res, err := http.Get(ts.URL) if err != nil { t.Fatal(err) } got, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } if string(got) != "hello" { t.Fatalf("got %q, want hello", string(got)) } ts.Close() res, err = http.Get(ts.URL) if err == nil { body, _ := io.ReadAll(res.Body) t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) } } func testServerCloseBlocking(t *testing.T, newServer newServerFunc) { ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) dial := func() net.Conn { c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } return c } // Keep one connection in StateNew (connected, but not sending anything) cnew := dial() defer cnew.Close() // Keep one connection in StateIdle (idle after a request) cidle := dial() defer cidle.Close() cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) _, err := http.ReadResponse(bufio.NewReader(cidle), nil) if err != nil { t.Fatal(err) } ts.Close() // test we don't hang here forever. } // Issue 14290 func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) { var s *Server s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.CloseClientConnections() })) defer s.Close() res, err := http.Get(s.URL) if err == nil { res.Body.Close() t.Fatalf("Unexpected response: %#v", res) } } // Tests that the Server.Client method works and returns an http.Client that can hit // NewTLSServer without cert warnings. func testServerClient(t *testing.T, newTLSServer newServerFunc) { ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) defer ts.Close() client := ts.Client() res, err := client.Get(ts.URL) if err != nil { t.Fatal(err) } got, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) } if string(got) != "hello" { t.Errorf("got %q, want hello", string(got)) } } // Tests that the Server.Client.Transport interface is implemented // by a *http.Transport. func testServerClientTransportType(t *testing.T, newServer newServerFunc) { ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) defer ts.Close() client := ts.Client() if _, ok := client.Transport.(*http.Transport); !ok { t.Errorf("got %T, want *http.Transport", client.Transport) } } // Tests that the TLS Server.Client.Transport interface is implemented // by a *http.Transport. func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) { ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) defer ts.Close() client := ts.Client() if _, ok := client.Transport.(*http.Transport); !ok { t.Errorf("got %T, want *http.Transport", client.Transport) } } type onlyCloseListener struct { net.Listener } func (onlyCloseListener) Close() error { return nil } // Issue 19729: panic in Server.Close for values created directly // without a constructor (so the unexported client field is nil). func TestServerZeroValueClose(t *testing.T) { ts := &Server{ Listener: onlyCloseListener{}, Config: &http.Server{}, } ts.Close() // tests that it doesn't panic } func TestTLSServerWithHTTP2(t *testing.T) { modes := []struct { name string wantProto string }{ {"http1", "HTTP/1.1"}, {"http2", "HTTP/2.0"}, } for _, tt := range modes { t.Run(tt.name, func(t *testing.T) { cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Proto", r.Proto) })) switch tt.name { case "http2": cst.EnableHTTP2 = true cst.StartTLS() default: cst.Start() } defer cst.Close() res, err := cst.Client().Get(cst.URL) if err != nil { t.Fatalf("Failed to make request: %v", err) } if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w { t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w) } }) } }