Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion p2p/security/noise/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) {
init, resp := net.Pipe()
_ = resp.Close()

session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true)
session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true, nil)
_, err := session.encrypt(nil, []byte("hi"))
if err == nil {
t.Error("expected encryption error when handshake incomplete")
Expand Down
57 changes: 28 additions & 29 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {

if s.initiator {
// stage 0 //
// do not send the payload just yet, as it would be plaintext; not secret.
// Handshake Msg Len = len(DH ephemeral key)
err = s.sendHandshakeMessage(hs, nil, hbuf)
if err != nil {
if err := s.sendHandshakeMessage(hs, s.earlyData, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}

Expand All @@ -91,44 +89,45 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil {
return err
}

// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
} else {
// stage 0 //
// We don't expect any payload on the first message.
if _, err := s.readHandshakeMessage(hs); err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}

// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
err = s.sendHandshakeMessage(hs, payload, hbuf)
if err != nil {
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
return nil
}

// stage 2 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
// stage 0 //
// We don't expect any payload on the first message.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do now

initialPayload, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler(initialPayload); err != nil {
return err
}
} else if len(initialPayload) > 0 {
return fmt.Errorf("received unexpected early data (%d bytes)", len(initialPayload))
}

return nil
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}

// stage 2 //
plaintext, err := s.readHandshakeMessage(hs)
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
}

// setCipherStates sets the initial cipher states that will be used to protect
Expand Down
19 changes: 12 additions & 7 deletions p2p/security/noise/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type secureSession struct {
insecureReader *bufio.Reader // to cushion io read syscalls
// we don't buffer writes to avoid introducing latency; optimisation possible. // TODO revisit

earlyData []byte
earlyDataHandler func([]byte) error

qseek int // queued bytes seek value.
qbuf []byte // queued bytes buffer.
rlen [2]byte // work buffer to read in the incoming message length.
Expand All @@ -38,14 +41,16 @@ type secureSession struct {

// newSecureSession creates a Noise session over the given insecureConn Conn, using
// the libp2p identity keypair from the given Transport.
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) {
func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool, earlyData []byte) (*secureSession, error) {
s := &secureSession{
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
insecureConn: insecure,
insecureReader: bufio.NewReader(insecure),
initiator: initiator,
localID: tpt.localID,
localKey: tpt.privateKey,
remoteID: remote,
earlyData: earlyData,
earlyDataHandler: tpt.earlyDataHandler,
}

// the go-routine we create to run the handshake will
Expand Down
36 changes: 31 additions & 5 deletions p2p/security/noise/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,51 @@ const ID = "/noise"

var _ sec.SecureTransport = &Transport{}

type Option func(*Transport) error

// WithEarlyDataHandler specifies a handler for early data sent by the initiator.
// If the error returned is non-nil, the handshake is aborted.
func WithEarlyDataHandler(h func([]byte) error) Option {
return func(t *Transport) error {
t.earlyDataHandler = h
return nil
}
}

// Transport implements the interface sec.SecureTransport
// https://godoc.org/github.com/libp2p/go-libp2p-core/sec#SecureConn
type Transport struct {
localID peer.ID
privateKey crypto.PrivKey

earlyDataHandler func([]byte) error
}

// New creates a new Noise transport using the given private key as its
// libp2p identity key.
func New(privkey crypto.PrivKey) (*Transport, error) {
func New(privkey crypto.PrivKey, opts ...Option) (*Transport, error) {
localID, err := peer.IDFromPrivateKey(privkey)
if err != nil {
return nil, err
}

return &Transport{
t := &Transport{
localID: localID,
privateKey: privkey,
}, nil
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}

return t, nil
}

// SecureInbound runs the Noise handshake as the responder.
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
c, err := newSecureSession(t, ctx, insecure, p, false)
c, err := newSecureSession(t, ctx, insecure, p, false, nil)
if err != nil {
addr, maErr := manet.FromNetAddr(insecure.RemoteAddr())
if maErr == nil {
Expand All @@ -52,5 +72,11 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer

// SecureOutbound runs the Noise handshake as the initiator.
func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, true)
return newSecureSession(t, ctx, insecure, p, true, nil)
}

// SecureOutboundWithEarlyData runs the Noise handshake as the initiator.
// earlyData is sent (unencrypted!) along with the first handshake message.
func (t *Transport) SecureOutboundWithEarlyData(ctx context.Context, insecure net.Conn, p peer.ID, earlyData []byte) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, true, earlyData)
}
60 changes: 40 additions & 20 deletions p2p/security/noise/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,19 @@ import (
"golang.org/x/crypto/chacha20poly1305"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newTestTransport(t *testing.T, typ, bits int) *Transport {
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
if err != nil {
t.Fatal(err)
}
id, err := peer.IDFromPublicKey(pub)
if err != nil {
t.Fatal(err)
}
return &Transport{
localID: id,
privateKey: priv,
}
func newTestTransport(t *testing.T, typ, bits int, opts ...Option) *Transport {
t.Helper()
priv, _, err := crypto.GenerateKeyPair(typ, bits)
require.NoError(t, err)
tr, err := New(priv, opts...)
require.NoError(t, err)
return tr
}

// Create a new pair of connected TCP sockets.
Expand Down Expand Up @@ -84,14 +77,27 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
<-done

if initErr != nil {
t.Fatal(initErr)
}
require.NoError(t, initErr)
require.NoError(t, respErr)
return initConn.(*secureSession), respConn.(*secureSession)
}

if respErr != nil {
t.Fatal(respErr)
}
func connectWithEarlyData(t *testing.T, initTransport, respTransport *Transport, earlyData []byte) (*secureSession, *secureSession) {
init, resp := newConnPair(t)

var initConn sec.SecureConn
var initErr error
done := make(chan struct{})
go func() {
defer close(done)
initConn, initErr = initTransport.SecureOutboundWithEarlyData(context.TODO(), init, respTransport.localID, earlyData)
}()

respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
<-done

require.NoError(t, initErr)
require.NoError(t, respErr)
return initConn.(*secureSession), respConn.(*secureSession)
}

Expand Down Expand Up @@ -372,3 +378,17 @@ func TestReadUnencryptedFails(t *testing.T) {
require.Error(t, err)
require.Equal(t, 0, afterLen)
}

func TestEarlyData(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
earlyDataChan := make(chan []byte, 1)
respTransport := newTestTransport(t, crypto.Ed25519, 2048, WithEarlyDataHandler(func(b []byte) error {
earlyDataChan <- b
return nil
}))

initConn, respConn := connectWithEarlyData(t, initTransport, respTransport, []byte("foobar"))
defer initConn.Close()
defer respConn.Close()
require.Equal(t, []byte("foobar"), <-earlyDataChan)
}