diff --git a/internal/command/exec/exec.go b/internal/command/exec/exec.go new file mode 100644 index 00000000..d0f48be1 --- /dev/null +++ b/internal/command/exec/exec.go @@ -0,0 +1,14 @@ +package exec + +import "github.com/spf13/cobra" + +func NewCommand() *cobra.Command { + command := &cobra.Command{ + Use: "exec", + Short: "Execute commands inside resources", + } + + command.AddCommand(newExecVMCommand()) + + return command +} diff --git a/internal/command/exec/vm.go b/internal/command/exec/vm.go new file mode 100644 index 00000000..a9ca99c8 --- /dev/null +++ b/internal/command/exec/vm.go @@ -0,0 +1,288 @@ +package exec + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/signal" + "syscall" + "time" + + "github.com/cirruslabs/orchard/internal/execstream" + "github.com/cirruslabs/orchard/pkg/client" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/spf13/cobra" + "golang.org/x/term" +) + +var ( + vmTimeout time.Duration + vmInteractive bool + vmTTY bool +) + +func newExecVMCommand() *cobra.Command { + command := &cobra.Command{ + Use: "vm VM_NAME COMMAND [ARGS...]", + Short: "Execute a command inside the VM", + Args: cobra.MinimumNArgs(2), + RunE: runExecVM, + } + + command.Flags().DurationVarP(&vmTimeout, "timeout", "w", 60*time.Second, + "time to wait for the VM to reach running state") + command.Flags().BoolVarP(&vmInteractive, "interactive", "i", false, + "attach local standard input to the remote command") + command.Flags().BoolVarP(&vmTTY, "tty", "t", false, + "allocate a pseudo-terminal on the remote end") + + return command +} + +func runExecVM(cmd *cobra.Command, args []string) error { + cmd.SilenceUsage = true + + name := args[0] + commandArgs := args[1:] + + client, err := client.New() + if err != nil { + return err + } + + ctx := cmd.Context() + + if err := waitForVMRunning(ctx, client, name, vmTimeout); err != nil { + return err + } + + rows, cols := uint32(0), uint32(0) + if vmTTY { + width, height, err := term.GetSize(int(os.Stdout.Fd())) + if err == nil { + cols = uint32(width) + rows = uint32(height) + } + } + + interactive := vmInteractive || vmTTY + + waitSeconds := uint16(vmTimeout / time.Second) + if waitSeconds == 0 { + waitSeconds = 1 + } + + conn, err := client.VMs().Exec(ctx, name, commandArgs, interactive, vmTTY, rows, cols, waitSeconds) + if err != nil { + return fmt.Errorf("failed to start exec session: %w", err) + } + defer conn.Close() + + decoder := execstream.NewDecoder(conn) + encoder := execstream.NewEncoder(conn) + + stdinCh := make(chan error, 1) + resizeCh := make(chan error, 1) + + if vmInteractive || vmTTY { + if vmTTY { + stdinFD := int(os.Stdin.Fd()) + state, err := term.MakeRaw(stdinFD) + if err != nil { + return fmt.Errorf("failed to put terminal into raw mode: %w", err) + } + defer func() { + _ = term.Restore(stdinFD, state) + }() + + go monitorTerminalResize(ctx, encoder, resizeCh) + } + + go streamStdin(ctx, encoder, stdinCh) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(sigCh) + + var exitCode int32 + +loop: + for { + var frame execstream.Frame + + err := execstream.ReadFrame(decoder, &frame) + if err != nil { + if errors.Is(err, io.EOF) { + break loop + } + return fmt.Errorf("exec session read failed: %w", err) + } + + switch frame.Type { + case execstream.FrameTypeStdout: + if len(frame.Data) > 0 { + if _, err := os.Stdout.Write(frame.Data); err != nil { + return err + } + } + case execstream.FrameTypeStderr: + if len(frame.Data) > 0 { + if vmTTY { + if _, err := os.Stdout.Write(frame.Data); err != nil { + return err + } + } else { + if _, err := os.Stderr.Write(frame.Data); err != nil { + return err + } + } + } + case execstream.FrameTypeExit: + if frame.Exit != nil { + exitCode = frame.Exit.Code + } + break loop + case execstream.FrameTypeError: + return fmt.Errorf("exec error: %s", frame.Error) + } + } + + select { + case err := <-stdinCh: + if err != nil && !errors.Is(err, context.Canceled) { + return err + } + default: + } + + select { + case err := <-resizeCh: + if err != nil && !errors.Is(err, context.Canceled) { + return err + } + default: + } + + if exitCode != 0 { + os.Exit(int(exitCode)) + } + + return nil +} + +func waitForVMRunning(ctx context.Context, client *client.Client, name string, timeout time.Duration) error { + if timeout <= 0 { + timeout = time.Second + } + + deadline := time.Now().Add(timeout) + + for { + vm, err := client.VMs().Get(ctx, name) + if err != nil { + return err + } + + switch vm.Status { + case v1.VMStatusRunning: + return nil + case v1.VMStatusFailed: + return fmt.Errorf("VM %s is in failed state: %s", name, vm.StatusMessage) + } + + if time.Now().After(deadline) { + return fmt.Errorf("VM %s did not reach running state within %s", name, timeout) + } + + select { + case <-time.After(1 * time.Second): + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func streamStdin(ctx context.Context, encoder *json.Encoder, errCh chan<- error) { + reader := bufio.NewReader(os.Stdin) + + for { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + default: + } + + buf := make([]byte, 4096) + n, err := reader.Read(buf) + if n > 0 { + if err := execstream.WriteFrame(encoder, &execstream.Frame{ + Type: execstream.FrameTypeStdin, + Data: buf[:n], + }); err != nil { + errCh <- err + return + } + } + + if errors.Is(err, io.EOF) { + execstream.WriteFrame(encoder, &execstream.Frame{Type: execstream.FrameTypeStdin}) + errCh <- nil + return + } + + if err != nil { + errCh <- err + return + } + } +} + +func monitorTerminalResize(ctx context.Context, encoder *json.Encoder, errCh chan<- error) { + stdoutFD := int(os.Stdout.Fd()) + prevWidth, prevHeight, err := term.GetSize(stdoutFD) + if err != nil { + errCh <- err + return + } + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case <-ticker.C: + width, height, err := term.GetSize(stdoutFD) + if err != nil { + errCh <- err + return + } + + if width == prevWidth && height == prevHeight { + continue + } + + if err := execstream.WriteFrame(encoder, &execstream.Frame{ + Type: execstream.FrameTypeResize, + Terminal: &execstream.TerminalSize{ + Rows: uint32(height), + Cols: uint32(width), + }, + }); err != nil { + errCh <- err + return + } + + prevWidth = width + prevHeight = height + } + } +} diff --git a/internal/command/root.go b/internal/command/root.go index 618bc5af..f0e92ad3 100644 --- a/internal/command/root.go +++ b/internal/command/root.go @@ -6,6 +6,7 @@ import ( "github.com/cirruslabs/orchard/internal/command/create" deletepkg "github.com/cirruslabs/orchard/internal/command/deletecmd" "github.com/cirruslabs/orchard/internal/command/dev" + "github.com/cirruslabs/orchard/internal/command/exec" "github.com/cirruslabs/orchard/internal/command/get" "github.com/cirruslabs/orchard/internal/command/list" "github.com/cirruslabs/orchard/internal/command/localnetworkhelper" @@ -51,6 +52,7 @@ func NewRootCmd() *cobra.Command { pause.NewCommand(), portforward.NewCommand(), resume.NewCommand(), + exec.NewCommand(), set.NewCommand(), ssh.NewCommand(), vnc.NewCommand(), diff --git a/internal/controller/api.go b/internal/controller/api.go index 774487c1..6363d3b2 100644 --- a/internal/controller/api.go +++ b/internal/controller/api.go @@ -118,6 +118,9 @@ func (controller *Controller) initAPI() *gin.Engine { v1.GET("/rpc/port-forward", func(c *gin.Context) { controller.rpcPortForward(c).Respond(c) }) + v1.GET("/rpc/exec", func(c *gin.Context) { + controller.rpcExec(c).Respond(c) + }) v1.POST("/rpc/resolve-ip", func(c *gin.Context) { controller.rpcResolveIP(c).Respond(c) }) @@ -138,6 +141,9 @@ func (controller *Controller) initAPI() *gin.Engine { v1.GET("/vms/:name/port-forward", func(c *gin.Context) { controller.portForwardVM(c).Respond(c) }) + v1.GET("/vms/:name/exec", func(c *gin.Context) { + controller.execVM(c).Respond(c) + }) v1.GET("/vms/:name/ip", func(c *gin.Context) { controller.ip(c).Respond(c) }) diff --git a/internal/controller/api_controller.go b/internal/controller/api_controller.go index 1fdc2cc5..90d6e200 100644 --- a/internal/controller/api_controller.go +++ b/internal/controller/api_controller.go @@ -23,6 +23,8 @@ func (controller *Controller) controllerInfo(ctx *gin.Context) responder.Respond capabilities = append(capabilities, v1pkg.ControllerCapabilityRPCV2) } + capabilities = append(capabilities, v1pkg.ControllerCapabilityExec) + return responder.JSON(http.StatusOK, &v1pkg.ControllerInfo{ Version: version.Version, Commit: version.Commit, diff --git a/internal/controller/api_rpc_exec.go b/internal/controller/api_rpc_exec.go new file mode 100644 index 00000000..a12c12dd --- /dev/null +++ b/internal/controller/api_rpc_exec.go @@ -0,0 +1,59 @@ +package controller + +import ( + "context" + "github.com/cirruslabs/orchard/internal/controller/rendezvous" + "github.com/cirruslabs/orchard/internal/responder" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "net" + "time" +) + +func (controller *Controller) rpcExec(ctx *gin.Context) responder.Responder { + if responder := controller.authorize(ctx, v1.ServiceAccountRoleComputeWrite); responder != nil { + return responder + } + + session := ctx.Query("session") + errorMessage := ctx.Query("errorMessage") + + wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + return responder.Error(err) + } + defer func() { + _ = wsConn.CloseNow() + }() + + proxyCtx, err := controller.connRendezvous.Respond(session, rendezvous.ResultWithErrorMessage[net.Conn]{ + Result: websocket.NetConn(ctx, wsConn, websocket.MessageBinary), + ErrorMessage: errorMessage, + }) + if err != nil { + return controller.wsError(wsConn, websocket.StatusInternalError, "exec RPC", + "failure to respond with the established WebSocket connection", err) + } + + for { + select { + case <-proxyCtx.Done(): + return responder.Empty() + case <-time.After(controller.pingInterval): + pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second) + + if err := wsConn.Ping(pingCtx); err != nil { + controller.logger.Warnf("exec RPC: failed to ping the worker, "+ + "connection might time out: %v", err) + } + + pingCtxCancel() + case <-ctx.Done(): + return controller.wsErrorNoClose("exec RPC", + "worker unexpectedly disconnected", ctx.Err()) + } + } +} diff --git a/internal/controller/api_rpc_watch.go b/internal/controller/api_rpc_watch.go index d7d0d180..5c406350 100644 --- a/internal/controller/api_rpc_watch.go +++ b/internal/controller/api_rpc_watch.go @@ -71,6 +71,24 @@ func (controller *Controller) rpcWatch(ctx *gin.Context) responder.Responder { Session: typedAction.ResolveIpAction.Session, VMUID: typedAction.ResolveIpAction.VmUid, } + case *rpc.WatchInstruction_ExecAction: + execAction := typedAction.ExecAction + + watchInstruction.ExecAction = &v1.ExecAction{ + Session: execAction.Session, + VMUID: execAction.VmUid, + Command: execAction.Command, + Args: execAction.Args, + Interactive: execAction.Interactive, + TTY: execAction.Tty, + } + + if execAction.TerminalSize != nil { + watchInstruction.ExecAction.Terminal = &v1.TerminalSize{ + Rows: execAction.TerminalSize.Rows, + Cols: execAction.TerminalSize.Cols, + } + } default: continue } diff --git a/internal/controller/api_vms_exec.go b/internal/controller/api_vms_exec.go new file mode 100644 index 00000000..f41c911f --- /dev/null +++ b/internal/controller/api_vms_exec.go @@ -0,0 +1,357 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "fmt" + storepkg "github.com/cirruslabs/orchard/internal/controller/store" + "github.com/cirruslabs/orchard/internal/execstream" + "github.com/cirruslabs/orchard/internal/netconncancel" + "github.com/cirruslabs/orchard/internal/responder" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/cirruslabs/orchard/rpc" + "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "io" + "net/http" + "strconv" + "time" +) + +func (controller *Controller) execVM(ctx *gin.Context) responder.Responder { + if responder := controller.authorize(ctx, v1.ServiceAccountRoleComputeWrite); responder != nil { + return responder + } + + name := ctx.Param("name") + + command := ctx.Query("command") + if command == "" { + return responder.Code(http.StatusBadRequest) + } + + args := ctx.QueryArray("arg") + + interactive, err := parseBoolWithDefault(ctx.Query("interactive"), false) + if err != nil { + return responder.Code(http.StatusBadRequest) + } + + tty, err := parseBoolWithDefault(ctx.Query("tty"), false) + if err != nil { + return responder.Code(http.StatusBadRequest) + } + + if tty { + interactive = true + } + + rows, err := parseUint32(ctx.Query("rows")) + if err != nil { + return responder.Code(http.StatusBadRequest) + } + + cols, err := parseUint32(ctx.Query("cols")) + if err != nil { + return responder.Code(http.StatusBadRequest) + } + + waitRaw := ctx.DefaultQuery("wait", "10") + wait, err := strconv.ParseUint(waitRaw, 10, 16) + if err != nil { + return responder.Code(http.StatusBadRequest) + } + + waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second) + defer waitContextCancel() + + vm, responderImpl := controller.waitForVM(waitContext, name) + if responderImpl != nil { + return responderImpl + } + + var workerResource *v1.Worker + + if responderImpl := controller.storeView(func(txn storepkg.Transaction) responder.Responder { + var err error + + workerResource, err = txn.GetWorker(vm.Worker) + if err != nil { + return responder.Error(err) + } + + return nil + }); responderImpl != nil { + return responderImpl + } + + if workerResource == nil || !workerResource.Capabilities.Has(v1.WorkerCapabilityExec) { + return responder.JSON(http.StatusNotImplemented, + NewErrorResponse("worker %s does not support exec", vm.Worker)) + } + + rendezvousCtx, rendezvousCtxCancel := context.WithCancel(ctx) + defer rendezvousCtxCancel() + + session := uuid.New().String() + + boomerangConnCh, cancel := controller.connRendezvous.Request(rendezvousCtx, session) + defer cancel() + + var terminalSize *rpc.WatchInstruction_Exec_TerminalSize + if rows > 0 && cols > 0 { + terminalSize = &rpc.WatchInstruction_Exec_TerminalSize{ + Rows: rows, + Cols: cols, + } + } + + err = controller.workerNotifier.Notify(waitContext, vm.Worker, &rpc.WatchInstruction{ + Action: &rpc.WatchInstruction_ExecAction{ + ExecAction: &rpc.WatchInstruction_Exec{ + Session: session, + VmUid: vm.UID, + Command: command, + Args: args, + Interactive: interactive, + Tty: tty, + TerminalSize: func() *rpc.WatchInstruction_Exec_TerminalSize { + if terminalSize == nil { + return nil + } + + return &rpc.WatchInstruction_Exec_TerminalSize{ + Rows: terminalSize.Rows, + Cols: terminalSize.Cols, + } + }(), + }, + }, + }) + if err != nil { + controller.logger.Warnf("failed to request exec session from the worker %s: %v", + vm.Worker, err) + + return responder.Code(http.StatusServiceUnavailable) + } + + select { + case rendezvousResponse := <-boomerangConnCh: + if rendezvousResponse.ErrorMessage != "" { + return responder.Error(fmt.Errorf("failed to establish exec session on the worker: %s", + rendezvousResponse.ErrorMessage)) + } + + if rendezvousResponse.Result == nil { + return responder.Error(errors.New("failed to establish exec session on the worker: no connection")) + } + + wsConn, err := websocket.Accept(ctx.Writer, ctx.Request, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + _ = rendezvousResponse.Result.Close() + + return responder.Error(err) + } + defer func() { + _ = wsConn.CloseNow() + }() + + workerConnWithCancel := netconncancel.New(rendezvousResponse.Result, rendezvousCtxCancel) + defer func() { + _ = workerConnWithCancel.Close() + }() + + expectedMsgType := websocket.MessageText + wsConnAsNetConn := websocket.NetConn(ctx, wsConn, expectedMsgType) + defer func() { + _ = wsConnAsNetConn.Close() + }() + + commandFrame := &execstream.Frame{ + Type: execstream.FrameTypeCommand, + Command: &execstream.Command{ + Name: command, + Args: args, + Interactive: interactive, + TTY: tty, + }, + } + if terminalSize != nil { + commandFrame.Command.Terminal = &execstream.TerminalSize{ + Rows: terminalSize.Rows, + Cols: terminalSize.Cols, + } + } + + workerEncoder := execstream.NewEncoder(workerConnWithCancel) + workerDecoder := execstream.NewDecoder(workerConnWithCancel) + clientEncoder := execstream.NewEncoder(wsConnAsNetConn) + clientDecoder := execstream.NewDecoder(wsConnAsNetConn) + + if err := execstream.WriteFrame(workerEncoder, commandFrame); err != nil { + return controller.wsError(wsConn, websocket.StatusInternalError, "exec session", + "failed to deliver command to worker", err) + } + + workerErrCh := make(chan error, 1) + clientErrCh := make(chan error, 1) + exitCh := make(chan int32, 1) + + go controller.forwardExecFromWorker(workerDecoder, clientEncoder, workerErrCh, exitCh) + go controller.forwardExecFromClient(clientDecoder, workerEncoder, clientErrCh) + + pingTicker := time.NewTicker(controller.pingInterval) + defer pingTicker.Stop() + + for { + select { + case err := <-workerErrCh: + if err == nil { + continue + } + + if errors.Is(err, context.Canceled) { + return responder.Empty() + } + + if statusErr, ok := status.FromError(err); ok && statusErr.Code() == codes.Canceled { + return responder.Empty() + } + + if errors.Is(err, io.EOF) { + return controller.wsError(wsConn, websocket.StatusInternalError, "exec session", + "worker closed the exec stream unexpectedly", err) + } + + return controller.wsError(wsConn, websocket.StatusInternalError, "exec session", + "failed while proxying worker stream", err) + case err := <-clientErrCh: + if err == nil { + continue + } + + var websocketCloseError websocket.CloseError + if errors.As(err, &websocketCloseError) { + return responder.Empty() + } + + if errors.Is(err, io.EOF) { + return responder.Empty() + } + + return controller.wsError(wsConn, websocket.StatusInternalError, "exec session", + "failed while proxying client stream", err) + case exitCode := <-exitCh: + if err := wsConn.Close(websocket.StatusNormalClosure, + fmt.Sprintf("command exited with code %d", exitCode)); err != nil { + controller.logger.Warnf("exec session: failed to close WebSocket connection: %v", err) + } + + return responder.Empty() + case <-pingTicker.C: + pingCtx, pingCtxCancel := context.WithTimeout(ctx, 5*time.Second) + + if err := wsConn.Ping(pingCtx); err != nil { + controller.logger.Warnf("exec session: failed to ping the client, "+ + "connection might time out: %v", err) + } + + pingCtxCancel() + case <-ctx.Done(): + return responder.Error(ctx.Err()) + } + } + case <-ctx.Done(): + return responder.Error(ctx.Err()) + } +} + +func parseBoolWithDefault(raw string, defaultValue bool) (bool, error) { + if raw == "" { + return defaultValue, nil + } + + value, err := strconv.ParseBool(raw) + if err != nil { + return false, err + } + + return value, nil +} + +func parseUint32(raw string) (uint32, error) { + if raw == "" { + return 0, nil + } + + value, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return 0, err + } + + return uint32(value), nil +} + +func (controller *Controller) forwardExecFromWorker( + decoder *json.Decoder, + encoder *json.Encoder, + errCh chan<- error, + exitCh chan<- int32, +) { + for { + var frame execstream.Frame + + if err := execstream.ReadFrame(decoder, &frame); err != nil { + errCh <- err + + return + } + + if err := execstream.WriteFrame(encoder, &frame); err != nil { + errCh <- err + + return + } + + if frame.Type == execstream.FrameTypeExit && frame.Exit != nil { + exitCh <- frame.Exit.Code + + return + } + } +} + +func (controller *Controller) forwardExecFromClient( + decoder *json.Decoder, + encoder *json.Encoder, + errCh chan<- error, +) { + for { + var frame execstream.Frame + + if err := execstream.ReadFrame(decoder, &frame); err != nil { + errCh <- err + + return + } + + switch frame.Type { + case execstream.FrameTypeStdin, execstream.FrameTypeResize: + if err := execstream.WriteFrame(encoder, &frame); err != nil { + errCh <- err + + return + } + default: + errCh <- fmt.Errorf("unsupported frame type %q received from client", frame.Type) + + return + } + } +} diff --git a/internal/controller/api_workers.go b/internal/controller/api_workers.go index e7480d15..cb293a78 100644 --- a/internal/controller/api_workers.go +++ b/internal/controller/api_workers.go @@ -102,6 +102,7 @@ func (controller *Controller) createWorker(ctx *gin.Context) responder.Responder dbWorker.Labels = worker.Labels dbWorker.DefaultCPU = worker.DefaultCPU dbWorker.DefaultMemory = worker.DefaultMemory + dbWorker.Capabilities = worker.Capabilities if err := txn.SetWorker(*dbWorker); err != nil { return responder.Error(err) diff --git a/internal/controller/rpc.go b/internal/controller/rpc.go index d3edc337..06af0075 100644 --- a/internal/controller/rpc.go +++ b/internal/controller/rpc.go @@ -82,6 +82,45 @@ func (controller *Controller) PortForward(stream rpc.Controller_PortForwardServe } } +func (controller *Controller) Exec(stream rpc.Controller_ExecServer) error { + if !controller.authorizeGRPC(stream.Context(), v1pkg.ServiceAccountRoleComputeWrite) { + return status.Errorf(codes.Unauthenticated, "auth failed") + } + + sessionMetadataValue := metadata.ValueFromIncomingContext(stream.Context(), rpc.MetadataWorkerExecSessionKey) + if len(sessionMetadataValue) == 0 { + return status.Errorf(codes.InvalidArgument, "no session in metadata") + } + + conn := &grpc_net_conn.Conn{ + Stream: stream, + Request: &rpc.ExecData{}, + Response: &rpc.ExecData{}, + Encode: grpc_net_conn.SimpleEncoder(func(message proto.Message) *[]byte { + return &message.(*rpc.ExecData).Data + }), + Decode: grpc_net_conn.SimpleDecoder(func(message proto.Message) *[]byte { + return &message.(*rpc.ExecData).Data + }), + } + + proxyCtx, err := controller.connRendezvous.Respond(sessionMetadataValue[0], + rendezvous.ResultWithErrorMessage[net.Conn]{ + Result: conn, + }, + ) + if err != nil { + return err + } + + select { + case <-proxyCtx.Done(): + return proxyCtx.Err() + case <-stream.Context().Done(): + return stream.Context().Err() + } +} + func (controller *Controller) ResolveIP(ctx context.Context, request *rpc.ResolveIPResult) (*emptypb.Empty, error) { if !controller.authorizeGRPC(ctx, v1pkg.ServiceAccountRoleComputeWrite) { return nil, status.Errorf(codes.Unauthenticated, "auth failed") diff --git a/internal/execstream/frame.go b/internal/execstream/frame.go new file mode 100644 index 00000000..a8e0a449 --- /dev/null +++ b/internal/execstream/frame.go @@ -0,0 +1,68 @@ +package execstream + +import ( + "encoding/json" + "io" +) + +type FrameType string + +const ( + FrameTypeCommand FrameType = "command" + FrameTypeStdin FrameType = "stdin" + FrameTypeStdout FrameType = "stdout" + FrameTypeStderr FrameType = "stderr" + FrameTypeResize FrameType = "resize" + FrameTypeExit FrameType = "exit" + FrameTypeError FrameType = "error" +) + +// Frame captures a single event flowing between controller, worker and clients. +// +// The payload is encoded as JSON where binary blobs (stdin/stdout/stderr data) are +// automatically base64-encoded by the JSON encoder. +type Frame struct { + Type FrameType `json:"type"` + + Command *Command `json:"command,omitempty"` + Data []byte `json:"data,omitempty"` + Terminal *TerminalSize `json:"terminal,omitempty"` + Exit *Exit `json:"exit,omitempty"` + Error string `json:"error,omitempty"` +} + +type Command struct { + Name string `json:"name"` + Args []string `json:"args,omitempty"` + Interactive bool `json:"interactive,omitempty"` + TTY bool `json:"tty,omitempty"` + Terminal *TerminalSize `json:"terminal,omitempty"` +} + +type TerminalSize struct { + Rows uint32 `json:"rows"` + Cols uint32 `json:"cols"` +} + +type Exit struct { + Code int32 `json:"code"` +} + +func NewEncoder(w io.Writer) *json.Encoder { + encoder := json.NewEncoder(w) + encoder.SetEscapeHTML(false) + + return encoder +} + +func NewDecoder(r io.Reader) *json.Decoder { + return json.NewDecoder(r) +} + +func WriteFrame(encoder *json.Encoder, frame *Frame) error { + return encoder.Encode(frame) +} + +func ReadFrame(decoder *json.Decoder, frame *Frame) error { + return decoder.Decode(frame) +} diff --git a/internal/tests/integration_test.go b/internal/tests/integration_test.go index 5b62e42a..83520746 100644 --- a/internal/tests/integration_test.go +++ b/internal/tests/integration_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "bytes" "context" "fmt" "net" @@ -12,6 +13,7 @@ import ( "time" "github.com/cirruslabs/orchard/internal/controller" + "github.com/cirruslabs/orchard/internal/execstream" "github.com/cirruslabs/orchard/internal/imageconstant" "github.com/cirruslabs/orchard/internal/tests/devcontroller" "github.com/cirruslabs/orchard/internal/tests/wait" @@ -26,6 +28,14 @@ import ( "golang.org/x/exp/slices" ) +func integrationTestImage() string { + if image := os.Getenv("ORCHARD_TEST_IMAGE"); image != "" { + return image + } + + return imageconstant.DefaultMacosImage +} + func TestSingleVM(t *testing.T) { devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) @@ -38,7 +48,7 @@ func TestSingleVM(t *testing.T) { Meta: v1.Meta{ Name: "test-vm", }, - Image: imageconstant.DefaultMacosImage, + Image: integrationTestImage(), CPU: 4, Memory: 8 * 1024, Headless: true, @@ -111,7 +121,7 @@ func TestFailedStartupScript(t *testing.T) { Meta: v1.Meta{ Name: "test-vm", }, - Image: imageconstant.DefaultMacosImage, + Image: integrationTestImage(), CPU: 4, Memory: 8 * 1024, Headless: true, @@ -149,7 +159,7 @@ func TestPortForwarding(t *testing.T) { Meta: v1.Meta{ Name: "test-vm", }, - Image: imageconstant.DefaultMacosImage, + Image: integrationTestImage(), CPU: 4, Memory: 8 * 1024, Headless: true, @@ -190,6 +200,81 @@ func TestPortForwarding(t *testing.T) { require.Contains(t, string(unameOutput), "Darwin arm64") } +func TestExec(t *testing.T) { + ctx := context.Background() + + devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t) + + workers, err := devClient.Workers().List(ctx) + require.NoError(t, err) + require.NotEmpty(t, workers) + require.True(t, workers[0].Capabilities.Has(v1.WorkerCapabilityExec)) + + vmName := "exec-vm" + + err = devClient.VMs().Create(ctx, &v1.VM{ + Meta: v1.Meta{ + Name: vmName, + }, + Image: integrationTestImage(), + CPU: 4, + Memory: 8 * 1024, + Headless: true, + }) + require.NoError(t, err) + + require.True(t, wait.Wait(10*time.Minute, func() bool { + vm, err := devClient.VMs().Get(ctx, vmName) + require.NoError(t, err) + t.Logf("Waiting for the VM to start. Current status: %s", vm.Status) + + return vm.Status == v1.VMStatusRunning || vm.Status == v1.VMStatusFailed + }), "failed to start VM for exec test") + + vm, err := devClient.VMs().Get(ctx, vmName) + require.NoError(t, err) + require.Equal(t, v1.VMStatusRunning, vm.Status, "VM failed to reach running state: %s", vm.StatusMessage) + + conn, err := devClient.VMs().Exec(ctx, vmName, []string{"/bin/echo", "orchard"}, + false, false, 0, 0, 180) + require.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + require.NoError(t, conn.SetDeadline(time.Now().Add(2*time.Minute))) + + decoder := execstream.NewDecoder(conn) + + var stdout bytes.Buffer + var stderr bytes.Buffer + var exitCode int32 + exitReceived := false + + for !exitReceived { + var frame execstream.Frame + + require.NoError(t, execstream.ReadFrame(decoder, &frame)) + + switch frame.Type { + case execstream.FrameTypeStdout: + stdout.Write(frame.Data) + case execstream.FrameTypeStderr: + stderr.Write(frame.Data) + case execstream.FrameTypeExit: + require.NotNil(t, frame.Exit, "exit frame missing payload") + exitCode = frame.Exit.Code + exitReceived = true + case execstream.FrameTypeError: + t.Fatalf("exec error: %s", frame.Error) + } + } + + require.Equal(t, int32(0), exitCode) + require.Contains(t, stdout.String(), "orchard") + require.Zero(t, stderr.Len()) +} + // TestSchedulerHealthCheckingNonExistentWorker ensures that scheduler // will eventually fail VMs that are scheduled on a worker that was // deleted from the API. diff --git a/internal/worker/exec.go b/internal/worker/exec.go new file mode 100644 index 00000000..5c38641d --- /dev/null +++ b/internal/worker/exec.go @@ -0,0 +1,673 @@ +package worker + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/cirruslabs/orchard/internal/execstream" + "github.com/cirruslabs/orchard/internal/worker/vmmanager" + v1 "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/cirruslabs/orchard/rpc" + guestagentrpc "github.com/cirruslabs/orchard/rpc/guestagent" + "github.com/samber/lo" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "io" + "net" + "os" + "path/filepath" + "time" +) + +type execOptions struct { + Session string + VMUID string + Command string + Args []string + Interactive bool + TTY bool + Terminal *execstream.TerminalSize +} + +func execOptionsFromProto(action *rpc.WatchInstruction_Exec) execOptions { + opts := execOptions{ + Session: action.Session, + VMUID: action.VmUid, + Command: action.Command, + Args: append([]string(nil), action.Args...), + Interactive: action.Interactive, + TTY: action.Tty, + } + + if action.TerminalSize != nil { + opts.Terminal = &execstream.TerminalSize{ + Rows: action.TerminalSize.Rows, + Cols: action.TerminalSize.Cols, + } + } + + return opts +} + +func execOptionsFromV1(action *v1.ExecAction) execOptions { + opts := execOptions{ + Session: action.Session, + VMUID: action.VMUID, + Command: action.Command, + Args: append([]string(nil), action.Args...), + Interactive: action.Interactive, + TTY: action.TTY, + } + + if action.Terminal != nil { + opts.Terminal = &execstream.TerminalSize{ + Rows: action.Terminal.Rows, + Cols: action.Terminal.Cols, + } + } + + return opts +} + +func (worker *Worker) runExecSession( + ctx context.Context, + opts execOptions, + controllerConn net.Conn, + vmHint *vmmanager.VM, +) error { + defer controllerConn.Close() + + controllerDecoder := execstream.NewDecoder(controllerConn) + controllerEncoder := execstream.NewEncoder(controllerConn) + + commandDetails := execstream.Command{ + Name: opts.Command, + Args: append([]string(nil), opts.Args...), + Interactive: opts.Interactive, + TTY: opts.TTY, + Terminal: cloneTerminalSize(opts.Terminal), + } + + var firstFrame execstream.Frame + + if err := execstream.ReadFrame(controllerDecoder, &firstFrame); err != nil { + errWrapped := fmt.Errorf("failed to read command frame: %w", err) + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + + if firstFrame.Type != execstream.FrameTypeCommand || firstFrame.Command == nil { + errWrapped := fmt.Errorf("expected command frame, got %q", firstFrame.Type) + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + + frameCommand := firstFrame.Command + + if frameCommand.Name != "" { + commandDetails.Name = frameCommand.Name + } + if len(frameCommand.Args) != 0 { + commandDetails.Args = append([]string(nil), frameCommand.Args...) + } + if frameCommand.Interactive { + commandDetails.Interactive = true + } + if frameCommand.TTY { + commandDetails.TTY = true + } + if frameCommand.Terminal != nil { + commandDetails.Terminal = &execstream.TerminalSize{ + Rows: frameCommand.Terminal.Rows, + Cols: frameCommand.Terminal.Cols, + } + } else if commandDetails.Terminal == nil && opts.Terminal != nil { + commandDetails.Terminal = cloneTerminalSize(opts.Terminal) + } + + if commandDetails.Name == "" { + errWrapped := errors.New("command name is empty") + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + + if commandDetails.Args == nil { + commandDetails.Args = []string{} + } + + vm := vmHint + + if vm == nil { + var err error + + vm, err = worker.findVMByUID(opts.VMUID) + if err != nil { + worker.sendExecErrorFrame(controllerEncoder, err.Error()) + + return err + } + } + + socketPath, err := vmControlSocketPath(vm) + if err != nil { + errWrapped := fmt.Errorf("failed to determine control socket path: %w", err) + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + + agentConn, err := worker.connectToGuestAgent(ctx, socketPath) + if err != nil { + errWrapped := fmt.Errorf("failed to connect to guest agent: %w", err) + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + defer agentConn.Close() + + var terminalSize *guestagentrpc.TerminalSize + + if commandDetails.Terminal != nil { + terminalSize = &guestagentrpc.TerminalSize{ + Rows: commandDetails.Terminal.Rows, + Cols: commandDetails.Terminal.Cols, + } + } + + commandReq := &guestagentrpc.ExecRequest{ + Type: &guestagentrpc.ExecRequest_Command_{ + Command: &guestagentrpc.ExecRequest_Command{ + Name: commandDetails.Name, + Args: append([]string(nil), commandDetails.Args...), + Interactive: commandDetails.Interactive, + Tty: commandDetails.TTY, + TerminalSize: func() *guestagentrpc.TerminalSize { + if terminalSize == nil { + return nil + } + + return &guestagentrpc.TerminalSize{ + Rows: terminalSize.Rows, + Cols: terminalSize.Cols, + } + }(), + }, + }, + } + + methods := []string{guestagentrpc.Agent_Exec_FullMethodName, "/Agent/Exec"} + var lastErr error + +outer: + for idx, method := range methods { + streamCtx, streamCancel := context.WithCancel(ctx) + + agentStream, err := worker.establishAgentExecStream(streamCtx, agentConn, method) + if err != nil { + streamCancel() + lastErr = err + if status.Code(err) == codes.Unimplemented && idx+1 < len(methods) { + worker.logger.Debugf("exec session: guest agent refused method %s, falling back", method) + continue outer + } + + errWrapped := fmt.Errorf("failed to start exec stream: %w", err) + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + + if err := agentStream.Send(commandReq); err != nil { + _ = agentStream.CloseSend() + streamCancel() + lastErr = err + if status.Code(err) == codes.Unimplemented && idx+1 < len(methods) { + worker.logger.Debugf("exec session: guest agent rejected method %s on send, falling back", method) + continue outer + } + + errWrapped := fmt.Errorf("failed to send command to guest agent: %w", err) + worker.sendExecErrorFrame(controllerEncoder, errWrapped.Error()) + + return errWrapped + } + + agentErrCh := make(chan error, 1) + exitCh := make(chan int32, 1) + + go worker.forwardAgentToController(agentStream, controllerEncoder, agentErrCh, exitCh) + + controllerErrCh := make(chan error, 1) + controllerStarted := false + timer := time.NewTimer(200 * time.Millisecond) + + for { + select { + case <-timer.C: + if !controllerStarted { + controllerStarted = true + go worker.forwardControllerToAgent(agentStream, controllerDecoder, controllerErrCh, commandDetails.TTY) + + worker.logger.Infow("exec session started", + "session", opts.Session, + "vm_uid", opts.VMUID, + "command", commandDetails.Name, + "args", commandDetails.Args, + "interactive", commandDetails.Interactive, + "tty", commandDetails.TTY, + "method", method, + ) + } + + case err := <-controllerErrCh: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + streamCancel() + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) { + worker.sendExecErrorFrame(controllerEncoder, fmt.Sprintf("controller stream error: %v", err)) + + return err + } + + return err + + case err := <-agentErrCh: + if err != nil { + if status.Code(err) == codes.Unimplemented && idx+1 < len(methods) { + lastErr = err + streamCancel() + _ = agentStream.CloseSend() + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + if controllerStarted { + select { + case ctrlErr := <-controllerErrCh: + if ctrlErr != nil && !errors.Is(ctrlErr, context.Canceled) && !errors.Is(ctrlErr, io.EOF) { + worker.logger.Debugf("exec session: controller stream closed during fallback: %v", ctrlErr) + } + default: + } + } + worker.logger.Debugf("exec session: guest agent rejected method %s mid-stream, falling back", method) + + continue outer + } + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + streamCancel() + worker.sendExecErrorFrame(controllerEncoder, fmt.Sprintf("guest agent error: %v", err)) + + return err + } + + case exitCode := <-exitCh: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + streamCancel() + worker.logger.Infow("exec session finished", + "session", opts.Session, + "vm_uid", opts.VMUID, + "exit_code", exitCode, + ) + + return nil + + case <-ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + streamCancel() + + return ctx.Err() + } + } + } + + if lastErr != nil { + worker.sendExecErrorFrame(controllerEncoder, fmt.Sprintf("guest agent error: %v", lastErr)) + + return lastErr + } + + unsupportedErr := fmt.Errorf("guest agent exec unsupported") + worker.sendExecErrorFrame(controllerEncoder, unsupportedErr.Error()) + + return unsupportedErr +} + +func (worker *Worker) forwardControllerToAgent( + agentStream guestagentrpc.Agent_ExecClient, + decoder *json.Decoder, + errCh chan<- error, + tty bool, +) { + for { + var frame execstream.Frame + + if err := execstream.ReadFrame(decoder, &frame); err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + _ = agentStream.CloseSend() + } + + errCh <- err + + return + } + + switch frame.Type { + case execstream.FrameTypeStdin: + data := frame.Data + + if tty && len(data) == 0 { + // send EOT + data = []byte{0x04} + } + + if !tty && len(data) == 0 { + if err := agentStream.CloseSend(); err != nil { + errCh <- err + } else { + errCh <- io.EOF + } + + return + } + + if err := agentStream.Send(&guestagentrpc.ExecRequest{ + Type: &guestagentrpc.ExecRequest_StandardInput{ + StandardInput: &guestagentrpc.IOChunk{Data: data}, + }, + }); err != nil { + errCh <- err + + return + } + case execstream.FrameTypeResize: + if !tty || frame.Terminal == nil { + continue + } + + if err := agentStream.Send(&guestagentrpc.ExecRequest{ + Type: &guestagentrpc.ExecRequest_TerminalResize{ + TerminalResize: &guestagentrpc.TerminalSize{ + Rows: frame.Terminal.Rows, + Cols: frame.Terminal.Cols, + }, + }, + }); err != nil { + errCh <- err + + return + } + case execstream.FrameTypeError: + errCh <- fmt.Errorf("controller reported error: %s", frame.Error) + + return + default: + // Ignore unsupported frame types + } + } +} + +func (worker *Worker) forwardAgentToController( + agentStream guestagentrpc.Agent_ExecClient, + encoder *json.Encoder, + errCh chan<- error, + exitCh chan<- int32, +) { + for { + resp, err := agentStream.Recv() + if err != nil { + errCh <- err + + return + } + + switch typed := resp.Type.(type) { + case *guestagentrpc.ExecResponse_StandardOutput: + if err := execstream.WriteFrame(encoder, &execstream.Frame{ + Type: execstream.FrameTypeStdout, + Data: typed.StandardOutput.Data, + }); err != nil { + errCh <- err + + return + } + case *guestagentrpc.ExecResponse_StandardError: + if err := execstream.WriteFrame(encoder, &execstream.Frame{ + Type: execstream.FrameTypeStderr, + Data: typed.StandardError.Data, + }); err != nil { + errCh <- err + + return + } + case *guestagentrpc.ExecResponse_Exit_: + if typed.Exit == nil { + continue + } + + if err := execstream.WriteFrame(encoder, &execstream.Frame{ + Type: execstream.FrameTypeExit, + Exit: &execstream.Exit{Code: typed.Exit.Code}, + }); err != nil { + errCh <- err + + return + } + + exitCh <- typed.Exit.Code + + return + default: + // ignore unknown payloads + } + } +} + +func (worker *Worker) findVMByUID(uid string) (*vmmanager.VM, error) { + vm, ok := lo.Find(worker.vmm.List(), func(item *vmmanager.VM) bool { + return item.Resource.UID == uid + }) + if !ok { + return nil, fmt.Errorf("VM with UID %q not found", uid) + } + + if !vm.Started() { + return nil, fmt.Errorf("VM with UID %q is not running", uid) + } + + return vm, nil +} + +func (worker *Worker) waitForVMByUID( + ctx context.Context, + uid string, + initialErr error, +) (*vmmanager.VM, error) { + worker.requestVMSyncing() + + waitCtx, waitCtxCancel := context.WithTimeout(ctx, 15*time.Second) + defer waitCtxCancel() + + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + + lastErr := initialErr + + for { + vm, err := worker.findVMByUID(uid) + if err == nil { + return vm, nil + } + + if err != nil { + lastErr = err + } + + select { + case <-waitCtx.Done(): + return nil, lastErr + case <-ticker.C: + continue + } + } +} + +func vmControlSocketPath(vm *vmmanager.VM) (string, error) { + if vm == nil { + return "", errors.New("nil VM provided") + } + + tartHome := os.Getenv("TART_HOME") + if tartHome == "" { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to determine user home directory: %w", err) + } + + tartHome = filepath.Join(homeDir, ".tart") + } + + return filepath.Join(tartHome, "vms", vm.OnDiskName().String(), "control.sock"), nil +} + +func (worker *Worker) connectToGuestAgent(ctx context.Context, socketPath string) (*grpc.ClientConn, error) { + waitCtx, waitCancel := context.WithTimeout(ctx, 2*time.Minute) + defer waitCancel() + + backoff := 500 * time.Millisecond + + for { + if _, err := os.Stat(socketPath); err != nil { + if !errors.Is(err, os.ErrNotExist) { + worker.logger.Warnf("exec session: control socket check failed: %v", err) + } + } + + attemptCtx, attemptCancel := context.WithTimeout(waitCtx, 5*time.Second) + + conn, err := grpc.NewClient( + "unix://"+socketPath, + grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) { + var dialer net.Dialer + + return dialer.DialContext(ctx, "unix", socketPath) + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err == nil { + if readyErr := waitForClientConnReady(attemptCtx, conn); readyErr == nil { + attemptCancel() + + return conn, nil + } else { + _ = conn.Close() + err = readyErr + } + } + + attemptCancel() + + if waitCtx.Err() != nil { + return nil, err + } + + worker.logger.Debugf("exec session: guest agent not ready yet: %v", err) + + select { + case <-time.After(backoff): + if backoff < 5*time.Second { + backoff *= 2 + } + case <-waitCtx.Done(): + return nil, waitCtx.Err() + } + } +} + +func (worker *Worker) sendExecErrorFrame(encoder *json.Encoder, message string) { + if encoder == nil { + return + } + + if err := execstream.WriteFrame(encoder, &execstream.Frame{ + Type: execstream.FrameTypeError, + Error: message, + }); err != nil && !errors.Is(err, io.EOF) { + worker.logger.Warnf("exec session: failed to send error frame: %v", err) + } +} + +func cloneTerminalSize(terminal *execstream.TerminalSize) *execstream.TerminalSize { + if terminal == nil { + return nil + } + + return &execstream.TerminalSize{ + Rows: terminal.Rows, + Cols: terminal.Cols, + } +} + +func (worker *Worker) establishAgentExecStream( + ctx context.Context, + conn *grpc.ClientConn, + method string, +) (guestagentrpc.Agent_ExecClient, error) { + stream, err := conn.NewStream(ctx, &guestagentrpc.Agent_ServiceDesc.Streams[0], + method, grpc.StaticMethod()) + if err != nil { + return nil, err + } + + return &grpc.GenericClientStream[guestagentrpc.ExecRequest, guestagentrpc.ExecResponse]{ + ClientStream: stream, + }, nil +} + +func waitForClientConnReady(ctx context.Context, conn *grpc.ClientConn) error { + for { + state := conn.GetState() + if state == connectivity.Ready { + return nil + } + + conn.Connect() + + if !conn.WaitForStateChange(ctx, state) { + if err := ctx.Err(); err != nil { + return err + } + + return fmt.Errorf("gRPC connection state change timed out while waiting for readiness") + } + } +} diff --git a/internal/worker/rpc.go b/internal/worker/rpc.go index d9a81be5..1b04e72a 100644 --- a/internal/worker/rpc.go +++ b/internal/worker/rpc.go @@ -59,6 +59,8 @@ func (worker *Worker) watchRPC(ctx context.Context) error { worker.requestVMSyncing() case *rpc.WatchInstruction_ResolveIpAction: go worker.handleGetIP(ctxWithMetadata, client, action.ResolveIpAction) + case *rpc.WatchInstruction_ExecAction: + go worker.handleExec(ctxWithMetadata, client, action.ExecAction) } } } @@ -184,3 +186,41 @@ func (worker *Worker) handleGetIP( return } } + +func (worker *Worker) handleExec( + ctx context.Context, + client rpc.ControllerClient, + execAction *rpc.WatchInstruction_Exec, +) { + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + grpcMetadata := metadata.Join( + worker.grpcMetadata(), + metadata.Pairs(rpc.MetadataWorkerExecSessionKey, execAction.Session), + ) + ctxWithMetadata := metadata.NewOutgoingContext(subCtx, grpcMetadata) + + stream, err := client.Exec(ctxWithMetadata) + if err != nil { + worker.logger.Warnf("exec failed: failed to call Exec() RPC method: %v", err) + + return + } + + conn := &grpc_net_conn.Conn{ + Stream: stream, + Request: &rpc.ExecData{}, + Response: &rpc.ExecData{}, + Encode: grpc_net_conn.SimpleEncoder(func(message proto.Message) *[]byte { + return &message.(*rpc.ExecData).Data + }), + Decode: grpc_net_conn.SimpleDecoder(func(message proto.Message) *[]byte { + return &message.(*rpc.ExecData).Data + }), + } + + if err := worker.runExecSession(subCtx, execOptionsFromProto(execAction), conn, nil); err != nil { + worker.logger.Warnf("exec session failed: %v", err) + } +} diff --git a/internal/worker/rpcv2.go b/internal/worker/rpcv2.go index bed413dc..3c6481f2 100644 --- a/internal/worker/rpcv2.go +++ b/internal/worker/rpcv2.go @@ -25,6 +25,8 @@ func (worker *Worker) watchRPCV2(ctx context.Context) error { worker.requestVMSyncing() } else if resolveIPAction := watchInstruction.ResolveIPAction; resolveIPAction != nil { go worker.handleGetIPV2(ctx, resolveIPAction) + } else if execAction := watchInstruction.ExecAction; execAction != nil { + go worker.handleExecV2(ctx, execAction) } case watchErr := <-watchErrCh: return watchErr @@ -157,3 +159,37 @@ func (worker *Worker) handleGetIPV2Inner( return ip, nil } + +func (worker *Worker) handleExecV2(ctx context.Context, execAction *v1.ExecAction) { + var errorMessage string + + vm, err := worker.findVMByUID(execAction.VMUID) + if err != nil { + worker.logger.Infof("exec session: VM %s not immediately available, retrying after syncing: %v", + execAction.VMUID, err) + + vm, err = worker.waitForVMByUID(ctx, execAction.VMUID, err) + if err != nil { + errorMessage = err.Error() + } + } + + if errorMessage != "" { + if _, err := worker.client.RPC().RespondExec(ctx, execAction.Session, errorMessage); err != nil { + worker.logger.Warnf("exec failed: failed to call API: %v", err) + } + + return + } + + conn, err := worker.client.RPC().RespondExec(ctx, execAction.Session, "") + if err != nil { + worker.logger.Warnf("exec failed: failed to call API: %v", err) + + return + } + + if err := worker.runExecSession(ctx, execOptionsFromV1(execAction), conn, vm); err != nil { + worker.logger.Warnf("exec session failed: %v", err) + } +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go index f0199208..70ddb141 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -230,6 +230,7 @@ func (worker *Worker) registerWorker(ctx context.Context) error { MachineID: platformUUID, DefaultCPU: worker.defaultCPU, DefaultMemory: worker.defaultMemory, + Capabilities: v1.WorkerCapabilities{v1.WorkerCapabilityExec}, }) if err != nil { return err diff --git a/pkg/client/client.go b/pkg/client/client.go index 6406d682..42c03f55 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -254,6 +254,21 @@ func (client *Client) wsRequest( ctx context.Context, path string, params map[string]string, +) (net.Conn, error) { + values := url.Values{} + + for key, value := range params { + values.Set(key, value) + } + + return client.wsRequestValues(ctx, path, values, websocket.MessageBinary) +} + +func (client *Client) wsRequestValues( + ctx context.Context, + path string, + params url.Values, + messageType websocket.MessageType, ) (net.Conn, error) { endpointURL := client.formatPath(path) @@ -265,8 +280,10 @@ func (client *Client) wsRequest( } values := endpointURL.Query() - for key, value := range params { - values.Set(key, value) + for key, valuesSlice := range params { + for _, value := range valuesSlice { + values.Add(key, value) + } } endpointURL.RawQuery = values.Encode() @@ -290,7 +307,7 @@ func (client *Client) wsRequest( return nil, err } - return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil + return websocket.NetConn(ctx, conn, messageType), nil } func (client *Client) formatPath(path string) *url.URL { diff --git a/pkg/client/rpc.go b/pkg/client/rpc.go index fa587d3c..0d20862d 100644 --- a/pkg/client/rpc.go +++ b/pkg/client/rpc.go @@ -60,6 +60,17 @@ func (service *RPCService) RespondPortForward( }) } +func (service *RPCService) RespondExec( + ctx context.Context, + session string, + errorMessage string, +) (net.Conn, error) { + return service.client.wsRequest(ctx, "rpc/exec", map[string]string{ + "session": session, + "errorMessage": errorMessage, + }) +} + func (service *RPCService) RespondIP( ctx context.Context, session string, diff --git a/pkg/client/vms.go b/pkg/client/vms.go index 372ae4cc..c75a78b6 100644 --- a/pkg/client/vms.go +++ b/pkg/client/vms.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/cirruslabs/orchard/pkg/resource/v1" + "github.com/coder/websocket" "net" "net/http" "net/url" @@ -121,6 +122,44 @@ func (service *VMsService) StreamEvents(name string) *EventStreamer { return NewEventStreamer(service.client, fmt.Sprintf("vms/%s/events", url.PathEscape(name))) } +func (service *VMsService) Exec( + ctx context.Context, + name string, + command []string, + interactive bool, + tty bool, + rows uint32, + cols uint32, + waitSeconds uint16, +) (net.Conn, error) { + if len(command) == 0 { + return nil, fmt.Errorf("command must contain at least one element") + } + + params := url.Values{} + params.Set("command", command[0]) + + for _, arg := range command[1:] { + params.Add("arg", arg) + } + + params.Set("interactive", strconv.FormatBool(interactive)) + params.Set("tty", strconv.FormatBool(tty)) + + if rows > 0 { + params.Set("rows", strconv.FormatUint(uint64(rows), 10)) + } + if cols > 0 { + params.Set("cols", strconv.FormatUint(uint64(cols), 10)) + } + if waitSeconds > 0 { + params.Set("wait", strconv.FormatUint(uint64(waitSeconds), 10)) + } + + return service.client.wsRequestValues(ctx, fmt.Sprintf("vms/%s/exec", url.PathEscape(name)), + params, websocket.MessageText) +} + func (service *VMsService) Logs(ctx context.Context, name string) (lines []string, err error) { var events []v1.Event err = service.client.request(ctx, http.MethodGet, fmt.Sprintf("vms/%s/events", url.PathEscape(name)), diff --git a/pkg/resource/v1/v1.go b/pkg/resource/v1/v1.go index 8164010b..78bf93f1 100644 --- a/pkg/resource/v1/v1.go +++ b/pkg/resource/v1/v1.go @@ -130,6 +130,7 @@ type ControllerCapability string const ( ControllerCapabilityRPCV1 ControllerCapability = "rpc-v1" ControllerCapabilityRPCV2 ControllerCapability = "rpc-v2" + ControllerCapabilityExec ControllerCapability = "exec" ) type ControllerCapabilities []ControllerCapability diff --git a/pkg/resource/v1/watch_instruction.go b/pkg/resource/v1/watch_instruction.go index f7b08fd8..2627a089 100644 --- a/pkg/resource/v1/watch_instruction.go +++ b/pkg/resource/v1/watch_instruction.go @@ -4,6 +4,7 @@ type WatchInstruction struct { PortForwardAction *PortForwardAction `json:"portForwardAction,omitempty"` SyncVMsAction *SyncVMsAction `json:"syncVMsAction,omitempty"` ResolveIPAction *ResolveIPAction `json:"resolveIPAction,omitempty"` + ExecAction *ExecAction `json:"execAction,omitempty"` } type PortForwardAction struct { @@ -20,3 +21,18 @@ type ResolveIPAction struct { Session string `json:"session"` VMUID string `json:"vmUID"` } + +type ExecAction struct { + Session string `json:"session"` + VMUID string `json:"vmUID"` + Command string `json:"command"` + Args []string `json:"args"` + Interactive bool `json:"interactive"` + TTY bool `json:"tty"` + Terminal *TerminalSize `json:"terminal,omitempty"` +} + +type TerminalSize struct { + Rows uint32 `json:"rows"` + Cols uint32 `json:"cols"` +} diff --git a/pkg/resource/v1/worker.go b/pkg/resource/v1/worker.go index b0a3b678..99dd6237 100644 --- a/pkg/resource/v1/worker.go +++ b/pkg/resource/v1/worker.go @@ -24,9 +24,29 @@ type Worker struct { // when it doesn't explicitly request a specific amount. DefaultMemory uint64 `json:"defaultMemory,omitempty"` + Capabilities WorkerCapabilities `json:"capabilities,omitempty"` + Meta } func (worker Worker) Offline(workerOfflineTimeout time.Duration) bool { return time.Since(worker.LastSeen) > workerOfflineTimeout } + +type WorkerCapability string + +const ( + WorkerCapabilityExec WorkerCapability = "exec" +) + +type WorkerCapabilities []WorkerCapability + +func (workerCapabilities WorkerCapabilities) Has(capability WorkerCapability) bool { + for _, workerCapability := range workerCapabilities { + if workerCapability == capability { + return true + } + } + + return false +} diff --git a/rpc/constants.go b/rpc/constants.go index f9cf3be2..ccff29c3 100644 --- a/rpc/constants.go +++ b/rpc/constants.go @@ -8,3 +8,5 @@ const MetadataServiceAccountTokenKey = "x-orchard-service-account-token" const MetadataWorkerNameKey = "x-orchard-worker-name" const MetadataWorkerPortForwardingSessionKey = "x-orchard-port-forwarding-session" + +const MetadataWorkerExecSessionKey = "x-orchard-exec-session" diff --git a/rpc/guestagent/guestagent.pb.go b/rpc/guestagent/guestagent.pb.go new file mode 100644 index 00000000..3647ae29 --- /dev/null +++ b/rpc/guestagent/guestagent.pb.go @@ -0,0 +1,626 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc (unknown) +// source: guestagent/guestagent.proto + +package guestagent + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ExecRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Type: + // + // *ExecRequest_Command_ + // *ExecRequest_StandardInput + // *ExecRequest_TerminalResize + Type isExecRequest_Type `protobuf_oneof:"type"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecRequest) Reset() { + *x = ExecRequest{} + mi := &file_guestagent_guestagent_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecRequest) ProtoMessage() {} + +func (x *ExecRequest) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecRequest.ProtoReflect.Descriptor instead. +func (*ExecRequest) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{0} +} + +func (x *ExecRequest) GetType() isExecRequest_Type { + if x != nil { + return x.Type + } + return nil +} + +func (x *ExecRequest) GetCommand() *ExecRequest_Command { + if x != nil { + if x, ok := x.Type.(*ExecRequest_Command_); ok { + return x.Command + } + } + return nil +} + +func (x *ExecRequest) GetStandardInput() *IOChunk { + if x != nil { + if x, ok := x.Type.(*ExecRequest_StandardInput); ok { + return x.StandardInput + } + } + return nil +} + +func (x *ExecRequest) GetTerminalResize() *TerminalSize { + if x != nil { + if x, ok := x.Type.(*ExecRequest_TerminalResize); ok { + return x.TerminalResize + } + } + return nil +} + +type isExecRequest_Type interface { + isExecRequest_Type() +} + +type ExecRequest_Command_ struct { + Command *ExecRequest_Command `protobuf:"bytes,1,opt,name=command,proto3,oneof"` +} + +type ExecRequest_StandardInput struct { + StandardInput *IOChunk `protobuf:"bytes,2,opt,name=standard_input,json=standardInput,proto3,oneof"` +} + +type ExecRequest_TerminalResize struct { + TerminalResize *TerminalSize `protobuf:"bytes,3,opt,name=terminal_resize,json=terminalResize,proto3,oneof"` +} + +func (*ExecRequest_Command_) isExecRequest_Type() {} + +func (*ExecRequest_StandardInput) isExecRequest_Type() {} + +func (*ExecRequest_TerminalResize) isExecRequest_Type() {} + +type ExecResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Type: + // + // *ExecResponse_Exit_ + // *ExecResponse_StandardOutput + // *ExecResponse_StandardError + Type isExecResponse_Type `protobuf_oneof:"type"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecResponse) Reset() { + *x = ExecResponse{} + mi := &file_guestagent_guestagent_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecResponse) ProtoMessage() {} + +func (x *ExecResponse) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecResponse.ProtoReflect.Descriptor instead. +func (*ExecResponse) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{1} +} + +func (x *ExecResponse) GetType() isExecResponse_Type { + if x != nil { + return x.Type + } + return nil +} + +func (x *ExecResponse) GetExit() *ExecResponse_Exit { + if x != nil { + if x, ok := x.Type.(*ExecResponse_Exit_); ok { + return x.Exit + } + } + return nil +} + +func (x *ExecResponse) GetStandardOutput() *IOChunk { + if x != nil { + if x, ok := x.Type.(*ExecResponse_StandardOutput); ok { + return x.StandardOutput + } + } + return nil +} + +func (x *ExecResponse) GetStandardError() *IOChunk { + if x != nil { + if x, ok := x.Type.(*ExecResponse_StandardError); ok { + return x.StandardError + } + } + return nil +} + +type isExecResponse_Type interface { + isExecResponse_Type() +} + +type ExecResponse_Exit_ struct { + Exit *ExecResponse_Exit `protobuf:"bytes,1,opt,name=exit,proto3,oneof"` +} + +type ExecResponse_StandardOutput struct { + StandardOutput *IOChunk `protobuf:"bytes,2,opt,name=standard_output,json=standardOutput,proto3,oneof"` +} + +type ExecResponse_StandardError struct { + StandardError *IOChunk `protobuf:"bytes,3,opt,name=standard_error,json=standardError,proto3,oneof"` +} + +func (*ExecResponse_Exit_) isExecResponse_Type() {} + +func (*ExecResponse_StandardOutput) isExecResponse_Type() {} + +func (*ExecResponse_StandardError) isExecResponse_Type() {} + +type TerminalSize struct { + state protoimpl.MessageState `protogen:"open.v1"` + Rows uint32 `protobuf:"varint,1,opt,name=rows,proto3" json:"rows,omitempty"` + Cols uint32 `protobuf:"varint,2,opt,name=cols,proto3" json:"cols,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TerminalSize) Reset() { + *x = TerminalSize{} + mi := &file_guestagent_guestagent_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TerminalSize) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TerminalSize) ProtoMessage() {} + +func (x *TerminalSize) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TerminalSize.ProtoReflect.Descriptor instead. +func (*TerminalSize) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{2} +} + +func (x *TerminalSize) GetRows() uint32 { + if x != nil { + return x.Rows + } + return 0 +} + +func (x *TerminalSize) GetCols() uint32 { + if x != nil { + return x.Cols + } + return 0 +} + +type IOChunk struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IOChunk) Reset() { + *x = IOChunk{} + mi := &file_guestagent_guestagent_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IOChunk) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IOChunk) ProtoMessage() {} + +func (x *IOChunk) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IOChunk.ProtoReflect.Descriptor instead. +func (*IOChunk) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{3} +} + +func (x *IOChunk) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type ResolveIPRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ResolveIPRequest) Reset() { + *x = ResolveIPRequest{} + mi := &file_guestagent_guestagent_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ResolveIPRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ResolveIPRequest) ProtoMessage() {} + +func (x *ResolveIPRequest) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ResolveIPRequest.ProtoReflect.Descriptor instead. +func (*ResolveIPRequest) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{4} +} + +type ResolveIPResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ResolveIPResponse) Reset() { + *x = ResolveIPResponse{} + mi := &file_guestagent_guestagent_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ResolveIPResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ResolveIPResponse) ProtoMessage() {} + +func (x *ResolveIPResponse) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ResolveIPResponse.ProtoReflect.Descriptor instead. +func (*ResolveIPResponse) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{5} +} + +func (x *ResolveIPResponse) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +type ExecRequest_Command struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Args []string `protobuf:"bytes,2,rep,name=args,proto3" json:"args,omitempty"` + Interactive bool `protobuf:"varint,3,opt,name=interactive,proto3" json:"interactive,omitempty"` + Tty bool `protobuf:"varint,4,opt,name=tty,proto3" json:"tty,omitempty"` + TerminalSize *TerminalSize `protobuf:"bytes,5,opt,name=terminal_size,json=terminalSize,proto3" json:"terminal_size,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecRequest_Command) Reset() { + *x = ExecRequest_Command{} + mi := &file_guestagent_guestagent_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecRequest_Command) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecRequest_Command) ProtoMessage() {} + +func (x *ExecRequest_Command) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecRequest_Command.ProtoReflect.Descriptor instead. +func (*ExecRequest_Command) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{0, 0} +} + +func (x *ExecRequest_Command) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *ExecRequest_Command) GetArgs() []string { + if x != nil { + return x.Args + } + return nil +} + +func (x *ExecRequest_Command) GetInteractive() bool { + if x != nil { + return x.Interactive + } + return false +} + +func (x *ExecRequest_Command) GetTty() bool { + if x != nil { + return x.Tty + } + return false +} + +func (x *ExecRequest_Command) GetTerminalSize() *TerminalSize { + if x != nil { + return x.TerminalSize + } + return nil +} + +type ExecResponse_Exit struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ExecResponse_Exit) Reset() { + *x = ExecResponse_Exit{} + mi := &file_guestagent_guestagent_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ExecResponse_Exit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecResponse_Exit) ProtoMessage() {} + +func (x *ExecResponse_Exit) ProtoReflect() protoreflect.Message { + mi := &file_guestagent_guestagent_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecResponse_Exit.ProtoReflect.Descriptor instead. +func (*ExecResponse_Exit) Descriptor() ([]byte, []int) { + return file_guestagent_guestagent_proto_rawDescGZIP(), []int{1, 0} +} + +func (x *ExecResponse_Exit) GetCode() int32 { + if x != nil { + return x.Code + } + return 0 +} + +var File_guestagent_guestagent_proto protoreflect.FileDescriptor + +const file_guestagent_guestagent_proto_rawDesc = "" + + "\n" + + "\x1bguestagent/guestagent.proto\x12\n" + + "guestagent\"\xfc\x02\n" + + "\vExecRequest\x12;\n" + + "\acommand\x18\x01 \x01(\v2\x1f.guestagent.ExecRequest.CommandH\x00R\acommand\x12<\n" + + "\x0estandard_input\x18\x02 \x01(\v2\x13.guestagent.IOChunkH\x00R\rstandardInput\x12C\n" + + "\x0fterminal_resize\x18\x03 \x01(\v2\x18.guestagent.TerminalSizeH\x00R\x0eterminalResize\x1a\xa4\x01\n" + + "\aCommand\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" + + "\x04args\x18\x02 \x03(\tR\x04args\x12 \n" + + "\vinteractive\x18\x03 \x01(\bR\vinteractive\x12\x10\n" + + "\x03tty\x18\x04 \x01(\bR\x03tty\x12=\n" + + "\rterminal_size\x18\x05 \x01(\v2\x18.guestagent.TerminalSizeR\fterminalSizeB\x06\n" + + "\x04type\"\xe5\x01\n" + + "\fExecResponse\x123\n" + + "\x04exit\x18\x01 \x01(\v2\x1d.guestagent.ExecResponse.ExitH\x00R\x04exit\x12>\n" + + "\x0fstandard_output\x18\x02 \x01(\v2\x13.guestagent.IOChunkH\x00R\x0estandardOutput\x12<\n" + + "\x0estandard_error\x18\x03 \x01(\v2\x13.guestagent.IOChunkH\x00R\rstandardError\x1a\x1a\n" + + "\x04Exit\x12\x12\n" + + "\x04code\x18\x01 \x01(\x05R\x04codeB\x06\n" + + "\x04type\"6\n" + + "\fTerminalSize\x12\x12\n" + + "\x04rows\x18\x01 \x01(\rR\x04rows\x12\x12\n" + + "\x04cols\x18\x02 \x01(\rR\x04cols\"\x1d\n" + + "\aIOChunk\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"\x12\n" + + "\x10ResolveIPRequest\"#\n" + + "\x11ResolveIPResponse\x12\x0e\n" + + "\x02ip\x18\x01 \x01(\tR\x02ip2\x90\x01\n" + + "\x05Agent\x12=\n" + + "\x04Exec\x12\x17.guestagent.ExecRequest\x1a\x18.guestagent.ExecResponse(\x010\x01\x12H\n" + + "\tResolveIP\x12\x1c.guestagent.ResolveIPRequest\x1a\x1d.guestagent.ResolveIPResponseB.Z,github.com/cirruslabs/orchard/rpc/guestagentb\x06proto3" + +var ( + file_guestagent_guestagent_proto_rawDescOnce sync.Once + file_guestagent_guestagent_proto_rawDescData []byte +) + +func file_guestagent_guestagent_proto_rawDescGZIP() []byte { + file_guestagent_guestagent_proto_rawDescOnce.Do(func() { + file_guestagent_guestagent_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_guestagent_guestagent_proto_rawDesc), len(file_guestagent_guestagent_proto_rawDesc))) + }) + return file_guestagent_guestagent_proto_rawDescData +} + +var file_guestagent_guestagent_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_guestagent_guestagent_proto_goTypes = []any{ + (*ExecRequest)(nil), // 0: guestagent.ExecRequest + (*ExecResponse)(nil), // 1: guestagent.ExecResponse + (*TerminalSize)(nil), // 2: guestagent.TerminalSize + (*IOChunk)(nil), // 3: guestagent.IOChunk + (*ResolveIPRequest)(nil), // 4: guestagent.ResolveIPRequest + (*ResolveIPResponse)(nil), // 5: guestagent.ResolveIPResponse + (*ExecRequest_Command)(nil), // 6: guestagent.ExecRequest.Command + (*ExecResponse_Exit)(nil), // 7: guestagent.ExecResponse.Exit +} +var file_guestagent_guestagent_proto_depIdxs = []int32{ + 6, // 0: guestagent.ExecRequest.command:type_name -> guestagent.ExecRequest.Command + 3, // 1: guestagent.ExecRequest.standard_input:type_name -> guestagent.IOChunk + 2, // 2: guestagent.ExecRequest.terminal_resize:type_name -> guestagent.TerminalSize + 7, // 3: guestagent.ExecResponse.exit:type_name -> guestagent.ExecResponse.Exit + 3, // 4: guestagent.ExecResponse.standard_output:type_name -> guestagent.IOChunk + 3, // 5: guestagent.ExecResponse.standard_error:type_name -> guestagent.IOChunk + 2, // 6: guestagent.ExecRequest.Command.terminal_size:type_name -> guestagent.TerminalSize + 0, // 7: guestagent.Agent.Exec:input_type -> guestagent.ExecRequest + 4, // 8: guestagent.Agent.ResolveIP:input_type -> guestagent.ResolveIPRequest + 1, // 9: guestagent.Agent.Exec:output_type -> guestagent.ExecResponse + 5, // 10: guestagent.Agent.ResolveIP:output_type -> guestagent.ResolveIPResponse + 9, // [9:11] is the sub-list for method output_type + 7, // [7:9] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name +} + +func init() { file_guestagent_guestagent_proto_init() } +func file_guestagent_guestagent_proto_init() { + if File_guestagent_guestagent_proto != nil { + return + } + file_guestagent_guestagent_proto_msgTypes[0].OneofWrappers = []any{ + (*ExecRequest_Command_)(nil), + (*ExecRequest_StandardInput)(nil), + (*ExecRequest_TerminalResize)(nil), + } + file_guestagent_guestagent_proto_msgTypes[1].OneofWrappers = []any{ + (*ExecResponse_Exit_)(nil), + (*ExecResponse_StandardOutput)(nil), + (*ExecResponse_StandardError)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_guestagent_guestagent_proto_rawDesc), len(file_guestagent_guestagent_proto_rawDesc)), + NumEnums: 0, + NumMessages: 8, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_guestagent_guestagent_proto_goTypes, + DependencyIndexes: file_guestagent_guestagent_proto_depIdxs, + MessageInfos: file_guestagent_guestagent_proto_msgTypes, + }.Build() + File_guestagent_guestagent_proto = out.File + file_guestagent_guestagent_proto_goTypes = nil + file_guestagent_guestagent_proto_depIdxs = nil +} diff --git a/rpc/guestagent/guestagent.proto b/rpc/guestagent/guestagent.proto new file mode 100644 index 00000000..43ac60e1 --- /dev/null +++ b/rpc/guestagent/guestagent.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package guestagent; + +option go_package = "github.com/cirruslabs/orchard/rpc/guestagent"; + +service Agent { + rpc Exec(stream ExecRequest) returns (stream ExecResponse); + rpc ResolveIP(ResolveIPRequest) returns (ResolveIPResponse); +} + +message ExecRequest { + message Command { + string name = 1; + repeated string args = 2; + bool interactive = 3; + bool tty = 4; + TerminalSize terminal_size = 5; + } + + oneof type { + Command command = 1; + IOChunk standard_input = 2; + TerminalSize terminal_resize = 3; + } +} + +message ExecResponse { + message Exit { + int32 code = 1; + } + + oneof type { + Exit exit = 1; + IOChunk standard_output = 2; + IOChunk standard_error = 3; + } +} + +message TerminalSize { + uint32 rows = 1; + uint32 cols = 2; +} + +message IOChunk { + bytes data = 1; +} + +message ResolveIPRequest { + // nothing for now +} + +message ResolveIPResponse { + string ip = 1; +} diff --git a/rpc/guestagent/guestagent_grpc.pb.go b/rpc/guestagent/guestagent_grpc.pb.go new file mode 100644 index 00000000..bb9641a5 --- /dev/null +++ b/rpc/guestagent/guestagent_grpc.pb.go @@ -0,0 +1,154 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: guestagent/guestagent.proto + +package guestagent + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Agent_Exec_FullMethodName = "/guestagent.Agent/Exec" + Agent_ResolveIP_FullMethodName = "/guestagent.Agent/ResolveIP" +) + +// AgentClient is the client API for Agent service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AgentClient interface { + Exec(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ExecRequest, ExecResponse], error) + ResolveIP(ctx context.Context, in *ResolveIPRequest, opts ...grpc.CallOption) (*ResolveIPResponse, error) +} + +type agentClient struct { + cc grpc.ClientConnInterface +} + +func NewAgentClient(cc grpc.ClientConnInterface) AgentClient { + return &agentClient{cc} +} + +func (c *agentClient) Exec(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ExecRequest, ExecResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &Agent_ServiceDesc.Streams[0], Agent_Exec_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[ExecRequest, ExecResponse]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Agent_ExecClient = grpc.BidiStreamingClient[ExecRequest, ExecResponse] + +func (c *agentClient) ResolveIP(ctx context.Context, in *ResolveIPRequest, opts ...grpc.CallOption) (*ResolveIPResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ResolveIPResponse) + err := c.cc.Invoke(ctx, Agent_ResolveIP_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// AgentServer is the server API for Agent service. +// All implementations must embed UnimplementedAgentServer +// for forward compatibility. +type AgentServer interface { + Exec(grpc.BidiStreamingServer[ExecRequest, ExecResponse]) error + ResolveIP(context.Context, *ResolveIPRequest) (*ResolveIPResponse, error) + mustEmbedUnimplementedAgentServer() +} + +// UnimplementedAgentServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedAgentServer struct{} + +func (UnimplementedAgentServer) Exec(grpc.BidiStreamingServer[ExecRequest, ExecResponse]) error { + return status.Errorf(codes.Unimplemented, "method Exec not implemented") +} +func (UnimplementedAgentServer) ResolveIP(context.Context, *ResolveIPRequest) (*ResolveIPResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ResolveIP not implemented") +} +func (UnimplementedAgentServer) mustEmbedUnimplementedAgentServer() {} +func (UnimplementedAgentServer) testEmbeddedByValue() {} + +// UnsafeAgentServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AgentServer will +// result in compilation errors. +type UnsafeAgentServer interface { + mustEmbedUnimplementedAgentServer() +} + +func RegisterAgentServer(s grpc.ServiceRegistrar, srv AgentServer) { + // If the following call pancis, it indicates UnimplementedAgentServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Agent_ServiceDesc, srv) +} + +func _Agent_Exec_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(AgentServer).Exec(&grpc.GenericServerStream[ExecRequest, ExecResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Agent_ExecServer = grpc.BidiStreamingServer[ExecRequest, ExecResponse] + +func _Agent_ResolveIP_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ResolveIPRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AgentServer).ResolveIP(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Agent_ResolveIP_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AgentServer).ResolveIP(ctx, req.(*ResolveIPRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Agent_ServiceDesc is the grpc.ServiceDesc for Agent service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Agent_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "guestagent.Agent", + HandlerType: (*AgentServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ResolveIP", + Handler: _Agent_ResolveIP_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Exec", + Handler: _Agent_Exec_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "guestagent/guestagent.proto", +} diff --git a/rpc/orchard.pb.go b/rpc/orchard.pb.go index 6f7cbef0..61e87fb1 100644 --- a/rpc/orchard.pb.go +++ b/rpc/orchard.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.34.2 +// protoc-gen-go v1.36.9 // protoc (unknown) // source: orchard.proto @@ -12,6 +12,7 @@ import ( emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -22,25 +23,23 @@ const ( ) type WatchInstruction struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Types that are assignable to Action: + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Action: // // *WatchInstruction_PortForwardAction // *WatchInstruction_SyncVmsAction // *WatchInstruction_ResolveIpAction - Action isWatchInstruction_Action `protobuf_oneof:"action"` + // *WatchInstruction_ExecAction + Action isWatchInstruction_Action `protobuf_oneof:"action"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *WatchInstruction) Reset() { *x = WatchInstruction{} - if protoimpl.UnsafeEnabled { - mi := &file_orchard_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_orchard_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *WatchInstruction) String() string { @@ -51,7 +50,7 @@ func (*WatchInstruction) ProtoMessage() {} func (x *WatchInstruction) ProtoReflect() protoreflect.Message { mi := &file_orchard_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -66,30 +65,45 @@ func (*WatchInstruction) Descriptor() ([]byte, []int) { return file_orchard_proto_rawDescGZIP(), []int{0} } -func (m *WatchInstruction) GetAction() isWatchInstruction_Action { - if m != nil { - return m.Action +func (x *WatchInstruction) GetAction() isWatchInstruction_Action { + if x != nil { + return x.Action } return nil } func (x *WatchInstruction) GetPortForwardAction() *WatchInstruction_PortForward { - if x, ok := x.GetAction().(*WatchInstruction_PortForwardAction); ok { - return x.PortForwardAction + if x != nil { + if x, ok := x.Action.(*WatchInstruction_PortForwardAction); ok { + return x.PortForwardAction + } } return nil } func (x *WatchInstruction) GetSyncVmsAction() *WatchInstruction_SyncVMs { - if x, ok := x.GetAction().(*WatchInstruction_SyncVmsAction); ok { - return x.SyncVmsAction + if x != nil { + if x, ok := x.Action.(*WatchInstruction_SyncVmsAction); ok { + return x.SyncVmsAction + } } return nil } func (x *WatchInstruction) GetResolveIpAction() *WatchInstruction_ResolveIP { - if x, ok := x.GetAction().(*WatchInstruction_ResolveIpAction); ok { - return x.ResolveIpAction + if x != nil { + if x, ok := x.Action.(*WatchInstruction_ResolveIpAction); ok { + return x.ResolveIpAction + } + } + return nil +} + +func (x *WatchInstruction) GetExecAction() *WatchInstruction_Exec { + if x != nil { + if x, ok := x.Action.(*WatchInstruction_ExecAction); ok { + return x.ExecAction + } } return nil } @@ -110,27 +124,30 @@ type WatchInstruction_ResolveIpAction struct { ResolveIpAction *WatchInstruction_ResolveIP `protobuf:"bytes,3,opt,name=resolve_ip_action,json=resolveIpAction,proto3,oneof"` } +type WatchInstruction_ExecAction struct { + ExecAction *WatchInstruction_Exec `protobuf:"bytes,4,opt,name=exec_action,json=execAction,proto3,oneof"` +} + func (*WatchInstruction_PortForwardAction) isWatchInstruction_Action() {} func (*WatchInstruction_SyncVmsAction) isWatchInstruction_Action() {} func (*WatchInstruction_ResolveIpAction) isWatchInstruction_Action() {} +func (*WatchInstruction_ExecAction) isWatchInstruction_Action() {} + type PortForwardData struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields - - Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + sizeCache protoimpl.SizeCache } func (x *PortForwardData) Reset() { *x = PortForwardData{} - if protoimpl.UnsafeEnabled { - mi := &file_orchard_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_orchard_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *PortForwardData) String() string { @@ -141,7 +158,7 @@ func (*PortForwardData) ProtoMessage() {} func (x *PortForwardData) ProtoReflect() protoreflect.Message { mi := &file_orchard_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -163,22 +180,63 @@ func (x *PortForwardData) GetData() []byte { return nil } -type ResolveIPResult struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +type ExecData struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} - Session string `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` - Ip string `protobuf:"bytes,2,opt,name=ip,proto3" json:"ip,omitempty"` +func (x *ExecData) Reset() { + *x = ExecData{} + mi := &file_orchard_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } -func (x *ResolveIPResult) Reset() { - *x = ResolveIPResult{} - if protoimpl.UnsafeEnabled { - mi := &file_orchard_proto_msgTypes[2] +func (x *ExecData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExecData) ProtoMessage() {} + +func (x *ExecData) ProtoReflect() protoreflect.Message { + mi := &file_orchard_proto_msgTypes[2] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms } + return mi.MessageOf(x) +} + +// Deprecated: Use ExecData.ProtoReflect.Descriptor instead. +func (*ExecData) Descriptor() ([]byte, []int) { + return file_orchard_proto_rawDescGZIP(), []int{2} +} + +func (x *ExecData) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type ResolveIPResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + Session string `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` + Ip string `protobuf:"bytes,2,opt,name=ip,proto3" json:"ip,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ResolveIPResult) Reset() { + *x = ResolveIPResult{} + mi := &file_orchard_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ResolveIPResult) String() string { @@ -188,8 +246,8 @@ func (x *ResolveIPResult) String() string { func (*ResolveIPResult) ProtoMessage() {} func (x *ResolveIPResult) ProtoReflect() protoreflect.Message { - mi := &file_orchard_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_orchard_proto_msgTypes[3] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -201,7 +259,7 @@ func (x *ResolveIPResult) ProtoReflect() protoreflect.Message { // Deprecated: Use ResolveIPResult.ProtoReflect.Descriptor instead. func (*ResolveIPResult) Descriptor() ([]byte, []int) { - return file_orchard_proto_rawDescGZIP(), []int{2} + return file_orchard_proto_rawDescGZIP(), []int{3} } func (x *ResolveIPResult) GetSession() string { @@ -219,25 +277,22 @@ func (x *ResolveIPResult) GetIp() string { } type WatchInstruction_PortForward struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` // we can have multiple port forwards for the same vm/port pair // let's distinguish them by a unique session Session string `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` // can be empty to request port-forwarding to the worker itself - VmUid string `protobuf:"bytes,2,opt,name=vm_uid,json=vmUid,proto3" json:"vm_uid,omitempty"` - Port uint32 `protobuf:"varint,3,opt,name=port,proto3" json:"port,omitempty"` + VmUid string `protobuf:"bytes,2,opt,name=vm_uid,json=vmUid,proto3" json:"vm_uid,omitempty"` + Port uint32 `protobuf:"varint,3,opt,name=port,proto3" json:"port,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *WatchInstruction_PortForward) Reset() { *x = WatchInstruction_PortForward{} - if protoimpl.UnsafeEnabled { - mi := &file_orchard_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_orchard_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *WatchInstruction_PortForward) String() string { @@ -247,8 +302,8 @@ func (x *WatchInstruction_PortForward) String() string { func (*WatchInstruction_PortForward) ProtoMessage() {} func (x *WatchInstruction_PortForward) ProtoReflect() protoreflect.Message { - mi := &file_orchard_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_orchard_proto_msgTypes[4] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -285,18 +340,16 @@ func (x *WatchInstruction_PortForward) GetPort() uint32 { } type WatchInstruction_SyncVMs struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *WatchInstruction_SyncVMs) Reset() { *x = WatchInstruction_SyncVMs{} - if protoimpl.UnsafeEnabled { - mi := &file_orchard_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_orchard_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *WatchInstruction_SyncVMs) String() string { @@ -306,8 +359,8 @@ func (x *WatchInstruction_SyncVMs) String() string { func (*WatchInstruction_SyncVMs) ProtoMessage() {} func (x *WatchInstruction_SyncVMs) ProtoReflect() protoreflect.Message { - mi := &file_orchard_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_orchard_proto_msgTypes[5] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -323,23 +376,20 @@ func (*WatchInstruction_SyncVMs) Descriptor() ([]byte, []int) { } type WatchInstruction_ResolveIP struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` // we can have multiple IP resolution requests for the same vm // let's distinguish them by a unique session - Session string `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` - VmUid string `protobuf:"bytes,2,opt,name=vm_uid,json=vmUid,proto3" json:"vm_uid,omitempty"` + Session string `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` + VmUid string `protobuf:"bytes,2,opt,name=vm_uid,json=vmUid,proto3" json:"vm_uid,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *WatchInstruction_ResolveIP) Reset() { *x = WatchInstruction_ResolveIP{} - if protoimpl.UnsafeEnabled { - mi := &file_orchard_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_orchard_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *WatchInstruction_ResolveIP) String() string { @@ -349,8 +399,8 @@ func (x *WatchInstruction_ResolveIP) String() string { func (*WatchInstruction_ResolveIP) ProtoMessage() {} func (x *WatchInstruction_ResolveIP) ProtoReflect() protoreflect.Message { - mi := &file_orchard_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_orchard_proto_msgTypes[6] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -379,97 +429,241 @@ func (x *WatchInstruction_ResolveIP) GetVmUid() string { return "" } -var File_orchard_proto protoreflect.FileDescriptor +type WatchInstruction_Exec struct { + state protoimpl.MessageState `protogen:"open.v1"` + // we can have multiple exec requests for the same vm + // so use a session identifier + Session string `protobuf:"bytes,1,opt,name=session,proto3" json:"session,omitempty"` + VmUid string `protobuf:"bytes,2,opt,name=vm_uid,json=vmUid,proto3" json:"vm_uid,omitempty"` + Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"` + Args []string `protobuf:"bytes,4,rep,name=args,proto3" json:"args,omitempty"` + Interactive bool `protobuf:"varint,5,opt,name=interactive,proto3" json:"interactive,omitempty"` + Tty bool `protobuf:"varint,6,opt,name=tty,proto3" json:"tty,omitempty"` + TerminalSize *WatchInstruction_Exec_TerminalSize `protobuf:"bytes,7,opt,name=terminal_size,json=terminalSize,proto3" json:"terminal_size,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WatchInstruction_Exec) Reset() { + *x = WatchInstruction_Exec{} + mi := &file_orchard_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WatchInstruction_Exec) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WatchInstruction_Exec) ProtoMessage() {} + +func (x *WatchInstruction_Exec) ProtoReflect() protoreflect.Message { + mi := &file_orchard_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WatchInstruction_Exec.ProtoReflect.Descriptor instead. +func (*WatchInstruction_Exec) Descriptor() ([]byte, []int) { + return file_orchard_proto_rawDescGZIP(), []int{0, 3} +} + +func (x *WatchInstruction_Exec) GetSession() string { + if x != nil { + return x.Session + } + return "" +} + +func (x *WatchInstruction_Exec) GetVmUid() string { + if x != nil { + return x.VmUid + } + return "" +} + +func (x *WatchInstruction_Exec) GetCommand() string { + if x != nil { + return x.Command + } + return "" +} + +func (x *WatchInstruction_Exec) GetArgs() []string { + if x != nil { + return x.Args + } + return nil +} + +func (x *WatchInstruction_Exec) GetInteractive() bool { + if x != nil { + return x.Interactive + } + return false +} + +func (x *WatchInstruction_Exec) GetTty() bool { + if x != nil { + return x.Tty + } + return false +} + +func (x *WatchInstruction_Exec) GetTerminalSize() *WatchInstruction_Exec_TerminalSize { + if x != nil { + return x.TerminalSize + } + return nil +} + +type WatchInstruction_Exec_TerminalSize struct { + state protoimpl.MessageState `protogen:"open.v1"` + Rows uint32 `protobuf:"varint,1,opt,name=rows,proto3" json:"rows,omitempty"` + Cols uint32 `protobuf:"varint,2,opt,name=cols,proto3" json:"cols,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WatchInstruction_Exec_TerminalSize) Reset() { + *x = WatchInstruction_Exec_TerminalSize{} + mi := &file_orchard_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WatchInstruction_Exec_TerminalSize) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WatchInstruction_Exec_TerminalSize) ProtoMessage() {} + +func (x *WatchInstruction_Exec_TerminalSize) ProtoReflect() protoreflect.Message { + mi := &file_orchard_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} -var file_orchard_proto_rawDesc = []byte{ - 0x0a, 0x0d, 0x6f, 0x72, 0x63, 0x68, 0x61, 0x72, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, - 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x9a, 0x03, 0x0a, - 0x10, 0x57, 0x61, 0x74, 0x63, 0x68, 0x49, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x12, 0x4f, 0x0a, 0x13, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x5f, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, - 0x2e, 0x57, 0x61, 0x74, 0x63, 0x68, 0x49, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x48, 0x00, 0x52, - 0x11, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x12, 0x43, 0x0a, 0x0f, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x76, 0x6d, 0x73, 0x5f, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x57, 0x61, - 0x74, 0x63, 0x68, 0x49, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x53, - 0x79, 0x6e, 0x63, 0x56, 0x4d, 0x73, 0x48, 0x00, 0x52, 0x0d, 0x73, 0x79, 0x6e, 0x63, 0x56, 0x6d, - 0x73, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x49, 0x0a, 0x11, 0x72, 0x65, 0x73, 0x6f, 0x6c, - 0x76, 0x65, 0x5f, 0x69, 0x70, 0x5f, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x57, 0x61, 0x74, 0x63, 0x68, 0x49, 0x6e, 0x73, 0x74, 0x72, 0x75, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x49, 0x50, 0x48, - 0x00, 0x52, 0x0f, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x49, 0x70, 0x41, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x1a, 0x52, 0x0a, 0x0b, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x06, 0x76, - 0x6d, 0x5f, 0x75, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x6d, 0x55, - 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, - 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x1a, 0x09, 0x0a, 0x07, 0x53, 0x79, 0x6e, 0x63, 0x56, 0x4d, - 0x73, 0x1a, 0x3c, 0x0a, 0x09, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x49, 0x50, 0x12, 0x18, - 0x0a, 0x07, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x06, 0x76, 0x6d, 0x5f, 0x75, - 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x6d, 0x55, 0x69, 0x64, 0x42, - 0x08, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x25, 0x0a, 0x0f, 0x50, 0x6f, 0x72, - 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x0a, 0x04, - 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, - 0x22, 0x3b, 0x0a, 0x0f, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x49, 0x50, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, - 0x02, 0x69, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x70, 0x32, 0xb0, 0x01, - 0x0a, 0x0a, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x12, 0x34, 0x0a, 0x05, - 0x57, 0x61, 0x74, 0x63, 0x68, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x57, 0x61, 0x74, 0x63, 0x68, 0x49, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x30, 0x01, 0x12, 0x35, 0x0a, 0x0b, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x12, 0x10, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x44, - 0x61, 0x74, 0x61, 0x1a, 0x10, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x44, 0x61, 0x74, 0x61, 0x28, 0x01, 0x30, 0x01, 0x12, 0x35, 0x0a, 0x09, 0x52, 0x65, 0x73, - 0x6f, 0x6c, 0x76, 0x65, 0x49, 0x50, 0x12, 0x10, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, - 0x49, 0x50, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x42, 0x23, 0x5a, 0x21, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, - 0x69, 0x72, 0x72, 0x75, 0x73, 0x6c, 0x61, 0x62, 0x73, 0x2f, 0x6f, 0x72, 0x63, 0x68, 0x61, 0x72, - 0x64, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +// Deprecated: Use WatchInstruction_Exec_TerminalSize.ProtoReflect.Descriptor instead. +func (*WatchInstruction_Exec_TerminalSize) Descriptor() ([]byte, []int) { + return file_orchard_proto_rawDescGZIP(), []int{0, 3, 0} } +func (x *WatchInstruction_Exec_TerminalSize) GetRows() uint32 { + if x != nil { + return x.Rows + } + return 0 +} + +func (x *WatchInstruction_Exec_TerminalSize) GetCols() uint32 { + if x != nil { + return x.Cols + } + return 0 +} + +var File_orchard_proto protoreflect.FileDescriptor + +const file_orchard_proto_rawDesc = "" + + "\n" + + "\rorchard.proto\x1a\x1bgoogle/protobuf/empty.proto\"\xf3\x05\n" + + "\x10WatchInstruction\x12O\n" + + "\x13port_forward_action\x18\x01 \x01(\v2\x1d.WatchInstruction.PortForwardH\x00R\x11portForwardAction\x12C\n" + + "\x0fsync_vms_action\x18\x02 \x01(\v2\x19.WatchInstruction.SyncVMsH\x00R\rsyncVmsAction\x12I\n" + + "\x11resolve_ip_action\x18\x03 \x01(\v2\x1b.WatchInstruction.ResolveIPH\x00R\x0fresolveIpAction\x129\n" + + "\vexec_action\x18\x04 \x01(\v2\x16.WatchInstruction.ExecH\x00R\n" + + "execAction\x1aR\n" + + "\vPortForward\x12\x18\n" + + "\asession\x18\x01 \x01(\tR\asession\x12\x15\n" + + "\x06vm_uid\x18\x02 \x01(\tR\x05vmUid\x12\x12\n" + + "\x04port\x18\x03 \x01(\rR\x04port\x1a\t\n" + + "\aSyncVMs\x1a<\n" + + "\tResolveIP\x12\x18\n" + + "\asession\x18\x01 \x01(\tR\asession\x12\x15\n" + + "\x06vm_uid\x18\x02 \x01(\tR\x05vmUid\x1a\x9b\x02\n" + + "\x04Exec\x12\x18\n" + + "\asession\x18\x01 \x01(\tR\asession\x12\x15\n" + + "\x06vm_uid\x18\x02 \x01(\tR\x05vmUid\x12\x18\n" + + "\acommand\x18\x03 \x01(\tR\acommand\x12\x12\n" + + "\x04args\x18\x04 \x03(\tR\x04args\x12 \n" + + "\vinteractive\x18\x05 \x01(\bR\vinteractive\x12\x10\n" + + "\x03tty\x18\x06 \x01(\bR\x03tty\x12H\n" + + "\rterminal_size\x18\a \x01(\v2#.WatchInstruction.Exec.TerminalSizeR\fterminalSize\x1a6\n" + + "\fTerminalSize\x12\x12\n" + + "\x04rows\x18\x01 \x01(\rR\x04rows\x12\x12\n" + + "\x04cols\x18\x02 \x01(\rR\x04colsB\b\n" + + "\x06action\"%\n" + + "\x0fPortForwardData\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"\x1e\n" + + "\bExecData\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\";\n" + + "\x0fResolveIPResult\x12\x18\n" + + "\asession\x18\x01 \x01(\tR\asession\x12\x0e\n" + + "\x02ip\x18\x02 \x01(\tR\x02ip2\xd2\x01\n" + + "\n" + + "Controller\x124\n" + + "\x05Watch\x12\x16.google.protobuf.Empty\x1a\x11.WatchInstruction0\x01\x125\n" + + "\vPortForward\x12\x10.PortForwardData\x1a\x10.PortForwardData(\x010\x01\x12 \n" + + "\x04Exec\x12\t.ExecData\x1a\t.ExecData(\x010\x01\x125\n" + + "\tResolveIP\x12\x10.ResolveIPResult\x1a\x16.google.protobuf.EmptyB#Z!github.com/cirruslabs/orchard/rpcb\x06proto3" + var ( file_orchard_proto_rawDescOnce sync.Once - file_orchard_proto_rawDescData = file_orchard_proto_rawDesc + file_orchard_proto_rawDescData []byte ) func file_orchard_proto_rawDescGZIP() []byte { file_orchard_proto_rawDescOnce.Do(func() { - file_orchard_proto_rawDescData = protoimpl.X.CompressGZIP(file_orchard_proto_rawDescData) + file_orchard_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_orchard_proto_rawDesc), len(file_orchard_proto_rawDesc))) }) return file_orchard_proto_rawDescData } -var file_orchard_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_orchard_proto_msgTypes = make([]protoimpl.MessageInfo, 9) var file_orchard_proto_goTypes = []any{ - (*WatchInstruction)(nil), // 0: WatchInstruction - (*PortForwardData)(nil), // 1: PortForwardData - (*ResolveIPResult)(nil), // 2: ResolveIPResult - (*WatchInstruction_PortForward)(nil), // 3: WatchInstruction.PortForward - (*WatchInstruction_SyncVMs)(nil), // 4: WatchInstruction.SyncVMs - (*WatchInstruction_ResolveIP)(nil), // 5: WatchInstruction.ResolveIP - (*emptypb.Empty)(nil), // 6: google.protobuf.Empty + (*WatchInstruction)(nil), // 0: WatchInstruction + (*PortForwardData)(nil), // 1: PortForwardData + (*ExecData)(nil), // 2: ExecData + (*ResolveIPResult)(nil), // 3: ResolveIPResult + (*WatchInstruction_PortForward)(nil), // 4: WatchInstruction.PortForward + (*WatchInstruction_SyncVMs)(nil), // 5: WatchInstruction.SyncVMs + (*WatchInstruction_ResolveIP)(nil), // 6: WatchInstruction.ResolveIP + (*WatchInstruction_Exec)(nil), // 7: WatchInstruction.Exec + (*WatchInstruction_Exec_TerminalSize)(nil), // 8: WatchInstruction.Exec.TerminalSize + (*emptypb.Empty)(nil), // 9: google.protobuf.Empty } var file_orchard_proto_depIdxs = []int32{ - 3, // 0: WatchInstruction.port_forward_action:type_name -> WatchInstruction.PortForward - 4, // 1: WatchInstruction.sync_vms_action:type_name -> WatchInstruction.SyncVMs - 5, // 2: WatchInstruction.resolve_ip_action:type_name -> WatchInstruction.ResolveIP - 6, // 3: Controller.Watch:input_type -> google.protobuf.Empty - 1, // 4: Controller.PortForward:input_type -> PortForwardData - 2, // 5: Controller.ResolveIP:input_type -> ResolveIPResult - 0, // 6: Controller.Watch:output_type -> WatchInstruction - 1, // 7: Controller.PortForward:output_type -> PortForwardData - 6, // 8: Controller.ResolveIP:output_type -> google.protobuf.Empty - 6, // [6:9] is the sub-list for method output_type - 3, // [3:6] is the sub-list for method input_type - 3, // [3:3] is the sub-list for extension type_name - 3, // [3:3] is the sub-list for extension extendee - 0, // [0:3] is the sub-list for field type_name + 4, // 0: WatchInstruction.port_forward_action:type_name -> WatchInstruction.PortForward + 5, // 1: WatchInstruction.sync_vms_action:type_name -> WatchInstruction.SyncVMs + 6, // 2: WatchInstruction.resolve_ip_action:type_name -> WatchInstruction.ResolveIP + 7, // 3: WatchInstruction.exec_action:type_name -> WatchInstruction.Exec + 8, // 4: WatchInstruction.Exec.terminal_size:type_name -> WatchInstruction.Exec.TerminalSize + 9, // 5: Controller.Watch:input_type -> google.protobuf.Empty + 1, // 6: Controller.PortForward:input_type -> PortForwardData + 2, // 7: Controller.Exec:input_type -> ExecData + 3, // 8: Controller.ResolveIP:input_type -> ResolveIPResult + 0, // 9: Controller.Watch:output_type -> WatchInstruction + 1, // 10: Controller.PortForward:output_type -> PortForwardData + 2, // 11: Controller.Exec:output_type -> ExecData + 9, // 12: Controller.ResolveIP:output_type -> google.protobuf.Empty + 9, // [9:13] is the sub-list for method output_type + 5, // [5:9] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name } func init() { file_orchard_proto_init() } @@ -477,92 +671,19 @@ func file_orchard_proto_init() { if File_orchard_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_orchard_proto_msgTypes[0].Exporter = func(v any, i int) any { - switch v := v.(*WatchInstruction); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_orchard_proto_msgTypes[1].Exporter = func(v any, i int) any { - switch v := v.(*PortForwardData); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_orchard_proto_msgTypes[2].Exporter = func(v any, i int) any { - switch v := v.(*ResolveIPResult); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_orchard_proto_msgTypes[3].Exporter = func(v any, i int) any { - switch v := v.(*WatchInstruction_PortForward); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_orchard_proto_msgTypes[4].Exporter = func(v any, i int) any { - switch v := v.(*WatchInstruction_SyncVMs); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_orchard_proto_msgTypes[5].Exporter = func(v any, i int) any { - switch v := v.(*WatchInstruction_ResolveIP); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } file_orchard_proto_msgTypes[0].OneofWrappers = []any{ (*WatchInstruction_PortForwardAction)(nil), (*WatchInstruction_SyncVmsAction)(nil), (*WatchInstruction_ResolveIpAction)(nil), + (*WatchInstruction_ExecAction)(nil), } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_orchard_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_orchard_proto_rawDesc), len(file_orchard_proto_rawDesc)), NumEnums: 0, - NumMessages: 6, + NumMessages: 9, NumExtensions: 0, NumServices: 1, }, @@ -571,7 +692,6 @@ func file_orchard_proto_init() { MessageInfos: file_orchard_proto_msgTypes, }.Build() File_orchard_proto = out.File - file_orchard_proto_rawDesc = nil file_orchard_proto_goTypes = nil file_orchard_proto_depIdxs = nil } diff --git a/rpc/orchard.proto b/rpc/orchard.proto index 7328e7aa..0f16b985 100644 --- a/rpc/orchard.proto +++ b/rpc/orchard.proto @@ -12,6 +12,9 @@ service Controller { // session information is passed in the requests metadata rpc PortForward(stream PortForwardData) returns (stream PortForwardData); + // session information is passed in the requests metadata + rpc Exec(stream ExecData) returns (stream ExecData); + // worker calls this method when it has successfully resolved the VM's IP rpc ResolveIP(ResolveIPResult) returns (google.protobuf.Empty); } @@ -34,11 +37,28 @@ message WatchInstruction { string session = 1; string vm_uid = 2; } + message Exec { + message TerminalSize { + uint32 rows = 1; + uint32 cols = 2; + } + + // we can have multiple exec requests for the same vm + // so use a session identifier + string session = 1; + string vm_uid = 2; + string command = 3; + repeated string args = 4; + bool interactive = 5; + bool tty = 6; + TerminalSize terminal_size = 7; + } oneof action { PortForward port_forward_action = 1; SyncVMs sync_vms_action = 2; ResolveIP resolve_ip_action = 3; + Exec exec_action = 4; } } @@ -46,6 +66,10 @@ message PortForwardData { bytes data = 1; } +message ExecData { + bytes data = 1; +} + message ResolveIPResult { string session = 1; string ip = 2; diff --git a/rpc/orchard_grpc.pb.go b/rpc/orchard_grpc.pb.go index d4f5a299..6825d683 100644 --- a/rpc/orchard_grpc.pb.go +++ b/rpc/orchard_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.4.0 +// - protoc-gen-go-grpc v1.5.1 // - protoc (unknown) // source: orchard.proto @@ -16,12 +16,13 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.62.0 or later. -const _ = grpc.SupportPackageIsVersion8 +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 const ( Controller_Watch_FullMethodName = "/Controller/Watch" Controller_PortForward_FullMethodName = "/Controller/PortForward" + Controller_Exec_FullMethodName = "/Controller/Exec" Controller_ResolveIP_FullMethodName = "/Controller/ResolveIP" ) @@ -30,10 +31,12 @@ const ( // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type ControllerClient interface { // message bus between the controller and a worker - Watch(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (Controller_WatchClient, error) + Watch(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[WatchInstruction], error) // single purpose method when a port forward is requested and running // session information is passed in the requests metadata - PortForward(ctx context.Context, opts ...grpc.CallOption) (Controller_PortForwardClient, error) + PortForward(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[PortForwardData, PortForwardData], error) + // session information is passed in the requests metadata + Exec(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ExecData, ExecData], error) // worker calls this method when it has successfully resolved the VM's IP ResolveIP(ctx context.Context, in *ResolveIPResult, opts ...grpc.CallOption) (*emptypb.Empty, error) } @@ -46,13 +49,13 @@ func NewControllerClient(cc grpc.ClientConnInterface) ControllerClient { return &controllerClient{cc} } -func (c *controllerClient) Watch(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (Controller_WatchClient, error) { +func (c *controllerClient) Watch(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (grpc.ServerStreamingClient[WatchInstruction], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Controller_ServiceDesc.Streams[0], Controller_Watch_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &controllerWatchClient{ClientStream: stream} + x := &grpc.GenericClientStream[emptypb.Empty, WatchInstruction]{ClientStream: stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -62,55 +65,35 @@ func (c *controllerClient) Watch(ctx context.Context, in *emptypb.Empty, opts .. return x, nil } -type Controller_WatchClient interface { - Recv() (*WatchInstruction, error) - grpc.ClientStream -} - -type controllerWatchClient struct { - grpc.ClientStream -} - -func (x *controllerWatchClient) Recv() (*WatchInstruction, error) { - m := new(WatchInstruction) - if err := x.ClientStream.RecvMsg(m); err != nil { - return nil, err - } - return m, nil -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Controller_WatchClient = grpc.ServerStreamingClient[WatchInstruction] -func (c *controllerClient) PortForward(ctx context.Context, opts ...grpc.CallOption) (Controller_PortForwardClient, error) { +func (c *controllerClient) PortForward(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[PortForwardData, PortForwardData], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &Controller_ServiceDesc.Streams[1], Controller_PortForward_FullMethodName, cOpts...) if err != nil { return nil, err } - x := &controllerPortForwardClient{ClientStream: stream} + x := &grpc.GenericClientStream[PortForwardData, PortForwardData]{ClientStream: stream} return x, nil } -type Controller_PortForwardClient interface { - Send(*PortForwardData) error - Recv() (*PortForwardData, error) - grpc.ClientStream -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Controller_PortForwardClient = grpc.BidiStreamingClient[PortForwardData, PortForwardData] -type controllerPortForwardClient struct { - grpc.ClientStream -} - -func (x *controllerPortForwardClient) Send(m *PortForwardData) error { - return x.ClientStream.SendMsg(m) -} - -func (x *controllerPortForwardClient) Recv() (*PortForwardData, error) { - m := new(PortForwardData) - if err := x.ClientStream.RecvMsg(m); err != nil { +func (c *controllerClient) Exec(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ExecData, ExecData], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &Controller_ServiceDesc.Streams[2], Controller_Exec_FullMethodName, cOpts...) + if err != nil { return nil, err } - return m, nil + x := &grpc.GenericClientStream[ExecData, ExecData]{ClientStream: stream} + return x, nil } +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Controller_ExecClient = grpc.BidiStreamingClient[ExecData, ExecData] + func (c *controllerClient) ResolveIP(ctx context.Context, in *ResolveIPResult, opts ...grpc.CallOption) (*emptypb.Empty, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(emptypb.Empty) @@ -123,32 +106,41 @@ func (c *controllerClient) ResolveIP(ctx context.Context, in *ResolveIPResult, o // ControllerServer is the server API for Controller service. // All implementations must embed UnimplementedControllerServer -// for forward compatibility +// for forward compatibility. type ControllerServer interface { // message bus between the controller and a worker - Watch(*emptypb.Empty, Controller_WatchServer) error + Watch(*emptypb.Empty, grpc.ServerStreamingServer[WatchInstruction]) error // single purpose method when a port forward is requested and running // session information is passed in the requests metadata - PortForward(Controller_PortForwardServer) error + PortForward(grpc.BidiStreamingServer[PortForwardData, PortForwardData]) error + // session information is passed in the requests metadata + Exec(grpc.BidiStreamingServer[ExecData, ExecData]) error // worker calls this method when it has successfully resolved the VM's IP ResolveIP(context.Context, *ResolveIPResult) (*emptypb.Empty, error) mustEmbedUnimplementedControllerServer() } -// UnimplementedControllerServer must be embedded to have forward compatible implementations. -type UnimplementedControllerServer struct { -} +// UnimplementedControllerServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedControllerServer struct{} -func (UnimplementedControllerServer) Watch(*emptypb.Empty, Controller_WatchServer) error { +func (UnimplementedControllerServer) Watch(*emptypb.Empty, grpc.ServerStreamingServer[WatchInstruction]) error { return status.Errorf(codes.Unimplemented, "method Watch not implemented") } -func (UnimplementedControllerServer) PortForward(Controller_PortForwardServer) error { +func (UnimplementedControllerServer) PortForward(grpc.BidiStreamingServer[PortForwardData, PortForwardData]) error { return status.Errorf(codes.Unimplemented, "method PortForward not implemented") } +func (UnimplementedControllerServer) Exec(grpc.BidiStreamingServer[ExecData, ExecData]) error { + return status.Errorf(codes.Unimplemented, "method Exec not implemented") +} func (UnimplementedControllerServer) ResolveIP(context.Context, *ResolveIPResult) (*emptypb.Empty, error) { return nil, status.Errorf(codes.Unimplemented, "method ResolveIP not implemented") } func (UnimplementedControllerServer) mustEmbedUnimplementedControllerServer() {} +func (UnimplementedControllerServer) testEmbeddedByValue() {} // UnsafeControllerServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to ControllerServer will @@ -158,6 +150,13 @@ type UnsafeControllerServer interface { } func RegisterControllerServer(s grpc.ServiceRegistrar, srv ControllerServer) { + // If the following call pancis, it indicates UnimplementedControllerServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } s.RegisterService(&Controller_ServiceDesc, srv) } @@ -166,47 +165,25 @@ func _Controller_Watch_Handler(srv interface{}, stream grpc.ServerStream) error if err := stream.RecvMsg(m); err != nil { return err } - return srv.(ControllerServer).Watch(m, &controllerWatchServer{ServerStream: stream}) + return srv.(ControllerServer).Watch(m, &grpc.GenericServerStream[emptypb.Empty, WatchInstruction]{ServerStream: stream}) } -type Controller_WatchServer interface { - Send(*WatchInstruction) error - grpc.ServerStream -} - -type controllerWatchServer struct { - grpc.ServerStream -} - -func (x *controllerWatchServer) Send(m *WatchInstruction) error { - return x.ServerStream.SendMsg(m) -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Controller_WatchServer = grpc.ServerStreamingServer[WatchInstruction] func _Controller_PortForward_Handler(srv interface{}, stream grpc.ServerStream) error { - return srv.(ControllerServer).PortForward(&controllerPortForwardServer{ServerStream: stream}) -} - -type Controller_PortForwardServer interface { - Send(*PortForwardData) error - Recv() (*PortForwardData, error) - grpc.ServerStream + return srv.(ControllerServer).PortForward(&grpc.GenericServerStream[PortForwardData, PortForwardData]{ServerStream: stream}) } -type controllerPortForwardServer struct { - grpc.ServerStream -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Controller_PortForwardServer = grpc.BidiStreamingServer[PortForwardData, PortForwardData] -func (x *controllerPortForwardServer) Send(m *PortForwardData) error { - return x.ServerStream.SendMsg(m) +func _Controller_Exec_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(ControllerServer).Exec(&grpc.GenericServerStream[ExecData, ExecData]{ServerStream: stream}) } -func (x *controllerPortForwardServer) Recv() (*PortForwardData, error) { - m := new(PortForwardData) - if err := x.ServerStream.RecvMsg(m); err != nil { - return nil, err - } - return m, nil -} +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type Controller_ExecServer = grpc.BidiStreamingServer[ExecData, ExecData] func _Controller_ResolveIP_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ResolveIPResult) @@ -250,6 +227,12 @@ var Controller_ServiceDesc = grpc.ServiceDesc{ ServerStreams: true, ClientStreams: true, }, + { + StreamName: "Exec", + Handler: _Controller_Exec_Handler, + ServerStreams: true, + ClientStreams: true, + }, }, Metadata: "orchard.proto", }