tutus-consensus/dbft_test.go

1269 lines
39 KiB
Go

package dbft_test
import (
"crypto/rand"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/tutus-one/tutus-consensus"
"github.com/tutus-one/tutus-consensus/internal/consensus"
"github.com/tutus-one/tutus-consensus/internal/crypto"
"github.com/tutus-one/tutus-consensus/timer"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
type Payload = dbft.ConsensusPayload[crypto.Uint256]
type testState struct {
myIndex int
count int
privs []dbft.PrivateKey
pubs []dbft.PublicKey
ch []Payload
currHeight uint32
currHash crypto.Uint256
pool *testPool
preBlocks []dbft.PreBlock[crypto.Uint256]
blocks []dbft.Block[crypto.Uint256]
verify func(b dbft.Block[crypto.Uint256]) bool
}
type (
testTx uint64
testPool struct {
storage map[crypto.Uint256]testTx
}
)
const debugTests = false
func TestDBFT_OnStartPrimarySendPrepareRequest(t *testing.T) {
s := newTestState(2, 7)
t.Run("backup sends nothing on start", func(t *testing.T) {
s.currHeight = 0
service, err := dbft.New[crypto.Uint256](s.getOptions()...)
require.NoError(t, err)
service.Start(0)
require.Nil(t, s.tryRecv())
})
t.Run("primary send PrepareRequest on start", func(t *testing.T) {
s.currHeight = 1
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)
p := s.tryRecv()
require.NotNil(t, p)
require.Equal(t, dbft.PrepareRequestType, p.Type())
require.EqualValues(t, 2, p.Height())
require.EqualValues(t, 0, p.ViewNumber())
require.NotNil(t, p.Payload())
require.EqualValues(t, 2, p.ValidatorIndex())
t.Run("primary send ChangeView on timeout", func(t *testing.T) {
service.OnTimeout(s.currHeight+1, 0)
// if there are many faulty must send RecoveryRequest
cv := s.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.RecoveryRequestType, cv.Type())
require.Nil(t, s.tryRecv())
// if all nodes are up must send ChangeView
for i := range service.LastSeenMessage {
service.LastSeenMessage[i] = &dbft.HeightView{s.currHeight + 1, 0}
}
service.OnTimeout(s.currHeight+1, 0)
cv = s.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())
require.EqualValues(t, 1, cv.GetChangeView().NewViewNumber())
require.Nil(t, s.tryRecv())
})
})
}
func TestDBFT_SingleNode(t *testing.T) {
for _, amev := range []bool{false, true} {
t.Run(fmt.Sprintf("AMEV %t", amev), func(t *testing.T) {
s := newTestState(0, 1)
s.currHeight = 2
opts := s.getOptions()
if amev {
opts = s.getAMEVOptions()
}
service, _ := dbft.New[crypto.Uint256](opts...)
service.Start(0)
p := s.tryRecv()
require.NotNil(t, p)
require.Equal(t, dbft.PrepareRequestType, p.Type())
require.EqualValues(t, 3, p.Height())
require.EqualValues(t, 0, p.ViewNumber())
require.NotNil(t, p.Payload())
require.EqualValues(t, 0, p.ValidatorIndex())
if amev {
cm := s.tryRecv()
require.NotNil(t, cm)
require.Equal(t, dbft.PreCommitType, cm.Type())
require.EqualValues(t, s.currHeight+1, cm.Height())
require.EqualValues(t, 0, cm.ViewNumber())
require.NotNil(t, cm.Payload())
require.EqualValues(t, 0, cm.ValidatorIndex())
}
cm := s.tryRecv()
require.NotNil(t, cm)
require.Equal(t, dbft.CommitType, cm.Type())
require.EqualValues(t, s.currHeight+1, cm.Height())
require.EqualValues(t, 0, cm.ViewNumber())
require.NotNil(t, cm.Payload())
require.EqualValues(t, 0, cm.ValidatorIndex())
b := s.nextBlock()
require.NotNil(t, b)
require.Equal(t, s.currHeight+1, b.Index())
})
}
}
func TestDBFT_OnReceiveRequestSendResponse(t *testing.T) {
s := newTestState(2, 7)
s.verify = func(b dbft.Block[crypto.Uint256]) bool {
for _, tx := range b.Transactions() {
if tx.(testTx)%10 == 0 {
return false
}
}
return true
}
t.Run("receive request from primary", func(t *testing.T) {
s.currHeight = 4
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{1}
s.pool.Add(txs[0])
p := s.getPrepareRequest(5, txs[0].Hash())
service.Start(0)
service.OnReceive(p)
resp := s.tryRecv()
require.NotNil(t, resp)
require.Equal(t, dbft.PrepareResponseType, resp.Type())
require.EqualValues(t, s.currHeight+1, resp.Height())
require.EqualValues(t, 0, resp.ViewNumber())
require.EqualValues(t, s.myIndex, resp.ValidatorIndex())
require.NotNil(t, resp.Payload())
require.Equal(t, p.Hash(), resp.GetPrepareResponse().PreparationHash())
// do nothing on second receive
service.OnReceive(p)
require.Nil(t, s.tryRecv())
t.Run("receive response from primary", func(t *testing.T) {
resp := s.getPrepareResponse(5, p.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())
})
})
t.Run("change view on invalid tx", func(t *testing.T) {
s.currHeight = 4
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{10}
service.Start(0)
for i := range service.LastSeenMessage {
service.LastSeenMessage[i] = &dbft.HeightView{s.currHeight + 1, 0}
}
p := s.getPrepareRequest(5, txs[0].Hash())
service.OnReceive(p)
require.Nil(t, s.tryRecv())
service.OnTransaction(testTx(10))
cv := s.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())
require.EqualValues(t, s.currHeight+1, cv.Height())
require.EqualValues(t, 0, cv.ViewNumber())
require.EqualValues(t, s.myIndex, cv.ValidatorIndex())
require.NotNil(t, cv.Payload())
require.EqualValues(t, 1, cv.GetChangeView().NewViewNumber())
})
t.Run("receive invalid prepare request", func(t *testing.T) {
s.currHeight = 4
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
txs := []testTx{1, 2}
s.pool.Add(txs[0])
service.Start(0)
t.Run("wrong primary index", func(t *testing.T) {
p := s.getPrepareRequest(4, txs[0].Hash())
service.OnReceive(p)
require.Nil(t, s.tryRecv())
})
t.Run("old height", func(t *testing.T) {
p := s.getPrepareRequestWithHeight(5, 3, txs[0].Hash())
service.OnReceive(p)
require.Nil(t, s.tryRecv())
})
t.Run("does not have all transactions", func(t *testing.T) {
p := s.getPrepareRequest(5, txs[0].Hash(), txs[1].Hash())
service.OnReceive(p)
require.Nil(t, s.tryRecv())
// do nothing with already present transaction
service.OnTransaction(txs[0])
require.Nil(t, s.tryRecv())
service.OnTransaction(txs[1])
resp := s.tryRecv()
require.NotNil(t, resp)
require.Equal(t, dbft.PrepareResponseType, resp.Type())
require.EqualValues(t, s.currHeight+1, resp.Height())
require.EqualValues(t, 0, resp.ViewNumber())
require.EqualValues(t, s.myIndex, resp.ValidatorIndex())
require.NotNil(t, resp.Payload())
require.Equal(t, p.Hash(), resp.GetPrepareResponse().PreparationHash())
// do not send response twice
service.OnTransaction(txs[1])
require.Nil(t, s.tryRecv())
})
})
}
func TestDBFT_CommitOnTransaction(t *testing.T) {
s := newTestState(0, 4)
s.currHeight = 1
srv, _ := dbft.New[crypto.Uint256](s.getOptions()...)
srv.Start(0)
require.Nil(t, s.tryRecv())
tx := testTx(42)
req := s.getPrepareRequest(2, tx.Hash())
srv.OnReceive(req)
srv.OnReceive(s.getPrepareResponse(1, req.Hash(), 0))
srv.OnReceive(s.getPrepareResponse(3, req.Hash(), 0))
require.Nil(t, srv.Header()) // missing transaction.
// Test state for forming header.
s1 := &testState{
count: s.count,
pool: newTestPool(),
currHeight: 1,
pubs: s.pubs,
privs: s.privs,
}
s1.pool.Add(tx)
srv1, _ := dbft.New[crypto.Uint256](s1.getOptions()...)
srv1.Start(0)
srv1.OnReceive(req)
srv1.OnReceive(s1.getPrepareResponse(1, req.Hash(), 0))
srv1.OnReceive(s1.getPrepareResponse(3, req.Hash(), 0))
require.NotNil(t, srv1.Header())
for _, i := range []uint16{1, 2, 3} {
require.NoError(t, srv1.Header().Sign(s1.privs[i]))
c := s1.getCommit(i, srv1.Header().Signature(), 0)
srv.OnReceive(c)
}
require.Nil(t, s.nextBlock())
srv.OnTransaction(tx)
require.NotNil(t, s.nextBlock())
}
func TestDBFT_OnReceiveCommit(t *testing.T) {
s := newTestState(2, 4)
t.Run("send commit after enough responses", func(t *testing.T) {
s.currHeight = 1
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)
req := s.tryRecv()
require.NotNil(t, req)
resp := s.getPrepareResponse(1, req.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())
resp = s.getPrepareResponse(0, req.Hash(), 0)
service.OnReceive(resp)
cm := s.tryRecv()
require.NotNil(t, cm)
require.Equal(t, dbft.CommitType, cm.Type())
require.EqualValues(t, s.currHeight+1, cm.Height())
require.EqualValues(t, 0, cm.ViewNumber())
require.EqualValues(t, s.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
pub := s.pubs[s.myIndex]
require.NoError(t, service.Header().Verify(pub, cm.GetCommit().Signature()))
t.Run("send recovery message on timeout", func(t *testing.T) {
service.OnTimeout(1, 0)
require.Nil(t, s.tryRecv())
service.OnTimeout(s.currHeight+1, 0)
r := s.tryRecv()
require.NotNil(t, r)
require.Equal(t, dbft.RecoveryMessageType, r.Type())
})
t.Run("process block after enough commits", func(t *testing.T) {
s0 := s.copyWithIndex(0)
require.NoError(t, service.Header().Sign(s0.privs[0]))
c0 := s0.getCommit(0, service.Header().Signature(), 0)
service.OnReceive(c0)
require.Nil(t, s.tryRecv())
require.Nil(t, s.nextBlock())
s1 := s.copyWithIndex(1)
require.NoError(t, service.Header().Sign(s1.privs[1]))
c1 := s1.getCommit(1, service.Header().Signature(), 0)
service.OnReceive(c1)
require.Nil(t, s.tryRecv())
b := s.nextBlock()
require.NotNil(t, b)
require.Equal(t, s.currHeight+1, b.Index())
})
})
}
func TestDBFT_OnReceiveRecoveryRequest(t *testing.T) {
s := newTestState(2, 4)
t.Run("send recovery message", func(t *testing.T) {
s.currHeight = 1
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)
req := s.tryRecv()
require.NotNil(t, req)
resp := s.getPrepareResponse(1, req.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())
resp = s.getPrepareResponse(0, req.Hash(), 0)
service.OnReceive(resp)
cm := s.tryRecv()
require.NotNil(t, cm)
rr := s.getRecoveryRequest(3)
service.OnReceive(rr)
rm := s.tryRecv()
require.NotNil(t, rm)
require.Equal(t, dbft.RecoveryMessageType, rm.Type())
other := s.copyWithIndex(3)
srv2, _ := dbft.New[crypto.Uint256](other.getOptions()...)
srv2.Start(0)
srv2.OnReceive(rm)
r2 := other.tryRecv()
require.NotNil(t, r2)
require.Equal(t, dbft.PrepareResponseType, r2.Type())
cm2 := other.tryRecv()
require.NotNil(t, cm2)
require.Equal(t, dbft.CommitType, cm2.Type())
pub := other.pubs[other.myIndex]
require.NoError(t, service.Header().Verify(pub, cm2.GetCommit().Signature()))
// send commit once during recovery
require.Nil(t, s.tryRecv())
})
}
func TestDBFT_OnReceiveRecoveryRequestResponds(t *testing.T) {
type recoveryset struct {
nodes int
sender int
receiver int
replies bool
}
var params []recoveryset
for _, nodes := range []int{4, 5, 7, 10} { // 5 is a bad BFT number, but we want to test the logic anyway.
for sender := range nodes {
for recv := range nodes {
params = append(params, recoveryset{nodes, sender, recv, false})
for i := 1; i <= ((nodes-1)/3)+1; i++ {
ind := (sender + i) % nodes
if ind == recv {
params[len(params)-1].replies = true
break
}
}
}
}
}
for _, param := range params {
t.Run(fmt.Sprintf("%d nodes, %d sender, %d receiver", param.nodes, param.sender, param.receiver), func(t *testing.T) {
s := newTestState(param.receiver, param.nodes)
s.currHeight = 1
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(uint64(param.receiver))
_ = s.tryRecv() // Flush the queue if primary.
rr := s.getRecoveryRequest(uint16(param.sender))
service.OnReceive(rr)
rm := s.tryRecv()
if param.replies {
require.NotNil(t, rm)
require.Equal(t, dbft.RecoveryMessageType, rm.Type())
} else {
require.Nil(t, rm)
}
})
}
}
func TestDBFT_OnReceiveChangeView(t *testing.T) {
s := newTestState(2, 4)
t.Run("change view correctly", func(t *testing.T) {
s.currHeight = 6
service, _ := dbft.New[crypto.Uint256](s.getOptions()...)
service.Start(0)
resp := s.getChangeView(1, 1)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())
resp = s.getChangeView(0, 1)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())
service.OnTimeout(s.currHeight+1, 0)
cv := s.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())
t.Run("primary sends prepare request after timeout", func(t *testing.T) {
service.OnTimeout(s.currHeight+1, 1)
pr := s.tryRecv()
require.NotNil(t, pr)
require.Equal(t, dbft.PrepareRequestType, pr.Type())
})
})
}
func TestDBFT_Invalid(t *testing.T) {
t.Run("without keys", func(t *testing.T) {
_, err := dbft.New[crypto.Uint256]()
require.Error(t, err)
})
priv, pub := crypto.Generate(rand.Reader)
require.NotNil(t, priv)
require.NotNil(t, pub)
opts := []func(*dbft.Config[crypto.Uint256]){dbft.WithGetKeyPair[crypto.Uint256](func(_ []dbft.PublicKey) (int, dbft.PrivateKey, dbft.PublicKey) {
return -1, nil, nil
})}
t.Run("without Timer", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithTimer[crypto.Uint256](timer.New()))
t.Run("without CurrentHeight", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return 0 }))
t.Run("without CurrentBlockHash", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return crypto.Uint256{} }))
t.Run("without GetValidators", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey {
return []dbft.PublicKey{pub}
}))
t.Run("without NewBlockFromContext", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewBlockFromContext[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] {
return nil
}))
t.Run("without NewConsensusPayload", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewConsensusPayload[crypto.Uint256](func(_ *dbft.Context[crypto.Uint256], _ dbft.MessageType, _ any) dbft.ConsensusPayload[crypto.Uint256] {
return nil
}))
t.Run("without NewPrepareRequest", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewPrepareRequest[crypto.Uint256](func(uint64, uint64, []crypto.Uint256) dbft.PrepareRequest[crypto.Uint256] {
return nil
}))
t.Run("without NewPrepareResponse", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewPrepareResponse[crypto.Uint256](func(crypto.Uint256) dbft.PrepareResponse[crypto.Uint256] {
return nil
}))
t.Run("without NewChangeView", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewChangeView[crypto.Uint256](func(byte, dbft.ChangeViewReason, uint64) dbft.ChangeView {
return nil
}))
t.Run("without NewCommit", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewCommit[crypto.Uint256](func([]byte) dbft.Commit {
return nil
}))
t.Run("without NewRecoveryRequest", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewRecoveryRequest[crypto.Uint256](func(uint64) dbft.RecoveryRequest {
return nil
}))
t.Run("without NewRecoveryMessage", func(t *testing.T) {
_, err := dbft.New(opts...)
require.Error(t, err)
})
opts = append(opts, dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] {
return nil
}), dbft.WithMaxTimePerBlock[crypto.Uint256](func() time.Duration {
return 0
}))
t.Run("MaxTimePerBlock without SubscribeForTxs", func(t *testing.T) {
_, err := dbft.New(opts...)
require.ErrorContains(t, err, "MaxTimePerBlock and SubscribeForTxs should be specified/not specified at the same time")
})
opts = append(opts, dbft.WithSubscribeForTxs[crypto.Uint256](func() {}))
t.Run("with all defaults", func(t *testing.T) {
d, err := dbft.New(opts...)
require.NoError(t, err)
require.NotNil(t, d)
require.NotNil(t, d.RequestTx)
require.NotNil(t, d.GetTx)
require.NotNil(t, d.GetVerified)
require.NotNil(t, d.VerifyBlock)
require.NotNil(t, d.Broadcast)
require.NotNil(t, d.ProcessBlock)
require.NotNil(t, d.GetBlock)
require.NotNil(t, d.Config.WatchOnly)
})
}
// TestDBFT_FourGoodNodesDeadlock checks that the following liveness lock is not really
// a liveness lock and there's a way to accept block in this situation.
// 0 :> [type |-> "cv", view |-> 1] <--- this is the primary at view 1
// 1 :> [type |-> "cv", view |-> 1] <--- this is the primary at view 0
// 2 :> [type |-> "commitSent", view |-> 0]
// 3 :> [type |-> "commitSent", view |-> 1]
//
// Test structure note: the test is organized to reproduce the liveness lock scenario
// described in https://github.com/neo-project/neo-modules/issues/792#issue-1609058923
// at the section named "1. Liveness lock with four non-faulty nodes". However, some
// steps are rearranged so that it's possible to reach the target network state described
// above. It is done because dbft implementation contains additional constraints comparing
// to the TLA+ model.
func TestDBFT_FourGoodNodesDeadlock(t *testing.T) {
r0 := newTestState(0, 4)
r0.currHeight = 4
s0, _ := dbft.New[crypto.Uint256](r0.getOptions()...)
s0.Start(0)
r1 := r0.copyWithIndex(1)
s1, _ := dbft.New[crypto.Uint256](r1.getOptions()...)
s1.Start(0)
r2 := r0.copyWithIndex(2)
s2, _ := dbft.New[crypto.Uint256](r2.getOptions()...)
s2.Start(0)
r3 := r0.copyWithIndex(3)
s3, _ := dbft.New[crypto.Uint256](r3.getOptions()...)
s3.Start(0)
// Step 1. The primary (at view 0) replica 1 sends the PrepareRequest message.
reqV0 := r1.tryRecv()
require.NotNil(t, reqV0)
require.Equal(t, dbft.PrepareRequestType, reqV0.Type())
// Step 2 will be performed later, see the comment to Step 2.
// Step 3. The backup (at view 0) replica 0 receives the PrepareRequest of
// view 0 and broadcasts its PrepareResponse.
s0.OnReceive(reqV0)
resp0V0 := r0.tryRecv()
require.NotNil(t, resp0V0)
require.Equal(t, dbft.PrepareResponseType, resp0V0.Type())
// Step 4 will be performed later, see the comment to Step 4.
// Step 5. The backup (at view 0) replica 2 receives the PrepareRequest of
// view 0 and broadcasts its PrepareResponse.
s2.OnReceive(reqV0)
resp2V0 := r2.tryRecv()
require.NotNil(t, resp2V0)
require.Equal(t, dbft.PrepareResponseType, resp2V0.Type())
// Step 6. The backup (at view 0) replica 2 collects M prepare messages (from
// itself and replicas 0, 1) and broadcasts the Commit message for view 0.
s2.OnReceive(resp0V0)
cm2V0 := r2.tryRecv()
require.NotNil(t, cm2V0)
require.Equal(t, dbft.CommitType, cm2V0.Type())
// Step 7. The backup (at view 0) replica 3 decides to change its view
// (possible on timeout) and sends the ChangeView message.
s3.OnReceive(resp0V0)
s3.OnReceive(resp2V0)
s3.OnTimeout(r3.currHeight+1, 0)
cv3V0 := r3.tryRecv()
require.NotNil(t, cv3V0)
require.Equal(t, dbft.ChangeViewType, cv3V0.Type())
// Step 2. The primary (at view 0) replica 1 decides to change its view
// (possible on timeout after receiving at least M non-commit messages from the
// current view) and sends the ChangeView message.
s1.OnReceive(resp0V0)
s1.OnReceive(cv3V0)
s1.OnTimeout(r1.currHeight+1, 0)
cv1V0 := r1.tryRecv()
require.NotNil(t, cv1V0)
require.Equal(t, dbft.ChangeViewType, cv1V0.Type())
// Step 4. The backup (at view 0) replica 0 decides to change its view
// (possible on timeout after receiving at least M non-commit messages from the
// current view) and sends the ChangeView message.
s0.OnReceive(cv3V0)
s0.OnTimeout(r0.currHeight+1, 0)
cv0V0 := r0.tryRecv()
require.NotNil(t, cv0V0)
require.Equal(t, dbft.ChangeViewType, cv0V0.Type())
// Step 8. The primary (at view 0) replica 1 collects M ChangeView messages
// (from itself and replicas 1, 3) and changes its view to 1.
s1.OnReceive(cv0V0)
require.Equal(t, uint8(1), s1.ViewNumber)
// Step 9. The backup (at view 0) replica 0 collects M ChangeView messages
// (from itself and replicas 0, 3) and changes its view to 1.
s0.OnReceive(cv1V0)
require.Equal(t, uint8(1), s0.ViewNumber)
// Step 10. The primary (at view 1) replica 0 sends the PrepareRequest message.
s0.OnTimeout(r0.currHeight+1, 1)
reqV1 := r0.tryRecv()
require.NotNil(t, reqV1)
require.Equal(t, dbft.PrepareRequestType, reqV1.Type())
// Step 11. The backup (at view 1) replica 1 receives the PrepareRequest of
// view 1 and sends the PrepareResponse.
s1.OnReceive(reqV1)
resp1V1 := r1.tryRecv()
require.NotNil(t, resp1V1)
require.Equal(t, dbft.PrepareResponseType, resp1V1.Type())
// Steps 12, 13 will be performed later, see the comments to Step 12, 13.
// Step 14. The backup (at view 0) replica 3 collects M ChangeView messages
// (from itself and replicas 0, 1) and changes its view to 1.
s3.OnReceive(cv0V0)
s3.OnReceive(cv1V0)
require.Equal(t, uint8(1), s3.ViewNumber)
// Intermediate step A. It is added to make step 14 possible. The backup (at
// view 1) replica 3 doesn't receive anything for a long time and sends
// RecoveryRequest.
s3.OnTimeout(r3.currHeight+1, 1)
rcvr3V1 := r3.tryRecv()
require.NotNil(t, rcvr3V1)
require.Equal(t, dbft.RecoveryRequestType, rcvr3V1.Type())
// Intermediate step B. The backup (at view 1) replica 1 should receive any
// message from replica 3 to be able to change view. However, it couldn't be
// PrepareResponse because replica 1 will immediately commit then. Thus, the
// only thing that remains is to receive RecoveryRequest from replica 3.
// Replica 1 then should answer with Recovery message.
s1.OnReceive(rcvr3V1)
rcvrResp1V1 := r1.tryRecv()
require.NotNil(t, rcvrResp1V1)
require.Equal(t, dbft.RecoveryMessageType, rcvrResp1V1.Type())
// Intermediate step C. The primary (at view 1) replica 0 should receive
// RecoveryRequest from replica 3. The purpose of this step is the same as
// in Intermediate step B.
s0.OnReceive(rcvr3V1)
rcvrResp0V1 := r0.tryRecv()
require.NotNil(t, rcvrResp0V1)
require.Equal(t, dbft.RecoveryMessageType, rcvrResp0V1.Type())
// Step 12. According to the neo-project/neo#792, at this step the backup (at view 1)
// replica 1 decides to change its view (possible on timeout) and sends the
// ChangeView message. However, the recovery message will be broadcast instead
// of CV, because there's additional condition: too much (>F) "lost" or committed
// nodes are present, see https://github.com/roman-khimov/dbft/blob/b769eb3e0f070d6eabb9443a5931eb4a2e46c538/send.go#L68.
// Replica 1 aware of replica 0 that has sent the PrepareRequest for view 1.
// It can also be aware of replica 2 that has committed at view 0, but it won't
// change the situation. The final way to allow CV is to receive something
// except from PrepareResponse from replica 3 to remove replica 3 from the list
// of "lost" nodes. That's why we'he added Intermediate steps A and B.
//
// After that replica 1 is allowed to send the CV message.
s1.OnTimeout(r1.currHeight+1, 1)
cv1V1 := r1.tryRecv()
require.NotNil(t, cv1V1)
require.Equal(t, dbft.ChangeViewType, cv1V1.Type())
// Step 13. The primary (at view 1) replica 0 decides to change its view
// (possible on timeout) and sends the ChangeView message.
s0.OnReceive(resp1V1)
s0.OnTimeout(r0.currHeight+1, 1)
cv0V1 := r0.tryRecv()
require.NotNil(t, cv0V1)
require.Equal(t, dbft.ChangeViewType, cv0V1.Type())
// Step 15. The backup (at view 1) replica 3 receives PrepareRequest of view
// 1 and broadcasts its PrepareResponse.
s3.OnReceive(reqV1)
resp3V1 := r3.tryRecv()
require.NotNil(t, resp3V1)
require.Equal(t, dbft.PrepareResponseType, resp3V1.Type())
// Step 16. The backup (at view 1) replica 3 collects M prepare messages and
// broadcasts the Commit message for view 1.
s3.OnReceive(resp1V1)
cm3V1 := r3.tryRecv()
require.NotNil(t, cm3V1)
require.Equal(t, dbft.CommitType, cm3V1.Type())
// Intermediate step D. It is needed to enable step 17 and to check that
// MoreThanFNodesCommittedOrLost works properly and counts Commit messages from
// any view.
s0.OnReceive(cm2V0)
s0.OnReceive(cm3V1)
// Step 17. The issue says that "The rest of undelivered messages eventually
// reaches their receivers, but it doesn't change the node's states.", but it's
// not true, the aim of the test is to show that replicas 0 and 1 still can
// commit at view 1 even after CV sent.
s0.OnReceive(resp3V1)
cm0V1 := r0.tryRecv()
require.NotNil(t, cm0V1)
require.Equal(t, dbft.CommitType, cm0V1.Type())
s1.OnReceive(cm0V1)
s1.OnReceive(resp3V1)
cm1V1 := r1.tryRecv()
require.NotNil(t, cm1V1)
require.Equal(t, dbft.CommitType, cm1V1.Type())
// Finally, send missing Commit message to replicas 0 and 1, they should accept
// the block.
require.Nil(t, r0.nextBlock())
s0.OnReceive(cm1V1)
require.NotNil(t, r0.nextBlock())
require.Nil(t, r1.nextBlock())
s1.OnReceive(cm3V1)
require.NotNil(t, r1.nextBlock())
}
func TestDBFT_OnReceiveCommitAMEV(t *testing.T) {
s := newTestState(2, 4)
t.Run("send preCommit after enough responses", func(t *testing.T) {
s.currHeight = 1
service, _ := dbft.New[crypto.Uint256](s.getAMEVOptions()...)
service.Start(0)
req := s.tryRecv()
require.NotNil(t, req)
resp := s.getPrepareResponse(1, req.Hash(), 0)
service.OnReceive(resp)
require.Nil(t, s.tryRecv())
resp = s.getPrepareResponse(0, req.Hash(), 0)
service.OnReceive(resp)
cm := s.tryRecv()
require.NotNil(t, cm)
require.Equal(t, dbft.PreCommitType, cm.Type())
require.EqualValues(t, s.currHeight+1, cm.Height())
require.EqualValues(t, 0, cm.ViewNumber())
require.EqualValues(t, s.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
pub := s.pubs[s.myIndex]
require.NoError(t, service.PreHeader().Verify(pub, cm.GetPreCommit().Data()))
t.Run("send commit after enough preCommits", func(t *testing.T) {
s0 := s.copyWithIndex(0)
require.NoError(t, service.PreHeader().SetData(s0.privs[0]))
preC0 := s0.getPreCommit(0, service.PreHeader().Data(), 0)
service.OnReceive(preC0)
require.Nil(t, s.tryRecv())
require.Nil(t, s.nextPreBlock())
require.Nil(t, s.nextBlock())
s1 := s.copyWithIndex(1)
require.NoError(t, service.PreHeader().SetData(s1.privs[1]))
preC1 := s1.getPreCommit(1, service.PreHeader().Data(), 0)
service.OnReceive(preC1)
b := s.nextPreBlock()
require.NotNil(t, b)
require.Equal(t, []byte{0, 0, 0, 2}, b.Data()) // After SetData it's equal to node index.
require.Nil(t, s.nextBlock())
c := s.tryRecv()
require.NotNil(t, c)
require.Equal(t, dbft.CommitType, c.Type())
require.EqualValues(t, s.currHeight+1, c.Height())
require.EqualValues(t, 0, c.ViewNumber())
require.EqualValues(t, s.myIndex, c.ValidatorIndex())
require.NotNil(t, c.Payload())
t.Run("process block a after enough commitAcks", func(t *testing.T) {
s0 := s.copyWithIndex(0)
require.NoError(t, service.Header().Sign(s0.privs[0]))
c0 := s0.getAMEVCommit(0, service.Header().Signature())
service.OnReceive(c0)
require.Nil(t, s.tryRecv())
require.Nil(t, s.nextPreBlock())
require.Nil(t, s.nextBlock())
s1 := s.copyWithIndex(1)
require.NoError(t, service.Header().Sign(s1.privs[1]))
c1 := s1.getAMEVCommit(1, service.Header().Signature())
service.OnReceive(c1)
require.Nil(t, s.tryRecv())
require.Nil(t, s.nextPreBlock())
b := s.nextBlock()
require.NotNil(t, b)
require.Equal(t, s.currHeight+1, b.Index())
})
})
})
}
func TestDBFT_CachedMessages(t *testing.T) {
for _, amev := range []bool{false, true} {
t.Run(fmt.Sprintf("AMEV %t", amev), func(t *testing.T) {
s2 := newTestState(2, 4)
s2.currHeight = 1
s1 := newTestState(1, 4)
s1.currHeight = 1
opts := s2.getOptions()
if amev {
opts = s2.getAMEVOptions()
}
service2, _ := dbft.New[crypto.Uint256](opts...)
service2.Start(0)
opts = s1.getOptions()
if amev {
opts = s1.getAMEVOptions()
}
service1, _ := dbft.New[crypto.Uint256](opts...)
service1.Start(0)
req := s2.tryRecv()
require.NotNil(t, req) // Primary sends a request.
require.Equal(t, dbft.PrepareRequestType, req.Type())
require.Nil(t, s1.tryRecv()) // Backup waits.
cv0 := s1.getChangeView(0, 1)
cv3 := s1.getChangeView(3, 1)
service1.OnReceive(cv0)
service1.OnReceive(cv3)
service1.OnTimeout(s1.currHeight+1, 0)
cv := s1.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())
service1.OnTimeout(s1.currHeight+1, 1)
req = s1.tryRecv()
require.NotNil(t, req)
require.Equal(t, dbft.PrepareRequestType, req.Type())
resp := s1.getPrepareResponse(3, req.Hash(), 1)
service1.OnReceive(resp)
require.Nil(t, s1.tryRecv())
service2.OnReceive(resp) // From the future.
require.Nil(t, s2.tryRecv())
resp = s1.getPrepareResponse(0, req.Hash(), 1)
service2.OnReceive(resp) // From the future.
require.Nil(t, s2.tryRecv())
service1.OnReceive(resp)
cm := s1.tryRecv()
require.NotNil(t, cm)
service2.OnReceive(cm)
require.Nil(t, s2.tryRecv())
if amev {
require.Equal(t, dbft.PreCommitType, cm.Type())
require.EqualValues(t, s1.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s1.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
pub := s1.pubs[s1.myIndex]
require.NoError(t, service1.PreHeader().Verify(pub, cm.GetPreCommit().Data()))
} else {
require.Equal(t, dbft.CommitType, cm.Type())
require.EqualValues(t, s1.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s1.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
}
service2.OnReceive(cv0)
service2.OnReceive(cv3)
service2.OnTimeout(s2.currHeight+1, 0)
cv = s2.tryRecv()
require.NotNil(t, cv)
require.Equal(t, dbft.ChangeViewType, cv.Type())
require.Equal(t, 1, int(service2.ViewNumber))
// s2 has some PrepareResponses, but doesn't have a request.
service2.OnReceive(req)
resp = s2.tryRecv()
require.NotNil(t, resp)
require.Equal(t, dbft.PrepareResponseType, resp.Type())
cm = s2.tryRecv()
require.NotNil(t, cm)
if amev {
require.Equal(t, dbft.PreCommitType, cm.Type())
require.EqualValues(t, s2.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s2.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
pub := s1.pubs[s1.myIndex]
require.NoError(t, service1.PreHeader().Verify(pub, cm.GetPreCommit().Data()))
service2.OnReceive(s2.getPreCommit(0, service2.PreHeader().Data(), 1))
cm = s2.tryRecv()
require.NotNil(t, cm)
require.Equal(t, dbft.CommitType, cm.Type())
} else {
require.Equal(t, dbft.CommitType, cm.Type())
require.EqualValues(t, s2.currHeight+1, cm.Height())
require.EqualValues(t, 1, cm.ViewNumber())
require.EqualValues(t, s2.myIndex, cm.ValidatorIndex())
require.NotNil(t, cm.Payload())
require.NoError(t, service2.Header().Sign(s2.privs[0]))
service2.OnReceive(s2.getCommit(0, service2.Header().Signature(), 1))
require.Nil(t, s2.tryRecv())
b := s2.nextBlock()
require.NotNil(t, b)
require.Equal(t, s2.currHeight+1, b.Index())
}
})
}
}
func (s testState) getChangeView(from uint16, view byte) Payload {
cv := consensus.NewChangeView(view, 0, 0)
p := consensus.NewConsensusPayload(dbft.ChangeViewType, s.currHeight+1, from, 0, cv)
return p
}
func (s testState) getRecoveryRequest(from uint16) Payload {
p := consensus.NewConsensusPayload(dbft.RecoveryRequestType, s.currHeight+1, from, 0, consensus.NewRecoveryRequest(0))
return p
}
func (s testState) getCommit(from uint16, sign []byte, view byte) Payload {
c := consensus.NewCommit(sign)
p := consensus.NewConsensusPayload(dbft.CommitType, s.currHeight+1, from, view, c)
return p
}
func (s testState) getAMEVCommit(from uint16, sign []byte) Payload {
c := consensus.NewAMEVCommit(sign)
p := consensus.NewConsensusPayload(dbft.CommitType, s.currHeight+1, from, 0, c)
return p
}
func (s testState) getPreCommit(from uint16, data []byte, view byte) Payload {
c := consensus.NewPreCommit(data)
p := consensus.NewConsensusPayload(dbft.PreCommitType, s.currHeight+1, from, view, c)
return p
}
func (s testState) getPrepareResponse(from uint16, phash crypto.Uint256, view byte) Payload {
resp := consensus.NewPrepareResponse(phash)
p := consensus.NewConsensusPayload(dbft.PrepareResponseType, s.currHeight+1, from, view, resp)
return p
}
func (s testState) getPrepareRequest(from uint16, hashes ...crypto.Uint256) Payload {
return s.getPrepareRequestWithHeight(from, s.currHeight+1, hashes...)
}
func (s testState) getPrepareRequestWithHeight(from uint16, height uint32, hashes ...crypto.Uint256) Payload {
req := consensus.NewPrepareRequest(0, 0, hashes)
p := consensus.NewConsensusPayload(dbft.PrepareRequestType, height, from, 0, req)
return p
}
func newTestState(myIndex int, count int) *testState {
s := &testState{
myIndex: myIndex,
count: count,
pool: newTestPool(),
}
s.privs, s.pubs = getTestValidators(count)
return s
}
func (s *testState) tryRecv() Payload {
if len(s.ch) == 0 {
return nil
}
p := s.ch[0]
s.ch = s.ch[1:]
return p
}
func (s *testState) nextBlock() dbft.Block[crypto.Uint256] {
if len(s.blocks) == 0 {
return nil
}
b := s.blocks[0]
s.blocks = s.blocks[1:]
return b
}
func (s *testState) nextPreBlock() dbft.PreBlock[crypto.Uint256] {
if len(s.preBlocks) == 0 {
return nil
}
b := s.preBlocks[0]
s.preBlocks = s.preBlocks[1:]
return b
}
func (s testState) copyWithIndex(myIndex int) *testState {
return &testState{
myIndex: myIndex,
count: s.count,
privs: s.privs,
pubs: s.pubs,
currHeight: s.currHeight,
currHash: s.currHash,
pool: newTestPool(),
}
}
func (s *testState) getOptions() []func(*dbft.Config[crypto.Uint256]) {
opts := []func(*dbft.Config[crypto.Uint256]){
dbft.WithTimer[crypto.Uint256](timer.New()),
dbft.WithCurrentHeight[crypto.Uint256](func() uint32 { return s.currHeight }),
dbft.WithCurrentBlockHash[crypto.Uint256](func() crypto.Uint256 { return s.currHash }),
dbft.WithGetValidators[crypto.Uint256](func(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey { return s.pubs }),
dbft.WithGetKeyPair[crypto.Uint256](func(_ []dbft.PublicKey) (int, dbft.PrivateKey, dbft.PublicKey) {
return s.myIndex, s.privs[s.myIndex], s.pubs[s.myIndex]
}),
dbft.WithBroadcast[crypto.Uint256](func(p Payload) { s.ch = append(s.ch, p) }),
dbft.WithGetTx[crypto.Uint256](s.pool.Get),
dbft.WithProcessBlock[crypto.Uint256](func(b dbft.Block[crypto.Uint256]) error { s.blocks = append(s.blocks, b); return nil }),
dbft.WithWatchOnly[crypto.Uint256](func() bool { return false }),
dbft.WithGetBlock[crypto.Uint256](func(crypto.Uint256) dbft.Block[crypto.Uint256] { return nil }),
dbft.WithTimer[crypto.Uint256](timer.New()),
dbft.WithLogger[crypto.Uint256](zap.NewNop()),
dbft.WithNewBlockFromContext[crypto.Uint256](newBlockFromContext),
dbft.WithTimePerBlock[crypto.Uint256](func() time.Duration {
return time.Second * 10
}),
dbft.WithRequestTx[crypto.Uint256](func(...crypto.Uint256) {}),
dbft.WithGetVerified[crypto.Uint256](func() []dbft.Transaction[crypto.Uint256] { return []dbft.Transaction[crypto.Uint256]{} }),
dbft.WithNewConsensusPayload[crypto.Uint256](newConsensusPayload),
dbft.WithNewPrepareRequest[crypto.Uint256](consensus.NewPrepareRequest),
dbft.WithNewPrepareResponse[crypto.Uint256](consensus.NewPrepareResponse),
dbft.WithNewChangeView[crypto.Uint256](consensus.NewChangeView),
dbft.WithNewCommit[crypto.Uint256](consensus.NewCommit),
dbft.WithNewRecoveryRequest[crypto.Uint256](consensus.NewRecoveryRequest),
dbft.WithNewRecoveryMessage[crypto.Uint256](func() dbft.RecoveryMessage[crypto.Uint256] {
return consensus.NewRecoveryMessage(nil)
}),
dbft.WithVerifyCommit[crypto.Uint256](func(p dbft.ConsensusPayload[crypto.Uint256]) error { return nil }),
}
verify := s.verify
if verify == nil {
verify = func(dbft.Block[crypto.Uint256]) bool { return true }
}
opts = append(opts, dbft.WithVerifyBlock(verify))
if debugTests {
cfg := zap.NewDevelopmentConfig()
cfg.DisableStacktrace = true
logger, _ := cfg.Build()
opts = append(opts, dbft.WithLogger[crypto.Uint256](logger))
}
return opts
}
func (s *testState) getAMEVOptions() []func(*dbft.Config[crypto.Uint256]) {
opts := s.getOptions()
opts = append(opts,
dbft.WithAntiMEVExtensionEnablingHeight[crypto.Uint256](0),
dbft.WithNewPreCommit[crypto.Uint256](consensus.NewPreCommit),
dbft.WithNewCommit[crypto.Uint256](consensus.NewAMEVCommit),
dbft.WithNewPreBlockFromContext[crypto.Uint256](newPreBlockFromContext),
dbft.WithNewBlockFromContext[crypto.Uint256](newAMEVBlockFromContext),
dbft.WithProcessPreBlock(func(b dbft.PreBlock[crypto.Uint256]) error {
s.preBlocks = append(s.preBlocks, b)
return nil
}),
)
return opts
}
func newBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] {
if ctx.TransactionHashes == nil {
return nil
}
block := consensus.NewBlock(ctx.Timestamp, ctx.BlockIndex, ctx.PrevHash, ctx.Nonce, ctx.TransactionHashes)
return block
}
func newPreBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.PreBlock[crypto.Uint256] {
if ctx.TransactionHashes == nil {
return nil
}
pre := consensus.NewPreBlock(ctx.Timestamp, ctx.BlockIndex, ctx.PrevHash, ctx.Nonce, ctx.TransactionHashes)
return pre
}
func newAMEVBlockFromContext(ctx *dbft.Context[crypto.Uint256]) dbft.Block[crypto.Uint256] {
if ctx.TransactionHashes == nil {
return nil
}
var data [][]byte
for _, c := range ctx.PreCommitPayloads {
if c != nil && c.ViewNumber() == ctx.ViewNumber {
data = append(data, c.GetPreCommit().Data())
}
}
pre := consensus.NewAMEVBlock(ctx.PreBlock(), data, ctx.M())
return pre
}
// newConsensusPayload is a function for creating consensus payload of specific
// type.
func newConsensusPayload(c *dbft.Context[crypto.Uint256], t dbft.MessageType, msg any) dbft.ConsensusPayload[crypto.Uint256] {
cp := consensus.NewConsensusPayload(t, c.BlockIndex, uint16(c.MyIndex), c.ViewNumber, msg)
return cp
}
func getTestValidators(n int) (privs []dbft.PrivateKey, pubs []dbft.PublicKey) {
for range n {
priv, pub := crypto.Generate(rand.Reader)
privs = append(privs, priv)
pubs = append(pubs, pub)
}
return
}
func (tx testTx) Hash() (h crypto.Uint256) {
binary.LittleEndian.PutUint64(h[:], uint64(tx))
return
}
func newTestPool() *testPool {
return &testPool{
storage: make(map[crypto.Uint256]testTx),
}
}
func (p *testPool) Add(tx testTx) {
p.storage[tx.Hash()] = tx
}
func (p *testPool) Get(h crypto.Uint256) dbft.Transaction[crypto.Uint256] {
if tx, ok := p.storage[h]; ok {
return tx
}
return nil
}