Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADR 30: Bolt Handshake Manifest v1 #619

Merged
merged 19 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 204 additions & 7 deletions neo4j/internal/bolt/connect.go
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ package bolt
import (
"context"
"fmt"
"io"
"net"
"strings"

"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
Expand All @@ -35,12 +37,18 @@ type protocolVersion struct {
back byte // Number of minor versions back
}

// Supported versions in priority order
func (p *protocolVersion) formatProtocol() string {
return fmt.Sprintf("0x%04X%02X%02X", p.back, p.minor, p.major)
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}

// versions lists the supported protocol versions in priority order.
// The first proposal is a marker indicating that the client wishes to use the
// new manifest-style negotiation.
var versions = [4]protocolVersion{
{major: 0xFF, minor: 0x01, back: 0x00}, // Bolt manifest marker
{major: 5, minor: 7, back: 7},
{major: 4, minor: 4, back: 2},
{major: 4, minor: 1},
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
{major: 3, minor: 0},
{major: 3, minor: 0, back: 0},
}

// Connect initiates the negotiation of the Bolt protocol version.
Expand Down Expand Up @@ -70,6 +78,7 @@ func Connect(ctx context.Context,
boltLogger.LogClientMessage("", "<MAGIC> %#010X", handshake[0:4])
boltLogger.LogClientMessage("", "<HANDSHAKE> %#010X %#010X %#010X %#010X", handshake[4:8], handshake[8:12], handshake[12:16], handshake[16:20])
}
// Write handshake proposals to server
_, err := racing.NewRacingWriter(conn).Write(ctx, handshake)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
Expand All @@ -84,14 +93,24 @@ func Connect(ctx context.Context,
return nil, err
}

if boltLogger != nil {
boltLogger.LogServerMessage("", "<HANDSHAKE> %#010X", buf)
}

major := buf[3]
minor := buf[2]

// Log legacy handshake response.
if !(major == 0xFF && minor == 0x01) && boltLogger != nil {
boltLogger.LogServerMessage("", "<HANDSHAKE> %#010X", buf)
}

bufferedConn := bufferedConnection(conn, readBufferSize)

// If the server selected manifest negotiation, perform the extended handshake.
if major == 0xFF && minor == 0x01 {
major, minor, err = performManifestNegotiation(ctx, bufferedConn, serverName, errorListener, boltLogger, buf)
if err != nil {
return nil, err
}
}

var boltConn db.Connection
switch major {
case 3:
Expand All @@ -115,3 +134,181 @@ func Connect(ctx context.Context,
}
return boltConn, nil
}

// performManifestNegotiation handles the manifest-style handshake.
// Returns the negotiated protocol's major and minor version.
func performManifestNegotiation(
ctx context.Context,
conn io.ReadWriteCloser,
serverName string,
errorListener ConnectionErrorListener,
boltLogger log.BoltLogger,
response []byte,
) (byte, byte, error) {
reader := racing.NewRacingReader(conn)

// Read the protocol offerings.
count, supported, err := readProtocolOfferings(ctx, reader, serverName, errorListener)
if err != nil {
return 0, 0, err
}

// Read the capability mask.
_, capBytes, err := readCapabilityMask(ctx, reader, serverName, errorListener)
if err != nil {
return 0, 0, err
}

// Log the complete server handshake message.
logManifestHandshake(boltLogger, response, count, supported, capBytes)

// Select an acceptable protocol version.
chosen, err := selectProtocol(supported, errorListener, serverName)
if err != nil {
invalidHandshake := []byte{0x00, 0x00, 0x00, 0x00, 0x00} // 4 bytes for version + 1 byte for capabilities.
if _, err := conn.Write(invalidHandshake); err != nil {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
errorListener.OnDialError(ctx, serverName, err)
}
return 0, 0, err
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}

// Send the handshake confirmation.
if err = sendHandshakeConfirmation(ctx, conn, boltLogger, errorListener, serverName, chosen, capBytes); err != nil {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
return 0, 0, err
}

return chosen.major, chosen.minor, nil
}

// readProtocolOfferings reads the number of protocol offerings and returns the count and
// a slice of supported protocol versions.
func readProtocolOfferings(ctx context.Context, r racing.RacingReader, serverName string, errorListener ConnectionErrorListener) (uint64, []protocolVersion, error) {
count, err := readVarInt(ctx, r)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
return 0, nil, fmt.Errorf("failed to read manifest protocol count: %w", err)
}
supported := make([]protocolVersion, count)
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
for i := uint64(0); i < count; i++ {
var versionBytes [4]byte
_, err := r.ReadFull(ctx, versionBytes[:])
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
return 0, nil, fmt.Errorf("failed to read manifest protocol version: %w", err)
}
supported[i] = protocolVersion{
back: versionBytes[1],
minor: versionBytes[2],
major: versionBytes[3],
}
}
return count, supported, nil
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}

// readCapabilityMask reads the capability bit mask (a Base128 VarInt) and returns both the
// raw value and its encoded byte slice.
func readCapabilityMask(ctx context.Context, r racing.RacingReader, serverName string, errorListener ConnectionErrorListener) (uint64, []byte, error) {
capMask, err := readVarInt(ctx, r)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
return 0, nil, fmt.Errorf("failed to read capability mask: %w", err)
}
capBytes, err := encodeVarInt(capMask)
if err != nil {
return 0, nil, fmt.Errorf("failed to encode capability mask: %w", err)
}
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
return capMask, capBytes, nil
}

// logManifestHandshake logs the complete server handshake message for manifest negotiation.
// It prints the initial response, count of offerings, each supported protocol, and the capability mask.
func logManifestHandshake(boltLogger log.BoltLogger, response []byte, count uint64, supported []protocolVersion, capBytes []byte) {
if boltLogger == nil {
return
}
var supportedProtocols []string
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
for _, p := range supported {
supportedProtocols = append(supportedProtocols, p.formatProtocol())
}
boltLogger.LogServerMessage("", "<HANDSHAKE> %s [%d] %s %s",
fmt.Sprintf("%#X", response),
count,
strings.Join(supportedProtocols, " "),
fmt.Sprintf("%#X", capBytes))
}

// selectProtocol iterates over our protocol proposals (skipping the manifest marker)
// and returns the first candidate whose major version matches and whose minor version
// falls within the range offered by the server.
func selectProtocol(supported []protocolVersion, errorListener ConnectionErrorListener, serverName string) (protocolVersion, error) {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
proposals := versions[1:]
for _, candidate := range proposals {
for _, offer := range supported {
if candidate.major == offer.major &&
candidate.minor <= offer.minor &&
candidate.minor >= (offer.minor-offer.back) {
return candidate, nil
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
return protocolVersion{}, fmt.Errorf("none of the server offered Bolt versions are supported (offered: %#v)", supported)
}

// sendHandshakeConfirmation sends the chosen protocol version and capability mask back to the server.
func sendHandshakeConfirmation(ctx context.Context, conn io.ReadWriteCloser, boltLogger log.BoltLogger, errorListener ConnectionErrorListener, serverName string, chosen protocolVersion, capBytes []byte) error {
chosenBytes := []byte{0x00, 0x00, chosen.minor, chosen.major}
if boltLogger != nil {
boltLogger.LogClientMessage("", "<HANDSHAKE> %#X %#X", chosenBytes, capBytes)
}
if _, err := conn.Write(chosenBytes); err != nil {
errorListener.OnDialError(ctx, serverName, err)
return err
}
if _, err := conn.Write(capBytes); err != nil {
errorListener.OnDialError(ctx, serverName, err)
return err
}
return nil
}

// readVarInt reads a Base128-encoded variable-length integer and returns the
// decoded unsigned integer, or an error if the value is too long or the read fails.
func readVarInt(ctx context.Context, r racing.RacingReader) (uint64, error) {
var result uint64
var shift uint
var buf [1]byte
for {
_, err := r.Read(ctx, buf[:])
if err != nil {
return 0, err
}
b := buf[0]
result |= uint64(b&0x7F) << shift
// The most significant bit is the continuation flag.
if b&0x80 == 0 {
break
}
shift += 7
if shift >= 64 {
return 0, fmt.Errorf("varint too long")
}
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}
return result, nil
}

// encodeVarInt encodes the given unsigned integer into a Base128 variable-length integer.
// Returns the encoded bytes or an error if the encoding fails.
func encodeVarInt(value uint64) ([]byte, error) {
var buf []byte
for {
b := byte(value & 0x7F)
value >>= 7
if value != 0 {
buf = append(buf, b|0x80)
} else {
buf = append(buf, b)
break
}
}
return buf, nil
}
87 changes: 87 additions & 0 deletions neo4j/internal/bolt/connect_test.go
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package bolt

import (
"bytes"
"context"
"testing"

Expand Down Expand Up @@ -99,3 +100,89 @@ func TestConnect(ot *testing.T) {
}
})
}

// fakeConn is a simple in-memory implementation of io.ReadWriteCloser.
type fakeConn struct {
r *bytes.Buffer // Data to be read (simulated server response)
w *bytes.Buffer // Data written by the client
}

func newFakeConn(readData []byte) *fakeConn {
return &fakeConn{
r: bytes.NewBuffer(readData),
w: &bytes.Buffer{},
}
}

func (f *fakeConn) Read(p []byte) (int, error) {
return f.r.Read(p)
}

func (f *fakeConn) Write(p []byte) (int, error) {
return f.w.Write(p)
}

func (f *fakeConn) Close() error {
return nil
}

// TestPerformManifestNegotiationSuccess simulates a successful manifest handshake.
// It provides a valid manifest handshake response and verifies that the negotiated
// protocol version is correct and that the handshake confirmation is written.
func TestPerformManifestNegotiationSuccess(t *testing.T) {
ctx := context.Background()
serverName := "testServer"
errorListener := &noopErrorListener{}

manifestData := []byte{
0x03, // count = 3
0x00, 0x07, 0x07, 0x05, // offering 1 --> protocol version 5.7 (back 7)
0x00, 0x02, 0x04, 0x04, // offering 2 --> protocol version 4.4 (back 2)
0x00, 0x00, 0x00, 0x03, // offering 3 --> protocol version 3.0
0x8F, 0x01, // capability mask
}
fake := newFakeConn(manifestData)
response := []byte{0x00, 0x00, 0x01, 0xFF}

major, minor, err := performManifestNegotiation(ctx, fake, serverName, errorListener, nil, response)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}

if major != 5 || minor != 7 {
t.Fatalf("Expected negotiated version 5.7, got %d.%d", major, minor)
}

expectedConfirmation := []byte{0x00, 0x00, 0x07, 0x05, 0x8F, 0x01}
if !bytes.Equal(fake.w.Bytes(), expectedConfirmation) {
t.Errorf("Handshake confirmation mismatch.\nExpected: % X\nGot: % X", expectedConfirmation, fake.w.Bytes())
}
}

// TestPerformManifestNegotiationNoSupportedVersion simulates a manifest handshake in which
// none of the server-offered protocol versions is acceptable to the client.
// It verifies that an error is returned and that the invalid handshake is sent.
func TestPerformManifestNegotiationNoSupportedVersion(t *testing.T) {
ctx := context.Background()
serverName := "testServer"
errorListener := &noopErrorListener{}

manifestData := []byte{
0x01, // count = 1
0x00, 0x00, 0xFF, 0xFF, // offering 1 --> protocol version 255.255
0x00, // capability mask
}
fake := newFakeConn(manifestData)
response := []byte{0x00, 0x00, 0x01, 0xFF}

_, _, err := performManifestNegotiation(ctx, fake, serverName, errorListener, nil, response)
if err == nil {
t.Fatal("Expected error for unsupported protocol version, got nil")
}

// In case of no supported version, the invalid handshake is sent.
expectedInvalid := []byte{0x00, 0x00, 0x00, 0x00, 0x00}
if !bytes.Equal(fake.w.Bytes(), expectedInvalid) {
t.Errorf("Expected invalid handshake % X, got % X", expectedInvalid, fake.w.Bytes())
}
}
2 changes: 1 addition & 1 deletion testkit-backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,6 @@ func (b *backend) handleRequest(req map[string]any) {
"Feature:Auth:Kerberos",
"Feature:Auth:Managed",
"Feature:Bolt:3.0",
"Feature:Bolt:4.1",
"Feature:Bolt:4.2",
"Feature:Bolt:4.3",
"Feature:Bolt:4.4",
Expand All @@ -1303,6 +1302,7 @@ func (b *backend) handleRequest(req map[string]any) {
"Feature:Bolt:5.6",
"Feature:Bolt:5.7",
"Feature:Bolt:Patch:UTC",
"Feature:Bolt:HandshakeManifestV1",
"Feature:Impersonation",
//"Feature:TLS:1.1",
"Feature:TLS:1.2",
Expand Down