Black Lives Matter. Support the Equal Justice Initiative.

Source file src/net/splice_test.go

Documentation: net

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build linux
     6  // +build linux
     7  
     8  package net
     9  
    10  import (
    11  	"io"
    12  	"log"
    13  	"os"
    14  	"os/exec"
    15  	"strconv"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  func TestSplice(t *testing.T) {
    22  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    23  	if !testableNetwork("unixgram") {
    24  		t.Skip("skipping unix-to-tcp tests")
    25  	}
    26  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    27  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    28  	t.Run("no-unixgram", testSpliceNoUnixgram)
    29  }
    30  
    31  func testSplice(t *testing.T, upNet, downNet string) {
    32  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    33  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    34  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    35  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    36  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    37  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    38  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    39  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    40  }
    41  
    42  type spliceTestCase struct {
    43  	upNet, downNet string
    44  
    45  	chunkSize, totalSize int
    46  	limitReadSize        int
    47  }
    48  
    49  func (tc spliceTestCase) test(t *testing.T) {
    50  	clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  	defer serverUp.Close()
    55  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	defer cleanup()
    60  	clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  	defer serverDown.Close()
    65  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	defer cleanup()
    70  	var (
    71  		r    io.Reader = serverUp
    72  		size           = tc.totalSize
    73  	)
    74  	if tc.limitReadSize > 0 {
    75  		if tc.limitReadSize < size {
    76  			size = tc.limitReadSize
    77  		}
    78  
    79  		r = &io.LimitedReader{
    80  			N: int64(tc.limitReadSize),
    81  			R: serverUp,
    82  		}
    83  		defer serverUp.Close()
    84  	}
    85  	n, err := io.Copy(serverDown, r)
    86  	serverDown.Close()
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	if want := int64(size); want != n {
    91  		t.Errorf("want %d bytes spliced, got %d", want, n)
    92  	}
    93  
    94  	if tc.limitReadSize > 0 {
    95  		wantN := 0
    96  		if tc.limitReadSize > size {
    97  			wantN = tc.limitReadSize - size
    98  		}
    99  
   100  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
   101  			t.Errorf("r.N = %d, want %d", n, wantN)
   102  		}
   103  	}
   104  }
   105  
   106  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   107  	clientUp, serverUp, err := spliceTestSocketPair(upNet)
   108  	if err != nil {
   109  		t.Fatal(err)
   110  	}
   111  	defer clientUp.Close()
   112  	clientDown, serverDown, err := spliceTestSocketPair(downNet)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	defer clientDown.Close()
   117  
   118  	serverUp.Close()
   119  
   120  	// We'd like to call net.splice here and check the handled return
   121  	// value, but we disable splice on old Linux kernels.
   122  	//
   123  	// In that case, poll.Splice and net.splice return a non-nil error
   124  	// and handled == false. We'd ideally like to see handled == true
   125  	// because the source reader is at EOF, but if we're running on an old
   126  	// kernel, and splice is disabled, we won't see EOF from net.splice,
   127  	// because we won't touch the reader at all.
   128  	//
   129  	// Trying to untangle the errors from net.splice and match them
   130  	// against the errors created by the poll package would be brittle,
   131  	// so this is a higher level test.
   132  	//
   133  	// The following ReadFrom should return immediately, regardless of
   134  	// whether splice is disabled or not. The other side should then
   135  	// get a goodbye signal. Test for the goodbye signal.
   136  	msg := "bye"
   137  	go func() {
   138  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   139  		io.WriteString(serverDown, msg)
   140  		serverDown.Close()
   141  	}()
   142  
   143  	buf := make([]byte, 3)
   144  	_, err = io.ReadFull(clientDown, buf)
   145  	if err != nil {
   146  		t.Errorf("clientDown: %v", err)
   147  	}
   148  	if string(buf) != msg {
   149  		t.Errorf("clientDown got %q, want %q", buf, msg)
   150  	}
   151  }
   152  
   153  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   154  	front, err := newLocalListener(upNet)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	defer front.Close()
   159  	back, err := newLocalListener(downNet)
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	defer back.Close()
   164  
   165  	var wg sync.WaitGroup
   166  	wg.Add(2)
   167  
   168  	proxy := func() {
   169  		src, err := front.Accept()
   170  		if err != nil {
   171  			return
   172  		}
   173  		dst, err := Dial(downNet, back.Addr().String())
   174  		if err != nil {
   175  			return
   176  		}
   177  		defer dst.Close()
   178  		defer src.Close()
   179  		go func() {
   180  			io.Copy(src, dst)
   181  			wg.Done()
   182  		}()
   183  		go func() {
   184  			io.Copy(dst, src)
   185  			wg.Done()
   186  		}()
   187  	}
   188  
   189  	go proxy()
   190  
   191  	toFront, err := Dial(upNet, front.Addr().String())
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	io.WriteString(toFront, "foo")
   197  	toFront.Close()
   198  
   199  	fromProxy, err := back.Accept()
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	defer fromProxy.Close()
   204  
   205  	_, err = io.ReadAll(fromProxy)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	wg.Wait()
   211  }
   212  
   213  func testSpliceNoUnixpacket(t *testing.T) {
   214  	clientUp, serverUp, err := spliceTestSocketPair("unixpacket")
   215  	if err != nil {
   216  		t.Fatal(err)
   217  	}
   218  	defer clientUp.Close()
   219  	defer serverUp.Close()
   220  	clientDown, serverDown, err := spliceTestSocketPair("tcp")
   221  	if err != nil {
   222  		t.Fatal(err)
   223  	}
   224  	defer clientDown.Close()
   225  	defer serverDown.Close()
   226  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   227  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   228  	// try, it assumes the kernel it's running on doesn't support splice
   229  	// for unix sockets and returns handled == false. This works for our
   230  	// purposes by somewhat of an accident, but is not entirely correct.
   231  	//
   232  	// What we want is err == nil and handled == false, i.e. we never
   233  	// called poll.Splice, because we know the unix socket's network.
   234  	_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
   235  	if err != nil || handled != false {
   236  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   237  	}
   238  }
   239  
   240  func testSpliceNoUnixgram(t *testing.T) {
   241  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr())
   242  	if err != nil {
   243  		t.Fatal(err)
   244  	}
   245  	defer os.Remove(addr.Name)
   246  	up, err := ListenUnixgram("unixgram", addr)
   247  	if err != nil {
   248  		t.Fatal(err)
   249  	}
   250  	defer up.Close()
   251  	clientDown, serverDown, err := spliceTestSocketPair("tcp")
   252  	if err != nil {
   253  		t.Fatal(err)
   254  	}
   255  	defer clientDown.Close()
   256  	defer serverDown.Close()
   257  	// Analogous to testSpliceNoUnixpacket.
   258  	_, err, handled := splice(serverDown.(*TCPConn).fd, up)
   259  	if err != nil || handled != false {
   260  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   261  	}
   262  }
   263  
   264  func BenchmarkSplice(b *testing.B) {
   265  	testHookUninstaller.Do(uninstallTestHooks)
   266  
   267  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   268  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   269  }
   270  
   271  func benchSplice(b *testing.B, upNet, downNet string) {
   272  	for i := 0; i <= 10; i++ {
   273  		chunkSize := 1 << uint(i+10)
   274  		tc := spliceTestCase{
   275  			upNet:     upNet,
   276  			downNet:   downNet,
   277  			chunkSize: chunkSize,
   278  		}
   279  
   280  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   281  	}
   282  }
   283  
   284  func (tc spliceTestCase) bench(b *testing.B) {
   285  	// To benchmark the genericReadFrom code path, set this to false.
   286  	useSplice := true
   287  
   288  	clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
   289  	if err != nil {
   290  		b.Fatal(err)
   291  	}
   292  	defer serverUp.Close()
   293  
   294  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   295  	if err != nil {
   296  		b.Fatal(err)
   297  	}
   298  	defer cleanup()
   299  
   300  	clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
   301  	if err != nil {
   302  		b.Fatal(err)
   303  	}
   304  	defer serverDown.Close()
   305  
   306  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   307  	if err != nil {
   308  		b.Fatal(err)
   309  	}
   310  	defer cleanup()
   311  
   312  	b.SetBytes(int64(tc.chunkSize))
   313  	b.ResetTimer()
   314  
   315  	if useSplice {
   316  		_, err := io.Copy(serverDown, serverUp)
   317  		if err != nil {
   318  			b.Fatal(err)
   319  		}
   320  	} else {
   321  		type onlyReader struct {
   322  			io.Reader
   323  		}
   324  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   325  		if err != nil {
   326  			b.Fatal(err)
   327  		}
   328  	}
   329  }
   330  
   331  func spliceTestSocketPair(net string) (client, server Conn, err error) {
   332  	ln, err := newLocalListener(net)
   333  	if err != nil {
   334  		return nil, nil, err
   335  	}
   336  	defer ln.Close()
   337  	var cerr, serr error
   338  	acceptDone := make(chan struct{})
   339  	go func() {
   340  		server, serr = ln.Accept()
   341  		acceptDone <- struct{}{}
   342  	}()
   343  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   344  	<-acceptDone
   345  	if cerr != nil {
   346  		if server != nil {
   347  			server.Close()
   348  		}
   349  		return nil, nil, cerr
   350  	}
   351  	if serr != nil {
   352  		if client != nil {
   353  			client.Close()
   354  		}
   355  		return nil, nil, serr
   356  	}
   357  	return client, server, nil
   358  }
   359  
   360  func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
   361  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
   367  	cmd.Env = []string{
   368  		"GO_NET_TEST_SPLICE=1",
   369  		"GO_NET_TEST_SPLICE_OP=" + op,
   370  		"GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   371  		"GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   372  		"TMPDIR=" + os.Getenv("TMPDIR"),
   373  	}
   374  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   375  	cmd.Stdout = os.Stdout
   376  	cmd.Stderr = os.Stderr
   377  
   378  	if err := cmd.Start(); err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	donec := make(chan struct{})
   383  	go func() {
   384  		cmd.Wait()
   385  		conn.Close()
   386  		f.Close()
   387  		close(donec)
   388  	}()
   389  
   390  	return func() {
   391  		select {
   392  		case <-donec:
   393  		case <-time.After(5 * time.Second):
   394  			log.Printf("killing splice client after 5 second shutdown timeout")
   395  			cmd.Process.Kill()
   396  			select {
   397  			case <-donec:
   398  			case <-time.After(5 * time.Second):
   399  				log.Printf("splice client didn't die after 10 seconds")
   400  			}
   401  		}
   402  	}, nil
   403  }
   404  
   405  func init() {
   406  	if os.Getenv("GO_NET_TEST_SPLICE") == "" {
   407  		return
   408  	}
   409  	defer os.Exit(0)
   410  
   411  	f := os.NewFile(uintptr(3), "splice-test-conn")
   412  	defer f.Close()
   413  
   414  	conn, err := FileConn(f)
   415  	if err != nil {
   416  		log.Fatal(err)
   417  	}
   418  
   419  	var chunkSize int
   420  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
   421  		log.Fatal(err)
   422  	}
   423  	buf := make([]byte, chunkSize)
   424  
   425  	var totalSize int
   426  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
   427  		log.Fatal(err)
   428  	}
   429  
   430  	var fn func([]byte) (int, error)
   431  	switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
   432  	case "r":
   433  		fn = conn.Read
   434  	case "w":
   435  		defer conn.Close()
   436  
   437  		fn = conn.Write
   438  	default:
   439  		log.Fatalf("unknown op %q", op)
   440  	}
   441  
   442  	var n int
   443  	for count := 0; count < totalSize; count += n {
   444  		if count+chunkSize > totalSize {
   445  			buf = buf[:totalSize-count]
   446  		}
   447  
   448  		var err error
   449  		if n, err = fn(buf); err != nil {
   450  			return
   451  		}
   452  	}
   453  }
   454  

View as plain text