proxy, http/httpproxy: do not mismatch IPv6 zone ids against hosts · golang/net@cde1dda (original) (raw)

`@@ -7,8 +7,9 @@ package proxy

`

7

7

`import (

`

8

8

`"context"

`

9

9

`"errors"

`

``

10

`+

"fmt"

`

10

11

`"net"

`

11

``

`-

"reflect"

`

``

12

`+

"slices"

`

12

13

`"testing"

`

13

14

`)

`

14

15

``

`@@ -22,55 +23,118 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {

`

22

23

`}

`

23

24

``

24

25

`func TestPerHost(t *testing.T) {

`

25

``

`-

expectedDef := []string{

`

26

``

`-

"example.com:123",

`

27

``

`-

"1.2.3.4:123",

`

28

``

`-

"[1001::]:123",

`

29

``

`-

}

`

30

``

`-

expectedBypass := []string{

`

31

``

`-

"localhost:123",

`

32

``

`-

"zone:123",

`

33

``

`-

"foo.zone:123",

`

34

``

`-

"127.0.0.1:123",

`

35

``

`-

"10.1.2.3:123",

`

36

``

`-

"[1000::]:123",

`

37

``

`-

}

`

38

``

-

39

``

`-

t.Run("Dial", func(t *testing.T) {

`

40

``

`-

var def, bypass recordingProxy

`

41

``

`-

perHost := NewPerHost(&def, &bypass)

`

42

``

`-

perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")

`

43

``

`-

for _, addr := range expectedDef {

`

44

``

`-

perHost.Dial("tcp", addr)

`

``

26

`+

for _, test := range []struct {

`

``

27

`+

config string // passed to PerHost.AddFromString

`

``

28

`+

nomatch []string // addrs using the default dialer

`

``

29

`+

match []string // addrs using the bypass dialer

`

``

30

`+

}{{

`

``

31

`+

config: "localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16",

`

``

32

`+

nomatch: []string{

`

``

33

`+

"example.com:123",

`

``

34

`+

"1.2.3.4:123",

`

``

35

`+

"[1001::]:123",

`

``

36

`+

},

`

``

37

`+

match: []string{

`

``

38

`+

"localhost:123",

`

``

39

`+

"zone:123",

`

``

40

`+

"foo.zone:123",

`

``

41

`+

"127.0.0.1:123",

`

``

42

`+

"10.1.2.3:123",

`

``

43

`+

"[1000::]:123",

`

``

44

`+

"[1000::%25.example.com]:123",

`

``

45

`+

},

`

``

46

`+

}, {

`

``

47

`+

config: "localhost",

`

``

48

`+

nomatch: []string{

`

``

49

`+

"127.0.0.1:80",

`

``

50

`+

},

`

``

51

`+

match: []string{

`

``

52

`+

"localhost:80",

`

``

53

`+

},

`

``

54

`+

}, {

`

``

55

`+

config: "*.zone",

`

``

56

`+

nomatch: []string{

`

``

57

`+

"foo.com:80",

`

``

58

`+

},

`

``

59

`+

match: []string{

`

``

60

`+

"foo.zone:80",

`

``

61

`+

"foo.bar.zone:80",

`

``

62

`+

},

`

``

63

`+

}, {

`

``

64

`+

config: "1.2.3.4",

`

``

65

`+

nomatch: []string{

`

``

66

`+

"127.0.0.1:80",

`

``

67

`+

"11.2.3.4:80",

`

``

68

`+

},

`

``

69

`+

match: []string{

`

``

70

`+

"1.2.3.4:80",

`

``

71

`+

},

`

``

72

`+

}, {

`

``

73

`+

config: "10.0.0.0/24",

`

``

74

`+

nomatch: []string{

`

``

75

`+

"10.0.1.1:80",

`

``

76

`+

},

`

``

77

`+

match: []string{

`

``

78

`+

"10.0.0.1:80",

`

``

79

`+

"10.0.0.255:80",

`

``

80

`+

},

`

``

81

`+

}, {

`

``

82

`+

config: "fe80::/10",

`

``

83

`+

nomatch: []string{

`

``

84

`+

"[fec0::1]:80",

`

``

85

`+

"[fec0::1%en0]:80",

`

``

86

`+

},

`

``

87

`+

match: []string{

`

``

88

`+

"[fe80::1]:80",

`

``

89

`+

"[fe80::1%en0]:80",

`

``

90

`+

},

`

``

91

`+

}, {

`

``

92

`+

// We don't allow zone IDs in network prefixes,

`

``

93

`+

// so this config matches nothing.

`

``

94

`+

config: "fe80::%en0/10",

`

``

95

`+

nomatch: []string{

`

``

96

`+

"[fec0::1]:80",

`

``

97

`+

"[fec0::1%en0]:80",

`

``

98

`+

"[fe80::1]:80",

`

``

99

`+

"[fe80::1%en0]:80",

`

``

100

`+

"[fe80::1%en1]:80",

`

``

101

`+

},

`

``

102

`+

}} {

`

``

103

`+

for _, addr := range test.match {

`

``

104

`+

testPerHost(t, test.config, addr, true)

`

45

105

` }

`

46

``

`-

for _, addr := range expectedBypass {

`

47

``

`-

perHost.Dial("tcp", addr)

`

``

106

`+

for _, addr := range test.nomatch {

`

``

107

`+

testPerHost(t, test.config, addr, false)

`

48

108

` }

`

``

109

`+

}

`

``

110

`+

}

`

49

111

``

50

``

`-

if !reflect.DeepEqual(expectedDef, def.addrs) {

`

51

``

`-

t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)

`

52

``

`-

}

`

53

``

`-

if !reflect.DeepEqual(expectedBypass, bypass.addrs) {

`

54

``

`-

t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)

`

55

``

`-

}

`

56

``

`-

})

`

``

112

`+

func testPerHost(t *testing.T, config, addr string, wantMatch bool) {

`

``

113

`+

name := fmt.Sprintf("config %q, dial %q", config, addr)

`

57

114

``

58

``

`-

t.Run("DialContext", func(t *testing.T) {

`

59

``

`-

var def, bypass recordingProxy

`

60

``

`-

perHost := NewPerHost(&def, &bypass)

`

61

``

`-

perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")

`

62

``

`-

for _, addr := range expectedDef {

`

63

``

`-

perHost.DialContext(context.Background(), "tcp", addr)

`

64

``

`-

}

`

65

``

`-

for _, addr := range expectedBypass {

`

66

``

`-

perHost.DialContext(context.Background(), "tcp", addr)

`

67

``

`-

}

`

``

115

`+

var def, bypass recordingProxy

`

``

116

`+

perHost := NewPerHost(&def, &bypass)

`

``

117

`+

perHost.AddFromString(config)

`

``

118

`+

perHost.Dial("tcp", addr)

`

68

119

``

69

``

`-

if !reflect.DeepEqual(expectedDef, def.addrs) {

`

70

``

`-

t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)

`

71

``

`-

}

`

72

``

`-

if !reflect.DeepEqual(expectedBypass, bypass.addrs) {

`

73

``

`-

t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)

`

74

``

`-

}

`

75

``

`-

})

`

``

120

`+

// Dial and DialContext should have the same results.

`

``

121

`+

var defc, bypassc recordingProxy

`

``

122

`+

perHostc := NewPerHost(&defc, &bypassc)

`

``

123

`+

perHostc.AddFromString(config)

`

``

124

`+

perHostc.DialContext(context.Background(), "tcp", addr)

`

``

125

`+

if !slices.Equal(def.addrs, defc.addrs) {

`

``

126

`+

t.Errorf("%v: Dial default=%v, bypass=%v; DialContext default=%v, bypass=%v", name, def.addrs, bypass.addrs, defc.addrs, bypass.addrs)

`

``

127

`+

return

`

``

128

`+

}

`

``

129

+

``

130

`+

if got, want := slices.Concat(def.addrs, bypass.addrs), []string{addr}; !slices.Equal(got, want) {

`

``

131

`+

t.Errorf("%v: dialed %q, want %q", name, got, want)

`

``

132

`+

return

`

``

133

`+

}

`

``

134

+

``

135

`+

gotMatch := len(bypass.addrs) > 0

`

``

136

`+

if gotMatch != wantMatch {

`

``

137

`+

t.Errorf("%v: matched=%v, want %v", name, gotMatch, wantMatch)

`

``

138

`+

return

`

``

139

`+

}

`

76

140

`}

`