Black Lives Matter. Support the Equal Justice Initiative.

Source file src/net/http/transport_test.go

Documentation: net/http

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests for transport.go.
     6  //
     7  // More tests are in clientserver_test.go (for things testing both client & server for both
     8  // HTTP/1 and HTTP/2). This
     9  
    10  package http_test
    11  
    12  import (
    13  	"bufio"
    14  	"bytes"
    15  	"compress/gzip"
    16  	"context"
    17  	"crypto/rand"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"encoding/binary"
    21  	"errors"
    22  	"fmt"
    23  	"go/token"
    24  	"internal/nettrace"
    25  	"io"
    26  	"log"
    27  	mrand "math/rand"
    28  	"net"
    29  	. "net/http"
    30  	"net/http/httptest"
    31  	"net/http/httptrace"
    32  	"net/http/httputil"
    33  	"net/http/internal/testcert"
    34  	"net/textproto"
    35  	"net/url"
    36  	"os"
    37  	"reflect"
    38  	"runtime"
    39  	"strconv"
    40  	"strings"
    41  	"sync"
    42  	"sync/atomic"
    43  	"testing"
    44  	"testing/iotest"
    45  	"time"
    46  
    47  	"golang.org/x/net/http/httpguts"
    48  )
    49  
    50  // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
    51  //       and then verify that the final 2 responses get errors back.
    52  
    53  // hostPortHandler writes back the client's "host:port".
    54  var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
    55  	if r.FormValue("close") == "true" {
    56  		w.Header().Set("Connection", "close")
    57  	}
    58  	w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
    59  	w.Write([]byte(r.RemoteAddr))
    60  })
    61  
    62  // testCloseConn is a net.Conn tracked by a testConnSet.
    63  type testCloseConn struct {
    64  	net.Conn
    65  	set *testConnSet
    66  }
    67  
    68  func (c *testCloseConn) Close() error {
    69  	c.set.remove(c)
    70  	return c.Conn.Close()
    71  }
    72  
    73  // testConnSet tracks a set of TCP connections and whether they've
    74  // been closed.
    75  type testConnSet struct {
    76  	t      *testing.T
    77  	mu     sync.Mutex // guards closed and list
    78  	closed map[net.Conn]bool
    79  	list   []net.Conn // in order created
    80  }
    81  
    82  func (tcs *testConnSet) insert(c net.Conn) {
    83  	tcs.mu.Lock()
    84  	defer tcs.mu.Unlock()
    85  	tcs.closed[c] = false
    86  	tcs.list = append(tcs.list, c)
    87  }
    88  
    89  func (tcs *testConnSet) remove(c net.Conn) {
    90  	tcs.mu.Lock()
    91  	defer tcs.mu.Unlock()
    92  	tcs.closed[c] = true
    93  }
    94  
    95  // some tests use this to manage raw tcp connections for later inspection
    96  func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
    97  	connSet := &testConnSet{
    98  		t:      t,
    99  		closed: make(map[net.Conn]bool),
   100  	}
   101  	dial := func(n, addr string) (net.Conn, error) {
   102  		c, err := net.Dial(n, addr)
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		tc := &testCloseConn{c, connSet}
   107  		connSet.insert(tc)
   108  		return tc, nil
   109  	}
   110  	return connSet, dial
   111  }
   112  
   113  func (tcs *testConnSet) check(t *testing.T) {
   114  	tcs.mu.Lock()
   115  	defer tcs.mu.Unlock()
   116  	for i := 4; i >= 0; i-- {
   117  		for i, c := range tcs.list {
   118  			if tcs.closed[c] {
   119  				continue
   120  			}
   121  			if i != 0 {
   122  				tcs.mu.Unlock()
   123  				time.Sleep(50 * time.Millisecond)
   124  				tcs.mu.Lock()
   125  				continue
   126  			}
   127  			t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
   128  		}
   129  	}
   130  }
   131  
   132  func TestReuseRequest(t *testing.T) {
   133  	defer afterTest(t)
   134  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   135  		w.Write([]byte("{}"))
   136  	}))
   137  	defer ts.Close()
   138  
   139  	c := ts.Client()
   140  	req, _ := NewRequest("GET", ts.URL, nil)
   141  	res, err := c.Do(req)
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  	err = res.Body.Close()
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  
   150  	res, err = c.Do(req)
   151  	if err != nil {
   152  		t.Fatal(err)
   153  	}
   154  	err = res.Body.Close()
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  }
   159  
   160  // Two subsequent requests and verify their response is the same.
   161  // The response from the server is our own IP:port
   162  func TestTransportKeepAlives(t *testing.T) {
   163  	defer afterTest(t)
   164  	ts := httptest.NewServer(hostPortHandler)
   165  	defer ts.Close()
   166  
   167  	c := ts.Client()
   168  	for _, disableKeepAlive := range []bool{false, true} {
   169  		c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
   170  		fetch := func(n int) string {
   171  			res, err := c.Get(ts.URL)
   172  			if err != nil {
   173  				t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
   174  			}
   175  			body, err := io.ReadAll(res.Body)
   176  			if err != nil {
   177  				t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
   178  			}
   179  			return string(body)
   180  		}
   181  
   182  		body1 := fetch(1)
   183  		body2 := fetch(2)
   184  
   185  		bodiesDiffer := body1 != body2
   186  		if bodiesDiffer != disableKeepAlive {
   187  			t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   188  				disableKeepAlive, bodiesDiffer, body1, body2)
   189  		}
   190  	}
   191  }
   192  
   193  func TestTransportConnectionCloseOnResponse(t *testing.T) {
   194  	defer afterTest(t)
   195  	ts := httptest.NewServer(hostPortHandler)
   196  	defer ts.Close()
   197  
   198  	connSet, testDial := makeTestDial(t)
   199  
   200  	c := ts.Client()
   201  	tr := c.Transport.(*Transport)
   202  	tr.Dial = testDial
   203  
   204  	for _, connectionClose := range []bool{false, true} {
   205  		fetch := func(n int) string {
   206  			req := new(Request)
   207  			var err error
   208  			req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
   209  			if err != nil {
   210  				t.Fatalf("URL parse error: %v", err)
   211  			}
   212  			req.Method = "GET"
   213  			req.Proto = "HTTP/1.1"
   214  			req.ProtoMajor = 1
   215  			req.ProtoMinor = 1
   216  
   217  			res, err := c.Do(req)
   218  			if err != nil {
   219  				t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
   220  			}
   221  			defer res.Body.Close()
   222  			body, err := io.ReadAll(res.Body)
   223  			if err != nil {
   224  				t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
   225  			}
   226  			return string(body)
   227  		}
   228  
   229  		body1 := fetch(1)
   230  		body2 := fetch(2)
   231  		bodiesDiffer := body1 != body2
   232  		if bodiesDiffer != connectionClose {
   233  			t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   234  				connectionClose, bodiesDiffer, body1, body2)
   235  		}
   236  
   237  		tr.CloseIdleConnections()
   238  	}
   239  
   240  	connSet.check(t)
   241  }
   242  
   243  func TestTransportConnectionCloseOnRequest(t *testing.T) {
   244  	defer afterTest(t)
   245  	ts := httptest.NewServer(hostPortHandler)
   246  	defer ts.Close()
   247  
   248  	connSet, testDial := makeTestDial(t)
   249  
   250  	c := ts.Client()
   251  	tr := c.Transport.(*Transport)
   252  	tr.Dial = testDial
   253  	for _, connectionClose := range []bool{false, true} {
   254  		fetch := func(n int) string {
   255  			req := new(Request)
   256  			var err error
   257  			req.URL, err = url.Parse(ts.URL)
   258  			if err != nil {
   259  				t.Fatalf("URL parse error: %v", err)
   260  			}
   261  			req.Method = "GET"
   262  			req.Proto = "HTTP/1.1"
   263  			req.ProtoMajor = 1
   264  			req.ProtoMinor = 1
   265  			req.Close = connectionClose
   266  
   267  			res, err := c.Do(req)
   268  			if err != nil {
   269  				t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
   270  			}
   271  			if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want {
   272  				t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v",
   273  					connectionClose, got, !connectionClose)
   274  			}
   275  			body, err := io.ReadAll(res.Body)
   276  			if err != nil {
   277  				t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
   278  			}
   279  			return string(body)
   280  		}
   281  
   282  		body1 := fetch(1)
   283  		body2 := fetch(2)
   284  		bodiesDiffer := body1 != body2
   285  		if bodiesDiffer != connectionClose {
   286  			t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
   287  				connectionClose, bodiesDiffer, body1, body2)
   288  		}
   289  
   290  		tr.CloseIdleConnections()
   291  	}
   292  
   293  	connSet.check(t)
   294  }
   295  
   296  // if the Transport's DisableKeepAlives is set, all requests should
   297  // send Connection: close.
   298  // HTTP/1-only (Connection: close doesn't exist in h2)
   299  func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
   300  	defer afterTest(t)
   301  	ts := httptest.NewServer(hostPortHandler)
   302  	defer ts.Close()
   303  
   304  	c := ts.Client()
   305  	c.Transport.(*Transport).DisableKeepAlives = true
   306  
   307  	res, err := c.Get(ts.URL)
   308  	if err != nil {
   309  		t.Fatal(err)
   310  	}
   311  	res.Body.Close()
   312  	if res.Header.Get("X-Saw-Close") != "true" {
   313  		t.Errorf("handler didn't see Connection: close ")
   314  	}
   315  }
   316  
   317  // Test that Transport only sends one "Connection: close", regardless of
   318  // how "close" was indicated.
   319  func TestTransportRespectRequestWantsClose(t *testing.T) {
   320  	tests := []struct {
   321  		disableKeepAlives bool
   322  		close             bool
   323  	}{
   324  		{disableKeepAlives: false, close: false},
   325  		{disableKeepAlives: false, close: true},
   326  		{disableKeepAlives: true, close: false},
   327  		{disableKeepAlives: true, close: true},
   328  	}
   329  
   330  	for _, tc := range tests {
   331  		t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
   332  			func(t *testing.T) {
   333  				defer afterTest(t)
   334  				ts := httptest.NewServer(hostPortHandler)
   335  				defer ts.Close()
   336  
   337  				c := ts.Client()
   338  				c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
   339  				req, err := NewRequest("GET", ts.URL, nil)
   340  				if err != nil {
   341  					t.Fatal(err)
   342  				}
   343  				count := 0
   344  				trace := &httptrace.ClientTrace{
   345  					WroteHeaderField: func(key string, field []string) {
   346  						if key != "Connection" {
   347  							return
   348  						}
   349  						if httpguts.HeaderValuesContainsToken(field, "close") {
   350  							count += 1
   351  						}
   352  					},
   353  				}
   354  				req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   355  				req.Close = tc.close
   356  				res, err := c.Do(req)
   357  				if err != nil {
   358  					t.Fatal(err)
   359  				}
   360  				defer res.Body.Close()
   361  				if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
   362  					t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
   363  				}
   364  			})
   365  	}
   366  
   367  }
   368  
   369  func TestTransportIdleCacheKeys(t *testing.T) {
   370  	defer afterTest(t)
   371  	ts := httptest.NewServer(hostPortHandler)
   372  	defer ts.Close()
   373  	c := ts.Client()
   374  	tr := c.Transport.(*Transport)
   375  
   376  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   377  		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
   378  	}
   379  
   380  	resp, err := c.Get(ts.URL)
   381  	if err != nil {
   382  		t.Error(err)
   383  	}
   384  	io.ReadAll(resp.Body)
   385  
   386  	keys := tr.IdleConnKeysForTesting()
   387  	if e, g := 1, len(keys); e != g {
   388  		t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
   389  	}
   390  
   391  	if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
   392  		t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
   393  	}
   394  
   395  	tr.CloseIdleConnections()
   396  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   397  		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
   398  	}
   399  }
   400  
   401  // Tests that the HTTP transport re-uses connections when a client
   402  // reads to the end of a response Body without closing it.
   403  func TestTransportReadToEndReusesConn(t *testing.T) {
   404  	defer afterTest(t)
   405  	const msg = "foobar"
   406  
   407  	var addrSeen map[string]int
   408  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   409  		addrSeen[r.RemoteAddr]++
   410  		if r.URL.Path == "/chunked/" {
   411  			w.WriteHeader(200)
   412  			w.(Flusher).Flush()
   413  		} else {
   414  			w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
   415  			w.WriteHeader(200)
   416  		}
   417  		w.Write([]byte(msg))
   418  	}))
   419  	defer ts.Close()
   420  
   421  	buf := make([]byte, len(msg))
   422  
   423  	for pi, path := range []string{"/content-length/", "/chunked/"} {
   424  		wantLen := []int{len(msg), -1}[pi]
   425  		addrSeen = make(map[string]int)
   426  		for i := 0; i < 3; i++ {
   427  			res, err := Get(ts.URL + path)
   428  			if err != nil {
   429  				t.Errorf("Get %s: %v", path, err)
   430  				continue
   431  			}
   432  			// We want to close this body eventually (before the
   433  			// defer afterTest at top runs), but not before the
   434  			// len(addrSeen) check at the bottom of this test,
   435  			// since Closing this early in the loop would risk
   436  			// making connections be re-used for the wrong reason.
   437  			defer res.Body.Close()
   438  
   439  			if res.ContentLength != int64(wantLen) {
   440  				t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
   441  			}
   442  			n, err := res.Body.Read(buf)
   443  			if n != len(msg) || err != io.EOF {
   444  				t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg))
   445  			}
   446  		}
   447  		if len(addrSeen) != 1 {
   448  			t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
   449  		}
   450  	}
   451  }
   452  
   453  func TestTransportMaxPerHostIdleConns(t *testing.T) {
   454  	defer afterTest(t)
   455  	stop := make(chan struct{}) // stop marks the exit of main Test goroutine
   456  	defer close(stop)
   457  
   458  	resch := make(chan string)
   459  	gotReq := make(chan bool)
   460  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   461  		gotReq <- true
   462  		var msg string
   463  		select {
   464  		case <-stop:
   465  			return
   466  		case msg = <-resch:
   467  		}
   468  		_, err := w.Write([]byte(msg))
   469  		if err != nil {
   470  			t.Errorf("Write: %v", err)
   471  			return
   472  		}
   473  	}))
   474  	defer ts.Close()
   475  
   476  	c := ts.Client()
   477  	tr := c.Transport.(*Transport)
   478  	maxIdleConnsPerHost := 2
   479  	tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
   480  
   481  	// Start 3 outstanding requests and wait for the server to get them.
   482  	// Their responses will hang until we write to resch, though.
   483  	donech := make(chan bool)
   484  	doReq := func() {
   485  		defer func() {
   486  			select {
   487  			case <-stop:
   488  				return
   489  			case donech <- t.Failed():
   490  			}
   491  		}()
   492  		resp, err := c.Get(ts.URL)
   493  		if err != nil {
   494  			t.Error(err)
   495  			return
   496  		}
   497  		if _, err := io.ReadAll(resp.Body); err != nil {
   498  			t.Errorf("ReadAll: %v", err)
   499  			return
   500  		}
   501  	}
   502  	go doReq()
   503  	<-gotReq
   504  	go doReq()
   505  	<-gotReq
   506  	go doReq()
   507  	<-gotReq
   508  
   509  	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
   510  		t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
   511  	}
   512  
   513  	resch <- "res1"
   514  	<-donech
   515  	keys := tr.IdleConnKeysForTesting()
   516  	if e, g := 1, len(keys); e != g {
   517  		t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
   518  	}
   519  	addr := ts.Listener.Addr().String()
   520  	cacheKey := "|http|" + addr
   521  	if keys[0] != cacheKey {
   522  		t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
   523  	}
   524  	if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
   525  		t.Errorf("after first response, expected %d idle conns; got %d", e, g)
   526  	}
   527  
   528  	resch <- "res2"
   529  	<-donech
   530  	if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
   531  		t.Errorf("after second response, idle conns = %d; want %d", g, w)
   532  	}
   533  
   534  	resch <- "res3"
   535  	<-donech
   536  	if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
   537  		t.Errorf("after third response, idle conns = %d; want %d", g, w)
   538  	}
   539  }
   540  
   541  func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
   542  	defer afterTest(t)
   543  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   544  		_, err := w.Write([]byte("foo"))
   545  		if err != nil {
   546  			t.Fatalf("Write: %v", err)
   547  		}
   548  	}))
   549  	defer ts.Close()
   550  	c := ts.Client()
   551  	tr := c.Transport.(*Transport)
   552  	dialStarted := make(chan struct{})
   553  	stallDial := make(chan struct{})
   554  	tr.Dial = func(network, addr string) (net.Conn, error) {
   555  		dialStarted <- struct{}{}
   556  		<-stallDial
   557  		return net.Dial(network, addr)
   558  	}
   559  
   560  	tr.DisableKeepAlives = true
   561  	tr.MaxConnsPerHost = 1
   562  
   563  	preDial := make(chan struct{})
   564  	reqComplete := make(chan struct{})
   565  	doReq := func(reqId string) {
   566  		req, _ := NewRequest("GET", ts.URL, nil)
   567  		trace := &httptrace.ClientTrace{
   568  			GetConn: func(hostPort string) {
   569  				preDial <- struct{}{}
   570  			},
   571  		}
   572  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   573  		resp, err := tr.RoundTrip(req)
   574  		if err != nil {
   575  			t.Errorf("unexpected error for request %s: %v", reqId, err)
   576  		}
   577  		_, err = io.ReadAll(resp.Body)
   578  		if err != nil {
   579  			t.Errorf("unexpected error for request %s: %v", reqId, err)
   580  		}
   581  		reqComplete <- struct{}{}
   582  	}
   583  	// get req1 to dial-in-progress
   584  	go doReq("req1")
   585  	<-preDial
   586  	<-dialStarted
   587  
   588  	// get req2 to waiting on conns per host to go down below max
   589  	go doReq("req2")
   590  	<-preDial
   591  	select {
   592  	case <-dialStarted:
   593  		t.Error("req2 dial started while req1 dial in progress")
   594  		return
   595  	default:
   596  	}
   597  
   598  	// let req1 complete
   599  	stallDial <- struct{}{}
   600  	<-reqComplete
   601  
   602  	// let req2 complete
   603  	<-dialStarted
   604  	stallDial <- struct{}{}
   605  	<-reqComplete
   606  }
   607  
   608  func TestTransportMaxConnsPerHost(t *testing.T) {
   609  	defer afterTest(t)
   610  	CondSkipHTTP2(t)
   611  
   612  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   613  		_, err := w.Write([]byte("foo"))
   614  		if err != nil {
   615  			t.Fatalf("Write: %v", err)
   616  		}
   617  	})
   618  
   619  	testMaxConns := func(scheme string, ts *httptest.Server) {
   620  		defer ts.Close()
   621  
   622  		c := ts.Client()
   623  		tr := c.Transport.(*Transport)
   624  		tr.MaxConnsPerHost = 1
   625  		if err := ExportHttp2ConfigureTransport(tr); err != nil {
   626  			t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
   627  		}
   628  
   629  		mu := sync.Mutex{}
   630  		var conns []net.Conn
   631  		var dialCnt, gotConnCnt, tlsHandshakeCnt int32
   632  		tr.Dial = func(network, addr string) (net.Conn, error) {
   633  			atomic.AddInt32(&dialCnt, 1)
   634  			c, err := net.Dial(network, addr)
   635  			mu.Lock()
   636  			defer mu.Unlock()
   637  			conns = append(conns, c)
   638  			return c, err
   639  		}
   640  
   641  		doReq := func() {
   642  			trace := &httptrace.ClientTrace{
   643  				GotConn: func(connInfo httptrace.GotConnInfo) {
   644  					if !connInfo.Reused {
   645  						atomic.AddInt32(&gotConnCnt, 1)
   646  					}
   647  				},
   648  				TLSHandshakeStart: func() {
   649  					atomic.AddInt32(&tlsHandshakeCnt, 1)
   650  				},
   651  			}
   652  			req, _ := NewRequest("GET", ts.URL, nil)
   653  			req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
   654  
   655  			resp, err := c.Do(req)
   656  			if err != nil {
   657  				t.Fatalf("request failed: %v", err)
   658  			}
   659  			defer resp.Body.Close()
   660  			_, err = io.ReadAll(resp.Body)
   661  			if err != nil {
   662  				t.Fatalf("read body failed: %v", err)
   663  			}
   664  		}
   665  
   666  		wg := sync.WaitGroup{}
   667  		for i := 0; i < 10; i++ {
   668  			wg.Add(1)
   669  			go func() {
   670  				defer wg.Done()
   671  				doReq()
   672  			}()
   673  		}
   674  		wg.Wait()
   675  
   676  		expected := int32(tr.MaxConnsPerHost)
   677  		if dialCnt != expected {
   678  			t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected)
   679  		}
   680  		if gotConnCnt != expected {
   681  			t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected)
   682  		}
   683  		if ts.TLS != nil && tlsHandshakeCnt != expected {
   684  			t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected)
   685  		}
   686  
   687  		if t.Failed() {
   688  			t.FailNow()
   689  		}
   690  
   691  		mu.Lock()
   692  		for _, c := range conns {
   693  			c.Close()
   694  		}
   695  		conns = nil
   696  		mu.Unlock()
   697  		tr.CloseIdleConnections()
   698  
   699  		doReq()
   700  		expected++
   701  		if dialCnt != expected {
   702  			t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt)
   703  		}
   704  		if gotConnCnt != expected {
   705  			t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected)
   706  		}
   707  		if ts.TLS != nil && tlsHandshakeCnt != expected {
   708  			t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected)
   709  		}
   710  	}
   711  
   712  	testMaxConns("http", httptest.NewServer(h))
   713  	testMaxConns("https", httptest.NewTLSServer(h))
   714  
   715  	ts := httptest.NewUnstartedServer(h)
   716  	ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
   717  	ts.StartTLS()
   718  	testMaxConns("http2", ts)
   719  }
   720  
   721  func TestTransportRemovesDeadIdleConnections(t *testing.T) {
   722  	setParallel(t)
   723  	defer afterTest(t)
   724  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   725  		io.WriteString(w, r.RemoteAddr)
   726  	}))
   727  	defer ts.Close()
   728  
   729  	c := ts.Client()
   730  	tr := c.Transport.(*Transport)
   731  
   732  	doReq := func(name string) string {
   733  		// Do a POST instead of a GET to prevent the Transport's
   734  		// idempotent request retry logic from kicking in...
   735  		res, err := c.Post(ts.URL, "", nil)
   736  		if err != nil {
   737  			t.Fatalf("%s: %v", name, err)
   738  		}
   739  		if res.StatusCode != 200 {
   740  			t.Fatalf("%s: %v", name, res.Status)
   741  		}
   742  		defer res.Body.Close()
   743  		slurp, err := io.ReadAll(res.Body)
   744  		if err != nil {
   745  			t.Fatalf("%s: %v", name, err)
   746  		}
   747  		return string(slurp)
   748  	}
   749  
   750  	first := doReq("first")
   751  	keys1 := tr.IdleConnKeysForTesting()
   752  
   753  	ts.CloseClientConnections()
   754  
   755  	var keys2 []string
   756  	if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool {
   757  		keys2 = tr.IdleConnKeysForTesting()
   758  		return len(keys2) == 0
   759  	}) {
   760  		t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2)
   761  	}
   762  
   763  	second := doReq("second")
   764  	if first == second {
   765  		t.Errorf("expected a different connection between requests. got %q both times", first)
   766  	}
   767  }
   768  
   769  // Test that the Transport notices when a server hangs up on its
   770  // unexpectedly (a keep-alive connection is closed).
   771  func TestTransportServerClosingUnexpectedly(t *testing.T) {
   772  	setParallel(t)
   773  	defer afterTest(t)
   774  	ts := httptest.NewServer(hostPortHandler)
   775  	defer ts.Close()
   776  	c := ts.Client()
   777  
   778  	fetch := func(n, retries int) string {
   779  		condFatalf := func(format string, arg ...interface{}) {
   780  			if retries <= 0 {
   781  				t.Fatalf(format, arg...)
   782  			}
   783  			t.Logf("retrying shortly after expected error: "+format, arg...)
   784  			time.Sleep(time.Second / time.Duration(retries))
   785  		}
   786  		for retries >= 0 {
   787  			retries--
   788  			res, err := c.Get(ts.URL)
   789  			if err != nil {
   790  				condFatalf("error in req #%d, GET: %v", n, err)
   791  				continue
   792  			}
   793  			body, err := io.ReadAll(res.Body)
   794  			if err != nil {
   795  				condFatalf("error in req #%d, ReadAll: %v", n, err)
   796  				continue
   797  			}
   798  			res.Body.Close()
   799  			return string(body)
   800  		}
   801  		panic("unreachable")
   802  	}
   803  
   804  	body1 := fetch(1, 0)
   805  	body2 := fetch(2, 0)
   806  
   807  	// Close all the idle connections in a way that's similar to
   808  	// the server hanging up on us. We don't use
   809  	// httptest.Server.CloseClientConnections because it's
   810  	// best-effort and stops blocking after 5 seconds. On a loaded
   811  	// machine running many tests concurrently it's possible for
   812  	// that method to be async and cause the body3 fetch below to
   813  	// run on an old connection. This function is synchronous.
   814  	ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
   815  
   816  	body3 := fetch(3, 5)
   817  
   818  	if body1 != body2 {
   819  		t.Errorf("expected body1 and body2 to be equal")
   820  	}
   821  	if body2 == body3 {
   822  		t.Errorf("expected body2 and body3 to be different")
   823  	}
   824  }
   825  
   826  // Test for https://golang.org/issue/2616 (appropriate issue number)
   827  // This fails pretty reliably with GOMAXPROCS=100 or something high.
   828  func TestStressSurpriseServerCloses(t *testing.T) {
   829  	defer afterTest(t)
   830  	if testing.Short() {
   831  		t.Skip("skipping test in short mode")
   832  	}
   833  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   834  		w.Header().Set("Content-Length", "5")
   835  		w.Header().Set("Content-Type", "text/plain")
   836  		w.Write([]byte("Hello"))
   837  		w.(Flusher).Flush()
   838  		conn, buf, _ := w.(Hijacker).Hijack()
   839  		buf.Flush()
   840  		conn.Close()
   841  	}))
   842  	defer ts.Close()
   843  	c := ts.Client()
   844  
   845  	// Do a bunch of traffic from different goroutines. Send to activityc
   846  	// after each request completes, regardless of whether it failed.
   847  	// If these are too high, OS X exhausts its ephemeral ports
   848  	// and hangs waiting for them to transition TCP states. That's
   849  	// not what we want to test. TODO(bradfitz): use an io.Pipe
   850  	// dialer for this test instead?
   851  	const (
   852  		numClients    = 20
   853  		reqsPerClient = 25
   854  	)
   855  	activityc := make(chan bool)
   856  	for i := 0; i < numClients; i++ {
   857  		go func() {
   858  			for i := 0; i < reqsPerClient; i++ {
   859  				res, err := c.Get(ts.URL)
   860  				if err == nil {
   861  					// We expect errors since the server is
   862  					// hanging up on us after telling us to
   863  					// send more requests, so we don't
   864  					// actually care what the error is.
   865  					// But we want to close the body in cases
   866  					// where we won the race.
   867  					res.Body.Close()
   868  				}
   869  				if !<-activityc { // Receives false when close(activityc) is executed
   870  					return
   871  				}
   872  			}
   873  		}()
   874  	}
   875  
   876  	// Make sure all the request come back, one way or another.
   877  	for i := 0; i < numClients*reqsPerClient; i++ {
   878  		select {
   879  		case activityc <- true:
   880  		case <-time.After(5 * time.Second):
   881  			close(activityc)
   882  			t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile")
   883  		}
   884  	}
   885  }
   886  
   887  // TestTransportHeadResponses verifies that we deal with Content-Lengths
   888  // with no bodies properly
   889  func TestTransportHeadResponses(t *testing.T) {
   890  	defer afterTest(t)
   891  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   892  		if r.Method != "HEAD" {
   893  			panic("expected HEAD; got " + r.Method)
   894  		}
   895  		w.Header().Set("Content-Length", "123")
   896  		w.WriteHeader(200)
   897  	}))
   898  	defer ts.Close()
   899  	c := ts.Client()
   900  
   901  	for i := 0; i < 2; i++ {
   902  		res, err := c.Head(ts.URL)
   903  		if err != nil {
   904  			t.Errorf("error on loop %d: %v", i, err)
   905  			continue
   906  		}
   907  		if e, g := "123", res.Header.Get("Content-Length"); e != g {
   908  			t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
   909  		}
   910  		if e, g := int64(123), res.ContentLength; e != g {
   911  			t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
   912  		}
   913  		if all, err := io.ReadAll(res.Body); err != nil {
   914  			t.Errorf("loop %d: Body ReadAll: %v", i, err)
   915  		} else if len(all) != 0 {
   916  			t.Errorf("Bogus body %q", all)
   917  		}
   918  	}
   919  }
   920  
   921  // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
   922  // on responses to HEAD requests.
   923  func TestTransportHeadChunkedResponse(t *testing.T) {
   924  	defer afterTest(t)
   925  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
   926  		if r.Method != "HEAD" {
   927  			panic("expected HEAD; got " + r.Method)
   928  		}
   929  		w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
   930  		w.Header().Set("x-client-ipport", r.RemoteAddr)
   931  		w.WriteHeader(200)
   932  	}))
   933  	defer ts.Close()
   934  	c := ts.Client()
   935  
   936  	// Ensure that we wait for the readLoop to complete before
   937  	// calling Head again
   938  	didRead := make(chan bool)
   939  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
   940  	defer SetReadLoopBeforeNextReadHook(nil)
   941  
   942  	res1, err := c.Head(ts.URL)
   943  	<-didRead
   944  
   945  	if err != nil {
   946  		t.Fatalf("request 1 error: %v", err)
   947  	}
   948  
   949  	res2, err := c.Head(ts.URL)
   950  	<-didRead
   951  
   952  	if err != nil {
   953  		t.Fatalf("request 2 error: %v", err)
   954  	}
   955  	if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
   956  		t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
   957  	}
   958  }
   959  
   960  var roundTripTests = []struct {
   961  	accept       string
   962  	expectAccept string
   963  	compressed   bool
   964  }{
   965  	// Requests with no accept-encoding header use transparent compression
   966  	{"", "gzip", false},
   967  	// Requests with other accept-encoding should pass through unmodified
   968  	{"foo", "foo", false},
   969  	// Requests with accept-encoding == gzip should be passed through
   970  	{"gzip", "gzip", true},
   971  }
   972  
   973  // Test that the modification made to the Request by the RoundTripper is cleaned up
   974  func TestRoundTripGzip(t *testing.T) {
   975  	setParallel(t)
   976  	defer afterTest(t)
   977  	const responseBody = "test response body"
   978  	ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
   979  		accept := req.Header.Get("Accept-Encoding")
   980  		if expect := req.FormValue("expect_accept"); accept != expect {
   981  			t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
   982  				req.FormValue("testnum"), accept, expect)
   983  		}
   984  		if accept == "gzip" {
   985  			rw.Header().Set("Content-Encoding", "gzip")
   986  			gz := gzip.NewWriter(rw)
   987  			gz.Write([]byte(responseBody))
   988  			gz.Close()
   989  		} else {
   990  			rw.Header().Set("Content-Encoding", accept)
   991  			rw.Write([]byte(responseBody))
   992  		}
   993  	}))
   994  	defer ts.Close()
   995  	tr := ts.Client().Transport.(*Transport)
   996  
   997  	for i, test := range roundTripTests {
   998  		// Test basic request (no accept-encoding)
   999  		req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
  1000  		if test.accept != "" {
  1001  			req.Header.Set("Accept-Encoding", test.accept)
  1002  		}
  1003  		res, err := tr.RoundTrip(req)
  1004  		if err != nil {
  1005  			t.Errorf("%d. RoundTrip: %v", i, err)
  1006  			continue
  1007  		}
  1008  		var body []byte
  1009  		if test.compressed {
  1010  			var r *gzip.Reader
  1011  			r, err = gzip.NewReader(res.Body)
  1012  			if err != nil {
  1013  				t.Errorf("%d. gzip NewReader: %v", i, err)
  1014  				continue
  1015  			}
  1016  			body, err = io.ReadAll(r)
  1017  			res.Body.Close()
  1018  		} else {
  1019  			body, err = io.ReadAll(res.Body)
  1020  		}
  1021  		if err != nil {
  1022  			t.Errorf("%d. Error: %q", i, err)
  1023  			continue
  1024  		}
  1025  		if g, e := string(body), responseBody; g != e {
  1026  			t.Errorf("%d. body = %q; want %q", i, g, e)
  1027  		}
  1028  		if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
  1029  			t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
  1030  		}
  1031  		if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
  1032  			t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
  1033  		}
  1034  	}
  1035  
  1036  }
  1037  
  1038  func TestTransportGzip(t *testing.T) {
  1039  	setParallel(t)
  1040  	defer afterTest(t)
  1041  	const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
  1042  	const nRandBytes = 1024 * 1024
  1043  	ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
  1044  		if req.Method == "HEAD" {
  1045  			if g := req.Header.Get("Accept-Encoding"); g != "" {
  1046  				t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
  1047  			}
  1048  			return
  1049  		}
  1050  		if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
  1051  			t.Errorf("Accept-Encoding = %q, want %q", g, e)
  1052  		}
  1053  		rw.Header().Set("Content-Encoding", "gzip")
  1054  
  1055  		var w io.Writer = rw
  1056  		var buf bytes.Buffer
  1057  		if req.FormValue("chunked") == "0" {
  1058  			w = &buf
  1059  			defer io.Copy(rw, &buf)
  1060  			defer func() {
  1061  				rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
  1062  			}()
  1063  		}
  1064  		gz := gzip.NewWriter(w)
  1065  		gz.Write([]byte(testString))
  1066  		if req.FormValue("body") == "large" {
  1067  			io.CopyN(gz, rand.Reader, nRandBytes)
  1068  		}
  1069  		gz.Close()
  1070  	}))
  1071  	defer ts.Close()
  1072  	c := ts.Client()
  1073  
  1074  	for _, chunked := range []string{"1", "0"} {
  1075  		// First fetch something large, but only read some of it.
  1076  		res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
  1077  		if err != nil {
  1078  			t.Fatalf("large get: %v", err)
  1079  		}
  1080  		buf := make([]byte, len(testString))
  1081  		n, err := io.ReadFull(res.Body, buf)
  1082  		if err != nil {
  1083  			t.Fatalf("partial read of large response: size=%d, %v", n, err)
  1084  		}
  1085  		if e, g := testString, string(buf); e != g {
  1086  			t.Errorf("partial read got %q, expected %q", g, e)
  1087  		}
  1088  		res.Body.Close()
  1089  		// Read on the body, even though it's closed
  1090  		n, err = res.Body.Read(buf)
  1091  		if n != 0 || err == nil {
  1092  			t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
  1093  		}
  1094  
  1095  		// Then something small.
  1096  		res, err = c.Get(ts.URL + "/?chunked=" + chunked)
  1097  		if err != nil {
  1098  			t.Fatal(err)
  1099  		}
  1100  		body, err := io.ReadAll(res.Body)
  1101  		if err != nil {
  1102  			t.Fatal(err)
  1103  		}
  1104  		if g, e := string(body), testString; g != e {
  1105  			t.Fatalf("body = %q; want %q", g, e)
  1106  		}
  1107  		if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
  1108  			t.Fatalf("Content-Encoding = %q; want %q", g, e)
  1109  		}
  1110  
  1111  		// Read on the body after it's been fully read:
  1112  		n, err = res.Body.Read(buf)
  1113  		if n != 0 || err == nil {
  1114  			t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
  1115  		}
  1116  		res.Body.Close()
  1117  		n, err = res.Body.Read(buf)
  1118  		if n != 0 || err == nil {
  1119  			t.Errorf("expected Read error after Close; got %d, %v", n, err)
  1120  		}
  1121  	}
  1122  
  1123  	// And a HEAD request too, because they're always weird.
  1124  	res, err := c.Head(ts.URL)
  1125  	if err != nil {
  1126  		t.Fatalf("Head: %v", err)
  1127  	}
  1128  	if res.StatusCode != 200 {
  1129  		t.Errorf("Head status=%d; want=200", res.StatusCode)
  1130  	}
  1131  }
  1132  
  1133  // If a request has Expect:100-continue header, the request blocks sending body until the first response.
  1134  // Premature consumption of the request body should not be occurred.
  1135  func TestTransportExpect100Continue(t *testing.T) {
  1136  	setParallel(t)
  1137  	defer afterTest(t)
  1138  
  1139  	ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
  1140  		switch req.URL.Path {
  1141  		case "/100":
  1142  			// This endpoint implicitly responds 100 Continue and reads body.
  1143  			if _, err := io.Copy(io.Discard, req.Body); err != nil {
  1144  				t.Error("Failed to read Body", err)
  1145  			}
  1146  			rw.WriteHeader(StatusOK)
  1147  		case "/200":
  1148  			// Go 1.5 adds Connection: close header if the client expect
  1149  			// continue but not entire request body is consumed.
  1150  			rw.WriteHeader(StatusOK)
  1151  		case "/500":
  1152  			rw.WriteHeader(StatusInternalServerError)
  1153  		case "/keepalive":
  1154  			// This hijacked endpoint responds error without Connection:close.
  1155  			_, bufrw, err := rw.(Hijacker).Hijack()
  1156  			if err != nil {
  1157  				log.Fatal(err)
  1158  			}
  1159  			bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
  1160  			bufrw.WriteString("Content-Length: 0\r\n\r\n")
  1161  			bufrw.Flush()
  1162  		case "/timeout":
  1163  			// This endpoint tries to read body without 100 (Continue) response.
  1164  			// After ExpectContinueTimeout, the reading will be started.
  1165  			conn, bufrw, err := rw.(Hijacker).Hijack()
  1166  			if err != nil {
  1167  				log.Fatal(err)
  1168  			}
  1169  			if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
  1170  				t.Error("Failed to read Body", err)
  1171  			}
  1172  			bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
  1173  			bufrw.Flush()
  1174  			conn.Close()
  1175  		}
  1176  
  1177  	}))
  1178  	defer ts.Close()
  1179  
  1180  	tests := []struct {
  1181  		path   string
  1182  		body   []byte
  1183  		sent   int
  1184  		status int
  1185  	}{
  1186  		{path: "/100", body: []byte("hello"), sent: 5, status: 200},       // Got 100 followed by 200, entire body is sent.
  1187  		{path: "/200", body: []byte("hello"), sent: 0, status: 200},       // Got 200 without 100. body isn't sent.
  1188  		{path: "/500", body: []byte("hello"), sent: 0, status: 500},       // Got 500 without 100. body isn't sent.
  1189  		{path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent.
  1190  		{path: "/timeout", body: []byte("hello"), sent: 5, status: 200},   // Timeout exceeded and entire body is sent.
  1191  	}
  1192  
  1193  	c := ts.Client()
  1194  	for i, v := range tests {
  1195  		tr := &Transport{
  1196  			ExpectContinueTimeout: 2 * time.Second,
  1197  		}
  1198  		defer tr.CloseIdleConnections()
  1199  		c.Transport = tr
  1200  		body := bytes.NewReader(v.body)
  1201  		req, err := NewRequest("PUT", ts.URL+v.path, body)
  1202  		if err != nil {
  1203  			t.Fatal(err)
  1204  		}
  1205  		req.Header.Set("Expect", "100-continue")
  1206  		req.ContentLength = int64(len(v.body))
  1207  
  1208  		resp, err := c.Do(req)
  1209  		if err != nil {
  1210  			t.Fatal(err)
  1211  		}
  1212  		resp.Body.Close()
  1213  
  1214  		sent := len(v.body) - body.Len()
  1215  		if v.status != resp.StatusCode {
  1216  			t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
  1217  		}
  1218  		if v.sent != sent {
  1219  			t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
  1220  		}
  1221  	}
  1222  }
  1223  
  1224  func TestSOCKS5Proxy(t *testing.T) {
  1225  	defer afterTest(t)
  1226  	ch := make(chan string, 1)
  1227  	l := newLocalListener(t)
  1228  	defer l.Close()
  1229  	defer close(ch)
  1230  	proxy := func(t *testing.T) {
  1231  		s, err := l.Accept()
  1232  		if err != nil {
  1233  			t.Errorf("socks5 proxy Accept(): %v", err)
  1234  			return
  1235  		}
  1236  		defer s.Close()
  1237  		var buf [22]byte
  1238  		if _, err := io.ReadFull(s, buf[:3]); err != nil {
  1239  			t.Errorf("socks5 proxy initial read: %v", err)
  1240  			return
  1241  		}
  1242  		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
  1243  			t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
  1244  			return
  1245  		}
  1246  		if _, err := s.Write([]byte{5, 0}); err != nil {
  1247  			t.Errorf("socks5 proxy initial write: %v", err)
  1248  			return
  1249  		}
  1250  		if _, err := io.ReadFull(s, buf[:4]); err != nil {
  1251  			t.Errorf("socks5 proxy second read: %v", err)
  1252  			return
  1253  		}
  1254  		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
  1255  			t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
  1256  			return
  1257  		}
  1258  		var ipLen int
  1259  		switch buf[3] {
  1260  		case 1:
  1261  			ipLen = net.IPv4len
  1262  		case 4:
  1263  			ipLen = net.IPv6len
  1264  		default:
  1265  			t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
  1266  			return
  1267  		}
  1268  		if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
  1269  			t.Errorf("socks5 proxy address read: %v", err)
  1270  			return
  1271  		}
  1272  		ip := net.IP(buf[4 : ipLen+4])
  1273  		port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
  1274  		copy(buf[:3], []byte{5, 0, 0})
  1275  		if _, err := s.Write(buf[:ipLen+6]); err != nil {
  1276  			t.Errorf("socks5 proxy connect write: %v", err)
  1277  			return
  1278  		}
  1279  		ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
  1280  
  1281  		// Implement proxying.
  1282  		targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
  1283  		targetConn, err := net.Dial("tcp", targetHost)
  1284  		if err != nil {
  1285  			t.Errorf("net.Dial failed")
  1286  			return
  1287  		}
  1288  		go io.Copy(targetConn, s)
  1289  		io.Copy(s, targetConn) // Wait for the client to close the socket.
  1290  		targetConn.Close()
  1291  	}
  1292  
  1293  	pu, err := url.Parse("socks5://" + l.Addr().String())
  1294  	if err != nil {
  1295  		t.Fatal(err)
  1296  	}
  1297  
  1298  	sentinelHeader := "X-Sentinel"
  1299  	sentinelValue := "12345"
  1300  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
  1301  		w.Header().Set(sentinelHeader, sentinelValue)
  1302  	})
  1303  	for _, useTLS := range []bool{false, true} {
  1304  		t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
  1305  			var ts *httptest.Server
  1306  			if useTLS {
  1307  				ts = httptest.NewTLSServer(h)
  1308  			} else {
  1309  				ts = httptest.NewServer(h)
  1310  			}
  1311  			go proxy(t)
  1312  			c := ts.Client()
  1313  			c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1314  			r, err := c.Head(ts.URL)
  1315  			if err != nil {
  1316  				t.Fatal(err)
  1317  			}
  1318  			if r.Header.Get(sentinelHeader) != sentinelValue {
  1319  				t.Errorf("Failed to retrieve sentinel value")
  1320  			}
  1321  			var got string
  1322  			select {
  1323  			case got = <-ch:
  1324  			case <-time.After(5 * time.Second):
  1325  				t.Fatal("timeout connecting to socks5 proxy")
  1326  			}
  1327  			ts.Close()
  1328  			tsu, err := url.Parse(ts.URL)
  1329  			if err != nil {
  1330  				t.Fatal(err)
  1331  			}
  1332  			want := "proxy for " + tsu.Host
  1333  			if got != want {
  1334  				t.Errorf("got %q, want %q", got, want)
  1335  			}
  1336  		})
  1337  	}
  1338  }
  1339  
  1340  func TestTransportProxy(t *testing.T) {
  1341  	defer afterTest(t)
  1342  	testCases := []struct{ httpsSite, httpsProxy bool }{
  1343  		{false, false},
  1344  		{false, true},
  1345  		{true, false},
  1346  		{true, true},
  1347  	}
  1348  	for _, testCase := range testCases {
  1349  		httpsSite := testCase.httpsSite
  1350  		httpsProxy := testCase.httpsProxy
  1351  		t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) {
  1352  			siteCh := make(chan *Request, 1)
  1353  			h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1354  				siteCh <- r
  1355  			})
  1356  			proxyCh := make(chan *Request, 1)
  1357  			h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
  1358  				proxyCh <- r
  1359  				// Implement an entire CONNECT proxy
  1360  				if r.Method == "CONNECT" {
  1361  					hijacker, ok := w.(Hijacker)
  1362  					if !ok {
  1363  						t.Errorf("hijack not allowed")
  1364  						return
  1365  					}
  1366  					clientConn, _, err := hijacker.Hijack()
  1367  					if err != nil {
  1368  						t.Errorf("hijacking failed")
  1369  						return
  1370  					}
  1371  					res := &Response{
  1372  						StatusCode: StatusOK,
  1373  						Proto:      "HTTP/1.1",
  1374  						ProtoMajor: 1,
  1375  						ProtoMinor: 1,
  1376  						Header:     make(Header),
  1377  					}
  1378  
  1379  					targetConn, err := net.Dial("tcp", r.URL.Host)
  1380  					if err != nil {
  1381  						t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
  1382  						return
  1383  					}
  1384  
  1385  					if err := res.Write(clientConn); err != nil {
  1386  						t.Errorf("Writing 200 OK failed: %v", err)
  1387  						return
  1388  					}
  1389  
  1390  					go io.Copy(targetConn, clientConn)
  1391  					go func() {
  1392  						io.Copy(clientConn, targetConn)
  1393  						targetConn.Close()
  1394  					}()
  1395  				}
  1396  			})
  1397  			var ts *httptest.Server
  1398  			if httpsSite {
  1399  				ts = httptest.NewTLSServer(h1)
  1400  			} else {
  1401  				ts = httptest.NewServer(h1)
  1402  			}
  1403  			var proxy *httptest.Server
  1404  			if httpsProxy {
  1405  				proxy = httptest.NewTLSServer(h2)
  1406  			} else {
  1407  				proxy = httptest.NewServer(h2)
  1408  			}
  1409  
  1410  			pu, err := url.Parse(proxy.URL)
  1411  			if err != nil {
  1412  				t.Fatal(err)
  1413  			}
  1414  
  1415  			// If neither server is HTTPS or both are, then c may be derived from either.
  1416  			// If only one server is HTTPS, c must be derived from that server in order
  1417  			// to ensure that it is configured to use the fake root CA from testcert.go.
  1418  			c := proxy.Client()
  1419  			if httpsSite {
  1420  				c = ts.Client()
  1421  			}
  1422  
  1423  			c.Transport.(*Transport).Proxy = ProxyURL(pu)
  1424  			if _, err := c.Head(ts.URL); err != nil {
  1425  				t.Error(err)
  1426  			}
  1427  			var got *Request
  1428  			select {
  1429  			case got = <-proxyCh:
  1430  			case <-time.After(5 * time.Second):
  1431  				t.Fatal("timeout connecting to http proxy")
  1432  			}
  1433  			c.Transport.(*Transport).CloseIdleConnections()
  1434  			ts.Close()
  1435  			proxy.Close()
  1436  			if httpsSite {
  1437  				// First message should be a CONNECT, asking for a socket to the real server,
  1438  				if got.Method != "CONNECT" {
  1439  					t.Errorf("Wrong method for secure proxying: %q", got.Method)
  1440  				}
  1441  				gotHost := got.URL.Host
  1442  				pu, err := url.Parse(ts.URL)
  1443  				if err != nil {
  1444  					t.Fatal("Invalid site URL")
  1445  				}
  1446  				if wantHost := pu.Host; gotHost != wantHost {
  1447  					t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
  1448  				}
  1449  
  1450  				// The next message on the channel should be from the site's server.
  1451  				next := <-siteCh
  1452  				if next.Method != "HEAD" {
  1453  					t.Errorf("Wrong method at destination: %s", next.Method)
  1454  				}
  1455  				if nextURL := next.URL.String(); nextURL != "/" {
  1456  					t.Errorf("Wrong URL at destination: %s", nextURL)
  1457  				}
  1458  			} else {
  1459  				if got.Method != "HEAD" {
  1460  					t.Errorf("Wrong method for destination: %q", got.Method)
  1461  				}
  1462  				gotURL := got.URL.String()
  1463  				wantURL := ts.URL + "/"
  1464  				if gotURL != wantURL {
  1465  					t.Errorf("Got URL %q, want %q", gotURL, wantURL)
  1466  				}
  1467  			}
  1468  		})
  1469  	}
  1470  }
  1471  
  1472  // Issue 28012: verify that the Transport closes its TCP connection to http proxies
  1473  // when they're slow to reply to HTTPS CONNECT responses.
  1474  func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
  1475  	setParallel(t)
  1476  	defer afterTest(t)
  1477  
  1478  	ctx, cancel := context.WithCancel(context.Background())
  1479  	defer cancel()
  1480  
  1481  	ln := newLocalListener(t)
  1482  	defer ln.Close()
  1483  	listenerDone := make(chan struct{})
  1484  	go func() {
  1485  		defer close(listenerDone)
  1486  		c, err := ln.Accept()
  1487  		if err != nil {
  1488  			t.Errorf("Accept: %v", err)
  1489  			return
  1490  		}
  1491  		defer c.Close()
  1492  		// Read the CONNECT request
  1493  		br := bufio.NewReader(c)
  1494  		cr, err := ReadRequest(br)
  1495  		if err != nil {
  1496  			t.Errorf("proxy server failed to read CONNECT request")
  1497  			return
  1498  		}
  1499  		if cr.Method != "CONNECT" {
  1500  			t.Errorf("unexpected method %q", cr.Method)
  1501  			return
  1502  		}
  1503  
  1504  		// Now hang and never write a response; instead, cancel the request and wait
  1505  		// for the client to close.
  1506  		// (Prior to Issue 28012 being fixed, we never closed.)
  1507  		cancel()
  1508  		var buf [1]byte
  1509  		_, err = br.Read(buf[:])
  1510  		if err != io.EOF {
  1511  			t.Errorf("proxy server Read err = %v; want EOF", err)
  1512  		}
  1513  		return
  1514  	}()
  1515  
  1516  	c := &Client{
  1517  		Transport: &Transport{
  1518  			Proxy: func(*Request) (*url.URL, error) {
  1519  				return url.Parse("http://" + ln.Addr().String())
  1520  			},
  1521  		},
  1522  	}
  1523  	req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
  1524  	if err != nil {
  1525  		t.Fatal(err)
  1526  	}
  1527  	_, err = c.Do(req)
  1528  	if err == nil {
  1529  		t.Errorf("unexpected Get success")
  1530  	}
  1531  
  1532  	// Wait unconditionally for the listener goroutine to exit: this should never
  1533  	// hang, so if it does we want a full goroutine dump — and that's exactly what
  1534  	// the testing package will give us when the test run times out.
  1535  	<-listenerDone
  1536  }
  1537  
  1538  // Issue 16997: test transport dial preserves typed errors
  1539  func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
  1540  	defer afterTest(t)
  1541  
  1542  	var errDial = errors.New("some dial error")
  1543  
  1544  	tr := &Transport{
  1545  		Proxy: func(*Request) (*url.URL, error) {
  1546  			return url.Parse("http://proxy.fake.tld/")
  1547  		},
  1548  		Dial: func(string, string) (net.Conn, error) {
  1549  			return nil, errDial
  1550  		},
  1551  	}
  1552  	defer tr.CloseIdleConnections()
  1553  
  1554  	c := &Client{Transport: tr}
  1555  	req, _ := NewRequest("GET", "http://fake.tld", nil)
  1556  	res, err := c.Do(req)
  1557  	if err == nil {
  1558  		res.Body.Close()
  1559  		t.Fatal("wanted a non-nil error")
  1560  	}
  1561  
  1562  	uerr, ok := err.(*url.Error)
  1563  	if !ok {
  1564  		t.Fatalf("got %T, want *url.Error", err)
  1565  	}
  1566  	oe, ok := uerr.Err.(*net.OpError)
  1567  	if !ok {
  1568  		t.Fatalf("url.Error.Err =  %T; want *net.OpError", uerr.Err)
  1569  	}
  1570  	want := &net.OpError{
  1571  		Op:  "proxyconnect",
  1572  		Net: "tcp",
  1573  		Err: errDial, // original error, unwrapped.
  1574  	}
  1575  	if !reflect.DeepEqual(oe, want) {
  1576  		t.Errorf("Got error %#v; want %#v", oe, want)
  1577  	}
  1578  }
  1579  
  1580  // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
  1581  //
  1582  // (A bug caused dialConn to instead write the per-request Proxy-Authorization
  1583  // header through to the shared Header instance, introducing a data race.)
  1584  func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
  1585  	setParallel(t)
  1586  	defer afterTest(t)
  1587  
  1588  	proxy := httptest.NewTLSServer(NotFoundHandler())
  1589  	defer proxy.Close()
  1590  	c := proxy.Client()
  1591  
  1592  	tr := c.Transport.(*Transport)
  1593  	tr.Proxy = func(*Request) (*url.URL, error) {
  1594  		u, _ := url.Parse(proxy.URL)
  1595  		u.User = url.UserPassword("aladdin", "opensesame")
  1596  		return u, nil
  1597  	}
  1598  	h := tr.ProxyConnectHeader
  1599  	if h == nil {
  1600  		h = make(Header)
  1601  	}
  1602  	tr.ProxyConnectHeader = h.Clone()
  1603  
  1604  	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
  1605  	if err != nil {
  1606  		t.Fatal(err)
  1607  	}
  1608  	_, err = c.Do(req)
  1609  	if err == nil {
  1610  		t.Errorf("unexpected Get success")
  1611  	}
  1612  
  1613  	if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
  1614  		t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
  1615  	}
  1616  }
  1617  
  1618  // TestTransportGzipRecursive sends a gzip quine and checks that the
  1619  // client gets the same value back. This is more cute than anything,
  1620  // but checks that we don't recurse forever, and checks that
  1621  // Content-Encoding is removed.
  1622  func TestTransportGzipRecursive(t *testing.T) {
  1623  	defer afterTest(t)
  1624  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1625  		w.Header().Set("Content-Encoding", "gzip")
  1626  		w.Write(rgz)
  1627  	}))
  1628  	defer ts.Close()
  1629  
  1630  	c := ts.Client()
  1631  	res, err := c.Get(ts.URL)
  1632  	if err != nil {
  1633  		t.Fatal(err)
  1634  	}
  1635  	body, err := io.ReadAll(res.Body)
  1636  	if err != nil {
  1637  		t.Fatal(err)
  1638  	}
  1639  	if !bytes.Equal(body, rgz) {
  1640  		t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
  1641  			body, rgz)
  1642  	}
  1643  	if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
  1644  		t.Fatalf("Content-Encoding = %q; want %q", g, e)
  1645  	}
  1646  }
  1647  
  1648  // golang.org/issue/7750: request fails when server replies with
  1649  // a short gzip body
  1650  func TestTransportGzipShort(t *testing.T) {
  1651  	defer afterTest(t)
  1652  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1653  		w.Header().Set("Content-Encoding", "gzip")
  1654  		w.Write([]byte{0x1f, 0x8b})
  1655  	}))
  1656  	defer ts.Close()
  1657  
  1658  	c := ts.Client()
  1659  	res, err := c.Get(ts.URL)
  1660  	if err != nil {
  1661  		t.Fatal(err)
  1662  	}
  1663  	defer res.Body.Close()
  1664  	_, err = io.ReadAll(res.Body)
  1665  	if err == nil {
  1666  		t.Fatal("Expect an error from reading a body.")
  1667  	}
  1668  	if err != io.ErrUnexpectedEOF {
  1669  		t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
  1670  	}
  1671  }
  1672  
  1673  // Wait until number of goroutines is no greater than nmax, or time out.
  1674  func waitNumGoroutine(nmax int) int {
  1675  	nfinal := runtime.NumGoroutine()
  1676  	for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
  1677  		time.Sleep(50 * time.Millisecond)
  1678  		runtime.GC()
  1679  		nfinal = runtime.NumGoroutine()
  1680  	}
  1681  	return nfinal
  1682  }
  1683  
  1684  // tests that persistent goroutine connections shut down when no longer desired.
  1685  func TestTransportPersistConnLeak(t *testing.T) {
  1686  	// Not parallel: counts goroutines
  1687  	defer afterTest(t)
  1688  
  1689  	const numReq = 25
  1690  	gotReqCh := make(chan bool, numReq)
  1691  	unblockCh := make(chan bool, numReq)
  1692  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1693  		gotReqCh <- true
  1694  		<-unblockCh
  1695  		w.Header().Set("Content-Length", "0")
  1696  		w.WriteHeader(204)
  1697  	}))
  1698  	defer ts.Close()
  1699  	c := ts.Client()
  1700  	tr := c.Transport.(*Transport)
  1701  
  1702  	n0 := runtime.NumGoroutine()
  1703  
  1704  	didReqCh := make(chan bool, numReq)
  1705  	failed := make(chan bool, numReq)
  1706  	for i := 0; i < numReq; i++ {
  1707  		go func() {
  1708  			res, err := c.Get(ts.URL)
  1709  			didReqCh <- true
  1710  			if err != nil {
  1711  				t.Logf("client fetch error: %v", err)
  1712  				failed <- true
  1713  				return
  1714  			}
  1715  			res.Body.Close()
  1716  		}()
  1717  	}
  1718  
  1719  	// Wait for all goroutines to be stuck in the Handler.
  1720  	for i := 0; i < numReq; i++ {
  1721  		select {
  1722  		case <-gotReqCh:
  1723  			// ok
  1724  		case <-failed:
  1725  			// Not great but not what we are testing:
  1726  			// sometimes an overloaded system will fail to make all the connections.
  1727  		}
  1728  	}
  1729  
  1730  	nhigh := runtime.NumGoroutine()
  1731  
  1732  	// Tell all handlers to unblock and reply.
  1733  	close(unblockCh)
  1734  
  1735  	// Wait for all HTTP clients to be done.
  1736  	for i := 0; i < numReq; i++ {
  1737  		<-didReqCh
  1738  	}
  1739  
  1740  	tr.CloseIdleConnections()
  1741  	nfinal := waitNumGoroutine(n0 + 5)
  1742  
  1743  	growth := nfinal - n0
  1744  
  1745  	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
  1746  	// Previously we were leaking one per numReq.
  1747  	if int(growth) > 5 {
  1748  		t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
  1749  		t.Error("too many new goroutines")
  1750  	}
  1751  }
  1752  
  1753  // golang.org/issue/4531: Transport leaks goroutines when
  1754  // request.ContentLength is explicitly short
  1755  func TestTransportPersistConnLeakShortBody(t *testing.T) {
  1756  	// Not parallel: measures goroutines.
  1757  	defer afterTest(t)
  1758  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1759  	}))
  1760  	defer ts.Close()
  1761  	c := ts.Client()
  1762  	tr := c.Transport.(*Transport)
  1763  
  1764  	n0 := runtime.NumGoroutine()
  1765  	body := []byte("Hello")
  1766  	for i := 0; i < 20; i++ {
  1767  		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  1768  		if err != nil {
  1769  			t.Fatal(err)
  1770  		}
  1771  		req.ContentLength = int64(len(body) - 2) // explicitly short
  1772  		_, err = c.Do(req)
  1773  		if err == nil {
  1774  			t.Fatal("Expect an error from writing too long of a body.")
  1775  		}
  1776  	}
  1777  	nhigh := runtime.NumGoroutine()
  1778  	tr.CloseIdleConnections()
  1779  	nfinal := waitNumGoroutine(n0 + 5)
  1780  
  1781  	growth := nfinal - n0
  1782  
  1783  	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
  1784  	// Previously we were leaking one per numReq.
  1785  	t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
  1786  	if int(growth) > 5 {
  1787  		t.Error("too many new goroutines")
  1788  	}
  1789  }
  1790  
  1791  // A countedConn is a net.Conn that decrements an atomic counter when finalized.
  1792  type countedConn struct {
  1793  	net.Conn
  1794  }
  1795  
  1796  // A countingDialer dials connections and counts the number that remain reachable.
  1797  type countingDialer struct {
  1798  	dialer      net.Dialer
  1799  	mu          sync.Mutex
  1800  	total, live int64
  1801  }
  1802  
  1803  func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  1804  	conn, err := d.dialer.DialContext(ctx, network, address)
  1805  	if err != nil {
  1806  		return nil, err
  1807  	}
  1808  
  1809  	counted := new(countedConn)
  1810  	counted.Conn = conn
  1811  
  1812  	d.mu.Lock()
  1813  	defer d.mu.Unlock()
  1814  	d.total++
  1815  	d.live++
  1816  
  1817  	runtime.SetFinalizer(counted, d.decrement)
  1818  	return counted, nil
  1819  }
  1820  
  1821  func (d *countingDialer) decrement(*countedConn) {
  1822  	d.mu.Lock()
  1823  	defer d.mu.Unlock()
  1824  	d.live--
  1825  }
  1826  
  1827  func (d *countingDialer) Read() (total, live int64) {
  1828  	d.mu.Lock()
  1829  	defer d.mu.Unlock()
  1830  	return d.total, d.live
  1831  }
  1832  
  1833  func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
  1834  	defer afterTest(t)
  1835  
  1836  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1837  		// Close every connection so that it cannot be kept alive.
  1838  		conn, _, err := w.(Hijacker).Hijack()
  1839  		if err != nil {
  1840  			t.Errorf("Hijack failed unexpectedly: %v", err)
  1841  			return
  1842  		}
  1843  		conn.Close()
  1844  	}))
  1845  	defer ts.Close()
  1846  
  1847  	var d countingDialer
  1848  	c := ts.Client()
  1849  	c.Transport.(*Transport).DialContext = d.DialContext
  1850  
  1851  	body := []byte("Hello")
  1852  	for i := 0; ; i++ {
  1853  		total, live := d.Read()
  1854  		if live < total {
  1855  			break
  1856  		}
  1857  		if i >= 1<<12 {
  1858  			t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
  1859  		}
  1860  
  1861  		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  1862  		if err != nil {
  1863  			t.Fatal(err)
  1864  		}
  1865  		_, err = c.Do(req)
  1866  		if err == nil {
  1867  			t.Fatal("expected broken connection")
  1868  		}
  1869  
  1870  		runtime.GC()
  1871  	}
  1872  }
  1873  
  1874  type countedContext struct {
  1875  	context.Context
  1876  }
  1877  
  1878  type contextCounter struct {
  1879  	mu   sync.Mutex
  1880  	live int64
  1881  }
  1882  
  1883  func (cc *contextCounter) Track(ctx context.Context) context.Context {
  1884  	counted := new(countedContext)
  1885  	counted.Context = ctx
  1886  	cc.mu.Lock()
  1887  	defer cc.mu.Unlock()
  1888  	cc.live++
  1889  	runtime.SetFinalizer(counted, cc.decrement)
  1890  	return counted
  1891  }
  1892  
  1893  func (cc *contextCounter) decrement(*countedContext) {
  1894  	cc.mu.Lock()
  1895  	defer cc.mu.Unlock()
  1896  	cc.live--
  1897  }
  1898  
  1899  func (cc *contextCounter) Read() (live int64) {
  1900  	cc.mu.Lock()
  1901  	defer cc.mu.Unlock()
  1902  	return cc.live
  1903  }
  1904  
  1905  func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
  1906  	defer afterTest(t)
  1907  
  1908  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1909  		runtime.Gosched()
  1910  		w.WriteHeader(StatusOK)
  1911  	}))
  1912  	defer ts.Close()
  1913  
  1914  	c := ts.Client()
  1915  	c.Transport.(*Transport).MaxConnsPerHost = 1
  1916  
  1917  	ctx := context.Background()
  1918  	body := []byte("Hello")
  1919  	doPosts := func(cc *contextCounter) {
  1920  		var wg sync.WaitGroup
  1921  		for n := 64; n > 0; n-- {
  1922  			wg.Add(1)
  1923  			go func() {
  1924  				defer wg.Done()
  1925  
  1926  				ctx := cc.Track(ctx)
  1927  				req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
  1928  				if err != nil {
  1929  					t.Error(err)
  1930  				}
  1931  
  1932  				_, err = c.Do(req.WithContext(ctx))
  1933  				if err != nil {
  1934  					t.Errorf("Do failed with error: %v", err)
  1935  				}
  1936  			}()
  1937  		}
  1938  		wg.Wait()
  1939  	}
  1940  
  1941  	var initialCC contextCounter
  1942  	doPosts(&initialCC)
  1943  
  1944  	// flushCC exists only to put pressure on the GC to finalize the initialCC
  1945  	// contexts: the flushCC allocations should eventually displace the initialCC
  1946  	// allocations.
  1947  	var flushCC contextCounter
  1948  	for i := 0; ; i++ {
  1949  		live := initialCC.Read()
  1950  		if live == 0 {
  1951  			break
  1952  		}
  1953  		if i >= 100 {
  1954  			t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
  1955  		}
  1956  		doPosts(&flushCC)
  1957  		runtime.GC()
  1958  	}
  1959  }
  1960  
  1961  // This used to crash; https://golang.org/issue/3266
  1962  func TestTransportIdleConnCrash(t *testing.T) {
  1963  	defer afterTest(t)
  1964  	var tr *Transport
  1965  
  1966  	unblockCh := make(chan bool, 1)
  1967  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1968  		<-unblockCh
  1969  		tr.CloseIdleConnections()
  1970  	}))
  1971  	defer ts.Close()
  1972  	c := ts.Client()
  1973  	tr = c.Transport.(*Transport)
  1974  
  1975  	didreq := make(chan bool)
  1976  	go func() {
  1977  		res, err := c.Get(ts.URL)
  1978  		if err != nil {
  1979  			t.Error(err)
  1980  		} else {
  1981  			res.Body.Close() // returns idle conn
  1982  		}
  1983  		didreq <- true
  1984  	}()
  1985  	unblockCh <- true
  1986  	<-didreq
  1987  }
  1988  
  1989  // Test that the transport doesn't close the TCP connection early,
  1990  // before the response body has been read. This was a regression
  1991  // which sadly lacked a triggering test. The large response body made
  1992  // the old race easier to trigger.
  1993  func TestIssue3644(t *testing.T) {
  1994  	defer afterTest(t)
  1995  	const numFoos = 5000
  1996  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  1997  		w.Header().Set("Connection", "close")
  1998  		for i := 0; i < numFoos; i++ {
  1999  			w.Write([]byte("foo "))
  2000  		}
  2001  	}))
  2002  	defer ts.Close()
  2003  	c := ts.Client()
  2004  	res, err := c.Get(ts.URL)
  2005  	if err != nil {
  2006  		t.Fatal(err)
  2007  	}
  2008  	defer res.Body.Close()
  2009  	bs, err := io.ReadAll(res.Body)
  2010  	if err != nil {
  2011  		t.Fatal(err)
  2012  	}
  2013  	if len(bs) != numFoos*len("foo ") {
  2014  		t.Errorf("unexpected response length")
  2015  	}
  2016  }
  2017  
  2018  // Test that a client receives a server's reply, even if the server doesn't read
  2019  // the entire request body.
  2020  func TestIssue3595(t *testing.T) {
  2021  	setParallel(t)
  2022  	defer afterTest(t)
  2023  	const deniedMsg = "sorry, denied."
  2024  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2025  		Error(w, deniedMsg, StatusUnauthorized)
  2026  	}))
  2027  	defer ts.Close()
  2028  	c := ts.Client()
  2029  	res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
  2030  	if err != nil {
  2031  		t.Errorf("Post: %v", err)
  2032  		return
  2033  	}
  2034  	got, err := io.ReadAll(res.Body)
  2035  	if err != nil {
  2036  		t.Fatalf("Body ReadAll: %v", err)
  2037  	}
  2038  	if !strings.Contains(string(got), deniedMsg) {
  2039  		t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
  2040  	}
  2041  }
  2042  
  2043  // From https://golang.org/issue/4454 ,
  2044  // "client fails to handle requests with no body and chunked encoding"
  2045  func TestChunkedNoContent(t *testing.T) {
  2046  	defer afterTest(t)
  2047  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2048  		w.WriteHeader(StatusNoContent)
  2049  	}))
  2050  	defer ts.Close()
  2051  
  2052  	c := ts.Client()
  2053  	for _, closeBody := range []bool{true, false} {
  2054  		const n = 4
  2055  		for i := 1; i <= n; i++ {
  2056  			res, err := c.Get(ts.URL)
  2057  			if err != nil {
  2058  				t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
  2059  			} else {
  2060  				if closeBody {
  2061  					res.Body.Close()
  2062  				}
  2063  			}
  2064  		}
  2065  	}
  2066  }
  2067  
  2068  func TestTransportConcurrency(t *testing.T) {
  2069  	// Not parallel: uses global test hooks.
  2070  	defer afterTest(t)
  2071  	maxProcs, numReqs := 16, 500
  2072  	if testing.Short() {
  2073  		maxProcs, numReqs = 4, 50
  2074  	}
  2075  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
  2076  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2077  		fmt.Fprintf(w, "%v", r.FormValue("echo"))
  2078  	}))
  2079  	defer ts.Close()
  2080  
  2081  	var wg sync.WaitGroup
  2082  	wg.Add(numReqs)
  2083  
  2084  	// Due to the Transport's "socket late binding" (see
  2085  	// idleConnCh in transport.go), the numReqs HTTP requests
  2086  	// below can finish with a dial still outstanding. To keep
  2087  	// the leak checker happy, keep track of pending dials and
  2088  	// wait for them to finish (and be closed or returned to the
  2089  	// idle pool) before we close idle connections.
  2090  	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
  2091  	defer SetPendingDialHooks(nil, nil)
  2092  
  2093  	c := ts.Client()
  2094  	reqs := make(chan string)
  2095  	defer close(reqs)
  2096  
  2097  	for i := 0; i < maxProcs*2; i++ {
  2098  		go func() {
  2099  			for req := range reqs {
  2100  				res, err := c.Get(ts.URL + "/?echo=" + req)
  2101  				if err != nil {
  2102  					t.Errorf("error on req %s: %v", req, err)
  2103  					wg.Done()
  2104  					continue
  2105  				}
  2106  				all, err := io.ReadAll(res.Body)
  2107  				if err != nil {
  2108  					t.Errorf("read error on req %s: %v", req, err)
  2109  					wg.Done()
  2110  					continue
  2111  				}
  2112  				if string(all) != req {
  2113  					t.Errorf("body of req %s = %q; want %q", req, all, req)
  2114  				}
  2115  				res.Body.Close()
  2116  				wg.Done()
  2117  			}
  2118  		}()
  2119  	}
  2120  	for i := 0; i < numReqs; i++ {
  2121  		reqs <- fmt.Sprintf("request-%d", i)
  2122  	}
  2123  	wg.Wait()
  2124  }
  2125  
  2126  func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
  2127  	setParallel(t)
  2128  	defer afterTest(t)
  2129  	const debug = false
  2130  	mux := NewServeMux()
  2131  	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
  2132  		io.Copy(w, neverEnding('a'))
  2133  	})
  2134  	ts := httptest.NewServer(mux)
  2135  	defer ts.Close()
  2136  	timeout := 100 * time.Millisecond
  2137  
  2138  	c := ts.Client()
  2139  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2140  		conn, err := net.Dial(n, addr)
  2141  		if err != nil {
  2142  			return nil, err
  2143  		}
  2144  		conn.SetDeadline(time.Now().Add(timeout))
  2145  		if debug {
  2146  			conn = NewLoggingConn("client", conn)
  2147  		}
  2148  		return conn, nil
  2149  	}
  2150  
  2151  	getFailed := false
  2152  	nRuns := 5
  2153  	if testing.Short() {
  2154  		nRuns = 1
  2155  	}
  2156  	for i := 0; i < nRuns; i++ {
  2157  		if debug {
  2158  			println("run", i+1, "of", nRuns)
  2159  		}
  2160  		sres, err := c.Get(ts.URL + "/get")
  2161  		if err != nil {
  2162  			if !getFailed {
  2163  				// Make the timeout longer, once.
  2164  				getFailed = true
  2165  				t.Logf("increasing timeout")
  2166  				i--
  2167  				timeout *= 10
  2168  				continue
  2169  			}
  2170  			t.Errorf("Error issuing GET: %v", err)
  2171  			break
  2172  		}
  2173  		_, err = io.Copy(io.Discard, sres.Body)
  2174  		if err == nil {
  2175  			t.Errorf("Unexpected successful copy")
  2176  			break
  2177  		}
  2178  	}
  2179  	if debug {
  2180  		println("tests complete; waiting for handlers to finish")
  2181  	}
  2182  }
  2183  
  2184  func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
  2185  	setParallel(t)
  2186  	defer afterTest(t)
  2187  	const debug = false
  2188  	mux := NewServeMux()
  2189  	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
  2190  		io.Copy(w, neverEnding('a'))
  2191  	})
  2192  	mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
  2193  		defer r.Body.Close()
  2194  		io.Copy(io.Discard, r.Body)
  2195  	})
  2196  	ts := httptest.NewServer(mux)
  2197  	timeout := 100 * time.Millisecond
  2198  
  2199  	c := ts.Client()
  2200  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2201  		conn, err := net.Dial(n, addr)
  2202  		if err != nil {
  2203  			return nil, err
  2204  		}
  2205  		conn.SetDeadline(time.Now().Add(timeout))
  2206  		if debug {
  2207  			conn = NewLoggingConn("client", conn)
  2208  		}
  2209  		return conn, nil
  2210  	}
  2211  
  2212  	getFailed := false
  2213  	nRuns := 5
  2214  	if testing.Short() {
  2215  		nRuns = 1
  2216  	}
  2217  	for i := 0; i < nRuns; i++ {
  2218  		if debug {
  2219  			println("run", i+1, "of", nRuns)
  2220  		}
  2221  		sres, err := c.Get(ts.URL + "/get")
  2222  		if err != nil {
  2223  			if !getFailed {
  2224  				// Make the timeout longer, once.
  2225  				getFailed = true
  2226  				t.Logf("increasing timeout")
  2227  				i--
  2228  				timeout *= 10
  2229  				continue
  2230  			}
  2231  			t.Errorf("Error issuing GET: %v", err)
  2232  			break
  2233  		}
  2234  		req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
  2235  		_, err = c.Do(req)
  2236  		if err == nil {
  2237  			sres.Body.Close()
  2238  			t.Errorf("Unexpected successful PUT")
  2239  			break
  2240  		}
  2241  		sres.Body.Close()
  2242  	}
  2243  	if debug {
  2244  		println("tests complete; waiting for handlers to finish")
  2245  	}
  2246  	ts.Close()
  2247  }
  2248  
  2249  func TestTransportResponseHeaderTimeout(t *testing.T) {
  2250  	setParallel(t)
  2251  	defer afterTest(t)
  2252  	if testing.Short() {
  2253  		t.Skip("skipping timeout test in -short mode")
  2254  	}
  2255  	inHandler := make(chan bool, 1)
  2256  	mux := NewServeMux()
  2257  	mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
  2258  		inHandler <- true
  2259  	})
  2260  	mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
  2261  		inHandler <- true
  2262  		time.Sleep(2 * time.Second)
  2263  	})
  2264  	ts := httptest.NewServer(mux)
  2265  	defer ts.Close()
  2266  
  2267  	c := ts.Client()
  2268  	c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond
  2269  
  2270  	tests := []struct {
  2271  		path    string
  2272  		want    int
  2273  		wantErr string
  2274  	}{
  2275  		{path: "/fast", want: 200},
  2276  		{path: "/slow", wantErr: "timeout awaiting response headers"},
  2277  		{path: "/fast", want: 200},
  2278  	}
  2279  	for i, tt := range tests {
  2280  		req, _ := NewRequest("GET", ts.URL+tt.path, nil)
  2281  		req = req.WithT(t)
  2282  		res, err := c.Do(req)
  2283  		select {
  2284  		case <-inHandler:
  2285  		case <-time.After(5 * time.Second):
  2286  			t.Errorf("never entered handler for test index %d, %s", i, tt.path)
  2287  			continue
  2288  		}
  2289  		if err != nil {
  2290  			uerr, ok := err.(*url.Error)
  2291  			if !ok {
  2292  				t.Errorf("error is not an url.Error; got: %#v", err)
  2293  				continue
  2294  			}
  2295  			nerr, ok := uerr.Err.(net.Error)
  2296  			if !ok {
  2297  				t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
  2298  				continue
  2299  			}
  2300  			if !nerr.Timeout() {
  2301  				t.Errorf("want timeout error; got: %q", nerr)
  2302  				continue
  2303  			}
  2304  			if strings.Contains(err.Error(), tt.wantErr) {
  2305  				continue
  2306  			}
  2307  			t.Errorf("%d. unexpected error: %v", i, err)
  2308  			continue
  2309  		}
  2310  		if tt.wantErr != "" {
  2311  			t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
  2312  			continue
  2313  		}
  2314  		if res.StatusCode != tt.want {
  2315  			t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
  2316  		}
  2317  	}
  2318  }
  2319  
  2320  func TestTransportCancelRequest(t *testing.T) {
  2321  	setParallel(t)
  2322  	defer afterTest(t)
  2323  	if testing.Short() {
  2324  		t.Skip("skipping test in -short mode")
  2325  	}
  2326  	unblockc := make(chan bool)
  2327  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2328  		fmt.Fprintf(w, "Hello")
  2329  		w.(Flusher).Flush() // send headers and some body
  2330  		<-unblockc
  2331  	}))
  2332  	defer ts.Close()
  2333  	defer close(unblockc)
  2334  
  2335  	c := ts.Client()
  2336  	tr := c.Transport.(*Transport)
  2337  
  2338  	req, _ := NewRequest("GET", ts.URL, nil)
  2339  	res, err := c.Do(req)
  2340  	if err != nil {
  2341  		t.Fatal(err)
  2342  	}
  2343  	go func() {
  2344  		time.Sleep(1 * time.Second)
  2345  		tr.CancelRequest(req)
  2346  	}()
  2347  	t0 := time.Now()
  2348  	body, err := io.ReadAll(res.Body)
  2349  	d := time.Since(t0)
  2350  
  2351  	if err != ExportErrRequestCanceled {
  2352  		t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
  2353  	}
  2354  	if string(body) != "Hello" {
  2355  		t.Errorf("Body = %q; want Hello", body)
  2356  	}
  2357  	if d < 500*time.Millisecond {
  2358  		t.Errorf("expected ~1 second delay; got %v", d)
  2359  	}
  2360  	// Verify no outstanding requests after readLoop/writeLoop
  2361  	// goroutines shut down.
  2362  	for tries := 5; tries > 0; tries-- {
  2363  		n := tr.NumPendingRequestsForTesting()
  2364  		if n == 0 {
  2365  			break
  2366  		}
  2367  		time.Sleep(100 * time.Millisecond)
  2368  		if tries == 1 {
  2369  			t.Errorf("pending requests = %d; want 0", n)
  2370  		}
  2371  	}
  2372  }
  2373  
  2374  func testTransportCancelRequestInDo(t *testing.T, body io.Reader) {
  2375  	setParallel(t)
  2376  	defer afterTest(t)
  2377  	if testing.Short() {
  2378  		t.Skip("skipping test in -short mode")
  2379  	}
  2380  	unblockc := make(chan bool)
  2381  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2382  		<-unblockc
  2383  	}))
  2384  	defer ts.Close()
  2385  	defer close(unblockc)
  2386  
  2387  	c := ts.Client()
  2388  	tr := c.Transport.(*Transport)
  2389  
  2390  	donec := make(chan bool)
  2391  	req, _ := NewRequest("GET", ts.URL, body)
  2392  	go func() {
  2393  		defer close(donec)
  2394  		c.Do(req)
  2395  	}()
  2396  	start := time.Now()
  2397  	timeout := 10 * time.Second
  2398  	for time.Since(start) < timeout {
  2399  		time.Sleep(100 * time.Millisecond)
  2400  		tr.CancelRequest(req)
  2401  		select {
  2402  		case <-donec:
  2403  			return
  2404  		default:
  2405  		}
  2406  	}
  2407  	t.Errorf("Do of canceled request has not returned after %v", timeout)
  2408  }
  2409  
  2410  func TestTransportCancelRequestInDo(t *testing.T) {
  2411  	testTransportCancelRequestInDo(t, nil)
  2412  }
  2413  
  2414  func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
  2415  	testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0}))
  2416  }
  2417  
  2418  func TestTransportCancelRequestInDial(t *testing.T) {
  2419  	defer afterTest(t)
  2420  	if testing.Short() {
  2421  		t.Skip("skipping test in -short mode")
  2422  	}
  2423  	var logbuf bytes.Buffer
  2424  	eventLog := log.New(&logbuf, "", 0)
  2425  
  2426  	unblockDial := make(chan bool)
  2427  	defer close(unblockDial)
  2428  
  2429  	inDial := make(chan bool)
  2430  	tr := &Transport{
  2431  		Dial: func(network, addr string) (net.Conn, error) {
  2432  			eventLog.Println("dial: blocking")
  2433  			if !<-inDial {
  2434  				return nil, errors.New("main Test goroutine exited")
  2435  			}
  2436  			<-unblockDial
  2437  			return nil, errors.New("nope")
  2438  		},
  2439  	}
  2440  	cl := &Client{Transport: tr}
  2441  	gotres := make(chan bool)
  2442  	req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
  2443  	go func() {
  2444  		_, err := cl.Do(req)
  2445  		eventLog.Printf("Get = %v", err)
  2446  		gotres <- true
  2447  	}()
  2448  
  2449  	select {
  2450  	case inDial <- true:
  2451  	case <-time.After(5 * time.Second):
  2452  		close(inDial)
  2453  		t.Fatal("timeout; never saw blocking dial")
  2454  	}
  2455  
  2456  	eventLog.Printf("canceling")
  2457  	tr.CancelRequest(req)
  2458  	tr.CancelRequest(req) // used to panic on second call
  2459  
  2460  	select {
  2461  	case <-gotres:
  2462  	case <-time.After(5 * time.Second):
  2463  		panic("hang. events are: " + logbuf.String())
  2464  	}
  2465  
  2466  	got := logbuf.String()
  2467  	want := `dial: blocking
  2468  canceling
  2469  Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
  2470  `
  2471  	if got != want {
  2472  		t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
  2473  	}
  2474  }
  2475  
  2476  func TestCancelRequestWithChannel(t *testing.T) {
  2477  	setParallel(t)
  2478  	defer afterTest(t)
  2479  	if testing.Short() {
  2480  		t.Skip("skipping test in -short mode")
  2481  	}
  2482  	unblockc := make(chan bool)
  2483  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2484  		fmt.Fprintf(w, "Hello")
  2485  		w.(Flusher).Flush() // send headers and some body
  2486  		<-unblockc
  2487  	}))
  2488  	defer ts.Close()
  2489  	defer close(unblockc)
  2490  
  2491  	c := ts.Client()
  2492  	tr := c.Transport.(*Transport)
  2493  
  2494  	req, _ := NewRequest("GET", ts.URL, nil)
  2495  	ch := make(chan struct{})
  2496  	req.Cancel = ch
  2497  
  2498  	res, err := c.Do(req)
  2499  	if err != nil {
  2500  		t.Fatal(err)
  2501  	}
  2502  	go func() {
  2503  		time.Sleep(1 * time.Second)
  2504  		close(ch)
  2505  	}()
  2506  	t0 := time.Now()
  2507  	body, err := io.ReadAll(res.Body)
  2508  	d := time.Since(t0)
  2509  
  2510  	if err != ExportErrRequestCanceled {
  2511  		t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
  2512  	}
  2513  	if string(body) != "Hello" {
  2514  		t.Errorf("Body = %q; want Hello", body)
  2515  	}
  2516  	if d < 500*time.Millisecond {
  2517  		t.Errorf("expected ~1 second delay; got %v", d)
  2518  	}
  2519  	// Verify no outstanding requests after readLoop/writeLoop
  2520  	// goroutines shut down.
  2521  	for tries := 5; tries > 0; tries-- {
  2522  		n := tr.NumPendingRequestsForTesting()
  2523  		if n == 0 {
  2524  			break
  2525  		}
  2526  		time.Sleep(100 * time.Millisecond)
  2527  		if tries == 1 {
  2528  			t.Errorf("pending requests = %d; want 0", n)
  2529  		}
  2530  	}
  2531  }
  2532  
  2533  func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
  2534  	testCancelRequestWithChannelBeforeDo(t, false)
  2535  }
  2536  func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
  2537  	testCancelRequestWithChannelBeforeDo(t, true)
  2538  }
  2539  func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
  2540  	setParallel(t)
  2541  	defer afterTest(t)
  2542  	unblockc := make(chan bool)
  2543  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2544  		<-unblockc
  2545  	}))
  2546  	defer ts.Close()
  2547  	defer close(unblockc)
  2548  
  2549  	c := ts.Client()
  2550  
  2551  	req, _ := NewRequest("GET", ts.URL, nil)
  2552  	if withCtx {
  2553  		ctx, cancel := context.WithCancel(context.Background())
  2554  		cancel()
  2555  		req = req.WithContext(ctx)
  2556  	} else {
  2557  		ch := make(chan struct{})
  2558  		req.Cancel = ch
  2559  		close(ch)
  2560  	}
  2561  
  2562  	_, err := c.Do(req)
  2563  	if ue, ok := err.(*url.Error); ok {
  2564  		err = ue.Err
  2565  	}
  2566  	if withCtx {
  2567  		if err != context.Canceled {
  2568  			t.Errorf("Do error = %v; want %v", err, context.Canceled)
  2569  		}
  2570  	} else {
  2571  		if err == nil || !strings.Contains(err.Error(), "canceled") {
  2572  			t.Errorf("Do error = %v; want cancellation", err)
  2573  		}
  2574  	}
  2575  }
  2576  
  2577  // Issue 11020. The returned error message should be errRequestCanceled
  2578  func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
  2579  	defer afterTest(t)
  2580  
  2581  	serverConnCh := make(chan net.Conn, 1)
  2582  	tr := &Transport{
  2583  		Dial: func(network, addr string) (net.Conn, error) {
  2584  			cc, sc := net.Pipe()
  2585  			serverConnCh <- sc
  2586  			return cc, nil
  2587  		},
  2588  	}
  2589  	defer tr.CloseIdleConnections()
  2590  	errc := make(chan error, 1)
  2591  	req, _ := NewRequest("GET", "http://example.com/", nil)
  2592  	go func() {
  2593  		_, err := tr.RoundTrip(req)
  2594  		errc <- err
  2595  	}()
  2596  
  2597  	sc := <-serverConnCh
  2598  	verb := make([]byte, 3)
  2599  	if _, err := io.ReadFull(sc, verb); err != nil {
  2600  		t.Errorf("Error reading HTTP verb from server: %v", err)
  2601  	}
  2602  	if string(verb) != "GET" {
  2603  		t.Errorf("server received %q; want GET", verb)
  2604  	}
  2605  	defer sc.Close()
  2606  
  2607  	tr.CancelRequest(req)
  2608  
  2609  	err := <-errc
  2610  	if err == nil {
  2611  		t.Fatalf("unexpected success from RoundTrip")
  2612  	}
  2613  	if err != ExportErrRequestCanceled {
  2614  		t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
  2615  	}
  2616  }
  2617  
  2618  // golang.org/issue/3672 -- Client can't close HTTP stream
  2619  // Calling Close on a Response.Body used to just read until EOF.
  2620  // Now it actually closes the TCP connection.
  2621  func TestTransportCloseResponseBody(t *testing.T) {
  2622  	defer afterTest(t)
  2623  	writeErr := make(chan error, 1)
  2624  	msg := []byte("young\n")
  2625  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  2626  		for {
  2627  			_, err := w.Write(msg)
  2628  			if err != nil {
  2629  				writeErr <- err
  2630  				return
  2631  			}
  2632  			w.(Flusher).Flush()
  2633  		}
  2634  	}))
  2635  	defer ts.Close()
  2636  
  2637  	c := ts.Client()
  2638  	tr := c.Transport.(*Transport)
  2639  
  2640  	req, _ := NewRequest("GET", ts.URL, nil)
  2641  	defer tr.CancelRequest(req)
  2642  
  2643  	res, err := c.Do(req)
  2644  	if err != nil {
  2645  		t.Fatal(err)
  2646  	}
  2647  
  2648  	const repeats = 3
  2649  	buf := make([]byte, len(msg)*repeats)
  2650  	want := bytes.Repeat(msg, repeats)
  2651  
  2652  	_, err = io.ReadFull(res.Body, buf)
  2653  	if err != nil {
  2654  		t.Fatal(err)
  2655  	}
  2656  	if !bytes.Equal(buf, want) {
  2657  		t.Fatalf("read %q; want %q", buf, want)
  2658  	}
  2659  	didClose := make(chan error, 1)
  2660  	go func() {
  2661  		didClose <- res.Body.Close()
  2662  	}()
  2663  	select {
  2664  	case err := <-didClose:
  2665  		if err != nil {
  2666  			t.Errorf("Close = %v", err)
  2667  		}
  2668  	case <-time.After(10 * time.Second):
  2669  		t.Fatal("too long waiting for close")
  2670  	}
  2671  	select {
  2672  	case err := <-writeErr:
  2673  		if err == nil {
  2674  			t.Errorf("expected non-nil write error")
  2675  		}
  2676  	case <-time.After(10 * time.Second):
  2677  		t.Fatal("too long waiting for write error")
  2678  	}
  2679  }
  2680  
  2681  type fooProto struct{}
  2682  
  2683  func (fooProto) RoundTrip(req *Request) (*Response, error) {
  2684  	res := &Response{
  2685  		Status:     "200 OK",
  2686  		StatusCode: 200,
  2687  		Header:     make(Header),
  2688  		Body:       io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
  2689  	}
  2690  	return res, nil
  2691  }
  2692  
  2693  func TestTransportAltProto(t *testing.T) {
  2694  	defer afterTest(t)
  2695  	tr := &Transport{}
  2696  	c := &Client{Transport: tr}
  2697  	tr.RegisterProtocol("foo", fooProto{})
  2698  	res, err := c.Get("foo://bar.com/path")
  2699  	if err != nil {
  2700  		t.Fatal(err)
  2701  	}
  2702  	bodyb, err := io.ReadAll(res.Body)
  2703  	if err != nil {
  2704  		t.Fatal(err)
  2705  	}
  2706  	body := string(bodyb)
  2707  	if e := "You wanted foo://bar.com/path"; body != e {
  2708  		t.Errorf("got response %q, want %q", body, e)
  2709  	}
  2710  }
  2711  
  2712  func TestTransportNoHost(t *testing.T) {
  2713  	defer afterTest(t)
  2714  	tr := &Transport{}
  2715  	_, err := tr.RoundTrip(&Request{
  2716  		Header: make(Header),
  2717  		URL: &url.URL{
  2718  			Scheme: "http",
  2719  		},
  2720  	})
  2721  	want := "http: no Host in request URL"
  2722  	if got := fmt.Sprint(err); got != want {
  2723  		t.Errorf("error = %v; want %q", err, want)
  2724  	}
  2725  }
  2726  
  2727  // Issue 13311
  2728  func TestTransportEmptyMethod(t *testing.T) {
  2729  	req, _ := NewRequest("GET", "http://foo.com/", nil)
  2730  	req.Method = ""                                 // docs say "For client requests an empty string means GET"
  2731  	got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
  2732  	if err != nil {
  2733  		t.Fatal(err)
  2734  	}
  2735  	if !strings.Contains(string(got), "GET ") {
  2736  		t.Fatalf("expected substring 'GET '; got: %s", got)
  2737  	}
  2738  }
  2739  
  2740  func TestTransportSocketLateBinding(t *testing.T) {
  2741  	setParallel(t)
  2742  	defer afterTest(t)
  2743  
  2744  	mux := NewServeMux()
  2745  	fooGate := make(chan bool, 1)
  2746  	mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
  2747  		w.Header().Set("foo-ipport", r.RemoteAddr)
  2748  		w.(Flusher).Flush()
  2749  		<-fooGate
  2750  	})
  2751  	mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
  2752  		w.Header().Set("bar-ipport", r.RemoteAddr)
  2753  	})
  2754  	ts := httptest.NewServer(mux)
  2755  	defer ts.Close()
  2756  
  2757  	dialGate := make(chan bool, 1)
  2758  	c := ts.Client()
  2759  	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
  2760  		if <-dialGate {
  2761  			return net.Dial(n, addr)
  2762  		}
  2763  		return nil, errors.New("manually closed")
  2764  	}
  2765  
  2766  	dialGate <- true // only allow one dial
  2767  	fooRes, err := c.Get(ts.URL + "/foo")
  2768  	if err != nil {
  2769  		t.Fatal(err)
  2770  	}
  2771  	fooAddr := fooRes.Header.Get("foo-ipport")
  2772  	if fooAddr == "" {
  2773  		t.Fatal("No addr on /foo request")
  2774  	}
  2775  	time.AfterFunc(200*time.Millisecond, func() {
  2776  		// let the foo response finish so we can use its
  2777  		// connection for /bar
  2778  		fooGate <- true
  2779  		io.Copy(io.Discard, fooRes.Body)
  2780  		fooRes.Body.Close()
  2781  	})
  2782  
  2783  	barRes, err := c.Get(ts.URL + "/bar")
  2784  	if err != nil {
  2785  		t.Fatal(err)
  2786  	}
  2787  	barAddr := barRes.Header.Get("bar-ipport")
  2788  	if barAddr != fooAddr {
  2789  		t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
  2790  	}
  2791  	barRes.Body.Close()
  2792  	dialGate <- false
  2793  }
  2794  
  2795  // Issue 2184
  2796  func TestTransportReading100Continue(t *testing.T) {
  2797  	defer afterTest(t)
  2798  
  2799  	const numReqs = 5
  2800  	reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
  2801  	reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
  2802  
  2803  	send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
  2804  		defer w.Close()
  2805  		defer r.Close()
  2806  		br := bufio.NewReader(r)
  2807  		n := 0
  2808  		for {
  2809  			n++
  2810  			req, err := ReadRequest(br)
  2811  			if err == io.EOF {
  2812  				return
  2813  			}
  2814  			if err != nil {
  2815  				t.Error(err)
  2816  				return
  2817  			}
  2818  			slurp, err := io.ReadAll(req.Body)
  2819  			if err != nil {
  2820  				t.Errorf("Server request body slurp: %v", err)
  2821  				return
  2822  			}
  2823  			id := req.Header.Get("Request-Id")
  2824  			resCode := req.Header.Get("X-Want-Response-Code")
  2825  			if resCode == "" {
  2826  				resCode = "100 Continue"
  2827  				if string(slurp) != reqBody(n) {
  2828  					t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
  2829  				}
  2830  			}
  2831  			body := fmt.Sprintf("Response number %d", n)
  2832  			v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
  2833  Date: Thu, 28 Feb 2013 17:55:41 GMT
  2834  
  2835  HTTP/1.1 200 OK
  2836  Content-Type: text/html
  2837  Echo-Request-Id: %s
  2838  Content-Length: %d
  2839  
  2840  %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
  2841  			w.Write(v)
  2842  			if id == reqID(numReqs) {
  2843  				return
  2844  			}
  2845  		}
  2846  
  2847  	}
  2848  
  2849  	tr := &Transport{
  2850  		Dial: func(n, addr string) (net.Conn, error) {
  2851  			sr, sw := io.Pipe() // server read/write
  2852  			cr, cw := io.Pipe() // client read/write
  2853  			conn := &rwTestConn{
  2854  				Reader: cr,
  2855  				Writer: sw,
  2856  				closeFunc: func() error {
  2857  					sw.Close()
  2858  					cw.Close()
  2859  					return nil
  2860  				},
  2861  			}
  2862  			go send100Response(cw, sr)
  2863  			return conn, nil
  2864  		},
  2865  		DisableKeepAlives: false,
  2866  	}
  2867  	defer tr.CloseIdleConnections()
  2868  	c := &Client{Transport: tr}
  2869  
  2870  	testResponse := func(req *Request, name string, wantCode int) {
  2871  		t.Helper()
  2872  		res, err := c.Do(req)
  2873  		if err != nil {
  2874  			t.Fatalf("%s: Do: %v", name, err)
  2875  		}
  2876  		if res.StatusCode != wantCode {
  2877  			t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
  2878  		}
  2879  		if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
  2880  			t.Errorf("%s: response id %q != request id %q", name, idBack, id)
  2881  		}
  2882  		_, err = io.ReadAll(res.Body)
  2883  		if err != nil {
  2884  			t.Fatalf("%s: Slurp error: %v", name, err)
  2885  		}
  2886  	}
  2887  
  2888  	// Few 100 responses, making sure we're not off-by-one.
  2889  	for i := 1; i <= numReqs; i++ {
  2890  		req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
  2891  		req.Header.Set("Request-Id", reqID(i))
  2892  		testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
  2893  	}
  2894  }
  2895  
  2896  // Issue 17739: the HTTP client must ignore any unknown 1xx
  2897  // informational responses before the actual response.
  2898  func TestTransportIgnore1xxResponses(t *testing.T) {
  2899  	setParallel(t)
  2900  	defer afterTest(t)
  2901  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2902  		conn, buf, _ := w.(Hijacker).Hijack()
  2903  		buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
  2904  		buf.Flush()
  2905  		conn.Close()
  2906  	}))
  2907  	defer cst.close()
  2908  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  2909  
  2910  	var got bytes.Buffer
  2911  
  2912  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  2913  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  2914  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  2915  			fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
  2916  			return nil
  2917  		},
  2918  	}))
  2919  	res, err := cst.c.Do(req)
  2920  	if err != nil {
  2921  		t.Fatal(err)
  2922  	}
  2923  	defer res.Body.Close()
  2924  
  2925  	res.Write(&got)
  2926  	want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
  2927  	if got.String() != want {
  2928  		t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want)
  2929  	}
  2930  }
  2931  
  2932  func TestTransportLimits1xxResponses(t *testing.T) {
  2933  	setParallel(t)
  2934  	defer afterTest(t)
  2935  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2936  		conn, buf, _ := w.(Hijacker).Hijack()
  2937  		for i := 0; i < 10; i++ {
  2938  			buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
  2939  		}
  2940  		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
  2941  		buf.Flush()
  2942  		conn.Close()
  2943  	}))
  2944  	defer cst.close()
  2945  	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
  2946  
  2947  	res, err := cst.c.Get(cst.ts.URL)
  2948  	if res != nil {
  2949  		defer res.Body.Close()
  2950  	}
  2951  	got := fmt.Sprint(err)
  2952  	wantSub := "too many 1xx informational responses"
  2953  	if !strings.Contains(got, wantSub) {
  2954  		t.Errorf("Get error = %v; want substring %q", err, wantSub)
  2955  	}
  2956  }
  2957  
  2958  // Issue 26161: the HTTP client must treat 101 responses
  2959  // as the final response.
  2960  func TestTransportTreat101Terminal(t *testing.T) {
  2961  	setParallel(t)
  2962  	defer afterTest(t)
  2963  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  2964  		conn, buf, _ := w.(Hijacker).Hijack()
  2965  		buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
  2966  		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
  2967  		buf.Flush()
  2968  		conn.Close()
  2969  	}))
  2970  	defer cst.close()
  2971  	res, err := cst.c.Get(cst.ts.URL)
  2972  	if err != nil {
  2973  		t.Fatal(err)
  2974  	}
  2975  	defer res.Body.Close()
  2976  	if res.StatusCode != StatusSwitchingProtocols {
  2977  		t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
  2978  	}
  2979  }
  2980  
  2981  type proxyFromEnvTest struct {
  2982  	req string // URL to fetch; blank means "http://example.com"
  2983  
  2984  	env      string // HTTP_PROXY
  2985  	httpsenv string // HTTPS_PROXY
  2986  	noenv    string // NO_PROXY
  2987  	reqmeth  string // REQUEST_METHOD
  2988  
  2989  	want    string
  2990  	wanterr error
  2991  }
  2992  
  2993  func (t proxyFromEnvTest) String() string {
  2994  	var buf bytes.Buffer
  2995  	space := func() {
  2996  		if buf.Len() > 0 {
  2997  			buf.WriteByte(' ')
  2998  		}
  2999  	}
  3000  	if t.env != "" {
  3001  		fmt.Fprintf(&buf, "http_proxy=%q", t.env)
  3002  	}
  3003  	if t.httpsenv != "" {
  3004  		space()
  3005  		fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
  3006  	}
  3007  	if t.noenv != "" {
  3008  		space()
  3009  		fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
  3010  	}
  3011  	if t.reqmeth != "" {
  3012  		space()
  3013  		fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
  3014  	}
  3015  	req := "http://example.com"
  3016  	if t.req != "" {
  3017  		req = t.req
  3018  	}
  3019  	space()
  3020  	fmt.Fprintf(&buf, "req=%q", req)
  3021  	return strings.TrimSpace(buf.String())
  3022  }
  3023  
  3024  var proxyFromEnvTests = []proxyFromEnvTest{
  3025  	{env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
  3026  	{env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
  3027  	{env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
  3028  	{env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
  3029  	{env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
  3030  	{env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
  3031  	{env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
  3032  
  3033  	// Don't use secure for http
  3034  	{req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
  3035  	// Use secure for https.
  3036  	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
  3037  	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
  3038  
  3039  	// Issue 16405: don't use HTTP_PROXY in a CGI environment,
  3040  	// where HTTP_PROXY can be attacker-controlled.
  3041  	{env: "http://10.1.2.3:8080", reqmeth: "POST",
  3042  		want:    "<nil>",
  3043  		wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
  3044  
  3045  	{want: "<nil>"},
  3046  
  3047  	{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
  3048  	{noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3049  	{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3050  	{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
  3051  	{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
  3052  }
  3053  
  3054  func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
  3055  	t.Helper()
  3056  	reqURL := tt.req
  3057  	if reqURL == "" {
  3058  		reqURL = "http://example.com"
  3059  	}
  3060  	req, _ := NewRequest("GET", reqURL, nil)
  3061  	url, err := proxyForRequest(req)
  3062  	if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
  3063  		t.Errorf("%v: got error = %q, want %q", tt, g, e)
  3064  		return
  3065  	}
  3066  	if got := fmt.Sprintf("%s", url); got != tt.want {
  3067  		t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
  3068  	}
  3069  }
  3070  
  3071  func TestProxyFromEnvironment(t *testing.T) {
  3072  	ResetProxyEnv()
  3073  	defer ResetProxyEnv()
  3074  	for _, tt := range proxyFromEnvTests {
  3075  		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
  3076  			os.Setenv("HTTP_PROXY", tt.env)
  3077  			os.Setenv("HTTPS_PROXY", tt.httpsenv)
  3078  			os.Setenv("NO_PROXY", tt.noenv)
  3079  			os.Setenv("REQUEST_METHOD", tt.reqmeth)
  3080  			ResetCachedEnvironment()
  3081  			return ProxyFromEnvironment(req)
  3082  		})
  3083  	}
  3084  }
  3085  
  3086  func TestProxyFromEnvironmentLowerCase(t *testing.T) {
  3087  	ResetProxyEnv()
  3088  	defer ResetProxyEnv()
  3089  	for _, tt := range proxyFromEnvTests {
  3090  		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
  3091  			os.Setenv("http_proxy", tt.env)
  3092  			os.Setenv("https_proxy", tt.httpsenv)
  3093  			os.Setenv("no_proxy", tt.noenv)
  3094  			os.Setenv("REQUEST_METHOD", tt.reqmeth)
  3095  			ResetCachedEnvironment()
  3096  			return ProxyFromEnvironment(req)
  3097  		})
  3098  	}
  3099  }
  3100  
  3101  func TestIdleConnChannelLeak(t *testing.T) {
  3102  	// Not parallel: uses global test hooks.
  3103  	var mu sync.Mutex
  3104  	var n int
  3105  
  3106  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3107  		mu.Lock()
  3108  		n++
  3109  		mu.Unlock()
  3110  	}))
  3111  	defer ts.Close()
  3112  
  3113  	const nReqs = 5
  3114  	didRead := make(chan bool, nReqs)
  3115  	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
  3116  	defer SetReadLoopBeforeNextReadHook(nil)
  3117  
  3118  	c := ts.Client()
  3119  	tr := c.Transport.(*Transport)
  3120  	tr.Dial = func(netw, addr string) (net.Conn, error) {
  3121  		return net.Dial(netw, ts.Listener.Addr().String())
  3122  	}
  3123  
  3124  	// First, without keep-alives.
  3125  	for _, disableKeep := range []bool{true, false} {
  3126  		tr.DisableKeepAlives = disableKeep
  3127  		for i := 0; i < nReqs; i++ {
  3128  			_, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
  3129  			if err != nil {
  3130  				t.Fatal(err)
  3131  			}
  3132  			// Note: no res.Body.Close is needed here, since the
  3133  			// response Content-Length is zero. Perhaps the test
  3134  			// should be more explicit and use a HEAD, but tests
  3135  			// elsewhere guarantee that zero byte responses generate
  3136  			// a "Content-Length: 0" instead of chunking.
  3137  		}
  3138  
  3139  		// At this point, each of the 5 Transport.readLoop goroutines
  3140  		// are scheduling noting that there are no response bodies (see
  3141  		// earlier comment), and are then calling putIdleConn, which
  3142  		// decrements this count. Usually that happens quickly, which is
  3143  		// why this test has seemed to work for ages. But it's still
  3144  		// racey: we have wait for them to finish first. See Issue 10427
  3145  		for i := 0; i < nReqs; i++ {
  3146  			<-didRead
  3147  		}
  3148  
  3149  		if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
  3150  			t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
  3151  		}
  3152  	}
  3153  }
  3154  
  3155  // Verify the status quo: that the Client.Post function coerces its
  3156  // body into a ReadCloser if it's a Closer, and that the Transport
  3157  // then closes it.
  3158  func TestTransportClosesRequestBody(t *testing.T) {
  3159  	defer afterTest(t)
  3160  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3161  		io.Copy(io.Discard, r.Body)
  3162  	}))
  3163  	defer ts.Close()
  3164  
  3165  	c := ts.Client()
  3166  
  3167  	closes := 0
  3168  
  3169  	res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  3170  	if err != nil {
  3171  		t.Fatal(err)
  3172  	}
  3173  	res.Body.Close()
  3174  	if closes != 1 {
  3175  		t.Errorf("closes = %d; want 1", closes)
  3176  	}
  3177  }
  3178  
  3179  func TestTransportTLSHandshakeTimeout(t *testing.T) {
  3180  	defer afterTest(t)
  3181  	if testing.Short() {
  3182  		t.Skip("skipping in short mode")
  3183  	}
  3184  	ln := newLocalListener(t)
  3185  	defer ln.Close()
  3186  	testdonec := make(chan struct{})
  3187  	defer close(testdonec)
  3188  
  3189  	go func() {
  3190  		c, err := ln.Accept()
  3191  		if err != nil {
  3192  			t.Error(err)
  3193  			return
  3194  		}
  3195  		<-testdonec
  3196  		c.Close()
  3197  	}()
  3198  
  3199  	getdonec := make(chan struct{})
  3200  	go func() {
  3201  		defer close(getdonec)
  3202  		tr := &Transport{
  3203  			Dial: func(_, _ string) (net.Conn, error) {
  3204  				return net.Dial("tcp", ln.Addr().String())
  3205  			},
  3206  			TLSHandshakeTimeout: 250 * time.Millisecond,
  3207  		}
  3208  		cl := &Client{Transport: tr}
  3209  		_, err := cl.Get("https://dummy.tld/")
  3210  		if err == nil {
  3211  			t.Error("expected error")
  3212  			return
  3213  		}
  3214  		ue, ok := err.(*url.Error)
  3215  		if !ok {
  3216  			t.Errorf("expected url.Error; got %#v", err)
  3217  			return
  3218  		}
  3219  		ne, ok := ue.Err.(net.Error)
  3220  		if !ok {
  3221  			t.Errorf("expected net.Error; got %#v", err)
  3222  			return
  3223  		}
  3224  		if !ne.Timeout() {
  3225  			t.Errorf("expected timeout error; got %v", err)
  3226  		}
  3227  		if !strings.Contains(err.Error(), "handshake timeout") {
  3228  			t.Errorf("expected 'handshake timeout' in error; got %v", err)
  3229  		}
  3230  	}()
  3231  	select {
  3232  	case <-getdonec:
  3233  	case <-time.After(5 * time.Second):
  3234  		t.Error("test timeout; TLS handshake hung?")
  3235  	}
  3236  }
  3237  
  3238  // Trying to repro golang.org/issue/3514
  3239  func TestTLSServerClosesConnection(t *testing.T) {
  3240  	defer afterTest(t)
  3241  
  3242  	closedc := make(chan bool, 1)
  3243  	ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3244  		if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
  3245  			conn, _, _ := w.(Hijacker).Hijack()
  3246  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
  3247  			conn.Close()
  3248  			closedc <- true
  3249  			return
  3250  		}
  3251  		fmt.Fprintf(w, "hello")
  3252  	}))
  3253  	defer ts.Close()
  3254  
  3255  	c := ts.Client()
  3256  	tr := c.Transport.(*Transport)
  3257  
  3258  	var nSuccess = 0
  3259  	var errs []error
  3260  	const trials = 20
  3261  	for i := 0; i < trials; i++ {
  3262  		tr.CloseIdleConnections()
  3263  		res, err := c.Get(ts.URL + "/keep-alive-then-die")
  3264  		if err != nil {
  3265  			t.Fatal(err)
  3266  		}
  3267  		<-closedc
  3268  		slurp, err := io.ReadAll(res.Body)
  3269  		if err != nil {
  3270  			t.Fatal(err)
  3271  		}
  3272  		if string(slurp) != "foo" {
  3273  			t.Errorf("Got %q, want foo", slurp)
  3274  		}
  3275  
  3276  		// Now try again and see if we successfully
  3277  		// pick a new connection.
  3278  		res, err = c.Get(ts.URL + "/")
  3279  		if err != nil {
  3280  			errs = append(errs, err)
  3281  			continue
  3282  		}
  3283  		slurp, err = io.ReadAll(res.Body)
  3284  		if err != nil {
  3285  			errs = append(errs, err)
  3286  			continue
  3287  		}
  3288  		nSuccess++
  3289  	}
  3290  	if nSuccess > 0 {
  3291  		t.Logf("successes = %d of %d", nSuccess, trials)
  3292  	} else {
  3293  		t.Errorf("All runs failed:")
  3294  	}
  3295  	for _, err := range errs {
  3296  		t.Logf("  err: %v", err)
  3297  	}
  3298  }
  3299  
  3300  // byteFromChanReader is an io.Reader that reads a single byte at a
  3301  // time from the channel. When the channel is closed, the reader
  3302  // returns io.EOF.
  3303  type byteFromChanReader chan byte
  3304  
  3305  func (c byteFromChanReader) Read(p []byte) (n int, err error) {
  3306  	if len(p) == 0 {
  3307  		return
  3308  	}
  3309  	b, ok := <-c
  3310  	if !ok {
  3311  		return 0, io.EOF
  3312  	}
  3313  	p[0] = b
  3314  	return 1, nil
  3315  }
  3316  
  3317  // Verifies that the Transport doesn't reuse a connection in the case
  3318  // where the server replies before the request has been fully
  3319  // written. We still honor that reply (see TestIssue3595), but don't
  3320  // send future requests on the connection because it's then in a
  3321  // questionable state.
  3322  // golang.org/issue/7569
  3323  func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
  3324  	setParallel(t)
  3325  	defer afterTest(t)
  3326  	var sconn struct {
  3327  		sync.Mutex
  3328  		c net.Conn
  3329  	}
  3330  	var getOkay bool
  3331  	closeConn := func() {
  3332  		sconn.Lock()
  3333  		defer sconn.Unlock()
  3334  		if sconn.c != nil {
  3335  			sconn.c.Close()
  3336  			sconn.c = nil
  3337  			if !getOkay {
  3338  				t.Logf("Closed server connection")
  3339  			}
  3340  		}
  3341  	}
  3342  	defer closeConn()
  3343  
  3344  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3345  		if r.Method == "GET" {
  3346  			io.WriteString(w, "bar")
  3347  			return
  3348  		}
  3349  		conn, _, _ := w.(Hijacker).Hijack()
  3350  		sconn.Lock()
  3351  		sconn.c = conn
  3352  		sconn.Unlock()
  3353  		conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
  3354  		go io.Copy(io.Discard, conn)
  3355  	}))
  3356  	defer ts.Close()
  3357  	c := ts.Client()
  3358  
  3359  	const bodySize = 256 << 10
  3360  	finalBit := make(byteFromChanReader, 1)
  3361  	req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
  3362  	req.ContentLength = bodySize
  3363  	res, err := c.Do(req)
  3364  	if err := wantBody(res, err, "foo"); err != nil {
  3365  		t.Errorf("POST response: %v", err)
  3366  	}
  3367  	donec := make(chan bool)
  3368  	go func() {
  3369  		defer close(donec)
  3370  		res, err = c.Get(ts.URL)
  3371  		if err := wantBody(res, err, "bar"); err != nil {
  3372  			t.Errorf("GET response: %v", err)
  3373  			return
  3374  		}
  3375  		getOkay = true // suppress test noise
  3376  	}()
  3377  	time.AfterFunc(5*time.Second, closeConn)
  3378  	select {
  3379  	case <-donec:
  3380  		finalBit <- 'x' // unblock the writeloop of the first Post
  3381  		close(finalBit)
  3382  	case <-time.After(7 * time.Second):
  3383  		t.Fatal("timeout waiting for GET request to finish")
  3384  	}
  3385  }
  3386  
  3387  // Tests that we don't leak Transport persistConn.readLoop goroutines
  3388  // when a server hangs up immediately after saying it would keep-alive.
  3389  func TestTransportIssue10457(t *testing.T) {
  3390  	defer afterTest(t) // used to fail in goroutine leak check
  3391  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3392  		// Send a response with no body, keep-alive
  3393  		// (implicit), and then lie and immediately close the
  3394  		// connection. This forces the Transport's readLoop to
  3395  		// immediately Peek an io.EOF and get to the point
  3396  		// that used to hang.
  3397  		conn, _, _ := w.(Hijacker).Hijack()
  3398  		conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
  3399  		conn.Close()
  3400  	}))
  3401  	defer ts.Close()
  3402  	c := ts.Client()
  3403  
  3404  	res, err := c.Get(ts.URL)
  3405  	if err != nil {
  3406  		t.Fatalf("Get: %v", err)
  3407  	}
  3408  	defer res.Body.Close()
  3409  
  3410  	// Just a sanity check that we at least get the response. The real
  3411  	// test here is that the "defer afterTest" above doesn't find any
  3412  	// leaked goroutines.
  3413  	if got, want := res.Header.Get("Foo"), "Bar"; got != want {
  3414  		t.Errorf("Foo header = %q; want %q", got, want)
  3415  	}
  3416  }
  3417  
  3418  type closerFunc func() error
  3419  
  3420  func (f closerFunc) Close() error { return f() }
  3421  
  3422  type writerFuncConn struct {
  3423  	net.Conn
  3424  	write func(p []byte) (n int, err error)
  3425  }
  3426  
  3427  func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
  3428  
  3429  // Issues 4677, 18241, and 17844. If we try to reuse a connection that the
  3430  // server is in the process of closing, we may end up successfully writing out
  3431  // our request (or a portion of our request) only to find a connection error
  3432  // when we try to read from (or finish writing to) the socket.
  3433  //
  3434  // NOTE: we resend a request only if:
  3435  //   - we reused a keep-alive connection
  3436  //   - we haven't yet received any header data
  3437  //   - either we wrote no bytes to the server, or the request is idempotent
  3438  // This automatically prevents an infinite resend loop because we'll run out of
  3439  // the cached keep-alive connections eventually.
  3440  func TestRetryRequestsOnError(t *testing.T) {
  3441  	newRequest := func(method, urlStr string, body io.Reader) *Request {
  3442  		req, err := NewRequest(method, urlStr, body)
  3443  		if err != nil {
  3444  			t.Fatal(err)
  3445  		}
  3446  		return req
  3447  	}
  3448  
  3449  	testCases := []struct {
  3450  		name       string
  3451  		failureN   int
  3452  		failureErr error
  3453  		// Note that we can't just re-use the Request object across calls to c.Do
  3454  		// because we need to rewind Body between calls.  (GetBody is only used to
  3455  		// rewind Body on failure and redirects, not just because it's done.)
  3456  		req       func() *Request
  3457  		reqString string
  3458  	}{
  3459  		{
  3460  			name: "IdempotentNoBodySomeWritten",
  3461  			// Believe that we've written some bytes to the server, so we know we're
  3462  			// not just in the "retry when no bytes sent" case".
  3463  			failureN: 1,
  3464  			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
  3465  			failureErr: ExportErrServerClosedIdle,
  3466  			req: func() *Request {
  3467  				return newRequest("GET", "http://fake.golang", nil)
  3468  			},
  3469  			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
  3470  		},
  3471  		{
  3472  			name: "IdempotentGetBodySomeWritten",
  3473  			// Believe that we've written some bytes to the server, so we know we're
  3474  			// not just in the "retry when no bytes sent" case".
  3475  			failureN: 1,
  3476  			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
  3477  			failureErr: ExportErrServerClosedIdle,
  3478  			req: func() *Request {
  3479  				return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
  3480  			},
  3481  			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
  3482  		},
  3483  		{
  3484  			name: "NothingWrittenNoBody",
  3485  			// It's key that we return 0 here -- that's what enables Transport to know
  3486  			// that nothing was written, even though this is a non-idempotent request.
  3487  			failureN:   0,
  3488  			failureErr: errors.New("second write fails"),
  3489  			req: func() *Request {
  3490  				return newRequest("DELETE", "http://fake.golang", nil)
  3491  			},
  3492  			reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
  3493  		},
  3494  		{
  3495  			name: "NothingWrittenGetBody",
  3496  			// It's key that we return 0 here -- that's what enables Transport to know
  3497  			// that nothing was written, even though this is a non-idempotent request.
  3498  			failureN:   0,
  3499  			failureErr: errors.New("second write fails"),
  3500  			// Note that NewRequest will set up GetBody for strings.Reader, which is
  3501  			// required for the retry to occur
  3502  			req: func() *Request {
  3503  				return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
  3504  			},
  3505  			reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
  3506  		},
  3507  	}
  3508  
  3509  	for _, tc := range testCases {
  3510  		t.Run(tc.name, func(t *testing.T) {
  3511  			defer afterTest(t)
  3512  
  3513  			var (
  3514  				mu     sync.Mutex
  3515  				logbuf bytes.Buffer
  3516  			)
  3517  			logf := func(format string, args ...interface{}) {
  3518  				mu.Lock()
  3519  				defer mu.Unlock()
  3520  				fmt.Fprintf(&logbuf, format, args...)
  3521  				logbuf.WriteByte('\n')
  3522  			}
  3523  
  3524  			ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3525  				logf("Handler")
  3526  				w.Header().Set("X-Status", "ok")
  3527  			}))
  3528  			defer ts.Close()
  3529  
  3530  			var writeNumAtomic int32
  3531  			c := ts.Client()
  3532  			c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
  3533  				logf("Dial")
  3534  				c, err := net.Dial(network, ts.Listener.Addr().String())
  3535  				if err != nil {
  3536  					logf("Dial error: %v", err)
  3537  					return nil, err
  3538  				}
  3539  				return &writerFuncConn{
  3540  					Conn: c,
  3541  					write: func(p []byte) (n int, err error) {
  3542  						if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
  3543  							logf("intentional write failure")
  3544  							return tc.failureN, tc.failureErr
  3545  						}
  3546  						logf("Write(%q)", p)
  3547  						return c.Write(p)
  3548  					},
  3549  				}, nil
  3550  			}
  3551  
  3552  			SetRoundTripRetried(func() {
  3553  				logf("Retried.")
  3554  			})
  3555  			defer SetRoundTripRetried(nil)
  3556  
  3557  			for i := 0; i < 3; i++ {
  3558  				t0 := time.Now()
  3559  				req := tc.req()
  3560  				res, err := c.Do(req)
  3561  				if err != nil {
  3562  					if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 {
  3563  						mu.Lock()
  3564  						got := logbuf.String()
  3565  						mu.Unlock()
  3566  						t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
  3567  					}
  3568  					t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse)
  3569  				}
  3570  				res.Body.Close()
  3571  				if res.Request != req {
  3572  					t.Errorf("Response.Request != original request; want identical Request")
  3573  				}
  3574  			}
  3575  
  3576  			mu.Lock()
  3577  			got := logbuf.String()
  3578  			mu.Unlock()
  3579  			want := fmt.Sprintf(`Dial
  3580  Write("%s")
  3581  Handler
  3582  intentional write failure
  3583  Retried.
  3584  Dial
  3585  Write("%s")
  3586  Handler
  3587  Write("%s")
  3588  Handler
  3589  `, tc.reqString, tc.reqString, tc.reqString)
  3590  			if got != want {
  3591  				t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
  3592  			}
  3593  		})
  3594  	}
  3595  }
  3596  
  3597  // Issue 6981
  3598  func TestTransportClosesBodyOnError(t *testing.T) {
  3599  	setParallel(t)
  3600  	defer afterTest(t)
  3601  	readBody := make(chan error, 1)
  3602  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3603  		_, err := io.ReadAll(r.Body)
  3604  		readBody <- err
  3605  	}))
  3606  	defer ts.Close()
  3607  	c := ts.Client()
  3608  	fakeErr := errors.New("fake error")
  3609  	didClose := make(chan bool, 1)
  3610  	req, _ := NewRequest("POST", ts.URL, struct {
  3611  		io.Reader
  3612  		io.Closer
  3613  	}{
  3614  		io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
  3615  		closerFunc(func() error {
  3616  			select {
  3617  			case didClose <- true:
  3618  			default:
  3619  			}
  3620  			return nil
  3621  		}),
  3622  	})
  3623  	res, err := c.Do(req)
  3624  	if res != nil {
  3625  		defer res.Body.Close()
  3626  	}
  3627  	if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
  3628  		t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
  3629  	}
  3630  	select {
  3631  	case err := <-readBody:
  3632  		if err == nil {
  3633  			t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
  3634  		}
  3635  	case <-time.After(5 * time.Second):
  3636  		t.Error("timeout waiting for server handler to complete")
  3637  	}
  3638  	select {
  3639  	case <-didClose:
  3640  	default:
  3641  		t.Errorf("didn't see Body.Close")
  3642  	}
  3643  }
  3644  
  3645  func TestTransportDialTLS(t *testing.T) {
  3646  	setParallel(t)
  3647  	defer afterTest(t)
  3648  	var mu sync.Mutex // guards following
  3649  	var gotReq, didDial bool
  3650  
  3651  	ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3652  		mu.Lock()
  3653  		gotReq = true
  3654  		mu.Unlock()
  3655  	}))
  3656  	defer ts.Close()
  3657  	c := ts.Client()
  3658  	c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
  3659  		mu.Lock()
  3660  		didDial = true
  3661  		mu.Unlock()
  3662  		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
  3663  		if err != nil {
  3664  			return nil, err
  3665  		}
  3666  		return c, c.Handshake()
  3667  	}
  3668  
  3669  	res, err := c.Get(ts.URL)
  3670  	if err != nil {
  3671  		t.Fatal(err)
  3672  	}
  3673  	res.Body.Close()
  3674  	mu.Lock()
  3675  	if !gotReq {
  3676  		t.Error("didn't get request")
  3677  	}
  3678  	if !didDial {
  3679  		t.Error("didn't use dial hook")
  3680  	}
  3681  }
  3682  
  3683  func TestTransportDialContext(t *testing.T) {
  3684  	setParallel(t)
  3685  	defer afterTest(t)
  3686  	var mu sync.Mutex // guards following
  3687  	var gotReq bool
  3688  	var receivedContext context.Context
  3689  
  3690  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3691  		mu.Lock()
  3692  		gotReq = true
  3693  		mu.Unlock()
  3694  	}))
  3695  	defer ts.Close()
  3696  	c := ts.Client()
  3697  	c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
  3698  		mu.Lock()
  3699  		receivedContext = ctx
  3700  		mu.Unlock()
  3701  		return net.Dial(netw, addr)
  3702  	}
  3703  
  3704  	req, err := NewRequest("GET", ts.URL, nil)
  3705  	if err != nil {
  3706  		t.Fatal(err)
  3707  	}
  3708  	ctx := context.WithValue(context.Background(), "some-key", "some-value")
  3709  	res, err := c.Do(req.WithContext(ctx))
  3710  	if err != nil {
  3711  		t.Fatal(err)
  3712  	}
  3713  	res.Body.Close()
  3714  	mu.Lock()
  3715  	if !gotReq {
  3716  		t.Error("didn't get request")
  3717  	}
  3718  	if receivedContext != ctx {
  3719  		t.Error("didn't receive correct context")
  3720  	}
  3721  }
  3722  
  3723  func TestTransportDialTLSContext(t *testing.T) {
  3724  	setParallel(t)
  3725  	defer afterTest(t)
  3726  	var mu sync.Mutex // guards following
  3727  	var gotReq bool
  3728  	var receivedContext context.Context
  3729  
  3730  	ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3731  		mu.Lock()
  3732  		gotReq = true
  3733  		mu.Unlock()
  3734  	}))
  3735  	defer ts.Close()
  3736  	c := ts.Client()
  3737  	c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
  3738  		mu.Lock()
  3739  		receivedContext = ctx
  3740  		mu.Unlock()
  3741  		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
  3742  		if err != nil {
  3743  			return nil, err
  3744  		}
  3745  		return c, c.HandshakeContext(ctx)
  3746  	}
  3747  
  3748  	req, err := NewRequest("GET", ts.URL, nil)
  3749  	if err != nil {
  3750  		t.Fatal(err)
  3751  	}
  3752  	ctx := context.WithValue(context.Background(), "some-key", "some-value")
  3753  	res, err := c.Do(req.WithContext(ctx))
  3754  	if err != nil {
  3755  		t.Fatal(err)
  3756  	}
  3757  	res.Body.Close()
  3758  	mu.Lock()
  3759  	if !gotReq {
  3760  		t.Error("didn't get request")
  3761  	}
  3762  	if receivedContext != ctx {
  3763  		t.Error("didn't receive correct context")
  3764  	}
  3765  }
  3766  
  3767  // Test for issue 8755
  3768  // Ensure that if a proxy returns an error, it is exposed by RoundTrip
  3769  func TestRoundTripReturnsProxyError(t *testing.T) {
  3770  	badProxy := func(*Request) (*url.URL, error) {
  3771  		return nil, errors.New("errorMessage")
  3772  	}
  3773  
  3774  	tr := &Transport{Proxy: badProxy}
  3775  
  3776  	req, _ := NewRequest("GET", "http://example.com", nil)
  3777  
  3778  	_, err := tr.RoundTrip(req)
  3779  
  3780  	if err == nil {
  3781  		t.Error("Expected proxy error to be returned by RoundTrip")
  3782  	}
  3783  }
  3784  
  3785  // tests that putting an idle conn after a call to CloseIdleConns does return it
  3786  func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
  3787  	tr := &Transport{}
  3788  	wantIdle := func(when string, n int) bool {
  3789  		got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
  3790  		if got == n {
  3791  			return true
  3792  		}
  3793  		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
  3794  		return false
  3795  	}
  3796  	wantIdle("start", 0)
  3797  	if !tr.PutIdleTestConn("http", "example.com") {
  3798  		t.Fatal("put failed")
  3799  	}
  3800  	if !tr.PutIdleTestConn("http", "example.com") {
  3801  		t.Fatal("second put failed")
  3802  	}
  3803  	wantIdle("after put", 2)
  3804  	tr.CloseIdleConnections()
  3805  	if !tr.IsIdleForTesting() {
  3806  		t.Error("should be idle after CloseIdleConnections")
  3807  	}
  3808  	wantIdle("after close idle", 0)
  3809  	if tr.PutIdleTestConn("http", "example.com") {
  3810  		t.Fatal("put didn't fail")
  3811  	}
  3812  	wantIdle("after second put", 0)
  3813  
  3814  	tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
  3815  	if tr.IsIdleForTesting() {
  3816  		t.Error("shouldn't be idle after QueueForIdleConnForTesting")
  3817  	}
  3818  	if !tr.PutIdleTestConn("http", "example.com") {
  3819  		t.Fatal("after re-activation")
  3820  	}
  3821  	wantIdle("after final put", 1)
  3822  }
  3823  
  3824  // Test for issue 34282
  3825  // Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn
  3826  func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
  3827  	tr := &Transport{}
  3828  	wantIdle := func(when string, n int) bool {
  3829  		got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
  3830  		if got == n {
  3831  			return true
  3832  		}
  3833  		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
  3834  		return false
  3835  	}
  3836  	wantIdle("start", 0)
  3837  	alt := funcRoundTripper(func() {})
  3838  	if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
  3839  		t.Fatal("put failed")
  3840  	}
  3841  	wantIdle("after put", 1)
  3842  	ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  3843  		GotConn: func(httptrace.GotConnInfo) {
  3844  			// tr.getConn should leave it for the HTTP/2 alt to call GotConn.
  3845  			t.Error("GotConn called")
  3846  		},
  3847  	})
  3848  	req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
  3849  	_, err := tr.RoundTrip(req)
  3850  	if err != errFakeRoundTrip {
  3851  		t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
  3852  	}
  3853  	wantIdle("after round trip", 1)
  3854  }
  3855  
  3856  func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
  3857  	if testing.Short() {
  3858  		t.Skip("skipping in short mode")
  3859  	}
  3860  
  3861  	trFunc := func(tr *Transport) {
  3862  		tr.MaxConnsPerHost = 1
  3863  		tr.MaxIdleConnsPerHost = 1
  3864  		tr.IdleConnTimeout = 10 * time.Millisecond
  3865  	}
  3866  	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
  3867  	defer cst.close()
  3868  
  3869  	if _, err := cst.c.Get(cst.ts.URL); err != nil {
  3870  		t.Fatalf("got error: %s", err)
  3871  	}
  3872  
  3873  	time.Sleep(100 * time.Millisecond)
  3874  	got := make(chan error)
  3875  	go func() {
  3876  		if _, err := cst.c.Get(cst.ts.URL); err != nil {
  3877  			got <- err
  3878  		}
  3879  		close(got)
  3880  	}()
  3881  
  3882  	timeout := time.NewTimer(5 * time.Second)
  3883  	defer timeout.Stop()
  3884  	select {
  3885  	case err := <-got:
  3886  		if err != nil {
  3887  			t.Fatalf("got error: %s", err)
  3888  		}
  3889  	case <-timeout.C:
  3890  		t.Fatal("request never completed")
  3891  	}
  3892  }
  3893  
  3894  // This tests that a client requesting a content range won't also
  3895  // implicitly ask for gzip support. If they want that, they need to do it
  3896  // on their own.
  3897  // golang.org/issue/8923
  3898  func TestTransportRangeAndGzip(t *testing.T) {
  3899  	defer afterTest(t)
  3900  	reqc := make(chan *Request, 1)
  3901  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3902  		reqc <- r
  3903  	}))
  3904  	defer ts.Close()
  3905  	c := ts.Client()
  3906  
  3907  	req, _ := NewRequest("GET", ts.URL, nil)
  3908  	req.Header.Set("Range", "bytes=7-11")
  3909  	res, err := c.Do(req)
  3910  	if err != nil {
  3911  		t.Fatal(err)
  3912  	}
  3913  
  3914  	select {
  3915  	case r := <-reqc:
  3916  		if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
  3917  			t.Error("Transport advertised gzip support in the Accept header")
  3918  		}
  3919  		if r.Header.Get("Range") == "" {
  3920  			t.Error("no Range in request")
  3921  		}
  3922  	case <-time.After(10 * time.Second):
  3923  		t.Fatal("timeout")
  3924  	}
  3925  	res.Body.Close()
  3926  }
  3927  
  3928  // Test for issue 10474
  3929  func TestTransportResponseCancelRace(t *testing.T) {
  3930  	defer afterTest(t)
  3931  
  3932  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3933  		// important that this response has a body.
  3934  		var b [1024]byte
  3935  		w.Write(b[:])
  3936  	}))
  3937  	defer ts.Close()
  3938  	tr := ts.Client().Transport.(*Transport)
  3939  
  3940  	req, err := NewRequest("GET", ts.URL, nil)
  3941  	if err != nil {
  3942  		t.Fatal(err)
  3943  	}
  3944  	res, err := tr.RoundTrip(req)
  3945  	if err != nil {
  3946  		t.Fatal(err)
  3947  	}
  3948  	// If we do an early close, Transport just throws the connection away and
  3949  	// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
  3950  	// so read the body
  3951  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  3952  		t.Fatal(err)
  3953  	}
  3954  
  3955  	req2, err := NewRequest("GET", ts.URL, nil)
  3956  	if err != nil {
  3957  		t.Fatal(err)
  3958  	}
  3959  	tr.CancelRequest(req)
  3960  	res, err = tr.RoundTrip(req2)
  3961  	if err != nil {
  3962  		t.Fatal(err)
  3963  	}
  3964  	res.Body.Close()
  3965  }
  3966  
  3967  // Test for issue 19248: Content-Encoding's value is case insensitive.
  3968  func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
  3969  	setParallel(t)
  3970  	defer afterTest(t)
  3971  	for _, ce := range []string{"gzip", "GZIP"} {
  3972  		ce := ce
  3973  		t.Run(ce, func(t *testing.T) {
  3974  			const encodedString = "Hello Gopher"
  3975  			ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  3976  				w.Header().Set("Content-Encoding", ce)
  3977  				gz := gzip.NewWriter(w)
  3978  				gz.Write([]byte(encodedString))
  3979  				gz.Close()
  3980  			}))
  3981  			defer ts.Close()
  3982  
  3983  			res, err := ts.Client().Get(ts.URL)
  3984  			if err != nil {
  3985  				t.Fatal(err)
  3986  			}
  3987  
  3988  			body, err := io.ReadAll(res.Body)
  3989  			res.Body.Close()
  3990  			if err != nil {
  3991  				t.Fatal(err)
  3992  			}
  3993  
  3994  			if string(body) != encodedString {
  3995  				t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
  3996  			}
  3997  		})
  3998  	}
  3999  }
  4000  
  4001  func TestTransportDialCancelRace(t *testing.T) {
  4002  	defer afterTest(t)
  4003  
  4004  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
  4005  	defer ts.Close()
  4006  	tr := ts.Client().Transport.(*Transport)
  4007  
  4008  	req, err := NewRequest("GET", ts.URL, nil)
  4009  	if err != nil {
  4010  		t.Fatal(err)
  4011  	}
  4012  	SetEnterRoundTripHook(func() {
  4013  		tr.CancelRequest(req)
  4014  	})
  4015  	defer SetEnterRoundTripHook(nil)
  4016  	res, err := tr.RoundTrip(req)
  4017  	if err != ExportErrRequestCanceled {
  4018  		t.Errorf("expected canceled request error; got %v", err)
  4019  		if err == nil {
  4020  			res.Body.Close()
  4021  		}
  4022  	}
  4023  }
  4024  
  4025  // logWritesConn is a net.Conn that logs each Write call to writes
  4026  // and then proxies to w.
  4027  // It proxies Read calls to a reader it receives from rch.
  4028  type logWritesConn struct {
  4029  	net.Conn // nil. crash on use.
  4030  
  4031  	w io.Writer
  4032  
  4033  	rch <-chan io.Reader
  4034  	r   io.Reader // nil until received by rch
  4035  
  4036  	mu     sync.Mutex
  4037  	writes []string
  4038  }
  4039  
  4040  func (c *logWritesConn) Write(p []byte) (n int, err error) {
  4041  	c.mu.Lock()
  4042  	defer c.mu.Unlock()
  4043  	c.writes = append(c.writes, string(p))
  4044  	return c.w.Write(p)
  4045  }
  4046  
  4047  func (c *logWritesConn) Read(p []byte) (n int, err error) {
  4048  	if c.r == nil {
  4049  		c.r = <-c.rch
  4050  	}
  4051  	return c.r.Read(p)
  4052  }
  4053  
  4054  func (c *logWritesConn) Close() error { return nil }
  4055  
  4056  // Issue 6574
  4057  func TestTransportFlushesBodyChunks(t *testing.T) {
  4058  	defer afterTest(t)
  4059  	resBody := make(chan io.Reader, 1)
  4060  	connr, connw := io.Pipe() // connection pipe pair
  4061  	lw := &logWritesConn{
  4062  		rch: resBody,
  4063  		w:   connw,
  4064  	}
  4065  	tr := &Transport{
  4066  		Dial: func(network, addr string) (net.Conn, error) {
  4067  			return lw, nil
  4068  		},
  4069  	}
  4070  	bodyr, bodyw := io.Pipe() // body pipe pair
  4071  	go func() {
  4072  		defer bodyw.Close()
  4073  		for i := 0; i < 3; i++ {
  4074  			fmt.Fprintf(bodyw, "num%d\n", i)
  4075  		}
  4076  	}()
  4077  	resc := make(chan *Response)
  4078  	go func() {
  4079  		req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
  4080  		req.Header.Set("User-Agent", "x") // known value for test
  4081  		res, err := tr.RoundTrip(req)
  4082  		if err != nil {
  4083  			t.Errorf("RoundTrip: %v", err)
  4084  			close(resc)
  4085  			return
  4086  		}
  4087  		resc <- res
  4088  
  4089  	}()
  4090  	// Fully consume the request before checking the Write log vs. want.
  4091  	req, err := ReadRequest(bufio.NewReader(connr))
  4092  	if err != nil {
  4093  		t.Fatal(err)
  4094  	}
  4095  	io.Copy(io.Discard, req.Body)
  4096  
  4097  	// Unblock the transport's roundTrip goroutine.
  4098  	resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
  4099  	res, ok := <-resc
  4100  	if !ok {
  4101  		return
  4102  	}
  4103  	defer res.Body.Close()
  4104  
  4105  	want := []string{
  4106  		"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
  4107  		"5\r\nnum0\n\r\n",
  4108  		"5\r\nnum1\n\r\n",
  4109  		"5\r\nnum2\n\r\n",
  4110  		"0\r\n\r\n",
  4111  	}
  4112  	if !reflect.DeepEqual(lw.writes, want) {
  4113  		t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
  4114  	}
  4115  }
  4116  
  4117  // Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
  4118  func TestTransportFlushesRequestHeader(t *testing.T) {
  4119  	defer afterTest(t)
  4120  	gotReq := make(chan struct{})
  4121  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4122  		close(gotReq)
  4123  	}))
  4124  	defer cst.close()
  4125  
  4126  	pr, pw := io.Pipe()
  4127  	req, err := NewRequest("POST", cst.ts.URL, pr)
  4128  	if err != nil {
  4129  		t.Fatal(err)
  4130  	}
  4131  	gotRes := make(chan struct{})
  4132  	go func() {
  4133  		defer close(gotRes)
  4134  		res, err := cst.tr.RoundTrip(req)
  4135  		if err != nil {
  4136  			t.Error(err)
  4137  			return
  4138  		}
  4139  		res.Body.Close()
  4140  	}()
  4141  
  4142  	select {
  4143  	case <-gotReq:
  4144  		pw.Close()
  4145  	case <-time.After(5 * time.Second):
  4146  		t.Fatal("timeout waiting for handler to get request")
  4147  	}
  4148  	<-gotRes
  4149  }
  4150  
  4151  // Issue 11745.
  4152  func TestTransportPrefersResponseOverWriteError(t *testing.T) {
  4153  	if testing.Short() {
  4154  		t.Skip("skipping in short mode")
  4155  	}
  4156  	defer afterTest(t)
  4157  	const contentLengthLimit = 1024 * 1024 // 1MB
  4158  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  4159  		if r.ContentLength >= contentLengthLimit {
  4160  			w.WriteHeader(StatusBadRequest)
  4161  			r.Body.Close()
  4162  			return
  4163  		}
  4164  		w.WriteHeader(StatusOK)
  4165  	}))
  4166  	defer ts.Close()
  4167  	c := ts.Client()
  4168  
  4169  	fail := 0
  4170  	count := 100
  4171  	bigBody := strings.Repeat("a", contentLengthLimit*2)
  4172  	for i := 0; i < count; i++ {
  4173  		req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
  4174  		if err != nil {
  4175  			t.Fatal(err)
  4176  		}
  4177  		resp, err := c.Do(req)
  4178  		if err != nil {
  4179  			fail++
  4180  			t.Logf("%d = %#v", i, err)
  4181  			if ue, ok := err.(*url.Error); ok {
  4182  				t.Logf("urlErr = %#v", ue.Err)
  4183  				if ne, ok := ue.Err.(*net.OpError); ok {
  4184  					t.Logf("netOpError = %#v", ne.Err)
  4185  				}
  4186  			}
  4187  		} else {
  4188  			resp.Body.Close()
  4189  			if resp.StatusCode != 400 {
  4190  				t.Errorf("Expected status code 400, got %v", resp.Status)
  4191  			}
  4192  		}
  4193  	}
  4194  	if fail > 0 {
  4195  		t.Errorf("Failed %v out of %v\n", fail, count)
  4196  	}
  4197  }
  4198  
  4199  func TestTransportAutomaticHTTP2(t *testing.T) {
  4200  	testTransportAutoHTTP(t, &Transport{}, true)
  4201  }
  4202  
  4203  func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
  4204  	testTransportAutoHTTP(t, &Transport{
  4205  		ForceAttemptHTTP2: true,
  4206  		TLSClientConfig:   new(tls.Config),
  4207  	}, true)
  4208  }
  4209  
  4210  // golang.org/issue/14391: also check DefaultTransport
  4211  func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
  4212  	testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
  4213  }
  4214  
  4215  func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
  4216  	testTransportAutoHTTP(t, &Transport{
  4217  		TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
  4218  	}, false)
  4219  }
  4220  
  4221  func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
  4222  	testTransportAutoHTTP(t, &Transport{
  4223  		TLSClientConfig: new(tls.Config),
  4224  	}, false)
  4225  }
  4226  
  4227  func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
  4228  	testTransportAutoHTTP(t, &Transport{
  4229  		ExpectContinueTimeout: 1 * time.Second,
  4230  	}, true)
  4231  }
  4232  
  4233  func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
  4234  	var d net.Dialer
  4235  	testTransportAutoHTTP(t, &Transport{
  4236  		Dial: d.Dial,
  4237  	}, false)
  4238  }
  4239  
  4240  func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
  4241  	var d net.Dialer
  4242  	testTransportAutoHTTP(t, &Transport{
  4243  		DialContext: d.DialContext,
  4244  	}, false)
  4245  }
  4246  
  4247  func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
  4248  	testTransportAutoHTTP(t, &Transport{
  4249  		DialTLS: func(network, addr string) (net.Conn, error) {
  4250  			panic("unused")
  4251  		},
  4252  	}, false)
  4253  }
  4254  
  4255  func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
  4256  	CondSkipHTTP2(t)
  4257  	_, err := tr.RoundTrip(new(Request))
  4258  	if err == nil {
  4259  		t.Error("expected error from RoundTrip")
  4260  	}
  4261  	if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
  4262  		t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
  4263  	}
  4264  }
  4265  
  4266  // Issue 13633: there was a race where we returned bodyless responses
  4267  // to callers before recycling the persistent connection, which meant
  4268  // a client doing two subsequent requests could end up on different
  4269  // connections. It's somewhat harmless but enough tests assume it's
  4270  // not true in order to test other things that it's worth fixing.
  4271  // Plus it's nice to be consistent and not have timing-dependent
  4272  // behavior.
  4273  func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
  4274  	defer afterTest(t)
  4275  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4276  		w.Header().Set("X-Addr", r.RemoteAddr)
  4277  		// Empty response body.
  4278  	}))
  4279  	defer cst.close()
  4280  	n := 100
  4281  	if testing.Short() {
  4282  		n = 10
  4283  	}
  4284  	var firstAddr string
  4285  	for i := 0; i < n; i++ {
  4286  		res, err := cst.c.Get(cst.ts.URL)
  4287  		if err != nil {
  4288  			log.Fatal(err)
  4289  		}
  4290  		addr := res.Header.Get("X-Addr")
  4291  		if i == 0 {
  4292  			firstAddr = addr
  4293  		} else if addr != firstAddr {
  4294  			t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
  4295  		}
  4296  		res.Body.Close()
  4297  	}
  4298  }
  4299  
  4300  // Issue 13839
  4301  func TestNoCrashReturningTransportAltConn(t *testing.T) {
  4302  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
  4303  	if err != nil {
  4304  		t.Fatal(err)
  4305  	}
  4306  	ln := newLocalListener(t)
  4307  	defer ln.Close()
  4308  
  4309  	var wg sync.WaitGroup
  4310  	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
  4311  	defer SetPendingDialHooks(nil, nil)
  4312  
  4313  	testDone := make(chan struct{})
  4314  	defer close(testDone)
  4315  	go func() {
  4316  		tln := tls.NewListener(ln, &tls.Config{
  4317  			NextProtos:   []string{"foo"},
  4318  			Certificates: []tls.Certificate{cert},
  4319  		})
  4320  		sc, err := tln.Accept()
  4321  		if err != nil {
  4322  			t.Error(err)
  4323  			return
  4324  		}
  4325  		if err := sc.(*tls.Conn).Handshake(); err != nil {
  4326  			t.Error(err)
  4327  			return
  4328  		}
  4329  		<-testDone
  4330  		sc.Close()
  4331  	}()
  4332  
  4333  	addr := ln.Addr().String()
  4334  
  4335  	req, _ := NewRequest("GET", "https://fake.tld/", nil)
  4336  	cancel := make(chan struct{})
  4337  	req.Cancel = cancel
  4338  
  4339  	doReturned := make(chan bool, 1)
  4340  	madeRoundTripper := make(chan bool, 1)
  4341  
  4342  	tr := &Transport{
  4343  		DisableKeepAlives: true,
  4344  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
  4345  			"foo": func(authority string, c *tls.Conn) RoundTripper {
  4346  				madeRoundTripper <- true
  4347  				return funcRoundTripper(func() {
  4348  					t.Error("foo RoundTripper should not be called")
  4349  				})
  4350  			},
  4351  		},
  4352  		Dial: func(_, _ string) (net.Conn, error) {
  4353  			panic("shouldn't be called")
  4354  		},
  4355  		DialTLS: func(_, _ string) (net.Conn, error) {
  4356  			tc, err := tls.Dial("tcp", addr, &tls.Config{
  4357  				InsecureSkipVerify: true,
  4358  				NextProtos:         []string{"foo"},
  4359  			})
  4360  			if err != nil {
  4361  				return nil, err
  4362  			}
  4363  			if err := tc.Handshake(); err != nil {
  4364  				return nil, err
  4365  			}
  4366  			close(cancel)
  4367  			<-doReturned
  4368  			return tc, nil
  4369  		},
  4370  	}
  4371  	c := &Client{Transport: tr}
  4372  
  4373  	_, err = c.Do(req)
  4374  	if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
  4375  		t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
  4376  	}
  4377  
  4378  	doReturned <- true
  4379  	<-madeRoundTripper
  4380  	wg.Wait()
  4381  }
  4382  
  4383  func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
  4384  	testTransportReuseConnection_Gzip(t, true)
  4385  }
  4386  
  4387  func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
  4388  	testTransportReuseConnection_Gzip(t, false)
  4389  }
  4390  
  4391  // Make sure we re-use underlying TCP connection for gzipped responses too.
  4392  func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
  4393  	setParallel(t)
  4394  	defer afterTest(t)
  4395  	addr := make(chan string, 2)
  4396  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  4397  		addr <- r.RemoteAddr
  4398  		w.Header().Set("Content-Encoding", "gzip")
  4399  		if chunked {
  4400  			w.(Flusher).Flush()
  4401  		}
  4402  		w.Write(rgz) // arbitrary gzip response
  4403  	}))
  4404  	defer ts.Close()
  4405  	c := ts.Client()
  4406  
  4407  	for i := 0; i < 2; i++ {
  4408  		res, err := c.Get(ts.URL)
  4409  		if err != nil {
  4410  			t.Fatal(err)
  4411  		}
  4412  		buf := make([]byte, len(rgz))
  4413  		if n, err := io.ReadFull(res.Body, buf); err != nil {
  4414  			t.Errorf("%d. ReadFull = %v, %v", i, n, err)
  4415  		}
  4416  		// Note: no res.Body.Close call. It should work without it,
  4417  		// since the flate.Reader's internal buffering will hit EOF
  4418  		// and that should be sufficient.
  4419  	}
  4420  	a1, a2 := <-addr, <-addr
  4421  	if a1 != a2 {
  4422  		t.Fatalf("didn't reuse connection")
  4423  	}
  4424  }
  4425  
  4426  func TestTransportResponseHeaderLength(t *testing.T) {
  4427  	setParallel(t)
  4428  	defer afterTest(t)
  4429  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  4430  		if r.URL.Path == "/long" {
  4431  			w.Header().Set("Long", strings.Repeat("a", 1<<20))
  4432  		}
  4433  	}))
  4434  	defer ts.Close()
  4435  	c := ts.Client()
  4436  	c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
  4437  
  4438  	if res, err := c.Get(ts.URL); err != nil {
  4439  		t.Fatal(err)
  4440  	} else {
  4441  		res.Body.Close()
  4442  	}
  4443  
  4444  	res, err := c.Get(ts.URL + "/long")
  4445  	if err == nil {
  4446  		defer res.Body.Close()
  4447  		var n int64
  4448  		for k, vv := range res.Header {
  4449  			for _, v := range vv {
  4450  				n += int64(len(k)) + int64(len(v))
  4451  			}
  4452  		}
  4453  		t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
  4454  	}
  4455  	if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
  4456  		t.Errorf("got error: %v; want %q", err, want)
  4457  	}
  4458  }
  4459  
  4460  func TestTransportEventTrace(t *testing.T)    { testTransportEventTrace(t, h1Mode, false) }
  4461  func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) }
  4462  
  4463  // test a non-nil httptrace.ClientTrace but with all hooks set to zero.
  4464  func TestTransportEventTrace_NoHooks(t *testing.T)    { testTransportEventTrace(t, h1Mode, true) }
  4465  func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) }
  4466  
  4467  func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
  4468  	defer afterTest(t)
  4469  	const resBody = "some body"
  4470  	gotWroteReqEvent := make(chan struct{}, 500)
  4471  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  4472  		if r.Method == "GET" {
  4473  			// Do nothing for the second request.
  4474  			return
  4475  		}
  4476  		if _, err := io.ReadAll(r.Body); err != nil {
  4477  			t.Error(err)
  4478  		}
  4479  		if !noHooks {
  4480  			select {
  4481  			case <-gotWroteReqEvent:
  4482  			case <-time.After(5 * time.Second):
  4483  				t.Error("timeout waiting for WroteRequest event")
  4484  			}
  4485  		}
  4486  		io.WriteString(w, resBody)
  4487  	}))
  4488  	defer cst.close()
  4489  
  4490  	cst.tr.ExpectContinueTimeout = 1 * time.Second
  4491  
  4492  	var mu sync.Mutex // guards buf
  4493  	var buf bytes.Buffer
  4494  	logf := func(format string, args ...interface{}) {
  4495  		mu.Lock()
  4496  		defer mu.Unlock()
  4497  		fmt.Fprintf(&buf, format, args...)
  4498  		buf.WriteByte('\n')
  4499  	}
  4500  
  4501  	addrStr := cst.ts.Listener.Addr().String()
  4502  	ip, port, err := net.SplitHostPort(addrStr)
  4503  	if err != nil {
  4504  		t.Fatal(err)
  4505  	}
  4506  
  4507  	// Install a fake DNS server.
  4508  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
  4509  		if host != "dns-is-faked.golang" {
  4510  			t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
  4511  			return nil, nil
  4512  		}
  4513  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  4514  	})
  4515  
  4516  	body := "some body"
  4517  	req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
  4518  	req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
  4519  	trace := &httptrace.ClientTrace{
  4520  		GetConn:              func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
  4521  		GotConn:              func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
  4522  		GotFirstResponseByte: func() { logf("first response byte") },
  4523  		PutIdleConn:          func(err error) { logf("PutIdleConn = %v", err) },
  4524  		DNSStart:             func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
  4525  		DNSDone:              func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
  4526  		ConnectStart:         func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
  4527  		ConnectDone: func(network, addr string, err error) {
  4528  			if err != nil {
  4529  				t.Errorf("ConnectDone: %v", err)
  4530  			}
  4531  			logf("ConnectDone: connected to %s %s = %v", network, addr, err)
  4532  		},
  4533  		WroteHeaderField: func(key string, value []string) {
  4534  			logf("WroteHeaderField: %s: %v", key, value)
  4535  		},
  4536  		WroteHeaders: func() {
  4537  			logf("WroteHeaders")
  4538  		},
  4539  		Wait100Continue: func() { logf("Wait100Continue") },
  4540  		Got100Continue:  func() { logf("Got100Continue") },
  4541  		WroteRequest: func(e httptrace.WroteRequestInfo) {
  4542  			logf("WroteRequest: %+v", e)
  4543  			gotWroteReqEvent <- struct{}{}
  4544  		},
  4545  	}
  4546  	if h2 {
  4547  		trace.TLSHandshakeStart = func() { logf("tls handshake start") }
  4548  		trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
  4549  			logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
  4550  		}
  4551  	}
  4552  	if noHooks {
  4553  		// zero out all func pointers, trying to get some path to crash
  4554  		*trace = httptrace.ClientTrace{}
  4555  	}
  4556  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  4557  
  4558  	req.Header.Set("Expect", "100-continue")
  4559  	res, err := cst.c.Do(req)
  4560  	if err != nil {
  4561  		t.Fatal(err)
  4562  	}
  4563  	logf("got roundtrip.response")
  4564  	slurp, err := io.ReadAll(res.Body)
  4565  	if err != nil {
  4566  		t.Fatal(err)
  4567  	}
  4568  	logf("consumed body")
  4569  	if string(slurp) != resBody || res.StatusCode != 200 {
  4570  		t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
  4571  	}
  4572  	res.Body.Close()
  4573  
  4574  	if noHooks {
  4575  		// Done at this point. Just testing a full HTTP
  4576  		// requests can happen with a trace pointing to a zero
  4577  		// ClientTrace, full of nil func pointers.
  4578  		return
  4579  	}
  4580  
  4581  	mu.Lock()
  4582  	got := buf.String()
  4583  	mu.Unlock()
  4584  
  4585  	wantOnce := func(sub string) {
  4586  		if strings.Count(got, sub) != 1 {
  4587  			t.Errorf("expected substring %q exactly once in output.", sub)
  4588  		}
  4589  	}
  4590  	wantOnceOrMore := func(sub string) {
  4591  		if strings.Count(got, sub) == 0 {
  4592  			t.Errorf("expected substring %q at least once in output.", sub)
  4593  		}
  4594  	}
  4595  	wantOnce("Getting conn for dns-is-faked.golang:" + port)
  4596  	wantOnce("DNS start: {Host:dns-is-faked.golang}")
  4597  	wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
  4598  	wantOnce("got conn: {")
  4599  	wantOnceOrMore("Connecting to tcp " + addrStr)
  4600  	wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
  4601  	wantOnce("Reused:false WasIdle:false IdleTime:0s")
  4602  	wantOnce("first response byte")
  4603  	if h2 {
  4604  		wantOnce("tls handshake start")
  4605  		wantOnce("tls handshake done")
  4606  	} else {
  4607  		wantOnce("PutIdleConn = <nil>")
  4608  		wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
  4609  		// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
  4610  		// WroteHeaderField hook is not yet implemented in h2.)
  4611  		wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
  4612  		wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
  4613  		wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
  4614  		wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
  4615  	}
  4616  	wantOnce("WroteHeaders")
  4617  	wantOnce("Wait100Continue")
  4618  	wantOnce("Got100Continue")
  4619  	wantOnce("WroteRequest: {Err:<nil>}")
  4620  	if strings.Contains(got, " to udp ") {
  4621  		t.Errorf("should not see UDP (DNS) connections")
  4622  	}
  4623  	if t.Failed() {
  4624  		t.Errorf("Output:\n%s", got)
  4625  	}
  4626  
  4627  	// And do a second request:
  4628  	req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
  4629  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  4630  	res, err = cst.c.Do(req)
  4631  	if err != nil {
  4632  		t.Fatal(err)
  4633  	}
  4634  	if res.StatusCode != 200 {
  4635  		t.Fatal(res.Status)
  4636  	}
  4637  	res.Body.Close()
  4638  
  4639  	mu.Lock()
  4640  	got = buf.String()
  4641  	mu.Unlock()
  4642  
  4643  	sub := "Getting conn for dns-is-faked.golang:"
  4644  	if gotn, want := strings.Count(got, sub), 2; gotn != want {
  4645  		t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
  4646  	}
  4647  
  4648  }
  4649  
  4650  func TestTransportEventTraceTLSVerify(t *testing.T) {
  4651  	var mu sync.Mutex
  4652  	var buf bytes.Buffer
  4653  	logf := func(format string, args ...interface{}) {
  4654  		mu.Lock()
  4655  		defer mu.Unlock()
  4656  		fmt.Fprintf(&buf, format, args...)
  4657  		buf.WriteByte('\n')
  4658  	}
  4659  
  4660  	ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  4661  		t.Error("Unexpected request")
  4662  	}))
  4663  	defer ts.Close()
  4664  	ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
  4665  		logf("%s", p)
  4666  		return len(p), nil
  4667  	}), "", 0)
  4668  
  4669  	certpool := x509.NewCertPool()
  4670  	certpool.AddCert(ts.Certificate())
  4671  
  4672  	c := &Client{Transport: &Transport{
  4673  		TLSClientConfig: &tls.Config{
  4674  			ServerName: "dns-is-faked.golang",
  4675  			RootCAs:    certpool,
  4676  		},
  4677  	}}
  4678  
  4679  	trace := &httptrace.ClientTrace{
  4680  		TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
  4681  		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
  4682  			logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
  4683  		},
  4684  	}
  4685  
  4686  	req, _ := NewRequest("GET", ts.URL, nil)
  4687  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
  4688  	_, err := c.Do(req)
  4689  	if err == nil {
  4690  		t.Error("Expected request to fail TLS verification")
  4691  	}
  4692  
  4693  	mu.Lock()
  4694  	got := buf.String()
  4695  	mu.Unlock()
  4696  
  4697  	wantOnce := func(sub string) {
  4698  		if strings.Count(got, sub) != 1 {
  4699  			t.Errorf("expected substring %q exactly once in output.", sub)
  4700  		}
  4701  	}
  4702  
  4703  	wantOnce("TLSHandshakeStart")
  4704  	wantOnce("TLSHandshakeDone")
  4705  	wantOnce("err = x509: certificate is valid for example.com")
  4706  
  4707  	if t.Failed() {
  4708  		t.Errorf("Output:\n%s", got)
  4709  	}
  4710  }
  4711  
  4712  var (
  4713  	isDNSHijackedOnce sync.Once
  4714  	isDNSHijacked     bool
  4715  )
  4716  
  4717  func skipIfDNSHijacked(t *testing.T) {
  4718  	// Skip this test if the user is using a shady/ISP
  4719  	// DNS server hijacking queries.
  4720  	// See issues 16732, 16716.
  4721  	isDNSHijackedOnce.Do(func() {
  4722  		addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
  4723  		isDNSHijacked = len(addrs) != 0
  4724  	})
  4725  	if isDNSHijacked {
  4726  		t.Skip("skipping; test requires non-hijacking DNS server")
  4727  	}
  4728  }
  4729  
  4730  func TestTransportEventTraceRealDNS(t *testing.T) {
  4731  	skipIfDNSHijacked(t)
  4732  	defer afterTest(t)
  4733  	tr := &Transport{}
  4734  	defer tr.CloseIdleConnections()
  4735  	c := &Client{Transport: tr}
  4736  
  4737  	var mu sync.Mutex // guards buf
  4738  	var buf bytes.Buffer
  4739  	logf := func(format string, args ...interface{}) {
  4740  		mu.Lock()
  4741  		defer mu.Unlock()
  4742  		fmt.Fprintf(&buf, format, args...)
  4743  		buf.WriteByte('\n')
  4744  	}
  4745  
  4746  	req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
  4747  	trace := &httptrace.ClientTrace{
  4748  		DNSStart:     func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
  4749  		DNSDone:      func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
  4750  		ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
  4751  		ConnectDone:  func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
  4752  	}
  4753  	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
  4754  
  4755  	resp, err := c.Do(req)
  4756  	if err == nil {
  4757  		resp.Body.Close()
  4758  		t.Fatal("expected error during DNS lookup")
  4759  	}
  4760  
  4761  	mu.Lock()
  4762  	got := buf.String()
  4763  	mu.Unlock()
  4764  
  4765  	wantSub := func(sub string) {
  4766  		if !strings.Contains(got, sub) {
  4767  			t.Errorf("expected substring %q in output.", sub)
  4768  		}
  4769  	}
  4770  	wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
  4771  	wantSub("DNSDone: {Addrs:[] Err:")
  4772  	if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
  4773  		t.Errorf("should not see Connect events")
  4774  	}
  4775  	if t.Failed() {
  4776  		t.Errorf("Output:\n%s", got)
  4777  	}
  4778  }
  4779  
  4780  // Issue 14353: port can only contain digits.
  4781  func TestTransportRejectsAlphaPort(t *testing.T) {
  4782  	res, err := Get("http://dummy.tld:123foo/bar")
  4783  	if err == nil {
  4784  		res.Body.Close()
  4785  		t.Fatal("unexpected success")
  4786  	}
  4787  	ue, ok := err.(*url.Error)
  4788  	if !ok {
  4789  		t.Fatalf("got %#v; want *url.Error", err)
  4790  	}
  4791  	got := ue.Err.Error()
  4792  	want := `invalid port ":123foo" after host`
  4793  	if got != want {
  4794  		t.Errorf("got error %q; want %q", got, want)
  4795  	}
  4796  }
  4797  
  4798  // Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1
  4799  // connections. The http2 test is done in TestTransportEventTrace_h2
  4800  func TestTLSHandshakeTrace(t *testing.T) {
  4801  	defer afterTest(t)
  4802  	ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
  4803  	defer ts.Close()
  4804  
  4805  	var mu sync.Mutex
  4806  	var start, done bool
  4807  	trace := &httptrace.ClientTrace{
  4808  		TLSHandshakeStart: func() {
  4809  			mu.Lock()
  4810  			defer mu.Unlock()
  4811  			start = true
  4812  		},
  4813  		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
  4814  			mu.Lock()
  4815  			defer mu.Unlock()
  4816  			done = true
  4817  			if err != nil {
  4818  				t.Fatal("Expected error to be nil but was:", err)
  4819  			}
  4820  		},
  4821  	}
  4822  
  4823  	c := ts.Client()
  4824  	req, err := NewRequest("GET", ts.URL, nil)
  4825  	if err != nil {
  4826  		t.Fatal("Unable to construct test request:", err)
  4827  	}
  4828  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
  4829  
  4830  	r, err := c.Do(req)
  4831  	if err != nil {
  4832  		t.Fatal("Unexpected error making request:", err)
  4833  	}
  4834  	r.Body.Close()
  4835  	mu.Lock()
  4836  	defer mu.Unlock()
  4837  	if !start {
  4838  		t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
  4839  	}
  4840  	if !done {
  4841  		t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
  4842  	}
  4843  }
  4844  
  4845  func TestTransportMaxIdleConns(t *testing.T) {
  4846  	defer afterTest(t)
  4847  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  4848  		// No body for convenience.
  4849  	}))
  4850  	defer ts.Close()
  4851  	c := ts.Client()
  4852  	tr := c.Transport.(*Transport)
  4853  	tr.MaxIdleConns = 4
  4854  
  4855  	ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
  4856  	if err != nil {
  4857  		t.Fatal(err)
  4858  	}
  4859  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
  4860  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  4861  	})
  4862  
  4863  	hitHost := func(n int) {
  4864  		req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
  4865  		req = req.WithContext(ctx)
  4866  		res, err := c.Do(req)
  4867  		if err != nil {
  4868  			t.Fatal(err)
  4869  		}
  4870  		res.Body.Close()
  4871  	}
  4872  	for i := 0; i < 4; i++ {
  4873  		hitHost(i)
  4874  	}
  4875  	want := []string{
  4876  		"|http|host-0.dns-is-faked.golang:" + port,
  4877  		"|http|host-1.dns-is-faked.golang:" + port,
  4878  		"|http|host-2.dns-is-faked.golang:" + port,
  4879  		"|http|host-3.dns-is-faked.golang:" + port,
  4880  	}
  4881  	if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
  4882  		t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
  4883  	}
  4884  
  4885  	// Now hitting the 5th host should kick out the first host:
  4886  	hitHost(4)
  4887  	want = []string{
  4888  		"|http|host-1.dns-is-faked.golang:" + port,
  4889  		"|http|host-2.dns-is-faked.golang:" + port,
  4890  		"|http|host-3.dns-is-faked.golang:" + port,
  4891  		"|http|host-4.dns-is-faked.golang:" + port,
  4892  	}
  4893  	if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
  4894  		t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
  4895  	}
  4896  }
  4897  
  4898  func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) }
  4899  func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) }
  4900  func testTransportIdleConnTimeout(t *testing.T, h2 bool) {
  4901  	if testing.Short() {
  4902  		t.Skip("skipping in short mode")
  4903  	}
  4904  	defer afterTest(t)
  4905  
  4906  	const timeout = 1 * time.Second
  4907  
  4908  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  4909  		// No body for convenience.
  4910  	}))
  4911  	defer cst.close()
  4912  	tr := cst.tr
  4913  	tr.IdleConnTimeout = timeout
  4914  	defer tr.CloseIdleConnections()
  4915  	c := &Client{Transport: tr}
  4916  
  4917  	idleConns := func() []string {
  4918  		if h2 {
  4919  			return tr.IdleConnStrsForTesting_h2()
  4920  		} else {
  4921  			return tr.IdleConnStrsForTesting()
  4922  		}
  4923  	}
  4924  
  4925  	var conn string
  4926  	doReq := func(n int) {
  4927  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  4928  		req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  4929  			PutIdleConn: func(err error) {
  4930  				if err != nil {
  4931  					t.Errorf("failed to keep idle conn: %v", err)
  4932  				}
  4933  			},
  4934  		}))
  4935  		res, err := c.Do(req)
  4936  		if err != nil {
  4937  			t.Fatal(err)
  4938  		}
  4939  		res.Body.Close()
  4940  		conns := idleConns()
  4941  		if len(conns) != 1 {
  4942  			t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
  4943  		}
  4944  		if conn == "" {
  4945  			conn = conns[0]
  4946  		}
  4947  		if conn != conns[0] {
  4948  			t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n)
  4949  		}
  4950  	}
  4951  	for i := 0; i < 3; i++ {
  4952  		doReq(i)
  4953  		time.Sleep(timeout / 2)
  4954  	}
  4955  	time.Sleep(timeout * 3 / 2)
  4956  	if got := idleConns(); len(got) != 0 {
  4957  		t.Errorf("idle conns = %q; want none", got)
  4958  	}
  4959  }
  4960  
  4961  // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
  4962  // HTTP/2 connection was established but its caller no longer
  4963  // wanted it. (Assuming the connection cache was enabled, which it is
  4964  // by default)
  4965  //
  4966  // This test reproduced the crash by setting the IdleConnTimeout low
  4967  // (to make the test reasonable) and then making a request which is
  4968  // canceled by the DialTLS hook, which then also waits to return the
  4969  // real connection until after the RoundTrip saw the error.  Then we
  4970  // know the successful tls.Dial from DialTLS will need to go into the
  4971  // idle pool. Then we give it a of time to explode.
  4972  func TestIdleConnH2Crash(t *testing.T) {
  4973  	setParallel(t)
  4974  	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  4975  		// nothing
  4976  	}))
  4977  	defer cst.close()
  4978  
  4979  	ctx, cancel := context.WithCancel(context.Background())
  4980  	defer cancel()
  4981  
  4982  	sawDoErr := make(chan bool, 1)
  4983  	testDone := make(chan struct{})
  4984  	defer close(testDone)
  4985  
  4986  	cst.tr.IdleConnTimeout = 5 * time.Millisecond
  4987  	cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
  4988  		c, err := tls.Dial(network, addr, &tls.Config{
  4989  			InsecureSkipVerify: true,
  4990  			NextProtos:         []string{"h2"},
  4991  		})
  4992  		if err != nil {
  4993  			t.Error(err)
  4994  			return nil, err
  4995  		}
  4996  		if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
  4997  			t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
  4998  			c.Close()
  4999  			return nil, errors.New("bogus")
  5000  		}
  5001  
  5002  		cancel()
  5003  
  5004  		failTimer := time.NewTimer(5 * time.Second)
  5005  		defer failTimer.Stop()
  5006  		select {
  5007  		case <-sawDoErr:
  5008  		case <-testDone:
  5009  		case <-failTimer.C:
  5010  			t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail")
  5011  		}
  5012  		return c, nil
  5013  	}
  5014  
  5015  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  5016  	req = req.WithContext(ctx)
  5017  	res, err := cst.c.Do(req)
  5018  	if err == nil {
  5019  		res.Body.Close()
  5020  		t.Fatal("unexpected success")
  5021  	}
  5022  	sawDoErr <- true
  5023  
  5024  	// Wait for the explosion.
  5025  	time.Sleep(cst.tr.IdleConnTimeout * 10)
  5026  }
  5027  
  5028  type funcConn struct {
  5029  	net.Conn
  5030  	read  func([]byte) (int, error)
  5031  	write func([]byte) (int, error)
  5032  }
  5033  
  5034  func (c funcConn) Read(p []byte) (int, error)  { return c.read(p) }
  5035  func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
  5036  func (c funcConn) Close() error                { return nil }
  5037  
  5038  // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
  5039  // back to the caller.
  5040  func TestTransportReturnsPeekError(t *testing.T) {
  5041  	errValue := errors.New("specific error value")
  5042  
  5043  	wrote := make(chan struct{})
  5044  	var wroteOnce sync.Once
  5045  
  5046  	tr := &Transport{
  5047  		Dial: func(network, addr string) (net.Conn, error) {
  5048  			c := funcConn{
  5049  				read: func([]byte) (int, error) {
  5050  					<-wrote
  5051  					return 0, errValue
  5052  				},
  5053  				write: func(p []byte) (int, error) {
  5054  					wroteOnce.Do(func() { close(wrote) })
  5055  					return len(p), nil
  5056  				},
  5057  			}
  5058  			return c, nil
  5059  		},
  5060  	}
  5061  	_, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
  5062  	if err != errValue {
  5063  		t.Errorf("error = %#v; want %v", err, errValue)
  5064  	}
  5065  }
  5066  
  5067  // Issue 13835: international domain names should work
  5068  func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) }
  5069  func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) }
  5070  func testTransportIDNA(t *testing.T, h2 bool) {
  5071  	defer afterTest(t)
  5072  
  5073  	const uniDomain = "гофер.го"
  5074  	const punyDomain = "xn--c1ae0ajs.xn--c1aw"
  5075  
  5076  	var port string
  5077  	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
  5078  		want := punyDomain + ":" + port
  5079  		if r.Host != want {
  5080  			t.Errorf("Host header = %q; want %q", r.Host, want)
  5081  		}
  5082  		if h2 {
  5083  			if r.TLS == nil {
  5084  				t.Errorf("r.TLS == nil")
  5085  			} else if r.TLS.ServerName != punyDomain {
  5086  				t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
  5087  			}
  5088  		}
  5089  		w.Header().Set("Hit-Handler", "1")
  5090  	}))
  5091  	defer cst.close()
  5092  
  5093  	ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
  5094  	if err != nil {
  5095  		t.Fatal(err)
  5096  	}
  5097  
  5098  	// Install a fake DNS server.
  5099  	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
  5100  		if host != punyDomain {
  5101  			t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
  5102  			return nil, nil
  5103  		}
  5104  		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
  5105  	})
  5106  
  5107  	req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
  5108  	trace := &httptrace.ClientTrace{
  5109  		GetConn: func(hostPort string) {
  5110  			want := net.JoinHostPort(punyDomain, port)
  5111  			if hostPort != want {
  5112  				t.Errorf("getting conn for %q; want %q", hostPort, want)
  5113  			}
  5114  		},
  5115  		DNSStart: func(e httptrace.DNSStartInfo) {
  5116  			if e.Host != punyDomain {
  5117  				t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
  5118  			}
  5119  		},
  5120  	}
  5121  	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
  5122  
  5123  	res, err := cst.tr.RoundTrip(req)
  5124  	if err != nil {
  5125  		t.Fatal(err)
  5126  	}
  5127  	defer res.Body.Close()
  5128  	if res.Header.Get("Hit-Handler") != "1" {
  5129  		out, err := httputil.DumpResponse(res, true)
  5130  		if err != nil {
  5131  			t.Fatal(err)
  5132  		}
  5133  		t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
  5134  	}
  5135  }
  5136  
  5137  // Issue 13290: send User-Agent in proxy CONNECT
  5138  func TestTransportProxyConnectHeader(t *testing.T) {
  5139  	defer afterTest(t)
  5140  	reqc := make(chan *Request, 1)
  5141  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  5142  		if r.Method != "CONNECT" {
  5143  			t.Errorf("method = %q; want CONNECT", r.Method)
  5144  		}
  5145  		reqc <- r
  5146  		c, _, err := w.(Hijacker).Hijack()
  5147  		if err != nil {
  5148  			t.Errorf("Hijack: %v", err)
  5149  			return
  5150  		}
  5151  		c.Close()
  5152  	}))
  5153  	defer ts.Close()
  5154  
  5155  	c := ts.Client()
  5156  	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
  5157  		return url.Parse(ts.URL)
  5158  	}
  5159  	c.Transport.(*Transport).ProxyConnectHeader = Header{
  5160  		"User-Agent": {"foo"},
  5161  		"Other":      {"bar"},
  5162  	}
  5163  
  5164  	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
  5165  	if err == nil {
  5166  		res.Body.Close()
  5167  		t.Errorf("unexpected success")
  5168  	}
  5169  	select {
  5170  	case <-time.After(3 * time.Second):
  5171  		t.Fatal("timeout")
  5172  	case r := <-reqc:
  5173  		if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
  5174  			t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
  5175  		}
  5176  		if got, want := r.Header.Get("Other"), "bar"; got != want {
  5177  			t.Errorf("CONNECT request Other = %q; want %q", got, want)
  5178  		}
  5179  	}
  5180  }
  5181  
  5182  func TestTransportProxyGetConnectHeader(t *testing.T) {
  5183  	defer afterTest(t)
  5184  	reqc := make(chan *Request, 1)
  5185  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  5186  		if r.Method != "CONNECT" {
  5187  			t.Errorf("method = %q; want CONNECT", r.Method)
  5188  		}
  5189  		reqc <- r
  5190  		c, _, err := w.(Hijacker).Hijack()
  5191  		if err != nil {
  5192  			t.Errorf("Hijack: %v", err)
  5193  			return
  5194  		}
  5195  		c.Close()
  5196  	}))
  5197  	defer ts.Close()
  5198  
  5199  	c := ts.Client()
  5200  	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
  5201  		return url.Parse(ts.URL)
  5202  	}
  5203  	// These should be ignored:
  5204  	c.Transport.(*Transport).ProxyConnectHeader = Header{
  5205  		"User-Agent": {"foo"},
  5206  		"Other":      {"bar"},
  5207  	}
  5208  	c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
  5209  		return Header{
  5210  			"User-Agent": {"foo2"},
  5211  			"Other":      {"bar2"},
  5212  		}, nil
  5213  	}
  5214  
  5215  	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
  5216  	if err == nil {
  5217  		res.Body.Close()
  5218  		t.Errorf("unexpected success")
  5219  	}
  5220  	select {
  5221  	case <-time.After(3 * time.Second):
  5222  		t.Fatal("timeout")
  5223  	case r := <-reqc:
  5224  		if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
  5225  			t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
  5226  		}
  5227  		if got, want := r.Header.Get("Other"), "bar2"; got != want {
  5228  			t.Errorf("CONNECT request Other = %q; want %q", got, want)
  5229  		}
  5230  	}
  5231  }
  5232  
  5233  var errFakeRoundTrip = errors.New("fake roundtrip")
  5234  
  5235  type funcRoundTripper func()
  5236  
  5237  func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
  5238  	fn()
  5239  	return nil, errFakeRoundTrip
  5240  }
  5241  
  5242  func wantBody(res *Response, err error, want string) error {
  5243  	if err != nil {
  5244  		return err
  5245  	}
  5246  	slurp, err := io.ReadAll(res.Body)
  5247  	if err != nil {
  5248  		return fmt.Errorf("error reading body: %v", err)
  5249  	}
  5250  	if string(slurp) != want {
  5251  		return fmt.Errorf("body = %q; want %q", slurp, want)
  5252  	}
  5253  	if err := res.Body.Close(); err != nil {
  5254  		return fmt.Errorf("body Close = %v", err)
  5255  	}
  5256  	return nil
  5257  }
  5258  
  5259  func newLocalListener(t *testing.T) net.Listener {
  5260  	ln, err := net.Listen("tcp", "127.0.0.1:0")
  5261  	if err != nil {
  5262  		ln, err = net.Listen("tcp6", "[::1]:0")
  5263  	}
  5264  	if err != nil {
  5265  		t.Fatal(err)
  5266  	}
  5267  	return ln
  5268  }
  5269  
  5270  type countCloseReader struct {
  5271  	n *int
  5272  	io.Reader
  5273  }
  5274  
  5275  func (cr countCloseReader) Close() error {
  5276  	(*cr.n)++
  5277  	return nil
  5278  }
  5279  
  5280  // rgz is a gzip quine that uncompresses to itself.
  5281  var rgz = []byte{
  5282  	0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
  5283  	0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
  5284  	0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
  5285  	0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
  5286  	0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
  5287  	0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
  5288  	0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
  5289  	0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
  5290  	0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
  5291  	0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
  5292  	0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
  5293  	0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
  5294  	0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
  5295  	0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
  5296  	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
  5297  	0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
  5298  	0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
  5299  	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
  5300  	0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
  5301  	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
  5302  	0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
  5303  	0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
  5304  	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
  5305  	0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
  5306  	0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
  5307  	0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
  5308  	0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
  5309  	0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
  5310  	0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
  5311  	0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
  5312  	0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
  5313  	0x00, 0x00,
  5314  }
  5315  
  5316  // Ensure that a missing status doesn't make the server panic
  5317  // See Issue https://golang.org/issues/21701
  5318  func TestMissingStatusNoPanic(t *testing.T) {
  5319  	t.Parallel()
  5320  
  5321  	const want = "unknown status code"
  5322  
  5323  	ln := newLocalListener(t)
  5324  	addr := ln.Addr().String()
  5325  	done := make(chan bool)
  5326  	fullAddrURL := fmt.Sprintf("http://%s", addr)
  5327  	raw := "HTTP/1.1 400\r\n" +
  5328  		"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
  5329  		"Content-Type: text/html; charset=utf-8\r\n" +
  5330  		"Content-Length: 10\r\n" +
  5331  		"Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
  5332  		"Vary: Accept-Encoding\r\n\r\n" +
  5333  		"Aloha Olaa"
  5334  
  5335  	go func() {
  5336  		defer close(done)
  5337  
  5338  		conn, _ := ln.Accept()
  5339  		if conn != nil {
  5340  			io.WriteString(conn, raw)
  5341  			io.ReadAll(conn)
  5342  			conn.Close()
  5343  		}
  5344  	}()
  5345  
  5346  	proxyURL, err := url.Parse(fullAddrURL)
  5347  	if err != nil {
  5348  		t.Fatalf("proxyURL: %v", err)
  5349  	}
  5350  
  5351  	tr := &Transport{Proxy: ProxyURL(proxyURL)}
  5352  
  5353  	req, _ := NewRequest("GET", "https://golang.org/", nil)
  5354  	res, err, panicked := doFetchCheckPanic(tr, req)
  5355  	if panicked {
  5356  		t.Error("panicked, expecting an error")
  5357  	}
  5358  	if res != nil && res.Body != nil {
  5359  		io.Copy(io.Discard, res.Body)
  5360  		res.Body.Close()
  5361  	}
  5362  
  5363  	if err == nil || !strings.Contains(err.Error(), want) {
  5364  		t.Errorf("got=%v want=%q", err, want)
  5365  	}
  5366  
  5367  	ln.Close()
  5368  	<-done
  5369  }
  5370  
  5371  func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
  5372  	defer func() {
  5373  		if r := recover(); r != nil {
  5374  			panicked = true
  5375  		}
  5376  	}()
  5377  	res, err = tr.RoundTrip(req)
  5378  	return
  5379  }
  5380  
  5381  // Issue 22330: do not allow the response body to be read when the status code
  5382  // forbids a response body.
  5383  func TestNoBodyOnChunked304Response(t *testing.T) {
  5384  	defer afterTest(t)
  5385  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5386  		conn, buf, _ := w.(Hijacker).Hijack()
  5387  		buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
  5388  		buf.Flush()
  5389  		conn.Close()
  5390  	}))
  5391  	defer cst.close()
  5392  
  5393  	// Our test server above is sending back bogus data after the
  5394  	// response (the "0\r\n\r\n" part), which causes the Transport
  5395  	// code to log spam. Disable keep-alives so we never even try
  5396  	// to reuse the connection.
  5397  	cst.tr.DisableKeepAlives = true
  5398  
  5399  	res, err := cst.c.Get(cst.ts.URL)
  5400  	if err != nil {
  5401  		t.Fatal(err)
  5402  	}
  5403  
  5404  	if res.Body != NoBody {
  5405  		t.Errorf("Unexpected body on 304 response")
  5406  	}
  5407  }
  5408  
  5409  type funcWriter func([]byte) (int, error)
  5410  
  5411  func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
  5412  
  5413  type doneContext struct {
  5414  	context.Context
  5415  	err error
  5416  }
  5417  
  5418  func (doneContext) Done() <-chan struct{} {
  5419  	c := make(chan struct{})
  5420  	close(c)
  5421  	return c
  5422  }
  5423  
  5424  func (d doneContext) Err() error { return d.err }
  5425  
  5426  // Issue 25852: Transport should check whether Context is done early.
  5427  func TestTransportCheckContextDoneEarly(t *testing.T) {
  5428  	tr := &Transport{}
  5429  	req, _ := NewRequest("GET", "http://fake.example/", nil)
  5430  	wantErr := errors.New("some error")
  5431  	req = req.WithContext(doneContext{context.Background(), wantErr})
  5432  	_, err := tr.RoundTrip(req)
  5433  	if err != wantErr {
  5434  		t.Errorf("error = %v; want %v", err, wantErr)
  5435  	}
  5436  }
  5437  
  5438  // Issue 23399: verify that if a client request times out, the Transport's
  5439  // conn is closed so that it's not reused.
  5440  //
  5441  // This is the test variant that times out before the server replies with
  5442  // any response headers.
  5443  func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
  5444  	setParallel(t)
  5445  	defer afterTest(t)
  5446  	inHandler := make(chan net.Conn, 1)
  5447  	handlerReadReturned := make(chan bool, 1)
  5448  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5449  		conn, _, err := w.(Hijacker).Hijack()
  5450  		if err != nil {
  5451  			t.Error(err)
  5452  			return
  5453  		}
  5454  		inHandler <- conn
  5455  		n, err := conn.Read([]byte{0})
  5456  		if n != 0 || err != io.EOF {
  5457  			t.Errorf("unexpected Read result: %v, %v", n, err)
  5458  		}
  5459  		handlerReadReturned <- true
  5460  	}))
  5461  	defer cst.close()
  5462  
  5463  	const timeout = 50 * time.Millisecond
  5464  	cst.c.Timeout = timeout
  5465  
  5466  	_, err := cst.c.Get(cst.ts.URL)
  5467  	if err == nil {
  5468  		t.Fatal("unexpected Get succeess")
  5469  	}
  5470  
  5471  	select {
  5472  	case c := <-inHandler:
  5473  		select {
  5474  		case <-handlerReadReturned:
  5475  			// Success.
  5476  			return
  5477  		case <-time.After(5 * time.Second):
  5478  			t.Error("Handler's conn.Read seems to be stuck in Read")
  5479  			c.Close() // close it to unblock Handler
  5480  		}
  5481  	case <-time.After(timeout * 10):
  5482  		// If we didn't get into the Handler in 50ms, that probably means
  5483  		// the builder was just slow and the Get failed in that time
  5484  		// but never made it to the server. That's fine. We'll usually
  5485  		// test the part above on faster machines.
  5486  		t.Skip("skipping test on slow builder")
  5487  	}
  5488  }
  5489  
  5490  // Issue 23399: verify that if a client request times out, the Transport's
  5491  // conn is closed so that it's not reused.
  5492  //
  5493  // This is the test variant that has the server send response headers
  5494  // first, and time out during the write of the response body.
  5495  func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
  5496  	setParallel(t)
  5497  	defer afterTest(t)
  5498  	inHandler := make(chan net.Conn, 1)
  5499  	handlerResult := make(chan error, 1)
  5500  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5501  		w.Header().Set("Content-Length", "100")
  5502  		w.(Flusher).Flush()
  5503  		conn, _, err := w.(Hijacker).Hijack()
  5504  		if err != nil {
  5505  			t.Error(err)
  5506  			return
  5507  		}
  5508  		conn.Write([]byte("foo"))
  5509  		inHandler <- conn
  5510  		n, err := conn.Read([]byte{0})
  5511  		// The error should be io.EOF or "read tcp
  5512  		// 127.0.0.1:35827->127.0.0.1:40290: read: connection
  5513  		// reset by peer" depending on timing. Really we just
  5514  		// care that it returns at all. But if it returns with
  5515  		// data, that's weird.
  5516  		if n != 0 || err == nil {
  5517  			handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err)
  5518  			return
  5519  		}
  5520  		handlerResult <- nil
  5521  	}))
  5522  	defer cst.close()
  5523  
  5524  	// Set Timeout to something very long but non-zero to exercise
  5525  	// the codepaths that check for it. But rather than wait for it to fire
  5526  	// (which would make the test slow), we send on the req.Cancel channel instead,
  5527  	// which happens to exercise the same code paths.
  5528  	cst.c.Timeout = time.Minute // just to be non-zero, not to hit it.
  5529  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  5530  	cancel := make(chan struct{})
  5531  	req.Cancel = cancel
  5532  
  5533  	res, err := cst.c.Do(req)
  5534  	if err != nil {
  5535  		select {
  5536  		case <-inHandler:
  5537  			t.Fatalf("Get error: %v", err)
  5538  		default:
  5539  			// Failed before entering handler. Ignore result.
  5540  			t.Skip("skipping test on slow builder")
  5541  		}
  5542  	}
  5543  
  5544  	close(cancel)
  5545  	got, err := io.ReadAll(res.Body)
  5546  	if err == nil {
  5547  		t.Fatalf("unexpected success; read %q, nil", got)
  5548  	}
  5549  
  5550  	select {
  5551  	case c := <-inHandler:
  5552  		select {
  5553  		case err := <-handlerResult:
  5554  			if err != nil {
  5555  				t.Errorf("handler: %v", err)
  5556  			}
  5557  			return
  5558  		case <-time.After(5 * time.Second):
  5559  			t.Error("Handler's conn.Read seems to be stuck in Read")
  5560  			c.Close() // close it to unblock Handler
  5561  		}
  5562  	case <-time.After(5 * time.Second):
  5563  		t.Fatal("timeout")
  5564  	}
  5565  }
  5566  
  5567  func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
  5568  	setParallel(t)
  5569  	defer afterTest(t)
  5570  	done := make(chan struct{})
  5571  	defer close(done)
  5572  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5573  		conn, _, err := w.(Hijacker).Hijack()
  5574  		if err != nil {
  5575  			t.Error(err)
  5576  			return
  5577  		}
  5578  		defer conn.Close()
  5579  		io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
  5580  		bs := bufio.NewScanner(conn)
  5581  		bs.Scan()
  5582  		fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
  5583  		<-done
  5584  	}))
  5585  	defer cst.close()
  5586  
  5587  	req, _ := NewRequest("GET", cst.ts.URL, nil)
  5588  	req.Header.Set("Upgrade", "foo")
  5589  	req.Header.Set("Connection", "upgrade")
  5590  	res, err := cst.c.Do(req)
  5591  	if err != nil {
  5592  		t.Fatal(err)
  5593  	}
  5594  	if res.StatusCode != 101 {
  5595  		t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
  5596  	}
  5597  	rwc, ok := res.Body.(io.ReadWriteCloser)
  5598  	if !ok {
  5599  		t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
  5600  	}
  5601  	defer rwc.Close()
  5602  	bs := bufio.NewScanner(rwc)
  5603  	if !bs.Scan() {
  5604  		t.Fatalf("expected readable input")
  5605  	}
  5606  	if got, want := bs.Text(), "Some buffered data"; got != want {
  5607  		t.Errorf("read %q; want %q", got, want)
  5608  	}
  5609  	io.WriteString(rwc, "echo\n")
  5610  	if !bs.Scan() {
  5611  		t.Fatalf("expected another line")
  5612  	}
  5613  	if got, want := bs.Text(), "ECHO"; got != want {
  5614  		t.Errorf("read %q; want %q", got, want)
  5615  	}
  5616  }
  5617  
  5618  func TestTransportCONNECTBidi(t *testing.T) {
  5619  	defer afterTest(t)
  5620  	const target = "backend:443"
  5621  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5622  		if r.Method != "CONNECT" {
  5623  			t.Errorf("unexpected method %q", r.Method)
  5624  			w.WriteHeader(500)
  5625  			return
  5626  		}
  5627  		if r.RequestURI != target {
  5628  			t.Errorf("unexpected CONNECT target %q", r.RequestURI)
  5629  			w.WriteHeader(500)
  5630  			return
  5631  		}
  5632  		nc, brw, err := w.(Hijacker).Hijack()
  5633  		if err != nil {
  5634  			t.Error(err)
  5635  			return
  5636  		}
  5637  		defer nc.Close()
  5638  		nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
  5639  		// Switch to a little protocol that capitalize its input lines:
  5640  		for {
  5641  			line, err := brw.ReadString('\n')
  5642  			if err != nil {
  5643  				if err != io.EOF {
  5644  					t.Error(err)
  5645  				}
  5646  				return
  5647  			}
  5648  			io.WriteString(brw, strings.ToUpper(line))
  5649  			brw.Flush()
  5650  		}
  5651  	}))
  5652  	defer cst.close()
  5653  	pr, pw := io.Pipe()
  5654  	defer pw.Close()
  5655  	req, err := NewRequest("CONNECT", cst.ts.URL, pr)
  5656  	if err != nil {
  5657  		t.Fatal(err)
  5658  	}
  5659  	req.URL.Opaque = target
  5660  	res, err := cst.c.Do(req)
  5661  	if err != nil {
  5662  		t.Fatal(err)
  5663  	}
  5664  	defer res.Body.Close()
  5665  	if res.StatusCode != 200 {
  5666  		t.Fatalf("status code = %d; want 200", res.StatusCode)
  5667  	}
  5668  	br := bufio.NewReader(res.Body)
  5669  	for _, str := range []string{"foo", "bar", "baz"} {
  5670  		fmt.Fprintf(pw, "%s\n", str)
  5671  		got, err := br.ReadString('\n')
  5672  		if err != nil {
  5673  			t.Fatal(err)
  5674  		}
  5675  		got = strings.TrimSpace(got)
  5676  		want := strings.ToUpper(str)
  5677  		if got != want {
  5678  			t.Fatalf("got %q; want %q", got, want)
  5679  		}
  5680  	}
  5681  }
  5682  
  5683  func TestTransportRequestReplayable(t *testing.T) {
  5684  	someBody := io.NopCloser(strings.NewReader(""))
  5685  	tests := []struct {
  5686  		name string
  5687  		req  *Request
  5688  		want bool
  5689  	}{
  5690  		{
  5691  			name: "GET",
  5692  			req:  &Request{Method: "GET"},
  5693  			want: true,
  5694  		},
  5695  		{
  5696  			name: "GET_http.NoBody",
  5697  			req:  &Request{Method: "GET", Body: NoBody},
  5698  			want: true,
  5699  		},
  5700  		{
  5701  			name: "GET_body",
  5702  			req:  &Request{Method: "GET", Body: someBody},
  5703  			want: false,
  5704  		},
  5705  		{
  5706  			name: "POST",
  5707  			req:  &Request{Method: "POST"},
  5708  			want: false,
  5709  		},
  5710  		{
  5711  			name: "POST_idempotency-key",
  5712  			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
  5713  			want: true,
  5714  		},
  5715  		{
  5716  			name: "POST_x-idempotency-key",
  5717  			req:  &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
  5718  			want: true,
  5719  		},
  5720  		{
  5721  			name: "POST_body",
  5722  			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
  5723  			want: false,
  5724  		},
  5725  	}
  5726  	for _, tt := range tests {
  5727  		t.Run(tt.name, func(t *testing.T) {
  5728  			got := tt.req.ExportIsReplayable()
  5729  			if got != tt.want {
  5730  				t.Errorf("replyable = %v; want %v", got, tt.want)
  5731  			}
  5732  		})
  5733  	}
  5734  }
  5735  
  5736  // testMockTCPConn is a mock TCP connection used to test that
  5737  // ReadFrom is called when sending the request body.
  5738  type testMockTCPConn struct {
  5739  	*net.TCPConn
  5740  
  5741  	ReadFromCalled bool
  5742  }
  5743  
  5744  func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
  5745  	c.ReadFromCalled = true
  5746  	return c.TCPConn.ReadFrom(r)
  5747  }
  5748  
  5749  func TestTransportRequestWriteRoundTrip(t *testing.T) {
  5750  	nBytes := int64(1 << 10)
  5751  	newFileFunc := func() (r io.Reader, done func(), err error) {
  5752  		f, err := os.CreateTemp("", "net-http-newfilefunc")
  5753  		if err != nil {
  5754  			return nil, nil, err
  5755  		}
  5756  
  5757  		// Write some bytes to the file to enable reading.
  5758  		if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
  5759  			return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
  5760  		}
  5761  		if _, err := f.Seek(0, 0); err != nil {
  5762  			return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
  5763  		}
  5764  
  5765  		done = func() {
  5766  			f.Close()
  5767  			os.Remove(f.Name())
  5768  		}
  5769  
  5770  		return f, done, nil
  5771  	}
  5772  
  5773  	newBufferFunc := func() (io.Reader, func(), error) {
  5774  		return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
  5775  	}
  5776  
  5777  	cases := []struct {
  5778  		name             string
  5779  		readerFunc       func() (io.Reader, func(), error)
  5780  		contentLength    int64
  5781  		expectedReadFrom bool
  5782  	}{
  5783  		{
  5784  			name:             "file, length",
  5785  			readerFunc:       newFileFunc,
  5786  			contentLength:    nBytes,
  5787  			expectedReadFrom: true,
  5788  		},
  5789  		{
  5790  			name:       "file, no length",
  5791  			readerFunc: newFileFunc,
  5792  		},
  5793  		{
  5794  			name:          "file, negative length",
  5795  			readerFunc:    newFileFunc,
  5796  			contentLength: -1,
  5797  		},
  5798  		{
  5799  			name:          "buffer",
  5800  			contentLength: nBytes,
  5801  			readerFunc:    newBufferFunc,
  5802  		},
  5803  		{
  5804  			name:       "buffer, no length",
  5805  			readerFunc: newBufferFunc,
  5806  		},
  5807  		{
  5808  			name:          "buffer, length -1",
  5809  			contentLength: -1,
  5810  			readerFunc:    newBufferFunc,
  5811  		},
  5812  	}
  5813  
  5814  	for _, tc := range cases {
  5815  		t.Run(tc.name, func(t *testing.T) {
  5816  			r, cleanup, err := tc.readerFunc()
  5817  			if err != nil {
  5818  				t.Fatal(err)
  5819  			}
  5820  			defer cleanup()
  5821  
  5822  			tConn := &testMockTCPConn{}
  5823  			trFunc := func(tr *Transport) {
  5824  				tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  5825  					var d net.Dialer
  5826  					conn, err := d.DialContext(ctx, network, addr)
  5827  					if err != nil {
  5828  						return nil, err
  5829  					}
  5830  
  5831  					tcpConn, ok := conn.(*net.TCPConn)
  5832  					if !ok {
  5833  						return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
  5834  					}
  5835  
  5836  					tConn.TCPConn = tcpConn
  5837  					return tConn, nil
  5838  				}
  5839  			}
  5840  
  5841  			cst := newClientServerTest(
  5842  				t,
  5843  				h1Mode,
  5844  				HandlerFunc(func(w ResponseWriter, r *Request) {
  5845  					io.Copy(io.Discard, r.Body)
  5846  					r.Body.Close()
  5847  					w.WriteHeader(200)
  5848  				}),
  5849  				trFunc,
  5850  			)
  5851  			defer cst.close()
  5852  
  5853  			req, err := NewRequest("PUT", cst.ts.URL, r)
  5854  			if err != nil {
  5855  				t.Fatal(err)
  5856  			}
  5857  			req.ContentLength = tc.contentLength
  5858  			req.Header.Set("Content-Type", "application/octet-stream")
  5859  			resp, err := cst.c.Do(req)
  5860  			if err != nil {
  5861  				t.Fatal(err)
  5862  			}
  5863  			defer resp.Body.Close()
  5864  			if resp.StatusCode != 200 {
  5865  				t.Fatalf("status code = %d; want 200", resp.StatusCode)
  5866  			}
  5867  
  5868  			if !tConn.ReadFromCalled && tc.expectedReadFrom {
  5869  				t.Fatalf("did not call ReadFrom")
  5870  			}
  5871  
  5872  			if tConn.ReadFromCalled && !tc.expectedReadFrom {
  5873  				t.Fatalf("ReadFrom was unexpectedly invoked")
  5874  			}
  5875  		})
  5876  	}
  5877  }
  5878  
  5879  func TestTransportClone(t *testing.T) {
  5880  	tr := &Transport{
  5881  		Proxy:                  func(*Request) (*url.URL, error) { panic("") },
  5882  		DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
  5883  		Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
  5884  		DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },
  5885  		DialTLSContext:         func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
  5886  		TLSClientConfig:        new(tls.Config),
  5887  		TLSHandshakeTimeout:    time.Second,
  5888  		DisableKeepAlives:      true,
  5889  		DisableCompression:     true,
  5890  		MaxIdleConns:           1,
  5891  		MaxIdleConnsPerHost:    1,
  5892  		MaxConnsPerHost:        1,
  5893  		IdleConnTimeout:        time.Second,
  5894  		ResponseHeaderTimeout:  time.Second,
  5895  		ExpectContinueTimeout:  time.Second,
  5896  		ProxyConnectHeader:     Header{},
  5897  		GetProxyConnectHeader:  func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
  5898  		MaxResponseHeaderBytes: 1,
  5899  		ForceAttemptHTTP2:      true,
  5900  		TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
  5901  			"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
  5902  		},
  5903  		ReadBufferSize:  1,
  5904  		WriteBufferSize: 1,
  5905  	}
  5906  	tr2 := tr.Clone()
  5907  	rv := reflect.ValueOf(tr2).Elem()
  5908  	rt := rv.Type()
  5909  	for i := 0; i < rt.NumField(); i++ {
  5910  		sf := rt.Field(i)
  5911  		if !token.IsExported(sf.Name) {
  5912  			continue
  5913  		}
  5914  		if rv.Field(i).IsZero() {
  5915  			t.Errorf("cloned field t2.%s is zero", sf.Name)
  5916  		}
  5917  	}
  5918  
  5919  	if _, ok := tr2.TLSNextProto["foo"]; !ok {
  5920  		t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
  5921  	}
  5922  
  5923  	// But test that a nil TLSNextProto is kept nil:
  5924  	tr = new(Transport)
  5925  	tr2 = tr.Clone()
  5926  	if tr2.TLSNextProto != nil {
  5927  		t.Errorf("Transport.TLSNextProto unexpected non-nil")
  5928  	}
  5929  }
  5930  
  5931  func TestIs408(t *testing.T) {
  5932  	tests := []struct {
  5933  		in   string
  5934  		want bool
  5935  	}{
  5936  		{"HTTP/1.0 408", true},
  5937  		{"HTTP/1.1 408", true},
  5938  		{"HTTP/1.8 408", true},
  5939  		{"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
  5940  		{"HTTP/1.1 408 ", true},
  5941  		{"HTTP/1.1 40", false},
  5942  		{"http/1.0 408", false},
  5943  		{"HTTP/1-1 408", false},
  5944  	}
  5945  	for _, tt := range tests {
  5946  		if got := Export_is408Message([]byte(tt.in)); got != tt.want {
  5947  			t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
  5948  		}
  5949  	}
  5950  }
  5951  
  5952  func TestTransportIgnores408(t *testing.T) {
  5953  	// Not parallel. Relies on mutating the log package's global Output.
  5954  	defer log.SetOutput(log.Writer())
  5955  
  5956  	var logout bytes.Buffer
  5957  	log.SetOutput(&logout)
  5958  
  5959  	defer afterTest(t)
  5960  	const target = "backend:443"
  5961  
  5962  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  5963  		nc, _, err := w.(Hijacker).Hijack()
  5964  		if err != nil {
  5965  			t.Error(err)
  5966  			return
  5967  		}
  5968  		defer nc.Close()
  5969  		nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
  5970  		nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
  5971  	}))
  5972  	defer cst.close()
  5973  	req, err := NewRequest("GET", cst.ts.URL, nil)
  5974  	if err != nil {
  5975  		t.Fatal(err)
  5976  	}
  5977  	res, err := cst.c.Do(req)
  5978  	if err != nil {
  5979  		t.Fatal(err)
  5980  	}
  5981  	slurp, err := io.ReadAll(res.Body)
  5982  	if err != nil {
  5983  		t.Fatal(err)
  5984  	}
  5985  	if err != nil {
  5986  		t.Fatal(err)
  5987  	}
  5988  	if string(slurp) != "ok" {
  5989  		t.Fatalf("got %q; want ok", slurp)
  5990  	}
  5991  
  5992  	t0 := time.Now()
  5993  	for i := 0; i < 50; i++ {
  5994  		time.Sleep(time.Duration(i) * 5 * time.Millisecond)
  5995  		if cst.tr.IdleConnKeyCountForTesting() == 0 {
  5996  			if got := logout.String(); got != "" {
  5997  				t.Fatalf("expected no log output; got: %s", got)
  5998  			}
  5999  			return
  6000  		}
  6001  	}
  6002  	t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0))
  6003  }
  6004  
  6005  func TestInvalidHeaderResponse(t *testing.T) {
  6006  	setParallel(t)
  6007  	defer afterTest(t)
  6008  	cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  6009  		conn, buf, _ := w.(Hijacker).Hijack()
  6010  		buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
  6011  			"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
  6012  			"Content-Type: text/html; charset=utf-8\r\n" +
  6013  			"Content-Length: 0\r\n" +
  6014  			"Foo : bar\r\n\r\n"))
  6015  		buf.Flush()
  6016  		conn.Close()
  6017  	}))
  6018  	defer cst.close()
  6019  	res, err := cst.c.Get(cst.ts.URL)
  6020  	if err != nil {
  6021  		t.Fatal(err)
  6022  	}
  6023  	defer res.Body.Close()
  6024  	if v := res.Header.Get("Foo"); v != "" {
  6025  		t.Errorf(`unexpected "Foo" header: %q`, v)
  6026  	}
  6027  	if v := res.Header.Get("Foo "); v != "bar" {
  6028  		t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
  6029  	}
  6030  }
  6031  
  6032  type bodyCloser bool
  6033  
  6034  func (bc *bodyCloser) Close() error {
  6035  	*bc = true
  6036  	return nil
  6037  }
  6038  func (bc *bodyCloser) Read(b []byte) (n int, err error) {
  6039  	return 0, io.EOF
  6040  }
  6041  
  6042  // Issue 35015: ensure that Transport closes the body on any error
  6043  // with an invalid request, as promised by Client.Do docs.
  6044  func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
  6045  	cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  6046  		t.Errorf("Should not have been invoked")
  6047  	}))
  6048  	defer cst.Close()
  6049  
  6050  	u, _ := url.Parse(cst.URL)
  6051  
  6052  	tests := []struct {
  6053  		name    string
  6054  		req     *Request
  6055  		wantErr string
  6056  	}{
  6057  		{
  6058  			name: "invalid method",
  6059  			req: &Request{
  6060  				Method: " ",
  6061  				URL:    u,
  6062  			},
  6063  			wantErr: "invalid method",
  6064  		},
  6065  		{
  6066  			name: "nil URL",
  6067  			req: &Request{
  6068  				Method: "GET",
  6069  			},
  6070  			wantErr: "nil Request.URL",
  6071  		},
  6072  		{
  6073  			name: "invalid header key",
  6074  			req: &Request{
  6075  				Method: "GET",
  6076  				Header: Header{"💡": {"emoji"}},
  6077  				URL:    u,
  6078  			},
  6079  			wantErr: "invalid header field name",
  6080  		},
  6081  		{
  6082  			name: "invalid header value",
  6083  			req: &Request{
  6084  				Method: "POST",
  6085  				Header: Header{"key": {"\x19"}},
  6086  				URL:    u,
  6087  			},
  6088  			wantErr: "invalid header field value",
  6089  		},
  6090  		{
  6091  			name: "non HTTP(s) scheme",
  6092  			req: &Request{
  6093  				Method: "POST",
  6094  				URL:    &url.URL{Scheme: "faux"},
  6095  			},
  6096  			wantErr: "unsupported protocol scheme",
  6097  		},
  6098  		{
  6099  			name: "no Host in URL",
  6100  			req: &Request{
  6101  				Method: "POST",
  6102  				URL:    &url.URL{Scheme: "http"},
  6103  			},
  6104  			wantErr: "no Host",
  6105  		},
  6106  	}
  6107  
  6108  	for _, tt := range tests {
  6109  		t.Run(tt.name, func(t *testing.T) {
  6110  			var bc bodyCloser
  6111  			req := tt.req
  6112  			req.Body = &bc
  6113  			_, err := DefaultClient.Do(tt.req)
  6114  			if err == nil {
  6115  				t.Fatal("Expected an error")
  6116  			}
  6117  			if !bc {
  6118  				t.Fatal("Expected body to have been closed")
  6119  			}
  6120  			if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
  6121  				t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w)
  6122  			}
  6123  		})
  6124  	}
  6125  }
  6126  
  6127  // breakableConn is a net.Conn wrapper with a Write method
  6128  // that will fail when its brokenState is true.
  6129  type breakableConn struct {
  6130  	net.Conn
  6131  	*brokenState
  6132  }
  6133  
  6134  type brokenState struct {
  6135  	sync.Mutex
  6136  	broken bool
  6137  }
  6138  
  6139  func (w *breakableConn) Write(b []byte) (n int, err error) {
  6140  	w.Lock()
  6141  	defer w.Unlock()
  6142  	if w.broken {
  6143  		return 0, errors.New("some write error")
  6144  	}
  6145  	return w.Conn.Write(b)
  6146  }
  6147  
  6148  // Issue 34978: don't cache a broken HTTP/2 connection
  6149  func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
  6150  	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
  6151  	defer cst.close()
  6152  
  6153  	var brokenState brokenState
  6154  
  6155  	const numReqs = 5
  6156  	var numDials, gotConns uint32 // atomic
  6157  
  6158  	cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  6159  		atomic.AddUint32(&numDials, 1)
  6160  		c, err := net.Dial(netw, addr)
  6161  		if err != nil {
  6162  			t.Errorf("unexpected Dial error: %v", err)
  6163  			return nil, err
  6164  		}
  6165  		return &breakableConn{c, &brokenState}, err
  6166  	}
  6167  
  6168  	for i := 1; i <= numReqs; i++ {
  6169  		brokenState.Lock()
  6170  		brokenState.broken = false
  6171  		brokenState.Unlock()
  6172  
  6173  		// doBreak controls whether we break the TCP connection after the TLS
  6174  		// handshake (before the HTTP/2 handshake). We test a few failures
  6175  		// in a row followed by a final success.
  6176  		doBreak := i != numReqs
  6177  
  6178  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  6179  			GotConn: func(info httptrace.GotConnInfo) {
  6180  				t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
  6181  				atomic.AddUint32(&gotConns, 1)
  6182  			},
  6183  			TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
  6184  				brokenState.Lock()
  6185  				defer brokenState.Unlock()
  6186  				if doBreak {
  6187  					brokenState.broken = true
  6188  				}
  6189  			},
  6190  		})
  6191  		req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
  6192  		if err != nil {
  6193  			t.Fatal(err)
  6194  		}
  6195  		_, err = cst.c.Do(req)
  6196  		if doBreak != (err != nil) {
  6197  			t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
  6198  		}
  6199  	}
  6200  	if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
  6201  		t.Errorf("GotConn calls = %v; want %v", got, want)
  6202  	}
  6203  	if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
  6204  		t.Errorf("Dials = %v; want %v", got, want)
  6205  	}
  6206  }
  6207  
  6208  // Issue 34941
  6209  // When the client has too many concurrent requests on a single connection,
  6210  // http.http2noCachedConnError is reported on multiple requests. There should
  6211  // only be one decrement regardless of the number of failures.
  6212  func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
  6213  	defer afterTest(t)
  6214  	CondSkipHTTP2(t)
  6215  
  6216  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
  6217  		_, err := w.Write([]byte("foo"))
  6218  		if err != nil {
  6219  			t.Fatalf("Write: %v", err)
  6220  		}
  6221  	})
  6222  
  6223  	ts := httptest.NewUnstartedServer(h)
  6224  	ts.EnableHTTP2 = true
  6225  	ts.StartTLS()
  6226  	defer ts.Close()
  6227  
  6228  	c := ts.Client()
  6229  	tr := c.Transport.(*Transport)
  6230  	tr.MaxConnsPerHost = 1
  6231  	if err := ExportHttp2ConfigureTransport(tr); err != nil {
  6232  		t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
  6233  	}
  6234  
  6235  	errCh := make(chan error, 300)
  6236  	doReq := func() {
  6237  		resp, err := c.Get(ts.URL)
  6238  		if err != nil {
  6239  			errCh <- fmt.Errorf("request failed: %v", err)
  6240  			return
  6241  		}
  6242  		defer resp.Body.Close()
  6243  		_, err = io.ReadAll(resp.Body)
  6244  		if err != nil {
  6245  			errCh <- fmt.Errorf("read body failed: %v", err)
  6246  		}
  6247  	}
  6248  
  6249  	var wg sync.WaitGroup
  6250  	for i := 0; i < 300; i++ {
  6251  		wg.Add(1)
  6252  		go func() {
  6253  			defer wg.Done()
  6254  			doReq()
  6255  		}()
  6256  	}
  6257  	wg.Wait()
  6258  	close(errCh)
  6259  
  6260  	for err := range errCh {
  6261  		t.Errorf("error occurred: %v", err)
  6262  	}
  6263  }
  6264  
  6265  // Issue 36820
  6266  // Test that we use the older backward compatible cancellation protocol
  6267  // when a RoundTripper is registered via RegisterProtocol.
  6268  func TestAltProtoCancellation(t *testing.T) {
  6269  	defer afterTest(t)
  6270  	tr := &Transport{}
  6271  	c := &Client{
  6272  		Transport: tr,
  6273  		Timeout:   time.Millisecond,
  6274  	}
  6275  	tr.RegisterProtocol("timeout", timeoutProto{})
  6276  	_, err := c.Get("timeout://bar.com/path")
  6277  	if err == nil {
  6278  		t.Error("request unexpectedly succeeded")
  6279  	} else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
  6280  		t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
  6281  	}
  6282  }
  6283  
  6284  var timeoutProtoErr = errors.New("canceled as expected")
  6285  
  6286  type timeoutProto struct{}
  6287  
  6288  func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
  6289  	select {
  6290  	case <-req.Cancel:
  6291  		return nil, timeoutProtoErr
  6292  	case <-time.After(5 * time.Second):
  6293  		return nil, errors.New("request was not canceled")
  6294  	}
  6295  }
  6296  
  6297  type roundTripFunc func(r *Request) (*Response, error)
  6298  
  6299  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
  6300  
  6301  // Issue 32441: body is not reset after ErrSkipAltProtocol
  6302  func TestIssue32441(t *testing.T) {
  6303  	defer afterTest(t)
  6304  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  6305  		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
  6306  			t.Error("body length is zero")
  6307  		}
  6308  	}))
  6309  	defer ts.Close()
  6310  	c := ts.Client()
  6311  	c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
  6312  		// Draining body to trigger failure condition on actual request to server.
  6313  		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
  6314  			t.Error("body length is zero during round trip")
  6315  		}
  6316  		return nil, ErrSkipAltProtocol
  6317  	}))
  6318  	if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
  6319  		t.Error(err)
  6320  	}
  6321  }
  6322  
  6323  // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
  6324  // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
  6325  func TestTransportRejectsSignInContentLength(t *testing.T) {
  6326  	cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
  6327  		w.Header().Set("Content-Length", "+3")
  6328  		w.Write([]byte("abc"))
  6329  	}))
  6330  	defer cst.Close()
  6331  
  6332  	c := cst.Client()
  6333  	res, err := c.Get(cst.URL)
  6334  	if err == nil || res != nil {
  6335  		t.Fatal("Expected a non-nil error and a nil http.Response")
  6336  	}
  6337  	if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
  6338  		t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
  6339  	}
  6340  }
  6341  
  6342  // dumpConn is a net.Conn which writes to Writer and reads from Reader
  6343  type dumpConn struct {
  6344  	io.Writer
  6345  	io.Reader
  6346  }
  6347  
  6348  func (c *dumpConn) Close() error                       { return nil }
  6349  func (c *dumpConn) LocalAddr() net.Addr                { return nil }
  6350  func (c *dumpConn) RemoteAddr() net.Addr               { return nil }
  6351  func (c *dumpConn) SetDeadline(t time.Time) error      { return nil }
  6352  func (c *dumpConn) SetReadDeadline(t time.Time) error  { return nil }
  6353  func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
  6354  
  6355  // delegateReader is a reader that delegates to another reader,
  6356  // once it arrives on a channel.
  6357  type delegateReader struct {
  6358  	c chan io.Reader
  6359  	r io.Reader // nil until received from c
  6360  }
  6361  
  6362  func (r *delegateReader) Read(p []byte) (int, error) {
  6363  	if r.r == nil {
  6364  		var ok bool
  6365  		if r.r, ok = <-r.c; !ok {
  6366  			return 0, errors.New("delegate closed")
  6367  		}
  6368  	}
  6369  	return r.r.Read(p)
  6370  }
  6371  
  6372  func testTransportRace(req *Request) {
  6373  	save := req.Body
  6374  	pr, pw := io.Pipe()
  6375  	defer pr.Close()
  6376  	defer pw.Close()
  6377  	dr := &delegateReader{c: make(chan io.Reader)}
  6378  
  6379  	t := &Transport{
  6380  		Dial: func(net, addr string) (net.Conn, error) {
  6381  			return &dumpConn{pw, dr}, nil
  6382  		},
  6383  	}
  6384  	defer t.CloseIdleConnections()
  6385  
  6386  	quitReadCh := make(chan struct{})
  6387  	// Wait for the request before replying with a dummy response:
  6388  	go func() {
  6389  		defer close(quitReadCh)
  6390  
  6391  		req, err := ReadRequest(bufio.NewReader(pr))
  6392  		if err == nil {
  6393  			// Ensure all the body is read; otherwise
  6394  			// we'll get a partial dump.
  6395  			io.Copy(io.Discard, req.Body)
  6396  			req.Body.Close()
  6397  		}
  6398  		select {
  6399  		case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
  6400  		case quitReadCh <- struct{}{}:
  6401  			// Ensure delegate is closed so Read doesn't block forever.
  6402  			close(dr.c)
  6403  		}
  6404  	}()
  6405  
  6406  	t.RoundTrip(req)
  6407  
  6408  	// Ensure the reader returns before we reset req.Body to prevent
  6409  	// a data race on req.Body.
  6410  	pw.Close()
  6411  	<-quitReadCh
  6412  
  6413  	req.Body = save
  6414  }
  6415  
  6416  // Issue 37669
  6417  // Test that a cancellation doesn't result in a data race due to the writeLoop
  6418  // goroutine being left running, if the caller mutates the processed Request
  6419  // upon completion.
  6420  func TestErrorWriteLoopRace(t *testing.T) {
  6421  	if testing.Short() {
  6422  		return
  6423  	}
  6424  	t.Parallel()
  6425  	for i := 0; i < 1000; i++ {
  6426  		delay := time.Duration(mrand.Intn(5)) * time.Millisecond
  6427  		ctx, cancel := context.WithTimeout(context.Background(), delay)
  6428  		defer cancel()
  6429  
  6430  		r := bytes.NewBuffer(make([]byte, 10000))
  6431  		req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
  6432  		if err != nil {
  6433  			t.Fatal(err)
  6434  		}
  6435  
  6436  		testTransportRace(req)
  6437  	}
  6438  }
  6439  
  6440  // Issue 41600
  6441  // Test that a new request which uses the connection of an active request
  6442  // cannot cause it to be canceled as well.
  6443  func TestCancelRequestWhenSharingConnection(t *testing.T) {
  6444  	reqc := make(chan chan struct{}, 2)
  6445  	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) {
  6446  		ch := make(chan struct{}, 1)
  6447  		reqc <- ch
  6448  		<-ch
  6449  		w.Header().Add("Content-Length", "0")
  6450  	}))
  6451  	defer ts.Close()
  6452  
  6453  	client := ts.Client()
  6454  	transport := client.Transport.(*Transport)
  6455  	transport.MaxIdleConns = 1
  6456  	transport.MaxConnsPerHost = 1
  6457  
  6458  	var wg sync.WaitGroup
  6459  
  6460  	wg.Add(1)
  6461  	putidlec := make(chan chan struct{})
  6462  	go func() {
  6463  		defer wg.Done()
  6464  		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
  6465  			PutIdleConn: func(error) {
  6466  				// Signal that the idle conn has been returned to the pool,
  6467  				// and wait for the order to proceed.
  6468  				ch := make(chan struct{})
  6469  				putidlec <- ch
  6470  				<-ch
  6471  			},
  6472  		})
  6473  		req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
  6474  		res, err := client.Do(req)
  6475  		if err == nil {
  6476  			res.Body.Close()
  6477  		}
  6478  		if err != nil {
  6479  			t.Errorf("request 1: got err %v, want nil", err)
  6480  		}
  6481  	}()
  6482  
  6483  	// Wait for the first request to receive a response and return the
  6484  	// connection to the idle pool.
  6485  	r1c := <-reqc
  6486  	close(r1c)
  6487  	idlec := <-putidlec
  6488  
  6489  	wg.Add(1)
  6490  	cancelctx, cancel := context.WithCancel(context.Background())
  6491  	go func() {
  6492  		defer wg.Done()
  6493  		req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
  6494  		res, err := client.Do(req)
  6495  		if err == nil {
  6496  			res.Body.Close()
  6497  		}
  6498  		if !errors.Is(err, context.Canceled) {
  6499  			t.Errorf("request 2: got err %v, want Canceled", err)
  6500  		}
  6501  	}()
  6502  
  6503  	// Wait for the second request to arrive at the server, and then cancel
  6504  	// the request context.
  6505  	r2c := <-reqc
  6506  	cancel()
  6507  
  6508  	// Give the cancelation a moment to take effect, and then unblock the first request.
  6509  	time.Sleep(1 * time.Millisecond)
  6510  	close(idlec)
  6511  
  6512  	close(r2c)
  6513  	wg.Wait()
  6514  }
  6515  

View as plain text