From 6b39f59ef17563d41c942898a5948302d01dba3f Mon Sep 17 00:00:00 2001 From: Jenita Date: Sat, 16 Aug 2025 20:33:37 -0500 Subject: [PATCH 1/8] feat: no error return on connection close Signed-off-by: Jenita --- protocol/blockfetch/blockfetch.go | 33 +++- protocol/blockfetch/blockfetch_test.go | 171 +++++++++++++++++ protocol/blockfetch/client.go | 23 ++- protocol/blockfetch/server.go | 17 ++ protocol/chainsync/chainsync.go | 34 +++- protocol/chainsync/chainsync_test.go | 184 ++++++++++++++++++ protocol/chainsync/client.go | 21 ++- protocol/chainsync/server.go | 21 ++- protocol/txsubmission/client.go | 19 +- protocol/txsubmission/server.go | 20 +- protocol/txsubmission/txsubmission.go | 33 +++- protocol/txsubmission/txsubmission_test.go | 209 +++++++++++++++++++++ 12 files changed, 773 insertions(+), 12 deletions(-) create mode 100644 protocol/blockfetch/blockfetch_test.go create mode 100644 protocol/chainsync/chainsync_test.go create mode 100644 protocol/txsubmission/txsubmission_test.go diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index e82b4b7c..06907f45 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -15,6 +15,9 @@ package blockfetch import ( + "errors" + "io" + "strings" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -76,7 +79,8 @@ var StateMap = protocol.StateMap{ }, }, StateDone: protocol.StateMapEntry{ - Agency: protocol.AgencyNone, + Agency: protocol.AgencyNone, + Transitions: []protocol.StateTransition{}, }, } @@ -118,6 +122,33 @@ func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch { return b } +func (b *BlockFetch) IsDone() bool { + if b.Client != nil && b.Client.IsDone() { + return true + } + if b.Server != nil && b.Server.IsDone() { + return true + } + return false +} + +func (b *BlockFetch) HandleConnectionError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, io.EOF) || isConnectionReset(err) { + if b.IsDone() { + return nil + } + } + return err +} + +func isConnectionReset(err error) bool { + return strings.Contains(err.Error(), "connection reset") || + strings.Contains(err.Error(), "broken pipe") +} + type BlockFetchOptionFunc func(*Config) func NewConfig(options ...BlockFetchOptionFunc) Config { diff --git a/protocol/blockfetch/blockfetch_test.go b/protocol/blockfetch/blockfetch_test.go new file mode 100644 index 00000000..e555a3dc --- /dev/null +++ b/protocol/blockfetch/blockfetch_test.go @@ -0,0 +1,171 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package blockfetch + +import ( + "errors" + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/ledger" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/stretchr/testify/assert" +) + +// testAddr implements net.Addr for testing +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +// testConn implements net.Conn for testing with buffered writes +type testConn struct { + writeChan chan []byte +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + c.writeChan <- b + return len(b), nil +} +func (c *testConn) Close() error { return nil } +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewBlockFetch(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + bf := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, bf.Client) + assert.NotNil(t, bf.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 5*time.Second, cfg.BatchStartTimeout) + assert.Equal(t, 60*time.Second, cfg.BlockTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithBatchStartTimeout(10*time.Second), + WithBlockTimeout(30*time.Second), + WithRecvQueueSize(100), + ) + assert.Equal(t, 10*time.Second, cfg.BatchStartTimeout) + assert.Equal(t, 30*time.Second, cfg.BlockTimeout) + assert.Equal(t, 100, cfg.RecvQueueSize) + }) +} + +func TestConnectionErrorHandling(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + bf := New(getTestProtocolOptions(conn), &cfg) + + t.Run("Non-EOF error when not done", func(t *testing.T) { + err := bf.HandleConnectionError(errors.New("test error")) + assert.Error(t, err) + }) + + t.Run("EOF error when not done", func(t *testing.T) { + err := bf.HandleConnectionError(io.EOF) + assert.Error(t, err) + }) + + t.Run("Connection reset error", func(t *testing.T) { + err := bf.HandleConnectionError(errors.New("connection reset by peer")) + assert.Error(t, err) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + + t.Run("Block callback registration", func(t *testing.T) { + blockFunc := func(ctx CallbackContext, slot uint, block ledger.Block) error { + return nil + } + cfg := NewConfig(WithBlockFunc(blockFunc)) + client := NewClient(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, client) + assert.NotNil(t, cfg.BlockFunc) + }) + + t.Run("RequestRange callback registration", func(t *testing.T) { + requestRangeFunc := func(ctx CallbackContext, start, end common.Point) error { + return nil + } + cfg := NewConfig(WithRequestRangeFunc(requestRangeFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestRangeFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + client := NewClient(getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + // Start the client protocol + client.Start() + + // Send a done message + err := client.Protocol.SendMessage(NewMsgClientDone()) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} + +func TestServerMessageHandling(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + server := NewServer(getTestProtocolOptions(conn), &cfg) + + t.Run("Server can be started", func(t *testing.T) { + server.Start() + + }) +} diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index 23f1091a..fb67a6a6 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -35,6 +35,8 @@ type Client struct { blockUseCallback bool onceStart sync.Once onceStop sync.Once + currentState protocol.State + stateMutex sync.Mutex } func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { @@ -46,6 +48,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { config: cfg, blockChan: make(chan ledger.Block), startBatchResultChan: make(chan error), + currentState: StateIdle, } c.callbackContext = CallbackContext{ Client: c, @@ -82,6 +85,18 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { return c } +func (c *Client) IsDone() bool { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + return c.currentState.Id == StateDone.Id +} + +func (c *Client) setState(newState protocol.State) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.currentState = newState +} + func (c *Client) Start() { c.onceStart.Do(func() { c.Protocol.Logger(). @@ -110,7 +125,11 @@ func (c *Client) Stop() error { "connection_id", c.callbackContext.ConnectionId.String(), ) msg := NewMsgClientDone() - err = c.SendMessage(msg) + if sendErr := c.SendMessage(msg); sendErr != nil { + err = sendErr + return + } + c.setState(StateDone) }) return err } @@ -196,6 +215,8 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleBlock(msg) case MessageTypeBatchDone: err = c.handleBatchDone() + case MessageTypeClientDone: + c.setState(StateDone) default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index e51f3d0e..65aeaded 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -17,6 +17,7 @@ package blockfetch import ( "errors" "fmt" + "sync" "github.com/blinklabs-io/gouroboros/cbor" "github.com/blinklabs-io/gouroboros/protocol" @@ -27,6 +28,8 @@ type Server struct { config *Config callbackContext CallbackContext protoOptions protocol.ProtocolOptions + currentState protocol.State + stateMutex sync.Mutex } func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { @@ -34,6 +37,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { config: cfg, // Save this for re-use later protoOptions: protoOptions, + currentState: StateIdle, } s.callbackContext = CallbackContext{ Server: s, @@ -43,6 +47,18 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { return s } +func (s *Server) IsDone() bool { + s.stateMutex.Lock() + defer s.stateMutex.Unlock() + return s.currentState.Id == StateDone.Id +} + +func (s *Server) setState(newState protocol.State) { + s.stateMutex.Lock() + defer s.stateMutex.Unlock() + s.currentState = newState +} + func (s *Server) initProtocol() { protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, @@ -126,6 +142,7 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeRequestRange: err = s.handleRequestRange(msg) case MessageTypeClientDone: + s.setState(StateDone) err = s.handleClientDone() default: err = fmt.Errorf( diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 2f7443cc..a49b3792 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -16,6 +16,9 @@ package chainsync import ( + "errors" + "io" + "strings" "sync" "time" @@ -192,8 +195,10 @@ var PipelineIsNotEmpty = func(context any, msg protocol.Message) bool { // ChainSync is a wrapper object that holds the client and server instances type ChainSync struct { - Client *Client - Server *Server + Client *Client + Server *Server + stateMutex sync.Mutex + currentState protocol.State } // Config is used to configure the ChainSync protocol instance @@ -329,3 +334,28 @@ func WithRecvQueueSize(size int) ChainSyncOptionFunc { c.RecvQueueSize = size } } + +// HandleConnectionError handles connection errors and determines if they should be ignored +func (c *ChainSync) HandleConnectionError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, io.EOF) || isConnectionReset(err) { + if c.IsDone() { + return nil + } + } + return err +} + +// IsDone returns true if the protocol is in the Done state +func (c *ChainSync) IsDone() bool { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + return c.currentState.Id == stateDone.Id +} + +func isConnectionReset(err error) bool { + return strings.Contains(err.Error(), "connection reset") || + strings.Contains(err.Error(), "broken pipe") +} diff --git a/protocol/chainsync/chainsync_test.go b/protocol/chainsync/chainsync_test.go new file mode 100644 index 00000000..b988399d --- /dev/null +++ b/protocol/chainsync/chainsync_test.go @@ -0,0 +1,184 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package chainsync + +import ( + "errors" + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/stretchr/testify/assert" +) + +// testAddr implements net.Addr for testing +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +// testConn implements net.Conn for testing with buffered writes +type testConn struct { + writeChan chan []byte +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + c.writeChan <- b + return len(b), nil +} +func (c *testConn) Close() error { return nil } +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewChainSync(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + cs := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, cs.Client) + assert.NotNil(t, cs.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 5*time.Second, cfg.IntersectTimeout) + assert.Equal(t, 300*time.Second, cfg.BlockTimeout) + assert.Equal(t, 0, cfg.PipelineLimit) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithIntersectTimeout(10*time.Second), + WithBlockTimeout(30*time.Second), + WithPipelineLimit(100), + WithRecvQueueSize(50), + ) + assert.Equal(t, 10*time.Second, cfg.IntersectTimeout) + assert.Equal(t, 30*time.Second, cfg.BlockTimeout) + assert.Equal(t, 100, cfg.PipelineLimit) + assert.Equal(t, 50, cfg.RecvQueueSize) + }) +} + +func TestConnectionErrorHandling(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + cs := New(getTestProtocolOptions(conn), &cfg) + + t.Run("Non-EOF error when not done", func(t *testing.T) { + err := cs.HandleConnectionError(errors.New("test error")) + assert.Error(t, err) + }) + + t.Run("EOF error when not done", func(t *testing.T) { + err := cs.HandleConnectionError(io.EOF) + assert.Error(t, err) + }) + + t.Run("Connection reset error when not done", func(t *testing.T) { + err := cs.HandleConnectionError(errors.New("connection reset by peer")) + assert.Error(t, err) + }) + + t.Run("EOF error when done", func(t *testing.T) { + // Force done state + cs.currentState = stateDone + err := cs.HandleConnectionError(io.EOF) + assert.NoError(t, err) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + + t.Run("RollForward callback registration", func(t *testing.T) { + rollForwardFunc := func(ctx CallbackContext, blockType uint, block any, tip Tip) error { + return nil + } + cfg := NewConfig(WithRollForwardFunc(rollForwardFunc)) + client := NewClient(&StateContext{}, getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, client) + assert.NotNil(t, cfg.RollForwardFunc) + }) + + t.Run("FindIntersect callback registration", func(t *testing.T) { + findIntersectFunc := func(ctx CallbackContext, points []common.Point) (common.Point, Tip, error) { + return common.Point{}, Tip{}, nil + } + cfg := NewConfig(WithFindIntersectFunc(findIntersectFunc)) + server := NewServer(&StateContext{}, getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.FindIntersectFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + client := NewClient(&StateContext{}, getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + // Start the client protocol + client.Start() + + // Create proper Done message + doneMsg := &MsgDone{} + + // Send a done message + err := client.Protocol.SendMessage(doneMsg) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} + +func TestServerMessageHandling(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + server := NewServer(&StateContext{}, getTestProtocolOptions(conn), &cfg) + + t.Run("Server can be started", func(t *testing.T) { + server.Start() + // Simple test to verify server can be started without error + assert.NotNil(t, server) + }) +} diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 6a15fab2..0ca4330c 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -45,6 +45,8 @@ type Client struct { wantCurrentTipChan chan chan<- Tip wantFirstBlockChan chan chan<- clientPointResult wantIntersectFoundChan chan chan<- clientPointResult + stateMutex sync.Mutex + currentState protocol.State } type clientPointResult struct { @@ -79,6 +81,7 @@ func NewClient( wantCurrentTipChan: make(chan chan<- Tip, 1), wantFirstBlockChan: make(chan chan<- clientPointResult, 1), wantIntersectFoundChan: make(chan chan<- clientPointResult, 1), + currentState: stateIdle, } c.callbackContext = CallbackContext{ Client: c, @@ -118,6 +121,19 @@ func NewClient( return c } +// Add state tracking methods +func (c *Client) setState(newState protocol.State) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.currentState = newState +} + +func (c *Client) IsDone() bool { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + return c.currentState.Id == stateDone.Id +} + func (c *Client) Start() { c.onceStart.Do(func() { c.Protocol.Logger(). @@ -570,12 +586,13 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleIntersectFound(msg) case MessageTypeIntersectNotFound: err = c.handleIntersectNotFound(msg) + case MessageTypeDone: + c.setState(stateDone) default: err = fmt.Errorf( "%s: received unexpected message type %d", ProtocolName, - msg.Type(), - ) + msg.Type()) } return err } diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index fb964bfc..8534c87f 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -17,6 +17,7 @@ package chainsync import ( "errors" "fmt" + "sync" "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" @@ -30,6 +31,8 @@ type Server struct { callbackContext CallbackContext protoOptions protocol.ProtocolOptions stateContext any + stateMutex sync.Mutex + currentState protocol.State } // NewServer returns a new ChainSync server object @@ -43,6 +46,7 @@ func NewServer( // Save these for re-use later protoOptions: protoOptions, stateContext: stateContext, + currentState: stateIdle, } s.callbackContext = CallbackContext{ Server: s, @@ -52,6 +56,18 @@ func NewServer( return s } +func (s *Server) setState(newState protocol.State) { + s.stateMutex.Lock() + defer s.stateMutex.Unlock() + s.currentState = newState +} + +func (s *Server) IsDone() bool { + s.stateMutex.Lock() + defer s.stateMutex.Unlock() + return s.currentState.Id == stateDone.Id +} + func (s *Server) initProtocol() { // Use node-to-client protocol ID ProtocolId := ProtocolIdNtC @@ -165,13 +181,13 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeFindIntersect: err = s.handleFindIntersect(msg) case MessageTypeDone: + s.setState(stateDone) err = s.handleDone() default: err = fmt.Errorf( "%s: received unexpected message type %d", ProtocolName, - msg.Type(), - ) + msg.Type()) } return err } @@ -230,6 +246,7 @@ func (s *Server) handleFindIntersect(msg protocol.Message) error { } func (s *Server) handleDone() error { + s.setState(stateDone) s.Protocol.Logger(). Debug("done", "component", "network", diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index d96d7641..446b13a6 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -28,6 +28,8 @@ type Client struct { config *Config callbackContext CallbackContext onceInit sync.Once + stateMutex sync.Mutex + currentState protocol.State } // NewClient returns a new TxSubmission client object @@ -37,7 +39,8 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { cfg = &tmpCfg } c := &Client{ - config: cfg, + config: cfg, + currentState: stateInit, } c.callbackContext = CallbackContext{ Client: c, @@ -83,6 +86,18 @@ func (c *Client) Init() { }) } +func (c *Client) setState(newState protocol.State) { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.currentState = newState +} + +func (c *Client) IsDone() bool { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + return c.currentState.Id == stateDone.Id +} + func (c *Client) messageHandler(msg protocol.Message) error { c.Protocol.Logger(). Debug(fmt.Sprintf("%s: client message for %+v", ProtocolName, c.callbackContext.ConnectionId.RemoteAddr)) @@ -92,6 +107,8 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleRequestTxIds(msg) case MessageTypeRequestTxs: err = c.handleRequestTxs(msg) + case MessageTypeDone: + c.setState(stateDone) default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index bdbd08f0..065de474 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "math" + "sync" "github.com/blinklabs-io/gouroboros/ledger/common" "github.com/blinklabs-io/gouroboros/protocol" @@ -32,6 +33,8 @@ type Server struct { ackCount int requestTxIdsResultChan chan requestTxIdsResult requestTxsResultChan chan []TxBody + stateMutex sync.Mutex + currentState protocol.State } type requestTxIdsResult struct { @@ -47,6 +50,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { protoOptions: protoOptions, requestTxIdsResultChan: make(chan requestTxIdsResult), requestTxsResultChan: make(chan []TxBody), + currentState: stateInit, } s.callbackContext = CallbackContext{ Server: s, @@ -164,6 +168,18 @@ func (s *Server) RequestTxs(txIds []TxId) ([]TxBody, error) { } } +func (s *Server) setState(newState protocol.State) { + s.stateMutex.Lock() + defer s.stateMutex.Unlock() + s.currentState = newState +} + +func (s *Server) IsDone() bool { + s.stateMutex.Lock() + defer s.stateMutex.Unlock() + return s.currentState.Id == stateDone.Id +} + func (s *Server) messageHandler(msg protocol.Message) error { var err error switch msg.Type() { @@ -172,9 +188,10 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeReplyTxs: err = s.handleReplyTxs(msg) case MessageTypeDone: - err = s.handleDone() + s.setState(stateDone) case MessageTypeInit: err = s.handleInit() + default: err = fmt.Errorf( "%s: received unexpected message type %d", @@ -214,6 +231,7 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error { } func (s *Server) handleDone() error { + s.setState(stateDone) s.Protocol.Logger(). Debug("done", "component", "network", diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index 2c5b9116..3591e1ec 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -16,6 +16,10 @@ package txsubmission import ( + "errors" + "io" + "strings" + "sync" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -113,8 +117,10 @@ var StateMap = protocol.StateMap{ // TxSubmission is a wrapper object that holds the client and server instances type TxSubmission struct { - Client *Client - Server *Server + Client *Client + Server *Server + stateMutex sync.Mutex + currentState protocol.State } // Config is used to configure the TxSubmission protocol instance @@ -201,3 +207,26 @@ func WithIdleTimeout(timeout time.Duration) TxSubmissionOptionFunc { c.IdleTimeout = timeout } } + +func (t *TxSubmission) HandleConnectionError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, io.EOF) || isConnectionReset(err) { + if t.IsDone() { + return nil + } + } + return err +} + +func (t *TxSubmission) IsDone() bool { + t.stateMutex.Lock() + defer t.stateMutex.Unlock() + return t.currentState.Id == stateDone.Id +} + +func isConnectionReset(err error) bool { + return strings.Contains(err.Error(), "connection reset") || + strings.Contains(err.Error(), "broken pipe") +} diff --git a/protocol/txsubmission/txsubmission_test.go b/protocol/txsubmission/txsubmission_test.go new file mode 100644 index 00000000..eb8a3a9c --- /dev/null +++ b/protocol/txsubmission/txsubmission_test.go @@ -0,0 +1,209 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package txsubmission + +import ( + "errors" + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/stretchr/testify/assert" +) + +// testAddr implements net.Addr for testing +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +// testConn implements net.Conn for testing with buffered writes +type testConn struct { + writeChan chan []byte +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + c.writeChan <- b + return len(b), nil +} +func (c *testConn) Close() error { return nil } +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewTxSubmission(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + ts := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, ts.Client) + assert.NotNil(t, ts.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 300*time.Second, cfg.IdleTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithIdleTimeout(60*time.Second), + WithRequestTxIdsFunc(func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { + return nil, nil + }), + WithRequestTxsFunc(func(ctx CallbackContext, txIds []TxId) ([]TxBody, error) { + return nil, nil + }), + ) + assert.Equal(t, 60*time.Second, cfg.IdleTimeout) + assert.NotNil(t, cfg.RequestTxIdsFunc) + assert.NotNil(t, cfg.RequestTxsFunc) + }) +} + +func TestConnectionErrorHandling(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + ts := New(getTestProtocolOptions(conn), &cfg) + + t.Run("Non-EOF error when not done", func(t *testing.T) { + err := ts.HandleConnectionError(errors.New("test error")) + assert.Error(t, err) + }) + + t.Run("EOF error when not done", func(t *testing.T) { + err := ts.HandleConnectionError(io.EOF) + assert.Error(t, err) + }) + + t.Run("Connection reset error when not done", func(t *testing.T) { + err := ts.HandleConnectionError(errors.New("connection reset by peer")) + assert.Error(t, err) + }) + + t.Run("EOF error when done", func(t *testing.T) { + ts.currentState = stateDone + err := ts.HandleConnectionError(io.EOF) + assert.NoError(t, err) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + + t.Run("RequestTxIds callback registration", func(t *testing.T) { + requestTxIdsFunc := func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { + return nil, nil + } + cfg := NewConfig(WithRequestTxIdsFunc(requestTxIdsFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestTxIdsFunc) + }) + + t.Run("RequestTxs callback registration", func(t *testing.T) { + requestTxsFunc := func(ctx CallbackContext, txIds []TxId) ([]TxBody, error) { + return nil, nil + } + cfg := NewConfig(WithRequestTxsFunc(requestTxsFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestTxsFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + client := NewClient(getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + // Start the client protocol + client.Start() + + // Send an init message + initMsg := NewMsgInit() + err := client.Protocol.SendMessage(initMsg) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} + +func TestServerMessageHandling(t *testing.T) { + conn := &testConn{writeChan: make(chan []byte, 1)} + cfg := NewConfig() + server := NewServer(getTestProtocolOptions(conn), &cfg) + + t.Run("Server can be started", func(t *testing.T) { + server.Start() + assert.NotNil(t, server) + }) +} + +func TestIsConnectionReset(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "connection reset", + err: errors.New("connection reset by peer"), + expected: true, + }, + { + name: "broken pipe", + err: errors.New("broken pipe"), + expected: true, + }, + { + name: "other error", + err: errors.New("other error"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isConnectionReset(tt.err)) + }) + } +} From c42b2427c1f7470a1a5f20334c39bfe24aab6671 Mon Sep 17 00:00:00 2001 From: Jenita Date: Sat, 16 Aug 2025 20:42:18 -0500 Subject: [PATCH 2/8] feat: no error return on connection close Signed-off-by: Jenita --- protocol/blockfetch/blockfetch.go | 3 +-- protocol/chainsync/server.go | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index 06907f45..10f89383 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -79,8 +79,7 @@ var StateMap = protocol.StateMap{ }, }, StateDone: protocol.StateMapEntry{ - Agency: protocol.AgencyNone, - Transitions: []protocol.StateTransition{}, + Agency: protocol.AgencyNone, }, } diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index 8534c87f..56f783e1 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -187,7 +187,8 @@ func (s *Server) messageHandler(msg protocol.Message) error { err = fmt.Errorf( "%s: received unexpected message type %d", ProtocolName, - msg.Type()) + msg.Type(), + ) } return err } From 829e00e2e93cf78578deb10908a3ac76557b7aa9 Mon Sep 17 00:00:00 2001 From: Jenita Date: Sat, 16 Aug 2025 20:48:07 -0500 Subject: [PATCH 3/8] feat: no error return on connection close Signed-off-by: Jenita --- protocol/txsubmission/server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index 065de474..46625cc8 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -189,6 +189,7 @@ func (s *Server) messageHandler(msg protocol.Message) error { err = s.handleReplyTxs(msg) case MessageTypeDone: s.setState(stateDone) + err = s.handleDone() case MessageTypeInit: err = s.handleInit() From b255d6652b283020a31789fc439af38717a17765 Mon Sep 17 00:00:00 2001 From: Jenita Date: Mon, 18 Aug 2025 20:39:40 -0500 Subject: [PATCH 4/8] feat: no error return on connection close Signed-off-by: Jenita --- protocol/blockfetch/blockfetch.go | 19 ++-- protocol/blockfetch/blockfetch_test.go | 76 +++++++++++---- protocol/blockfetch/client.go | 27 ++---- protocol/blockfetch/server.go | 20 +--- protocol/chainsync/chainsync.go | 23 ++--- protocol/chainsync/chainsync_test.go | 104 +++++++++++++-------- protocol/chainsync/client.go | 31 ++---- protocol/chainsync/server.go | 26 +----- protocol/protocol.go | 41 +++++++- protocol/txsubmission/client.go | 21 +---- protocol/txsubmission/server.go | 21 +---- protocol/txsubmission/txsubmission.go | 14 +-- protocol/txsubmission/txsubmission_test.go | 97 +++++++++---------- 13 files changed, 248 insertions(+), 272 deletions(-) diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index 10f89383..5a514ca1 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -121,24 +121,17 @@ func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch { return b } -func (b *BlockFetch) IsDone() bool { - if b.Client != nil && b.Client.IsDone() { - return true - } - if b.Server != nil && b.Server.IsDone() { - return true - } - return false -} - func (b *BlockFetch) HandleConnectionError(err error) error { if err == nil { return nil } + // Check if protocol is done or if it's a normal connection closure + if b.Client.IsDone() || b.Server.IsDone() { + return nil + } + if errors.Is(err, io.EOF) || isConnectionReset(err) { - if b.IsDone() { - return nil - } + return err } return err } diff --git a/protocol/blockfetch/blockfetch_test.go b/protocol/blockfetch/blockfetch_test.go index e555a3dc..fc78708f 100644 --- a/protocol/blockfetch/blockfetch_test.go +++ b/protocol/blockfetch/blockfetch_test.go @@ -38,14 +38,33 @@ func (a testAddr) String() string { return "test-addr" } // testConn implements net.Conn for testing with buffered writes type testConn struct { writeChan chan []byte + closed bool + closeChan chan struct{} +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } } func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } func (c *testConn) Write(b []byte) (n int, err error) { - c.writeChan <- b - return len(b), nil + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} +func (c *testConn) Close() error { + if !c.closed { + close(c.closeChan) + c.closed = true + } + return nil } -func (c *testConn) Close() error { return nil } func (c *testConn) LocalAddr() net.Addr { return testAddr{} } func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } func (c *testConn) SetDeadline(t time.Time) error { return nil } @@ -54,6 +73,12 @@ func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { mux := muxer.New(conn) + go mux.Start() + go func() { + <-conn.(*testConn).closeChan + mux.Stop() + }() + return protocol.ProtocolOptions{ ConnectionId: connection.ConnectionId{ LocalAddr: testAddr{}, @@ -65,7 +90,8 @@ func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { } func TestNewBlockFetch(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() bf := New(getTestProtocolOptions(conn), &cfg) assert.NotNil(t, bf.Client) @@ -92,10 +118,17 @@ func TestConfigOptions(t *testing.T) { } func TestConnectionErrorHandling(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() bf := New(getTestProtocolOptions(conn), &cfg) + // Start protocols + bf.Client.Start() + defer bf.Client.Stop() + bf.Server.Start() + defer bf.Server.Stop() + t.Run("Non-EOF error when not done", func(t *testing.T) { err := bf.HandleConnectionError(errors.New("test error")) assert.Error(t, err) @@ -110,10 +143,23 @@ func TestConnectionErrorHandling(t *testing.T) { err := bf.HandleConnectionError(errors.New("connection reset by peer")) assert.Error(t, err) }) + + t.Run("EOF error when done", func(t *testing.T) { + // Send done message to properly transition to done state + err := bf.Client.SendMessage(NewMsgClientDone()) + assert.NoError(t, err) + + // Wait for state transition + time.Sleep(100 * time.Millisecond) + + err = bf.HandleConnectionError(io.EOF) + assert.NoError(t, err, "expected no error when protocol is in done state") + }) } func TestCallbackRegistration(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() t.Run("Block callback registration", func(t *testing.T) { blockFunc := func(ctx CallbackContext, slot uint, block ledger.Block) error { @@ -137,16 +183,17 @@ func TestCallbackRegistration(t *testing.T) { } func TestClientMessageSending(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() client := NewClient(getTestProtocolOptions(conn), &cfg) t.Run("Client can send messages", func(t *testing.T) { - // Start the client protocol client.Start() + defer client.Stop() // Send a done message - err := client.Protocol.SendMessage(NewMsgClientDone()) + err := client.SendMessage(NewMsgClientDone()) assert.NoError(t, err) // Verify message was written to connection @@ -158,14 +205,3 @@ func TestClientMessageSending(t *testing.T) { } }) } - -func TestServerMessageHandling(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} - cfg := NewConfig() - server := NewServer(getTestProtocolOptions(conn), &cfg) - - t.Run("Server can be started", func(t *testing.T) { - server.Start() - - }) -} diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index fb67a6a6..0af7a3c9 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -35,8 +35,6 @@ type Client struct { blockUseCallback bool onceStart sync.Once onceStop sync.Once - currentState protocol.State - stateMutex sync.Mutex } func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { @@ -48,7 +46,6 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { config: cfg, blockChan: make(chan ledger.Block), startBatchResultChan: make(chan error), - currentState: StateIdle, } c.callbackContext = CallbackContext{ Client: c, @@ -85,18 +82,6 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { return c } -func (c *Client) IsDone() bool { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - return c.currentState.Id == StateDone.Id -} - -func (c *Client) setState(newState protocol.State) { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.currentState = newState -} - func (c *Client) Start() { c.onceStart.Do(func() { c.Protocol.Logger(). @@ -124,12 +109,12 @@ func (c *Client) Stop() error { "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) - msg := NewMsgClientDone() - if sendErr := c.SendMessage(msg); sendErr != nil { - err = sendErr - return + if !c.IsDone() { + msg := NewMsgClientDone() + if err = c.SendMessage(msg); err != nil { + return + } } - c.setState(StateDone) }) return err } @@ -216,7 +201,7 @@ func (c *Client) messageHandler(msg protocol.Message) error { case MessageTypeBatchDone: err = c.handleBatchDone() case MessageTypeClientDone: - c.setState(StateDone) + return nil default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index 65aeaded..0500b864 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -17,7 +17,6 @@ package blockfetch import ( "errors" "fmt" - "sync" "github.com/blinklabs-io/gouroboros/cbor" "github.com/blinklabs-io/gouroboros/protocol" @@ -28,8 +27,6 @@ type Server struct { config *Config callbackContext CallbackContext protoOptions protocol.ProtocolOptions - currentState protocol.State - stateMutex sync.Mutex } func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { @@ -37,7 +34,6 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { config: cfg, // Save this for re-use later protoOptions: protoOptions, - currentState: StateIdle, } s.callbackContext = CallbackContext{ Server: s, @@ -47,18 +43,6 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { return s } -func (s *Server) IsDone() bool { - s.stateMutex.Lock() - defer s.stateMutex.Unlock() - return s.currentState.Id == StateDone.Id -} - -func (s *Server) setState(newState protocol.State) { - s.stateMutex.Lock() - defer s.stateMutex.Unlock() - s.currentState = newState -} - func (s *Server) initProtocol() { protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, @@ -142,8 +126,8 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeRequestRange: err = s.handleRequestRange(msg) case MessageTypeClientDone: - s.setState(StateDone) - err = s.handleClientDone() + // State handled automatically by base protocol + return s.handleClientDone() default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index a49b3792..7d8b97e4 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -195,13 +195,10 @@ var PipelineIsNotEmpty = func(context any, msg protocol.Message) bool { // ChainSync is a wrapper object that holds the client and server instances type ChainSync struct { - Client *Client - Server *Server - stateMutex sync.Mutex - currentState protocol.State + Client *Client + Server *Server } -// Config is used to configure the ChainSync protocol instance type Config struct { RollBackwardFunc RollBackwardFunc RollForwardFunc RollForwardFunc @@ -335,26 +332,20 @@ func WithRecvQueueSize(size int) ChainSyncOptionFunc { } } -// HandleConnectionError handles connection errors and determines if they should be ignored func (c *ChainSync) HandleConnectionError(err error) error { if err == nil { return nil } + if c.Client.IsDone() || c.Server.IsDone() { + return nil + } + if errors.Is(err, io.EOF) || isConnectionReset(err) { - if c.IsDone() { - return nil - } + return err } return err } -// IsDone returns true if the protocol is in the Done state -func (c *ChainSync) IsDone() bool { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - return c.currentState.Id == stateDone.Id -} - func isConnectionReset(err error) bool { return strings.Contains(err.Error(), "connection reset") || strings.Contains(err.Error(), "broken pipe") diff --git a/protocol/chainsync/chainsync_test.go b/protocol/chainsync/chainsync_test.go index b988399d..58be64eb 100644 --- a/protocol/chainsync/chainsync_test.go +++ b/protocol/chainsync/chainsync_test.go @@ -4,13 +4,14 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package chainsync import ( @@ -28,23 +29,40 @@ import ( "github.com/stretchr/testify/assert" ) -// testAddr implements net.Addr for testing type testAddr struct{} func (a testAddr) Network() string { return "test" } func (a testAddr) String() string { return "test-addr" } -// testConn implements net.Conn for testing with buffered writes type testConn struct { writeChan chan []byte + closed bool + closeChan chan struct{} +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } } func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } func (c *testConn) Write(b []byte) (n int, err error) { - c.writeChan <- b - return len(b), nil + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} +func (c *testConn) Close() error { + if !c.closed { + close(c.closeChan) + c.closed = true + } + return nil } -func (c *testConn) Close() error { return nil } func (c *testConn) LocalAddr() net.Addr { return testAddr{} } func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } func (c *testConn) SetDeadline(t time.Time) error { return nil } @@ -53,6 +71,12 @@ func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { mux := muxer.New(conn) + go mux.Start() + go func() { + <-conn.(*testConn).closeChan + mux.Stop() + }() + return protocol.ProtocolOptions{ ConnectionId: connection.ConnectionId{ LocalAddr: testAddr{}, @@ -60,11 +84,13 @@ func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { }, Muxer: mux, Logger: slog.Default(), + Mode: protocol.ProtocolModeNodeToClient, } } func TestNewChainSync(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() cs := New(getTestProtocolOptions(conn), &cfg) assert.NotNil(t, cs.Client) @@ -76,28 +102,34 @@ func TestConfigOptions(t *testing.T) { cfg := NewConfig() assert.Equal(t, 5*time.Second, cfg.IntersectTimeout) assert.Equal(t, 300*time.Second, cfg.BlockTimeout) - assert.Equal(t, 0, cfg.PipelineLimit) }) t.Run("Custom config", func(t *testing.T) { cfg := NewConfig( WithIntersectTimeout(10*time.Second), WithBlockTimeout(30*time.Second), - WithPipelineLimit(100), - WithRecvQueueSize(50), + WithPipelineLimit(10), + WithRecvQueueSize(100), ) assert.Equal(t, 10*time.Second, cfg.IntersectTimeout) assert.Equal(t, 30*time.Second, cfg.BlockTimeout) - assert.Equal(t, 100, cfg.PipelineLimit) - assert.Equal(t, 50, cfg.RecvQueueSize) + assert.Equal(t, 10, cfg.PipelineLimit) + assert.Equal(t, 100, cfg.RecvQueueSize) }) } func TestConnectionErrorHandling(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() cs := New(getTestProtocolOptions(conn), &cfg) + // Start protocols + cs.Client.Start() + defer cs.Client.Stop() + cs.Server.Start() + defer cs.Server.Stop() + t.Run("Non-EOF error when not done", func(t *testing.T) { err := cs.HandleConnectionError(errors.New("test error")) assert.Error(t, err) @@ -108,28 +140,34 @@ func TestConnectionErrorHandling(t *testing.T) { assert.Error(t, err) }) - t.Run("Connection reset error when not done", func(t *testing.T) { + t.Run("Connection reset error", func(t *testing.T) { err := cs.HandleConnectionError(errors.New("connection reset by peer")) assert.Error(t, err) }) t.Run("EOF error when done", func(t *testing.T) { - // Force done state - cs.currentState = stateDone - err := cs.HandleConnectionError(io.EOF) + // Send done message to properly transition to done state + err := cs.Client.SendMessage(NewMsgDone()) assert.NoError(t, err) + + // Wait for state transition + time.Sleep(100 * time.Millisecond) + + err = cs.HandleConnectionError(io.EOF) + assert.NoError(t, err, "expected no error when protocol is in done state") }) } func TestCallbackRegistration(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() t.Run("RollForward callback registration", func(t *testing.T) { - rollForwardFunc := func(ctx CallbackContext, blockType uint, block any, tip Tip) error { + rollForwardFunc := func(ctx CallbackContext, blockType uint, blockData any, tip Tip) error { return nil } cfg := NewConfig(WithRollForwardFunc(rollForwardFunc)) - client := NewClient(&StateContext{}, getTestProtocolOptions(conn), &cfg) + client := NewClient(nil, getTestProtocolOptions(conn), &cfg) assert.NotNil(t, client) assert.NotNil(t, cfg.RollForwardFunc) }) @@ -139,26 +177,24 @@ func TestCallbackRegistration(t *testing.T) { return common.Point{}, Tip{}, nil } cfg := NewConfig(WithFindIntersectFunc(findIntersectFunc)) - server := NewServer(&StateContext{}, getTestProtocolOptions(conn), &cfg) + server := NewServer(nil, getTestProtocolOptions(conn), &cfg) assert.NotNil(t, server) assert.NotNil(t, cfg.FindIntersectFunc) }) } func TestClientMessageSending(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() - client := NewClient(&StateContext{}, getTestProtocolOptions(conn), &cfg) + client := NewClient(nil, getTestProtocolOptions(conn), &cfg) t.Run("Client can send messages", func(t *testing.T) { - // Start the client protocol client.Start() - - // Create proper Done message - doneMsg := &MsgDone{} + defer client.Stop() // Send a done message - err := client.Protocol.SendMessage(doneMsg) + err := client.SendMessage(NewMsgDone()) assert.NoError(t, err) // Verify message was written to connection @@ -170,15 +206,3 @@ func TestClientMessageSending(t *testing.T) { } }) } - -func TestServerMessageHandling(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} - cfg := NewConfig() - server := NewServer(&StateContext{}, getTestProtocolOptions(conn), &cfg) - - t.Run("Server can be started", func(t *testing.T) { - server.Start() - // Simple test to verify server can be started without error - assert.NotNil(t, server) - }) -} diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 0ca4330c..7d48db3f 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -45,8 +45,6 @@ type Client struct { wantCurrentTipChan chan chan<- Tip wantFirstBlockChan chan chan<- clientPointResult wantIntersectFoundChan chan chan<- clientPointResult - stateMutex sync.Mutex - currentState protocol.State } type clientPointResult struct { @@ -81,7 +79,6 @@ func NewClient( wantCurrentTipChan: make(chan chan<- Tip, 1), wantFirstBlockChan: make(chan chan<- clientPointResult, 1), wantIntersectFoundChan: make(chan chan<- clientPointResult, 1), - currentState: stateIdle, } c.callbackContext = CallbackContext{ Client: c, @@ -121,19 +118,6 @@ func NewClient( return c } -// Add state tracking methods -func (c *Client) setState(newState protocol.State) { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.currentState = newState -} - -func (c *Client) IsDone() bool { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - return c.currentState.Id == stateDone.Id -} - func (c *Client) Start() { c.onceStart.Do(func() { c.Protocol.Logger(). @@ -163,9 +147,11 @@ func (c *Client) Stop() error { ) c.busyMutex.Lock() defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + if !c.IsDone() { + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + return + } } }) return err @@ -587,12 +573,9 @@ func (c *Client) messageHandler(msg protocol.Message) error { case MessageTypeIntersectNotFound: err = c.handleIntersectNotFound(msg) case MessageTypeDone: - c.setState(stateDone) + return nil default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type()) + err = fmt.Errorf("%s: received unexpected message type %d", ProtocolName, msg.Type()) } return err } diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index 56f783e1..f8f20fa0 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -17,7 +17,6 @@ package chainsync import ( "errors" "fmt" - "sync" "github.com/blinklabs-io/gouroboros/ledger" "github.com/blinklabs-io/gouroboros/protocol" @@ -31,8 +30,6 @@ type Server struct { callbackContext CallbackContext protoOptions protocol.ProtocolOptions stateContext any - stateMutex sync.Mutex - currentState protocol.State } // NewServer returns a new ChainSync server object @@ -46,7 +43,6 @@ func NewServer( // Save these for re-use later protoOptions: protoOptions, stateContext: stateContext, - currentState: stateIdle, } s.callbackContext = CallbackContext{ Server: s, @@ -56,18 +52,6 @@ func NewServer( return s } -func (s *Server) setState(newState protocol.State) { - s.stateMutex.Lock() - defer s.stateMutex.Unlock() - s.currentState = newState -} - -func (s *Server) IsDone() bool { - s.stateMutex.Lock() - defer s.stateMutex.Unlock() - return s.currentState.Id == stateDone.Id -} - func (s *Server) initProtocol() { // Use node-to-client protocol ID ProtocolId := ProtocolIdNtC @@ -181,14 +165,9 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeFindIntersect: err = s.handleFindIntersect(msg) case MessageTypeDone: - s.setState(stateDone) - err = s.handleDone() + return s.handleDone() default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type(), - ) + err = fmt.Errorf("%s: received unexpected message type %d", ProtocolName, msg.Type()) } return err } @@ -247,7 +226,6 @@ func (s *Server) handleFindIntersect(msg protocol.Message) error { } func (s *Server) handleDone() error { - s.setState(stateDone) s.Protocol.Logger(). Debug("done", "component", "network", diff --git a/protocol/protocol.go b/protocol/protocol.go index 4f1d524c..ffcca85a 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -102,8 +102,9 @@ type ProtocolOptions struct { } type protocolStateTransition struct { - msg Message - errorChan chan<- error + msg Message + errorChan chan<- error + stateRespChan chan<- State } // MessageHandlerFunc represents a function that handles an incoming message @@ -126,6 +127,34 @@ func New(config ProtocolConfig) *Protocol { return p } +// CurrentState returns the current protocol state +func (p *Protocol) CurrentState() State { + stateChan := make(chan State, 1) + p.stateTransitionChan <- protocolStateTransition{ + stateRespChan: stateChan, + } + return <-stateChan +} + +// IsDone checks if the protocol is in a done/completed state +func (p *Protocol) IsDone() bool { + currentState := p.CurrentState() + if entry, exists := p.config.StateMap[currentState]; exists { + return entry.Agency == AgencyNone + } + return false +} + +// GetDoneState returns the done state from the state map +func (s StateMap) GetDoneState() State { + for state, entry := range s { + if entry.Agency == AgencyNone { + return state + } + } + return State{} +} + // Start initializes the mini-protocol func (p *Protocol) Start() { p.onceStart.Do(func() { @@ -514,6 +543,12 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { for { select { case t := <-ch: + if t.msg == nil && t.stateRespChan != nil { + // Handle state request + t.stateRespChan <- currentState + continue + } + nextState, err := p.nextState(currentState, t.msg) if err != nil { t.errorChan <- fmt.Errorf( @@ -574,7 +609,7 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) { func (p *Protocol) transitionState(msg Message) error { errorChan := make(chan error, 1) - p.stateTransitionChan <- protocolStateTransition{msg, errorChan} + p.stateTransitionChan <- protocolStateTransition{msg, errorChan, nil} return <-errorChan } diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index 446b13a6..77453868 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -28,8 +28,6 @@ type Client struct { config *Config callbackContext CallbackContext onceInit sync.Once - stateMutex sync.Mutex - currentState protocol.State } // NewClient returns a new TxSubmission client object @@ -39,8 +37,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { cfg = &tmpCfg } c := &Client{ - config: cfg, - currentState: stateInit, + config: cfg, } c.callbackContext = CallbackContext{ Client: c, @@ -86,21 +83,10 @@ func (c *Client) Init() { }) } -func (c *Client) setState(newState protocol.State) { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - c.currentState = newState -} - -func (c *Client) IsDone() bool { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - return c.currentState.Id == stateDone.Id -} - func (c *Client) messageHandler(msg protocol.Message) error { c.Protocol.Logger(). Debug(fmt.Sprintf("%s: client message for %+v", ProtocolName, c.callbackContext.ConnectionId.RemoteAddr)) + var err error switch msg.Type() { case MessageTypeRequestTxIds: @@ -108,7 +94,8 @@ func (c *Client) messageHandler(msg protocol.Message) error { case MessageTypeRequestTxs: err = c.handleRequestTxs(msg) case MessageTypeDone: - c.setState(stateDone) + // No need to set state here, base Protocol handles it + return nil default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index 46625cc8..ad694079 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -18,7 +18,6 @@ import ( "errors" "fmt" "math" - "sync" "github.com/blinklabs-io/gouroboros/ledger/common" "github.com/blinklabs-io/gouroboros/protocol" @@ -33,8 +32,6 @@ type Server struct { ackCount int requestTxIdsResultChan chan requestTxIdsResult requestTxsResultChan chan []TxBody - stateMutex sync.Mutex - currentState protocol.State } type requestTxIdsResult struct { @@ -50,7 +47,6 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { protoOptions: protoOptions, requestTxIdsResultChan: make(chan requestTxIdsResult), requestTxsResultChan: make(chan []TxBody), - currentState: stateInit, } s.callbackContext = CallbackContext{ Server: s, @@ -168,18 +164,6 @@ func (s *Server) RequestTxs(txIds []TxId) ([]TxBody, error) { } } -func (s *Server) setState(newState protocol.State) { - s.stateMutex.Lock() - defer s.stateMutex.Unlock() - s.currentState = newState -} - -func (s *Server) IsDone() bool { - s.stateMutex.Lock() - defer s.stateMutex.Unlock() - return s.currentState.Id == stateDone.Id -} - func (s *Server) messageHandler(msg protocol.Message) error { var err error switch msg.Type() { @@ -188,11 +172,9 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeReplyTxs: err = s.handleReplyTxs(msg) case MessageTypeDone: - s.setState(stateDone) - err = s.handleDone() + return s.handleDone() case MessageTypeInit: err = s.handleInit() - default: err = fmt.Errorf( "%s: received unexpected message type %d", @@ -232,7 +214,6 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error { } func (s *Server) handleDone() error { - s.setState(stateDone) s.Protocol.Logger(). Debug("done", "component", "network", diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index 3591e1ec..55c977bc 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -212,20 +212,16 @@ func (t *TxSubmission) HandleConnectionError(err error) error { if err == nil { return nil } + if t.Client.IsDone() || t.Server.IsDone() { + return nil + } + if errors.Is(err, io.EOF) || isConnectionReset(err) { - if t.IsDone() { - return nil - } + return err } return err } -func (t *TxSubmission) IsDone() bool { - t.stateMutex.Lock() - defer t.stateMutex.Unlock() - return t.currentState.Id == stateDone.Id -} - func isConnectionReset(err error) bool { return strings.Contains(err.Error(), "connection reset") || strings.Contains(err.Error(), "broken pipe") diff --git a/protocol/txsubmission/txsubmission_test.go b/protocol/txsubmission/txsubmission_test.go index eb8a3a9c..106b00fb 100644 --- a/protocol/txsubmission/txsubmission_test.go +++ b/protocol/txsubmission/txsubmission_test.go @@ -18,6 +18,7 @@ import ( "io" "log/slog" "net" + "sync" "testing" "time" @@ -25,31 +26,59 @@ import ( "github.com/blinklabs-io/gouroboros/muxer" "github.com/blinklabs-io/gouroboros/protocol" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -// testAddr implements net.Addr for testing type testAddr struct{} func (a testAddr) Network() string { return "test" } func (a testAddr) String() string { return "test-addr" } -// testConn implements net.Conn for testing with buffered writes type testConn struct { writeChan chan []byte + closed bool + closeChan chan struct{} + closeOnce sync.Once + mu sync.Mutex +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } } func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } -func (c *testConn) Write(b []byte) (n int, err error) { - c.writeChan <- b - return len(b), nil +func (c *testConn) Close() error { + c.closeOnce.Do(func() { + c.mu.Lock() + defer c.mu.Unlock() + close(c.closeChan) + c.closed = true + }) + return nil } -func (c *testConn) Close() error { return nil } func (c *testConn) LocalAddr() net.Addr { return testAddr{} } func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } func (c *testConn) SetDeadline(t time.Time) error { return nil } func (c *testConn) SetReadDeadline(t time.Time) error { return nil } func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } +func (c *testConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, io.EOF + } + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} + func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { mux := muxer.New(conn) return protocol.ProtocolOptions{ @@ -63,7 +92,8 @@ func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { } func TestNewTxSubmission(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() ts := New(getTestProtocolOptions(conn), &cfg) assert.NotNil(t, ts.Client) @@ -91,36 +121,9 @@ func TestConfigOptions(t *testing.T) { assert.NotNil(t, cfg.RequestTxsFunc) }) } - -func TestConnectionErrorHandling(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} - cfg := NewConfig() - ts := New(getTestProtocolOptions(conn), &cfg) - - t.Run("Non-EOF error when not done", func(t *testing.T) { - err := ts.HandleConnectionError(errors.New("test error")) - assert.Error(t, err) - }) - - t.Run("EOF error when not done", func(t *testing.T) { - err := ts.HandleConnectionError(io.EOF) - assert.Error(t, err) - }) - - t.Run("Connection reset error when not done", func(t *testing.T) { - err := ts.HandleConnectionError(errors.New("connection reset by peer")) - assert.Error(t, err) - }) - - t.Run("EOF error when done", func(t *testing.T) { - ts.currentState = stateDone - err := ts.HandleConnectionError(io.EOF) - assert.NoError(t, err) - }) -} - func TestCallbackRegistration(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() t.Run("RequestTxIds callback registration", func(t *testing.T) { requestTxIdsFunc := func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { @@ -144,36 +147,36 @@ func TestCallbackRegistration(t *testing.T) { } func TestClientMessageSending(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() client := NewClient(getTestProtocolOptions(conn), &cfg) t.Run("Client can send messages", func(t *testing.T) { - // Start the client protocol client.Start() + defer client.Stop() - // Send an init message - initMsg := NewMsgInit() - err := client.Protocol.SendMessage(initMsg) - assert.NoError(t, err) + err := client.SendMessage(NewMsgInit()) + require.NoError(t, err) - // Verify message was written to connection select { - case <-conn.writeChan: - // Message was sent successfully - case <-time.After(100 * time.Millisecond): + case msg := <-conn.writeChan: + assert.NotEmpty(t, msg, "expected message to be written") + case <-time.After(2 * time.Second): t.Fatal("timeout waiting for message send") } }) } func TestServerMessageHandling(t *testing.T) { - conn := &testConn{writeChan: make(chan []byte, 1)} + conn := newTestConn() + defer conn.Close() cfg := NewConfig() server := NewServer(getTestProtocolOptions(conn), &cfg) t.Run("Server can be started", func(t *testing.T) { server.Start() + defer server.Stop() assert.NotNil(t, server) }) } From 3afdaa0bf8ef385cb6d2baa080d1a114652cf45b Mon Sep 17 00:00:00 2001 From: Jenita Date: Mon, 18 Aug 2025 20:46:43 -0500 Subject: [PATCH 5/8] feat: no error return on connection close Signed-off-by: Jenita --- protocol/blockfetch/server.go | 1 - protocol/chainsync/chainsync.go | 1 + protocol/chainsync/client.go | 6 +++++- protocol/chainsync/server.go | 8 ++++++-- protocol/txsubmission/client.go | 1 - protocol/txsubmission/server.go | 2 +- protocol/txsubmission/txsubmission.go | 7 ++----- 7 files changed, 15 insertions(+), 11 deletions(-) diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index 0500b864..1004cceb 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -126,7 +126,6 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeRequestRange: err = s.handleRequestRange(msg) case MessageTypeClientDone: - // State handled automatically by base protocol return s.handleClientDone() default: err = fmt.Errorf( diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index 7d8b97e4..b2eb7fa2 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -199,6 +199,7 @@ type ChainSync struct { Server *Server } +// Config is used to configure the ChainSync protocol instance type Config struct { RollBackwardFunc RollBackwardFunc RollForwardFunc RollForwardFunc diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 7d48db3f..9c90c2d1 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -575,7 +575,11 @@ func (c *Client) messageHandler(msg protocol.Message) error { case MessageTypeDone: return nil default: - err = fmt.Errorf("%s: received unexpected message type %d", ProtocolName, msg.Type()) + err = fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.Type(), + ) } return err } diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index f8f20fa0..fb964bfc 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -165,9 +165,13 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeFindIntersect: err = s.handleFindIntersect(msg) case MessageTypeDone: - return s.handleDone() + err = s.handleDone() default: - err = fmt.Errorf("%s: received unexpected message type %d", ProtocolName, msg.Type()) + err = fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.Type(), + ) } return err } diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index 77453868..15dfcab5 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -94,7 +94,6 @@ func (c *Client) messageHandler(msg protocol.Message) error { case MessageTypeRequestTxs: err = c.handleRequestTxs(msg) case MessageTypeDone: - // No need to set state here, base Protocol handles it return nil default: err = fmt.Errorf( diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index ad694079..bdbd08f0 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -172,7 +172,7 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeReplyTxs: err = s.handleReplyTxs(msg) case MessageTypeDone: - return s.handleDone() + err = s.handleDone() case MessageTypeInit: err = s.handleInit() default: diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index 55c977bc..d06cb8de 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -19,7 +19,6 @@ import ( "errors" "io" "strings" - "sync" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -117,10 +116,8 @@ var StateMap = protocol.StateMap{ // TxSubmission is a wrapper object that holds the client and server instances type TxSubmission struct { - Client *Client - Server *Server - stateMutex sync.Mutex - currentState protocol.State + Client *Client + Server *Server } // Config is used to configure the TxSubmission protocol instance From 4de8be15b3781f7638aa839a5a01982e50eeefc2 Mon Sep 17 00:00:00 2001 From: Jenita Date: Mon, 18 Aug 2025 20:56:48 -0500 Subject: [PATCH 6/8] feat: no error return on connection close Signed-off-by: Jenita --- protocol/blockfetch/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index 1004cceb..e51f3d0e 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -126,7 +126,7 @@ func (s *Server) messageHandler(msg protocol.Message) error { case MessageTypeRequestRange: err = s.handleRequestRange(msg) case MessageTypeClientDone: - return s.handleClientDone() + err = s.handleClientDone() default: err = fmt.Errorf( "%s: received unexpected message type %d", From a43fdd1ac7b4884ab4894f8911bc77d539b2f311 Mon Sep 17 00:00:00 2001 From: Jenita Date: Sat, 30 Aug 2025 20:51:35 -0500 Subject: [PATCH 7/8] feat: no error return on connection close Signed-off-by: Jenita --- connection.go | 66 ++++++- connection_test.go | 193 ++++++++++++++++++++- protocol/blockfetch/blockfetch.go | 23 --- protocol/blockfetch/blockfetch_test.go | 41 ----- protocol/chainsync/chainsync.go | 22 --- protocol/chainsync/chainsync_test.go | 41 ----- protocol/txsubmission/txsubmission.go | 22 --- protocol/txsubmission/txsubmission_test.go | 31 ---- 8 files changed, 249 insertions(+), 190 deletions(-) diff --git a/connection.go b/connection.go index 387a4ba5..f73fa26e 100644 --- a/connection.go +++ b/connection.go @@ -29,6 +29,7 @@ import ( "io" "log/slog" "net" + "strings" "sync" "time" @@ -250,6 +251,49 @@ func (c *Connection) shutdown() { close(c.errorChan) } +// isConnectionReset checks if an error is a connection reset error +func (c *Connection) isConnectionReset(err error) bool { + errStr := err.Error() + return strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "broken pipe") +} + +// checkProtocols checks if the protocols are explicitly stopped by the client- treat as normal connection closure +func (c *Connection) checkProtocols() bool { + // Check chain-sync protocol + if c.chainSync != nil && (!c.chainSync.Client.IsDone() || !c.chainSync.Server.IsDone()) { + return false + } + + // Check block-fetch protocol + if c.blockFetch != nil && (!c.blockFetch.Client.IsDone() || !c.blockFetch.Server.IsDone()) { + return false + } + + // Check tx-submission protocol + if c.txSubmission != nil && (!c.txSubmission.Client.IsDone() || !c.txSubmission.Server.IsDone()) { + return false + } + + return true +} + +// HandleConnectionError handles connection-level errors centrally +func (c *Connection) HandleConnectionError(err error) error { + if err == nil { + return nil + } + + if c.checkProtocols() { + return nil + } + + if errors.Is(err, io.EOF) || c.isConnectionReset(err) { + return err + } + return err +} + // setupConnection establishes the muxer, configures and starts the handshake process, and initializes // the appropriate mini-protocols func (c *Connection) setupConnection() error { @@ -285,16 +329,20 @@ func (c *Connection) setupConnection() error { if !ok { return } - var connErr *muxer.ConnectionClosedError - if errors.As(err, &connErr) { - // Pass through ConnectionClosedError from muxer - c.errorChan <- err - } else { - // Wrap error message to denote it comes from the muxer - c.errorChan <- fmt.Errorf("muxer error: %w", err) + + // Use centralized connection error handling + if handledErr := c.HandleConnectionError(err); handledErr != nil { + var connErr *muxer.ConnectionClosedError + if errors.As(handledErr, &connErr) { + // Pass through ConnectionClosedError from muxer + c.errorChan <- handledErr + } else { + // Wrap error message to denote it comes from the muxer + c.errorChan <- fmt.Errorf("muxer error: %w", handledErr) + } + // Close connection on muxer errors + c.Close() } - // Close connection on muxer errors - c.Close() } }() protoOptions := protocol.ProtocolOptions{ diff --git a/connection_test.go b/connection_test.go index f904b6ad..8d28e454 100644 --- a/connection_test.go +++ b/connection_test.go @@ -16,11 +16,12 @@ package ouroboros_test import ( "fmt" + "io" "testing" "time" ouroboros "github.com/blinklabs-io/gouroboros" - "github.com/blinklabs-io/ouroboros-mock" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" "go.uber.org/goleak" ) @@ -82,3 +83,193 @@ func TestDoubleClose(t *testing.T) { t.Errorf("did not shutdown within timeout") } } + +// TestHandleConnectionError_ProtocolsDone tests that connection errors are ignored +// when main protocols are explicitly stopped +func TestHandleConnectionError_ProtocolsDone(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Test through HandleConnectionError - should return error when protocols are active + testErr := io.EOF + err = oConn.HandleConnectionError(testErr) + if err != testErr { + t.Fatalf("expected original error when protocols are active, got: %v", err) + } + + oConn.Close() +} + +// TestHandleConnectionError_ProtocolCompletion tests the specific requirement: +// When main protocols are done, connection errors should be ignored +func TestHandleConnectionError_ProtocolCompletion(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Test that the implementation correctly handles the requirement + t.Log("Testing that connection errors are ignored when main protocols (chain-sync, block-fetch, tx-submission) are explicitly stopped by client") + + oConn.Close() +} + +// TestHandleConnectionError_NilError tests that nil errors are handled correctly +func TestHandleConnectionError_NilError(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Test that nil errors are handled correctly - should return nil + err = oConn.HandleConnectionError(nil) + if err != nil { + t.Fatalf("expected nil error when input is nil, got: %s", err) + } + + oConn.Close() +} + +// TestHandleConnectionError_ConnectionReset tests connection reset error handling +func TestHandleConnectionError_ConnectionReset(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Test connection reset error - should return the error when protocols are active + resetErr := fmt.Errorf("connection reset by peer") + err = oConn.HandleConnectionError(resetErr) + if err != resetErr { + t.Fatalf("expected connection reset error when protocols are active, got: %v", err) + } + + oConn.Close() +} + +// TestHandleConnectionError_EOF tests EOF error handling +func TestHandleConnectionError_EOF(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Test EOF error - should return the error when protocols are active + eofErr := io.EOF + err = oConn.HandleConnectionError(eofErr) + if err != eofErr { + t.Fatalf("expected EOF error when protocols are active, got: %v", err) + } + + oConn.Close() +} + +// TestCentralizedErrorHandlingIntegration tests that error handling is centralized +// in the Connection class rather than in individual protocols +func TestCentralizedErrorHandlingIntegration(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + testErr := fmt.Errorf("test error") + resultErr := oConn.HandleConnectionError(testErr) + + // Should return the original error when protocols are active + if resultErr != testErr { + t.Fatalf("expected test error to be returned, got: %v", resultErr) + } + + // Verify that the method signature matches the expected behavior + if oConn.HandleConnectionError(nil) != nil { + t.Error("HandleConnectionError(nil) should return nil") + } + + oConn.Close() +} diff --git a/protocol/blockfetch/blockfetch.go b/protocol/blockfetch/blockfetch.go index 5a514ca1..e82b4b7c 100644 --- a/protocol/blockfetch/blockfetch.go +++ b/protocol/blockfetch/blockfetch.go @@ -15,9 +15,6 @@ package blockfetch import ( - "errors" - "io" - "strings" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -121,26 +118,6 @@ func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch { return b } -func (b *BlockFetch) HandleConnectionError(err error) error { - if err == nil { - return nil - } - // Check if protocol is done or if it's a normal connection closure - if b.Client.IsDone() || b.Server.IsDone() { - return nil - } - - if errors.Is(err, io.EOF) || isConnectionReset(err) { - return err - } - return err -} - -func isConnectionReset(err error) bool { - return strings.Contains(err.Error(), "connection reset") || - strings.Contains(err.Error(), "broken pipe") -} - type BlockFetchOptionFunc func(*Config) func NewConfig(options ...BlockFetchOptionFunc) Config { diff --git a/protocol/blockfetch/blockfetch_test.go b/protocol/blockfetch/blockfetch_test.go index fc78708f..2a5cfbba 100644 --- a/protocol/blockfetch/blockfetch_test.go +++ b/protocol/blockfetch/blockfetch_test.go @@ -14,7 +14,6 @@ package blockfetch import ( - "errors" "io" "log/slog" "net" @@ -117,46 +116,6 @@ func TestConfigOptions(t *testing.T) { }) } -func TestConnectionErrorHandling(t *testing.T) { - conn := newTestConn() - defer conn.Close() - cfg := NewConfig() - bf := New(getTestProtocolOptions(conn), &cfg) - - // Start protocols - bf.Client.Start() - defer bf.Client.Stop() - bf.Server.Start() - defer bf.Server.Stop() - - t.Run("Non-EOF error when not done", func(t *testing.T) { - err := bf.HandleConnectionError(errors.New("test error")) - assert.Error(t, err) - }) - - t.Run("EOF error when not done", func(t *testing.T) { - err := bf.HandleConnectionError(io.EOF) - assert.Error(t, err) - }) - - t.Run("Connection reset error", func(t *testing.T) { - err := bf.HandleConnectionError(errors.New("connection reset by peer")) - assert.Error(t, err) - }) - - t.Run("EOF error when done", func(t *testing.T) { - // Send done message to properly transition to done state - err := bf.Client.SendMessage(NewMsgClientDone()) - assert.NoError(t, err) - - // Wait for state transition - time.Sleep(100 * time.Millisecond) - - err = bf.HandleConnectionError(io.EOF) - assert.NoError(t, err, "expected no error when protocol is in done state") - }) -} - func TestCallbackRegistration(t *testing.T) { conn := newTestConn() defer conn.Close() diff --git a/protocol/chainsync/chainsync.go b/protocol/chainsync/chainsync.go index b2eb7fa2..2f7443cc 100644 --- a/protocol/chainsync/chainsync.go +++ b/protocol/chainsync/chainsync.go @@ -16,9 +16,6 @@ package chainsync import ( - "errors" - "io" - "strings" "sync" "time" @@ -332,22 +329,3 @@ func WithRecvQueueSize(size int) ChainSyncOptionFunc { c.RecvQueueSize = size } } - -func (c *ChainSync) HandleConnectionError(err error) error { - if err == nil { - return nil - } - if c.Client.IsDone() || c.Server.IsDone() { - return nil - } - - if errors.Is(err, io.EOF) || isConnectionReset(err) { - return err - } - return err -} - -func isConnectionReset(err error) bool { - return strings.Contains(err.Error(), "connection reset") || - strings.Contains(err.Error(), "broken pipe") -} diff --git a/protocol/chainsync/chainsync_test.go b/protocol/chainsync/chainsync_test.go index 58be64eb..d660717d 100644 --- a/protocol/chainsync/chainsync_test.go +++ b/protocol/chainsync/chainsync_test.go @@ -15,7 +15,6 @@ package chainsync import ( - "errors" "io" "log/slog" "net" @@ -118,46 +117,6 @@ func TestConfigOptions(t *testing.T) { }) } -func TestConnectionErrorHandling(t *testing.T) { - conn := newTestConn() - defer conn.Close() - cfg := NewConfig() - cs := New(getTestProtocolOptions(conn), &cfg) - - // Start protocols - cs.Client.Start() - defer cs.Client.Stop() - cs.Server.Start() - defer cs.Server.Stop() - - t.Run("Non-EOF error when not done", func(t *testing.T) { - err := cs.HandleConnectionError(errors.New("test error")) - assert.Error(t, err) - }) - - t.Run("EOF error when not done", func(t *testing.T) { - err := cs.HandleConnectionError(io.EOF) - assert.Error(t, err) - }) - - t.Run("Connection reset error", func(t *testing.T) { - err := cs.HandleConnectionError(errors.New("connection reset by peer")) - assert.Error(t, err) - }) - - t.Run("EOF error when done", func(t *testing.T) { - // Send done message to properly transition to done state - err := cs.Client.SendMessage(NewMsgDone()) - assert.NoError(t, err) - - // Wait for state transition - time.Sleep(100 * time.Millisecond) - - err = cs.HandleConnectionError(io.EOF) - assert.NoError(t, err, "expected no error when protocol is in done state") - }) -} - func TestCallbackRegistration(t *testing.T) { conn := newTestConn() defer conn.Close() diff --git a/protocol/txsubmission/txsubmission.go b/protocol/txsubmission/txsubmission.go index d06cb8de..2c5b9116 100644 --- a/protocol/txsubmission/txsubmission.go +++ b/protocol/txsubmission/txsubmission.go @@ -16,9 +16,6 @@ package txsubmission import ( - "errors" - "io" - "strings" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -204,22 +201,3 @@ func WithIdleTimeout(timeout time.Duration) TxSubmissionOptionFunc { c.IdleTimeout = timeout } } - -func (t *TxSubmission) HandleConnectionError(err error) error { - if err == nil { - return nil - } - if t.Client.IsDone() || t.Server.IsDone() { - return nil - } - - if errors.Is(err, io.EOF) || isConnectionReset(err) { - return err - } - return err -} - -func isConnectionReset(err error) bool { - return strings.Contains(err.Error(), "connection reset") || - strings.Contains(err.Error(), "broken pipe") -} diff --git a/protocol/txsubmission/txsubmission_test.go b/protocol/txsubmission/txsubmission_test.go index 106b00fb..eb7e5e60 100644 --- a/protocol/txsubmission/txsubmission_test.go +++ b/protocol/txsubmission/txsubmission_test.go @@ -14,7 +14,6 @@ package txsubmission import ( - "errors" "io" "log/slog" "net" @@ -180,33 +179,3 @@ func TestServerMessageHandling(t *testing.T) { assert.NotNil(t, server) }) } - -func TestIsConnectionReset(t *testing.T) { - tests := []struct { - name string - err error - expected bool - }{ - { - name: "connection reset", - err: errors.New("connection reset by peer"), - expected: true, - }, - { - name: "broken pipe", - err: errors.New("broken pipe"), - expected: true, - }, - { - name: "other error", - err: errors.New("other error"), - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, isConnectionReset(tt.err)) - }) - } -} From 762277d28d3cd2a7165edc6fa438cf3b614066b9 Mon Sep 17 00:00:00 2001 From: Jenita Date: Fri, 19 Sep 2025 20:07:40 -0500 Subject: [PATCH 8/8] feat: addressed review comments Signed-off-by: Jenita --- connection.go | 85 ++++++-- connection_test.go | 375 ++++++++++++++++---------------- protocol/blockfetch/client.go | 2 - protocol/chainsync/client.go | 2 - protocol/protocol.go | 52 +++-- protocol/txsubmission/client.go | 2 - 6 files changed, 294 insertions(+), 224 deletions(-) diff --git a/connection.go b/connection.go index f73fa26e..20fe1ace 100644 --- a/connection.go +++ b/connection.go @@ -29,8 +29,10 @@ import ( "io" "log/slog" "net" + "os" "strings" "sync" + "syscall" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -251,46 +253,97 @@ func (c *Connection) shutdown() { close(c.errorChan) } -// isConnectionReset checks if an error is a connection reset error +// isConnectionReset checks if an error is a connection reset error using proper error type checking func (c *Connection) isConnectionReset(err error) bool { - errStr := err.Error() - return strings.Contains(errStr, "connection reset") || - strings.Contains(errStr, "broken pipe") + if errors.Is(err, io.EOF) { + return true + } + + // Check for connection reset errors using proper error type checking + var opErr *net.OpError + if errors.As(err, &opErr) { + if syscallErr, ok := opErr.Err.(*os.SyscallError); ok { + if errno, ok := syscallErr.Err.(syscall.Errno); ok { + // Check for connection reset (ECONNRESET) or broken pipe (EPIPE) + return errno == syscall.ECONNRESET || errno == syscall.EPIPE + } + } + // Also check for string-based errors as fallback for edge cases + errStr := opErr.Err.Error() + return strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "broken pipe") + } + + return false } -// checkProtocols checks if the protocols are explicitly stopped by the client- treat as normal connection closure -func (c *Connection) checkProtocols() bool { +// checkProtocolsDone checks if the protocols are explicitly stopped by the client - treat as normal connection closure +func (c *Connection) checkProtocolsDone() bool { // Check chain-sync protocol - if c.chainSync != nil && (!c.chainSync.Client.IsDone() || !c.chainSync.Server.IsDone()) { - return false + if c.chainSync != nil { + if (c.chainSync.Client != nil && !c.chainSync.Client.IsDone()) || + (c.chainSync.Server != nil && !c.chainSync.Server.IsDone()) { + return false + } } // Check block-fetch protocol - if c.blockFetch != nil && (!c.blockFetch.Client.IsDone() || !c.blockFetch.Server.IsDone()) { - return false + if c.blockFetch != nil { + if (c.blockFetch.Client != nil && !c.blockFetch.Client.IsDone()) || + (c.blockFetch.Server != nil && !c.blockFetch.Server.IsDone()) { + return false + } } // Check tx-submission protocol - if c.txSubmission != nil && (!c.txSubmission.Client.IsDone() || !c.txSubmission.Server.IsDone()) { - return false + if c.txSubmission != nil { + if (c.txSubmission.Client != nil && !c.txSubmission.Client.IsDone()) || + (c.txSubmission.Server != nil && !c.txSubmission.Server.IsDone()) { + return false + } + } + + // Check local-state-query protocol + if c.localStateQuery != nil { + if (c.localStateQuery.Client != nil && !c.localStateQuery.Client.IsDone()) || + (c.localStateQuery.Server != nil && !c.localStateQuery.Server.IsDone()) { + return false + } + } + + // Check local-tx-monitor protocol + if c.localTxMonitor != nil { + if (c.localTxMonitor.Client != nil && !c.localTxMonitor.Client.IsDone()) || + (c.localTxMonitor.Server != nil && !c.localTxMonitor.Server.IsDone()) { + return false + } + } + + // Check local-tx-submission protocol + if c.localTxSubmission != nil { + if (c.localTxSubmission.Client != nil && !c.localTxSubmission.Client.IsDone()) || + (c.localTxSubmission.Server != nil && !c.localTxSubmission.Server.IsDone()) { + return false + } } return true } -// HandleConnectionError handles connection-level errors centrally -func (c *Connection) HandleConnectionError(err error) error { +// handleConnectionError handles connection-level errors centrally +func (c *Connection) handleConnectionError(err error) error { if err == nil { return nil } - if c.checkProtocols() { + if c.checkProtocolsDone() { return nil } if errors.Is(err, io.EOF) || c.isConnectionReset(err) { return err } + return err } @@ -331,7 +384,7 @@ func (c *Connection) setupConnection() error { } // Use centralized connection error handling - if handledErr := c.HandleConnectionError(err); handledErr != nil { + if handledErr := c.handleConnectionError(err); handledErr != nil { var connErr *muxer.ConnectionClosedError if errors.As(handledErr, &connErr) { // Pass through ConnectionClosedError from muxer diff --git a/connection_test.go b/connection_test.go index 8d28e454..2bd2d840 100644 --- a/connection_test.go +++ b/connection_test.go @@ -15,167 +15,143 @@ package ouroboros_test import ( - "fmt" - "io" "testing" "time" ouroboros "github.com/blinklabs-io/gouroboros" + "github.com/blinklabs-io/gouroboros/protocol/chainsync" ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" "go.uber.org/goleak" ) -// Ensure that we don't panic when closing the Connection object after a failed Dial() call -func TestDialFailClose(t *testing.T) { +// TestErrorHandlingWithActiveProtocols tests that connection errors are propagated +// when protocols are active, and ignored when protocols are stopped +func TestErrorHandlingWithActiveProtocols(t *testing.T) { defer goleak.VerifyNone(t) - oConn, err := ouroboros.New() - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } - err = oConn.Dial("unix", "/path/does/not/exist") - if err == nil { - t.Fatalf("did not get expected failure on Dial()") - } - // Close connection - oConn.Close() -} -func TestDoubleClose(t *testing.T) { - defer goleak.VerifyNone(t) - mockConn := ouroboros_mock.NewConnection( - ouroboros_mock.ProtocolRoleClient, - []ouroboros_mock.ConversationEntry{ - ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtCResponse, - }, - ) - oConn, err := ouroboros.New( - ouroboros.WithConnection(mockConn), - ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), - ) - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } - // Async error handler - go func() { - err, ok := <-oConn.ErrorChan() - if !ok { - return - } - // We can't call t.Fatalf() from a different Goroutine, so we panic instead - panic(fmt.Sprintf("unexpected Ouroboros connection error: %s", err)) - }() - // Close connection - if err := oConn.Close(); err != nil { - t.Fatalf("unexpected error when closing Connection object: %s", err) - } - // Close connection again - if err := oConn.Close(); err != nil { - t.Fatalf( - "unexpected error when closing Connection object again: %s", - err, + t.Run("ErrorsPropagatedWhenProtocolsActive", func(t *testing.T) { + // Create a mock connection that will complete handshake + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, ) - } - // Wait for connection shutdown - select { - case <-oConn.ErrorChan(): - case <-time.After(10 * time.Second): - t.Errorf("did not shutdown within timeout") - } -} -// TestHandleConnectionError_ProtocolsDone tests that connection errors are ignored -// when main protocols are explicitly stopped -func TestHandleConnectionError_ProtocolsDone(t *testing.T) { - defer goleak.VerifyNone(t) + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } - mockConn := ouroboros_mock.NewConnection( - ouroboros_mock.ProtocolRoleClient, - []ouroboros_mock.ConversationEntry{ - ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtNResponse, - }, - ) + // Wait for handshake to complete by checking if protocols are initialized + var chainSyncProtocol *chainsync.ChainSync + for i := 0; i < 100; i++ { + chainSyncProtocol = oConn.ChainSync() + if chainSyncProtocol != nil && chainSyncProtocol.Client != nil { + break + } + time.Sleep(10 * time.Millisecond) + } - oConn, err := ouroboros.New( - ouroboros.WithConnection(mockConn), - ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), - ouroboros.WithNodeToNode(true), - ) - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } + if chainSyncProtocol == nil || chainSyncProtocol.Client == nil { + oConn.Close() + t.Fatal("chain sync protocol not initialized") + } - // Test through HandleConnectionError - should return error when protocols are active - testErr := io.EOF - err = oConn.HandleConnectionError(testErr) - if err != testErr { - t.Fatalf("expected original error when protocols are active, got: %v", err) - } + // Start the chain sync protocol to make it active + chainSyncProtocol.Client.Start() - oConn.Close() -} + // Wait a bit for protocol to start + time.Sleep(100 * time.Millisecond) -// TestHandleConnectionError_ProtocolCompletion tests the specific requirement: -// When main protocols are done, connection errors should be ignored -func TestHandleConnectionError_ProtocolCompletion(t *testing.T) { - defer goleak.VerifyNone(t) + // Close the mock connection to generate a connection error + mockConn.Close() - mockConn := ouroboros_mock.NewConnection( - ouroboros_mock.ProtocolRoleClient, - []ouroboros_mock.ConversationEntry{ - ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtNResponse, - }, - ) + // We should receive a connection error since protocols are active + select { + case err := <-oConn.ErrorChan(): + if err == nil { + t.Fatal("expected connection error, got nil") + } + t.Logf("Received connection error (expected with active protocols): %s", err) + case <-time.After(2 * time.Second): + t.Error("timed out waiting for connection error - error should be propagated when protocols are active") + } - oConn, err := ouroboros.New( - ouroboros.WithConnection(mockConn), - ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), - ouroboros.WithNodeToNode(true), - ) - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } + oConn.Close() + }) + + t.Run("ErrorsIgnoredWhenProtocolsStopped", func(t *testing.T) { + // Create a mock connection + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) - // Test that the implementation correctly handles the requirement - t.Log("Testing that connection errors are ignored when main protocols (chain-sync, block-fetch, tx-submission) are explicitly stopped by client") + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } - oConn.Close() -} + // Wait for handshake to complete + var chainSyncProtocol *chainsync.ChainSync + for i := 0; i < 100; i++ { + chainSyncProtocol = oConn.ChainSync() + if chainSyncProtocol != nil && chainSyncProtocol.Client != nil { + break + } + time.Sleep(10 * time.Millisecond) + } -// TestHandleConnectionError_NilError tests that nil errors are handled correctly -func TestHandleConnectionError_NilError(t *testing.T) { - defer goleak.VerifyNone(t) + if chainSyncProtocol == nil || chainSyncProtocol.Client == nil { + oConn.Close() + t.Fatal("chain sync protocol not initialized") + } - mockConn := ouroboros_mock.NewConnection( - ouroboros_mock.ProtocolRoleClient, - []ouroboros_mock.ConversationEntry{ - ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtNResponse, - }, - ) + // Start and then immediately stop the protocol + chainSyncProtocol.Client.Start() + time.Sleep(50 * time.Millisecond) - oConn, err := ouroboros.New( - ouroboros.WithConnection(mockConn), - ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), - ouroboros.WithNodeToNode(true), - ) - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } + // Stop the protocol explicitly + if err := chainSyncProtocol.Client.Stop(); err != nil { + t.Fatalf("failed to stop chain sync: %s", err) + } - // Test that nil errors are handled correctly - should return nil - err = oConn.HandleConnectionError(nil) - if err != nil { - t.Fatalf("expected nil error when input is nil, got: %s", err) - } + // Wait for protocol to be done + select { + case <-chainSyncProtocol.Client.DoneChan(): + // Protocol is stopped + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for protocol to stop") + } - oConn.Close() + // Now close the mock connection to generate an error + mockConn.Close() + select { + case err := <-oConn.ErrorChan(): + t.Logf("Received error during shutdown: %s", err) + case <-time.After(500 * time.Millisecond): + t.Log("No connection error received (expected when protocols are stopped)") + } + + oConn.Close() + }) } -// TestHandleConnectionError_ConnectionReset tests connection reset error handling -func TestHandleConnectionError_ConnectionReset(t *testing.T) { +// TestErrorHandlingWithMultipleProtocols tests error handling with multiple active protocols +func TestErrorHandlingWithMultipleProtocols(t *testing.T) { defer goleak.VerifyNone(t) mockConn := ouroboros_mock.NewConnection( @@ -195,80 +171,109 @@ func TestHandleConnectionError_ConnectionReset(t *testing.T) { t.Fatalf("unexpected error when creating Connection object: %s", err) } - // Test connection reset error - should return the error when protocols are active - resetErr := fmt.Errorf("connection reset by peer") - err = oConn.HandleConnectionError(resetErr) - if err != resetErr { - t.Fatalf("expected connection reset error when protocols are active, got: %v", err) + // Wait for handshake to complete + time.Sleep(100 * time.Millisecond) + + // Start multiple protocols + chainSync := oConn.ChainSync() + blockFetch := oConn.BlockFetch() + txSubmission := oConn.TxSubmission() + + if chainSync != nil && chainSync.Client != nil { + chainSync.Client.Start() + } + if blockFetch != nil && blockFetch.Client != nil { + blockFetch.Client.Start() + } + if txSubmission != nil && txSubmission.Client != nil { + txSubmission.Client.Start() + } + + // Wait for protocols to start + time.Sleep(100 * time.Millisecond) + + // Close connection to generate error + mockConn.Close() + + // Should receive error since protocols are active + select { + case err := <-oConn.ErrorChan(): + if err == nil { + t.Fatal("expected connection error, got nil") + } + t.Logf("Received connection error with multiple active protocols: %s", err) + case <-time.After(2 * time.Second): + t.Error("timed out waiting for connection error") } oConn.Close() } -// TestHandleConnectionError_EOF tests EOF error handling -func TestHandleConnectionError_EOF(t *testing.T) { +// TestBasicErrorHandling tests basic error handling scenarios +func TestBasicErrorHandling(t *testing.T) { defer goleak.VerifyNone(t) - mockConn := ouroboros_mock.NewConnection( - ouroboros_mock.ProtocolRoleClient, - []ouroboros_mock.ConversationEntry{ - ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtNResponse, - }, - ) + t.Run("DialFailure", func(t *testing.T) { + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } - oConn, err := ouroboros.New( - ouroboros.WithConnection(mockConn), - ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), - ouroboros.WithNodeToNode(true), - ) - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } + err = oConn.Dial("tcp", "invalid-hostname:9999") + if err == nil { + t.Fatal("expected dial error, got nil") + } - // Test EOF error - should return the error when protocols are active - eofErr := io.EOF - err = oConn.HandleConnectionError(eofErr) - if err != eofErr { - t.Fatalf("expected EOF error when protocols are active, got: %v", err) - } + oConn.Close() + }) - oConn.Close() + t.Run("DoubleClose", func(t *testing.T) { + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // First close + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error on first close: %s", err) + } + + // Second close should also work + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error on second close: %s", err) + } + }) } -// TestCentralizedErrorHandlingIntegration tests that error handling is centralized -// in the Connection class rather than in individual protocols -func TestCentralizedErrorHandlingIntegration(t *testing.T) { +// TestErrorChannelBehavior tests basic error channel behavior +func TestErrorChannelBehavior(t *testing.T) { defer goleak.VerifyNone(t) - mockConn := ouroboros_mock.NewConnection( - ouroboros_mock.ProtocolRoleClient, - []ouroboros_mock.ConversationEntry{ - ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtNResponse, - }, - ) - oConn, err := ouroboros.New( - ouroboros.WithConnection(mockConn), - ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), - ouroboros.WithNodeToNode(true), + ouroboros.WithNetworkMagic(764824073), ) if err != nil { t.Fatalf("unexpected error when creating Connection object: %s", err) } - testErr := fmt.Errorf("test error") - resultErr := oConn.HandleConnectionError(testErr) - - // Should return the original error when protocols are active - if resultErr != testErr { - t.Fatalf("expected test error to be returned, got: %v", resultErr) + errorChan := oConn.ErrorChan() + if errorChan == nil { + t.Fatal("error channel should not be nil") } - // Verify that the method signature matches the expected behavior - if oConn.HandleConnectionError(nil) != nil { - t.Error("HandleConnectionError(nil) should return nil") + select { + case err, ok := <-errorChan: + if ok { + t.Logf("Error channel contained: %s", err) + } else { + t.Error("Error channel should not be closed initially") + } + default: + // Expected - channel is empty but open } oConn.Close() diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index 0af7a3c9..be33fdf0 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -200,8 +200,6 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleBlock(msg) case MessageTypeBatchDone: err = c.handleBatchDone() - case MessageTypeClientDone: - return nil default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 9c90c2d1..27a8c2aa 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -572,8 +572,6 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleIntersectFound(msg) case MessageTypeIntersectNotFound: err = c.handleIntersectNotFound(msg) - case MessageTypeDone: - return nil default: err = fmt.Errorf( "%s: received unexpected message type %d", diff --git a/protocol/protocol.go b/protocol/protocol.go index ffcca85a..1f3ace54 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -38,6 +38,7 @@ const DefaultRecvQueueSize = 50 // Protocol implements the base functionality of an Ouroboros mini-protocol type Protocol struct { config ProtocolConfig + currentState State doneChan chan struct{} muxerSendChan chan *muxer.Segment muxerRecvChan chan *muxer.Segment @@ -51,6 +52,7 @@ type Protocol struct { stateTransitionChan chan<- protocolStateTransition onceStart sync.Once onceStop sync.Once + stateMutex sync.RWMutex } // ProtocolConfig provides the configuration for Protocol @@ -129,20 +131,22 @@ func New(config ProtocolConfig) *Protocol { // CurrentState returns the current protocol state func (p *Protocol) CurrentState() State { - stateChan := make(chan State, 1) - p.stateTransitionChan <- protocolStateTransition{ - stateRespChan: stateChan, - } - return <-stateChan + p.stateMutex.RLock() + defer p.stateMutex.RUnlock() + return p.currentState } // IsDone checks if the protocol is in a done/completed state func (p *Protocol) IsDone() bool { currentState := p.CurrentState() + // return true if current state has AgencyNone if entry, exists := p.config.StateMap[currentState]; exists { - return entry.Agency == AgencyNone + if entry.Agency == AgencyNone { + return true + } } - return false + // return true if current state is the initial state + return currentState == p.config.InitialState } // GetDoneState returns the done state from the state map @@ -475,7 +479,6 @@ func (p *Protocol) recvLoop() { } func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { - var currentState State var transitionTimer *time.Timer setState := func(s State) { @@ -485,11 +488,18 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { } transitionTimer = nil - // Set the new state - currentState = s + // Set the new state with proper locking + p.stateMutex.Lock() + p.currentState = s + p.stateMutex.Unlock() // Mark protocol as ready to send/receive based on role and agency of the new state - switch p.config.StateMap[currentState].Agency { + stateEntry, exists := p.config.StateMap[s] + if !exists { + return + } + + switch stateEntry.Agency { case AgencyNone: return case AgencyClient: @@ -525,12 +535,11 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { } // Set timeout for state transition - if p.config.StateMap[currentState].Timeout > 0 { - transitionTimer = time.NewTimer( - p.config.StateMap[currentState].Timeout, - ) + if stateEntry.Timeout > 0 { + transitionTimer = time.NewTimer(stateEntry.Timeout) } } + getTimerChan := func() <-chan time.Time { if transitionTimer == nil { return nil @@ -538,17 +547,26 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { return transitionTimer.C } + // Initialize current state setState(p.config.InitialState) for { select { case t := <-ch: if t.msg == nil && t.stateRespChan != nil { - // Handle state request + // Handle state request - use the field instead of local variable + p.stateMutex.RLock() + currentState := p.currentState + p.stateMutex.RUnlock() t.stateRespChan <- currentState continue } + // Get current state for transition logic + p.stateMutex.RLock() + currentState := p.currentState + p.stateMutex.RUnlock() + nextState, err := p.nextState(currentState, t.msg) if err != nil { t.errorChan <- fmt.Errorf( @@ -573,7 +591,7 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { fmt.Errorf( "%s: timeout waiting on transition from protocol state %s", p.config.Name, - currentState, + p.CurrentState(), ), ) diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index 15dfcab5..a956329b 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -93,8 +93,6 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleRequestTxIds(msg) case MessageTypeRequestTxs: err = c.handleRequestTxs(msg) - case MessageTypeDone: - return nil default: err = fmt.Errorf( "%s: received unexpected message type %d",