Skip to content

Commit bea3bdf

Browse files
sukunrtMarcoPolo
authored andcommitted
holepunch: pass address function in constructor
1 parent 71c66ee commit bea3bdf

File tree

5 files changed

+52
-76
lines changed

5 files changed

+52
-76
lines changed

p2p/host/basic/basic_host.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
274274
opts.HolePunchingOptions = append(hpOpts, opts.HolePunchingOptions...)
275275

276276
}
277-
h.hps, err = holepunch.NewService(h, h.ids, opts.HolePunchingOptions...)
277+
h.hps, err = holepunch.NewService(h, h.ids, func() []ma.Multiaddr {
278+
addrs := h.AllAddrs()
279+
if opts.AddrsFactory != nil {
280+
addrs = opts.AddrsFactory(addrs)
281+
}
282+
// AllAddrs may ignore observed addresses in favour of NAT mappings. Use both for hole punching.
283+
addrs = append(addrs, h.ids.OwnObservedAddrs()...)
284+
addrs = ma.Unique(addrs)
285+
return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
286+
}, opts.HolePunchingOptions...)
278287
if err != nil {
279288
return nil, fmt.Errorf("failed to create hole punch service: %w", err)
280289
}

p2p/protocol/holepunch/holepunch_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ func TestNoHolePunchIfDirectConnExists(t *testing.T) {
9494
require.GreaterOrEqual(t, nc1, 1)
9595
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
9696
require.GreaterOrEqual(t, nc2, 1)
97-
9897
require.NoError(t, hps.DirectConnect(h2.ID()))
9998
require.Len(t, h1.Network().ConnsToPeer(h2.ID()), nc1)
10099
require.Len(t, h2.Network().ConnsToPeer(h1.ID()), nc2)
@@ -473,8 +472,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc
473472
hps = addHolePunchService(t, h2, h2opt...)
474473
}
475474

476-
// h1 has a relay addr
477-
// h2 should connect to the relay addr
475+
// h2 has a relay addr
478476
var raddr ma.Multiaddr
479477
for _, a := range h2.Addrs() {
480478
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
@@ -483,6 +481,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc
483481
}
484482
}
485483
require.NotEmpty(t, raddr)
484+
// h1 should connect to the relay addr
486485
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
487486
ID: h2.ID(),
488487
Addrs: []ma.Multiaddr{raddr},
@@ -492,7 +491,9 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc
492491

493492
func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service {
494493
t.Helper()
495-
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
494+
hps, err := holepunch.NewService(h, newMockIDService(t, h), func() []ma.Multiaddr {
495+
return append(h.Addrs(), ma.StringCast("/ip4/1.2.3.4/tcp/1234"))
496+
}, opts...)
496497
require.NoError(t, err)
497498
return hps
498499
}
@@ -505,7 +506,6 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host,
505506
libp2p.ResourceManager(&network.NullResourceManager{}),
506507
)
507508
require.NoError(t, err)
508-
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
509-
require.NoError(t, err)
509+
hps := addHolePunchService(t, h, opts...)
510510
return h, hps
511511
}

p2p/protocol/holepunch/holepuncher.go

+12-11
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ type holePuncher struct {
3737
host host.Host
3838
refCount sync.WaitGroup
3939

40-
ids identify.IDService
40+
ids identify.IDService
41+
listenAddrs func() []ma.Multiaddr
4142

4243
// active hole punches for deduplicating
4344
activeMx sync.Mutex
@@ -50,13 +51,14 @@ type holePuncher struct {
5051
filter AddrFilter
5152
}
5253

53-
func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer, filter AddrFilter) *holePuncher {
54+
func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, tracer *tracer, filter AddrFilter) *holePuncher {
5455
hp := &holePuncher{
55-
host: h,
56-
ids: ids,
57-
active: make(map[peer.ID]struct{}),
58-
tracer: tracer,
59-
filter: filter,
56+
host: h,
57+
ids: ids,
58+
active: make(map[peer.ID]struct{}),
59+
tracer: tracer,
60+
filter: filter,
61+
listenAddrs: listenAddrs,
6062
}
6163
hp.ctx, hp.ctxCancel = context.WithCancel(context.Background())
6264
h.Network().Notify((*netNotifiee)(hp))
@@ -102,16 +104,15 @@ func (hp *holePuncher) directConnect(rp peer.ID) error {
102104
if getDirectConnection(hp.host, rp) != nil {
103105
return nil
104106
}
105-
106107
// short-circuit hole punching if a direct dial works.
107108
// attempt a direct connection ONLY if we have a public address for the remote peer
108109
for _, a := range hp.host.Peerstore().Addrs(rp) {
109-
if manet.IsPublicAddr(a) && !isRelayAddress(a) {
110+
if !isRelayAddress(a) && manet.IsPublicAddr(a) {
110111
forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching")
111112
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout)
112113

113114
tstart := time.Now()
114-
// This dials *all* public addresses from the peerstore.
115+
// This dials *all* addresses, public and private, from the peerstore.
115116
err := hp.host.Connect(dialCtx, peer.AddrInfo{ID: rp})
116117
dt := time.Since(tstart)
117118
cancel()
@@ -206,7 +207,7 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
206207
str.SetDeadline(time.Now().Add(StreamTimeout))
207208

208209
// send a CONNECT and start RTT measurement.
209-
obsAddrs := removeRelayAddrs(hp.ids.OwnObservedAddrs())
210+
obsAddrs := removeRelayAddrs(hp.listenAddrs())
210211
if hp.filter != nil {
211212
obsAddrs = hp.filter.FilterLocal(str.Conn().RemotePeer(), obsAddrs)
212213
}

p2p/protocol/holepunch/svc.go

+22-51
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"slices"
78
"sync"
89
"time"
910

1011
logging "github.com/ipfs/go-log/v2"
11-
"github.com/libp2p/go-libp2p/core/event"
1212
"github.com/libp2p/go-libp2p/core/host"
1313
"github.com/libp2p/go-libp2p/core/network"
1414
"github.com/libp2p/go-libp2p/core/peer"
1515
"github.com/libp2p/go-libp2p/core/protocol"
16-
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
1716
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb"
1817
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
1918
"github.com/libp2p/go-msgio/pbio"
@@ -47,7 +46,13 @@ type Service struct {
4746
ctxCancel context.CancelFunc
4847

4948
host host.Host
50-
ids identify.IDService
49+
// ids helps with connection reversal. We wait for identify to complete and attempt
50+
// a direct connection to the peer if it's publicly reachable.
51+
ids identify.IDService
52+
// listenAddrs provides the addresses for the host to be used for hole punching. We use this
53+
// and not host.Addrs because host.Addrs might remove public unreachable address and only advertise
54+
// publicly reachable relay addresses.
55+
listenAddrs func() []ma.Multiaddr
5156

5257
holePuncherMx sync.Mutex
5358
holePuncher *holePuncher
@@ -65,7 +70,7 @@ type Service struct {
6570
// no matter if they are behind a NAT / firewall or not.
6671
// The Service handles DCUtR streams (which are initiated from the node behind
6772
// a NAT / Firewall once we establish a connection to them through a relay.
68-
func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) {
73+
func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, opts ...Option) (*Service, error) {
6974
if ids == nil {
7075
return nil, errors.New("identify service can't be nil")
7176
}
@@ -76,6 +81,7 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,
7681
ctxCancel: cancel,
7782
host: h,
7883
ids: ids,
84+
listenAddrs: listenAddrs,
7985
hasPublicAddrsChan: make(chan struct{}),
8086
}
8187

@@ -88,18 +94,18 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,
8894
s.tracer.Start()
8995

9096
s.refCount.Add(1)
91-
go s.watchForPublicAddr()
97+
go s.waitForPublicAddr()
9298

9399
return s, nil
94100
}
95101

96-
func (s *Service) watchForPublicAddr() {
102+
func (s *Service) waitForPublicAddr() {
97103
defer s.refCount.Done()
98104

99105
log.Debug("waiting until we have at least one public address", "peer", s.host.ID())
100106

101107
// TODO: We should have an event here that fires when identify discovers a new
102-
// address (and when autonat confirms that address).
108+
// address.
103109
// As we currently don't have an event like this, just check our observed addresses
104110
// regularly (exponential backoff starting at 250 ms, capped at 5s).
105111
duration := 250 * time.Millisecond
@@ -125,44 +131,27 @@ func (s *Service) watchForPublicAddr() {
125131
}
126132
}
127133

128-
// Only start the holePuncher if we're behind a NAT / firewall.
129-
sub, err := s.host.EventBus().Subscribe(&event.EvtLocalReachabilityChanged{}, eventbus.Name("holepunch"))
130-
if err != nil {
131-
log.Debugf("failed to subscripe to Reachability event: %s", err)
134+
s.holePuncherMx.Lock()
135+
if s.ctx.Err() != nil {
136+
// service is closed
132137
return
133138
}
134-
defer sub.Close()
135-
for {
136-
select {
137-
case <-s.ctx.Done():
138-
return
139-
case e, ok := <-sub.Out():
140-
if !ok {
141-
return
142-
}
143-
if e.(event.EvtLocalReachabilityChanged).Reachability != network.ReachabilityPrivate {
144-
continue
145-
}
146-
s.holePuncherMx.Lock()
147-
s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter)
148-
s.holePuncherMx.Unlock()
149-
close(s.hasPublicAddrsChan)
150-
return
151-
}
152-
}
139+
s.holePuncher = newHolePuncher(s.host, s.ids, s.listenAddrs, s.tracer, s.filter)
140+
s.holePuncherMx.Unlock()
141+
close(s.hasPublicAddrsChan)
153142
}
154143

155144
// Close closes the Hole Punch Service.
156145
func (s *Service) Close() error {
157146
var err error
147+
s.ctxCancel()
158148
s.holePuncherMx.Lock()
159149
if s.holePuncher != nil {
160150
err = s.holePuncher.Close()
161151
}
162152
s.holePuncherMx.Unlock()
163153
s.tracer.Close()
164154
s.host.RemoveStreamHandler(Protocol)
165-
s.ctxCancel()
166155
s.refCount.Wait()
167156
return err
168157
}
@@ -172,7 +161,7 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, remo
172161
if !isRelayAddress(str.Conn().RemoteMultiaddr()) {
173162
return 0, nil, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr())
174163
}
175-
ownAddrs = s.getPublicAddrs()
164+
ownAddrs = s.listenAddrs()
176165
if s.filter != nil {
177166
ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs)
178167
}
@@ -277,25 +266,7 @@ func (s *Service) handleNewStream(str network.Stream) {
277266

278267
// getPublicAddrs returns public observed and interface addresses
279268
func (s *Service) getPublicAddrs() []ma.Multiaddr {
280-
addrs := removeRelayAddrs(s.ids.OwnObservedAddrs())
281-
282-
interfaceListenAddrs, err := s.host.Network().InterfaceListenAddresses()
283-
if err != nil {
284-
log.Debugf("failed to get to get InterfaceListenAddresses: %s", err)
285-
} else {
286-
addrs = append(addrs, interfaceListenAddrs...)
287-
}
288-
289-
addrs = ma.Unique(addrs)
290-
291-
publicAddrs := make([]ma.Multiaddr, 0, len(addrs))
292-
293-
for _, addr := range addrs {
294-
if manet.IsPublicAddr(addr) {
295-
publicAddrs = append(publicAddrs, addr)
296-
}
297-
}
298-
return publicAddrs
269+
return slices.DeleteFunc(s.listenAddrs(), func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
299270
}
300271

301272
// DirectConnect is only exposed for testing purposes.

p2p/protocol/holepunch/util.go

+2-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package holepunch
22

33
import (
44
"context"
5+
"slices"
56

67
"github.com/libp2p/go-libp2p/core/host"
78
"github.com/libp2p/go-libp2p/core/network"
@@ -11,13 +12,7 @@ import (
1112
)
1213

1314
func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
14-
result := make([]ma.Multiaddr, 0, len(addrs))
15-
for _, addr := range addrs {
16-
if !isRelayAddress(addr) {
17-
result = append(result, addr)
18-
}
19-
}
20-
return result
15+
return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return isRelayAddress(a) })
2116
}
2217

2318
func isRelayAddress(a ma.Multiaddr) bool {

0 commit comments

Comments
 (0)