Skip to content

Commit 22fb034

Browse files
authored
PE-41 remote file api (#1)
* RPC for remote file streaming -- just implemented 'stat' for now (streaming to come) * allow RPC iterators for MShell RPCs. implement two test commands to test viewing files * implement read-file handler * read-file: allow overriding of content-type and use line's cwd not remote instance cwd * checkpoint on write-file impl * implemented metacommand version of write file * checkpoint, untested write-file impl * multipart handling for write-file data * add usetemp param to writefile
1 parent 681d80e commit 22fb034

File tree

5 files changed

+542
-7
lines changed

5 files changed

+542
-7
lines changed

cmd/main-server.go

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@ package main
22

33
import (
44
"context"
5+
"encoding/base64"
56
"encoding/json"
67
"errors"
78
"fmt"
9+
"io"
810
"io/fs"
911
"log"
12+
"mime/multipart"
1013
"net/http"
1114
"os"
1215
"os/signal"
16+
"path/filepath"
17+
"regexp"
1318
"runtime/debug"
1419
"strconv"
1520
"strings"
@@ -20,6 +25,8 @@ import (
2025
"github.com/google/uuid"
2126
"github.com/gorilla/mux"
2227

28+
"github.com/commandlinedev/apishell/pkg/packet"
29+
"github.com/commandlinedev/apishell/pkg/server"
2330
"github.com/commandlinedev/prompt-server/pkg/cmdrunner"
2431
"github.com/commandlinedev/prompt-server/pkg/pcloud"
2532
"github.com/commandlinedev/prompt-server/pkg/remote"
@@ -49,11 +56,14 @@ const InitialTelemetryWait = 30 * time.Second
4956
const TelemetryTick = 30 * time.Minute
5057
const TelemetryInterval = 8 * time.Hour
5158

59+
const MaxWriteFileMemSize = 20 * (1024 * 1024) // 20M
60+
5261
var GlobalLock = &sync.Mutex{}
5362
var WSStateMap = make(map[string]*scws.WSState) // clientid -> WsState
5463
var GlobalAuthKey string
5564
var BuildTime = "0"
5665
var shutdownOnce sync.Once
66+
var ContentTypeHeaderValidRe = regexp.MustCompile(`^\w+/[\w.+-]+$`)
5767

5868
type ClientActiveState struct {
5969
Fg bool `json:"fg"`
@@ -312,6 +322,273 @@ func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
312322
w.Write(data)
313323
}
314324

325+
type writeFileParamsType struct {
326+
ScreenId string `json:"screenid"`
327+
LineId string `json:"lineid"`
328+
Path string `json:"path"`
329+
UseTemp bool `json:"usetemp,omitempty"`
330+
}
331+
332+
func parseWriteFileParams(r *http.Request) (*writeFileParamsType, multipart.File, error) {
333+
err := r.ParseMultipartForm(MaxWriteFileMemSize)
334+
if err != nil {
335+
return nil, nil, fmt.Errorf("cannot parse multipart form data: %v", err)
336+
}
337+
form := r.MultipartForm
338+
if len(form.Value["params"]) == 0 {
339+
return nil, nil, fmt.Errorf("no params found")
340+
}
341+
paramsStr := form.Value["params"][0]
342+
var params writeFileParamsType
343+
err = json.Unmarshal([]byte(paramsStr), &params)
344+
if err != nil {
345+
return nil, nil, fmt.Errorf("bad params json: %v", err)
346+
}
347+
if len(form.File["data"]) == 0 {
348+
return nil, nil, fmt.Errorf("no data found")
349+
}
350+
fileHeader := form.File["data"][0]
351+
file, err := fileHeader.Open()
352+
if err != nil {
353+
return nil, nil, fmt.Errorf("error opening multipart data file: %v", err)
354+
}
355+
return &params, file, nil
356+
}
357+
358+
func HandleWriteFile(w http.ResponseWriter, r *http.Request) {
359+
defer func() {
360+
r := recover()
361+
if r == nil {
362+
return
363+
}
364+
log.Printf("[error] in write-file: %v\n", r)
365+
debug.PrintStack()
366+
WriteJsonError(w, fmt.Errorf("panic: %v", r))
367+
return
368+
}()
369+
w.Header().Set("Cache-Control", "no-cache")
370+
params, mpFile, err := parseWriteFileParams(r)
371+
if err != nil {
372+
WriteJsonError(w, fmt.Errorf("error parsing multipart form params: %w", err))
373+
return
374+
}
375+
if params.ScreenId == "" || params.LineId == "" || params.Path == "" {
376+
WriteJsonError(w, fmt.Errorf("invalid params, must set screenid, lineid, and path"))
377+
return
378+
}
379+
if _, err := uuid.Parse(params.ScreenId); err != nil {
380+
WriteJsonError(w, fmt.Errorf("invalid screenid: %v", err))
381+
return
382+
}
383+
if _, err := uuid.Parse(params.LineId); err != nil {
384+
WriteJsonError(w, fmt.Errorf("invalid lineid: %v", err))
385+
return
386+
}
387+
_, cmd, err := sstore.GetLineCmdByLineId(r.Context(), params.ScreenId, params.LineId)
388+
if err != nil {
389+
WriteJsonError(w, fmt.Errorf("cannot retrieve line/cmd: %v", err))
390+
return
391+
}
392+
if cmd == nil {
393+
WriteJsonError(w, fmt.Errorf("line not found"))
394+
return
395+
}
396+
if cmd.Remote.RemoteId == "" {
397+
WriteJsonError(w, fmt.Errorf("invalid line, no remote"))
398+
return
399+
}
400+
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
401+
if msh == nil {
402+
WriteJsonError(w, fmt.Errorf("invalid line, cannot resolve remote"))
403+
return
404+
}
405+
cwd := cmd.FeState["cwd"]
406+
writePk := packet.MakeWriteFilePacket()
407+
writePk.ReqId = uuid.New().String()
408+
writePk.UseTemp = params.UseTemp
409+
if filepath.IsAbs(params.Path) {
410+
writePk.Path = params.Path
411+
} else {
412+
writePk.Path = filepath.Join(cwd, params.Path)
413+
}
414+
iter, err := msh.PacketRpcIter(r.Context(), writePk)
415+
if err != nil {
416+
WriteJsonError(w, fmt.Errorf("error: %v", err))
417+
return
418+
}
419+
// first packet should be WriteFileReady
420+
readyIf, err := iter.Next(r.Context())
421+
if err != nil {
422+
WriteJsonError(w, fmt.Errorf("error while getting ready response: %w", err))
423+
return
424+
}
425+
readyPk, ok := readyIf.(*packet.WriteFileReadyPacketType)
426+
if !ok {
427+
WriteJsonError(w, fmt.Errorf("bad ready packet received: %T", readyIf))
428+
return
429+
}
430+
if readyPk.Error != "" {
431+
WriteJsonError(w, fmt.Errorf("ready error: %s", readyPk.Error))
432+
return
433+
}
434+
var buffer [server.MaxFileDataPacketSize]byte
435+
bufSlice := buffer[:]
436+
for {
437+
dataPk := packet.MakeFileDataPacket(writePk.ReqId)
438+
nr, err := io.ReadFull(mpFile, bufSlice)
439+
if err == io.ErrUnexpectedEOF || err == io.EOF {
440+
dataPk.Eof = true
441+
} else if err != nil {
442+
dataErr := fmt.Errorf("error reading file data: %v", err)
443+
dataPk.Error = dataErr.Error()
444+
msh.SendFileData(dataPk)
445+
WriteJsonError(w, dataErr)
446+
return
447+
}
448+
if nr > 0 {
449+
dataPk.Data = make([]byte, nr)
450+
copy(dataPk.Data, bufSlice[0:nr])
451+
}
452+
msh.SendFileData(dataPk)
453+
if dataPk.Eof {
454+
break
455+
}
456+
// slight throttle for sending packets
457+
time.Sleep(10 * time.Millisecond)
458+
}
459+
doneIf, err := iter.Next(r.Context())
460+
if err != nil {
461+
WriteJsonError(w, fmt.Errorf("error while getting done response: %w", err))
462+
return
463+
}
464+
donePk, ok := doneIf.(*packet.WriteFileDonePacketType)
465+
if !ok {
466+
WriteJsonError(w, fmt.Errorf("bad done packet received: %T", doneIf))
467+
return
468+
}
469+
if donePk.Error != "" {
470+
WriteJsonError(w, fmt.Errorf("dne error: %s", donePk.Error))
471+
return
472+
}
473+
WriteJsonSuccess(w, nil)
474+
return
475+
}
476+
477+
func HandleReadFile(w http.ResponseWriter, r *http.Request) {
478+
qvals := r.URL.Query()
479+
screenId := qvals.Get("screenid")
480+
lineId := qvals.Get("lineid")
481+
path := qvals.Get("path") // validate path?
482+
contentType := qvals.Get("mimetype")
483+
if contentType == "" {
484+
contentType = "application/octet-stream"
485+
}
486+
if screenId == "" || lineId == "" {
487+
w.WriteHeader(500)
488+
w.Write([]byte(fmt.Sprintf("must specify sessionid, screenid, and lineid")))
489+
return
490+
}
491+
if path == "" {
492+
w.WriteHeader(500)
493+
w.Write([]byte(fmt.Sprintf("must specify path")))
494+
return
495+
}
496+
if _, err := uuid.Parse(screenId); err != nil {
497+
w.WriteHeader(500)
498+
w.Write([]byte(fmt.Sprintf("invalid screenid: %v", err)))
499+
return
500+
}
501+
if _, err := uuid.Parse(lineId); err != nil {
502+
w.WriteHeader(500)
503+
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
504+
return
505+
}
506+
if !ContentTypeHeaderValidRe.MatchString(contentType) {
507+
w.WriteHeader(500)
508+
w.Write([]byte(fmt.Sprintf("invalid mimetype specified")))
509+
return
510+
}
511+
_, cmd, err := sstore.GetLineCmdByLineId(r.Context(), screenId, lineId)
512+
if err != nil {
513+
w.WriteHeader(500)
514+
w.Write([]byte(fmt.Sprintf("invalid lineid: %v", err)))
515+
return
516+
}
517+
if cmd == nil {
518+
w.WriteHeader(500)
519+
w.Write([]byte(fmt.Sprintf("invalid line, no cmd")))
520+
return
521+
}
522+
if cmd.Remote.RemoteId == "" {
523+
w.WriteHeader(500)
524+
w.Write([]byte(fmt.Sprintf("invalid line, no remote")))
525+
return
526+
}
527+
streamPk := packet.MakeStreamFilePacket()
528+
streamPk.ReqId = uuid.New().String()
529+
cwd := cmd.FeState["cwd"]
530+
if filepath.IsAbs(path) {
531+
streamPk.Path = path
532+
} else {
533+
streamPk.Path = filepath.Join(cwd, path)
534+
}
535+
msh := remote.GetRemoteById(cmd.Remote.RemoteId)
536+
if msh == nil {
537+
w.WriteHeader(500)
538+
w.Write([]byte(fmt.Sprintf("invalid line, cannot resolve remote")))
539+
return
540+
}
541+
iter, err := msh.StreamFile(r.Context(), streamPk)
542+
if err != nil {
543+
w.WriteHeader(500)
544+
w.Write([]byte(fmt.Sprintf("error trying to stream file: %v", err)))
545+
return
546+
}
547+
defer iter.Close()
548+
respIf, err := iter.Next(r.Context())
549+
if err != nil {
550+
w.WriteHeader(500)
551+
w.Write([]byte(fmt.Sprintf("error getting streamfile response: %v", err)))
552+
return
553+
}
554+
resp, ok := respIf.(*packet.StreamFileResponseType)
555+
if !ok {
556+
w.WriteHeader(500)
557+
w.Write([]byte(fmt.Sprintf("bad response packet type: %T", respIf)))
558+
return
559+
}
560+
if resp.Error != "" {
561+
w.WriteHeader(500)
562+
w.Write([]byte(fmt.Sprintf("error response: %s", resp.Error)))
563+
return
564+
}
565+
infoJson, _ := json.Marshal(resp.Info)
566+
w.Header().Set("X-FileInfo", base64.StdEncoding.EncodeToString(infoJson))
567+
w.Header().Set("Content-Type", contentType)
568+
w.WriteHeader(http.StatusOK)
569+
for {
570+
dataPkIf, err := iter.Next(r.Context())
571+
if err != nil {
572+
log.Printf("error in read-file while getting data: %v\n", err)
573+
break
574+
}
575+
if dataPkIf == nil {
576+
break
577+
}
578+
dataPk, ok := dataPkIf.(*packet.FileDataPacketType)
579+
if !ok {
580+
log.Printf("error in read-file, invalid data packet type: %T", dataPkIf)
581+
break
582+
}
583+
if dataPk.Error != "" {
584+
log.Printf("in read-file, data packet error: %s", dataPk.Error)
585+
break
586+
}
587+
w.Write(dataPk.Data)
588+
}
589+
return
590+
}
591+
315592
func WriteJsonError(w http.ResponseWriter, errVal error) {
316593
w.Header().Set("Content-Type", "application/json")
317594
w.WriteHeader(200)
@@ -576,6 +853,8 @@ func main() {
576853
gr.HandleFunc("/api/get-client-data", AuthKeyWrap(HandleGetClientData))
577854
gr.HandleFunc("/api/set-winsize", AuthKeyWrap(HandleSetWinSize))
578855
gr.HandleFunc("/api/log-active-state", AuthKeyWrap(HandleLogActiveState))
856+
gr.HandleFunc("/api/read-file", AuthKeyWrap(HandleReadFile))
857+
gr.HandleFunc("/api/write-file", AuthKeyWrap(HandleWriteFile)).Methods("POST")
579858
serverAddr := MainServerAddr
580859
if scbase.IsDevMode() {
581860
serverAddr = MainServerDevAddr

0 commit comments

Comments
 (0)