Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADR 30: Bolt Handshake Manifest v1 #619

Merged
merged 19 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions neo4j/internal/bolt/bolt4.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,8 @@ func (b *bolt4) Connect(
hello := map[string]any{
"user_agent": userAgent,
}
// On bolt >= 4.1 add routing to enable/disable routing
if b.minor >= 1 {
if routingContext != nil {
hello["routing"] = routingContext
}
if routingContext != nil {
hello["routing"] = routingContext
}
checkUtcPatch := b.minor >= 3
if checkUtcPatch {
Expand Down
31 changes: 0 additions & 31 deletions neo4j/internal/bolt/bolt4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,37 +267,6 @@ func TestBolt4(outer *testing.T) {
bolt.Close(context.Background())
})

outer.Run("No routing in hello on 4.0", func(t *testing.T) {
routingContext := map[string]string{"some": "thing"}
conn, srv, cleanup := setupBolt4Pipe(t)
defer cleanup()
go func() {
srv.waitForHandshake()
srv.acceptVersion(4, 0)
hmap := srv.waitForHello()
_, exists := hmap["routing"].(map[string]any)
if exists {
panic("Should be no routing entry")
}
srv.acceptHello()
}()
bolt, err := Connect(
context.Background(),
"serverName",
conn,
auth,
"007",
routingContext,
nil,
logger,
nil,
idb.NotificationConfig{},
DefaultReadBufferSize,
)
AssertNoError(t, err)
bolt.Close(context.Background())
})

outer.Run("Failed authentication", func(t *testing.T) {
conn, srv, cleanup := setupBolt4Pipe(t)
defer cleanup()
Expand Down
205 changes: 194 additions & 11 deletions neo4j/internal/bolt/connect.go
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ package bolt

import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"strings"

"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
Expand All @@ -35,12 +38,18 @@ type protocolVersion struct {
back byte // Number of minor versions back
}

// Supported versions in priority order
func (p *protocolVersion) formatProtocol() string {
return fmt.Sprintf("%#04X%02X%02X", p.back, p.minor, p.major)
}

// versions lists the supported protocol versions in priority order.
// The first proposal is a marker indicating that the client wishes to use the
// new manifest-style negotiation.
var versions = [4]protocolVersion{
{major: 0xFF, minor: 0x01, back: 0x00}, // Bolt manifest marker
{major: 5, minor: 7, back: 7},
{major: 4, minor: 4, back: 2},
{major: 4, minor: 1},
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
{major: 3, minor: 0},
{major: 3, minor: 0, back: 0},
}

// Connect initiates the negotiation of the Bolt protocol version.
Expand Down Expand Up @@ -70,6 +79,7 @@ func Connect(ctx context.Context,
boltLogger.LogClientMessage("", "<MAGIC> %#010X", handshake[0:4])
boltLogger.LogClientMessage("", "<HANDSHAKE> %#010X %#010X %#010X %#010X", handshake[4:8], handshake[8:12], handshake[12:16], handshake[16:20])
}
// Write handshake proposals to server
_, err := racing.NewRacingWriter(conn).Write(ctx, handshake)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
Expand All @@ -84,14 +94,29 @@ func Connect(ctx context.Context,
return nil, err
}

if boltLogger != nil {
boltLogger.LogServerMessage("", "<HANDSHAKE> %#010X", buf)
}

major := buf[3]
minor := buf[2]

if major == 80 && minor == 84 {
return nil, &errorutil.UsageError{Message: "server responded HTTP. Make sure you are not trying to connect to the http endpoint " +
"(HTTP defaults to port 7474 whereas BOLT defaults to port 7687)"}
}

// Log legacy handshake response.
if !(major == 0xFF && minor == 0x01) && boltLogger != nil {
boltLogger.LogServerMessage("", "<HANDSHAKE> %#010X", buf)
}

bufferedConn := bufferedConnection(conn, readBufferSize)

// If the server selected manifest negotiation, perform the extended handshake.
if major == 0xFF && minor == 0x01 {
major, minor, err = performManifestNegotiation(ctx, bufferedConn, serverName, errorListener, boltLogger, buf)
if err != nil {
return nil, err
}
}

var boltConn db.Connection
switch major {
case 3:
Expand All @@ -103,10 +128,6 @@ func Connect(ctx context.Context,
case 0:
return nil, fmt.Errorf("server did not accept any of the requested Bolt versions (%#v)", versions)
default:
if major == 80 && minor == 84 {
return nil, &errorutil.UsageError{Message: "server responded HTTP. Make sure you are not trying to connect to the http endpoint " +
"(HTTP defaults to port 7474 whereas BOLT defaults to port 7687)"}
}
return nil, &errorutil.UsageError{Message: fmt.Sprintf("server responded with unsupported version %d.%d", major, minor)}
}
if err = boltConn.Connect(ctx, int(minor), auth, userAgent, routingContext, notificationConfig); err != nil {
Expand All @@ -115,3 +136,165 @@ func Connect(ctx context.Context,
}
return boltConn, nil
}

// performManifestNegotiation handles the manifest-style handshake.
// Returns the negotiated protocol's major and minor version.
func performManifestNegotiation(
ctx context.Context,
conn io.ReadWriteCloser,
serverName string,
errorListener ConnectionErrorListener,
boltLogger log.BoltLogger,
response []byte,
) (byte, byte, error) {
reader := racing.NewRacingReader(conn)

// Read the protocol offerings.
supported, err := readProtocolOfferings(ctx, reader, serverName, errorListener)
if err != nil {
return 0, 0, err
}

// Read the capability mask.
_, capBytes, err := readCapabilityMask(ctx, reader, serverName, errorListener)
if err != nil {
return 0, 0, err
}

// Log the complete server handshake message.
logManifestHandshake(boltLogger, response, supported, capBytes)

// Select an acceptable protocol version.
chosen, err := selectProtocol(supported)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
// Best-effort attempt to send an invalid handshake (ignore any error).
invalidHandshake := []byte{0x00, 0x00, 0x00, 0x00, 0x00} // 4 bytes for version + 1 byte for capabilities.
_, _ = racing.NewRacingWriter(conn).Write(ctx, invalidHandshake)
return 0, 0, err
}

// Send the handshake confirmation.
if err = sendHandshakeConfirmation(ctx, conn, boltLogger, errorListener, serverName, chosen, []byte{0x00}); err != nil {
return 0, 0, err
}

return chosen.major, chosen.minor, nil
}

// readProtocolOfferings reads the number of protocol offerings and returns the count and
// a slice of supported protocol versions.
func readProtocolOfferings(ctx context.Context, r racing.RacingReader, serverName string, errorListener ConnectionErrorListener) ([]protocolVersion, error) {
count, _, err := readVarInt(ctx, r)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
return nil, fmt.Errorf("failed to read manifest protocol count: %w", err)
}
supported := make([]protocolVersion, 0, count)
for i := uint64(0); i < count; i++ {
var versionBytes [4]byte
_, err := r.ReadFull(ctx, versionBytes[:])
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
return nil, fmt.Errorf("failed to read manifest protocol version: %w", err)
}
supported = append(supported, protocolVersion{
back: versionBytes[1],
minor: versionBytes[2],
major: versionBytes[3],
})
}
return supported, nil
}

// readCapabilityMask reads the capability bit mask (a Base128 VarInt) and returns both the
// raw value and its encoded byte slice.
func readCapabilityMask(ctx context.Context, r racing.RacingReader, serverName string, errorListener ConnectionErrorListener) (uint64, []byte, error) {
capMask, capBytes, err := readVarInt(ctx, r)
if err != nil {
errorListener.OnDialError(ctx, serverName, err)
return 0, capBytes, fmt.Errorf("failed to read capability mask: %w", err)
}
return capMask, capBytes, nil
}

// logManifestHandshake logs the complete server handshake message for manifest negotiation.
// It prints the initial response, count of offerings, each supported protocol, and the capability mask.
func logManifestHandshake(boltLogger log.BoltLogger, response []byte, supported []protocolVersion, capBytes []byte) {
if boltLogger == nil {
return
}
supportedProtocols := make([]string, 0, len(supported))
for _, p := range supported {
supportedProtocols = append(supportedProtocols, p.formatProtocol())
}
boltLogger.LogServerMessage("", "<HANDSHAKE> %s [%d] %s %s",
fmt.Sprintf("%#X", response),
len(supported),
strings.Join(supportedProtocols, " "),
fmt.Sprintf("%#X", capBytes))
}

// selectProtocol iterates over our protocol proposals (skipping the manifest marker)
// and returns the first candidate whose major version matches and whose minor version
// falls within the range offered by the server.
func selectProtocol(offers []protocolVersion) (protocolVersion, error) {
for _, candidate := range versions[1:] {
for v := int(candidate.minor); v >= int(candidate.minor)-int(candidate.back); v-- {
for _, offer := range offers {
if offer.major != candidate.major {
continue
}
if byte(v) <= offer.minor && byte(v) >= offer.minor-offer.back {
return protocolVersion{major: candidate.major, minor: byte(v)}, nil
}
}
}
}
return protocolVersion{}, fmt.Errorf("none of the server offered Bolt versions are supported (offered: %#v)", offers)
}

// sendHandshakeConfirmation sends the chosen protocol version and capability mask back to the server.
func sendHandshakeConfirmation(ctx context.Context, conn io.ReadWriteCloser, boltLogger log.BoltLogger, errorListener ConnectionErrorListener, serverName string, chosen protocolVersion, capBytes []byte) error {
chosenBytes := []byte{0x00, 0x00, chosen.minor, chosen.major}
if boltLogger != nil {
boltLogger.LogClientMessage("", "<HANDSHAKE> %#X %#X", chosenBytes, capBytes)
}
writer := racing.NewRacingWriter(conn)
if _, err := writer.Write(ctx, chosenBytes); err != nil {
errorListener.OnDialError(ctx, serverName, err)
return err
}
if _, err := writer.Write(ctx, capBytes); err != nil {
errorListener.OnDialError(ctx, serverName, err)
return err
}
return nil
}

// readVarInt returns a Base128-encoded variable-length integer from the reader.
func readVarInt(ctx context.Context, r racing.RacingReader) (uint64, []byte, error) {
var buf [binary.MaxVarintLen64]byte
// Read one byte at a time until a byte with the MSB not set is encountered.
for i := 0; i < len(buf); i++ {
if _, err := r.Read(ctx, buf[i:i+1]); err != nil {
return 0, buf[:i], err
}
// If the continuation bit is not set, we've reached the end of the varint.
if buf[i]&0x80 == 0 {
value, n := binary.Uvarint(buf[:i+1])
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
if n <= 0 {
return 0, buf[:i+1], fmt.Errorf("failed to decode varint")
}
return value, buf[:i+1], nil
}
}
return 0, buf[:], fmt.Errorf("varint too long")
}

// encodeVarInt returns the encoded unsigned integer into a Base128 variable-length integer.
func encodeVarInt(value uint64) []byte {
var buf [binary.MaxVarintLen64]byte
n := binary.PutUvarint(buf[:], value)
return buf[:n]
}
Loading