Source file
src/syscall/creds_test.go
Documentation: syscall
1
2
3
4
5
6
7
8 package syscall_test
9
10 import (
11 "bytes"
12 "net"
13 "os"
14 "syscall"
15 "testing"
16 )
17
18
19
20
21
22 func TestSCMCredentials(t *testing.T) {
23 socketTypeTests := []struct {
24 socketType int
25 dataLen int
26 }{
27 {
28 syscall.SOCK_STREAM,
29 1,
30 }, {
31 syscall.SOCK_DGRAM,
32 0,
33 },
34 }
35
36 for _, tt := range socketTypeTests {
37 fds, err := syscall.Socketpair(syscall.AF_LOCAL, tt.socketType, 0)
38 if err != nil {
39 t.Fatalf("Socketpair: %v", err)
40 }
41
42 err = syscall.SetsockoptInt(fds[0], syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
43 if err != nil {
44 syscall.Close(fds[0])
45 syscall.Close(fds[1])
46 t.Fatalf("SetsockoptInt: %v", err)
47 }
48
49 srvFile := os.NewFile(uintptr(fds[0]), "server")
50 cliFile := os.NewFile(uintptr(fds[1]), "client")
51 defer srvFile.Close()
52 defer cliFile.Close()
53
54 srv, err := net.FileConn(srvFile)
55 if err != nil {
56 t.Errorf("FileConn: %v", err)
57 return
58 }
59 defer srv.Close()
60
61 cli, err := net.FileConn(cliFile)
62 if err != nil {
63 t.Errorf("FileConn: %v", err)
64 return
65 }
66 defer cli.Close()
67
68 var ucred syscall.Ucred
69 if os.Getuid() != 0 {
70 ucred.Pid = int32(os.Getpid())
71 ucred.Uid = 0
72 ucred.Gid = 0
73 oob := syscall.UnixCredentials(&ucred)
74 _, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
75 if op, ok := err.(*net.OpError); ok {
76 err = op.Err
77 }
78 if sys, ok := err.(*os.SyscallError); ok {
79 err = sys.Err
80 }
81 if err != syscall.EPERM {
82 t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err)
83 }
84 }
85
86 ucred.Pid = int32(os.Getpid())
87 ucred.Uid = uint32(os.Getuid())
88 ucred.Gid = uint32(os.Getgid())
89 oob := syscall.UnixCredentials(&ucred)
90
91
92 n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
93 if err != nil {
94 t.Fatalf("WriteMsgUnix: %v", err)
95 }
96 if n != 0 {
97 t.Fatalf("WriteMsgUnix n = %d, want 0", n)
98 }
99 if oobn != len(oob) {
100 t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
101 }
102
103 oob2 := make([]byte, 10*len(oob))
104 n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
105 if err != nil {
106 t.Fatalf("ReadMsgUnix: %v", err)
107 }
108 if flags != syscall.MSG_CMSG_CLOEXEC {
109 t.Fatalf("ReadMsgUnix flags = %#x, want %#x (MSG_CMSG_CLOEXEC)", flags, syscall.MSG_CMSG_CLOEXEC)
110 }
111 if n != tt.dataLen {
112 t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
113 }
114 if oobn2 != oobn {
115
116
117 t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
118 }
119 oob2 = oob2[:oobn2]
120 if !bytes.Equal(oob, oob2) {
121 t.Fatal("ReadMsgUnix oob bytes don't match")
122 }
123
124 scm, err := syscall.ParseSocketControlMessage(oob2)
125 if err != nil {
126 t.Fatalf("ParseSocketControlMessage: %v", err)
127 }
128 newUcred, err := syscall.ParseUnixCredentials(&scm[0])
129 if err != nil {
130 t.Fatalf("ParseUnixCredentials: %v", err)
131 }
132 if *newUcred != ucred {
133 t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
134 }
135 }
136 }
137
View as plain text