Skip to content

Resend ReplicationRequest If Finalization Is Expected But Not Received #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
28 changes: 21 additions & 7 deletions epoch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2430,20 +2430,34 @@ func (e *Epoch) handleReplicationRequest(req *ReplicationRequest, from NodeID) e

// locateQuorumRecord locates a block with a notarization or finalization in the epochs memory or storage.
func (e *Epoch) locateQuorumRecord(seq uint64) *VerifiedQuorumRound {
var notarizedRound *Round
for _, round := range e.rounds {
blockSeq := round.block.BlockHeader().Seq
if blockSeq == seq {
if round.finalization == nil && round.notarization == nil {
break
}
return &VerifiedQuorumRound{
VerifiedBlock: round.block,
Notarization: round.notarization,
Finalization: round.finalization,
if round.finalization != nil {
return &VerifiedQuorumRound{
VerifiedBlock: round.block,
Finalization: round.finalization,
}
} else if round.notarization != nil {
if notarizedRound == nil {
notarizedRound = round
} else if round.notarization.Vote.Round > notarizedRound.num {
// set the notarized round if it is the highest round we have seen so far
notarizedRound = round
}
}
}
}

if notarizedRound != nil {
// we have a notarization, but no finalization, so we return the notarization
return &VerifiedQuorumRound{
VerifiedBlock: notarizedRound.block,
Notarization: notarizedRound.notarization,
}
}

block, finalization, exists := e.Storage.Retrieve(seq)
if exists {
return &VerifiedQuorumRound{
Expand Down
10 changes: 8 additions & 2 deletions epoch_multinode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ type testNodeConfig struct {
// newSimplexNode creates a new testNode and adds it to [net].
func newSimplexNode(t *testing.T, nodeID NodeID, net *inMemNetwork, bb BlockBuilder, config *testNodeConfig) *testNode {
comm := newTestComm(nodeID, net, allowAllMessages)

epochConfig := defaultTestNodeEpochConfig(t, nodeID, comm, bb)

if config != nil {
Expand Down Expand Up @@ -475,6 +474,13 @@ func onlyAllowEmptyRoundMessages(msg *Message, _, _ NodeID) bool {
return false
}

type testNetworkCommunication interface {
Communication
setFilter(filter messageFilter)
}

var _ testNetworkCommunication = (*testComm)(nil)

type testComm struct {
from NodeID
net *inMemNetwork
Expand Down Expand Up @@ -648,7 +654,7 @@ func (n *inMemNetwork) addNode(node *testNode) {

func (n *inMemNetwork) setAllNodesMessageFilter(filter messageFilter) {
for _, instance := range n.instances {
instance.e.Comm.(*testComm).setFilter(filter)
instance.e.Comm.(testNetworkCommunication).setFilter(filter)
}
}

Expand Down
9 changes: 8 additions & 1 deletion replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,22 @@ func (r *ReplicationState) createReplicationTimeoutTask(start, end uint64, nodes
func (r *ReplicationState) receivedReplicationResponse(data []QuorumRound, node NodeID) {
seqs := make([]uint64, 0, len(data))

// remove all sequences where we expect a finalization but only received a notarization
highestSeq := r.highestSequenceObserved.seq
for _, qr := range data {
if qr.GetSequence() <= highestSeq && qr.Finalization == nil && qr.Notarization != nil {
r.logger.Debug("Received notarization without finalization, skipping", zap.Stringer("from", node), zap.Uint64("seq", qr.GetSequence()))
continue
}

seqs = append(seqs, qr.GetSequence())
}

slices.Sort(seqs)

task := FindReplicationTask(r.timeoutHandler, node, seqs)
if task == nil {
r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node))
r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node), zap.Any("seqs", seqs))
return
}
r.timeoutHandler.RemoveTask(node, task.TaskID)
Expand Down
149 changes: 149 additions & 0 deletions replication_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package simplex_test

import (
"bytes"
"sync"
"testing"
"time"

"github.com/ava-labs/simplex"

Expand Down Expand Up @@ -79,6 +82,8 @@ func (m *testTimeoutMessageFilter) failOnReplicationRequest(msg *simplex.Message
return true
}

// receiveReplicationRequest is used to filter out sending replication responses, and notify a channel
// when a replication request is received.
func (m *testTimeoutMessageFilter) receivedReplicationRequest(msg *simplex.Message, _, _ simplex.NodeID) bool {
if msg.VerifiedReplicationResponse != nil || msg.ReplicationResponse != nil {
m.replicationResponses <- struct{}{}
Expand Down Expand Up @@ -285,3 +290,147 @@ func TestReplicationRequestIncompleteResponses(t *testing.T) {
laggingNode.e.AdvanceTime(laggingNode.e.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 2))
laggingNode.storage.waitForBlockCommit(startSeq)
}

type collectNotarizationComm struct {
lock *sync.Mutex
notarizations map[uint64]*simplex.Notarization
testNetworkCommunication

replicationResponses chan struct{}
}

func newCollectNotarizationComm(nodeID simplex.NodeID, net *inMemNetwork, notarizations map[uint64]*simplex.Notarization, lock *sync.Mutex) *collectNotarizationComm {
return &collectNotarizationComm{
notarizations: notarizations,
testNetworkCommunication: newTestComm(nodeID, net, allowAllMessages),
replicationResponses: make(chan struct{}, 3),
lock: lock,
}
}

func (c *collectNotarizationComm) Send(msg *simplex.Message, to simplex.NodeID) {
if msg.Notarization != nil {
c.lock.Lock()
c.notarizations[msg.Notarization.Vote.Round] = msg.Notarization
c.lock.Unlock()
}
c.testNetworkCommunication.Send(msg, to)
}

func (c *collectNotarizationComm) Broadcast(msg *simplex.Message) {
if msg.Notarization != nil {
c.lock.Lock()
c.notarizations[msg.Notarization.Vote.Round] = msg.Notarization
c.lock.Unlock()
}
c.testNetworkCommunication.Broadcast(msg)
}

func (c *collectNotarizationComm) removeFinalizationsFromReplicationResponses(msg *simplex.Message, from, to simplex.NodeID) bool {
if msg.VerifiedReplicationResponse != nil || msg.ReplicationResponse != nil {
newData := make([]simplex.VerifiedQuorumRound, 0, len(msg.VerifiedReplicationResponse.Data))

for i := 0; i < len(msg.VerifiedReplicationResponse.Data); i++ {
qr := msg.VerifiedReplicationResponse.Data[i]
if qr.Finalization != nil && c.notarizations[qr.GetRound()] != nil {
qr.Finalization = nil
qr.Notarization = c.notarizations[qr.GetRound()]
}
newData = append(newData, qr)
}
msg.VerifiedReplicationResponse.Data = newData
c.replicationResponses <- struct{}{}
}
return true
}

// TestReplicationRequestWithoutFinalization tests that a replication request is not marked as completed
// if we are expecting a finalization but it is not present in the response.
func TestReplicationRequestWithoutFinalization(t *testing.T) {
nodes := []simplex.NodeID{{1}, {2}, {3}, []byte("lagging")}
endDisconnect := uint64(10)
bb := newTestControlledBlockBuilder(t)
laggingBb := newTestControlledBlockBuilder(t)
net := newInMemNetwork(t, nodes)

notarizations := make(map[uint64]*simplex.Notarization)
mapLock := &sync.Mutex{}
testConfig := func(nodeID simplex.NodeID) *testNodeConfig {
return &testNodeConfig{
replicationEnabled: true,
comm: newCollectNotarizationComm(nodeID, net, notarizations, mapLock),
}
}

notarizationComm := newCollectNotarizationComm(nodes[0], net, notarizations, mapLock)
newSimplexNode(t, nodes[0], net, bb, &testNodeConfig{
replicationEnabled: true,
comm: notarizationComm,
})
newSimplexNode(t, nodes[1], net, bb, testConfig(nodes[1]))
newSimplexNode(t, nodes[2], net, bb, testConfig(nodes[2]))
laggingNode := newSimplexNode(t, nodes[3], net, laggingBb, testConfig(nodes[3]))

epochTimes := make([]time.Time, 0, 4)
for _, n := range net.instances {
epochTimes = append(epochTimes, n.e.StartTime)
}
// lagging node disconnects
net.Disconnect(nodes[3])
net.startInstances()

missedSeqs := uint64(0)
// normal nodes continue to make progress
for i := uint64(0); i < endDisconnect; i++ {
emptyRound := bytes.Equal(simplex.LeaderForRound(nodes, i), nodes[3])
if emptyRound {
advanceWithoutLeader(t, net, bb, epochTimes, i, laggingNode.e.ID)
missedSeqs++
} else {
bb.triggerNewBlock()
for _, n := range net.instances[:3] {
n.storage.waitForBlockCommit(i - missedSeqs)
}
}
}

// all nodes excpet for lagging node have progressed and commited [endDisconnect - missedSeqs] blocks
for _, n := range net.instances[:3] {
require.Equal(t, endDisconnect-missedSeqs, n.storage.Height())
}
require.Equal(t, uint64(0), laggingNode.storage.Height())
require.Equal(t, uint64(0), laggingNode.e.Metadata().Round)
// lagging node reconnects
net.setAllNodesMessageFilter(notarizationComm.removeFinalizationsFromReplicationResponses)
net.Connect(nodes[3])
bb.triggerNewBlock()
for _, n := range net.instances {
if n.e.ID.Equals(laggingNode.e.ID) {
continue
}

n.storage.waitForBlockCommit(endDisconnect - missedSeqs)
<-notarizationComm.replicationResponses
}

// wait until the lagging nodes sends replication requests
// due to the removeFinalizationsFromReplicationResponses message filter
// the lagging node should not have processed any replication requests
require.Equal(t, uint64(0), laggingNode.storage.Height())

// we should still have these replication requests in the timeout handler
laggingNode.e.AdvanceTime(laggingNode.e.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 2))

for range net.instances[:3] {
// the lagging node should have sent replication requests
// and the normal nodes should have responded
<-notarizationComm.replicationResponses
}

require.Equal(t, uint64(0), laggingNode.storage.Height())
// We should still have these replication requests in the timeout handler
// but now we allow the lagging node to process them
net.setAllNodesMessageFilter(allowAllMessages)
laggingNode.e.AdvanceTime(laggingNode.e.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 4))
laggingNode.storage.waitForBlockCommit(endDisconnect - missedSeqs)
}