package main import ( "context" "crypto/rand" "flag" "fmt" "net/http" "net/http/pprof" "os" "os/signal" "slices" "sync" "syscall" "time" "git.marketally.com/tutus-one/tutus-consensus" "git.marketally.com/tutus-one/tutus-consensus/internal/consensus" "git.marketally.com/tutus-one/tutus-consensus/internal/crypto" "go.uber.org/zap" ) type ( simNode struct { id int d *dbft.DBFT[crypto.Uint256] messages chan dbft.ConsensusPayload[crypto.Uint256] key dbft.PrivateKey pub dbft.PublicKey pool *memPool cluster []*simNode log *zap.Logger height uint32 lastHash crypto.Uint256 validators []dbft.PublicKey } ) const ( defaultChanSize = 100 ) var ( nodebug = flag.Bool("nodebug", false, "disable debug logging") count = flag.Int("count", 7, "node count") watchers = flag.Int("watchers", 7, "watch-only node count") blocked = flag.Int("blocked", -1, "blocked validator (payloads from him/her are dropped)") txPerBlock = flag.Int("txblock", 1, "transactions per block") txCount = flag.Int("txcount", 100000, "transactions on every node") duration = flag.Duration("duration", time.Second*20, "duration of simulation (infinite by default)") ) func main() { flag.Parse() initDebugger() logger := initLogger() clusterSize := *count watchOnly := *watchers nodes := make([]*simNode, clusterSize+watchOnly) initNodes(nodes, logger) updatePublicKeys(nodes, clusterSize) ctx, cancel := initContext(*duration) defer cancel() wg := new(sync.WaitGroup) wg.Add(len(nodes)) for i := range nodes { go func(i int) { defer wg.Done() nodes[i].Run(ctx) }(i) } wg.Wait() } // Run implements simple event loop. func (n *simNode) Run(ctx context.Context) { n.d.Start(0) for { select { case <-ctx.Done(): n.log.Info("context cancelled") return case <-n.d.Timer.C(): n.d.OnTimeout(n.d.Timer.Height(), n.d.Timer.View()) case msg := <-n.messages: n.d.OnReceive(msg) } } } func initNodes(nodes []*simNode, log *zap.Logger) { for i := range nodes { if err := initSimNode(nodes, i, log); err != nil { panic(err) } } } func initSimNode(nodes []*simNode, i int, log *zap.Logger) error { key, pub := crypto.Generate(rand.Reader) nodes[i] = &simNode{ id: i, messages: make(chan dbft.ConsensusPayload[crypto.Uint256], defaultChanSize), key: key, pub: pub, pool: newMemoryPool(), log: log.With(zap.Int("id", i)), cluster: nodes, } var err error nodes[i].d, err = consensus.New(nodes[i].log, key, pub, nodes[i].pool.Get, nodes[i].pool.GetVerified, nodes[i].Broadcast, nodes[i].ProcessBlock, nodes[i].CurrentHeight, nodes[i].CurrentBlockHash, nodes[i].GetValidators, nodes[i].VerifyPayload, ) if err != nil { return fmt.Errorf("failed to initialize dBFT: %w", err) } nodes[i].addTx(*txCount) return nil } func updatePublicKeys(nodes []*simNode, n int) { pubs := make([]dbft.PublicKey, n) for i := range pubs { pubs[i] = nodes[i].pub } sortValidators(pubs) for i := range nodes { nodes[i].validators = pubs } } func sortValidators(pubs []dbft.PublicKey) { slices.SortFunc(pubs, func(a, b dbft.PublicKey) int { x := a.(*crypto.ECDSAPub) y := b.(*crypto.ECDSAPub) return x.Compare(y) }) } func (n *simNode) Broadcast(m dbft.ConsensusPayload[crypto.Uint256]) { for i, node := range n.cluster { if i != n.id { select { case node.messages <- m: default: n.log.Warn("can't broadcast message: channel is full") } } } } func (n *simNode) CurrentHeight() uint32 { return n.height } func (n *simNode) CurrentBlockHash() crypto.Uint256 { return n.lastHash } // GetValidators always returns the same list of validators. func (n *simNode) GetValidators(...dbft.Transaction[crypto.Uint256]) []dbft.PublicKey { return n.validators } func (n *simNode) ProcessBlock(b dbft.Block[crypto.Uint256]) error { n.d.Logger.Debug("received block", zap.Uint32("height", b.Index())) for _, tx := range b.Transactions() { n.pool.Delete(tx.Hash()) } n.height = b.Index() n.lastHash = b.Hash() return nil } // VerifyPayload verifies that payload was received from a good validator. func (n *simNode) VerifyPayload(p dbft.ConsensusPayload[crypto.Uint256]) error { if *blocked != -1 && p.ValidatorIndex() == uint16(*blocked) { return fmt.Errorf("message from blocked validator: %d", *blocked) } return nil } func (n *simNode) addTx(count int) { for i := range count { tx := consensus.Tx64(uint64(i)) n.pool.Add(&tx) } } // ============================= // Memory pool for transactions. // ============================= type memPool struct { mtx *sync.RWMutex store map[crypto.Uint256]dbft.Transaction[crypto.Uint256] } func newMemoryPool() *memPool { return &memPool{ mtx: new(sync.RWMutex), store: make(map[crypto.Uint256]dbft.Transaction[crypto.Uint256]), } } func (p *memPool) Add(tx dbft.Transaction[crypto.Uint256]) { p.mtx.Lock() h := tx.Hash() if _, ok := p.store[h]; !ok { p.store[h] = tx } p.mtx.Unlock() } func (p *memPool) Get(h crypto.Uint256) (tx dbft.Transaction[crypto.Uint256]) { p.mtx.RLock() tx = p.store[h] p.mtx.RUnlock() return } func (p *memPool) Delete(h crypto.Uint256) { p.mtx.Lock() delete(p.store, h) p.mtx.Unlock() } func (p *memPool) GetVerified() (txx []dbft.Transaction[crypto.Uint256]) { n := *txPerBlock if n == 0 { return } txx = make([]dbft.Transaction[crypto.Uint256], 0, n) for _, tx := range p.store { txx = append(txx, tx) if n--; n == 0 { return } } return } // initDebugger initializes pprof debug facilities. func initDebugger() { r := http.NewServeMux() r.HandleFunc("/debug/pprof/", pprof.Index) r.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) r.HandleFunc("/debug/pprof/profile", pprof.Profile) r.HandleFunc("/debug/pprof/symbol", pprof.Symbol) r.HandleFunc("/debug/pprof/trace", pprof.Trace) go func() { err := http.ListenAndServe("localhost:6060", r) if err != nil { panic(err) } }() } // initLogger initializes new logger. func initLogger() *zap.Logger { if *nodebug { return zap.L() } logger, err := zap.NewDevelopment() if err != nil { panic("can't init logger") } return logger } // initContext creates new context which will be cancelled by Ctrl+C. func initContext(d time.Duration) (ctx context.Context, cancel func()) { // exit by Ctrl+C c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) go func() { <-c cancel() }() if d != 0 { return context.WithTimeout(context.Background(), *duration) } return context.WithCancel(context.Background()) }