Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions conn/bind_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ type StdNetBind struct {

blackhole4 bool
blackhole6 bool

extraFns []ControlFn
}

func NewStdNetBind() Bind {
func NewStdNetBind(fns []ControlFn) Bind {
return &StdNetBind{
udpAddrPool: sync.Pool{
New: func() any {
Expand All @@ -70,6 +72,8 @@ func NewStdNetBind() Bind {
return &msgs
},
},

extraFns: fns,
}
}

Expand Down Expand Up @@ -119,8 +123,8 @@ func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}

func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
func listenNet(network string, port int, fns []ControlFn) (*net.UDPConn, int, error) {
conn, err := listenConfig(fns).ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -156,13 +160,13 @@ again:
var v4pc *ipv4.PacketConn
var v6pc *ipv6.PacketConn

v4conn, port, err = listenNet("udp4", port)
v4conn, port, err = listenNet("udp4", port, s.extraFns)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}

// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
v6conn, port, err = listenNet("udp6", port, s.extraFns)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
Expand Down Expand Up @@ -338,7 +342,7 @@ func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}

func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
func (s *StdNetBind) Send(bufs [][]byte, services []uint64, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
Expand Down
4 changes: 2 additions & 2 deletions conn/bind_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewDefaultBind() Bind { return NewWinRingBind() }

func NewWinRingBind() Bind {
if !winrio.Initialize() {
return NewStdNetBind()
return NewStdNetBind([]ControlFn{})
}
return new(WinRingBind)
}
Expand Down Expand Up @@ -486,7 +486,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}

func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
func (bind *WinRingBind) Send(bufs [][]byte, services []uint64, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
Expand Down
2 changes: 1 addition & 1 deletion conn/bindtest/bindtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
}
}

func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
func (c *ChannelBind) Send(bufs [][]byte, services []uint64, ep conn.Endpoint) error {
for _, b := range bufs {
select {
case <-c.closeSignal:
Expand Down
2 changes: 1 addition & 1 deletion conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Bind interface {

// Send writes one or more packets in bufs to address ep. The length of
// bufs must not exceed BatchSize().
Send(bufs [][]byte, ep Endpoint) error
Send(bufs [][]byte, services []uint64, ep Endpoint) error

// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
Expand Down
12 changes: 9 additions & 3 deletions conn/controlfns.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,29 @@ const socketBufferSize = 7 << 20
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
type ControlFn func(network, address string, c syscall.RawConn) error

// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
var controlFns = []ControlFn{}

// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
func listenConfig(extraFns []ControlFn) *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}

for _, fn := range extraFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
Expand Down
2 changes: 1 addition & 1 deletion conn/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

package conn

func NewDefaultBind() Bind { return NewStdNetBind() }
func NewDefaultBind() Bind { return NewStdNetBind(nil) }
4 changes: 2 additions & 2 deletions conn/sticky_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func Test_getSrcFromControl(t *testing.T) {

func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
Expand All @@ -239,7 +239,7 @@ func Test_listenConfig(t *testing.T) {
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
Expand Down
10 changes: 5 additions & 5 deletions device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,11 +426,11 @@ type fakeBindSized struct {
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(bufs [][]byte, services []uint64, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }

type fakeTUNDeviceSized struct {
size int
Expand Down
4 changes: 2 additions & 2 deletions device/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}

func (peer *Peer) SendBuffers(buffers [][]byte) error {
func (peer *Peer) SendBuffers(buffers [][]byte, services []uint64) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()

Expand All @@ -133,7 +133,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error {
}
peer.endpoint.Unlock()

err := peer.device.net.bind.Send(buffers, endpoint)
err := peer.device.net.bind.Send(buffers, services, endpoint)
if err == nil {
var totalLen uint64
for _, b := range buffers {
Expand Down
24 changes: 20 additions & 4 deletions device/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type QueueOutboundElement struct {
nonce uint64 // nonce for encryption
keypair *Keypair // keypair for encryption
peer *Peer // related peer
service uint64 // inner packet service identifier
drop bool // service identifier result, should drop this packet
}

type QueueOutboundElementsContainer struct {
Expand Down Expand Up @@ -130,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()

err = peer.SendBuffers([][]byte{packet})
err = peer.SendBuffers([][]byte{packet}, []uint64{0})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
Expand Down Expand Up @@ -167,7 +169,7 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketSent()

// TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{packet})
err = peer.SendBuffers([][]byte{packet}, []uint64{0})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
Expand All @@ -187,7 +189,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
packet := make([]byte, MessageCookieReplySize)
_ = reply.marshal(packet)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
device.net.bind.Send([][]byte{packet}, []uint64{0}, initiatingElem.endpoint)

return nil
}
Expand Down Expand Up @@ -445,6 +447,14 @@ func (device *Device) RoutineEncryption(id int) {

for elemsContainer := range device.queue.encryption.c {
for _, elem := range elemsContainer.elems {
// identify inner packet
service, shouldDrop := ExecuteServiceFns(elem.packet)
if shouldDrop {
elem.drop = true
continue
}
elem.service = service

// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]

Expand Down Expand Up @@ -483,9 +493,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device.log.Verbosef("%v - Routine: sequential sender - started", peer)

bufs := make([][]byte, 0, maxBatchSize)
services := make([]uint64, 0, maxBatchSize)

for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
services = services[:0]
if elemsContainer == nil {
return
}
Expand All @@ -507,16 +519,20 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
dataSent := false
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if elem.drop {
continue
}
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
bufs = append(bufs, elem.packet)
services = append(services, elem.service)
}

peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()

err := peer.SendBuffers(bufs)
err := peer.SendBuffers(bufs, services)
if dataSent {
peer.timersDataSent()
}
Expand Down
26 changes: 26 additions & 0 deletions device/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package device

// ServiceFn process packet and return serivce id and drop flag
type ServiceFn func(buff []byte) (service uint64, shouldDrop bool)

var serviceFns []ServiceFn

// RegisterServiceFn register service function to identify packet
func RegisterServiceFn(fn ServiceFn) {
serviceFns = append(serviceFns, fn)
}

// ExecuteServiceFns to process packet data
func ExecuteServiceFns(buff []byte) (service uint64, shouldDrop bool) {
finalService := uint64(0)
for _, fn := range serviceFns {
service, shouldDrop = fn(buff)
if service != 0 {
finalService = service
}
if shouldDrop {
return finalService, true
}
}
return finalService, false
}