Skip to content

feat(gossipsub): Add MessageBatch #607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions gossipsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package pubsub
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"math/rand"
"sort"
"sync"
"time"

pb "github.com/libp2p/go-libp2p-pubsub/pb"
Expand Down Expand Up @@ -1140,6 +1142,10 @@ func (gs *GossipSubRouter) connector() {
}

func (gs *GossipSubRouter) Publish(msg *Message) {
if msg.messageBatch != nil {
defer msg.messageBatch.doneWithMsg()
}

gs.mcache.Put(msg)

from := msg.ReceivedFrom
Expand Down Expand Up @@ -1213,6 +1219,10 @@ func (gs *GossipSubRouter) Publish(msg *Message) {
continue
}

if msg.messageBatch != nil {
msg.messageBatch.queueRPC(pid, gs.p.idGen.ID(msg), out)
continue
}
gs.sendRPC(pid, out, false)
}
}
Expand Down Expand Up @@ -2204,3 +2214,80 @@ func computeChecksum(mid string) checksum {
}
return cs
}

type pendingRPC struct {
peer peer.ID
rpc *RPC
}

// MessageBatch allows a user to batch related messages and then publish them
// at once. This allows the system to prioritize sending a single copy of each
// message before sending more copies. This helps bandwidth constrained peers.
type MessageBatch struct {
sync.Mutex
// PendingRPCsToAdd is a waitgroup that is used to wait for all the RPCs to
// be added to the batch. This library's publish is async, so we need to be
// careful to not publish the batch before all the RPCs are added.
pendingRPCsToAdd sync.WaitGroup
sendRPC func(peer peer.ID, rpc *RPC, urgent bool)
rpcs map[string][]pendingRPC
}

// NewMessageBatch creates a new MessageBatch. This only works for GossipSub.
func NewMessageBatch(ps *PubSub) (*MessageBatch, error) {
if ps == nil {
return nil, errors.New("pubsub is nil")
}
if gs, ok := ps.rt.(*GossipSubRouter); ok {
return &MessageBatch{
sendRPC: gs.sendRPC,
}, nil
}
return nil, errors.New("pubsub is not a GossipSubRouter")
}

// Add adds a message to the batch.
func (p *MessageBatch) Add(ctx context.Context, topic *Topic, data []byte, opts ...PubOpt) error {
p.pendingRPCsToAdd.Add(1)
opts = append(opts, func(o *PublishOptions) error {
o.messageBatch = p
return nil
})
return topic.Publish(ctx, data, opts...)
}

// Publish publishes the messages in the batch.
//
// Users should make sure there is enough space in the Peer's outbound queue to
// ensure messages are not dropped. WithPeerOutboundQueueSize should be set to
// at least the expected number of batched messages per peer plus some slack to
// account for gossip messages.
func (p *MessageBatch) Publish() {
p.pendingRPCsToAdd.Wait()
p.Lock()
defer p.Unlock()

for len(p.rpcs) > 0 {
for msgID, rpcs := range p.rpcs {
if len(rpcs) == 0 {
delete(p.rpcs, msgID)
continue
}
p.sendRPC(rpcs[0].peer, rpcs[0].rpc, false)
p.rpcs[msgID] = rpcs[1:]
}
}
}

func (p *MessageBatch) doneWithMsg() {
p.pendingRPCsToAdd.Done()
}

func (p *MessageBatch) queueRPC(peer peer.ID, msgID string, rpc *RPC) {
p.Lock()
defer p.Unlock()
if p.rpcs == nil {
p.rpcs = make(map[string][]pendingRPC)
}
p.rpcs[msgID] = append(p.rpcs[msgID], pendingRPC{peer: peer, rpc: rpc})
}
244 changes: 244 additions & 0 deletions gossipsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"io"
mrand "math/rand"
"sort"
"strings"
"sync"
"sync/atomic"
"testing"
"testing/quick"
"time"

pb "github.com/libp2p/go-libp2p-pubsub/pb"
Expand Down Expand Up @@ -3334,3 +3336,245 @@ func BenchmarkAllocDoDropRPC(b *testing.B) {
gs.doDropRPC(&RPC{}, "peerID", "reason")
}
}

func TestMessageBatchPublishesRarestFirst(t *testing.T) {
const maxNumPeers = 256
const maxNumMessages = 1_000

err := quick.Check(func(numPeers uint16, numMessages uint16) bool {
numPeers = numPeers % maxNumPeers
numMessages = numMessages % maxNumMessages

output := make([]pendingRPC, 0, numMessages*numPeers)
batch := &MessageBatch{
sendRPC: func(peer peer.ID, rpc *RPC, urgent bool) {
output = append(output, pendingRPC{
peer: peer,
rpc: rpc,
})
},
}

peers := make([]peer.ID, numPeers)
for i := 0; i < int(numPeers); i++ {
peers[i] = peer.ID(fmt.Sprintf("peer%d", i))
}

getID := func(r pendingRPC) string {
return string(r.rpc.Publish[0].Data)
}

for i := 0; i < int(numMessages); i++ {
for j := 0; j < int(numPeers); j++ {
batch.queueRPC(peers[j], fmt.Sprintf("msg%d", i), &RPC{
RPC: pb.RPC{
Publish: []*pb.Message{
{
Data: []byte(fmt.Sprintf("msg%d", i)),
},
},
},
})
}
}

batch.Publish()

// Check invariants
// 1. The published rpcs count is the same as the number of messages added
// 2. Before all message IDs are seen, no message ID may be repeated
// 3. The set of message ID + peer ID combinations should be the same as the input

// 1.
expectedCount := int(numMessages) * int(numPeers)
if len(output) != expectedCount {
t.Logf("Expected %d RPCs, got %d", expectedCount, len(output))
return false
}

// 2.
seen := make(map[string]bool)
expected := make(map[string]bool)
for i := 0; i < int(numMessages); i++ {
expected[fmt.Sprintf("msg%d", i)] = true
}

for _, rpc := range output {
if expected[getID(rpc)] {
delete(expected, getID(rpc))
}
if seen[getID(rpc)] && len(expected) > 0 {
t.Logf("Message ID %s repeated before all message IDs are seen", getID(rpc))
return false
}
seen[getID(rpc)] = true
}

// 3.
inputSet := make(map[string]bool)
for i := 0; i < int(numMessages); i++ {
for j := 0; j < int(numPeers); j++ {
inputSet[fmt.Sprintf("msg%d:peer%d", i, j)] = true
}
}
for _, rpc := range output {
if !inputSet[getID(rpc)+":"+string(rpc.peer)] {
t.Logf("Message ID %s not in input", getID(rpc))
return false
}
}
return true
}, &quick.Config{MaxCount: 32})
if err != nil {
t.Fatal(err)
}
}

func BenchmarkMessageBatchPublish(b *testing.B) {
const numPeers = 1_000
const numMessages = 1_000

batch := &MessageBatch{sendRPC: func(peer peer.ID, rpc *RPC, urgent bool) {}}

peers := make([]peer.ID, numPeers)
for i := 0; i < int(numPeers); i++ {
peers[i] = peer.ID(fmt.Sprintf("peer%d", i))
}
msgs := make([]string, numMessages)
for i := 0; i < numMessages; i++ {
msgs[i] = fmt.Sprintf("msg%d", i)
}

emptyRPC := &RPC{}
b.ResetTimer()

for i := 0; i < b.N; i++ {
j := i % len(peers)
msgIdx := i % numMessages
batch.queueRPC(peers[j], msgs[msgIdx], emptyRPC)
if i%100 == 0 {
batch.Publish()
}
}
}

func TestMessageBatchPublish(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hosts := getDefaultHosts(t, 20)

msgIDFn := func(msg *pb.Message) string {
hdr := string(msg.Data[0:16])
msgID := strings.SplitN(hdr, " ", 2)
return msgID[0]
}
const numMessages = 100
// +8 to account for the gossiping overhead
psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(msgIDFn), WithPeerOutboundQueueSize(numMessages+8))

var topics []*Topic
var msgs []*Subscription
for _, ps := range psubs {
topic, err := ps.Join("foobar")
if err != nil {
t.Fatal(err)
}
topics = append(topics, topic)

subch, err := topic.Subscribe(WithBufferSize(numMessages + 8))
if err != nil {
t.Fatal(err)
}

msgs = append(msgs, subch)
}

sparseConnect(t, hosts)

// wait for heartbeats to build mesh
time.Sleep(time.Second * 2)

batch, err := NewMessageBatch(psubs[0])
if err != nil {
t.Fatal(err)
}

for i := 0; i < numMessages; i++ {
msg := []byte(fmt.Sprintf("%d it's not a floooooood %d", i, i))
err := batch.Add(ctx, topics[0], msg)
if err != nil {
t.Fatal(err)
}
}
batch.Publish()

for i := 1; i < numMessages; i++ {
for _, sub := range msgs {
got, err := sub.Next(ctx)
if err != nil {
t.Fatal(sub.err)
}
id := msgIDFn(got.Message)
expected := []byte(fmt.Sprintf("%s it's not a floooooood %s", id, id))
if !bytes.Equal(expected, got.Data) {
t.Fatal("got wrong message!")
}
}
}
}

func TestMessageBatchAsyncAddMsg(t *testing.T) {
// Multiple runs because this is racey
const runs = 10
const expectedNumRPCsPerPeer = 10

peerCounts := []int{2, 3, 5}

for _, numPeers := range peerCounts {
t.Run(fmt.Sprintf("%d hosts", numPeers), func(t *testing.T) {
hosts := getDefaultHosts(t, numPeers)
psubs := getGossipsubs(context.Background(), hosts)
denseConnect(t, hosts)

var publisherTopic *Topic
for i, psub := range psubs {
topic, err := psub.Join("foobar")
if err != nil {
t.Fatal(err)
}
_, err = topic.Subscribe(WithBufferSize(runs * expectedNumRPCsPerPeer))
if err != nil {
t.Fatal(err)
}

if i == 0 {
publisherTopic = topic
}
}
// Give the nodes a second to bootstrap
time.Sleep(2 * time.Second)

for range runs {
var sentRPCs atomic.Int32
b, err := NewMessageBatch(psubs[0])
if err != nil {
t.Fatal(err)
}
b.sendRPC = func(peer peer.ID, rpc *RPC, urgent bool) {
sentRPCs.Add(1)
}

// publisher
for i := range expectedNumRPCsPerPeer {
b.Add(context.Background(), publisherTopic, []byte(fmt.Sprintf("msg%d", i)))
}
b.Publish()

if sentRPCs.Load() != int32(expectedNumRPCsPerPeer*(numPeers-1)) {
t.Fatalf("expected %d RPCs, got %d", expectedNumRPCsPerPeer, sentRPCs.Load())
}
}
})
}

}
3 changes: 2 additions & 1 deletion pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ type Message struct {
ReceivedFrom peer.ID
ValidatorData interface{}
Local bool
messageBatch *MessageBatch
}

func (m *Message) GetFrom() peer.ID {
Expand Down Expand Up @@ -1101,7 +1102,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) {
continue
}

msg := &Message{pmsg, "", rpc.from, nil, false}
msg := &Message{pmsg, "", rpc.from, nil, false, nil}
if p.shouldPush(msg) {
toPush = append(toPush, msg)
}
Expand Down
Loading