Skip to content

Commit 4db73a1

Browse files
committed
Implement redirect logic in UDP proxy
1 parent 06a17f0 commit 4db73a1

File tree

5 files changed

+214
-28
lines changed

5 files changed

+214
-28
lines changed

client/iface/wgproxy/bind/proxy.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (p *ProxyBind) Pause() {
9292
p.pausedCond.L.Unlock()
9393
}
9494

95-
func (p *ProxyBind) RedirectTo(endpoint *net.UDPAddr) {
95+
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
9696
p.pausedCond.L.Lock()
9797
p.paused = false
9898

client/iface/wgproxy/ebpf/wrapper.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (p *ProxyWrapper) Pause() {
8181
p.pausedCond.L.Unlock()
8282
}
8383

84-
func (p *ProxyWrapper) RedirectTo(endpoint *net.UDPAddr) {
84+
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
8585
p.pausedCond.L.Lock()
8686
p.paused = false
8787

client/iface/wgproxy/proxy.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ type Proxy interface {
1111
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
1212
Work() // Work start or resume the proxy
1313
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
14+
/*
15+
RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
16+
and rewrite the src address to the endpoint address.
17+
With this logic can avoid the package loss from relayed connections.
18+
*/
19+
RedirectAs(endpoint *net.UDPAddr)
1420
CloseConn() error
15-
RedirectTo(endpoint *net.UDPAddr)
1621
}

client/iface/wgproxy/udp/proxy.go

+67-25
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@ import (
1818
type WGUDPProxy struct {
1919
localWGListenPort int
2020

21-
remoteConn net.Conn
22-
localConn net.Conn
23-
ctx context.Context
24-
cancel context.CancelFunc
25-
closeMu sync.Mutex
26-
closed bool
27-
28-
pausedMu sync.Mutex
29-
paused bool
30-
isStarted bool
21+
remoteConn net.Conn
22+
localConn net.Conn
23+
srcFakerConn *SrcFaker
24+
sendPkg func(data []byte) (int, error)
25+
ctx context.Context
26+
cancel context.CancelFunc
27+
closeMu sync.Mutex
28+
closed bool
29+
30+
paused bool
31+
pausedCond *sync.Cond
32+
isStarted bool
3133
}
3234

3335
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
3436
func NewWGUDPProxy(wgPort int) *WGUDPProxy {
3537
log.Debugf("Initializing new user space proxy with port %d", wgPort)
3638
p := &WGUDPProxy{
3739
localWGListenPort: wgPort,
40+
pausedCond: sync.NewCond(&sync.Mutex{}),
3841
}
3942
return p
4043
}
@@ -54,6 +57,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
5457

5558
p.ctx, p.cancel = context.WithCancel(ctx)
5659
p.localConn = localConn
60+
p.sendPkg = p.localConn.Write
5761
p.remoteConn = remoteConn
5862

5963
return err
@@ -73,15 +77,17 @@ func (p *WGUDPProxy) Work() {
7377
return
7478
}
7579

76-
p.pausedMu.Lock()
80+
p.pausedCond.L.Lock()
7781
p.paused = false
78-
p.pausedMu.Unlock()
82+
p.sendPkg = p.localConn.Write
7983

8084
if !p.isStarted {
8185
p.isStarted = true
8286
go p.proxyToRemote(p.ctx)
8387
go p.proxyToLocal(p.ctx)
8488
}
89+
p.pausedCond.L.Unlock()
90+
p.pausedCond.Signal()
8591
}
8692

8793
// Pause pauses the proxy from receiving data from the remote peer
@@ -90,13 +96,33 @@ func (p *WGUDPProxy) Pause() {
9096
return
9197
}
9298

93-
p.pausedMu.Lock()
99+
p.pausedCond.L.Lock()
94100
p.paused = true
95-
p.pausedMu.Unlock()
101+
p.pausedCond.L.Unlock()
96102
}
97103

98-
func (p *WGUDPProxy) RedirectTo(endpoint *net.UDPAddr) {
99-
// todo implement me
104+
// RedirectAs start to use the fake sourced raw socket as package sender
105+
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
106+
p.pausedCond.L.Lock()
107+
defer func() {
108+
p.pausedCond.L.Unlock()
109+
p.pausedCond.Signal()
110+
}()
111+
112+
p.paused = false
113+
if p.srcFakerConn != nil {
114+
if err := p.srcFakerConn.Close(); err != nil {
115+
log.Errorf("failed to close src faker conn: %s", err)
116+
}
117+
p.srcFakerConn = nil
118+
}
119+
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
120+
if err != nil {
121+
log.Errorf("failed to create src faker conn: %s", err)
122+
return
123+
}
124+
p.srcFakerConn = srcFakerConn
125+
p.sendPkg = p.srcFakerConn.SendPkg
100126
}
101127

102128
// CloseConn close the localConn
@@ -108,25 +134,35 @@ func (p *WGUDPProxy) CloseConn() error {
108134
}
109135

110136
func (p *WGUDPProxy) close() error {
137+
var result *multierror.Error
138+
111139
p.closeMu.Lock()
112140
defer p.closeMu.Unlock()
113141

114142
// prevent double close
115143
if p.closed {
116144
return nil
117145
}
118-
p.closed = true
119146

120147
p.cancel()
121148

122-
var result *multierror.Error
149+
p.pausedCond.L.Lock()
150+
p.paused = false
151+
p.pausedCond.L.Unlock()
152+
p.pausedCond.Signal()
153+
123154
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
124155
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
125156
}
126157

127158
if err := p.localConn.Close(); err != nil {
128159
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
129160
}
161+
162+
if err := p.srcFakerConn.Close(); err != nil {
163+
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
164+
}
165+
130166
return cerrors.FormatErrorOrNil(result)
131167
}
132168

@@ -179,14 +215,20 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
179215
return
180216
}
181217

182-
p.pausedMu.Lock()
183-
if p.paused {
184-
p.pausedMu.Unlock()
185-
continue
218+
for {
219+
p.pausedCond.L.Lock()
220+
if p.paused {
221+
p.pausedCond.Wait()
222+
if !p.paused {
223+
break
224+
}
225+
p.pausedCond.L.Unlock()
226+
continue
227+
}
228+
break
186229
}
187-
188-
_, err = p.localConn.Write(buf[:n])
189-
p.pausedMu.Unlock()
230+
_, err = p.sendPkg(buf[:n])
231+
p.pausedCond.L.Unlock()
190232

191233
if err != nil {
192234
if ctx.Err() != nil {

client/iface/wgproxy/udp/rawsocket.go

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package udp
2+
3+
import (
4+
"fmt"
5+
"net"
6+
"os"
7+
"syscall"
8+
9+
"github.com/google/gopacket"
10+
"github.com/google/gopacket/layers"
11+
log "github.com/sirupsen/logrus"
12+
13+
nbnet "github.com/netbirdio/netbird/util/net"
14+
)
15+
16+
var (
17+
serializeOpts = gopacket.SerializeOptions{
18+
ComputeChecksums: true,
19+
FixLengths: true,
20+
}
21+
22+
localHostNetIPAddr = &net.IPAddr{
23+
IP: net.ParseIP("127.0.0.1"),
24+
}
25+
)
26+
27+
type SrcFaker struct {
28+
srcAddr *net.UDPAddr
29+
30+
rawSocket net.PacketConn
31+
ipH gopacket.SerializableLayer
32+
udpH gopacket.SerializableLayer
33+
layerBuffer gopacket.SerializeBuffer
34+
}
35+
36+
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
37+
rawSocket, err := prepareSenderRawSocket()
38+
if err != nil {
39+
return nil, err
40+
}
41+
42+
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
43+
if err != nil {
44+
return nil, err
45+
}
46+
47+
f := &SrcFaker{
48+
srcAddr: srcAddr,
49+
rawSocket: rawSocket,
50+
ipH: ipH,
51+
udpH: udpH,
52+
layerBuffer: gopacket.NewSerializeBuffer(),
53+
}
54+
55+
return f, nil
56+
}
57+
58+
func (f *SrcFaker) Close() error {
59+
return f.rawSocket.Close()
60+
}
61+
62+
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
63+
defer func() {
64+
if err := f.layerBuffer.Clear(); err != nil {
65+
log.Errorf("failed to clear layer buffer: %s", err)
66+
}
67+
}()
68+
69+
payload := gopacket.Payload(data)
70+
71+
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
72+
if err != nil {
73+
return 0, fmt.Errorf("serialize layers: %w", err)
74+
}
75+
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
76+
if err != nil {
77+
return 0, fmt.Errorf("write to raw conn: %w", err)
78+
}
79+
return n, nil
80+
}
81+
82+
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
83+
ipH := &layers.IPv4{
84+
DstIP: net.ParseIP("127.0.0.1"),
85+
SrcIP: srcAddr.IP,
86+
Version: 4,
87+
TTL: 64,
88+
Protocol: layers.IPProtocolUDP,
89+
}
90+
udpH := &layers.UDP{
91+
SrcPort: layers.UDPPort(srcAddr.Port),
92+
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
93+
}
94+
95+
err := udpH.SetNetworkLayerForChecksum(ipH)
96+
if err != nil {
97+
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
98+
}
99+
100+
return ipH, udpH, nil
101+
}
102+
103+
func prepareSenderRawSocket() (net.PacketConn, error) {
104+
// Create a raw socket.
105+
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
106+
if err != nil {
107+
return nil, fmt.Errorf("creating raw socket failed: %w", err)
108+
}
109+
110+
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
111+
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
112+
if err != nil {
113+
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
114+
}
115+
116+
// Bind the socket to the "lo" interface.
117+
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
118+
if err != nil {
119+
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
120+
}
121+
122+
// Set the fwmark on the socket.
123+
err = nbnet.SetSocketOpt(fd)
124+
if err != nil {
125+
return nil, fmt.Errorf("setting fwmark failed: %w", err)
126+
}
127+
128+
// Convert the file descriptor to a PacketConn.
129+
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
130+
if file == nil {
131+
return nil, fmt.Errorf("converting fd to file failed")
132+
}
133+
packetConn, err := net.FilePacketConn(file)
134+
if err != nil {
135+
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
136+
}
137+
138+
return packetConn, nil
139+
}

0 commit comments

Comments
 (0)