diff --git a/fou_linux.go b/fou_linux.go index ed55b2b7..b93108bb 100644 --- a/fou_linux.go +++ b/fou_linux.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package netlink @@ -171,36 +172,40 @@ func (h *Handle) FouList(fam int) ([]Fou, error) { func deserializeFouMsg(msg []byte) (Fou, error) { // we'll skip to byte 4 to first attribute msg = msg[3:] - var shift int fou := Fou{} for { - // attribute header is at least 16 bits + // attribute header is at least 32 bits (16 bit type + 16 bit length) if len(msg) < 4 { return fou, ErrAttrHeaderTruncated } lgt := int(binary.BigEndian.Uint16(msg[0:2])) - if len(msg) < lgt+4 { + lgt4 := lgt & (^0x3) + + // Padding to 4 bytes according to netlink man7 + lgtPad := lgt4 + 4 + if lgt4 == lgt { + lgtPad = lgt + } + + if len(msg) < lgtPad { return fou, ErrAttrBodyTruncated } attr := binary.BigEndian.Uint16(msg[2:4]) - shift = lgt + 3 switch attr { case FOU_ATTR_AF: fou.Family = int(msg[5]) case FOU_ATTR_PORT: fou.Port = int(binary.BigEndian.Uint16(msg[5:7])) - // port is 2 bytes - shift = lgt + 2 case FOU_ATTR_IPPROTO: fou.Protocol = int(msg[5]) case FOU_ATTR_TYPE: fou.EncapType = int(msg[5]) } - msg = msg[shift:] + msg = msg[lgtPad:] if len(msg) < 4 { break diff --git a/fou_test.go b/fou_test.go index b252bbad..afdc32c0 100644 --- a/fou_test.go +++ b/fou_test.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package netlink @@ -42,7 +43,7 @@ func TestFouDeserializeMsg(t *testing.T) { } // deserialize truncated attribute header - msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0} + msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0} if _, err := deserializeFouMsg(msg); err == nil { t.Error("expected attribute body truncated error") } else if err != ErrAttrBodyTruncated {