Skip to content

Commit b2df6c3

Browse files
committed
Handle input values that are not in canonical form
1 parent 81d3c39 commit b2df6c3

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

traverse.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,12 @@ func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) *
9696

9797
pointer, bit := r.traverseTree(ip, 0, uint(prefixLength))
9898

99-
if bit < prefixLength {
100-
ip = ip.Mask(net.CIDRMask(bit, len(ip)*8))
101-
}
99+
// We could skip this when bit >= prefixLength if we assume that the network
100+
// passed in is in canonical form. However, given that this may not be the
101+
// case, it is safest to always take the mask. If this is hot code at some
102+
// point, we could eliminate the allocation of the net.IPMask by zeroing
103+
// out the bits in ip directly.
104+
ip = ip.Mask(net.CIDRMask(bit, len(ip)*8))
102105
networks.nodes = []netNode{
103106
{
104107
ip: ip,

traverse_test.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package maxminddb
33
import (
44
"fmt"
55
"net"
6+
"strconv"
7+
"strings"
68
"testing"
79

810
"github.com/stretchr/testify/assert"
@@ -71,13 +73,22 @@ var tests = []networkTest{
7173
},
7274
},
7375
{
76+
// This is intentionally in non-canonical form to test
77+
// that we handle it correctly.
7478
Network: "1.1.1.1/30",
7579
Database: "ipv4",
7680
Expected: []string{
7781
"1.1.1.1/32",
7882
"1.1.1.2/31",
7983
},
8084
},
85+
{
86+
Network: "1.1.1.2/31",
87+
Database: "ipv4",
88+
Expected: []string{
89+
"1.1.1.2/31",
90+
},
91+
},
8192
{
8293
Network: "1.1.1.1/32",
8394
Database: "ipv4",
@@ -267,7 +278,21 @@ func TestNetworksWithin(t *testing.T) {
267278
reader, err := Open(fileName)
268279
require.NoError(t, err, "unexpected error while opening database: %v", err)
269280

270-
_, network, err := net.ParseCIDR(v.Network)
281+
// We are purposely not using net.ParseCIDR so that we can pass in
282+
// values that aren't in canonical form.
283+
parts := strings.Split(v.Network, "/")
284+
ip := net.ParseIP(parts[0])
285+
if v := ip.To4(); v != nil {
286+
ip = v
287+
}
288+
prefixLength, err := strconv.Atoi(parts[1])
289+
require.NoError(t, err)
290+
mask := net.CIDRMask(prefixLength, len(ip)*8)
291+
network := &net.IPNet{
292+
IP: ip,
293+
Mask: mask,
294+
}
295+
271296
require.NoError(t, err)
272297
n := reader.NetworksWithin(network, v.Options...)
273298
var innerIPs []string

0 commit comments

Comments
 (0)