Skip to content

Commit 7c31a8c

Browse files
authored
feat: s3/transfermanager (v2): round-robin DNS and multi-NIC (#2975)
1 parent 43b305d commit 7c31a8c

File tree

3 files changed

+387
-0
lines changed

3 files changed

+387
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package transfermanager
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"github.com/aws/smithy-go/container/private/cache"
8+
"github.com/aws/smithy-go/container/private/cache/lru"
9+
)
10+
11+
// dnsCache implements an LRU cache of DNS query results by host.
12+
//
13+
// Cache retrievals will automatically rotate between IP addresses for
14+
// multi-value query results.
15+
type dnsCache struct {
16+
mu sync.Mutex
17+
addrs cache.Cache
18+
}
19+
20+
// newDNSCache returns an initialized dnsCache with given capacity.
21+
func newDNSCache(cap int) *dnsCache {
22+
return &dnsCache{
23+
addrs: lru.New(cap),
24+
}
25+
}
26+
27+
// GetAddr returns the next IP address for the given host if present in the
28+
// cache.
29+
func (c *dnsCache) GetAddr(host string) (string, bool) {
30+
c.mu.Lock()
31+
defer c.mu.Unlock()
32+
33+
v, ok := c.addrs.Get(host)
34+
if !ok {
35+
return "", false
36+
}
37+
38+
record := v.(*dnsCacheEntry)
39+
if timeNow().After(record.expires) {
40+
return "", false
41+
}
42+
43+
addr := record.addrs[record.index]
44+
record.index = (record.index + 1) % len(record.addrs)
45+
return addr, true
46+
}
47+
48+
// PutAddrs stores a DNS query result in the cache, overwriting any present
49+
// entry for the host if it exists.
50+
func (c *dnsCache) PutAddrs(host string, addrs []string, expires time.Time) {
51+
c.mu.Lock()
52+
defer c.mu.Unlock()
53+
54+
c.addrs.Put(host, &dnsCacheEntry{addrs, expires, 0})
55+
}
56+
57+
type dnsCacheEntry struct {
58+
addrs []string
59+
expires time.Time
60+
index int
61+
}

feature/s3/transfermanager/rrdns.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package transfermanager
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"net/http"
8+
"sync"
9+
"time"
10+
11+
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
12+
)
13+
14+
var timeNow = time.Now
15+
16+
// WithRoundRobinDNS configures an http.Transport to spread HTTP connections
17+
// across multiple IP addresses for a given host.
18+
//
19+
// This is recommended by the [S3 performance guide] in high-concurrency
20+
// application environments.
21+
//
22+
// WithRoundRobinDNS wraps the underlying DialContext hook on http.Transport.
23+
// Future modifications to this hook MUST preserve said wrapping in order for
24+
// round-robin DNS to operate.
25+
//
26+
// [S3 performance guide]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html
27+
func WithRoundRobinDNS(opts ...func(*RoundRobinDNSOptions)) func(*http.Transport) {
28+
options := &RoundRobinDNSOptions{
29+
TTL: 30 * time.Second,
30+
MaxHosts: 100,
31+
}
32+
for _, opt := range opts {
33+
opt(options)
34+
}
35+
36+
return func(t *http.Transport) {
37+
rr := &rrDNS{
38+
cache: newDNSCache(options.MaxHosts),
39+
expiry: options.TTL,
40+
resolver: &net.Resolver{},
41+
dialContext: t.DialContext,
42+
}
43+
t.DialContext = rr.DialContext
44+
}
45+
}
46+
47+
// RoundRobinDNSOptions configures use of round-robin DNS.
48+
type RoundRobinDNSOptions struct {
49+
// The length of time for which the results of a DNS query are valid.
50+
TTL time.Duration
51+
52+
// A limit to the number of DNS query results, cached by hostname, which are
53+
// stored. Round-robin DNS uses an LRU cache.
54+
MaxHosts int
55+
}
56+
57+
type resolver interface {
58+
LookupHost(context.Context, string) ([]string, error)
59+
}
60+
61+
type rrDNS struct {
62+
sf singleflight.Group
63+
cache *dnsCache
64+
65+
expiry time.Duration
66+
resolver resolver
67+
68+
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
69+
}
70+
71+
// DialContext implements the DialContext hook used by http.Transport,
72+
// pre-caching IP addresses for a given host and distributing them evenly
73+
// across new connections.
74+
func (r *rrDNS) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
75+
host, port, err := net.SplitHostPort(addr)
76+
if err != nil {
77+
return nil, fmt.Errorf("rrdns split host/port: %w", err)
78+
}
79+
80+
ipaddr, err := r.getAddr(ctx, host)
81+
if err != nil {
82+
return nil, fmt.Errorf("rrdns lookup host: %w", err)
83+
}
84+
85+
return r.dialContext(ctx, network, net.JoinHostPort(ipaddr, port))
86+
}
87+
88+
func (r *rrDNS) getAddr(ctx context.Context, host string) (string, error) {
89+
addr, ok := r.cache.GetAddr(host)
90+
if ok {
91+
return addr, nil
92+
}
93+
return r.lookupHost(ctx, host)
94+
}
95+
96+
func (r *rrDNS) lookupHost(ctx context.Context, host string) (string, error) {
97+
ch := r.sf.DoChan(host, func() (interface{}, error) {
98+
addrs, err := r.resolver.LookupHost(ctx, host)
99+
if err != nil {
100+
return nil, err
101+
}
102+
103+
expires := timeNow().Add(r.expiry)
104+
r.cache.PutAddrs(host, addrs, expires)
105+
return nil, nil
106+
})
107+
108+
select {
109+
case result := <-ch:
110+
if result.Err != nil {
111+
return "", result.Err
112+
}
113+
114+
addr, _ := r.cache.GetAddr(host)
115+
return addr, nil
116+
case <-ctx.Done():
117+
return "", ctx.Err()
118+
}
119+
}
120+
121+
// WithRotoDialer configures an http.Transport to cycle through multiple local
122+
// network addresses when creating new HTTP connections.
123+
//
124+
// WithRotoDialer REPLACES the root DialContext hook on the underlying
125+
// Transport, thereby destroying any previously-applied wrappings around it. If
126+
// the caller needs to apply additional decorations to the DialContext hook,
127+
// they must do so after applying WithRotoDialer.
128+
func WithRotoDialer(addrs []net.Addr) func(*http.Transport) {
129+
return func(t *http.Transport) {
130+
var dialers []*net.Dialer
131+
for _, addr := range addrs {
132+
dialers = append(dialers, &net.Dialer{
133+
LocalAddr: addr,
134+
})
135+
}
136+
137+
t.DialContext = (&rotoDialer{
138+
dialers: dialers,
139+
}).DialContext
140+
}
141+
}
142+
143+
type rotoDialer struct {
144+
mu sync.Mutex
145+
dialers []*net.Dialer
146+
index int
147+
}
148+
149+
func (r *rotoDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
150+
return r.next().DialContext(ctx, network, addr)
151+
}
152+
153+
func (r *rotoDialer) next() *net.Dialer {
154+
r.mu.Lock()
155+
defer r.mu.Unlock()
156+
157+
d := r.dialers[r.index]
158+
r.index = (r.index + 1) % len(r.dialers)
159+
return d
160+
}
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package transfermanager
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"testing"
8+
"time"
9+
)
10+
11+
// these tests also cover the cache impl (cycling+expiry+evict)
12+
13+
type mockNow struct {
14+
now time.Time
15+
}
16+
17+
func (m *mockNow) Now() time.Time {
18+
return m.now
19+
}
20+
21+
func (m *mockNow) Add(d time.Duration) {
22+
m.now = m.now.Add(d)
23+
}
24+
25+
func useMockNow(m *mockNow) func() {
26+
timeNow = m.Now
27+
return func() {
28+
timeNow = time.Now
29+
}
30+
}
31+
32+
var errDialContextOK = errors.New("dial context ok")
33+
34+
type mockResolver struct {
35+
addrs map[string][]string
36+
err error
37+
}
38+
39+
func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
40+
return m.addrs[host], m.err
41+
}
42+
43+
type mockDialContext struct {
44+
calledWith string
45+
}
46+
47+
func (m *mockDialContext) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
48+
m.calledWith = addr
49+
return nil, errDialContextOK
50+
}
51+
52+
func TestRoundRobinDNS_CycleIPs(t *testing.T) {
53+
restore := useMockNow(&mockNow{})
54+
defer restore()
55+
56+
addrs := []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}
57+
r := &mockResolver{
58+
addrs: map[string][]string{
59+
"s3.us-east-1.amazonaws.com": addrs,
60+
},
61+
}
62+
dc := &mockDialContext{}
63+
64+
rr := &rrDNS{
65+
cache: newDNSCache(1),
66+
resolver: r,
67+
dialContext: dc.DialContext,
68+
}
69+
70+
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0])
71+
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[1])
72+
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[2])
73+
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0])
74+
}
75+
76+
func TestRoundRobinDNS_MultiIP(t *testing.T) {
77+
restore := useMockNow(&mockNow{})
78+
defer restore()
79+
80+
r := &mockResolver{
81+
addrs: map[string][]string{
82+
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
83+
"host2.com": []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"},
84+
},
85+
}
86+
dc := &mockDialContext{}
87+
88+
rr := &rrDNS{
89+
cache: newDNSCache(2),
90+
resolver: r,
91+
dialContext: dc.DialContext,
92+
}
93+
94+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
95+
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0])
96+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
97+
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][1])
98+
}
99+
100+
func TestRoundRobinDNS_MaxHosts(t *testing.T) {
101+
restore := useMockNow(&mockNow{})
102+
defer restore()
103+
104+
r := &mockResolver{
105+
addrs: map[string][]string{
106+
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
107+
"host2.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
108+
},
109+
}
110+
dc := &mockDialContext{}
111+
112+
rr := &rrDNS{
113+
cache: newDNSCache(1),
114+
resolver: r,
115+
dialContext: dc.DialContext,
116+
}
117+
118+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
119+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
120+
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) // evicts host1
121+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) // evicts host2
122+
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0])
123+
}
124+
125+
func TestRoundRobinDNS_Expires(t *testing.T) {
126+
now := &mockNow{time.Unix(0, 0)}
127+
restore := useMockNow(now)
128+
defer restore()
129+
130+
r := &mockResolver{
131+
addrs: map[string][]string{
132+
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
133+
},
134+
}
135+
dc := &mockDialContext{}
136+
137+
rr := &rrDNS{
138+
cache: newDNSCache(2),
139+
expiry: 30,
140+
resolver: r,
141+
dialContext: dc.DialContext,
142+
}
143+
144+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
145+
now.Add(16) // hasn't expired
146+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
147+
now.Add(16) // expired, starts over
148+
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
149+
}
150+
151+
func expectDialContext(t *testing.T, rr *rrDNS, dc *mockDialContext, host, expect string) {
152+
const port = "443"
153+
154+
t.Helper()
155+
_, err := rr.DialContext(context.Background(), "", net.JoinHostPort(host, port))
156+
if err != errDialContextOK {
157+
t.Errorf("expect sentinel err, got %v", err)
158+
}
159+
actual, _, err := net.SplitHostPort(dc.calledWith)
160+
if err != nil {
161+
t.Fatal(err)
162+
}
163+
if expect != actual {
164+
t.Errorf("expect addr %s, got %s", expect, actual)
165+
}
166+
}

0 commit comments

Comments
 (0)