Skip to content

Commit 131e5bd

Browse files
webtransport: have the server send the certificates (#1757)
1 parent 214b337 commit 131e5bd

File tree

4 files changed

+92
-54
lines changed

4 files changed

+92
-54
lines changed

p2p/transport/webtransport/cert_manager.go

+32-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package libp2pwebtransport
22

33
import (
4-
"bytes"
54
"context"
65
"crypto/sha256"
76
"crypto/tls"
87
"fmt"
98
"sync"
109
"time"
1110

11+
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
12+
1213
"github.com/benbjohnson/clock"
1314
ma "github.com/multiformats/go-multiaddr"
1415
"github.com/multiformats/go-multihash"
@@ -54,6 +55,8 @@ type certManager struct {
5455
currentConfig *certConfig
5556
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
5657
addrComp ma.Multiaddr
58+
59+
protobuf []byte
5760
}
5861

5962
func newCertManager(clock clock.Clock) (*certManager, error) {
@@ -88,6 +91,9 @@ func (m *certManager) rollConfig() error {
8891
m.lastConfig = m.currentConfig
8992
m.currentConfig = m.nextConfig
9093
m.nextConfig = c
94+
if err := m.cacheProtobuf(); err != nil {
95+
return err
96+
}
9197
return m.cacheAddrComponent()
9298
}
9399

@@ -131,17 +137,33 @@ func (m *certManager) AddrComponent() ma.Multiaddr {
131137
return m.addrComp
132138
}
133139

134-
func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error {
135-
for _, h := range hashes {
136-
if h.Code != multihash.SHA2_256 {
137-
return fmt.Errorf("expected SHA256 hash, got %d", h.Code)
138-
}
139-
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) &&
140-
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) &&
141-
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) {
142-
return fmt.Errorf("found unexpected hash: %+x", h.Digest)
140+
func (m *certManager) Protobuf() []byte {
141+
return m.protobuf
142+
}
143+
144+
func (m *certManager) cacheProtobuf() error {
145+
hashes := make([][32]byte, 0, 3)
146+
if m.lastConfig != nil {
147+
hashes = append(hashes, m.lastConfig.sha256)
148+
}
149+
hashes = append(hashes, m.currentConfig.sha256)
150+
if m.nextConfig != nil {
151+
hashes = append(hashes, m.nextConfig.sha256)
152+
}
153+
154+
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(hashes))}
155+
for _, certHash := range hashes {
156+
h, err := multihash.Encode(certHash[:], multihash.SHA2_256)
157+
if err != nil {
158+
return fmt.Errorf("failed to encode certificate hash: %w", err)
143159
}
160+
msg.CertHashes = append(msg.CertHashes, h)
161+
}
162+
msgBytes, err := msg.Marshal()
163+
if err != nil {
164+
return fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
144165
}
166+
m.protobuf = msgBytes
145167
return nil
146168
}
147169

p2p/transport/webtransport/listener.go

+15-28
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ import (
99
"net/http"
1010
"time"
1111

12+
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
13+
1214
"github.com/libp2p/go-libp2p/core/connmgr"
1315
"github.com/libp2p/go-libp2p/core/network"
1416
tpt "github.com/libp2p/go-libp2p/core/transport"
1517
"github.com/libp2p/go-libp2p/p2p/security/noise"
16-
pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb"
1718

1819
"github.com/lucas-clemente/quic-go/http3"
1920
"github.com/marten-seemann/webtransport-go"
2021
ma "github.com/multiformats/go-multiaddr"
2122
manet "github.com/multiformats/go-multiaddr/net"
22-
"github.com/multiformats/go-multihash"
2323
)
2424

2525
var errClosed = errors.New("closed")
@@ -197,7 +197,19 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
197197
if err != nil {
198198
return nil, err
199199
}
200-
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData)))
200+
var earlyData []byte
201+
if l.isStaticTLSConf {
202+
var msg pb.WebTransport
203+
var err error
204+
earlyData, err = msg.Marshal()
205+
if err != nil {
206+
return nil, err
207+
}
208+
} else {
209+
earlyData = l.certManager.Protobuf()
210+
}
211+
212+
n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataSender(earlyData)))
201213
if err != nil {
202214
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
203215
}
@@ -212,31 +224,6 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*
212224
}, nil
213225
}
214226

215-
func (l *listener) checkEarlyData(b []byte) error {
216-
var msg pb.WebTransport
217-
if err := msg.Unmarshal(b); err != nil {
218-
fmt.Println(1)
219-
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
220-
}
221-
222-
if l.isStaticTLSConf {
223-
if len(msg.CertHashes) > 0 {
224-
return errors.New("using static TLS config, didn't expect any certificate hashes")
225-
}
226-
return nil
227-
}
228-
229-
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
230-
for _, h := range msg.CertHashes {
231-
dh, err := multihash.Decode(h)
232-
if err != nil {
233-
return fmt.Errorf("failed to decode hash: %w", err)
234-
}
235-
hashes = append(hashes, *dh)
236-
}
237-
return l.certManager.Verify(hashes)
238-
}
239-
240227
func (l *listener) Addr() net.Addr {
241228
return l.addr
242229
}

p2p/transport/webtransport/transport.go

+43-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package libp2pwebtransport
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/tls"
67
"crypto/x509"
8+
"errors"
79
"fmt"
810
"io"
911
"sync"
@@ -196,32 +198,62 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
196198

197199
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
198200
// The server will verify that it advertised all of these certificate hashes.
199-
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))}
200-
for _, certHash := range certHashes {
201-
h, err := multihash.Encode(certHash.Digest, certHash.Code)
201+
var verified bool
202+
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error {
203+
decodedCertHashes, err := decodeCertHashesFromProtobuf(b)
202204
if err != nil {
203-
return nil, fmt.Errorf("failed to encode certificate hash: %w", err)
205+
return err
204206
}
205-
msg.CertHashes = append(msg.CertHashes, h)
206-
}
207-
msgBytes, err := msg.Marshal()
208-
if err != nil {
209-
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
210-
}
211-
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil))
207+
for _, sent := range certHashes {
208+
var found bool
209+
for _, rcvd := range decodedCertHashes {
210+
if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) {
211+
found = true
212+
break
213+
}
214+
}
215+
if !found {
216+
return fmt.Errorf("missing cert hash: %v", sent)
217+
}
218+
}
219+
verified = true
220+
return nil
221+
}), nil))
212222
if err != nil {
213223
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
214224
}
215225
c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p)
216226
if err != nil {
217227
return nil, err
218228
}
229+
// The Noise handshake _should_ guarantee that our verification callback is called.
230+
// Double-check just in case.
231+
if !verified {
232+
return nil, errors.New("didn't verify")
233+
}
219234
return &connSecurityMultiaddrs{
220235
ConnSecurity: c,
221236
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
222237
}, nil
223238
}
224239

240+
func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) {
241+
var msg pb.WebTransport
242+
if err := msg.Unmarshal(b); err != nil {
243+
return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
244+
}
245+
246+
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
247+
for _, h := range msg.CertHashes {
248+
dh, err := multihash.Decode(h)
249+
if err != nil {
250+
return nil, fmt.Errorf("failed to decode hash: %w", err)
251+
}
252+
hashes = append(hashes, *dh)
253+
}
254+
return hashes, nil
255+
}
256+
225257
func (t *transport) CanDial(addr ma.Multiaddr) bool {
226258
var numHashes int
227259
ma.ForEach(addr, func(c ma.Component) bool {

p2p/transport/webtransport/transport_test.go

+2-5
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,8 @@ func TestHashVerification(t *testing.T) {
162162
})
163163

164164
t.Run("fails when adding a wrong hash", func(t *testing.T) {
165-
conn, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
166-
if err != nil {
167-
_, err = conn.AcceptStream()
168-
require.Error(t, err)
169-
}
165+
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
166+
require.Error(t, err)
170167
})
171168

172169
require.NoError(t, ln.Close())

0 commit comments

Comments
 (0)