experimental/credentials: Add credentials that don't enforce ALPN (#7… · grpc/grpc-go@54b3eb9 (original) (raw)
``
1
`+
/*
`
``
2
`+
`
``
3
`+
- Copyright 2025 gRPC authors.
`
``
4
`+
`
``
5
`+
- Licensed under the Apache License, Version 2.0 (the "License");
`
``
6
`+
- you may not use this file except in compliance with the License.
`
``
7
`+
- You may obtain a copy of the License at
`
``
8
`+
`
``
9
`+
`
``
10
`+
`
``
11
`+
- Unless required by applicable law or agreed to in writing, software
`
``
12
`+
- distributed under the License is distributed on an "AS IS" BASIS,
`
``
13
`+
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
`
``
14
`+
- See the License for the specific language governing permissions and
`
``
15
`+
- limitations under the License.
`
``
16
`+
`
``
17
`+
*/
`
``
18
+
``
19
`+
package credentials
`
``
20
+
``
21
`+
import (
`
``
22
`+
"context"
`
``
23
`+
"crypto/tls"
`
``
24
`+
"net"
`
``
25
`+
"strings"
`
``
26
`+
"testing"
`
``
27
`+
"time"
`
``
28
+
``
29
`+
"google.golang.org/grpc/credentials"
`
``
30
`+
"google.golang.org/grpc/internal/grpctest"
`
``
31
`+
"google.golang.org/grpc/testdata"
`
``
32
`+
)
`
``
33
+
``
34
`+
const defaultTestTimeout = 10 * time.Second
`
``
35
+
``
36
`+
type s struct {
`
``
37
`+
grpctest.Tester
`
``
38
`+
}
`
``
39
+
``
40
`+
func Test(t *testing.T) {
`
``
41
`+
grpctest.RunSubTests(t, s{})
`
``
42
`+
}
`
``
43
+
``
44
`+
func (s) TestTLSOverrideServerName(t *testing.T) {
`
``
45
`+
expectedServerName := "server.name"
`
``
46
`+
c := NewTLSWithALPNDisabled(nil)
`
``
47
`+
c.OverrideServerName(expectedServerName)
`
``
48
`+
if c.Info().ServerName != expectedServerName {
`
``
49
`+
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
`
``
50
`+
}
`
``
51
`+
}
`
``
52
+
``
53
`+
func (s) TestTLSClone(t *testing.T) {
`
``
54
`+
expectedServerName := "server.name"
`
``
55
`+
c := NewTLSWithALPNDisabled(nil)
`
``
56
`+
c.OverrideServerName(expectedServerName)
`
``
57
`+
cc := c.Clone()
`
``
58
`+
if cc.Info().ServerName != expectedServerName {
`
``
59
`+
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
`
``
60
`+
}
`
``
61
`+
cc.OverrideServerName("")
`
``
62
`+
if c.Info().ServerName != expectedServerName {
`
``
63
`+
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
`
``
64
`+
}
`
``
65
+
``
66
`+
}
`
``
67
+
``
68
`+
type serverHandshake func(net.Conn) (credentials.AuthInfo, error)
`
``
69
+
``
70
`+
func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) {
`
``
71
`+
tcs := []struct {
`
``
72
`+
name string
`
``
73
`+
address string
`
``
74
`+
}{
`
``
75
`+
{
`
``
76
`+
name: "localhost",
`
``
77
`+
address: "localhost:0",
`
``
78
`+
},
`
``
79
`+
{
`
``
80
`+
name: "ipv4",
`
``
81
`+
address: "127.0.0.1:0",
`
``
82
`+
},
`
``
83
`+
{
`
``
84
`+
name: "ipv6",
`
``
85
`+
address: "[::1]:0",
`
``
86
`+
},
`
``
87
`+
}
`
``
88
+
``
89
`+
for _, tc := range tcs {
`
``
90
`+
t.Run(tc.name, func(t *testing.T) {
`
``
91
`+
done := make(chan credentials.AuthInfo, 1)
`
``
92
`+
lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)
`
``
93
`+
defer lis.Close()
`
``
94
`+
lisAddr := lis.Addr().String()
`
``
95
`+
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
`
``
96
`+
// wait until server sends serverAuthInfo or fails.
`
``
97
`+
serverAuthInfo, ok := <-done
`
``
98
`+
if !ok {
`
``
99
`+
t.Fatalf("Error at server-side")
`
``
100
`+
}
`
``
101
`+
if !compare(clientAuthInfo, serverAuthInfo) {
`
``
102
`+
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
`
``
103
`+
}
`
``
104
`+
})
`
``
105
`+
}
`
``
106
`+
}
`
``
107
+
``
108
`+
func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) {
`
``
109
`+
done := make(chan credentials.AuthInfo, 1)
`
``
110
`+
lis := launchServer(t, gRPCServerHandshake, done)
`
``
111
`+
defer lis.Close()
`
``
112
`+
clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
`
``
113
`+
// wait until server sends serverAuthInfo or fails.
`
``
114
`+
serverAuthInfo, ok := <-done
`
``
115
`+
if !ok {
`
``
116
`+
t.Fatalf("Error at server-side")
`
``
117
`+
}
`
``
118
`+
if !compare(clientAuthInfo, serverAuthInfo) {
`
``
119
`+
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
`
``
120
`+
}
`
``
121
`+
}
`
``
122
+
``
123
`+
func (s) TestServerAndClientHandshake(t *testing.T) {
`
``
124
`+
done := make(chan credentials.AuthInfo, 1)
`
``
125
`+
lis := launchServer(t, gRPCServerHandshake, done)
`
``
126
`+
defer lis.Close()
`
``
127
`+
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
`
``
128
`+
// wait until server sends serverAuthInfo or fails.
`
``
129
`+
serverAuthInfo, ok := <-done
`
``
130
`+
if !ok {
`
``
131
`+
t.Fatalf("Error at server-side")
`
``
132
`+
}
`
``
133
`+
if !compare(clientAuthInfo, serverAuthInfo) {
`
``
134
`+
t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
`
``
135
`+
}
`
``
136
`+
}
`
``
137
+
``
138
`+
func compare(a1, a2 credentials.AuthInfo) bool {
`
``
139
`+
if a1.AuthType() != a2.AuthType() {
`
``
140
`+
return false
`
``
141
`+
}
`
``
142
`+
switch a1.AuthType() {
`
``
143
`+
case "tls":
`
``
144
`+
state1 := a1.(credentials.TLSInfo).State
`
``
145
`+
state2 := a2.(credentials.TLSInfo).State
`
``
146
`+
if state1.Version == state2.Version &&
`
``
147
`+
state1.HandshakeComplete == state2.HandshakeComplete &&
`
``
148
`+
state1.CipherSuite == state2.CipherSuite &&
`
``
149
`+
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
`
``
150
`+
return true
`
``
151
`+
}
`
``
152
`+
return false
`
``
153
`+
default:
`
``
154
`+
return false
`
``
155
`+
}
`
``
156
`+
}
`
``
157
+
``
158
`+
func launchServer(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo) net.Listener {
`
``
159
`+
return launchServerOnListenAddress(t, hs, done, "localhost:0")
`
``
160
`+
}
`
``
161
+
``
162
`+
func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, address string) net.Listener {
`
``
163
`+
lis, err := net.Listen("tcp", address)
`
``
164
`+
if err != nil {
`
``
165
`+
if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
`
``
166
`+
strings.Contains(err.Error(), "socket: address family not supported by protocol") {
`
``
167
`+
t.Skipf("no support for address %v", address)
`
``
168
`+
}
`
``
169
`+
t.Fatalf("Failed to listen: %v", err)
`
``
170
`+
}
`
``
171
`+
go serverHandle(t, hs, done, lis)
`
``
172
`+
return lis
`
``
173
`+
}
`
``
174
+
``
175
`+
// Is run in a separate goroutine.
`
``
176
`+
func serverHandle(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, lis net.Listener) {
`
``
177
`+
serverRawConn, err := lis.Accept()
`
``
178
`+
if err != nil {
`
``
179
`+
t.Errorf("Server failed to accept connection: %v", err)
`
``
180
`+
close(done)
`
``
181
`+
return
`
``
182
`+
}
`
``
183
`+
serverAuthInfo, err := hs(serverRawConn)
`
``
184
`+
if err != nil {
`
``
185
`+
t.Errorf("Server failed while handshake. Error: %v", err)
`
``
186
`+
serverRawConn.Close()
`
``
187
`+
close(done)
`
``
188
`+
return
`
``
189
`+
}
`
``
190
`+
done <- serverAuthInfo
`
``
191
`+
}
`
``
192
+
``
193
`+
func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), lisAddr string) credentials.AuthInfo {
`
``
194
`+
conn, err := net.Dial("tcp", lisAddr)
`
``
195
`+
if err != nil {
`
``
196
`+
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
`
``
197
`+
}
`
``
198
`+
defer conn.Close()
`
``
199
`+
clientAuthInfo, err := hs(conn, lisAddr)
`
``
200
`+
if err != nil {
`
``
201
`+
t.Fatalf("Error on client while handshake. Error: %v", err)
`
``
202
`+
}
`
``
203
`+
return clientAuthInfo
`
``
204
`+
}
`
``
205
+
``
206
`+
// Server handshake implementation in gRPC.
`
``
207
`+
func gRPCServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {
`
``
208
`+
serverTLS, err := NewServerTLSFromFileWithALPNDisabled(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
`
``
209
`+
if err != nil {
`
``
210
`+
return nil, err
`
``
211
`+
}
`
``
212
`+
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
`
``
213
`+
if err != nil {
`
``
214
`+
return nil, err
`
``
215
`+
}
`
``
216
`+
return serverAuthInfo, nil
`
``
217
`+
}
`
``
218
+
``
219
`+
// Client handshake implementation in gRPC.
`
``
220
`+
func gRPCClientHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
`
``
221
`+
clientTLS := NewTLSWithALPNDisabled(&tls.Config{InsecureSkipVerify: true})
`
``
222
`+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
`
``
223
`+
defer cancel()
`
``
224
`+
_, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn)
`
``
225
`+
if err != nil {
`
``
226
`+
return nil, err
`
``
227
`+
}
`
``
228
`+
return authInfo, nil
`
``
229
`+
}
`
``
230
+
``
231
`+
func tlsServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {
`
``
232
`+
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
`
``
233
`+
if err != nil {
`
``
234
`+
return nil, err
`
``
235
`+
}
`
``
236
`+
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
`
``
237
`+
serverConn := tls.Server(conn, serverTLSConfig)
`
``
238
`+
err = serverConn.Handshake()
`
``
239
`+
if err != nil {
`
``
240
`+
return nil, err
`
``
241
`+
}
`
``
242
`+
return credentials.TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil
`
``
243
`+
}
`
``
244
+
``
245
`+
func tlsClientHandshake(conn net.Conn, _ string) (credentials.AuthInfo, error) {
`
``
246
`+
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
`
``
247
`+
clientConn := tls.Client(conn, clientTLSConfig)
`
``
248
`+
if err := clientConn.Handshake(); err != nil {
`
``
249
`+
return nil, err
`
``
250
`+
}
`
``
251
`+
return credentials.TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil
`
``
252
`+
}
`