Skip to content

conn updates 2 #1660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Legend: ✅ Done | 🔧 In Progress | 🔷 Planned | 🤞 Stretch Goal
- 🔷 Monaco Theming
- 🤞 Blockcontroller fixes for terminal escape sequences
- 🤞 Explore VSCode Extension Compatibility with standalone Monaco Editor (language servers)
- 🔷 Various Connection Bugs + Improvements
- 🔧 Various Connection Bugs + Improvements
- 🔧 More Connection Config Options

## Future Releases
Expand Down
38 changes: 22 additions & 16 deletions cmd/wsh/cmd/wshcmd-conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,26 @@ func validateConnectionName(name string) error {
return nil
}

func connStatusRun(cmd *cobra.Command, args []string) error {
func getAllConnStatus() ([]wshrpc.ConnStatus, error) {
var allResp []wshrpc.ConnStatus
sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting ssh connection status: %w", err)
return nil, fmt.Errorf("getting ssh connection status: %w", err)
}
allResp = append(allResp, sshResp...)
wslResp, err := wshclient.WslStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting wsl connection status: %w", err)
return nil, fmt.Errorf("getting wsl connection status: %w", err)
}
allResp = append(allResp, wslResp...)
return allResp, nil
}

func connStatusRun(cmd *cobra.Command, args []string) error {
allResp, err := getAllConnStatus()
if err != nil {
return err
}
if len(allResp) == 0 {
WriteStdout("no connections\n")
return nil
Expand Down Expand Up @@ -142,21 +150,19 @@ func connDisconnectRun(cmd *cobra.Command, args []string) error {
}

func connDisconnectAllRun(cmd *cobra.Command, args []string) error {
resp, err := wshclient.ConnStatusCommand(RpcClient, nil)
allConns, err := getAllConnStatus()
if err != nil {
return fmt.Errorf("getting connection status: %w", err)
}
if len(resp) == 0 {
return nil
return err
}
for _, conn := range resp {
if conn.Status == "connected" {
err := wshclient.ConnDisconnectCommand(RpcClient, conn.Connection, &wshrpc.RpcOpts{Timeout: 10000})
if err != nil {
WriteStdout("error disconnecting %q: %v\n", conn.Connection, err)
} else {
WriteStdout("disconnected %q\n", conn.Connection)
}
for _, conn := range allConns {
if conn.Status != "connected" {
continue
}
err := wshclient.ConnDisconnectCommand(RpcClient, conn.Connection, &wshrpc.RpcOpts{Timeout: 10000})
if err != nil {
WriteStdout("error disconnecting %q: %v\n", conn.Connection, err)
} else {
WriteStdout("disconnected %q\n", conn.Connection)
}
}
return nil
Expand Down
1 change: 0 additions & 1 deletion frontend/app/store/keymodel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ function registerGlobalKeys() {
return false;
}
globalKeyMap.set("Cmd:f", activateSearch);
globalKeyMap.set("Ctrl:f", activateSearch);
globalKeyMap.set("Escape", deactivateSearch);
const allKeys = Array.from(globalKeyMap.keys());
// special case keys, handled by web view
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ require (
go.opentelemetry.io/otel/metric v1.29.0 // indirect
go.opentelemetry.io/otel/trace v1.29.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
golang.org/x/mod v0.22.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/oauth2 v0.24.0 // indirect
golang.org/x/sync v0.10.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
Expand Down
86 changes: 52 additions & 34 deletions pkg/blockcontroller/blockcontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,27 +273,34 @@ func createCmdStrAndOpts(blockId string, blockMeta waveobj.MetaMapType) (string,
}

func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj.MetaMapType) error {
shellProc, err := bc.setupAndStartShellProcess(rc, blockMeta)
if err != nil {
return err
}
return bc.manageRunningShellProcess(shellProc, rc, blockMeta)
}

func (bc *BlockController) setupAndStartShellProcess(rc *RunShellOpts, blockMeta waveobj.MetaMapType) (*shellexec.ShellProc, error) {
// create a circular blockfile for the output
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
err := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, filestore.FileOptsType{MaxSize: DefaultTermMaxFileSize, Circular: true})
if err != nil && err != fs.ErrExist {
err = fs.ErrExist
return fmt.Errorf("error creating blockfile: %w", err)
fsErr := filestore.WFS.MakeFile(ctx, bc.BlockId, BlockFile_Term, nil, filestore.FileOptsType{MaxSize: DefaultTermMaxFileSize, Circular: true})
if fsErr != nil && fsErr != fs.ErrExist {
return nil, fmt.Errorf("error creating blockfile: %w", fsErr)
}
if err == fs.ErrExist {
if fsErr == fs.ErrExist {
// reset the terminal state
bc.resetTerminalState()
}
err = nil
bcInitStatus := bc.GetRuntimeStatus()
if bcInitStatus.ShellProcStatus == Status_Running {
return nil
return nil, nil
}
// TODO better sync here (don't let two starts happen at the same times)
remoteName := blockMeta.GetString(waveobj.MetaKey_Connection, "")
var cmdStr string
var cmdOpts shellexec.CommandOptsType
var err error
if bc.ControllerType == BlockController_Shell {
cmdOpts.Env = make(map[string]string)
cmdOpts.Interactive = true
Expand All @@ -302,19 +309,19 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
if cmdOpts.Cwd != "" {
cwdPath, err := wavebase.ExpandHomeDir(cmdOpts.Cwd)
if err != nil {
return err
return nil, err
}
cmdOpts.Cwd = cwdPath
}
} else if bc.ControllerType == BlockController_Cmd {
var cmdOptsPtr *shellexec.CommandOptsType
cmdStr, cmdOptsPtr, err = createCmdStrAndOpts(bc.BlockId, blockMeta)
if err != nil {
return err
return nil, err
}
cmdOpts = *cmdOptsPtr
} else {
return fmt.Errorf("unknown controller type %q", bc.ControllerType)
return nil, fmt.Errorf("unknown controller type %q", bc.ControllerType)
}
var shellProc *shellexec.ShellProc
if strings.HasPrefix(remoteName, "wsl://") {
Expand All @@ -325,45 +332,45 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
connStatus := wslConn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected, cannot start shellproc")
return nil, fmt.Errorf("not connected, cannot start shellproc")
}

// create jwt
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
return nil, fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return err
return nil, err
}
} else if remoteName != "" {
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
defer cancelFunc()

opts, err := remote.ParseOpts(remoteName)
if err != nil {
return err
return nil, err
}
conn := conncontroller.GetConn(credentialCtx, opts, false, &wshrpc.ConnKeywords{})
connStatus := conn.DeriveConnStatus()
if connStatus.Status != conncontroller.Status_Connected {
return fmt.Errorf("not connected, cannot start shellproc")
return nil, fmt.Errorf("not connected, cannot start shellproc")
}
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()}, conn.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
return nil, fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
if !conn.WshEnabled.Load() {
shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil {
return err
return nil, err
}
} else {
shellProc, err = shellexec.StartRemoteShellProc(rc.TermSize, cmdStr, cmdOpts, conn)
Expand All @@ -376,19 +383,16 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
log.Print("attempting install without wsh")
shellProc, err = shellexec.StartRemoteShellProcNoWsh(rc.TermSize, cmdStr, cmdOpts, conn)
if err != nil {
return err
return nil, err
}
}
}
if err != nil {
return err
}
} else {
// local terminal
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}, wavebase.GetDomainSocketName())
if err != nil {
return fmt.Errorf("error making jwt token: %w", err)
return nil, fmt.Errorf("error making jwt token: %w", err)
}
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
Expand All @@ -407,14 +411,18 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
}
shellProc, err = shellexec.StartShellProc(rc.TermSize, cmdStr, cmdOpts)
if err != nil {
return err
return nil, err
}
}
bc.UpdateControllerAndSendUpdate(func() bool {
bc.ShellProc = shellProc
bc.ShellProcStatus = Status_Running
return true
})
return shellProc, nil
}

func (bc *BlockController) manageRunningShellProcess(shellProc *shellexec.ShellProc, rc *RunShellOpts, blockMeta waveobj.MetaMapType) error {
shellInputCh := make(chan *BlockInputUnion, 32)
bc.ShellInputCh = shellInputCh

Expand Down Expand Up @@ -473,14 +481,7 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
shellProc.Cmd.Write(ic.InputData)
}
if ic.TermSize != nil {
err = setTermSize(ctx, bc.BlockId, *ic.TermSize)
if err != nil {
log.Printf("error setting pty size: %v\n", err)
}
err = shellProc.Cmd.SetSize(ic.TermSize.Rows, ic.TermSize.Cols)
if err != nil {
log.Printf("error setting pty size: %v\n", err)
}
updateTermSize(shellProc, bc.BlockId, *ic.TermSize)
}
}
}()
Expand Down Expand Up @@ -522,6 +523,17 @@ func (bc *BlockController) DoRunShellCommand(rc *RunShellOpts, blockMeta waveobj
return nil
}

func updateTermSize(shellProc *shellexec.ShellProc, blockId string, termSize waveobj.TermSize) {
err := setTermSizeInDB(blockId, termSize)
if err != nil {
log.Printf("error setting pty size: %v\n", err)
}
err = shellProc.Cmd.SetSize(termSize.Rows, termSize.Cols)
if err != nil {
log.Printf("error setting pty size: %v\n", err)
}
}

func checkCloseOnExit(blockId string, exitCode int) {
ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancelFn()
Expand Down Expand Up @@ -569,16 +581,22 @@ func getTermSize(bdata *waveobj.Block) waveobj.TermSize {
}
}

func setTermSize(ctx context.Context, blockId string, termSize waveobj.TermSize) error {
func setTermSizeInDB(blockId string, termSize waveobj.TermSize) error {
ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second)
defer cancelFn()
ctx = waveobj.ContextWithUpdates(ctx)
bdata, err := wstore.DBMustGet[*waveobj.Block](context.Background(), blockId)
bdata, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId)
if err != nil {
return fmt.Errorf("error getting block data: %v", err)
}
if bdata.RuntimeOpts == nil {
return fmt.Errorf("error from nil RuntimeOpts: %v", err)
bdata.RuntimeOpts = &waveobj.RuntimeOpts{}
}
bdata.RuntimeOpts.TermSize = termSize
err = wstore.DBUpdate(ctx, bdata)
if err != nil {
return fmt.Errorf("error updating block data: %v", err)
}
updates := waveobj.ContextGetUpdatesRtn(ctx)
wps.Broker.SendUpdateEvents(updates)
return nil
Expand Down
Loading
Loading