Skip to content

Commit 0aa04a3

Browse files
authored
Merge pull request #82 from netlify/safe-http-client
Add a SafeHttpClient method
2 parents 7ea416f + ee82797 commit 0aa04a3

File tree

2 files changed

+129
-20
lines changed

2 files changed

+129
-20
lines changed

http/http.go

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package http
22

33
import (
44
"context"
5-
"errors"
65
"net"
6+
"net/http"
7+
"net/http/httptrace"
8+
9+
"github.com/sirupsen/logrus"
710
)
811

912
var privateIPBlocks []*net.IPNet
@@ -26,27 +29,75 @@ func init() {
2629
}
2730
}
2831

29-
func isPrivateIP(ip net.IP) bool {
30-
for _, block := range privateIPBlocks {
32+
func blocksContain(blocks []*net.IPNet, ip net.IP) bool {
33+
for _, block := range blocks {
3134
if block.Contains(ip) {
3235
return true
3336
}
3437
}
3538
return false
3639
}
3740

38-
func isLocalAddress(addr string) bool {
39-
ip := net.ParseIP(addr)
40-
return isPrivateIP(ip)
41+
func isPrivateIP(ip net.IP) bool {
42+
return blocksContain(privateIPBlocks, ip)
4143
}
4244

43-
// SafeDialContext exchanges a DialContext for a SafeDialContext that will never dial a reserved IP range
44-
func SafeDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
45-
return func(ctx context.Context, network, addr string) (net.Conn, error) {
46-
if isLocalAddress(addr) {
47-
return nil, errors.New("Connection to local network address denied")
48-
}
45+
type noLocalTransport struct {
46+
inner http.RoundTripper
47+
errlog logrus.FieldLogger
48+
allowedBlocks []*net.IPNet
49+
}
50+
51+
func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error) {
52+
ctx, cancel := context.WithCancel(req.Context())
53+
54+
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
55+
ConnectStart: func(network, addr string) {
56+
host, _, err := net.SplitHostPort(addr)
57+
if err != nil {
58+
cancel()
59+
no.errlog.WithError(err).Error("Cancelled request due to error in address parsing")
60+
return
61+
}
62+
ip := net.ParseIP(host)
63+
if ip == nil {
64+
cancel()
65+
no.errlog.WithError(err).Error("Cancelled request due to error in ip parsing")
66+
return
67+
}
68+
69+
if blocksContain(no.allowedBlocks, ip) {
70+
return
71+
}
72+
73+
if isPrivateIP(ip) {
74+
cancel()
75+
no.errlog.Error("Cancelled attempted request to ip in private range")
76+
return
77+
}
78+
},
79+
})
4980

50-
return dialContext(ctx, network, addr)
81+
req = req.WithContext(ctx)
82+
return no.inner.RoundTrip(req)
83+
}
84+
85+
func SafeRoundtripper(trans http.RoundTripper, log logrus.FieldLogger, allowedBlocks ...*net.IPNet) http.RoundTripper {
86+
if trans == nil {
87+
trans = http.DefaultTransport
88+
}
89+
90+
ret := &noLocalTransport{
91+
inner: trans,
92+
errlog: log.WithField("transport", "local_blocker"),
93+
allowedBlocks: allowedBlocks,
5194
}
95+
96+
return ret
97+
}
98+
99+
func SafeHTTPClient(client *http.Client, log logrus.FieldLogger, allowedBlocks ...*net.IPNet) *http.Client {
100+
client.Transport = SafeRoundtripper(client.Transport, log, allowedBlocks...)
101+
102+
return client
52103
}

http/http_test.go

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,74 @@
11
package http
22

33
import (
4+
"net"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
48
"testing"
59

10+
"github.com/sirupsen/logrus"
611
"github.com/stretchr/testify/assert"
712
)
813

9-
func TestIsLocalAddress(t *testing.T) {
10-
assert.False(t, isLocalAddress("216.58.194.206"))
11-
assert.True(t, isLocalAddress("127.0.0.1"))
12-
assert.True(t, isLocalAddress("10.0.0.1"))
13-
assert.True(t, isLocalAddress("192.168.0.1"))
14-
assert.True(t, isLocalAddress("172.16.0.0"))
15-
assert.True(t, isLocalAddress("169.254.169.254"))
14+
func TestIsPrivateIP(t *testing.T) {
15+
tests := []struct {
16+
ip string
17+
expected bool
18+
}{
19+
{"216.58.194.206", false},
20+
{"127.0.0.1", true},
21+
{"10.0.0.1", true},
22+
{"192.168.0.1", true},
23+
{"172.16.0.0", true},
24+
{"169.254.169.254", true},
25+
}
26+
27+
for _, tt := range tests {
28+
ip := net.ParseIP(tt.ip)
29+
assert.Equal(t, tt.expected, isPrivateIP(ip))
30+
}
31+
}
32+
33+
func TestSafeHTTPClient(t *testing.T) {
34+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35+
w.Write([]byte("Done"))
36+
}))
37+
defer ts.Close()
38+
tsURL, err := url.Parse(ts.URL)
39+
if err != nil {
40+
t.Fatal(err)
41+
}
42+
43+
client := SafeHTTPClient(&http.Client{}, logrus.New())
44+
45+
// It blocks the local IP.
46+
_, err = client.Get(ts.URL)
47+
assert.NotNil(t, err)
48+
49+
// It blocks localhost.
50+
_, err = client.Get("http://localhost:" + tsURL.Port())
51+
assert.NotNil(t, err)
52+
53+
// It succeeds when the local IP range used by the testserver is removed from
54+
// the blacklist.
55+
ipNet := popMatchingBlock(net.ParseIP(tsURL.Hostname()))
56+
_, err = client.Get(ts.URL)
57+
assert.Nil(t, err)
58+
privateIPBlocks = append(privateIPBlocks, ipNet)
59+
60+
// It allows whitelisting for local development.
61+
client = SafeHTTPClient(&http.Client{}, logrus.New(), ipNet)
62+
_, err = client.Get(ts.URL)
63+
assert.Nil(t, err)
64+
}
65+
66+
func popMatchingBlock(ip net.IP) *net.IPNet {
67+
for i, ipNet := range privateIPBlocks {
68+
if ipNet.Contains(ip) {
69+
privateIPBlocks = append(privateIPBlocks[:i], privateIPBlocks[i+1:]...)
70+
return ipNet
71+
}
72+
}
73+
return nil
1674
}

0 commit comments

Comments
 (0)