diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 67e01e3e72..98dfe5bb67 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -1306,7 +1306,7 @@ "bzlTransitiveDigest": "ynEctVXvD5j3EVd/ITbHx2D6CkTijK64au5zmriT7cg=", "usagesDigest": "xr+U7navw2+SHogogmHRboKLbXAnkZfsctPQuyjmArw=", "recordedFileInputs": { - "@@//doc/requirements.txt": "f484a8b9cbf0f49e2b5e1d2d36b676f7dc88ecdff9c573001ea835a2738d84ff", + "@@//doc/requirements.txt": "69647040f4c4bdd93fd8369b245316b08cfabd17a23693d833081b5785c0f131", "@@//tools/env/pip3/requirements.txt": "92aa5b99f8051e7e2528f5ff44bb2cb263e3da1682de73908e07e173c2a417ac", "@@//tools/lint/python/requirements.txt": "c6eb43e931b4200ae71e62fc65bc78d72224495a1648dbc0a565f09571724bd8", "@@rules_fuzzing+//fuzzing/requirements.txt": "ab04664be026b632a0d2a2446c4f65982b7654f5b6851d2f9d399a19b7242a5b", diff --git a/acceptance/stun/BUILD.bazel b/acceptance/stun/BUILD.bazel new file mode 100644 index 0000000000..4dec92ea5f --- /dev/null +++ b/acceptance/stun/BUILD.bazel @@ -0,0 +1,16 @@ +load("//:scion.bzl", "scion_go_binary") +load("//acceptance/common:topogen.bzl", "topogen_test") + +topogen_test( + name = "test", + src = "test.py", + args = [ + "--executable=test-client:$(location //acceptance/stun/test-client)", + "--executable=test-server:$(location //acceptance/stun/test-server)", + ], + data = [ + "//acceptance/stun/test-client", + "//acceptance/stun/test-server", + ], + topo = "//topology:tiny.topo", +) diff --git a/acceptance/stun/test-client/BUILD.bazel b/acceptance/stun/test-client/BUILD.bazel new file mode 100644 index 0000000000..0fd68bd376 --- /dev/null +++ b/acceptance/stun/test-client/BUILD.bazel @@ -0,0 +1,19 @@ +load("@rules_go//go:def.bzl", "go_binary", "go_library") + +go_library( + name = "go_default_library", + srcs = ["main.go"], + importpath = "github.com/scionproto/scion/acceptance/stun/test-client", + visibility = ["//visibility:private"], + deps = [ + "//pkg/daemon:go_default_library", + "//pkg/daemon/types:go_default_library", + "//pkg/snet:go_default_library", + ], +) + +go_binary( + name = "test-client", + embed = [":go_default_library"], + visibility = ["//visibility:public"], +) diff --git a/acceptance/stun/test-client/main.go b/acceptance/stun/test-client/main.go new file mode 100644 index 0000000000..062f1137ba --- /dev/null +++ b/acceptance/stun/test-client/main.go @@ -0,0 +1,103 @@ +// Copyright 2025 ETH Zurich +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "flag" + "log" + "os" + + "github.com/scionproto/scion/pkg/daemon" + daemontypes "github.com/scionproto/scion/pkg/daemon/types" + "github.com/scionproto/scion/pkg/snet" +) + +func main() { + log.SetOutput(os.Stdout) + log.Println("Client running") + var daemonAddr string + var localAddr snet.UDPAddr + var remoteAddr snet.UDPAddr + var data string + flag.StringVar(&daemonAddr, "daemon", "127.0.0.1:30255", "Daemon address") + flag.Var(&localAddr, "local", "Local address") + flag.Var(&remoteAddr, "remote", "Remote address") + flag.StringVar(&data, "data", "", "Data") + flag.Parse() + + ctx := context.Background() + + dc, err := daemon.NewService(daemonAddr).Connect(ctx) + if err != nil { + log.Fatalf("Failed to create SCION daemon connector: %v", err) + } + + ps, err := dc.Paths(ctx, remoteAddr.IA, localAddr.IA, daemontypes.PathReqFlags{Refresh: true}) + if err != nil { + log.Fatalf("Failed to lookup paths: %v", err) + } + + if len(ps) == 0 { + log.Fatalf("No paths to %v available", remoteAddr.IA) + } + + sp := ps[0] + + log.Printf("Selected path to %v:", remoteAddr.IA) + log.Printf("\t%v", sp) + + topology, err := daemon.LoadTopology(ctx, dc) + if err != nil { + log.Fatalf("Failed to load topology from daemon: %v", err) + } + + scionNetwork := snet.SCIONNetwork{ + Topology: topology, + STUNEnabled: true, + } + + remoteAddr.Path = sp.Dataplane() + remoteAddr.NextHop = sp.UnderlayNextHop() + + conn, err := scionNetwork.Dial(ctx, "udp", localAddr.Host, &remoteAddr) + if err != nil { + log.Fatalf("Failed to dial SCION address: %v", err) + } + + defer conn.Close() + + log.Print("Successfully established SCION connection") + + _, err = conn.Write([]byte(data)) + if err != nil { + log.Fatalf("Failed to write to SCION connection: %v", err) + } + + log.Printf("Successfully sent data to %v", remoteAddr.IA) + + buf := make([]byte, 4096) + n, err := conn.Read(buf) + if err != nil { + log.Fatalf("Failed to read from SCION connection: %v", err) + } + + response := string(buf[:n]) + log.Printf("Received data: \"%s\"", response) + if response != data { + log.Fatalf("Assertion failed: response does not match sent data") + } + os.Exit(0) +} diff --git a/acceptance/stun/test-server/BUILD.bazel b/acceptance/stun/test-server/BUILD.bazel new file mode 100644 index 0000000000..8e65707313 --- /dev/null +++ b/acceptance/stun/test-server/BUILD.bazel @@ -0,0 +1,15 @@ +load("@rules_go//go:def.bzl", "go_binary", "go_library") + +go_library( + name = "go_default_library", + srcs = ["main.go"], + importpath = "github.com/scionproto/scion/acceptance/stun/test-server", + visibility = ["//visibility:private"], + deps = ["//pkg/snet:go_default_library"], +) + +go_binary( + name = "test-server", + embed = [":go_default_library"], + visibility = ["//visibility:public"], +) diff --git a/acceptance/stun/test-server/main.go b/acceptance/stun/test-server/main.go new file mode 100644 index 0000000000..8e64299811 --- /dev/null +++ b/acceptance/stun/test-server/main.go @@ -0,0 +1,102 @@ +// Copyright 2025 ETH Zurich +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "log" + "net" + "os" + + "github.com/scionproto/scion/pkg/snet" +) + +func main() { + log.SetOutput(os.Stdout) + log.Print("Server running") + + var localAddr snet.UDPAddr + flag.Var(&localAddr, "local", "Local address") + flag.Parse() + + conn, err := net.ListenUDP("udp", localAddr.Host) + if err != nil { + log.Fatalf("Failed to listen on UDP connection: %v", err) + } + defer conn.Close() + + for { + var pkt snet.Packet + pkt.Prepare() + + n, lastHop, err := conn.ReadFrom(pkt.Bytes) + if err != nil { + log.Printf("Failed to read packet: %v", err) + continue + } + pkt.Bytes = pkt.Bytes[:n] + + err = pkt.Decode() + if err != nil { + log.Printf("Failed to decode packet: %v", err) + continue + } + + pld, ok := pkt.Payload.(snet.UDPPayload) + if !ok { + log.Printf("Failed to read packet payload") + continue + } + + if int(pld.DstPort) != localAddr.Host.Port { + continue + } + + log.Printf("Received data: %q from %v:%v", string(pld.Payload), pkt.Source, pld.SrcPort) + + pkt.Destination, pkt.Source = pkt.Source, pkt.Destination + + rp, ok := pkt.Path.(snet.RawPath) + if !ok { + log.Printf("Failed to reverse path, unexpected path type: %v", pkt.Path) + continue + } + replyPather := snet.DefaultReplyPather{} + replyPath, err := replyPather.ReplyPath(rp) + if err != nil { + log.Printf("Failed to reverse path: %v", err) + continue + } + pkt.Path = replyPath + + pkt.Payload = snet.UDPPayload{ + SrcPort: pld.DstPort, + DstPort: pld.SrcPort, + Payload: pld.Payload, + } + + err = pkt.Serialize() + if err != nil { + log.Printf("Failed to serialize SCION packet: %v", err) + continue + } + + _, err = conn.WriteTo(pkt.Bytes, lastHop) + if err != nil { + log.Printf("Failed to write packet: %v", err) + continue + } + } +} diff --git a/acceptance/stun/test.py b/acceptance/stun/test.py new file mode 100644 index 0000000000..cfb744bbf2 --- /dev/null +++ b/acceptance/stun/test.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# Copyright 2025 ETH Zurich +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from acceptance.common import base +import time +import yaml +from plumbum import local + + +class Test(base.TestTopogen): + def setup_prepare(self): + super().setup_prepare() + + # Modify test topology configuration by adding a separate NAT'ed docker network + # (192.168.123.0/24). + # A new docker container is added which acts as a NAT between the new network and + # AS1-ff00:0:111 (172.20.0.24/29). + # The tester container of AS1-ff00:0:111 is moved inside the separate network, along with + # the accompanying daemon and dispatcher. + with (open(self.artifacts / "gen/scion-dc.yml", "r") as file): + scion_dc = yaml.safe_load(file) + + # Create new docker network + scion_dc["networks"]["local_001"] = { + "driver": "bridge", + "driver_opts": {"com.docker.network.bridge.name": "local_001"}, + "ipam": {"config": [{"subnet": "192.168.123.0/24"}]} + } + + # Move tester dispatcher to new network + scion_dc["services"]["disp_tester_1-ff00_0_111"]["networks"] = \ + {"local_001": {"ipv4_address": "192.168.123.4"}} + + # Connect tester daemon to new network + scion_dc["services"]["sd1-ff00_0_111"]["networks"]["local_001"] = \ + {"ipv4_address": "192.168.123.3"} + + # Move tester container to new network + scion_dc["services"]["tester_1-ff00_0_110"]["environment"]["SCION_DAEMON_ADDRESS"] = \ + "172.20.0.21:30255" + scion_dc["services"]["tester_1-ff00_0_111"].pop("entrypoint") + scion_dc["services"]["tester_1-ff00_0_111"]["command"] = \ + ('bash -c "ip route del default && ip route add default via 192.168.123.2 ' + '&& tail -f /dev/null"') + scion_dc["services"]["tester_1-ff00_0_111"]["environment"] = { + "SCION_DAEMON": "192.168.123.3:30255", + "SCION_DAEMON_ADDRESS": "192.168.123.3:30255", + "SCION_LOCAL_ADDR": "192.168.123.4" + } + + # Create new docker container that acts as a NAT. + # We use iptables for the NAT + # (https://www.man7.org/linux/man-pages/man8/iptables.8.html) + # iptables command breakdown: + # -t nat specifies that the rule applies to the NAT table + # -A POSTROUTING appends the rule to the POSTROUTING chain, which is used to + # modify packets after routing decision and prior to leaving + # the network interface + # -s 192.168.123.0/24 specifies the source address range + # -p tcp/udp specifies the protocol the rule applies to + # -o eth1 specifies that the rule applies to packets leaving eth1 + # -j MASQUERADE dynamically replace source IP of outgoing packets with the + # IP of eth1 + # --random uses random source ports + # --to-ports 31000-32767 specifies to use only ports from the dispatched port range + # see https://www.man7.org/linux/man-pages/man8/iptables-extensions.8.html for more + # information + scion_dc["services"]["nat_1-ff00_0_111"] = { + "command": 'sh -c "apk update && apk add --no-cache iptables ' + '&& iptables -t nat -A POSTROUTING -s 192.168.123.0/24 -p tcp -o eth1 ' + '-j MASQUERADE && iptables -t nat -A POSTROUTING -s 192.168.123.0/24 ' + '-p udp -o eth1 -j MASQUERADE --random --to-ports 31000-32767 ' + '&& tail -f /dev/null"', + "image": "alpine:latest", + "networks": { + "scn_002": {"ipv4_address": "172.20.0.29"}, + "local_001": {"ipv4_address": "192.168.123.2"}, + }, + "cap_add": ["NET_ADMIN"] + } + with open(self.artifacts / "gen/scion-dc.yml", "w") as file: + yaml.dump(scion_dc, file) + + # More configuration changes to reflect new network topology + with open(self.artifacts / "gen/networks.conf", "r") as file: + filecontent = file.read() + + filecontent = filecontent.replace("sd1-ff00_0_111", "nat-ff00_0_111") + filecontent = filecontent.replace( + "tester_1-ff00_0_111 = 172.20.0.29", + "[192.168.123.0/24]\nnat-ff00_0_111 = 192.168.123.2\nsd1-ff00_0_111 = 192.168.123.3\n" + "tester_1-ff00_0_111 = 192.168.123.4") + + with open(self.artifacts / "gen/networks.conf", "w") as file: + file.write(filecontent) + + with open(self.artifacts / "gen/sciond_addresses.json", "r") as file: + filecontent = file.read() + + filecontent = filecontent.replace("172.20.0.28", "192.168.123.4") + + with open(self.artifacts / "gen/sciond_addresses.json", "w") as file: + file.write(filecontent) + + with open(self.artifacts / "gen/ASff00_0_111/sd.toml", "r") as file: + filecontent = file.read() + + filecontent = filecontent.replace("172.20.0.28", "192.168.123.3") + + with open(self.artifacts / "gen/ASff00_0_111/sd.toml", "w") as file: + file.write(filecontent) + + def _run(self): + self.await_connectivity() + time.sleep(10) # wait for everything to start up + + # copy test executables to test container + test_client = local["realpath"](self.get_executable("test-client").executable).strip() + test_server = local["realpath"](self.get_executable("test-server").executable).strip() + self.dc("cp", test_server, "tester_1-ff00_0_110" + ":/bin/") + self.dc("cp", test_client, "tester_1-ff00_0_111" + ":/bin/") + + # run tests (located in test-client/main.go and test-server/main.go) + self.dc.execute_detached("tester_1-ff00_0_110", "bash", "-c", + "test-server -local 1-ff00:0:110,172.20.0.22:31000") + time.sleep(3) + result = self.dc.execute( + "tester_1-ff00_0_111", "bash", "-c", + 'test-client -daemon 192.168.123.3:30255 -local 1-ff00:0:111,192.168.123.4:31000 ' + '-remote 1-ff00:0:110,172.20.0.22:31000 -data "abc"') + print(result) + + +if __name__ == "__main__": + base.main(Test) diff --git a/pkg/snet/BUILD.bazel b/pkg/snet/BUILD.bazel index a6a366f45a..4307d2b1b3 100644 --- a/pkg/snet/BUILD.bazel +++ b/pkg/snet/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "snet.go", "sock_error_posix.go", "sock_error_windows.go", + "stun_conn.go", "svcaddr.go", "udpaddr.go", "writer.go", @@ -37,9 +38,11 @@ go_library( "//pkg/slayers/path/epic:go_default_library", "//pkg/slayers/path/onehop:go_default_library", "//pkg/slayers/path/scion:go_default_library", + "//pkg/stun:go_default_library", "//private/topology:go_default_library", "//private/topology/underlay:go_default_library", "@com_github_gopacket_gopacket//:go_default_library", + "@org_golang_x_sync//singleflight:go_default_library", ] + select({ "@rules_go//go/platform:windows": [ "@org_golang_x_sys//windows:go_default_library", diff --git a/pkg/snet/conn.go b/pkg/snet/conn.go index 1db4f1988e..ad94cb6199 100644 --- a/pkg/snet/conn.go +++ b/pkg/snet/conn.go @@ -75,6 +75,7 @@ func NewCookedConn( if local.Host == nil || local.Host.IP.IsUnspecified() { return nil, serrors.New("nil or unspecified address is not supported.") } + hasSTUN := hasSTUNConn(pconn) return &Conn{ conn: pconn, local: local, @@ -86,12 +87,14 @@ func NewCookedConn( remote: o.remote, dispatchedPortStart: topo.PortRange.Start, dispatchedPortEnd: topo.PortRange.End, + hasSTUN: hasSTUN, }, scionConnReader: scionConnReader{ conn: pconn, buffer: make([]byte, common.SupportedMTU), replyPather: o.replyPather, local: local, + hasSTUN: hasSTUN, }, }, nil } @@ -153,3 +156,13 @@ func apply(opts []ConnOption) options { } return o } + +// hasSTUNConn checks if the provided PacketConn has STUN enabled. +func hasSTUNConn(pc PacketConn) bool { + scionPacketConn, ok := pc.(*SCIONPacketConn) + if !ok { + return false + } + _, ok = scionPacketConn.conn.(*stunConn) + return ok +} diff --git a/pkg/snet/packet_conn.go b/pkg/snet/packet_conn.go index 9352a2c734..249f93ecec 100644 --- a/pkg/snet/packet_conn.go +++ b/pkg/snet/packet_conn.go @@ -114,8 +114,8 @@ type SCIONPacketConnMetrics struct { // SCIONPacketConn gives applications full control over the content of valid SCION // packets. type SCIONPacketConn struct { - // Conn is the connection to send/receive serialized packets on. - Conn *net.UDPConn + // conn is the connection to send/receive serialized packets on. + conn sysPacketConn // SCMPHandler is invoked for packets that contain an SCMP L4. If the // handler is nil, errors are returned back to applications every time an // SCMP message is received. @@ -126,17 +126,26 @@ type SCIONPacketConn struct { Topology Topology } +// sysPacketConn is a wrapper interface around net.PacketConn. +// It exists so custom types can wrap or customize the standard net.PacketConn methods. +type sysPacketConn interface { + net.PacketConn + SyscallConn() (syscall.RawConn, error) + SetReadBuffer(bytes int) error + SetWriteBuffer(bytes int) error +} + func (c *SCIONPacketConn) SetReadBuffer(bytes int) error { - return c.Conn.SetReadBuffer(bytes) + return c.conn.SetReadBuffer(bytes) } func (c *SCIONPacketConn) SetDeadline(d time.Time) error { - return c.Conn.SetDeadline(d) + return c.conn.SetDeadline(d) } func (c *SCIONPacketConn) Close() error { metrics.CounterInc(c.Metrics.Closes) - return c.Conn.Close() + return c.conn.Close() } func (c *SCIONPacketConn) WriteTo(pkt *Packet, ov *net.UDPAddr) error { @@ -145,7 +154,7 @@ func (c *SCIONPacketConn) WriteTo(pkt *Packet, ov *net.UDPAddr) error { } // Send message - n, err := c.Conn.WriteTo(pkt.Bytes, ov) + n, err := c.conn.WriteTo(pkt.Bytes, ov) if err != nil { return serrors.Wrap("Reliable socket write error", err) } @@ -155,11 +164,11 @@ func (c *SCIONPacketConn) WriteTo(pkt *Packet, ov *net.UDPAddr) error { } func (c *SCIONPacketConn) SetWriteBuffer(bytes int) error { - return c.Conn.SetWriteBuffer(bytes) + return c.conn.SetWriteBuffer(bytes) } func (c *SCIONPacketConn) SetWriteDeadline(d time.Time) error { - return c.Conn.SetWriteDeadline(d) + return c.conn.SetWriteDeadline(d) } func (c *SCIONPacketConn) ReadFrom(pkt *Packet, ov *net.UDPAddr) error { @@ -199,12 +208,12 @@ func (c *SCIONPacketConn) ReadFrom(pkt *Packet, ov *net.UDPAddr) error { } func (c *SCIONPacketConn) SyscallConn() (syscall.RawConn, error) { - return c.Conn.SyscallConn() + return c.conn.SyscallConn() } func (c *SCIONPacketConn) readFrom(pkt *Packet) (*net.UDPAddr, error) { pkt.Prepare() - n, remoteAddr, err := c.Conn.ReadFrom(pkt.Bytes) + n, remoteAddr, err := c.conn.ReadFrom(pkt.Bytes) if err != nil { metrics.CounterInc(c.Metrics.UnderlayConnectionErrors) return nil, serrors.Wrap("reading underlay connection", err) @@ -242,11 +251,11 @@ func (c *SCIONPacketConn) readFrom(pkt *Packet) (*net.UDPAddr, error) { } func (c *SCIONPacketConn) SetReadDeadline(d time.Time) error { - return c.Conn.SetReadDeadline(d) + return c.conn.SetReadDeadline(d) } func (c *SCIONPacketConn) LocalAddr() net.Addr { - return c.Conn.LocalAddr() + return c.conn.LocalAddr() } // isShimDispatcher checks that udpAddr corresponds to the address where the diff --git a/pkg/snet/reader.go b/pkg/snet/reader.go index 3769918220..88f8e6ef2d 100644 --- a/pkg/snet/reader.go +++ b/pkg/snet/reader.go @@ -38,6 +38,9 @@ type scionConnReader struct { mtx sync.Mutex buffer []byte + + // hasSTUN indicates whether the conn has STUN enabled. + hasSTUN bool } // ReadFrom reads data into b, returning the length of copied data and the @@ -88,14 +91,9 @@ func (c *scionConnReader) read(b []byte) (int, *UDPAddr, error) { return 0, nil, serrors.New("unexpected payload", "type", common.TypeOf(pkt.Payload)) } - // XXX(JordiSubira): We explicitly forbid nil or unspecified address in the current constructor - // for Conn. - // If this were ever to change, we would always fall into the following if statement, then - // we would like to replace this logic (e.g., using IP_PKTINFO, with its caveats). pktAddrPort := netip.AddrPortFrom(pkt.Destination.Host.IP(), udp.DstPort) - if c.local.IA != pkt.Destination.IA || - c.local.Host.AddrPort() != pktAddrPort { - return 0, nil, serrors.New("packet is destined to a different host", + if c.local.IA != pkt.Destination.IA { + return 0, nil, serrors.New("packet is destined to a different IA", "local_isd_as", c.local.IA, "local_host", c.local.Host, "pkt_destination_isd_as", pkt.Destination.IA, @@ -103,6 +101,27 @@ func (c *scionConnReader) read(b []byte) (int, *UDPAddr, error) { ) } + // XXX(JordiSubira): We explicitly forbid nil or unspecified address in the current constructor + // for Conn. + // If this were ever to change, we would always fall into the following if statement, then + // we would like to replace this logic (e.g., using IP_PKTINFO, with its caveats). + if c.local.Host.AddrPort() != pktAddrPort { + + // If the client is behind a NAT, the SCION packet will hold the mapped external address, + // which is expected to be different from the local address. To handle this case, we check + // whether the underlying connection is a stunConn, which indicates that NAT traversal + // is in use. + // TODO: Is it necessary to check that the address matches one of the mapped addresses? + if !c.hasSTUN { + return 0, nil, serrors.New("packet is destined to a different host", + "local_isd_as", c.local.IA, + "local_host", c.local.Host, + "pkt_destination_isd_as", pkt.Destination.IA, + "pkt_destination_host", pktAddrPort, + ) + } + } + // Extract remote address. // Copy the address data to prevent races. See // https://github.com/scionproto/scion/issues/1659. diff --git a/pkg/snet/snet.go b/pkg/snet/snet.go index b661641c5e..386d946096 100644 --- a/pkg/snet/snet.go +++ b/pkg/snet/snet.go @@ -44,6 +44,7 @@ import ( "github.com/scionproto/scion/pkg/addr" "github.com/scionproto/scion/pkg/log" "github.com/scionproto/scion/pkg/metrics/v2" + "github.com/scionproto/scion/pkg/private/common" "github.com/scionproto/scion/pkg/private/serrors" ) @@ -89,6 +90,8 @@ type SCIONNetwork struct { // SCMPHandler describes the network behaviour upon receiving SCMP traffic. SCMPHandler SCMPHandler PacketConnMetrics SCIONPacketConnMetrics + // STUNEnabled indicates whether STUN should be used for NAT traversal. + STUNEnabled bool } // OpenRaw returns a PacketConn which listens on the specified address. @@ -121,7 +124,7 @@ func (n *SCIONNetwork) OpenRaw(ctx context.Context, addr *net.UDPAddr) (PacketCo return nil, err } return &SCIONPacketConn{ - Conn: pconn, + conn: pconn, SCMPHandler: n.SCMPHandler, Metrics: n.PacketConnMetrics, Topology: n.Topology, @@ -154,6 +157,20 @@ func (n *SCIONNetwork) Dial(ctx context.Context, network string, listen *net.UDP return nil, err } log.FromCtx(ctx).Debug("UDP socket opened on", "addr", packetConn.LocalAddr(), "to", remote) + + if n.STUNEnabled { + scionPacketConn, ok := packetConn.(*SCIONPacketConn) + if !ok { + return nil, serrors.New("expected SCIONPacketConn", "type", common.TypeOf(packetConn)) + } + stunConn, err := newSTUNConn(scionPacketConn.conn) + if err != nil { + return nil, serrors.Wrap("error creating STUN conn", err) + } + scionPacketConn.conn = stunConn + packetConn = scionPacketConn + } + return NewCookedConn(packetConn, n.Topology, WithReplyPather(n.ReplyPather), WithRemote(remote)) } diff --git a/pkg/snet/stun_conn.go b/pkg/snet/stun_conn.go new file mode 100644 index 0000000000..fa66e031d5 --- /dev/null +++ b/pkg/snet/stun_conn.go @@ -0,0 +1,342 @@ +// Copyright 2025 ETH Zurich +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snet + +import ( + "errors" + "net" + "net/netip" + "os" + "sync" + "syscall" + "time" + + "github.com/scionproto/scion/pkg/stun" + + "golang.org/x/sync/singleflight" +) + +const timeoutDuration = 5 * time.Minute + +// stunConn is a wrapper around sysPacketConn that handles STUN requests. +type stunConn struct { + sysPacketConn + recvChan chan dataPacket + maxQueuedBytes int64 + mutex sync.Mutex + sg singleflight.Group + + // the following fields are protected by mutex + queuedBytes int64 + stunChans map[stun.TxID]chan stunResponse + mappings map[netip.AddrPort]*natMapping + readDeadline time.Time + writeDeadline time.Time + readDeadlineChanged chan struct{} + writeDeadlineChanged chan struct{} + cond *sync.Cond // condition variable for pending STUN requests +} + +type dataPacket struct { + data []byte + addr net.Addr +} + +type stunResponse struct { + addr netip.AddrPort + err error +} + +func newSTUNConn(conn sysPacketConn) (*stunConn, error) { + // Get the receive buffer size + sysCallConn, err := conn.SyscallConn() + if err != nil { + return nil, err + } + var rcvBufSize int + err = sysCallConn.Control(func(fd uintptr) { + rcvBufSize, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) + }) + if err != nil { + return nil, err + } + + // assuming lower bound of per packet metadata of 64 bytes + maxNumPacket := max(rcvBufSize/64, 10) + + c := &stunConn{ + sysPacketConn: conn, + recvChan: make(chan dataPacket, maxNumPacket), + maxQueuedBytes: int64(rcvBufSize), + stunChans: make(map[stun.TxID]chan stunResponse), + mappings: make(map[netip.AddrPort]*natMapping), + readDeadlineChanged: make(chan struct{}), + writeDeadlineChanged: make(chan struct{}), + } + c.cond = sync.NewCond(&c.mutex) + + // background goroutine to continuously read from the underlying UDP connection and filter out + // STUN packets + go func() { + buf := make([]byte, 65535) + for { + n, addr, err := c.sysPacketConn.ReadFrom(buf) + if err != nil { + if errors.Is(err, net.ErrClosed) || + errors.Is(err, syscall.EBADF) { // bad file descriptor (connection closed) + close(c.recvChan) + return + } + continue + } + respTxID, mappedAddr, err := stun.ParseResponse(buf[:n]) + if err == nil { + c.mutex.Lock() + ch, ok := c.stunChans[respTxID] + c.mutex.Unlock() + if ok { + select { + case ch <- stunResponse{addr: mappedAddr, err: err}: + default: + } + } + } else if errors.Is(err, stun.ErrNotSTUN) { + func() { + c.mutex.Lock() + defer c.mutex.Unlock() + + pktLen := int64(len(buf[:n])) + if c.queuedBytes+pktLen <= c.maxQueuedBytes { + data := make([]byte, n) + copy(data, buf[:n]) + select { + case c.recvChan <- dataPacket{data: data, addr: addr}: + c.queuedBytes += pktLen + default: + } + } + }() + } // for all other errors, ignore the packet + } + }() + + return c, nil +} + +func (c *stunConn) ReadFrom(b []byte) (int, net.Addr, error) { + deadlineTimer := time.NewTimer(0) + deadlineTimer.Stop() + + for { + c.mutex.Lock() + deadline := c.readDeadline + deadlineChan := c.readDeadlineChanged + c.mutex.Unlock() + + if !deadline.IsZero() { + timeout := time.Until(deadline) + if timeout <= 0 { + return 0, nil, os.ErrDeadlineExceeded + } + deadlineTimer.Reset(timeout) + } + + select { + case pkt, ok := <-c.recvChan: + if !ok { + return 0, nil, net.ErrClosed + } + c.mutex.Lock() + c.queuedBytes -= int64(len(pkt.data)) + c.mutex.Unlock() + n := copy(b, pkt.data) + return n, pkt.addr, nil + case <-deadlineTimer.C: + return 0, nil, os.ErrDeadlineExceeded + case <-deadlineChan: + continue // read deadline changed, re-evaluate + } + } +} + +type natMapping struct { + destination netip.AddrPort + mappedAddr netip.AddrPort + lastUsed time.Time +} + +func (mapping *natMapping) touch() { + mapping.lastUsed = time.Now() +} + +func (mapping *natMapping) isValid() bool { + return time.Since(mapping.lastUsed) < timeoutDuration +} + +func (c *stunConn) mappedAddr(dest netip.AddrPort) (netip.AddrPort, error) { + addr, exists := func() (netip.AddrPort, bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + // Check if mapping exists and is valid + if mapping, ok := c.mappings[dest]; ok && mapping.isValid() { + mapping.touch() + return mapping.mappedAddr, true + } + return netip.AddrPort{}, false + }() + if exists { + return addr, nil + } + + result, err, _ := c.sg.Do(dest.String(), func() (interface{}, error) { + return c.makeSTUNRequest(dest) + }) + + if err != nil { + return netip.AddrPort{}, err + } + + return result.(*natMapping).mappedAddr, nil +} + +func (c *stunConn) makeSTUNRequest(dest netip.AddrPort) (*natMapping, error) { + txID := stun.NewTxID() + stunRequest := stun.Request(txID) + + // values according to RFC 8489 Section 6.2.1 + // TODO: make configurable? + const Rc = 7 // Maximum number of retransmissions + const Rm = 16 // Multiplier for final retransmission wait time + const initialRTO = 500 * time.Millisecond + + c.mutex.Lock() + stunChan := make(chan stunResponse, Rc*2) + c.stunChans[txID] = stunChan + c.mutex.Unlock() + + defer func() { + c.mutex.Lock() + delete(c.stunChans, txID) + c.mutex.Unlock() + }() + + retransmissionTimer := time.NewTimer(0) + retransmissionTimer.Stop() + + deadlineTimer := time.NewTimer(0) + deadlineTimer.Stop() + + var mappedAddr netip.AddrPort + currentRTO := initialRTO + + for i := range Rc { + _, err := c.WriteTo(stunRequest, net.UDPAddrFromAddrPort(dest)) + if err != nil { + return nil, err + } + + var waitDuration time.Duration + if i < Rc-1 { + waitDuration = currentRTO + currentRTO *= 2 + } else { + waitDuration = Rm * initialRTO + } + retransmissionTimer.Reset(waitDuration) + + var timerExpired bool + for !timerExpired { + c.mutex.Lock() + deadline := c.writeDeadline + deadlineChanged := c.writeDeadlineChanged + c.mutex.Unlock() + + if !deadline.IsZero() { + timeout := time.Until(deadline) + if timeout <= 0 { + return nil, os.ErrDeadlineExceeded + } + deadlineTimer.Reset(timeout) + } + + select { + case <-retransmissionTimer.C: + timerExpired = true + case <-deadlineTimer.C: + return nil, os.ErrDeadlineExceeded + case <-deadlineChanged: + continue // write deadline changed, re-evaluate + case resp := <-stunChan: + if resp.err != nil { + return nil, resp.err + } + mappedAddr = resp.addr + if !mappedAddr.IsValid() { + return nil, errors.New("invalid mapped address") + } + + c.mutex.Lock() + mapping := c.mappings[dest] + if mapping == nil { + mapping = &natMapping{destination: dest} + c.mappings[dest] = mapping + } + mapping.mappedAddr = mappedAddr + mapping.touch() + c.mutex.Unlock() + + return mapping, nil + } + } + } + + return nil, errors.New("STUN request timed out") +} + +func (c *stunConn) SetDeadline(t time.Time) error { + c.mutex.Lock() + defer c.mutex.Unlock() + err := c.sysPacketConn.SetWriteDeadline(t) + if err == nil { + c.readDeadline = t + c.writeDeadline = t + close(c.readDeadlineChanged) + close(c.writeDeadlineChanged) + c.readDeadlineChanged = make(chan struct{}) + c.writeDeadlineChanged = make(chan struct{}) + } + return err +} + +func (c *stunConn) SetReadDeadline(t time.Time) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.readDeadline = t + close(c.readDeadlineChanged) + c.readDeadlineChanged = make(chan struct{}) + return nil +} + +func (c *stunConn) SetWriteDeadline(t time.Time) error { + c.mutex.Lock() + defer c.mutex.Unlock() + err := c.sysPacketConn.SetWriteDeadline(t) + if err == nil { + c.writeDeadline = t + close(c.writeDeadlineChanged) + c.writeDeadlineChanged = make(chan struct{}) + } + return err +} diff --git a/pkg/snet/writer.go b/pkg/snet/writer.go index f6661b1d51..dcf6d5d11a 100644 --- a/pkg/snet/writer.go +++ b/pkg/snet/writer.go @@ -36,6 +36,9 @@ type scionConnWriter struct { mtx sync.Mutex buffer []byte + + // hasSTUN indicates whether the conn has STUN enabled. + hasSTUN bool } // WriteTo sends b to raddr. @@ -82,6 +85,19 @@ func (c *scionConnWriter) WriteTo(b []byte, raddr net.Addr) (int, error) { if !ok { return 0, serrors.New("invalid listen host IP", "ip", c.local.Host.IP) } + listenHostPort := uint16(c.local.Host.Port) + + // Rewrite source address if STUN is in use + var err error + listenHostIP, listenHostPort, err = c.stunMappedSource( + raddr, + nextHop, + listenHostIP, + listenHostPort, + ) + if err != nil { + return 0, err + } pkt := &Packet{ Bytes: Bytes(c.buffer), @@ -93,7 +109,7 @@ func (c *scionConnWriter) WriteTo(b []byte, raddr net.Addr) (int, error) { }, Path: path, Payload: UDPPayload{ - SrcPort: uint16(c.local.Host.Port), + SrcPort: listenHostPort, DstPort: uint16(port), Payload: b, }, @@ -121,3 +137,47 @@ func (c *scionConnWriter) SetWriteDeadline(t time.Time) error { func (c *scionConnWriter) isWithinRange(port int) bool { return port >= int(c.dispatchedPortStart) && port <= int(c.dispatchedPortEnd) } + +// stunMappedSource returns the NAT mapped address for the source if the connection is +// using STUN and the destination is in a different IA. Otherwise, it returns the original +// address unchanged. +func (c *scionConnWriter) stunMappedSource( + raddr net.Addr, + nextHop *net.UDPAddr, + listenHostIP netip.Addr, + listenHostPort uint16, +) (netip.Addr, uint16, error) { + + if !c.hasSTUN { + return listenHostIP, listenHostPort, nil + } + + scionPacketConn := c.conn.(*SCIONPacketConn) + stunConn := scionPacketConn.conn.(*stunConn) + + var sameIA bool + switch a := raddr.(type) { + case *UDPAddr: + sameIA = a.IA.Equal(c.local.IA) + case *SVCAddr: + sameIA = a.IA.Equal(c.local.IA) + } + + if sameIA { + return listenHostIP, listenHostPort, nil + } + + nextHopIP, ok := netip.AddrFromSlice(nextHop.IP) + if !ok { + return netip.Addr{}, 0, serrors.New("invalid next hop IP", "ip", nextHop.IP) + } + nextHopIP = nextHopIP.Unmap() + nextHopAddrPort := netip.AddrPortFrom(nextHopIP, uint16(nextHop.Port)) + + mappedAddr, err := stunConn.mappedAddr(nextHopAddrPort) + if err != nil { + return netip.Addr{}, 0, serrors.New("Error getting mapped address for STUN", "stun", err) + } + + return mappedAddr.Addr(), mappedAddr.Port(), nil +} diff --git a/pkg/stun/stun.go b/pkg/stun/stun.go index c0ef158151..8d5113bbbb 100755 --- a/pkg/stun/stun.go +++ b/pkg/stun/stun.go @@ -4,36 +4,78 @@ // Copied from https://github.com/tailscale/tailscale/blob/main/net/stun/stun.go // Modifications: // - Remove requirement for "software" attribute -// - Remove unused methods // Package STUN parses STUN binding request packets and generates response packets. package stun import ( + "bytes" + crand "crypto/rand" "encoding/binary" "errors" "hash/crc32" + "net" "net/netip" ) const ( attrNumFingerprint = 0x8028 + attrMappedAddress = 0x0001 attrXorMappedAddress = 0x0020 - bindingRequest = "\x00\x01" - magicCookie = "\x21\x12\xa4\x42" - lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 - headerLen = 20 + // This alternative attribute type is not + // mentioned in the RFC, but the shift into + // the "comprehension-optional" range seems + // like an easy mistake for a server to make. + // And servers appear to send it. + attrXorMappedAddressAlt = 0x8020 + + bindingRequest = "\x00\x01" + magicCookie = "\x21\x12\xa4\x42" + lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 + headerLen = 20 ) // TxID is a transaction ID. type TxID [12]byte +// NewTxID returns a new random TxID. +func NewTxID() TxID { + var tx TxID + if _, err := crand.Read(tx[:]); err != nil { + panic(err) + } + return tx +} + +// Request generates a binding request STUN packet. +// The transaction ID, tID, should be a random sequence of bytes. +func Request(tID TxID) []byte { + // STUN header, RFC5389 Section 6. + b := make([]byte, 0, headerLen+lenFingerprint) + b = append(b, bindingRequest...) + b = appendU16(b, uint16(lenFingerprint)) // number of bytes following header + b = append(b, magicCookie...) + b = append(b, tID[:]...) + + // Attribute FINGERPRINT, RFC5389 Section 15.5. + fp := fingerPrint(b) + b = appendU16(b, attrNumFingerprint) + b = appendU16(b, 4) + b = appendU32(b, fp) + + return b +} + func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } func appendU16(b []byte, v uint16) []byte { return append(b, byte(v>>8), byte(v)) } +func appendU32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + // ParseBindingRequest parses a STUN binding request. func ParseBindingRequest(b []byte) (TxID, error) { if !Is(b) { @@ -66,11 +108,12 @@ func ParseBindingRequest(b []byte) (TxID, error) { } var ( - ErrNotSTUN = errors.New("malformed STUN packet") - ErrMalformedAttrs = errors.New("STUN response has malformed attributes") - ErrNotBindingRequest = errors.New("STUN request not a binding request") - ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") - ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") + ErrNotSTUN = errors.New("malformed STUN packet") + ErrNotSuccessResponse = errors.New("STUN packet is not a response") + ErrMalformedAttrs = errors.New("STUN response has malformed attributes") + ErrNotBindingRequest = errors.New("STUN request not a binding request") + ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") + ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") ) func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { @@ -132,6 +175,120 @@ func Response(txID TxID, addrPort netip.AddrPort) []byte { return b } +// ParseResponse parses a successful binding response STUN packet. +// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. +func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { + if !Is(b) { + return tID, netip.AddrPort{}, ErrNotSTUN + } + copy(tID[:], b[8:8+len(tID)]) + if b[0] != 0x01 || b[1] != 0x01 { + return tID, netip.AddrPort{}, ErrNotSuccessResponse + } + attrsLen := int(binary.BigEndian.Uint16(b[2:4])) + b = b[headerLen:] // remove STUN header + if attrsLen > len(b) { + return tID, netip.AddrPort{}, ErrMalformedAttrs + } else if len(b) > attrsLen { + b = b[:attrsLen] // trim trailing packet bytes + } + + var fallbackAddr netip.AddrPort + + // Read through the attributes. + // The the addr+port reported by XOR-MAPPED-ADDRESS + // as the canonical value. If the attribute is not + // present but the STUN server responds with + // MAPPED-ADDRESS we fall back to it. + if err := foreachAttr(b, func(attrType uint16, attr []byte) error { + switch attrType { + case attrXorMappedAddress, attrXorMappedAddressAlt: + ipSlice, port, err := xorMappedAddress(tID, attr) + if err != nil { + return err + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + addr = netip.AddrPortFrom(ip.Unmap(), port) + } + case attrMappedAddress: + ipSlice, port, err := mappedAddress(attr) + if err != nil { + return ErrMalformedAttrs + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) + } + } + return nil + + }); err != nil { + return TxID{}, netip.AddrPort{}, err + } + + if addr.IsValid() { + return tID, addr, nil + } + if fallbackAddr.IsValid() { + return tID, fallbackAddr, nil + } + return tID, netip.AddrPort{}, ErrMalformedAttrs +} + +func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { + // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + xorPort := binary.BigEndian.Uint16(b[2:4]) + addrField := b[4:] + port = xorPort ^ 0x2112 // first half of magicCookie + + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + xorAddr := addrField[:addrLen] + addr = make([]byte, addrLen) + for i := range xorAddr { + if i < len(magicCookie) { + addr[i] = xorAddr[i] ^ magicCookie[i] + } else { + addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] + } + } + return addr, port, nil +} + +func familyAddrLen(fam byte) int { + switch fam { + case 0x01: // IPv4 + return net.IPv4len + case 0x02: // IPv6 + return net.IPv6len + default: + return 0 + } +} + +func mappedAddress(b []byte) (addr []byte, port uint16, err error) { + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + port = uint16(b[2])<<8 | uint16(b[3]) + addrField := b[4:] + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + return bytes.Clone(addrField[:addrLen]), port, nil +} + // Is reports whether b is a STUN message. func Is(b []byte) bool { return len(b) >= headerLen &&