@@ -37,8 +37,15 @@ type Server struct {
37
37
IdleTimeout time.Duration // connection timeout when no activity, none if empty
38
38
MaxTimeout time.Duration // absolute connection timeout, none if empty
39
39
40
- channelHandlers map [string ]channelHandler
41
- requestHandlers map [string ]RequestHandler
40
+ // ChannelHandlers allow overriding the built-in session handlers or provide
41
+ // extensions to the protocol, such as tcpip forwarding. By default only the
42
+ // "session" handler is enabled.
43
+ ChannelHandlers map [string ]ChannelHandler
44
+
45
+ // RequestHandlers allow overriding the server-level request handlers or
46
+ // provide extensions to the protocol, such as tcpip forwarding. By default
47
+ // no handlers are enabled.
48
+ RequestHandlers map [string ]RequestHandler
42
49
43
50
listenerWg sync.WaitGroup
44
51
mu sync.Mutex
@@ -47,12 +54,32 @@ type Server struct {
47
54
connWg sync.WaitGroup
48
55
doneChan chan struct {}
49
56
}
57
+
50
58
type RequestHandler interface {
51
- HandleRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
59
+ HandleSSHRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
60
+ }
61
+
62
+ type RequestHandlerFunc func (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte )
63
+
64
+ func (f RequestHandlerFunc ) HandleSSHRequest (ctx Context , srv * Server , req * gossh.Request ) (ok bool , payload []byte ) {
65
+ return f (ctx , srv , req )
66
+ }
67
+
68
+ var DefaultRequestHandlers = map [string ]RequestHandler {}
69
+
70
+ type ChannelHandler interface {
71
+ HandleSSHChannel (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
72
+ }
73
+
74
+ type ChannelHandlerFunc func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
75
+
76
+ func (f ChannelHandlerFunc ) HandleSSHChannel (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context ) {
77
+ f (srv , conn , newChan , ctx )
52
78
}
53
79
54
- // internal for now
55
- type channelHandler func (srv * Server , conn * gossh.ServerConn , newChan gossh.NewChannel , ctx Context )
80
+ var DefaultChannelHandlers = map [string ]ChannelHandler {
81
+ "session" : ChannelHandlerFunc (DefaultSessionHandler ),
82
+ }
56
83
57
84
func (srv * Server ) ensureHostSigner () error {
58
85
if len (srv .HostSigners ) == 0 {
@@ -68,13 +95,17 @@ func (srv *Server) ensureHostSigner() error {
68
95
func (srv * Server ) ensureHandlers () {
69
96
srv .mu .Lock ()
70
97
defer srv .mu .Unlock ()
71
- srv .requestHandlers = map [string ]RequestHandler {
72
- "tcpip-forward" : forwardedTCPHandler {},
73
- "cancel-tcpip-forward" : forwardedTCPHandler {},
98
+ if srv .RequestHandlers == nil {
99
+ srv .RequestHandlers = map [string ]RequestHandler {}
100
+ for k , v := range DefaultRequestHandlers {
101
+ srv .RequestHandlers [k ] = v
102
+ }
74
103
}
75
- srv .channelHandlers = map [string ]channelHandler {
76
- "session" : sessionHandler ,
77
- "direct-tcpip" : directTcpipHandler ,
104
+ if srv .ChannelHandlers == nil {
105
+ srv .ChannelHandlers = map [string ]ChannelHandler {}
106
+ for k , v := range DefaultChannelHandlers {
107
+ srv .ChannelHandlers [k ] = v
108
+ }
78
109
}
79
110
}
80
111
@@ -186,12 +217,6 @@ func (srv *Server) Serve(l net.Listener) error {
186
217
if srv .Handler == nil {
187
218
srv .Handler = DefaultHandler
188
219
}
189
- if srv .channelHandlers == nil {
190
- srv .channelHandlers = map [string ]channelHandler {
191
- "session" : sessionHandler ,
192
- "direct-tcpip" : directTcpipHandler ,
193
- }
194
- }
195
220
var tempDelay time.Duration
196
221
197
222
srv .trackListener (l , true )
@@ -255,30 +280,32 @@ func (srv *Server) handleConn(newConn net.Conn) {
255
280
//go gossh.DiscardRequests(reqs)
256
281
go srv .handleRequests (ctx , reqs )
257
282
for ch := range chans {
258
- handler , found := srv .channelHandlers [ch .ChannelType ()]
259
- if ! found {
283
+ handler := srv .ChannelHandlers [ch .ChannelType ()]
284
+ if handler == nil {
285
+ handler = srv .ChannelHandlers ["default" ]
286
+ }
287
+ if handler == nil {
260
288
ch .Reject (gossh .UnknownChannelType , "unsupported channel type" )
261
289
continue
262
290
}
263
- go handler (srv , sshConn , ch , ctx )
291
+ go handler . HandleSSHChannel (srv , sshConn , ch , ctx )
264
292
}
265
293
}
266
294
267
295
func (srv * Server ) handleRequests (ctx Context , in <- chan * gossh.Request ) {
268
296
for req := range in {
269
- handler , found := srv .requestHandlers [req .Type ]
270
- if ! found {
271
- if req .WantReply {
272
- req .Reply (false , nil )
273
- }
297
+ handler := srv .RequestHandlers [req .Type ]
298
+ if handler == nil {
299
+ handler = srv .RequestHandlers ["default" ]
300
+ }
301
+ if handler == nil {
302
+ req .Reply (false , nil )
274
303
continue
275
304
}
276
305
/*reqCtx, cancel := context.WithCancel(ctx)
277
306
defer cancel() */
278
- ret , payload := handler .HandleRequest (ctx , srv , req )
279
- if req .WantReply {
280
- req .Reply (ret , payload )
281
- }
307
+ ret , payload := handler .HandleSSHRequest (ctx , srv , req )
308
+ req .Reply (ret , payload )
282
309
}
283
310
}
284
311
0 commit comments