Skip to content

Commit 82ee6e8

Browse files
authored
{client, naming}: allow selector to define its own net.Addr parser (#176)
This is used to avoid unnecessary addr parse which is commonly used in trpc-database. To properly update the DSN library, we need to introduce this feature into the open-source tRPC-Go.
1 parent f727602 commit 82ee6e8

File tree

4 files changed

+80
-4
lines changed

4 files changed

+80
-4
lines changed

client/client.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ func selectorFilter(ctx context.Context, req interface{}, rsp interface{}, next
393393
if err != nil {
394394
return OptionsFromContext(ctx).fixTimeout(err)
395395
}
396-
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address)
396+
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr)
397397

398398
// Start to process the next filter and report.
399399
begin := time.Now()
@@ -471,11 +471,21 @@ func getNode(opts *Options) (*registry.Node, error) {
471471
return node, nil
472472
}
473473

474-
func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) {
474+
func ensureMsgRemoteAddr(
475+
msg codec.Msg,
476+
network, address string,
477+
parseAddr func(network, address string) net.Addr,
478+
) {
475479
// If RemoteAddr has already been set, just return.
476480
if msg.RemoteAddr() != nil {
477481
return
478482
}
483+
484+
if parseAddr != nil {
485+
msg.WithRemoteAddr(parseAddr(network, address))
486+
return
487+
}
488+
479489
switch network {
480490
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
481491
// Check if address can be parsed as an ip.
@@ -484,7 +494,6 @@ func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) {
484494
return
485495
}
486496
}
487-
488497
var addr net.Addr
489498
switch network {
490499
case "tcp", "tcp4", "tcp6":

client/client_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ package client_test
1616
import (
1717
"context"
1818
"errors"
19+
"fmt"
20+
"net"
1921
"testing"
2022
"time"
2123

@@ -409,6 +411,31 @@ func TestFixTimeout(t *testing.T) {
409411
})
410412
}
411413

414+
func TestSelectorRemoteAddrUseUserProvidedParser(t *testing.T) {
415+
selector.Register(t.Name(), &fSelector{
416+
selectNode: func(s string, option ...selector.Option) (*registry.Node, error) {
417+
return &registry.Node{
418+
Network: t.Name(),
419+
Address: t.Name(),
420+
ParseAddr: func(network, address string) net.Addr {
421+
return newUnresolvedAddr(network, address)
422+
}}, nil
423+
},
424+
report: func(node *registry.Node, duration time.Duration, err error) error { return nil },
425+
})
426+
fake := "fake"
427+
codec.Register(fake, nil, &fakeCodec{})
428+
ctx := trpc.BackgroundContext()
429+
require.NotNil(t, client.New().Invoke(ctx, "failbody", nil,
430+
client.WithServiceName(t.Name()),
431+
client.WithProtocol(fake),
432+
client.WithTarget(fmt.Sprintf("%s://xxx", t.Name()))))
433+
addr := trpc.Message(ctx).RemoteAddr()
434+
require.NotNil(t, addr)
435+
require.Equal(t, t.Name(), addr.Network())
436+
require.Equal(t, t.Name(), addr.String())
437+
}
438+
412439
type multiplexedTransport struct {
413440
require func(context.Context, []byte, ...transport.RoundTripOption)
414441
fakeTransport
@@ -527,3 +554,39 @@ func (c *fakeSelector) Select(serviceName string, opt ...selector.Option) (*regi
527554
func (c *fakeSelector) Report(node *registry.Node, cost time.Duration, err error) error {
528555
return nil
529556
}
557+
558+
type fSelector struct {
559+
selectNode func(string, ...selector.Option) (*registry.Node, error)
560+
report func(*registry.Node, time.Duration, error) error
561+
}
562+
563+
func (s *fSelector) Select(serviceName string, opts ...selector.Option) (*registry.Node, error) {
564+
return s.selectNode(serviceName, opts...)
565+
}
566+
567+
func (s *fSelector) Report(node *registry.Node, cost time.Duration, err error) error {
568+
return s.report(node, cost, err)
569+
}
570+
571+
// newUnresolvedAddr returns a new unresolvedAddr.
572+
func newUnresolvedAddr(network, address string) *unresolvedAddr {
573+
return &unresolvedAddr{network: network, address: address}
574+
}
575+
576+
var _ net.Addr = (*unresolvedAddr)(nil)
577+
578+
// unresolvedAddr is a net.Addr which returns the original network or address.
579+
type unresolvedAddr struct {
580+
network string
581+
address string
582+
}
583+
584+
// Network returns the unresolved original network.
585+
func (a *unresolvedAddr) Network() string {
586+
return a.network
587+
}
588+
589+
// String returns the unresolved original address.
590+
func (a *unresolvedAddr) String() string {
591+
return a.address
592+
}

client/stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func (s *stream) Init(ctx context.Context, opt ...Option) (*Options, error) {
162162
report.SelectNodeFail.Incr()
163163
return nil, err
164164
}
165-
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address)
165+
ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr)
166166
const invalidCost = -1
167167
opts.Node.set(node, node.Address, invalidCost)
168168
if opts.Codec == nil {

naming/registry/node.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ package registry
1515

1616
import (
1717
"fmt"
18+
"net"
1819
"time"
1920
)
2021

@@ -30,6 +31,9 @@ type Node struct {
3031
CostTime time.Duration // 当次请求耗时
3132
EnvKey string // 透传的环境信息
3233
Metadata map[string]interface{}
34+
// ParseAddr should be used to convert Node to net.Addr if it's not nil.
35+
// See test case TestSelectorRemoteAddrUseUserProvidedParser in client package.
36+
ParseAddr func(network, address string) net.Addr
3337
}
3438

3539
// String returns an abbreviation information of node.

0 commit comments

Comments
 (0)