315 lines
6.5 KiB
Go
315 lines
6.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"flag"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/pprof"
|
|
"os"
|
|
"os/signal"
|
|
"slices"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/tutus-one/tutus-consensus"
|
|
"github.com/tutus-one/tutus-consensus/internal/consensus"
|
|
"github.com/tutus-one/tutus-consensus/internal/crypto"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type (
|
|
simNode struct {
|
|
id int
|
|
d *dbft.DBFT[crypto.Uint256]
|
|
messages chan dbft.ConsensusPayload[crypto.Uint256]
|
|
key dbft.PrivateKey
|
|
pub dbft.PublicKey
|
|
pool *memPool
|
|
cluster []*simNode
|
|
log *zap.Logger
|
|
|
|
height uint32
|
|
lastHash crypto.Uint256
|
|
validators []dbft.PublicKey
|
|
}
|
|
)
|
|
|
|
const (
|
|
defaultChanSize = 100
|
|
)
|
|
|
|
var (
|
|
nodebug = flag.Bool("nodebug", false, "disable debug logging")
|
|
count = flag.Int("count", 7, "node count")
|
|
watchers = flag.Int("watchers", 7, "watch-only node count")
|
|
blocked = flag.Int("blocked", -1, "blocked validator (payloads from him/her are dropped)")
|
|
txPerBlock = flag.Int("txblock", 1, "transactions per block")
|
|
txCount = flag.Int("txcount", 100000, "transactions on every node")
|
|
duration = flag.Duration("duration", time.Second*20, "duration of simulation (infinite by default)")
|
|
)
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
initDebugger()
|
|
|
|
logger := initLogger()
|
|
clusterSize := *count
|
|
watchOnly := *watchers
|
|
nodes := make([]*simNode, clusterSize+watchOnly)
|
|
|
|
initNodes(nodes, logger)
|
|
updatePublicKeys(nodes, clusterSize)
|
|
|
|
ctx, cancel := initContext(*duration)
|
|
defer cancel()
|
|
|
|
wg := new(sync.WaitGroup)
|
|
wg.Add(len(nodes))
|
|
|
|
for i := range nodes {
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
|
|
nodes[i].Run(ctx)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// Run implements simple event loop.
|
|
func (n *simNode) Run(ctx context.Context) {
|
|
n.d.Start(0)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
n.log.Info("context cancelled")
|
|
return
|
|
case <-n.d.Timer.C():
|
|
n.d.OnTimeout(n.d.Timer.Height(), n.d.Timer.View())
|
|
case msg := <-n.messages:
|
|
n.d.OnReceive(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
func initNodes(nodes []*simNode, log *zap.Logger) {
|
|
for i := range nodes {
|
|
if err := initSimNode(nodes, i, log); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func initSimNode(nodes []*simNode, i int, log *zap.Logger) error {
|
|
key, pub := crypto.Generate(rand.Reader)
|
|
nodes[i] = &simNode{
|
|
id: i,
|
|
messages: make(chan dbft.ConsensusPayload[crypto.Uint256], defaultChanSize),
|
|
key: key,
|
|
pub: pub,
|
|
pool: newMemoryPool(),
|
|
log: log.With(zap.Int("id", i)),
|
|
cluster: nodes,
|
|
}
|
|
|
|
var err error
|
|
nodes[i].d, err = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get,
|
|
nodes[i].pool.GetVerified,
|
|
nodes[i].Broadcast,
|
|
nodes[i].ProcessBlock,
|
|
nodes[i].CurrentHeight,
|
|
nodes[i].CurrentBlockHash,
|
|
nodes[i].GetValidators,
|
|
nodes[i].VerifyPayload,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize dBFT: %w", err)
|
|
}
|
|
|
|
nodes[i].addTx(*txCount)
|
|
|
|
return nil
|
|
}
|
|
|
|
func updatePublicKeys(nodes []*simNode, n int) {
|
|
pubs := make([]dbft.PublicKey, n)
|
|
for i := range pubs {
|
|
pubs[i] = nodes[i].pub
|
|
}
|
|
|
|
sortValidators(pubs)
|
|
|
|
for i := range nodes {
|
|
nodes[i].validators = pubs
|
|
}
|
|
}
|
|
|
|
func sortValidators(pubs []dbft.PublicKey) {
|
|
slices.SortFunc(pubs, func(a, b dbft.PublicKey) int {
|
|
x := a.(*crypto.ECDSAPub)
|
|
y := b.(*crypto.ECDSAPub)
|
|
return x.Compare(y)
|
|
})
|
|
}
|
|
|
|
func (n *simNode) Broadcast(m dbft.ConsensusPayload[crypto.Uint256]) {
|
|
for i, node := range n.cluster {
|
|
if i != n.id {
|
|
select {
|
|
case node.messages <- m:
|
|
default:
|
|
n.log.Warn("can't broadcast message: channel is full")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (n *simNode) CurrentHeight() uint32 { return n.height }
|
|
func (n *simNode) CurrentBlockHash() crypto.Uint256 { return n.lastHash }
|
|
|
|
// GetValidators always returns the same list of validators.
|
|
func (n *simNode) GetValidators(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey {
|
|
return n.validators
|
|
}
|
|
|
|
func (n *simNode) ProcessBlock(b dbft.Block[crypto.Uint256]) error {
|
|
n.d.Logger.Debug("received block", zap.Uint32("height", b.Index()))
|
|
|
|
for _, tx := range b.Transactions() {
|
|
n.pool.Delete(tx.Hash())
|
|
}
|
|
|
|
n.height = b.Index()
|
|
n.lastHash = b.Hash()
|
|
return nil
|
|
}
|
|
|
|
// VerifyPayload verifies that payload was received from a good validator.
|
|
func (n *simNode) VerifyPayload(p dbft.ConsensusPayload[crypto.Uint256]) error {
|
|
if *blocked != -1 && p.ValidatorIndex() == uint16(*blocked) {
|
|
return fmt.Errorf("message from blocked validator: %d", *blocked)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (n *simNode) addTx(count int) {
|
|
for i := range count {
|
|
tx := consensus.Tx64(uint64(i))
|
|
n.pool.Add(&tx)
|
|
}
|
|
}
|
|
|
|
// =============================
|
|
// Memory pool for transactions.
|
|
// =============================
|
|
|
|
type memPool struct {
|
|
mtx *sync.RWMutex
|
|
store map[crypto.Uint256]dbft.Transaction[crypto.Uint256]
|
|
}
|
|
|
|
func newMemoryPool() *memPool {
|
|
return &memPool{
|
|
mtx: new(sync.RWMutex),
|
|
store: make(map[crypto.Uint256]dbft.Transaction[crypto.Uint256]),
|
|
}
|
|
}
|
|
|
|
func (p *memPool) Add(tx dbft.Transaction[crypto.Uint256]) {
|
|
p.mtx.Lock()
|
|
|
|
h := tx.Hash()
|
|
if _, ok := p.store[h]; !ok {
|
|
p.store[h] = tx
|
|
}
|
|
|
|
p.mtx.Unlock()
|
|
}
|
|
|
|
func (p *memPool) Get(h crypto.Uint256) (tx dbft.Transaction[crypto.Uint256]) {
|
|
p.mtx.RLock()
|
|
tx = p.store[h]
|
|
p.mtx.RUnlock()
|
|
|
|
return
|
|
}
|
|
|
|
func (p *memPool) Delete(h crypto.Uint256) {
|
|
p.mtx.Lock()
|
|
delete(p.store, h)
|
|
p.mtx.Unlock()
|
|
}
|
|
|
|
func (p *memPool) GetVerified() (txx []dbft.Transaction[crypto.Uint256]) {
|
|
n := *txPerBlock
|
|
if n == 0 {
|
|
return
|
|
}
|
|
|
|
txx = make([]dbft.Transaction[crypto.Uint256], 0, n)
|
|
for _, tx := range p.store {
|
|
txx = append(txx, tx)
|
|
|
|
if n--; n == 0 {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// initDebugger initializes pprof debug facilities.
|
|
func initDebugger() {
|
|
r := http.NewServeMux()
|
|
r.HandleFunc("/debug/pprof/", pprof.Index)
|
|
r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
|
r.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
|
r.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
|
r.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
|
|
|
go func() {
|
|
err := http.ListenAndServe("localhost:6060", r)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// initLogger initializes new logger.
|
|
func initLogger() *zap.Logger {
|
|
if *nodebug {
|
|
return zap.L()
|
|
}
|
|
|
|
logger, err := zap.NewDevelopment()
|
|
if err != nil {
|
|
panic("can't init logger")
|
|
}
|
|
|
|
return logger
|
|
}
|
|
|
|
// initContext creates new context which will be cancelled by Ctrl+C.
|
|
func initContext(d time.Duration) (ctx context.Context, cancel func()) {
|
|
// exit by Ctrl+C
|
|
c := make(chan os.Signal, 1)
|
|
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
<-c
|
|
cancel()
|
|
}()
|
|
|
|
if d != 0 {
|
|
return context.WithTimeout(context.Background(), *duration)
|
|
}
|
|
|
|
return context.WithCancel(context.Background())
|
|
}
|