@@ -2,14 +2,19 @@ package main
2
2
3
3
import (
4
4
"context"
5
+ "encoding/base64"
5
6
"encoding/json"
6
7
"errors"
7
8
"fmt"
9
+ "io"
8
10
"io/fs"
9
11
"log"
12
+ "mime/multipart"
10
13
"net/http"
11
14
"os"
12
15
"os/signal"
16
+ "path/filepath"
17
+ "regexp"
13
18
"runtime/debug"
14
19
"strconv"
15
20
"strings"
@@ -20,6 +25,8 @@ import (
20
25
"github.com/google/uuid"
21
26
"github.com/gorilla/mux"
22
27
28
+ "github.com/commandlinedev/apishell/pkg/packet"
29
+ "github.com/commandlinedev/apishell/pkg/server"
23
30
"github.com/commandlinedev/prompt-server/pkg/cmdrunner"
24
31
"github.com/commandlinedev/prompt-server/pkg/pcloud"
25
32
"github.com/commandlinedev/prompt-server/pkg/remote"
@@ -49,11 +56,14 @@ const InitialTelemetryWait = 30 * time.Second
49
56
const TelemetryTick = 30 * time .Minute
50
57
const TelemetryInterval = 8 * time .Hour
51
58
59
+ const MaxWriteFileMemSize = 20 * (1024 * 1024 ) // 20M
60
+
52
61
var GlobalLock = & sync.Mutex {}
53
62
var WSStateMap = make (map [string ]* scws.WSState ) // clientid -> WsState
54
63
var GlobalAuthKey string
55
64
var BuildTime = "0"
56
65
var shutdownOnce sync.Once
66
+ var ContentTypeHeaderValidRe = regexp .MustCompile (`^\w+/[\w.+-]+$` )
57
67
58
68
type ClientActiveState struct {
59
69
Fg bool `json:"fg"`
@@ -312,6 +322,273 @@ func HandleGetPtyOut(w http.ResponseWriter, r *http.Request) {
312
322
w .Write (data )
313
323
}
314
324
325
+ type writeFileParamsType struct {
326
+ ScreenId string `json:"screenid"`
327
+ LineId string `json:"lineid"`
328
+ Path string `json:"path"`
329
+ UseTemp bool `json:"usetemp,omitempty"`
330
+ }
331
+
332
+ func parseWriteFileParams (r * http.Request ) (* writeFileParamsType , multipart.File , error ) {
333
+ err := r .ParseMultipartForm (MaxWriteFileMemSize )
334
+ if err != nil {
335
+ return nil , nil , fmt .Errorf ("cannot parse multipart form data: %v" , err )
336
+ }
337
+ form := r .MultipartForm
338
+ if len (form .Value ["params" ]) == 0 {
339
+ return nil , nil , fmt .Errorf ("no params found" )
340
+ }
341
+ paramsStr := form .Value ["params" ][0 ]
342
+ var params writeFileParamsType
343
+ err = json .Unmarshal ([]byte (paramsStr ), & params )
344
+ if err != nil {
345
+ return nil , nil , fmt .Errorf ("bad params json: %v" , err )
346
+ }
347
+ if len (form .File ["data" ]) == 0 {
348
+ return nil , nil , fmt .Errorf ("no data found" )
349
+ }
350
+ fileHeader := form .File ["data" ][0 ]
351
+ file , err := fileHeader .Open ()
352
+ if err != nil {
353
+ return nil , nil , fmt .Errorf ("error opening multipart data file: %v" , err )
354
+ }
355
+ return & params , file , nil
356
+ }
357
+
358
+ func HandleWriteFile (w http.ResponseWriter , r * http.Request ) {
359
+ defer func () {
360
+ r := recover ()
361
+ if r == nil {
362
+ return
363
+ }
364
+ log .Printf ("[error] in write-file: %v\n " , r )
365
+ debug .PrintStack ()
366
+ WriteJsonError (w , fmt .Errorf ("panic: %v" , r ))
367
+ return
368
+ }()
369
+ w .Header ().Set ("Cache-Control" , "no-cache" )
370
+ params , mpFile , err := parseWriteFileParams (r )
371
+ if err != nil {
372
+ WriteJsonError (w , fmt .Errorf ("error parsing multipart form params: %w" , err ))
373
+ return
374
+ }
375
+ if params .ScreenId == "" || params .LineId == "" || params .Path == "" {
376
+ WriteJsonError (w , fmt .Errorf ("invalid params, must set screenid, lineid, and path" ))
377
+ return
378
+ }
379
+ if _ , err := uuid .Parse (params .ScreenId ); err != nil {
380
+ WriteJsonError (w , fmt .Errorf ("invalid screenid: %v" , err ))
381
+ return
382
+ }
383
+ if _ , err := uuid .Parse (params .LineId ); err != nil {
384
+ WriteJsonError (w , fmt .Errorf ("invalid lineid: %v" , err ))
385
+ return
386
+ }
387
+ _ , cmd , err := sstore .GetLineCmdByLineId (r .Context (), params .ScreenId , params .LineId )
388
+ if err != nil {
389
+ WriteJsonError (w , fmt .Errorf ("cannot retrieve line/cmd: %v" , err ))
390
+ return
391
+ }
392
+ if cmd == nil {
393
+ WriteJsonError (w , fmt .Errorf ("line not found" ))
394
+ return
395
+ }
396
+ if cmd .Remote .RemoteId == "" {
397
+ WriteJsonError (w , fmt .Errorf ("invalid line, no remote" ))
398
+ return
399
+ }
400
+ msh := remote .GetRemoteById (cmd .Remote .RemoteId )
401
+ if msh == nil {
402
+ WriteJsonError (w , fmt .Errorf ("invalid line, cannot resolve remote" ))
403
+ return
404
+ }
405
+ cwd := cmd .FeState ["cwd" ]
406
+ writePk := packet .MakeWriteFilePacket ()
407
+ writePk .ReqId = uuid .New ().String ()
408
+ writePk .UseTemp = params .UseTemp
409
+ if filepath .IsAbs (params .Path ) {
410
+ writePk .Path = params .Path
411
+ } else {
412
+ writePk .Path = filepath .Join (cwd , params .Path )
413
+ }
414
+ iter , err := msh .PacketRpcIter (r .Context (), writePk )
415
+ if err != nil {
416
+ WriteJsonError (w , fmt .Errorf ("error: %v" , err ))
417
+ return
418
+ }
419
+ // first packet should be WriteFileReady
420
+ readyIf , err := iter .Next (r .Context ())
421
+ if err != nil {
422
+ WriteJsonError (w , fmt .Errorf ("error while getting ready response: %w" , err ))
423
+ return
424
+ }
425
+ readyPk , ok := readyIf .(* packet.WriteFileReadyPacketType )
426
+ if ! ok {
427
+ WriteJsonError (w , fmt .Errorf ("bad ready packet received: %T" , readyIf ))
428
+ return
429
+ }
430
+ if readyPk .Error != "" {
431
+ WriteJsonError (w , fmt .Errorf ("ready error: %s" , readyPk .Error ))
432
+ return
433
+ }
434
+ var buffer [server .MaxFileDataPacketSize ]byte
435
+ bufSlice := buffer [:]
436
+ for {
437
+ dataPk := packet .MakeFileDataPacket (writePk .ReqId )
438
+ nr , err := io .ReadFull (mpFile , bufSlice )
439
+ if err == io .ErrUnexpectedEOF || err == io .EOF {
440
+ dataPk .Eof = true
441
+ } else if err != nil {
442
+ dataErr := fmt .Errorf ("error reading file data: %v" , err )
443
+ dataPk .Error = dataErr .Error ()
444
+ msh .SendFileData (dataPk )
445
+ WriteJsonError (w , dataErr )
446
+ return
447
+ }
448
+ if nr > 0 {
449
+ dataPk .Data = make ([]byte , nr )
450
+ copy (dataPk .Data , bufSlice [0 :nr ])
451
+ }
452
+ msh .SendFileData (dataPk )
453
+ if dataPk .Eof {
454
+ break
455
+ }
456
+ // slight throttle for sending packets
457
+ time .Sleep (10 * time .Millisecond )
458
+ }
459
+ doneIf , err := iter .Next (r .Context ())
460
+ if err != nil {
461
+ WriteJsonError (w , fmt .Errorf ("error while getting done response: %w" , err ))
462
+ return
463
+ }
464
+ donePk , ok := doneIf .(* packet.WriteFileDonePacketType )
465
+ if ! ok {
466
+ WriteJsonError (w , fmt .Errorf ("bad done packet received: %T" , doneIf ))
467
+ return
468
+ }
469
+ if donePk .Error != "" {
470
+ WriteJsonError (w , fmt .Errorf ("dne error: %s" , donePk .Error ))
471
+ return
472
+ }
473
+ WriteJsonSuccess (w , nil )
474
+ return
475
+ }
476
+
477
+ func HandleReadFile (w http.ResponseWriter , r * http.Request ) {
478
+ qvals := r .URL .Query ()
479
+ screenId := qvals .Get ("screenid" )
480
+ lineId := qvals .Get ("lineid" )
481
+ path := qvals .Get ("path" ) // validate path?
482
+ contentType := qvals .Get ("mimetype" )
483
+ if contentType == "" {
484
+ contentType = "application/octet-stream"
485
+ }
486
+ if screenId == "" || lineId == "" {
487
+ w .WriteHeader (500 )
488
+ w .Write ([]byte (fmt .Sprintf ("must specify sessionid, screenid, and lineid" )))
489
+ return
490
+ }
491
+ if path == "" {
492
+ w .WriteHeader (500 )
493
+ w .Write ([]byte (fmt .Sprintf ("must specify path" )))
494
+ return
495
+ }
496
+ if _ , err := uuid .Parse (screenId ); err != nil {
497
+ w .WriteHeader (500 )
498
+ w .Write ([]byte (fmt .Sprintf ("invalid screenid: %v" , err )))
499
+ return
500
+ }
501
+ if _ , err := uuid .Parse (lineId ); err != nil {
502
+ w .WriteHeader (500 )
503
+ w .Write ([]byte (fmt .Sprintf ("invalid lineid: %v" , err )))
504
+ return
505
+ }
506
+ if ! ContentTypeHeaderValidRe .MatchString (contentType ) {
507
+ w .WriteHeader (500 )
508
+ w .Write ([]byte (fmt .Sprintf ("invalid mimetype specified" )))
509
+ return
510
+ }
511
+ _ , cmd , err := sstore .GetLineCmdByLineId (r .Context (), screenId , lineId )
512
+ if err != nil {
513
+ w .WriteHeader (500 )
514
+ w .Write ([]byte (fmt .Sprintf ("invalid lineid: %v" , err )))
515
+ return
516
+ }
517
+ if cmd == nil {
518
+ w .WriteHeader (500 )
519
+ w .Write ([]byte (fmt .Sprintf ("invalid line, no cmd" )))
520
+ return
521
+ }
522
+ if cmd .Remote .RemoteId == "" {
523
+ w .WriteHeader (500 )
524
+ w .Write ([]byte (fmt .Sprintf ("invalid line, no remote" )))
525
+ return
526
+ }
527
+ streamPk := packet .MakeStreamFilePacket ()
528
+ streamPk .ReqId = uuid .New ().String ()
529
+ cwd := cmd .FeState ["cwd" ]
530
+ if filepath .IsAbs (path ) {
531
+ streamPk .Path = path
532
+ } else {
533
+ streamPk .Path = filepath .Join (cwd , path )
534
+ }
535
+ msh := remote .GetRemoteById (cmd .Remote .RemoteId )
536
+ if msh == nil {
537
+ w .WriteHeader (500 )
538
+ w .Write ([]byte (fmt .Sprintf ("invalid line, cannot resolve remote" )))
539
+ return
540
+ }
541
+ iter , err := msh .StreamFile (r .Context (), streamPk )
542
+ if err != nil {
543
+ w .WriteHeader (500 )
544
+ w .Write ([]byte (fmt .Sprintf ("error trying to stream file: %v" , err )))
545
+ return
546
+ }
547
+ defer iter .Close ()
548
+ respIf , err := iter .Next (r .Context ())
549
+ if err != nil {
550
+ w .WriteHeader (500 )
551
+ w .Write ([]byte (fmt .Sprintf ("error getting streamfile response: %v" , err )))
552
+ return
553
+ }
554
+ resp , ok := respIf .(* packet.StreamFileResponseType )
555
+ if ! ok {
556
+ w .WriteHeader (500 )
557
+ w .Write ([]byte (fmt .Sprintf ("bad response packet type: %T" , respIf )))
558
+ return
559
+ }
560
+ if resp .Error != "" {
561
+ w .WriteHeader (500 )
562
+ w .Write ([]byte (fmt .Sprintf ("error response: %s" , resp .Error )))
563
+ return
564
+ }
565
+ infoJson , _ := json .Marshal (resp .Info )
566
+ w .Header ().Set ("X-FileInfo" , base64 .StdEncoding .EncodeToString (infoJson ))
567
+ w .Header ().Set ("Content-Type" , contentType )
568
+ w .WriteHeader (http .StatusOK )
569
+ for {
570
+ dataPkIf , err := iter .Next (r .Context ())
571
+ if err != nil {
572
+ log .Printf ("error in read-file while getting data: %v\n " , err )
573
+ break
574
+ }
575
+ if dataPkIf == nil {
576
+ break
577
+ }
578
+ dataPk , ok := dataPkIf .(* packet.FileDataPacketType )
579
+ if ! ok {
580
+ log .Printf ("error in read-file, invalid data packet type: %T" , dataPkIf )
581
+ break
582
+ }
583
+ if dataPk .Error != "" {
584
+ log .Printf ("in read-file, data packet error: %s" , dataPk .Error )
585
+ break
586
+ }
587
+ w .Write (dataPk .Data )
588
+ }
589
+ return
590
+ }
591
+
315
592
func WriteJsonError (w http.ResponseWriter , errVal error ) {
316
593
w .Header ().Set ("Content-Type" , "application/json" )
317
594
w .WriteHeader (200 )
@@ -576,6 +853,8 @@ func main() {
576
853
gr .HandleFunc ("/api/get-client-data" , AuthKeyWrap (HandleGetClientData ))
577
854
gr .HandleFunc ("/api/set-winsize" , AuthKeyWrap (HandleSetWinSize ))
578
855
gr .HandleFunc ("/api/log-active-state" , AuthKeyWrap (HandleLogActiveState ))
856
+ gr .HandleFunc ("/api/read-file" , AuthKeyWrap (HandleReadFile ))
857
+ gr .HandleFunc ("/api/write-file" , AuthKeyWrap (HandleWriteFile )).Methods ("POST" )
579
858
serverAddr := MainServerAddr
580
859
if scbase .IsDevMode () {
581
860
serverAddr = MainServerDevAddr
0 commit comments