Skip to content

Commit 4e5d962

Browse files
committed
Implement remote DNS
This commit implements remote DNS. It introduces two new dependencies: ttlcache and dns. Remote DNS intercepts UDP DNS queries for A records on port 53. It replies with an unused IP address from an address pool, 198.18.0.0/15 by default. When obtaining a new address from the pool, tun2socks needs to memorize which name the address belongs to, so that when a client connects to the address, it can instruct the proxy to connect to the FQDN. To implement this IP to name mapping, ttlcache is used. To prevent using multiple addresses for the same name, ttlcache is also used to implement a name to IP mapping. If an IP address is already cached for a name, that address is returned instread. When building a connection, the connection metadata is inspected and if the destination address is associated with a DNS name, the proxy is instructed to use this name instead of the IP address.
1 parent 63f71e0 commit 4e5d962

File tree

14 files changed

+380
-21
lines changed

14 files changed

+380
-21
lines changed

component/remotedns/handle.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package remotedns
2+
3+
import (
4+
"net"
5+
6+
"github.com/miekg/dns"
7+
"gvisor.dev/gvisor/pkg/tcpip"
8+
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
9+
"gvisor.dev/gvisor/pkg/tcpip/stack"
10+
"gvisor.dev/gvisor/pkg/waiter"
11+
12+
"github.com/xjasonlyu/tun2socks/v2/log"
13+
M "github.com/xjasonlyu/tun2socks/v2/metadata"
14+
)
15+
16+
func RewriteMetadata(metadata *M.Metadata) bool {
17+
if !IsEnabled() {
18+
return false
19+
}
20+
dstName, found := getCachedName(metadata.DstIP)
21+
if !found {
22+
return false
23+
}
24+
metadata.DstIP = nil
25+
metadata.DstName = dstName
26+
return true
27+
}
28+
29+
func HandleDNSQuery(s *stack.Stack, id stack.TransportEndpointID, ptr *stack.PacketBuffer) bool {
30+
if !IsEnabled() {
31+
return false
32+
}
33+
34+
msg := dns.Msg{}
35+
err := msg.Unpack(ptr.Data().AsRange().ToSlice())
36+
37+
isCorrectEndpoint := id.LocalPort == 53 && (listenAddress.Equal(id.LocalAddress.AsSlice()) || listenAddress.IsUnspecified())
38+
39+
// Ignore UDP packets that are not matching the listen address and are not recursive queries
40+
if !isCorrectEndpoint || err != nil || len(msg.Question) != 1 || msg.Question[0].Qtype != dns.TypeA &&
41+
msg.Question[0].Qtype != dns.TypeAAAA || msg.Question[0].Qclass != dns.ClassINET || !msg.RecursionDesired ||
42+
msg.Response {
43+
return false
44+
}
45+
46+
qname := msg.Question[0].Name
47+
qtype := msg.Question[0].Qtype
48+
49+
log.Debugf("[DNS] query %s %s", dns.TypeToString[qtype], qname)
50+
51+
var ip net.IP
52+
if qtype == dns.TypeA {
53+
rr := dns.A{}
54+
ip = findOrInsertNameAndReturnIP(4, qname)
55+
if ip != nil {
56+
rr.A = ip
57+
rr.Hdr.Name = qname
58+
rr.Hdr.Ttl = dnsTTL
59+
rr.Hdr.Class = dns.ClassINET
60+
rr.Hdr.Rrtype = qtype
61+
msg.Answer = append(msg.Answer, &rr)
62+
} else {
63+
log.Warnf("[DNS] IP space exhausted")
64+
msg.Rcode = dns.RcodeServerFailure
65+
}
66+
}
67+
68+
msg.Response = true
69+
msg.RecursionDesired = false
70+
msg.RecursionAvailable = true
71+
72+
var wq waiter.Queue
73+
74+
ep, err2 := s.NewEndpoint(ptr.TransportProtocolNumber, ptr.NetworkProtocolNumber, &wq)
75+
if err2 != nil {
76+
return true
77+
}
78+
defer ep.Close()
79+
80+
ep.Bind(tcpip.FullAddress{NIC: ptr.NICID, Addr: id.LocalAddress, Port: id.LocalPort})
81+
conn := gonet.NewUDPConn(&wq, ep)
82+
defer conn.Close()
83+
packed, err := msg.Pack()
84+
if err != nil {
85+
return true
86+
}
87+
_, _ = conn.WriteTo(packed, &net.UDPAddr{IP: id.RemoteAddress.AsSlice(), Port: int(id.RemotePort)})
88+
return true
89+
}

component/remotedns/iputil.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package remotedns
2+
3+
import "net"
4+
5+
func copyIP(ip net.IP) net.IP {
6+
dup := make(net.IP, len(ip))
7+
copy(dup, ip)
8+
return dup
9+
}
10+
11+
func incrementIP(ip net.IP) net.IP {
12+
result := copyIP(ip)
13+
for i := len(result) - 1; i >= 0; i-- {
14+
result[i]++
15+
if result[i] != 0 {
16+
break
17+
}
18+
}
19+
return result
20+
}
21+
22+
func getBroadcastAddress(ipnet *net.IPNet) net.IP {
23+
result := copyIP(ipnet.IP)
24+
for i := 0; i < len(ipnet.IP); i++ {
25+
result[i] |= ^ipnet.Mask[i]
26+
}
27+
return result
28+
}
29+
30+
func getNetworkAddress(ipnet *net.IPNet) net.IP {
31+
result := copyIP(ipnet.IP)
32+
for i := 0; i < len(ipnet.IP); i++ {
33+
result[i] &= ipnet.Mask[i]
34+
}
35+
return result
36+
}

component/remotedns/pool.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package remotedns
2+
3+
import (
4+
"net"
5+
"sync"
6+
"time"
7+
8+
"github.com/jellydator/ttlcache/v3"
9+
)
10+
11+
var (
12+
ipToName = ttlcache.New[string, string]()
13+
nameToIP = ttlcache.New[string, net.IP]()
14+
mutex = sync.Mutex{}
15+
16+
ip4NextAddress net.IP
17+
ip4BroadcastAddress net.IP
18+
)
19+
20+
func findOrInsertNameAndReturnIP(ipVersion int, name string) net.IP {
21+
if ipVersion != 4 {
22+
panic("Method not implemented for IPv6")
23+
}
24+
mutex.Lock()
25+
defer mutex.Unlock()
26+
var result net.IP = nil
27+
var ipnet *net.IPNet
28+
var nextAddress *net.IP
29+
var broadcastAddress net.IP
30+
if ipVersion == 4 {
31+
ipnet = ip4net
32+
nextAddress = &ip4NextAddress
33+
broadcastAddress = ip4BroadcastAddress
34+
}
35+
36+
nameToIP.DeleteExpired()
37+
ipToName.DeleteExpired()
38+
39+
entry := nameToIP.Get(name)
40+
if entry != nil {
41+
ip := entry.Value()
42+
ipToName.Touch(ip.String())
43+
return ip
44+
}
45+
46+
// Beginning from the pointer to the next most likely free IP, loop through the IP address space
47+
// until either a free IP is found or the space is exhausted
48+
passedBroadcastAddress := false
49+
for result == nil {
50+
if nextAddress.Equal(broadcastAddress) {
51+
*nextAddress = getNetworkAddress(ipnet)
52+
*nextAddress = incrementIP(ipnet.IP)
53+
54+
// We have seen the broadcast address twice during looping
55+
// This means that our IP address space is exhausted
56+
if passedBroadcastAddress {
57+
return nil
58+
}
59+
passedBroadcastAddress = true
60+
}
61+
62+
// Skip the listen address if that is inside our pool range
63+
if nextAddress.Equal(listenAddress) {
64+
*nextAddress = incrementIP(*nextAddress)
65+
continue
66+
}
67+
68+
// Do not touch entries that exist in the cache already.
69+
hasKey := ipToName.Has((*nextAddress).String())
70+
if !hasKey {
71+
_ = ipToName.Set((*nextAddress).String(), name, time.Duration(dnsTTL)*time.Second+cacheGraceTime)
72+
_ = nameToIP.Set(name, *nextAddress, time.Duration(dnsTTL)*time.Second+cacheGraceTime)
73+
result = *nextAddress
74+
}
75+
76+
*nextAddress = incrementIP(*nextAddress)
77+
}
78+
79+
return result
80+
}
81+
82+
func getCachedName(address net.IP) (string, bool) {
83+
mutex.Lock()
84+
defer mutex.Unlock()
85+
entry := ipToName.Get(address.String())
86+
if entry == nil {
87+
return "", false
88+
}
89+
nameToIP.Touch(entry.Value())
90+
return entry.Value(), true
91+
}

component/remotedns/settings.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package remotedns
2+
3+
import (
4+
"errors"
5+
"net"
6+
"time"
7+
)
8+
9+
// Timeouts are somewhat arbitrary. For example, netcat will resolve the DNS
10+
// names upon startup and then stick to the resolved IP address. A timeout of 1
11+
// second may therefore be too low in cases where the first UDP packet is not
12+
// sent immediately.
13+
// cacheGraceTime defines how long an entry should still be retained in the cache
14+
// after being resolved by DNS.
15+
const (
16+
cacheGraceTime = 30 * time.Second
17+
)
18+
19+
var (
20+
enabled = false
21+
dnsTTL uint32 = 0
22+
ip4net *net.IPNet
23+
listenAddress net.IP
24+
)
25+
26+
func IsEnabled() bool {
27+
return enabled
28+
}
29+
30+
func SetDNSTTL(timeout time.Duration) {
31+
dnsTTL = uint32(timeout.Seconds())
32+
}
33+
34+
func SetListenAddress(ip net.IP) {
35+
listenAddress = ip
36+
}
37+
38+
func SetNetwork(ipnet *net.IPNet) error {
39+
leadingOnes, _ := ipnet.Mask.Size()
40+
if len(ipnet.IP) == 4 {
41+
if leadingOnes > 30 {
42+
return errors.New("IPv4 remote DNS subnet too small")
43+
}
44+
ip4net = ipnet
45+
} else {
46+
return errors.New("unsupported protocol")
47+
}
48+
return nil
49+
}
50+
51+
func Enable() {
52+
ip4NextAddress = incrementIP(getNetworkAddress(ip4net))
53+
ip4BroadcastAddress = getBroadcastAddress(ip4net)
54+
enabled = true
55+
}

core/udp.go

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,38 @@ import (
77
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
88
"gvisor.dev/gvisor/pkg/waiter"
99

10+
"github.com/xjasonlyu/tun2socks/v2/component/remotedns"
1011
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
1112
"github.com/xjasonlyu/tun2socks/v2/core/option"
1213
)
1314

1415
func withUDPHandler(handle func(adapter.UDPConn)) option.Option {
1516
return func(s *stack.Stack) error {
16-
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
17-
var (
18-
wq waiter.Queue
19-
id = r.ID()
20-
)
21-
ep, err := r.CreateEndpoint(&wq)
22-
if err != nil {
23-
glog.Debugf("forward udp request: %s:%d->%s:%d: %s",
24-
id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
25-
return
17+
s.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, ptr *stack.PacketBuffer) bool {
18+
if remotedns.HandleDNSQuery(s, id, ptr) {
19+
return true
2620
}
2721

28-
conn := &udpConn{
29-
UDPConn: gonet.NewUDPConn(&wq, ep),
30-
id: id,
31-
}
32-
handle(conn)
22+
udpForwarder := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
23+
var (
24+
wq waiter.Queue
25+
id = r.ID()
26+
)
27+
ep, err := r.CreateEndpoint(&wq)
28+
if err != nil {
29+
glog.Debugf("forward udp request %s:%d->%s:%d: %s",
30+
id.RemoteAddress, id.RemotePort, id.LocalAddress, id.LocalPort, err)
31+
return
32+
}
33+
34+
conn := &udpConn{
35+
UDPConn: gonet.NewUDPConn(&wq, ep),
36+
id: id,
37+
}
38+
handle(conn)
39+
})
40+
return udpForwarder.HandlePacket(id, ptr)
3341
})
34-
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
3542
return nil
3643
}
3744
}

0 commit comments

Comments
 (0)