Skip to content

Commit 4cb9f94

Browse files
committed
Add CheckAddrWithOptions method with IPv4/IPv6 and socket mark support
1 parent 716b728 commit 4cb9f94

File tree

7 files changed

+288
-12
lines changed

7 files changed

+288
-12
lines changed

checker_linux.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,25 @@ func (c *Checker) CheckAddr(addr string, timeout time.Duration) (err error) {
141141

142142
// CheckAddrZeroLinger is like CheckAddr with an extra parameter indicating whether to enable zero linger.
143143
func (c *Checker) CheckAddrZeroLinger(addr string, timeout time.Duration, zeroLinger bool) error {
144+
opts := DefaultOptions().WithTimeout(timeout).WithZeroLinger(zeroLinger)
145+
return c.CheckAddrWithOptions(addr, opts)
146+
}
147+
148+
// CheckAddrWithOptions performs a TCP check with given address and options.
149+
// A successful check will result in nil error.
150+
// ErrTimeout is returned if timeout.
151+
// Note: timeout includes domain resolving.
152+
func (c *Checker) CheckAddrWithOptions(addr string, opts Options) error {
144153
// Set deadline
145-
deadline := time.Now().Add(timeout)
154+
deadline := time.Now().Add(opts.Timeout)
146155

147-
// Parse address
148-
rAddr, family, err := parseSockAddr(addr)
156+
// Parse address with specified network
157+
rAddr, family, err := parseSockAddrWithNetwork(addr, opts.Network)
149158
if err != nil {
150159
return err
151160
}
152161
// Create socket with options set
153-
fd, err := createSocketZeroLinger(family, zeroLinger)
162+
fd, err := createSocketWithOptions(family, opts.ZeroLinger, opts.Mark)
154163
if err != nil {
155164
return err
156165
}

checker_nonlinux.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,17 @@ func (c *Checker) CheckAddr(addr string, timeout time.Duration) error {
4242

4343
// CheckAddrZeroLinger is CheckerAddr with a zeroLinger parameter.
4444
func (c *Checker) CheckAddrZeroLinger(addr string, timeout time.Duration, zeroLinger bool) error {
45-
conn, err := net.DialTimeout("tcp", addr, timeout)
45+
opts := DefaultOptions().WithTimeout(timeout).WithZeroLinger(zeroLinger)
46+
return c.CheckAddrWithOptions(addr, opts)
47+
}
48+
49+
// CheckAddrWithOptions performs a TCP check with given address and options.
50+
// NOTE: zeroLinger is ignored on non-POSIX operating systems because
51+
// net.TCPConn.SetLinger is only implemented in src/net/sockopt_posix.go.
52+
func (c *Checker) CheckAddrWithOptions(addr string, opts Options) error {
53+
conn, err := net.DialTimeout(opts.Network, addr, opts.Timeout)
4654
if conn != nil {
47-
if zeroLinger {
55+
if opts.ZeroLinger {
4856
// Simply ignore the error since this is a fake implementation.
4957
_ = conn.(*net.TCPConn).SetLinger(0)
5058
}

options.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package tcp
2+
3+
import (
4+
"time"
5+
)
6+
7+
// Options contains configuration for TCP connectivity checks
8+
type Options struct {
9+
// Timeout specifies the maximum duration for the check operation
10+
Timeout time.Duration
11+
12+
// Network specifies the network type for address resolution and connection
13+
// Supported values:
14+
// "tcp" - Try IPv4 first, then IPv6 (default behavior)
15+
// "tcp4" - IPv4 only
16+
// "tcp6" - IPv6 only
17+
Network string
18+
19+
// ZeroLinger indicates whether to set SO_LINGER with zero timeout
20+
// This forces the connection to be reset immediately when closed
21+
ZeroLinger bool
22+
23+
// Mark sets the SO_MARK socket option (Linux only)
24+
// This is useful for traffic marking and routing policies
25+
// Value of 0 means no mark is set
26+
Mark int
27+
}
28+
29+
// DefaultOptions returns Options with default values
30+
func DefaultOptions() Options {
31+
return Options{
32+
Timeout: time.Second * 3,
33+
Network: "tcp",
34+
ZeroLinger: true,
35+
Mark: 0, // No mark by default
36+
}
37+
}
38+
39+
// WithTimeout sets the timeout for the operation
40+
func (o Options) WithTimeout(timeout time.Duration) Options {
41+
o.Timeout = timeout
42+
return o
43+
}
44+
45+
// WithNetwork sets the network type (tcp, tcp4, tcp6)
46+
func (o Options) WithNetwork(network string) Options {
47+
o.Network = network
48+
return o
49+
}
50+
51+
// WithZeroLinger sets the zero linger option
52+
func (o Options) WithZeroLinger(zeroLinger bool) Options {
53+
o.ZeroLinger = zeroLinger
54+
return o
55+
}
56+
57+
// WithMark sets the SO_MARK socket option (Linux only)
58+
// This is useful for traffic marking and routing policies
59+
func (o Options) WithMark(mark int) Options {
60+
o.Mark = mark
61+
return o
62+
}

options_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package tcp
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
)
8+
9+
// setupTestChecker creates a checker with CheckingLoop running and waits until ready
10+
func setupTestChecker(t *testing.T) (*Checker, context.CancelFunc) {
11+
checker := NewChecker()
12+
ctx, cancel := context.WithCancel(context.Background())
13+
go func() {
14+
_ = checker.CheckingLoop(ctx)
15+
}()
16+
<-checker.WaitReady()
17+
return checker, cancel
18+
}
19+
20+
// testWithChecker runs a test function with a prepared checker and test server
21+
func testWithChecker(t *testing.T, testFunc func(*testing.T, *Checker, string)) {
22+
checker, cancel := setupTestChecker(t)
23+
defer cancel()
24+
25+
testAddr, stopServer := StartTestServer()
26+
defer stopServer()
27+
28+
testFunc(t, checker, testAddr)
29+
}
30+
31+
func TestOptions(t *testing.T) {
32+
// Test default options
33+
opts := DefaultOptions()
34+
if opts.Network != "tcp" {
35+
t.Errorf("expected default network to be 'tcp', got %s", opts.Network)
36+
}
37+
if opts.Timeout != 3*time.Second {
38+
t.Errorf("expected default timeout to be 3s, got %v", opts.Timeout)
39+
}
40+
if !opts.ZeroLinger {
41+
t.Error("expected default ZeroLinger to be true")
42+
}
43+
if opts.Mark != 0 {
44+
t.Errorf("expected default Mark to be 0, got %d", opts.Mark)
45+
}
46+
47+
// Test fluent interface
48+
customOpts := DefaultOptions().
49+
WithTimeout(5 * time.Second).
50+
WithNetwork("tcp6").
51+
WithZeroLinger(false).
52+
WithMark(100)
53+
54+
if customOpts.Network != "tcp6" {
55+
t.Errorf("expected network to be 'tcp6', got %s", customOpts.Network)
56+
}
57+
if customOpts.Timeout != 5*time.Second {
58+
t.Errorf("expected timeout to be 5s, got %v", customOpts.Timeout)
59+
}
60+
if customOpts.ZeroLinger {
61+
t.Error("expected ZeroLinger to be false")
62+
}
63+
if customOpts.Mark != 100 {
64+
t.Errorf("expected Mark to be 100, got %d", customOpts.Mark)
65+
}
66+
}
67+
68+
func TestCheckerWithOptions(t *testing.T) {
69+
testWithChecker(t, func(t *testing.T, checker *Checker, testAddr string) {
70+
// Test basic functionality with default options
71+
opts := DefaultOptions().WithTimeout(2 * time.Second)
72+
err := checker.CheckAddrWithOptions(testAddr, opts)
73+
if err != nil {
74+
t.Errorf("Connection to test server failed: %v", err)
75+
}
76+
77+
// Test IPv4 specific
78+
opts4 := DefaultOptions().WithTimeout(2 * time.Second).WithNetwork("tcp4")
79+
err = checker.CheckAddrWithOptions(testAddr, opts4)
80+
if err != nil {
81+
t.Errorf("IPv4 connection to test server failed: %v", err)
82+
}
83+
})
84+
}
85+
86+
func TestIPv6Support(t *testing.T) {
87+
checker, cancel := setupTestChecker(t)
88+
defer cancel()
89+
90+
// Try to start IPv6 server
91+
testAddr6, stopServer6, err := StartTestServerIPv6()
92+
if err != nil {
93+
t.Skipf("Skipping IPv6 test: %v", err)
94+
return
95+
}
96+
defer stopServer6()
97+
98+
// Test IPv6 connection with tcp6 network
99+
opts6 := DefaultOptions().WithTimeout(2 * time.Second).WithNetwork("tcp6")
100+
err = checker.CheckAddrWithOptions(testAddr6, opts6)
101+
if err != nil {
102+
t.Errorf("IPv6 connection to test server failed: %v", err)
103+
}
104+
105+
// Test IPv6 connection with tcp network (should also work)
106+
optsGeneral := DefaultOptions().WithTimeout(2 * time.Second).WithNetwork("tcp")
107+
err = checker.CheckAddrWithOptions(testAddr6, optsGeneral)
108+
if err != nil {
109+
t.Errorf("General TCP connection to IPv6 test server failed: %v", err)
110+
}
111+
112+
// Test that tcp4 fails on IPv6 address (should fail)
113+
opts4 := DefaultOptions().WithTimeout(1 * time.Second).WithNetwork("tcp4")
114+
err = checker.CheckAddrWithOptions(testAddr6, opts4)
115+
if err == nil {
116+
t.Error("Expected tcp4 connection to IPv6 address to fail, but it succeeded")
117+
} else {
118+
t.Logf("tcp4 connection to IPv6 address failed as expected: %v", err)
119+
}
120+
}
121+
122+
func TestMarkOption(t *testing.T) {
123+
testWithChecker(t, func(t *testing.T, checker *Checker, testAddr string) {
124+
// Test with mark set (only meaningful on Linux)
125+
opts := DefaultOptions().WithTimeout(2 * time.Second).WithMark(42)
126+
err := checker.CheckAddrWithOptions(testAddr, opts)
127+
if err != nil {
128+
t.Errorf("Connection with mark to test server failed: %v", err)
129+
}
130+
})
131+
}
132+
133+
func TestBackwardCompatibility(t *testing.T) {
134+
testWithChecker(t, func(t *testing.T, checker *Checker, testAddr string) {
135+
// Old API should still work
136+
err := checker.CheckAddr(testAddr, 2*time.Second)
137+
if err != nil {
138+
t.Errorf("Old API connection to test server failed: %v", err)
139+
}
140+
141+
// CheckAddrZeroLinger should use new implementation internally
142+
err = checker.CheckAddrZeroLinger(testAddr, 2*time.Second, true)
143+
if err != nil {
144+
t.Errorf("CheckAddrZeroLinger connection to test server failed: %v", err)
145+
}
146+
})
147+
}

socket_linux.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,28 @@ import (
1212

1313
const maxEpollEvents = 32
1414

15-
// createSocket creates a socket with necessary options set.
16-
func createSocketZeroLinger(family int, zeroLinger bool) (fd int, err error) {
15+
// createSocketWithOptions creates a socket with specified options
16+
func createSocketWithOptions(family int, zeroLinger bool, mark int) (fd int, err error) {
1717
// Create socket
1818
fd, err = _createNonBlockingSocket(family)
19-
if err == nil {
20-
if zeroLinger {
21-
err = _setZeroLinger(fd)
19+
if err != nil {
20+
return
21+
}
22+
23+
if zeroLinger {
24+
if err = _setZeroLinger(fd); err != nil {
25+
_ = unix.Close(fd)
26+
return
2227
}
2328
}
29+
30+
if mark != 0 {
31+
if err = _setMark(fd, mark); err != nil {
32+
_ = unix.Close(fd)
33+
return
34+
}
35+
}
36+
2437
return
2538
}
2639

@@ -62,6 +75,11 @@ func _setZeroLinger(fd int) error {
6275
return unix.SetsockoptLinger(fd, unix.SOL_SOCKET, unix.SO_LINGER, &zeroLinger)
6376
}
6477

78+
// setMark sets SO_MARK for given fd (Linux only)
79+
func _setMark(fd int, mark int) error {
80+
return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, mark)
81+
}
82+
6583
func createPoller() (fd int, err error) {
6684
fd, err = unix.EpollCreate1(unix.EPOLL_CLOEXEC)
6785
if err != nil {

socket_unix.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@ import (
1010

1111
// parseSockAddr resolves given addr to unix.Sockaddr
1212
func parseSockAddr(addr string) (sAddr unix.Sockaddr, family int, err error) {
13-
tAddr, err := net.ResolveTCPAddr("tcp", addr)
13+
return parseSockAddrWithNetwork(addr, "tcp")
14+
}
15+
16+
// parseSockAddrWithNetwork resolves given addr to unix.Sockaddr with specified network
17+
func parseSockAddrWithNetwork(addr, network string) (sAddr unix.Sockaddr, family int, err error) {
18+
tAddr, err := net.ResolveTCPAddr(network, addr)
1419
if err != nil {
1520
return
1621
}

test_server_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tcp
22

33
import (
44
"context"
5+
"net"
56
"net/http"
67
"net/http/httptest"
78
)
@@ -12,3 +13,29 @@ func StartTestServer() (string, context.CancelFunc) {
1213
addr := ts.Listener.Addr().String()
1314
return addr, ts.Close
1415
}
16+
17+
// StartTestServerIPv6 starts a test HTTP server on IPv6 loopback and returns its address and a cancel function
18+
func StartTestServerIPv6() (string, context.CancelFunc, error) {
19+
// Create a listener on IPv6 loopback address
20+
listener, err := net.Listen("tcp6", "[::1]:0")
21+
if err != nil {
22+
return "", nil, err
23+
}
24+
25+
// Create HTTP server with the IPv6 listener
26+
server := &http.Server{
27+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
28+
}
29+
30+
// Start server in background
31+
go func() {
32+
_ = server.Serve(listener)
33+
}()
34+
35+
addr := listener.Addr().String()
36+
cancel := func() {
37+
_ = server.Close()
38+
}
39+
40+
return addr, cancel, nil
41+
}

0 commit comments

Comments
 (0)