@@ -2,8 +2,11 @@ package http
2
2
3
3
import (
4
4
"context"
5
- "errors"
6
5
"net"
6
+ "net/http"
7
+ "net/http/httptrace"
8
+
9
+ "github.com/sirupsen/logrus"
7
10
)
8
11
9
12
var privateIPBlocks []* net.IPNet
@@ -26,27 +29,75 @@ func init() {
26
29
}
27
30
}
28
31
29
- func isPrivateIP ( ip net.IP ) bool {
30
- for _ , block := range privateIPBlocks {
32
+ func blocksContain ( blocks [] * net. IPNet , ip net.IP ) bool {
33
+ for _ , block := range blocks {
31
34
if block .Contains (ip ) {
32
35
return true
33
36
}
34
37
}
35
38
return false
36
39
}
37
40
38
- func isLocalAddress (addr string ) bool {
39
- ip := net .ParseIP (addr )
40
- return isPrivateIP (ip )
41
+ func isPrivateIP (ip net.IP ) bool {
42
+ return blocksContain (privateIPBlocks , ip )
41
43
}
42
44
43
- // SafeDialContext exchanges a DialContext for a SafeDialContext that will never dial a reserved IP range
44
- func SafeDialContext (dialContext func (ctx context.Context , network , addr string ) (net.Conn , error )) func (ctx context.Context , network , addr string ) (net.Conn , error ) {
45
- return func (ctx context.Context , network , addr string ) (net.Conn , error ) {
46
- if isLocalAddress (addr ) {
47
- return nil , errors .New ("Connection to local network address denied" )
48
- }
45
+ type noLocalTransport struct {
46
+ inner http.RoundTripper
47
+ errlog logrus.FieldLogger
48
+ allowedBlocks []* net.IPNet
49
+ }
50
+
51
+ func (no noLocalTransport ) RoundTrip (req * http.Request ) (* http.Response , error ) {
52
+ ctx , cancel := context .WithCancel (req .Context ())
53
+
54
+ ctx = httptrace .WithClientTrace (ctx , & httptrace.ClientTrace {
55
+ ConnectStart : func (network , addr string ) {
56
+ host , _ , err := net .SplitHostPort (addr )
57
+ if err != nil {
58
+ cancel ()
59
+ no .errlog .WithError (err ).Error ("Cancelled request due to error in address parsing" )
60
+ return
61
+ }
62
+ ip := net .ParseIP (host )
63
+ if ip == nil {
64
+ cancel ()
65
+ no .errlog .WithError (err ).Error ("Cancelled request due to error in ip parsing" )
66
+ return
67
+ }
68
+
69
+ if blocksContain (no .allowedBlocks , ip ) {
70
+ return
71
+ }
72
+
73
+ if isPrivateIP (ip ) {
74
+ cancel ()
75
+ no .errlog .Error ("Cancelled attempted request to ip in private range" )
76
+ return
77
+ }
78
+ },
79
+ })
49
80
50
- return dialContext (ctx , network , addr )
81
+ req = req .WithContext (ctx )
82
+ return no .inner .RoundTrip (req )
83
+ }
84
+
85
+ func SafeRoundtripper (trans http.RoundTripper , log logrus.FieldLogger , allowedBlocks ... * net.IPNet ) http.RoundTripper {
86
+ if trans == nil {
87
+ trans = http .DefaultTransport
88
+ }
89
+
90
+ ret := & noLocalTransport {
91
+ inner : trans ,
92
+ errlog : log .WithField ("transport" , "local_blocker" ),
93
+ allowedBlocks : allowedBlocks ,
51
94
}
95
+
96
+ return ret
97
+ }
98
+
99
+ func SafeHTTPClient (client * http.Client , log logrus.FieldLogger , allowedBlocks ... * net.IPNet ) * http.Client {
100
+ client .Transport = SafeRoundtripper (client .Transport , log , allowedBlocks ... )
101
+
102
+ return client
52
103
}
0 commit comments