experimental/credentials: Add credentials that don't enforce ALPN (#7… · grpc/grpc-go@54b3eb9 (original) (raw)

``

1

`+

/*

`

``

2

`+

`

``

3

`+

`

``

4

`+

`

``

5

`+

`

``

6

`+

`

``

7

`+

`

``

8

`+

`

``

9

`+

`

``

10

`+

`

``

11

`+

`

``

12

`+

`

``

13

`+

`

``

14

`+

`

``

15

`+

`

``

16

`+

`

``

17

`+

*/

`

``

18

+

``

19

`+

package credentials

`

``

20

+

``

21

`+

import (

`

``

22

`+

"context"

`

``

23

`+

"crypto/tls"

`

``

24

`+

"net"

`

``

25

`+

"strings"

`

``

26

`+

"testing"

`

``

27

`+

"time"

`

``

28

+

``

29

`+

"google.golang.org/grpc/credentials"

`

``

30

`+

"google.golang.org/grpc/internal/grpctest"

`

``

31

`+

"google.golang.org/grpc/testdata"

`

``

32

`+

)

`

``

33

+

``

34

`+

const defaultTestTimeout = 10 * time.Second

`

``

35

+

``

36

`+

type s struct {

`

``

37

`+

grpctest.Tester

`

``

38

`+

}

`

``

39

+

``

40

`+

func Test(t *testing.T) {

`

``

41

`+

grpctest.RunSubTests(t, s{})

`

``

42

`+

}

`

``

43

+

``

44

`+

func (s) TestTLSOverrideServerName(t *testing.T) {

`

``

45

`+

expectedServerName := "server.name"

`

``

46

`+

c := NewTLSWithALPNDisabled(nil)

`

``

47

`+

c.OverrideServerName(expectedServerName)

`

``

48

`+

if c.Info().ServerName != expectedServerName {

`

``

49

`+

t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)

`

``

50

`+

}

`

``

51

`+

}

`

``

52

+

``

53

`+

func (s) TestTLSClone(t *testing.T) {

`

``

54

`+

expectedServerName := "server.name"

`

``

55

`+

c := NewTLSWithALPNDisabled(nil)

`

``

56

`+

c.OverrideServerName(expectedServerName)

`

``

57

`+

cc := c.Clone()

`

``

58

`+

if cc.Info().ServerName != expectedServerName {

`

``

59

`+

t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)

`

``

60

`+

}

`

``

61

`+

cc.OverrideServerName("")

`

``

62

`+

if c.Info().ServerName != expectedServerName {

`

``

63

`+

t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)

`

``

64

`+

}

`

``

65

+

``

66

`+

}

`

``

67

+

``

68

`+

type serverHandshake func(net.Conn) (credentials.AuthInfo, error)

`

``

69

+

``

70

`+

func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) {

`

``

71

`+

tcs := []struct {

`

``

72

`+

name string

`

``

73

`+

address string

`

``

74

`+

}{

`

``

75

`+

{

`

``

76

`+

name: "localhost",

`

``

77

`+

address: "localhost:0",

`

``

78

`+

},

`

``

79

`+

{

`

``

80

`+

name: "ipv4",

`

``

81

`+

address: "127.0.0.1:0",

`

``

82

`+

},

`

``

83

`+

{

`

``

84

`+

name: "ipv6",

`

``

85

`+

address: "[::1]:0",

`

``

86

`+

},

`

``

87

`+

}

`

``

88

+

``

89

`+

for _, tc := range tcs {

`

``

90

`+

t.Run(tc.name, func(t *testing.T) {

`

``

91

`+

done := make(chan credentials.AuthInfo, 1)

`

``

92

`+

lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)

`

``

93

`+

defer lis.Close()

`

``

94

`+

lisAddr := lis.Addr().String()

`

``

95

`+

clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)

`

``

96

`+

// wait until server sends serverAuthInfo or fails.

`

``

97

`+

serverAuthInfo, ok := <-done

`

``

98

`+

if !ok {

`

``

99

`+

t.Fatalf("Error at server-side")

`

``

100

`+

}

`

``

101

`+

if !compare(clientAuthInfo, serverAuthInfo) {

`

``

102

`+

t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)

`

``

103

`+

}

`

``

104

`+

})

`

``

105

`+

}

`

``

106

`+

}

`

``

107

+

``

108

`+

func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) {

`

``

109

`+

done := make(chan credentials.AuthInfo, 1)

`

``

110

`+

lis := launchServer(t, gRPCServerHandshake, done)

`

``

111

`+

defer lis.Close()

`

``

112

`+

clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())

`

``

113

`+

// wait until server sends serverAuthInfo or fails.

`

``

114

`+

serverAuthInfo, ok := <-done

`

``

115

`+

if !ok {

`

``

116

`+

t.Fatalf("Error at server-side")

`

``

117

`+

}

`

``

118

`+

if !compare(clientAuthInfo, serverAuthInfo) {

`

``

119

`+

t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)

`

``

120

`+

}

`

``

121

`+

}

`

``

122

+

``

123

`+

func (s) TestServerAndClientHandshake(t *testing.T) {

`

``

124

`+

done := make(chan credentials.AuthInfo, 1)

`

``

125

`+

lis := launchServer(t, gRPCServerHandshake, done)

`

``

126

`+

defer lis.Close()

`

``

127

`+

clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())

`

``

128

`+

// wait until server sends serverAuthInfo or fails.

`

``

129

`+

serverAuthInfo, ok := <-done

`

``

130

`+

if !ok {

`

``

131

`+

t.Fatalf("Error at server-side")

`

``

132

`+

}

`

``

133

`+

if !compare(clientAuthInfo, serverAuthInfo) {

`

``

134

`+

t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)

`

``

135

`+

}

`

``

136

`+

}

`

``

137

+

``

138

`+

func compare(a1, a2 credentials.AuthInfo) bool {

`

``

139

`+

if a1.AuthType() != a2.AuthType() {

`

``

140

`+

return false

`

``

141

`+

}

`

``

142

`+

switch a1.AuthType() {

`

``

143

`+

case "tls":

`

``

144

`+

state1 := a1.(credentials.TLSInfo).State

`

``

145

`+

state2 := a2.(credentials.TLSInfo).State

`

``

146

`+

if state1.Version == state2.Version &&

`

``

147

`+

state1.HandshakeComplete == state2.HandshakeComplete &&

`

``

148

`+

state1.CipherSuite == state2.CipherSuite &&

`

``

149

`+

state1.NegotiatedProtocol == state2.NegotiatedProtocol {

`

``

150

`+

return true

`

``

151

`+

}

`

``

152

`+

return false

`

``

153

`+

default:

`

``

154

`+

return false

`

``

155

`+

}

`

``

156

`+

}

`

``

157

+

``

158

`+

func launchServer(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo) net.Listener {

`

``

159

`+

return launchServerOnListenAddress(t, hs, done, "localhost:0")

`

``

160

`+

}

`

``

161

+

``

162

`+

func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, address string) net.Listener {

`

``

163

`+

lis, err := net.Listen("tcp", address)

`

``

164

`+

if err != nil {

`

``

165

`+

if strings.Contains(err.Error(), "bind: cannot assign requested address") ||

`

``

166

`+

strings.Contains(err.Error(), "socket: address family not supported by protocol") {

`

``

167

`+

t.Skipf("no support for address %v", address)

`

``

168

`+

}

`

``

169

`+

t.Fatalf("Failed to listen: %v", err)

`

``

170

`+

}

`

``

171

`+

go serverHandle(t, hs, done, lis)

`

``

172

`+

return lis

`

``

173

`+

}

`

``

174

+

``

175

`+

// Is run in a separate goroutine.

`

``

176

`+

func serverHandle(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, lis net.Listener) {

`

``

177

`+

serverRawConn, err := lis.Accept()

`

``

178

`+

if err != nil {

`

``

179

`+

t.Errorf("Server failed to accept connection: %v", err)

`

``

180

`+

close(done)

`

``

181

`+

return

`

``

182

`+

}

`

``

183

`+

serverAuthInfo, err := hs(serverRawConn)

`

``

184

`+

if err != nil {

`

``

185

`+

t.Errorf("Server failed while handshake. Error: %v", err)

`

``

186

`+

serverRawConn.Close()

`

``

187

`+

close(done)

`

``

188

`+

return

`

``

189

`+

}

`

``

190

`+

done <- serverAuthInfo

`

``

191

`+

}

`

``

192

+

``

193

`+

func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), lisAddr string) credentials.AuthInfo {

`

``

194

`+

conn, err := net.Dial("tcp", lisAddr)

`

``

195

`+

if err != nil {

`

``

196

`+

t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)

`

``

197

`+

}

`

``

198

`+

defer conn.Close()

`

``

199

`+

clientAuthInfo, err := hs(conn, lisAddr)

`

``

200

`+

if err != nil {

`

``

201

`+

t.Fatalf("Error on client while handshake. Error: %v", err)

`

``

202

`+

}

`

``

203

`+

return clientAuthInfo

`

``

204

`+

}

`

``

205

+

``

206

`+

// Server handshake implementation in gRPC.

`

``

207

`+

func gRPCServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {

`

``

208

`+

serverTLS, err := NewServerTLSFromFileWithALPNDisabled(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))

`

``

209

`+

if err != nil {

`

``

210

`+

return nil, err

`

``

211

`+

}

`

``

212

`+

_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)

`

``

213

`+

if err != nil {

`

``

214

`+

return nil, err

`

``

215

`+

}

`

``

216

`+

return serverAuthInfo, nil

`

``

217

`+

}

`

``

218

+

``

219

`+

// Client handshake implementation in gRPC.

`

``

220

`+

func gRPCClientHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {

`

``

221

`+

clientTLS := NewTLSWithALPNDisabled(&tls.Config{InsecureSkipVerify: true})

`

``

222

`+

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)

`

``

223

`+

defer cancel()

`

``

224

`+

_, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn)

`

``

225

`+

if err != nil {

`

``

226

`+

return nil, err

`

``

227

`+

}

`

``

228

`+

return authInfo, nil

`

``

229

`+

}

`

``

230

+

``

231

`+

func tlsServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {

`

``

232

`+

cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))

`

``

233

`+

if err != nil {

`

``

234

`+

return nil, err

`

``

235

`+

}

`

``

236

`+

serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}

`

``

237

`+

serverConn := tls.Server(conn, serverTLSConfig)

`

``

238

`+

err = serverConn.Handshake()

`

``

239

`+

if err != nil {

`

``

240

`+

return nil, err

`

``

241

`+

}

`

``

242

`+

return credentials.TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil

`

``

243

`+

}

`

``

244

+

``

245

`+

func tlsClientHandshake(conn net.Conn, _ string) (credentials.AuthInfo, error) {

`

``

246

`+

clientTLSConfig := &tls.Config{InsecureSkipVerify: true}

`

``

247

`+

clientConn := tls.Client(conn, clientTLSConfig)

`

``

248

`+

if err := clientConn.Handshake(); err != nil {

`

``

249

`+

return nil, err

`

``

250

`+

}

`

``

251

`+

return credentials.TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil

`

``

252

`+

}

`