Skip to content

Commit f199e8c

Browse files
authored
Merge pull request #108 from gliderlabs/configurable-handlers
Configurable channel handlers Closes #89, #71
2 parents a9daacc + dd61f8b commit f199e8c

File tree

4 files changed

+68
-36
lines changed

4 files changed

+68
-36
lines changed

server.go

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,15 @@ type Server struct {
3737
IdleTimeout time.Duration // connection timeout when no activity, none if empty
3838
MaxTimeout time.Duration // absolute connection timeout, none if empty
3939

40-
channelHandlers map[string]channelHandler
41-
requestHandlers map[string]RequestHandler
40+
// ChannelHandlers allow overriding the built-in session handlers or provide
41+
// extensions to the protocol, such as tcpip forwarding. By default only the
42+
// "session" handler is enabled.
43+
ChannelHandlers map[string]ChannelHandler
44+
45+
// RequestHandlers allow overriding the server-level request handlers or
46+
// provide extensions to the protocol, such as tcpip forwarding. By default
47+
// no handlers are enabled.
48+
RequestHandlers map[string]RequestHandler
4249

4350
listenerWg sync.WaitGroup
4451
mu sync.Mutex
@@ -47,12 +54,32 @@ type Server struct {
4754
connWg sync.WaitGroup
4855
doneChan chan struct{}
4956
}
57+
5058
type RequestHandler interface {
51-
HandleRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
59+
HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
60+
}
61+
62+
type RequestHandlerFunc func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
63+
64+
func (f RequestHandlerFunc) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) {
65+
return f(ctx, srv, req)
66+
}
67+
68+
var DefaultRequestHandlers = map[string]RequestHandler{}
69+
70+
type ChannelHandler interface {
71+
HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
72+
}
73+
74+
type ChannelHandlerFunc func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
75+
76+
func (f ChannelHandlerFunc) HandleSSHChannel(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
77+
f(srv, conn, newChan, ctx)
5278
}
5379

54-
// internal for now
55-
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
80+
var DefaultChannelHandlers = map[string]ChannelHandler{
81+
"session": ChannelHandlerFunc(DefaultSessionHandler),
82+
}
5683

5784
func (srv *Server) ensureHostSigner() error {
5885
if len(srv.HostSigners) == 0 {
@@ -68,13 +95,17 @@ func (srv *Server) ensureHostSigner() error {
6895
func (srv *Server) ensureHandlers() {
6996
srv.mu.Lock()
7097
defer srv.mu.Unlock()
71-
srv.requestHandlers = map[string]RequestHandler{
72-
"tcpip-forward": forwardedTCPHandler{},
73-
"cancel-tcpip-forward": forwardedTCPHandler{},
98+
if srv.RequestHandlers == nil {
99+
srv.RequestHandlers = map[string]RequestHandler{}
100+
for k, v := range DefaultRequestHandlers {
101+
srv.RequestHandlers[k] = v
102+
}
74103
}
75-
srv.channelHandlers = map[string]channelHandler{
76-
"session": sessionHandler,
77-
"direct-tcpip": directTcpipHandler,
104+
if srv.ChannelHandlers == nil {
105+
srv.ChannelHandlers = map[string]ChannelHandler{}
106+
for k, v := range DefaultChannelHandlers {
107+
srv.ChannelHandlers[k] = v
108+
}
78109
}
79110
}
80111

@@ -186,12 +217,6 @@ func (srv *Server) Serve(l net.Listener) error {
186217
if srv.Handler == nil {
187218
srv.Handler = DefaultHandler
188219
}
189-
if srv.channelHandlers == nil {
190-
srv.channelHandlers = map[string]channelHandler{
191-
"session": sessionHandler,
192-
"direct-tcpip": directTcpipHandler,
193-
}
194-
}
195220
var tempDelay time.Duration
196221

197222
srv.trackListener(l, true)
@@ -255,30 +280,32 @@ func (srv *Server) handleConn(newConn net.Conn) {
255280
//go gossh.DiscardRequests(reqs)
256281
go srv.handleRequests(ctx, reqs)
257282
for ch := range chans {
258-
handler, found := srv.channelHandlers[ch.ChannelType()]
259-
if !found {
283+
handler := srv.ChannelHandlers[ch.ChannelType()]
284+
if handler == nil {
285+
handler = srv.ChannelHandlers["default"]
286+
}
287+
if handler == nil {
260288
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
261289
continue
262290
}
263-
go handler(srv, sshConn, ch, ctx)
291+
go handler.HandleSSHChannel(srv, sshConn, ch, ctx)
264292
}
265293
}
266294

267295
func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
268296
for req := range in {
269-
handler, found := srv.requestHandlers[req.Type]
270-
if !found {
271-
if req.WantReply {
272-
req.Reply(false, nil)
273-
}
297+
handler := srv.RequestHandlers[req.Type]
298+
if handler == nil {
299+
handler = srv.RequestHandlers["default"]
300+
}
301+
if handler == nil {
302+
req.Reply(false, nil)
274303
continue
275304
}
276305
/*reqCtx, cancel := context.WithCancel(ctx)
277306
defer cancel() */
278-
ret, payload := handler.HandleRequest(ctx, srv, req)
279-
if req.WantReply {
280-
req.Reply(ret, payload)
281-
}
307+
ret, payload := handler.HandleSSHRequest(ctx, srv, req)
308+
req.Reply(ret, payload)
282309
}
283310
}
284311

session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ type Session interface {
7777
// when there is no signal channel specified
7878
const maxSigBufSize = 128
7979

80-
func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
80+
func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
8181
ch, reqs, err := newChan.Accept()
8282
if err != nil {
8383
// TODO: trigger event callback

session_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ func (srv *Server) serveOnce(l net.Listener) error {
1919
if e != nil {
2020
return e
2121
}
22-
srv.channelHandlers = map[string]channelHandler{
23-
"session": sessionHandler,
24-
"direct-tcpip": directTcpipHandler,
22+
srv.ChannelHandlers = map[string]ChannelHandler{
23+
"session": ChannelHandlerFunc(DefaultSessionHandler),
24+
"direct-tcpip": ChannelHandlerFunc(DirectTCPIPHandler),
2525
}
2626
srv.handleConn(conn)
2727
return nil

tcpip.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ type localForwardChannelData struct {
2323
OriginPort uint32
2424
}
2525

26-
func directTcpipHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
26+
// DirectTCPIPHandler can be enabled by adding it to the server's
27+
// ChannelHandlers under direct-tcpip.
28+
func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
2729
d := localForwardChannelData{}
2830
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
2931
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
@@ -84,12 +86,15 @@ type remoteForwardChannelData struct {
8486
OriginPort uint32
8587
}
8688

87-
type forwardedTCPHandler struct {
89+
// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and
90+
// adding it to the server's RequestHandlers under tcpip-forward and
91+
// cancel-tcpip-forward.
92+
type ForwardedTCPHandler struct {
8893
forwards map[string]net.Listener
8994
sync.Mutex
9095
}
9196

92-
func (h forwardedTCPHandler) HandleRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
97+
func (h ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
9398
h.Lock()
9499
if h.forwards == nil {
95100
h.forwards = make(map[string]net.Listener)

0 commit comments

Comments
 (0)