Skip to content

Commit 46805b0

Browse files
authored
refactor: Follows up on #261 (#264)
* add comment * decorate err * rename offset field * more validation checks * Add benchmark for component validation * Use *Protocol in Component
1 parent 1ef63b5 commit 46805b0

File tree

5 files changed

+135
-18
lines changed

5 files changed

+135
-18
lines changed

codec.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func stringToBytes(s string) ([]byte, error) {
5555
}
5656
err = p.Transcoder.ValidateBytes(a)
5757
if err != nil {
58-
return nil, err
58+
return nil, fmt.Errorf("failed to validate multiaddr %q: invalid value %q for protocol %s: %w", s, sp[0], p.Name, err)
5959
}
6060
if p.Size < 0 { // varint size.
6161
_, _ = b.Write(varint.ToUvarint(uint64(len(a))))
@@ -79,12 +79,16 @@ func readComponent(b []byte) (int, Component, error) {
7979
if p.Code == 0 {
8080
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
8181
}
82+
pPtr := protocolPtrByCode[code]
83+
if pPtr == nil {
84+
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
85+
}
8286

8387
if p.Size == 0 {
8488
c, err := validateComponent(Component{
85-
bytes: string(b[:offset]),
86-
offset: offset,
87-
protocol: p,
89+
bytes: string(b[:offset]),
90+
valueStartIdx: offset,
91+
protocol: pPtr,
8892
})
8993

9094
return offset, c, err
@@ -109,9 +113,9 @@ func readComponent(b []byte) (int, Component, error) {
109113
}
110114

111115
c, err := validateComponent(Component{
112-
bytes: string(b[:offset+size]),
113-
protocol: p,
114-
offset: offset,
116+
bytes: string(b[:offset+size]),
117+
protocol: pPtr,
118+
valueStartIdx: offset,
115119
})
116120

117121
return offset + size, c, err

component.go

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package multiaddr
22

33
import (
4+
"bytes"
45
"encoding/binary"
56
"encoding/json"
67
"fmt"
@@ -11,9 +12,11 @@ import (
1112

1213
// Component is a single multiaddr Component.
1314
type Component struct {
14-
bytes string // Uses the string type to ensure immutability.
15-
protocol Protocol
16-
offset int
15+
// bytes is the raw bytes of the component. It includes the protocol code as
16+
// varint, possibly the size of the value, and the value.
17+
bytes string // string for immutability.
18+
protocol *Protocol
19+
valueStartIdx int // Index of the first byte of the Component's value in the bytes array
1720
}
1821

1922
func (c Component) AsMultiaddr() Multiaddr {
@@ -107,22 +110,31 @@ func (c Component) Compare(o Component) int {
107110
}
108111

109112
func (c Component) Protocols() []Protocol {
110-
return []Protocol{c.protocol}
113+
if c.protocol == nil {
114+
return nil
115+
}
116+
return []Protocol{*c.protocol}
111117
}
112118

113119
func (c Component) ValueForProtocol(code int) (string, error) {
120+
if c.protocol == nil {
121+
return "", fmt.Errorf("component has nil protocol")
122+
}
114123
if c.protocol.Code != code {
115124
return "", ErrProtocolNotFound
116125
}
117126
return c.Value(), nil
118127
}
119128

120129
func (c Component) Protocol() Protocol {
121-
return c.protocol
130+
if c.protocol == nil {
131+
return Protocol{}
132+
}
133+
return *c.protocol
122134
}
123135

124136
func (c Component) RawValue() []byte {
125-
return []byte(c.bytes[c.offset:])
137+
return []byte(c.bytes[c.valueStartIdx:])
126138
}
127139

128140
func (c Component) Value() string {
@@ -135,10 +147,13 @@ func (c Component) Value() string {
135147
}
136148

137149
func (c Component) valueAndErr() (string, error) {
150+
if c.protocol == nil {
151+
return "", fmt.Errorf("component has nil protocol")
152+
}
138153
if c.protocol.Transcoder == nil {
139154
return "", nil
140155
}
141-
value, err := c.protocol.Transcoder.BytesToString([]byte(c.bytes[c.offset:]))
156+
value, err := c.protocol.Transcoder.BytesToString([]byte(c.bytes[c.valueStartIdx:]))
142157
if err != nil {
143158
return "", err
144159
}
@@ -154,6 +169,9 @@ func (c Component) String() string {
154169
// writeTo is an efficient, private function for string-formatting a multiaddr.
155170
// Trust me, we tend to allocate a lot when doing this.
156171
func (c Component) writeTo(b *strings.Builder) {
172+
if c.protocol == nil {
173+
return
174+
}
157175
b.WriteByte('/')
158176
b.WriteString(c.protocol.Name)
159177
value := c.Value()
@@ -185,6 +203,11 @@ func NewComponent(protocol, value string) (Component, error) {
185203
}
186204

187205
func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
206+
protocolPtr := protocolPtrByCode[protocol.Code]
207+
if protocolPtr == nil {
208+
protocolPtr = &protocol
209+
}
210+
188211
size := len(bvalue)
189212
size += len(protocol.VCode)
190213
if protocol.Size < 0 {
@@ -205,23 +228,63 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
205228

206229
return validateComponent(
207230
Component{
208-
bytes: string(maddr),
209-
protocol: protocol,
210-
offset: offset,
231+
bytes: string(maddr),
232+
protocol: protocolPtr,
233+
valueStartIdx: offset,
211234
})
212235
}
213236

214237
// validateComponent MUST be called after creating a non-zero Component.
215238
// It ensures that we will be able to call all methods on Component without
216239
// error.
217240
func validateComponent(c Component) (Component, error) {
241+
if c.protocol == nil {
242+
return Component{}, fmt.Errorf("component is missing its protocol")
243+
}
244+
if c.valueStartIdx > len(c.bytes) {
245+
return Component{}, fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes")
246+
}
247+
248+
if len(c.protocol.VCode) == 0 {
249+
return Component{}, fmt.Errorf("Component is missing its protocol's VCode field")
250+
}
251+
if len(c.bytes) < len(c.protocol.VCode) {
252+
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.bytes), len(c.protocol.VCode))
253+
}
254+
if !bytes.Equal([]byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode) {
255+
return Component{}, fmt.Errorf("component's VCode field is invalid: %v != %v", []byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode)
256+
}
257+
if c.protocol.Size < 0 {
258+
size, n, err := ReadVarintCode([]byte(c.bytes[len(c.protocol.VCode):]))
259+
if err != nil {
260+
return Component{}, err
261+
}
262+
if size != len(c.bytes[c.valueStartIdx:]) {
263+
return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
264+
}
265+
266+
if len(c.protocol.VCode)+n+size != len(c.bytes) {
267+
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+n+size, len(c.bytes))
268+
}
269+
} else {
270+
// Fixed size value
271+
size := c.protocol.Size / 8
272+
if size != len(c.bytes[c.valueStartIdx:]) {
273+
return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
274+
}
275+
276+
if len(c.protocol.VCode)+size != len(c.bytes) {
277+
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+size, len(c.bytes))
278+
}
279+
}
280+
218281
_, err := c.valueAndErr()
219282
if err != nil {
220283
return Component{}, err
221284

222285
}
223286
if c.protocol.Transcoder != nil {
224-
err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.offset:]))
287+
err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.valueStartIdx:]))
225288
if err != nil {
226289
return Component{}, err
227290
}

matest/matest.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ type MultiaddrMatcher struct {
7272
multiaddr.Multiaddr
7373
}
7474

75+
// Implements the Matcher interface for gomock.Matcher
76+
// Let's us use this struct in gomock tests. Example:
77+
// Expect(mock.Method(gomock.Any(), multiaddrMatcher).Return(nil)
7578
func (m MultiaddrMatcher) Matches(x interface{}) bool {
7679
if m2, ok := x.(multiaddr.Multiaddr); ok {
7780
return m.Equal(m2)

multiaddr_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,3 +1132,43 @@ func FuzzSplitRoundtrip(f *testing.F) {
11321132
}
11331133
})
11341134
}
1135+
1136+
func BenchmarkComponentValidation(b *testing.B) {
1137+
comp, err := NewComponent("ip4", "127.0.0.1")
1138+
if err != nil {
1139+
b.Fatal(err)
1140+
}
1141+
b.ReportAllocs()
1142+
for i := 0; i < b.N; i++ {
1143+
_, err := validateComponent(comp)
1144+
if err != nil {
1145+
b.Fatal(err)
1146+
}
1147+
}
1148+
}
1149+
1150+
func FuzzComponents(f *testing.F) {
1151+
for _, v := range good {
1152+
m := StringCast(v)
1153+
for _, c := range m {
1154+
f.Add(c.Bytes())
1155+
}
1156+
}
1157+
f.Fuzz(func(t *testing.T, compBytes []byte) {
1158+
n, c, err := readComponent(compBytes)
1159+
if err != nil {
1160+
t.Skip()
1161+
}
1162+
if c.protocol == nil {
1163+
t.Fatal("component has nil protocol")
1164+
}
1165+
if c.protocol.Code == 0 {
1166+
t.Fatal("component has nil protocol code")
1167+
}
1168+
if !bytes.Equal(c.Bytes(), compBytes[:n]) {
1169+
t.Logf("component bytes: %v", c.Bytes())
1170+
t.Logf("original bytes: %v", compBytes[:n])
1171+
t.Fatal("component bytes are not equal to the original bytes")
1172+
}
1173+
})
1174+
}

protocol.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ type Protocol struct {
4747
var protocolsByName = map[string]Protocol{}
4848
var protocolsByCode = map[int]Protocol{}
4949

50+
// Keep a map of pointers so that we can reuse the same pointer for the same protocol.
51+
var protocolPtrByCode = map[int]*Protocol{}
52+
5053
// Protocols is the list of multiaddr protocols supported by this module.
5154
var Protocols = []Protocol{}
5255

@@ -65,10 +68,14 @@ func AddProtocol(p Protocol) error {
6568
if p.Path && p.Size >= 0 {
6669
return fmt.Errorf("path protocols must have variable-length sizes")
6770
}
71+
if len(p.VCode) == 0 {
72+
return fmt.Errorf("protocol code %d is missing its VCode field", p.Code)
73+
}
6874

6975
Protocols = append(Protocols, p)
7076
protocolsByName[p.Name] = p
7177
protocolsByCode[p.Code] = p
78+
protocolPtrByCode[p.Code] = &p
7279
return nil
7380
}
7481

0 commit comments

Comments
 (0)