diff --git a/codec.go b/codec.go index f2951c4..0a63f12 100644 --- a/codec.go +++ b/codec.go @@ -53,6 +53,10 @@ func stringToBytes(s string) ([]byte, error) { if err != nil { return nil, fmt.Errorf("failed to parse multiaddr %q: invalid value %q for protocol %s: %s", s, sp[0], p.Name, err) } + err = p.Transcoder.ValidateBytes(a) + if err != nil { + return nil, err + } if p.Size < 0 { // varint size. _, _ = b.Write(varint.ToUvarint(uint64(len(a)))) } @@ -63,51 +67,6 @@ func stringToBytes(s string) ([]byte, error) { return b.Bytes(), nil } -func validateBytes(b []byte) (err error) { - if len(b) == 0 { - return fmt.Errorf("empty multiaddr") - } - for len(b) > 0 { - code, n, err := ReadVarintCode(b) - if err != nil { - return err - } - - b = b[n:] - p := ProtocolWithCode(code) - if p.Code == 0 { - return fmt.Errorf("no protocol with code %d", code) - } - - if p.Size == 0 { - continue - } - - n, size, err := sizeForAddr(p, b) - if err != nil { - return err - } - - b = b[n:] - - if len(b) < size || size < 0 { - return fmt.Errorf("invalid value for size %d", len(b)) - } - if p.Path && len(b) != size { - return fmt.Errorf("invalid size of component for path protocol %d: expected %d", size, len(b)) - } - - err = p.Transcoder.ValidateBytes(b[:size]) - if err != nil { - return err - } - - b = b[size:] - } - - return nil -} - func readComponent(b []byte) (int, Component, error) { var offset int code, n, err := ReadVarintCode(b) @@ -122,60 +81,64 @@ func readComponent(b []byte) (int, Component, error) { } if p.Size == 0 { - return offset, Component{ - bytes: b[:offset], + c, err := validateComponent(Component{ + bytes: string(b[:offset]), offset: offset, protocol: p, - }, nil - } + }) - n, size, err := sizeForAddr(p, b[offset:]) - if err != nil { - return 0, Component{}, err + return offset, c, err } - offset += n + var size int + if p.Size < 0 { + // varint + var n int + size, n, err = ReadVarintCode(b[offset:]) + if err != nil { + return 0, Component{}, err + } + offset += n + } else { + // Size is in bits, but we operate on bytes + size = p.Size / 8 + } - if len(b[offset:]) < size || size < 0 { + if len(b[offset:]) < size || size <= 0 { return 0, Component{}, fmt.Errorf("invalid value for size %d", len(b[offset:])) } - return offset + size, Component{ - bytes: b[:offset+size], + c, err := validateComponent(Component{ + bytes: string(b[:offset+size]), protocol: p, offset: offset, - }, nil + }) + + return offset + size, c, err } -func bytesToString(b []byte) (ret string, err error) { +func readMultiaddr(b []byte) (int, Multiaddr, error) { if len(b) == 0 { - return "", fmt.Errorf("empty multiaddr") + return 0, nil, fmt.Errorf("empty multiaddr") } - var buf strings.Builder + var res Multiaddr + bytesRead := 0 + sawPathComponent := false for len(b) > 0 { n, c, err := readComponent(b) if err != nil { - return "", err + return 0, nil, err } b = b[n:] - c.writeTo(&buf) - } + bytesRead += n - return buf.String(), nil -} - -func sizeForAddr(p Protocol, b []byte) (skip, size int, err error) { - switch { - case p.Size > 0: - return 0, (p.Size / 8), nil - case p.Size == 0: - return 0, 0, nil - default: - size, n, err := ReadVarintCode(b) - if err != nil { - return 0, 0, err + if sawPathComponent { + // It is an error to have another component after a path component. + return bytesRead, nil, fmt.Errorf("unexpected component after path component") } - return n, size, nil + sawPathComponent = c.protocol.Path + res = append(res, c) } + return bytesRead, res, nil } diff --git a/component.go b/component.go index 4ee6809..aa8f4ef 100644 --- a/component.go +++ b/component.go @@ -1,7 +1,6 @@ package multiaddr import ( - "bytes" "encoding/binary" "encoding/json" "fmt" @@ -12,20 +11,42 @@ import ( // Component is a single multiaddr Component. type Component struct { - bytes []byte + bytes string // Uses the string type to ensure immutability. protocol Protocol offset int } -func (c *Component) Bytes() []byte { - return c.bytes +func (c Component) AsMultiaddr() Multiaddr { + if c.Empty() { + return nil + } + return []Component{c} +} + +func (c Component) Encapsulate(o Multiaddr) Multiaddr { + return c.AsMultiaddr().Encapsulate(o) +} + +func (c Component) Decapsulate(o Multiaddr) Multiaddr { + return c.AsMultiaddr().Decapsulate(o) +} + +func (c Component) Empty() bool { + return len(c.bytes) == 0 +} + +func (c Component) Bytes() []byte { + return []byte(c.bytes) } -func (c *Component) MarshalBinary() ([]byte, error) { +func (c Component) MarshalBinary() ([]byte, error) { return c.Bytes(), nil } func (c *Component) UnmarshalBinary(data []byte) error { + if c == nil { + return errNilPtr + } _, comp, err := readComponent(data) if err != nil { return err @@ -34,11 +55,15 @@ func (c *Component) UnmarshalBinary(data []byte) error { return nil } -func (c *Component) MarshalText() ([]byte, error) { +func (c Component) MarshalText() ([]byte, error) { return []byte(c.String()), nil } func (c *Component) UnmarshalText(data []byte) error { + if c == nil { + return errNilPtr + } + bytes, err := stringToBytes(string(data)) if err != nil { return err @@ -51,7 +76,7 @@ func (c *Component) UnmarshalText(data []byte) error { return nil } -func (c *Component) MarshalJSON() ([]byte, error) { +func (c Component) MarshalJSON() ([]byte, error) { txt, err := c.MarshalText() if err != nil { return nil, err @@ -60,66 +85,67 @@ func (c *Component) MarshalJSON() ([]byte, error) { return json.Marshal(string(txt)) } -func (m *Component) UnmarshalJSON(data []byte) error { +func (c *Component) UnmarshalJSON(data []byte) error { + if c == nil { + return errNilPtr + } + var v string if err := json.Unmarshal(data, &v); err != nil { return err } - return m.UnmarshalText([]byte(v)) -} - -func (c *Component) Equal(o Multiaddr) bool { - if o == nil { - return false - } - return bytes.Equal(c.bytes, o.Bytes()) + return c.UnmarshalText([]byte(v)) } -func (c *Component) Protocols() []Protocol { - return []Protocol{c.protocol} +func (c Component) Equal(o Component) bool { + return c.bytes == o.bytes } -func (c *Component) Decapsulate(o Multiaddr) Multiaddr { - if c.Equal(o) { - return nil - } - return c +func (c Component) Compare(o Component) int { + return strings.Compare(c.bytes, o.bytes) } -func (c *Component) Encapsulate(o Multiaddr) Multiaddr { - m := &multiaddr{bytes: c.bytes} - return m.Encapsulate(o) +func (c Component) Protocols() []Protocol { + return []Protocol{c.protocol} } -func (c *Component) ValueForProtocol(code int) (string, error) { +func (c Component) ValueForProtocol(code int) (string, error) { if c.protocol.Code != code { return "", ErrProtocolNotFound } return c.Value(), nil } -func (c *Component) Protocol() Protocol { +func (c Component) Protocol() Protocol { return c.protocol } -func (c *Component) RawValue() []byte { - return c.bytes[c.offset:] +func (c Component) RawValue() []byte { + return []byte(c.bytes[c.offset:]) } -func (c *Component) Value() string { - if c.protocol.Transcoder == nil { +func (c Component) Value() string { + if c.Empty() { return "" } - value, err := c.protocol.Transcoder.BytesToString(c.bytes[c.offset:]) + // This Component MUST have been checked by validateComponent when created + value, _ := c.valueAndErr() + return value +} + +func (c Component) valueAndErr() (string, error) { + if c.protocol.Transcoder == nil { + return "", nil + } + value, err := c.protocol.Transcoder.BytesToString([]byte(c.bytes[c.offset:])) if err != nil { - // This Component must have been checked. - panic(err) + return "", err } - return value + return value, nil } -func (c *Component) String() string { +func (c Component) String() string { var b strings.Builder c.writeTo(&b) return b.String() @@ -127,7 +153,7 @@ func (c *Component) String() string { // writeTo is an efficient, private function for string-formatting a multiaddr. // Trust me, we tend to allocate a lot when doing this. -func (c *Component) writeTo(b *strings.Builder) { +func (c Component) writeTo(b *strings.Builder) { b.WriteByte('/') b.WriteString(c.protocol.Name) value := c.Value() @@ -141,25 +167,24 @@ func (c *Component) writeTo(b *strings.Builder) { } // NewComponent constructs a new multiaddr component -func NewComponent(protocol, value string) (*Component, error) { +func NewComponent(protocol, value string) (Component, error) { p := ProtocolWithName(protocol) if p.Code == 0 { - return nil, fmt.Errorf("unsupported protocol: %s", protocol) + return Component{}, fmt.Errorf("unsupported protocol: %s", protocol) } if p.Transcoder != nil { bts, err := p.Transcoder.StringToBytes(value) if err != nil { - return nil, err + return Component{}, err } - return newComponent(p, bts), nil + return newComponent(p, bts) } else if value != "" { - return nil, fmt.Errorf("protocol %s doesn't take a value", p.Name) + return Component{}, fmt.Errorf("protocol %s doesn't take a value", p.Name) } - return newComponent(p, nil), nil - // TODO: handle path /? + return newComponent(p, nil) } -func newComponent(protocol Protocol, bvalue []byte) *Component { +func newComponent(protocol Protocol, bvalue []byte) (Component, error) { size := len(bvalue) size += len(protocol.VCode) if protocol.Size < 0 { @@ -173,14 +198,33 @@ func newComponent(protocol Protocol, bvalue []byte) *Component { } copy(maddr[offset:], bvalue) - // For debugging + // Shouldn't happen if len(maddr) != offset+len(bvalue) { - panic("incorrect length") + return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(maddr), offset+len(bvalue)) } - return &Component{ - bytes: maddr, - protocol: protocol, - offset: offset, + return validateComponent( + Component{ + bytes: string(maddr), + protocol: protocol, + offset: offset, + }) +} + +// validateComponent MUST be called after creating a non-zero Component. +// It ensures that we will be able to call all methods on Component without +// error. +func validateComponent(c Component) (Component, error) { + _, err := c.valueAndErr() + if err != nil { + return Component{}, err + + } + if c.protocol.Transcoder != nil { + err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.offset:])) + if err != nil { + return Component{}, err + } } + return c, nil } diff --git a/interface.go b/interface.go deleted file mode 100644 index 699c54d..0000000 --- a/interface.go +++ /dev/null @@ -1,63 +0,0 @@ -package multiaddr - -import ( - "encoding" - "encoding/json" -) - -/* -Multiaddr is a cross-protocol, cross-platform format for representing -internet addresses. It emphasizes explicitness and self-description. -Learn more here: https://github.com/multiformats/multiaddr - -Multiaddrs have both a binary and string representation. - - import ma "github.com/multiformats/go-multiaddr" - - addr, err := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/80") - // err non-nil when parsing failed. -*/ -type Multiaddr interface { - json.Marshaler - json.Unmarshaler - encoding.TextMarshaler - encoding.TextUnmarshaler - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler - - // Equal returns whether two Multiaddrs are exactly equal - Equal(Multiaddr) bool - - // Bytes returns the []byte representation of this Multiaddr - // - // This function may expose immutable, internal state. Do not modify. - Bytes() []byte - - // String returns the string representation of this Multiaddr - // (may panic if internal state is corrupted) - String() string - - // Protocols returns the list of Protocols this Multiaddr includes - // will panic if protocol code incorrect (and bytes accessed incorrectly) - Protocols() []Protocol - - // Encapsulate wraps this Multiaddr around another. For example: - // - // /ip4/1.2.3.4 encapsulate /tcp/80 = /ip4/1.2.3.4/tcp/80 - // - Encapsulate(Multiaddr) Multiaddr - - // Decapsulate removes a Multiaddr wrapping. For example: - // - // /ip4/1.2.3.4/tcp/80 decapsulate /tcp/80 = /ip4/1.2.3.4 - // /ip4/1.2.3.4/tcp/80 decapsulate /udp/80 = /ip4/1.2.3.4/tcp/80 - // /ip4/1.2.3.4/tcp/80 decapsulate /ip4/1.2.3.4 = nil - // - Decapsulate(Multiaddr) Multiaddr - - // ValueForProtocol returns the value (if any) following the specified protocol - // - // Note: protocols can appear multiple times in a single multiaddr. - // Consider using `ForEach` to walk over the addr manually. - ValueForProtocol(code int) (string, error) -} diff --git a/matest/matest.go b/matest/matest.go new file mode 100644 index 0000000..aa76655 --- /dev/null +++ b/matest/matest.go @@ -0,0 +1,80 @@ +// Package matest provides utilities for testing with multiaddrs. +package matest + +import ( + "slices" + + "github.com/multiformats/go-multiaddr" +) + +type TestingT interface { + Errorf(format string, args ...interface{}) +} + +type tHelper interface { + Helper() +} + +func AssertEqualMultiaddr(t TestingT, expected, actual multiaddr.Multiaddr) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !expected.Equal(actual) { + t.Errorf("expected %v, got %v", expected, actual) + return false + } + return true +} + +func AssertEqualMultiaddrs(t TestingT, expected, actual []multiaddr.Multiaddr) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if len(expected) != len(actual) { + t.Errorf("expected %v, got %v", expected, actual) + return false + } + for i, e := range expected { + if !e.Equal(actual[i]) { + t.Errorf("expected %v, got %v", expected, actual) + return false + } + } + return true +} + +// AssertMultiaddrsMatch is the same as AssertEqualMultiaddrs, but it ignores the order of the elements. +func AssertMultiaddrsMatch(t TestingT, expected, actual []multiaddr.Multiaddr) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + e := slices.Clone(expected) + a := slices.Clone(actual) + slices.SortFunc(e, func(a, b multiaddr.Multiaddr) int { return a.Compare(b) }) + slices.SortFunc(a, func(a, b multiaddr.Multiaddr) int { return a.Compare(b) }) + return AssertEqualMultiaddrs(t, e, a) +} + +func AssertMultiaddrsContain(t TestingT, haystack []multiaddr.Multiaddr, needle multiaddr.Multiaddr) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + for _, h := range haystack { + if h.Equal(needle) { + return true + } + } + t.Errorf("expected %v to contain %v", haystack, needle) + return false +} + +type MultiaddrMatcher struct { + multiaddr.Multiaddr +} + +func (m MultiaddrMatcher) Matches(x interface{}) bool { + if m2, ok := x.(multiaddr.Multiaddr); ok { + return m.Equal(m2) + } + return false +} diff --git a/multiaddr.go b/multiaddr.go index 5e60780..86092ae 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -1,17 +1,31 @@ package multiaddr import ( - "bytes" + "cmp" "encoding/json" + "errors" "fmt" "log" + "strings" "golang.org/x/exp/slices" ) -// multiaddr is the data structure representing a Multiaddr -type multiaddr struct { - bytes []byte +var errNilPtr = errors.New("nil ptr") + +// Multiaddr is the data structure representing a Multiaddr +type Multiaddr []Component + +func (m Multiaddr) Empty() bool { + if len(m) == 0 { + return true + } + for _, c := range m { + if !c.Empty() { + return false + } + } + return true } // NewMultiaddr parses and validates an input string, returning a *Multiaddr @@ -26,7 +40,7 @@ func NewMultiaddr(s string) (a Multiaddr, err error) { if err != nil { return nil, err } - return &multiaddr{bytes: b}, nil + return NewMultiaddrBytes(b) } // NewMultiaddrBytes initializes a Multiaddr from a byte representation. @@ -38,131 +52,144 @@ func NewMultiaddrBytes(b []byte) (a Multiaddr, err error) { err = fmt.Errorf("%v", e) } }() - - if err := validateBytes(b); err != nil { + bytesRead, m, err := readMultiaddr(b) + if err != nil { return nil, err } - - return &multiaddr{bytes: b}, nil + if bytesRead != len(b) { + return nil, fmt.Errorf("unexpected extra data. %v bytes leftover", len(b)-bytesRead) + } + if len(m) == 0 { + return nil, nil + } + return m, nil } // Equal tests whether two multiaddrs are equal -func (m *multiaddr) Equal(m2 Multiaddr) bool { - if m2 == nil { +func (m Multiaddr) Equal(m2 Multiaddr) bool { + if len(m) != len(m2) { return false } - return bytes.Equal(m.bytes, m2.Bytes()) + for i, c := range m { + if !c.Equal(m2[i]) { + return false + } + } + return true +} + +func (m Multiaddr) Compare(o Multiaddr) int { + for i := 0; i < len(m) && i < len(o); i++ { + if cmp := m[i].Compare(o[i]); cmp != 0 { + return cmp + } + } + return cmp.Compare(len(m), len(o)) } // Bytes returns the []byte representation of this Multiaddr -// -// Do not modify the returned buffer, it may be shared. -func (m *multiaddr) Bytes() []byte { - return m.bytes +func (m Multiaddr) Bytes() []byte { + size := 0 + for _, c := range m { + size += len(c.bytes) + } + + out := make([]byte, 0, size) + for _, c := range m { + out = append(out, c.bytes...) + } + + return out } // String returns the string representation of a Multiaddr -func (m *multiaddr) String() string { - s, err := bytesToString(m.bytes) - if err != nil { - panic(fmt.Errorf("multiaddr failed to convert back to string. corrupted? %s", err)) +func (m Multiaddr) String() string { + var buf strings.Builder + + for _, c := range m { + c.writeTo(&buf) } - return s + return buf.String() } -func (m *multiaddr) MarshalBinary() ([]byte, error) { +func (m Multiaddr) MarshalBinary() ([]byte, error) { return m.Bytes(), nil } -func (m *multiaddr) UnmarshalBinary(data []byte) error { +func (m *Multiaddr) UnmarshalBinary(data []byte) error { + if m == nil { + return errNilPtr + } new, err := NewMultiaddrBytes(data) if err != nil { return err } - *m = *(new.(*multiaddr)) + *m = new return nil } -func (m *multiaddr) MarshalText() ([]byte, error) { +func (m Multiaddr) MarshalText() ([]byte, error) { return []byte(m.String()), nil } -func (m *multiaddr) UnmarshalText(data []byte) error { +func (m *Multiaddr) UnmarshalText(data []byte) error { + if m == nil { + return errNilPtr + } + new, err := NewMultiaddr(string(data)) if err != nil { return err } - *m = *(new.(*multiaddr)) + *m = new return nil } -func (m *multiaddr) MarshalJSON() ([]byte, error) { +func (m Multiaddr) MarshalJSON() ([]byte, error) { return json.Marshal(m.String()) } -func (m *multiaddr) UnmarshalJSON(data []byte) error { +func (m *Multiaddr) UnmarshalJSON(data []byte) error { + if m == nil { + return errNilPtr + } var v string if err := json.Unmarshal(data, &v); err != nil { return err } new, err := NewMultiaddr(v) - *m = *(new.(*multiaddr)) + *m = new return err } // Protocols returns the list of protocols this Multiaddr has. // will panic in case we access bytes incorrectly. -func (m *multiaddr) Protocols() []Protocol { - ps := make([]Protocol, 0, 8) - b := m.bytes - for len(b) > 0 { - code, n, err := ReadVarintCode(b) - if err != nil { - panic(err) - } - - p := ProtocolWithCode(code) - if p.Code == 0 { - // this is a panic (and not returning err) because this should've been - // caught on constructing the Multiaddr - panic(fmt.Errorf("no protocol with code %d", b[0])) - } - ps = append(ps, p) - b = b[n:] - - n, size, err := sizeForAddr(p, b) - if err != nil { - panic(err) - } - - b = b[n+size:] +func (m Multiaddr) Protocols() []Protocol { + out := make([]Protocol, 0, len(m)) + for _, c := range m { + out = append(out, c.Protocol()) } - return ps + return out } // Encapsulate wraps a given Multiaddr, returning the resulting joined Multiaddr -func (m *multiaddr) Encapsulate(o Multiaddr) Multiaddr { - if o == nil { - return m - } - - mb := m.bytes - ob := o.Bytes() - - b := make([]byte, len(mb)+len(ob)) - copy(b, mb) - copy(b[len(mb):], ob) - return &multiaddr{bytes: b} +func (m Multiaddr) Encapsulate(o Multiaddr) Multiaddr { + return Join(m, o) } -// Decapsulate unwraps Multiaddr up until the given Multiaddr is found. -func (m *multiaddr) Decapsulate(right Multiaddr) Multiaddr { - if right == nil { +func (m Multiaddr) EncapsulateC(c Component) Multiaddr { + if c.Empty() { return m } + out := make([]Component, 0, len(m)+1) + out = append(out, m...) + out = append(out, c) + return out +} - leftParts := Split(m) - rightParts := Split(right) +// Decapsulate unwraps Multiaddr up until the given Multiaddr is found. +func (m Multiaddr) Decapsulate(rightParts Multiaddr) Multiaddr { + leftParts := m lastIndex := -1 for i := range leftParts { @@ -189,28 +216,20 @@ func (m *multiaddr) Decapsulate(right Multiaddr) Multiaddr { } if lastIndex < 0 { - // if multiaddr not contained, returns a copy. - cpy := make([]byte, len(m.bytes)) - copy(cpy, m.bytes) - return &multiaddr{bytes: cpy} + return m } - - return Join(leftParts[:lastIndex]...) + return leftParts[:lastIndex] } var ErrProtocolNotFound = fmt.Errorf("protocol not found in multiaddr") -func (m *multiaddr) ValueForProtocol(code int) (value string, err error) { - err = ErrProtocolNotFound - ForEach(m, func(c Component) bool { +func (m Multiaddr) ValueForProtocol(code int) (value string, err error) { + for _, c := range m { if c.Protocol().Code == code { - value = c.Value() - err = nil - return false + return c.Value(), nil } - return true - }) - return + } + return "", ErrProtocolNotFound } // FilterAddrs is a filter that removes certain addresses, according to the given filters. @@ -246,7 +265,7 @@ func Unique(addrs []Multiaddr) []Multiaddr { return addrs } // Use the new slices package here, as the sort function doesn't allocate (sort.Slice does). - slices.SortFunc(addrs, func(a, b Multiaddr) int { return bytes.Compare(a.Bytes(), b.Bytes()) }) + slices.SortFunc(addrs, func(a, b Multiaddr) int { return a.Compare(b) }) idx := 1 for i := 1; i < len(addrs); i++ { if !addrs[i-1].Equal(addrs[i]) { diff --git a/multiaddr/main.go b/multiaddr/main.go index 6d0aa7b..4266d31 100644 --- a/multiaddr/main.go +++ b/multiaddr/main.go @@ -57,7 +57,7 @@ Options: func infoCommand(addr maddr.Multiaddr) { var compsJson []string - maddr.ForEach(addr, func(comp maddr.Component) bool { + for _, comp := range addr { lengthPrefix := "" if comp.Protocol().Size == maddr.LengthPrefixedVarSize { lengthPrefix = "0x" + hex.EncodeToString(maddr.CodeToVarint(len(comp.RawValue()))) @@ -76,8 +76,7 @@ func infoCommand(addr maddr.Multiaddr) { fmt.Sprintf(`"uvarint": "0x%x", `, comp.Protocol().VCode)+ fmt.Sprintf(`"lengthPrefix": "%s"`, lengthPrefix)+ `}`) - return true - }) + } addrJson := `{ "string": "%[1]s", diff --git a/multiaddr_test.go b/multiaddr_test.go index 2d3b73c..b281c2d 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -21,6 +21,53 @@ func newMultiaddr(t *testing.T, a string) Multiaddr { return m } +func TestReturnsNilOnEmpty(t *testing.T) { + a := StringCast("/ip4/1.2.3.4") + a, _ = SplitLast(a) + require.Nil(t, a) + a, _ = SplitLast(a) + require.Nil(t, a) + + // Test that empty multiaddr from various operations returns nil + a = StringCast("/ip4/1.2.3.4/tcp/1234") + _, a = SplitFirst(a) + _, a = SplitFirst(a) + require.Nil(t, a) + _, a = SplitFirst(a) + require.Nil(t, a) + + a = StringCast("/ip4/1.2.3.4/tcp/1234") + a = a.Decapsulate(a) + require.Nil(t, a) + + a = StringCast("/ip4/1.2.3.4/tcp/1234") + a = a.Decapsulate(StringCast("/tcp/1234")) + a = a.Decapsulate(StringCast("/ip4/1.2.3.4")) + require.Nil(t, a) + + // Test that SplitFunc returns nil when we split at beginning and end + a = StringCast("/ip4/1.2.3.4/tcp/1234") + pre, _ := SplitFunc(a, func(c Component) bool { + return c.Protocol().Code == P_IP4 + }) + require.Nil(t, pre) + + a = StringCast("/ip4/1.2.3.4/tcp/1234") + _, post := SplitFunc(a, func(c Component) bool { + return false + }) + require.Nil(t, post) + + _, err := NewMultiaddr("") + require.Error(t, err) + + a = JoinComponents() + require.Nil(t, a) + + a = Join() + require.Nil(t, a) +} + func TestConstructFails(t *testing.T) { cases := []string{ "/ip4", @@ -255,7 +302,7 @@ func TestNilInterface(t *testing.T) { // Test components c, _ := SplitFirst(m1) - c.Equal(m2) + c.AsMultiaddr().Equal(m2) c.Encapsulate(m2) c.Decapsulate(m2) @@ -285,7 +332,7 @@ func TestStringToBytes(t *testing.T) { t.Error("failed to convert \n", s, "to\n", hex.EncodeToString(b1), "got\n", hex.EncodeToString(b2)) } - if err := validateBytes(b2); err != nil { + if _, err := NewMultiaddrBytes(b2); err != nil { t.Error(err, "len:", len(b2)) } } @@ -303,7 +350,6 @@ func TestStringToBytes(t *testing.T) { } func TestBytesToString(t *testing.T) { - testString := func(s1 string, h string) { t.Helper() b, err := hex.DecodeString(h) @@ -311,11 +357,12 @@ func TestBytesToString(t *testing.T) { t.Error("failed to decode hex", h) } - if err := validateBytes(b); err != nil { + if _, err := NewMultiaddrBytes(b); err != nil { t.Error(err) } - s2, err := bytesToString(b) + m, err := NewMultiaddrBytes(b) + s2 := m.String() if err != nil { t.Log("236", s1, ":", string(h), ":", s2) t.Error("failed to convert", b, err) @@ -357,7 +404,7 @@ func TestBytesSplitAndJoin(t *testing.T) { } } - joined := Join(split...) + joined := JoinComponents(split...) if !m.Equal(joined) { t.Errorf("joined components failed: %s != %s", m, joined) } @@ -761,11 +808,11 @@ func TestBinaryMarshaler(t *testing.T) { t.Fatal(err) } - var addr2 multiaddr + var addr2 Multiaddr if err = addr2.UnmarshalBinary(b); err != nil { t.Fatal(err) } - if !addr.Equal(&addr2) { + if !addr.Equal(addr2) { t.Error("expected equal addresses in circular marshaling test") } } @@ -777,11 +824,11 @@ func TestTextMarshaler(t *testing.T) { t.Fatal(err) } - var addr2 multiaddr + var addr2 Multiaddr if err = addr2.UnmarshalText(b); err != nil { t.Fatal(err) } - if !addr.Equal(&addr2) { + if !addr.Equal(addr2) { t.Error("expected equal addresses in circular marshaling test") } } @@ -793,11 +840,11 @@ func TestJSONMarshaler(t *testing.T) { t.Fatal(err) } - var addr2 multiaddr + var addr2 Multiaddr if err = addr2.UnmarshalJSON(b); err != nil { t.Fatal(err) } - if !addr.Equal(&addr2) { + if !addr.Equal(addr2) { t.Error("expected equal addresses in circular marshaling test") } } @@ -812,7 +859,7 @@ func TestComponentBinaryMarshaler(t *testing.T) { t.Fatal(err) } - comp2 := &Component{} + comp2 := Component{} if err = comp2.UnmarshalBinary(b); err != nil { t.Fatal(err) } @@ -831,7 +878,7 @@ func TestComponentTextMarshaler(t *testing.T) { t.Fatal(err) } - comp2 := &Component{} + comp2 := Component{} if err = comp2.UnmarshalText(b); err != nil { t.Fatal(err) } @@ -850,7 +897,7 @@ func TestComponentJSONMarshaler(t *testing.T) { t.Fatal(err) } - comp2 := &Component{} + comp2 := Component{} if err = comp2.UnmarshalJSON(b); err != nil { t.Fatal(err) } @@ -859,6 +906,30 @@ func TestComponentJSONMarshaler(t *testing.T) { } } +func TestUseNil(t *testing.T) { + f := func() Multiaddr { + return nil + } + + _ = f() + + var foo Multiaddr = nil + foo.Bytes() + foo.Compare(nil) + foo.Decapsulate(nil) + foo.Encapsulate(nil) + foo.Equal(nil) + _, _ = foo.MarshalBinary() + _, _ = foo.MarshalJSON() + _, _ = foo.MarshalText() + foo.Protocols() + _ = foo.String() + _ = foo.UnmarshalBinary(nil) + _ = foo.UnmarshalJSON(nil) + _ = foo.UnmarshalText(nil) + _, _ = foo.ValueForProtocol(0) +} + func TestFilterAddrs(t *testing.T) { bad := []Multiaddr{ newMultiaddr(t, "/ip6/fe80::1/tcp/1234"), @@ -1013,13 +1084,13 @@ func FuzzSplitRoundtrip(f *testing.F) { // Test SplitFirst first, rest := SplitFirst(addr) - joined := Join(first, rest) - require.Equal(t, addr, joined, "SplitFirst and Join should round-trip") + joined := Join(first.AsMultiaddr(), rest) + require.True(t, addr.Equal(joined), "SplitFirst and Join should round-trip") // Test SplitLast rest, last := SplitLast(addr) - joined = Join(rest, last) - require.Equal(t, addr, joined, "SplitLast and Join should round-trip") + joined = Join(rest, last.AsMultiaddr()) + require.True(t, addr.Equal(joined), "SplitLast and Join should round-trip") p := addr.Protocols() if len(p) == 0 { @@ -1044,18 +1115,18 @@ func FuzzSplitRoundtrip(f *testing.F) { return c.Protocol().Code == proto.Code } beforeC, after := SplitFirst(addr) - joined = Join(beforeC, after) - require.Equal(t, addr, joined) + joined = Join(beforeC.AsMultiaddr(), after) + require.True(t, addr.Equal(joined)) tryPubMethods(after) before, afterC := SplitLast(addr) - joined = Join(before, afterC) - require.Equal(t, addr, joined) + joined = Join(before, afterC.AsMultiaddr()) + require.True(t, addr.Equal(joined)) tryPubMethods(before) before, after = SplitFunc(addr, splitFunc) joined = Join(before, after) - require.Equal(t, addr, joined) + require.True(t, addr.Equal(joined)) tryPubMethods(before) tryPubMethods(after) } diff --git a/net/convert.go b/net/convert.go index 4603fa2..be320d3 100644 --- a/net/convert.go +++ b/net/convert.go @@ -57,15 +57,17 @@ func MultiaddrToIPNet(m ma.Multiaddr) (*net.IPNet, error) { var ipString string var mask string - ma.ForEach(m, func(c ma.Component) bool { + for _, c := range m { if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { ipString = c.Value() } if c.Protocol().Code == ma.P_IPCIDR { mask = c.Value() } - return ipString == "" || mask == "" - }) + if ipString != "" && mask != "" { + break + } + } if ipString == "" { return nil, errors.New("no ip protocol found") @@ -102,12 +104,17 @@ func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) { func FromIPAndZone(ip net.IP, zone string) (ma.Multiaddr, error) { switch { case ip.To4() != nil: - return ma.NewComponent("ip4", ip.String()) + c, err := ma.NewComponent("ip4", ip.String()) + if err != nil { + return nil, err + } + return c.AsMultiaddr(), nil case ip.To16() != nil: - ip6, err := ma.NewComponent("ip6", ip.String()) + ip6C, err := ma.NewComponent("ip6", ip.String()) if err != nil { return nil, err } + ip6 := ip6C.AsMultiaddr() if zone == "" { return ip6, nil } else { @@ -130,21 +137,18 @@ func FromIP(ip net.IP) (ma.Multiaddr, error) { // ToIP converts a Multiaddr to a net.IP when possible func ToIP(addr ma.Multiaddr) (net.IP, error) { var ip net.IP - ma.ForEach(addr, func(c ma.Component) bool { + for _, c := range addr { switch c.Protocol().Code { case ma.P_IP6ZONE: // we can't return these anyways. - return true + continue case ma.P_IP6, ma.P_IP4: ip = net.IP(c.RawValue()) - return false + return ip, nil } - return false - }) - if ip == nil { return nil, errNotIP } - return ip, nil + return nil, errNotIP } // DialArgs is a convenience function that returns network and address as @@ -167,7 +171,7 @@ func DialArgs(m ma.Multiaddr) (string, string, error) { return network, ip + ":" + port, nil } // Hostname is only true when network is one of the above. - panic("unreachable") + return "", "", errors.New("no hostname") // should be unreachable } switch network { @@ -198,48 +202,48 @@ func DialArgs(m ma.Multiaddr) (string, string, error) { // dialArgComponents extracts the raw pieces used in dialing a Multiaddr func dialArgComponents(m ma.Multiaddr) (zone, network, ip, port string, hostname bool, err error) { - ma.ForEach(m, func(c ma.Component) bool { + for _, c := range m { switch network { case "": switch c.Protocol().Code { case ma.P_IP6ZONE: if zone != "" { err = fmt.Errorf("%s has multiple zones", m) - return false + return } zone = c.Value() - return true + continue case ma.P_IP6: network = "ip6" ip = c.Value() - return true + continue case ma.P_IP4: if zone != "" { err = fmt.Errorf("%s has ip4 with zone", m) - return false + return } network = "ip4" ip = c.Value() - return true + continue case ma.P_DNS: network = "ip" hostname = true ip = c.Value() - return true + continue case ma.P_DNS4: network = "ip4" hostname = true ip = c.Value() - return true + continue case ma.P_DNS6: network = "ip6" hostname = true ip = c.Value() - return true + continue case ma.P_UNIX: network = "unix" ip = c.Value() - return false + return } case "ip": switch c.Protocol().Code { @@ -248,7 +252,7 @@ func dialArgComponents(m ma.Multiaddr) (zone, network, ip, port string, hostname case ma.P_TCP: network = "tcp" default: - return false + return } port = c.Value() case "ip4": @@ -258,7 +262,7 @@ func dialArgComponents(m ma.Multiaddr) (zone, network, ip, port string, hostname case ma.P_TCP: network = "tcp4" default: - return false + return } port = c.Value() case "ip6": @@ -268,13 +272,13 @@ func dialArgComponents(m ma.Multiaddr) (zone, network, ip, port string, hostname case ma.P_TCP: network = "tcp6" default: - return false + return } port = c.Value() } // Done. - return false - }) + return + } return } @@ -354,5 +358,9 @@ func parseUnixNetAddr(a net.Addr) (ma.Multiaddr, error) { path = "/" + path } - return ma.NewComponent("unix", path) + c, err := ma.NewComponent("unix", path) + if err != nil { + return nil, err + } + return c.AsMultiaddr(), nil } diff --git a/net/ip.go b/net/ip.go index def9321..e8acecb 100644 --- a/net/ip.go +++ b/net/ip.go @@ -64,7 +64,7 @@ func IsIPLoopback(m ma.Multiaddr) bool { return false } c, _ := ma.SplitFirst(m) - if c == nil { + if c.Empty() { return false } switch c.Protocol().Code { @@ -83,7 +83,7 @@ func IsIP6LinkLocal(m ma.Multiaddr) bool { return false } c, _ := ma.SplitFirst(m) - if c == nil || c.Protocol().Code != ma.P_IP6 { + if c.Empty() || c.Protocol().Code != ma.P_IP6 { return false } ip := net.IP(c.RawValue()) @@ -106,11 +106,11 @@ func IsIPUnspecified(m ma.Multiaddr) bool { // else return m func zoneless(m ma.Multiaddr) ma.Multiaddr { head, tail := ma.SplitFirst(m) - if head == nil { + if head.Empty() { return nil } if head.Protocol().Code == ma.P_IP6ZONE { - if tail == nil { + if tail.Empty() { return nil } tailhead, _ := ma.SplitFirst(tail) @@ -127,6 +127,6 @@ func zoneless(m ma.Multiaddr) ma.Multiaddr { // used for NAT64 Translation. See RFC 6052 func IsNAT64IPv4ConvertedIPv6Addr(addr ma.Multiaddr) bool { c, _ := ma.SplitFirst(addr) - return c != nil && c.Protocol().Code == ma.P_IP6 && + return !c.Empty() && c.Protocol().Code == ma.P_IP6 && inAddrRange(c.RawValue(), nat64) } diff --git a/net/resolve.go b/net/resolve.go index 44c2ef1..0b43d08 100644 --- a/net/resolve.go +++ b/net/resolve.go @@ -14,7 +14,7 @@ func ResolveUnspecifiedAddress(resolve ma.Multiaddr, ifaceAddrs []ma.Multiaddr) first, rest := ma.SplitFirst(resolve) // if first component (ip) is not unspecified, use it as is. - if !IsIPUnspecified(first) { + if !IsIPUnspecified(first.AsMultiaddr()) { return []ma.Multiaddr{resolve}, nil } diff --git a/net/resolve_test.go b/net/resolve_test.go index e4af820..a1405ae 100644 --- a/net/resolve_test.go +++ b/net/resolve_test.go @@ -36,7 +36,10 @@ func TestResolvingAddrs(t *testing.T) { actual, err := ResolveUnspecifiedAddresses(unspec, iface) require.NoError(t, err) - require.Equal(t, actual, spec) + require.Equal(t, len(actual), len(spec)) + for i := range actual { + require.True(t, actual[i].Equal(spec[i])) + } ip4u := []ma.Multiaddr{newMultiaddr(t, "/ip4/0.0.0.0")} ip4i := []ma.Multiaddr{newMultiaddr(t, "/ip4/1.2.3.4")} diff --git a/testdata/fuzz/FuzzNewMultiaddrBytes/19bd9fb2604afd6f b/testdata/fuzz/FuzzNewMultiaddrBytes/19bd9fb2604afd6f new file mode 100644 index 0000000..1558ec7 --- /dev/null +++ b/testdata/fuzz/FuzzNewMultiaddrBytes/19bd9fb2604afd6f @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("\x90\x03\x03/00!00") diff --git a/util.go b/util.go index 16a347c..71fb9bf 100644 --- a/util.go +++ b/util.go @@ -1,54 +1,48 @@ package multiaddr -import "fmt" +import ( + "fmt" +) // Split returns the sub-address portions of a multiaddr. -func Split(m Multiaddr) []Multiaddr { - if _, ok := m.(*Component); ok { - return []Multiaddr{m} - } - var addrs []Multiaddr - ForEach(m, func(c Component) bool { - addrs = append(addrs, &c) - return true - }) - return addrs +func Split(m Multiaddr) []Component { + return m } -// Join returns a combination of addresses. -func Join(ms ...Multiaddr) Multiaddr { - switch len(ms) { - case 0: - // empty multiaddr, unfortunately, we have callers that rely on - // this contract. - return &multiaddr{} - case 1: - return ms[0] +func JoinComponents(cs ...Component) Multiaddr { + if len(cs) == 0 { + return nil } - - length := 0 - for _, m := range ms { - if m == nil { - continue + out := make([]Component, 0, len(cs)) + for _, c := range cs { + if !c.Empty() { + out = append(out, c) } - length += len(m.Bytes()) } + return out +} - bidx := 0 - b := make([]byte, length) - if length == 0 { +// Join returns a combination of addresses. +// Note: This copies all the components from the input Multiaddrs. Depending on +// your use case, you may prefer to use `append(leftMA, rightMA...)` instead. +func Join(ms ...Multiaddr) Multiaddr { + size := 0 + for _, m := range ms { + size += len(m) + } + if size == 0 { return nil } - for _, mb := range ms { - if mb == nil { - continue + + out := make([]Component, 0, size) + for _, m := range ms { + for _, c := range m { + if !c.Empty() { + out = append(out, c) + } } - bidx += copy(b[bidx:], mb.Bytes()) - } - if length == 0 { - return nil } - return &multiaddr{bytes: b} + return out } // Cast re-casts a byte slice as a multiaddr. will panic if it fails to parse. @@ -70,135 +64,64 @@ func StringCast(s string) Multiaddr { } // SplitFirst returns the first component and the rest of the multiaddr. -func SplitFirst(m Multiaddr) (*Component, Multiaddr) { - if m == nil { - return nil, nil - } - // Shortcut if we already have a component - if c, ok := m.(*Component); ok { - return c, nil - } - - b := m.Bytes() - if len(b) == 0 { - return nil, nil - } - n, c, err := readComponent(b) - if err != nil { - panic(err) +func SplitFirst(m Multiaddr) (Component, Multiaddr) { + if m.Empty() { + return Component{}, nil } - if len(b) == n { - return &c, nil + if len(m) == 1 { + return m[0], nil } - return &c, &multiaddr{b[n:]} + return m[0], m[1:] } // SplitLast returns the rest of the multiaddr and the last component. -func SplitLast(m Multiaddr) (Multiaddr, *Component) { - if m == nil { - return nil, nil - } - - // Shortcut if we already have a component - if c, ok := m.(*Component); ok { - return nil, c +func SplitLast(m Multiaddr) (Multiaddr, Component) { + if m.Empty() { + return nil, Component{} } - - b := m.Bytes() - if len(b) == 0 { - return nil, nil - } - - var ( - c Component - err error - offset int - ) - for { - var n int - n, c, err = readComponent(b[offset:]) - if err != nil { - panic(err) - } - if len(b) == n+offset { - // Reached end - if offset == 0 { - // Only one component - return nil, &c - } - return &multiaddr{b[:offset]}, &c - } - offset += n + if len(m) == 1 { + // We want to explicitly return a nil slice if the prefix is now empty. + return nil, m[0] } + return m[:len(m)-1], m[len(m)-1] } // SplitFunc splits the multiaddr when the callback first returns true. The // component on which the callback first returns will be included in the // *second* multiaddr. func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { - if m == nil { - return nil, nil - } - // Shortcut if we already have a component - if c, ok := m.(*Component); ok { - if cb(*c) { - return nil, m - } - return m, nil - } - b := m.Bytes() - if len(b) == 0 { + if m.Empty() { return nil, nil } - var ( - c Component - err error - offset int - ) - for offset < len(b) { - var n int - n, c, err = readComponent(b[offset:]) - if err != nil { - panic(err) - } + + idx := len(m) + for i, c := range m { if cb(c) { + idx = i break } - offset += n } - switch offset { - case 0: - return nil, m - case len(b): - return m, nil - default: - return &multiaddr{b[:offset]}, &multiaddr{b[offset:]} + pre, post := m[:idx], m[idx:] + if pre.Empty() { + pre = nil + } + if post.Empty() { + post = nil } + return pre, post } // ForEach walks over the multiaddr, component by component. // -// This function iterates over components *by value* to avoid allocating. +// This function iterates over components. // Return true to continue iteration, false to stop. +// +// Prefer to use a standard for range loop instead +// e.g. `for _, c := range m { ... }` func ForEach(m Multiaddr, cb func(c Component) bool) { - if m == nil { - return - } - // Shortcut if we already have a component - if c, ok := m.(*Component); ok { - cb(*c) - return - } - - b := m.Bytes() - for len(b) > 0 { - n, c, err := readComponent(b) - if err != nil { - panic(err) - } + for _, c := range m { if !cb(c) { return } - b = b[n:] } } diff --git a/util_test.go b/util_test.go index 4976d7d..3494486 100644 --- a/util_test.go +++ b/util_test.go @@ -21,8 +21,8 @@ func TestSplitFirstLast(t *testing.T) { head, tail := SplitFirst(addr) rest, last := SplitLast(addr) if len(x) == 0 { - if head != nil { - t.Error("expected head to be nil") + if !head.Empty() { + t.Error("expected head to be empty") } if tail != nil { t.Error("expected tail to be nil") @@ -30,15 +30,15 @@ func TestSplitFirstLast(t *testing.T) { if rest != nil { t.Error("expected rest to be nil") } - if last != nil { - t.Error("expected last to be nil") + if !last.Empty() { + t.Error("expected last to be empty") } continue } - if !head.Equal(StringCast(x[0])) { + if !head.AsMultiaddr().Equal(StringCast(x[0])) { t.Errorf("expected %s to be %s", head, x[0]) } - if !last.Equal(StringCast(x[len(x)-1])) { + if !last.AsMultiaddr().Equal(StringCast(x[len(x)-1])) { t.Errorf("expected %s to be %s", head, x[len(x)-1]) } if len(x) == 1 { @@ -65,33 +65,33 @@ func TestSplitFirstLast(t *testing.T) { t.Fatal(err) } - ci, m := SplitFirst(c) + ci, m := SplitFirst(c.AsMultiaddr()) if !ci.Equal(c) || m != nil { t.Error("split first on component failed") } - m, ci = SplitLast(c) + m, ci = SplitLast(c.AsMultiaddr()) if !ci.Equal(c) || m != nil { t.Error("split last on component failed") } - cis := Split(c) + cis := Split(c.AsMultiaddr()) if len(cis) != 1 || !cis[0].Equal(c) { t.Error("split on component failed") } - m1, m2 := SplitFunc(c, func(c Component) bool { + m1, m2 := SplitFunc(c.AsMultiaddr(), func(c Component) bool { return true }) - if m1 != nil || !m2.Equal(c) { + if m1 != nil || !m2.Equal(c.AsMultiaddr()) { t.Error("split func(true) on component failed") } - m1, m2 = SplitFunc(c, func(c Component) bool { + m1, m2 = SplitFunc(c.AsMultiaddr(), func(c Component) bool { return false }) - if !m1.Equal(c) || m2 != nil { + if !m1.Equal(c.AsMultiaddr()) || m2 != nil { t.Error("split func(false) on component failed") } i := 0 - ForEach(c, func(ci Component) bool { + ForEach(c.AsMultiaddr(), func(ci Component) bool { if i != 0 { t.Error("expected exactly one component") } @@ -119,10 +119,10 @@ func TestSplitFunc(t *testing.T) { for i, cs := range x { target := StringCast(cs) a, b := SplitFunc(addr, func(c Component) bool { - return c.Equal(target) + return c.AsMultiaddr().Equal(target) }) if i == 0 { - if a != nil { + if !a.Empty() { t.Error("expected nil addr") } } else { @@ -135,7 +135,7 @@ func TestSplitFunc(t *testing.T) { } } a, b := SplitFunc(addr, func(_ Component) bool { return false }) - if !a.Equal(addr) || b != nil { + if !a.Equal(addr) || !b.Empty() { t.Error("should not have split") } } diff --git a/v015-MIGRATION.md b/v015-MIGRATION.md new file mode 100644 index 0000000..ab75a39 --- /dev/null +++ b/v015-MIGRATION.md @@ -0,0 +1,15 @@ +## Breaking changes in the large refactor of go-multiaddr v0.15 + +- There is no `Multiaddr` interface type. +- Multiaddr is now a concrete type. Not an interface. +- Empty Multiaddrs/ should be checked with `.Empty()`, not `== nil`. This is similar to how slices should be checked with `len(s) == 0` rather than `s == nil`. +- Components do not implement `Multiaddr` as there is no `Multiaddr` to implement. +- `Multiaddr` can no longer be a key in a Map. If you want unique Multiaddrs, use `Multiaddr.String()` as the key, otherwise you can use the pointer value `*Multiaddr`. + +## Callouts + +- Multiaddr.Bytes() is a `O(n)` operation for n Components, as opposed to a `O(1)` operation. + +## Migration tips for v0.15 + +- If trying to encapsulate a Component to a Multiaddr, use `m.encapsulateC(c)`, instead of the old form of `m.Encapsulate(c)`. `Encapsulate` now only accepts a `Multiaddr`. `EncapsulateC` accepts a `Component`.