Skip to content

Commit d8b9b2a

Browse files
committed
feat(natpmp): rpc error contain all failed attempt messages
1 parent c826707 commit d8b9b2a

File tree

3 files changed

+98
-8
lines changed

3 files changed

+98
-8
lines changed

internal/natpmp/portmapping_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ func Test_Client_AddPortMapping(t *testing.T) {
4949
initialConnectionDuration: time.Millisecond,
5050
exchanges: []udpExchange{{close: true}},
5151
err: ErrConnectionTimeout,
52-
errMessage: "executing remote procedure call: connection timeout: after 1ms",
52+
errMessage: "executing remote procedure call: connection timeout: failed attempts: " +
53+
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
5354
},
5455
"add_udp": {
5556
ctx: context.Background(),

internal/natpmp/rpc.go

+56-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"fmt"
77
"net"
88
"net/netip"
9+
"sort"
10+
"strings"
911
"time"
1012
)
1113

@@ -65,9 +67,8 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
6567
// Note it does not double if the source IP mismatches the gateway IP.
6668
connectionDuration := c.initialConnectionDuration
6769

68-
var totalRetryDuration time.Duration
69-
7070
var retryCount uint
71+
var failedAttempts []string
7172
for retryCount = 0; retryCount < c.maxRetries; retryCount++ {
7273
deadline := time.Now().Add(connectionDuration)
7374
err = connection.SetDeadline(deadline)
@@ -87,8 +88,8 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
8788
}
8889
var netErr net.Error
8990
if errors.As(err, &netErr) && netErr.Timeout() {
90-
totalRetryDuration += connectionDuration
9191
connectionDuration *= 2
92+
failedAttempts = append(failedAttempts, netErr.Error())
9293
continue
9394
}
9495
return nil, fmt.Errorf("reading from udp connection: %w", err)
@@ -98,6 +99,9 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
9899
// Upon receiving a response packet, the client MUST check the source IP
99100
// address, and silently discard the packet if the address is not the
100101
// address of the gateway to which the request was sent.
102+
failedAttempts = append(failedAttempts,
103+
fmt.Sprintf("received response from %s instead of gateway IP %s",
104+
receivedRemoteAddress.IP, gatewayAddress.IP))
101105
continue
102106
}
103107

@@ -106,8 +110,8 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
106110
}
107111

108112
if retryCount == c.maxRetries {
109-
return nil, fmt.Errorf("%w: after %s",
110-
ErrConnectionTimeout, totalRetryDuration)
113+
return nil, fmt.Errorf("%w: failed attempts: %s",
114+
ErrConnectionTimeout, dedupFailedAttempts(failedAttempts))
111115
}
112116

113117
// Opcodes between 0 and 127 are client requests. Opcodes from 128 to
@@ -121,3 +125,50 @@ func (c *Client) rpc(ctx context.Context, gateway netip.Addr,
121125

122126
return response, nil
123127
}
128+
129+
func dedupFailedAttempts(failedAttempts []string) (errorMessage string) {
130+
type data struct {
131+
message string
132+
indices []int
133+
}
134+
messageToData := make(map[string]data, len(failedAttempts))
135+
for i, message := range failedAttempts {
136+
metadata, ok := messageToData[message]
137+
if !ok {
138+
metadata.message = message
139+
}
140+
metadata.indices = append(metadata.indices, i)
141+
sort.Slice(metadata.indices, func(i, j int) bool {
142+
return metadata.indices[i] < metadata.indices[j]
143+
})
144+
messageToData[message] = metadata
145+
}
146+
147+
// Sort by first index
148+
dataSlice := make([]data, 0, len(messageToData))
149+
for _, metadata := range messageToData {
150+
dataSlice = append(dataSlice, metadata)
151+
}
152+
sort.Slice(dataSlice, func(i, j int) bool {
153+
return dataSlice[i].indices[0] < dataSlice[j].indices[0]
154+
})
155+
156+
dedupedFailedAttempts := make([]string, 0, len(dataSlice))
157+
for _, data := range dataSlice {
158+
newMessage := fmt.Sprintf("%s (%s)", data.message,
159+
indicesToTryString(data.indices))
160+
dedupedFailedAttempts = append(dedupedFailedAttempts, newMessage)
161+
}
162+
return strings.Join(dedupedFailedAttempts, "; ")
163+
}
164+
165+
func indicesToTryString(indices []int) string {
166+
if len(indices) == 1 {
167+
return fmt.Sprintf("try %d", indices[0]+1)
168+
}
169+
tries := make([]string, len(indices))
170+
for i, index := range indices {
171+
tries[i] = fmt.Sprintf("%d", index+1)
172+
}
173+
return fmt.Sprintf("tries %s", strings.Join(tries, ", "))
174+
}

internal/natpmp/rpc_test.go

+40-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ func Test_Client_rpc(t *testing.T) {
5353
exchanges: []udpExchange{
5454
{request: []byte{0, 1}, close: true},
5555
},
56-
err: ErrConnectionTimeout,
57-
errMessage: "connection timeout: after 1ms",
56+
err: ErrConnectionTimeout,
57+
errMessage: "connection timeout: failed attempts: " +
58+
"read udp 127.0.0.1:[1-9][0-9]{0,4}->127.0.0.1:[1-9][0-9]{0,4}: i/o timeout \\(try 1\\)",
5859
},
5960
"response_too_small": {
6061
ctx: context.Background(),
@@ -164,3 +165,40 @@ func Test_Client_rpc(t *testing.T) {
164165
})
165166
}
166167
}
168+
169+
func Test_dedupFailedAttempts(t *testing.T) {
170+
t.Parallel()
171+
172+
testCases := map[string]struct {
173+
failedAttempts []string
174+
expected string
175+
}{
176+
"empty": {},
177+
"single_attempt": {
178+
failedAttempts: []string{"test"},
179+
expected: "test (try 1)",
180+
},
181+
"multiple_same_attempts": {
182+
failedAttempts: []string{"test", "test", "test"},
183+
expected: "test (tries 1, 2, 3)",
184+
},
185+
"multiple_different_attempts": {
186+
failedAttempts: []string{"test1", "test2", "test3"},
187+
expected: "test1 (try 1); test2 (try 2); test3 (try 3)",
188+
},
189+
"soup_mix": {
190+
failedAttempts: []string{"test1", "test2", "test1", "test3", "test2"},
191+
expected: "test1 (tries 1, 3); test2 (tries 2, 5); test3 (try 4)",
192+
},
193+
}
194+
195+
for name, testCase := range testCases {
196+
testCase := testCase
197+
t.Run(name, func(t *testing.T) {
198+
t.Parallel()
199+
200+
actual := dedupFailedAttempts(testCase.failedAttempts)
201+
assert.Equal(t, testCase.expected, actual)
202+
})
203+
}
204+
}

0 commit comments

Comments
 (0)