diff --git a/epoch.go b/epoch.go index 9a73132..ee27078 100644 --- a/epoch.go +++ b/epoch.go @@ -2430,20 +2430,34 @@ func (e *Epoch) handleReplicationRequest(req *ReplicationRequest, from NodeID) e // locateQuorumRecord locates a block with a notarization or finalization in the epochs memory or storage. func (e *Epoch) locateQuorumRecord(seq uint64) *VerifiedQuorumRound { + var notarizedRound *Round for _, round := range e.rounds { blockSeq := round.block.BlockHeader().Seq if blockSeq == seq { - if round.finalization == nil && round.notarization == nil { - break - } - return &VerifiedQuorumRound{ - VerifiedBlock: round.block, - Notarization: round.notarization, - Finalization: round.finalization, + if round.finalization != nil { + return &VerifiedQuorumRound{ + VerifiedBlock: round.block, + Finalization: round.finalization, + } + } else if round.notarization != nil { + if notarizedRound == nil { + notarizedRound = round + } else if round.notarization.Vote.Round > notarizedRound.num { + // set the notarized round if it is the highest round we have seen so far + notarizedRound = round + } } } } + if notarizedRound != nil { + // we have a notarization, but no finalization, so we return the notarization + return &VerifiedQuorumRound{ + VerifiedBlock: notarizedRound.block, + Notarization: notarizedRound.notarization, + } + } + block, finalization, exists := e.Storage.Retrieve(seq) if exists { return &VerifiedQuorumRound{ diff --git a/epoch_multinode_test.go b/epoch_multinode_test.go index af560c1..b6df514 100644 --- a/epoch_multinode_test.go +++ b/epoch_multinode_test.go @@ -144,7 +144,6 @@ type testNodeConfig struct { // newSimplexNode creates a new testNode and adds it to [net]. func newSimplexNode(t *testing.T, nodeID NodeID, net *inMemNetwork, bb BlockBuilder, config *testNodeConfig) *testNode { comm := newTestComm(nodeID, net, allowAllMessages) - epochConfig := defaultTestNodeEpochConfig(t, nodeID, comm, bb) if config != nil { @@ -475,6 +474,13 @@ func onlyAllowEmptyRoundMessages(msg *Message, _, _ NodeID) bool { return false } +type testNetworkCommunication interface { + Communication + setFilter(filter messageFilter) +} + +var _ testNetworkCommunication = (*testComm)(nil) + type testComm struct { from NodeID net *inMemNetwork @@ -648,7 +654,7 @@ func (n *inMemNetwork) addNode(node *testNode) { func (n *inMemNetwork) setAllNodesMessageFilter(filter messageFilter) { for _, instance := range n.instances { - instance.e.Comm.(*testComm).setFilter(filter) + instance.e.Comm.(testNetworkCommunication).setFilter(filter) } } diff --git a/replication.go b/replication.go index 8f23f8a..a79a5f7 100644 --- a/replication.go +++ b/replication.go @@ -177,7 +177,14 @@ func (r *ReplicationState) createReplicationTimeoutTask(start, end uint64, nodes func (r *ReplicationState) receivedReplicationResponse(data []QuorumRound, node NodeID) { seqs := make([]uint64, 0, len(data)) + // remove all sequences where we expect a finalization but only received a notarization + highestSeq := r.highestSequenceObserved.seq for _, qr := range data { + if qr.GetSequence() <= highestSeq && qr.Finalization == nil && qr.Notarization != nil { + r.logger.Debug("Received notarization without finalization, skipping", zap.Stringer("from", node), zap.Uint64("seq", qr.GetSequence())) + continue + } + seqs = append(seqs, qr.GetSequence()) } @@ -185,7 +192,7 @@ func (r *ReplicationState) receivedReplicationResponse(data []QuorumRound, node task := FindReplicationTask(r.timeoutHandler, node, seqs) if task == nil { - r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node)) + r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node), zap.Any("seqs", seqs)) return } r.timeoutHandler.RemoveTask(node, task.TaskID) diff --git a/replication_timeout_test.go b/replication_timeout_test.go index c3684fa..8818d52 100644 --- a/replication_timeout_test.go +++ b/replication_timeout_test.go @@ -4,7 +4,10 @@ package simplex_test import ( + "bytes" + "sync" "testing" + "time" "github.com/ava-labs/simplex" @@ -79,6 +82,8 @@ func (m *testTimeoutMessageFilter) failOnReplicationRequest(msg *simplex.Message return true } +// receiveReplicationRequest is used to filter out sending replication responses, and notify a channel +// when a replication request is received. func (m *testTimeoutMessageFilter) receivedReplicationRequest(msg *simplex.Message, _, _ simplex.NodeID) bool { if msg.VerifiedReplicationResponse != nil || msg.ReplicationResponse != nil { m.replicationResponses <- struct{}{} @@ -285,3 +290,147 @@ func TestReplicationRequestIncompleteResponses(t *testing.T) { laggingNode.e.AdvanceTime(laggingNode.e.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 2)) laggingNode.storage.waitForBlockCommit(startSeq) } + +type collectNotarizationComm struct { + lock *sync.Mutex + notarizations map[uint64]*simplex.Notarization + testNetworkCommunication + + replicationResponses chan struct{} +} + +func newCollectNotarizationComm(nodeID simplex.NodeID, net *inMemNetwork, notarizations map[uint64]*simplex.Notarization, lock *sync.Mutex) *collectNotarizationComm { + return &collectNotarizationComm{ + notarizations: notarizations, + testNetworkCommunication: newTestComm(nodeID, net, allowAllMessages), + replicationResponses: make(chan struct{}, 3), + lock: lock, + } +} + +func (c *collectNotarizationComm) Send(msg *simplex.Message, to simplex.NodeID) { + if msg.Notarization != nil { + c.lock.Lock() + c.notarizations[msg.Notarization.Vote.Round] = msg.Notarization + c.lock.Unlock() + } + c.testNetworkCommunication.Send(msg, to) +} + +func (c *collectNotarizationComm) Broadcast(msg *simplex.Message) { + if msg.Notarization != nil { + c.lock.Lock() + c.notarizations[msg.Notarization.Vote.Round] = msg.Notarization + c.lock.Unlock() + } + c.testNetworkCommunication.Broadcast(msg) +} + +func (c *collectNotarizationComm) removeFinalizationsFromReplicationResponses(msg *simplex.Message, from, to simplex.NodeID) bool { + if msg.VerifiedReplicationResponse != nil || msg.ReplicationResponse != nil { + newData := make([]simplex.VerifiedQuorumRound, 0, len(msg.VerifiedReplicationResponse.Data)) + + for i := 0; i < len(msg.VerifiedReplicationResponse.Data); i++ { + qr := msg.VerifiedReplicationResponse.Data[i] + if qr.Finalization != nil && c.notarizations[qr.GetRound()] != nil { + qr.Finalization = nil + qr.Notarization = c.notarizations[qr.GetRound()] + } + newData = append(newData, qr) + } + msg.VerifiedReplicationResponse.Data = newData + c.replicationResponses <- struct{}{} + } + return true +} + +// TestReplicationRequestWithoutFinalization tests that a replication request is not marked as completed +// if we are expecting a finalization but it is not present in the response. +func TestReplicationRequestWithoutFinalization(t *testing.T) { + nodes := []simplex.NodeID{{1}, {2}, {3}, []byte("lagging")} + endDisconnect := uint64(10) + bb := newTestControlledBlockBuilder(t) + laggingBb := newTestControlledBlockBuilder(t) + net := newInMemNetwork(t, nodes) + + notarizations := make(map[uint64]*simplex.Notarization) + mapLock := &sync.Mutex{} + testConfig := func(nodeID simplex.NodeID) *testNodeConfig { + return &testNodeConfig{ + replicationEnabled: true, + comm: newCollectNotarizationComm(nodeID, net, notarizations, mapLock), + } + } + + notarizationComm := newCollectNotarizationComm(nodes[0], net, notarizations, mapLock) + newSimplexNode(t, nodes[0], net, bb, &testNodeConfig{ + replicationEnabled: true, + comm: notarizationComm, + }) + newSimplexNode(t, nodes[1], net, bb, testConfig(nodes[1])) + newSimplexNode(t, nodes[2], net, bb, testConfig(nodes[2])) + laggingNode := newSimplexNode(t, nodes[3], net, laggingBb, testConfig(nodes[3])) + + epochTimes := make([]time.Time, 0, 4) + for _, n := range net.instances { + epochTimes = append(epochTimes, n.e.StartTime) + } + // lagging node disconnects + net.Disconnect(nodes[3]) + net.startInstances() + + missedSeqs := uint64(0) + // normal nodes continue to make progress + for i := uint64(0); i < endDisconnect; i++ { + emptyRound := bytes.Equal(simplex.LeaderForRound(nodes, i), nodes[3]) + if emptyRound { + advanceWithoutLeader(t, net, bb, epochTimes, i, laggingNode.e.ID) + missedSeqs++ + } else { + bb.triggerNewBlock() + for _, n := range net.instances[:3] { + n.storage.waitForBlockCommit(i - missedSeqs) + } + } + } + + // all nodes excpet for lagging node have progressed and commited [endDisconnect - missedSeqs] blocks + for _, n := range net.instances[:3] { + require.Equal(t, endDisconnect-missedSeqs, n.storage.Height()) + } + require.Equal(t, uint64(0), laggingNode.storage.Height()) + require.Equal(t, uint64(0), laggingNode.e.Metadata().Round) + // lagging node reconnects + net.setAllNodesMessageFilter(notarizationComm.removeFinalizationsFromReplicationResponses) + net.Connect(nodes[3]) + bb.triggerNewBlock() + for _, n := range net.instances { + if n.e.ID.Equals(laggingNode.e.ID) { + continue + } + + n.storage.waitForBlockCommit(endDisconnect - missedSeqs) + <-notarizationComm.replicationResponses + } + + // wait until the lagging nodes sends replication requests + // due to the removeFinalizationsFromReplicationResponses message filter + // the lagging node should not have processed any replication requests + require.Equal(t, uint64(0), laggingNode.storage.Height()) + + // we should still have these replication requests in the timeout handler + laggingNode.e.AdvanceTime(laggingNode.e.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 2)) + + for range net.instances[:3] { + // the lagging node should have sent replication requests + // and the normal nodes should have responded + <-notarizationComm.replicationResponses + } + + require.Equal(t, uint64(0), laggingNode.storage.Height()) + // We should still have these replication requests in the timeout handler + // but now we allow the lagging node to process them + net.setAllNodesMessageFilter(allowAllMessages) + laggingNode.e.AdvanceTime(laggingNode.e.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 4)) + laggingNode.storage.waitForBlockCommit(endDisconnect - missedSeqs) +}