From b070430e6832a2da379637145edbd16e0a72fb58 Mon Sep 17 00:00:00 2001 From: Jon Date: Sat, 28 Mar 2026 12:37:25 +0800 Subject: [PATCH 1/4] feat(execd): add WebSocket PTY support Add interactive terminal sessions over WebSocket to execd: - REST endpoints to create/resize/close PTY sessions - WebSocket handler with output replay on reconnect - Circular replay buffer (1 MiB) for session history - Windows stubs for cross-platform build compatibility Co-Authored-By: Claude Sonnet 4.6 --- components/execd/go.mod | 1 + components/execd/go.sum | 2 + components/execd/pkg/runtime/ctrl.go | 1 + components/execd/pkg/runtime/pty_session.go | 523 ++++++++++++++++++ .../execd/pkg/runtime/pty_session_test.go | 342 ++++++++++++ .../execd/pkg/runtime/pty_session_windows.go | 67 +++ components/execd/pkg/runtime/replay_buffer.go | 125 +++++ .../execd/pkg/runtime/replay_buffer_test.go | 154 ++++++ .../pkg/web/controller/pty_controller.go | 133 +++++ components/execd/pkg/web/controller/pty_ws.go | 358 ++++++++++++ .../execd/pkg/web/controller/pty_ws_test.go | 397 +++++++++++++ components/execd/pkg/web/model/error.go | 1 + components/execd/pkg/web/model/pty.go | 32 ++ components/execd/pkg/web/model/pty_ws.go | 55 ++ components/execd/pkg/web/router.go | 14 + 15 files changed, 2205 insertions(+) create mode 100644 components/execd/pkg/runtime/pty_session.go create mode 100644 components/execd/pkg/runtime/pty_session_test.go create mode 100644 components/execd/pkg/runtime/pty_session_windows.go create mode 100644 components/execd/pkg/runtime/replay_buffer.go create mode 100644 components/execd/pkg/runtime/replay_buffer_test.go create mode 100644 components/execd/pkg/web/controller/pty_controller.go create mode 100644 components/execd/pkg/web/controller/pty_ws.go create mode 100644 components/execd/pkg/web/controller/pty_ws_test.go create mode 100644 components/execd/pkg/web/model/pty.go create mode 100644 components/execd/pkg/web/model/pty_ws.go diff --git a/components/execd/go.mod b/components/execd/go.mod index 0b3f39cb..d44b521c 100644 --- a/components/execd/go.mod +++ b/components/execd/go.mod @@ -5,6 +5,7 @@ go 1.24.0 require ( github.com/alibaba/opensandbox/internal v0.0.0 github.com/bmatcuk/doublestar/v4 v4.9.1 + github.com/creack/pty v1.1.24 github.com/gin-gonic/gin v1.10.0 github.com/go-playground/validator/v10 v10.28.0 github.com/go-sql-driver/mysql v1.8.1 diff --git a/components/execd/go.sum b/components/execd/go.sum index 0ebe846c..88c4b7fe 100644 --- a/components/execd/go.sum +++ b/components/execd/go.sum @@ -11,6 +11,8 @@ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/execd/pkg/runtime/ctrl.go b/components/execd/pkg/runtime/ctrl.go index 3a835b31..46870bfd 100644 --- a/components/execd/pkg/runtime/ctrl.go +++ b/components/execd/pkg/runtime/ctrl.go @@ -42,6 +42,7 @@ type Controller struct { defaultLanguageSessions sync.Map // map[Language]string commandClientMap sync.Map // map[sessionID]*commandKernel bashSessionClientMap sync.Map // map[sessionID]*bashSession + ptySessionMap sync.Map // map[sessionID]*ptySession db *sql.DB dbOnce sync.Once } diff --git a/components/execd/pkg/runtime/pty_session.go b/components/execd/pkg/runtime/pty_session.go new file mode 100644 index 00000000..84fa3493 --- /dev/null +++ b/components/execd/pkg/runtime/pty_session.go @@ -0,0 +1,523 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "errors" + "fmt" + "io" + "os" + "os/exec" + "sync" + "sync/atomic" + "syscall" + + "github.com/creack/pty" + + "github.com/alibaba/opensandbox/execd/pkg/log" +) + +var errPTYSessionNotSupported = errors.New("pty session not supported") + +// IsPTYSessionSupported reports whether PTY sessions are supported on this platform. +func IsPTYSessionSupported() bool { return true } + +func NewPTYSessionID() string { + return uuidString() +} + +// ptySession manages a single interactive PTY or pipe-mode bash process. +// +// Lifecycle: +// 1. Create via newPTYSession. +// 2. Call StartPTY() or StartPipe() from the WS handler (after LockWS). +// 3. Zero or more clients call AttachOutput() to receive live output. +// 4. The bash process exits → Done() closes → exit frame sent. +// 5. Call close() to terminate an early session and release resources. +type ptySession struct { + id string + cwd string + + mu sync.Mutex + closing bool + + // Process tracking (guarded by mu) + pid int // PID of the running bash process (0 = not running) + lastExitCode int // exit code; -1 until process exits + doneCh chan struct{} // closed when process exits (non-nil after Start*) + + // Stdin (PTY master in PTY mode; write end of os.Pipe in pipe mode) + stdin io.WriteCloser + + // PTY-specific + isPTY bool + ptmx *os.File // PTY master fd; nil in pipe mode + + // Replay + replay *replayBuffer + + // WS exclusive lock: only one WebSocket client at a time. + wsConnected atomic.Bool + + // Output broadcast (guards stdoutW / stderrW). + // The broadcast goroutine holds outMu only while reading the pointer; writes + // to the pipe happen outside the lock to avoid blocking broadcast on slow clients. + outMu sync.Mutex + stdoutW *io.PipeWriter // current per-connection sink; nil when no client attached + stderrW *io.PipeWriter // nil in PTY mode +} + +func newPTYSession(id, cwd string) *ptySession { + return &ptySession{ + id: id, + cwd: cwd, + replay: newReplayBuffer(), + lastExitCode: -1, + } +} + +// LockWS attempts to acquire the exclusive WebSocket connection lock. +// Returns true on success, false if another client is already connected. +func (s *ptySession) LockWS() bool { + return s.wsConnected.CompareAndSwap(false, true) +} + +// UnlockWS releases the WebSocket connection lock. +func (s *ptySession) UnlockWS() { + s.wsConnected.Store(false) +} + +// IsRunning returns true if the bash process is currently alive. +func (s *ptySession) IsRunning() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.pid != 0 +} + +// IsPTY returns true when the session was started in PTY mode. +func (s *ptySession) IsPTY() bool { + return s.isPTY +} + +// ExitCode returns the exit code of the last process, or -1 if it has not exited yet. +func (s *ptySession) ExitCode() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.lastExitCode +} + +// Done returns a channel that is closed when the bash process exits. +// Returns nil if the process has not been started yet. +func (s *ptySession) Done() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.doneCh +} + +// ReplayBuffer returns the session's replay buffer (thread-safe). +func (s *ptySession) ReplayBuffer() *replayBuffer { + return s.replay +} + +// StartPTY launches bash via pty.StartWithSize. +// Must be called with the WS lock held. +func (s *ptySession) StartPTY() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.pid != 0 { + return errors.New("pty session already started") + } + if s.closing { + return errors.New("pty session is closing") + } + + cmd := exec.Command("bash", "--norc", "--noprofile") + cmd.Env = os.Environ() + if s.cwd != "" { + cmd.Dir = s.cwd + } + // Do NOT set Setpgid: pty.StartWithSize sets Setsid+Setctty internally. + // Combining Setsid+Setpgid causes EPERM (setpgid is illegal for a session leader). + + ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: 80, Rows: 24}) + if err != nil { + return fmt.Errorf("pty.StartWithSize: %w", err) + } + + s.ptmx = ptmx + s.isPTY = true + s.pid = cmd.Process.Pid + s.doneCh = make(chan struct{}) + s.stdin = ptmx // write to the PTY master to feed stdin + + go s.broadcastPTY() + go s.waitAndExit(cmd, ptmx) + + return nil +} + +// StartPipe launches bash with plain stdin/stdout/stderr os.Pipes. +// Must be called with the WS lock held. +func (s *ptySession) StartPipe() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.pid != 0 { + return errors.New("pty session already started") + } + if s.closing { + return errors.New("pty session is closing") + } + + stdinR, stdinW, err := os.Pipe() + if err != nil { + return fmt.Errorf("stdin pipe: %w", err) + } + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + _ = stdinR.Close() + _ = stdinW.Close() + return fmt.Errorf("stdout pipe: %w", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + _ = stdinR.Close() + _ = stdinW.Close() + _ = stdoutR.Close() + _ = stdoutW.Close() + return fmt.Errorf("stderr pipe: %w", err) + } + + cmd := exec.Command("bash", "--norc", "--noprofile") + cmd.Env = os.Environ() + if s.cwd != "" { + cmd.Dir = s.cwd + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Stdin = stdinR + cmd.Stdout = stdoutW + cmd.Stderr = stderrW + + if err := cmd.Start(); err != nil { + _ = stdinR.Close() + _ = stdinW.Close() + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + return fmt.Errorf("cmd.Start: %w", err) + } + + // Close the child-side ends in the parent — the child has its own copies. + _ = stdinR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + + s.isPTY = false + s.pid = cmd.Process.Pid + s.doneCh = make(chan struct{}) + s.stdin = stdinW + + go s.broadcastPipe(stdoutR, true) + go s.broadcastPipe(stderrR, false) + go s.waitAndExitPipe(cmd, stdinW, stdoutR, stderrR) + + return nil +} + +// broadcastPTY reads from the PTY master and fans out to replay + active WS client. +func (s *ptySession) broadcastPTY() { + buf := make([]byte, 32*1024) + for { + n, err := s.ptmx.Read(buf) + if n > 0 { + chunk := buf[:n] + s.replay.write(chunk) + s.fanout(chunk, true) + } + if err != nil { + // EIO or EOF when the child exits — normal termination + break + } + } +} + +// broadcastPipe reads from a pipe (stdout or stderr) and fans out to replay + active WS client. +func (s *ptySession) broadcastPipe(r *os.File, isStdout bool) { + buf := make([]byte, 32*1024) + for { + n, err := r.Read(buf) + if n > 0 { + chunk := buf[:n] + s.replay.write(chunk) + s.fanout(chunk, isStdout) + } + if err != nil { + break + } + } + _ = r.Close() +} + +// fanout delivers chunk to the current per-connection PipeWriter (if any). +// If isStdout is false, it targets stderrW instead. +func (s *ptySession) fanout(chunk []byte, isStdout bool) { + s.outMu.Lock() + var w *io.PipeWriter + if isStdout { + w = s.stdoutW + } else { + w = s.stderrW + } + s.outMu.Unlock() + + if w != nil { + if _, err := w.Write(chunk); err != nil { + // Pipe was closed (client detached) — ignore. + log.Warning("pty fanout write: %v", err) + } + } +} + +// waitAndExit waits for the PTY-mode process and updates session state on exit. +func (s *ptySession) waitAndExit(cmd *exec.Cmd, ptmx *os.File) { + _ = cmd.Wait() + + // Close the PTY master to unblock the broadcast goroutine. + _ = ptmx.Close() + + s.mu.Lock() + exitCode := 0 + if cmd.ProcessState != nil { + exitCode = cmd.ProcessState.ExitCode() + } + s.lastExitCode = exitCode + s.pid = 0 + doneCh := s.doneCh + s.mu.Unlock() + + close(doneCh) +} + +// waitAndExitPipe waits for the pipe-mode process and updates session state on exit. +func (s *ptySession) waitAndExitPipe(cmd *exec.Cmd, stdinW, stdoutR, stderrR *os.File) { + _ = cmd.Wait() + + // Close stdin write-end so the child (if still running) sees EOF. + _ = stdinW.Close() + + s.mu.Lock() + exitCode := 0 + if cmd.ProcessState != nil { + exitCode = cmd.ProcessState.ExitCode() + } + s.lastExitCode = exitCode + s.pid = 0 + doneCh := s.doneCh + s.mu.Unlock() + + close(doneCh) +} + +// WriteStdin writes p to bash stdin (PTY master or pipe write-end). +func (s *ptySession) WriteStdin(p []byte) (int, error) { + s.mu.Lock() + w := s.stdin + s.mu.Unlock() + if w == nil { + return 0, errors.New("session not started") + } + return w.Write(p) +} + +// AttachOutput creates a fresh per-connection io.Pipe and swaps it into the +// broadcast fanout path. +// +// Ordering guarantee (no duplicates on reconnect): +// - Caller must snapshot the replay buffer BEFORE calling AttachOutput. +// - Bytes produced between the snapshot and AttachOutput are delivered via +// the live pipe only (not in the snapshot), so each byte arrives exactly once. +// +// Returns (stdout reader, stderr reader [nil in PTY mode], detach func). +// Calling detach() closes the writers, sending EOF to the readers and +// unblocking all pump goroutines. +func (s *ptySession) AttachOutput() (io.Reader, io.Reader, func()) { + stdoutR, stdoutW := io.Pipe() + + s.outMu.Lock() + s.stdoutW = stdoutW + s.outMu.Unlock() + + if s.isPTY { + detach := func() { + s.outMu.Lock() + s.stdoutW = nil + s.outMu.Unlock() + _ = stdoutW.Close() + } + return stdoutR, nil, detach + } + + // Pipe mode: also attach stderr. + stderrR, stderrW := io.Pipe() + + s.outMu.Lock() + s.stderrW = stderrW + s.outMu.Unlock() + + detach := func() { + s.outMu.Lock() + s.stdoutW = nil + s.stderrW = nil + s.outMu.Unlock() + _ = stdoutW.Close() + _ = stderrW.Close() + } + return stdoutR, stderrR, detach +} + +// SendSignal sends the named signal to the process group. +// Recognised names: SIGINT, SIGTERM, SIGKILL, SIGQUIT, SIGHUP. +func (s *ptySession) SendSignal(name string) { + s.mu.Lock() + pid := s.pid + s.mu.Unlock() + if pid == 0 { + return + } + + sig := parseSignalName(name) + if sig == 0 { + log.Warning("ptySession.SendSignal: unknown signal %q", name) + return + } + + // In PTY mode (setsid), pgid == pid automatically. + // In pipe mode (Setpgid), pgid is also == pid. + // Either way, Kill(-pid, sig) sends to the process group. + if err := syscall.Kill(-pid, sig); err != nil { + log.Warning("ptySession.SendSignal kill(-%d, %v): %v", pid, sig, err) + } +} + +func parseSignalName(name string) syscall.Signal { + switch name { + case "SIGINT": + return syscall.SIGINT + case "SIGTERM": + return syscall.SIGTERM + case "SIGKILL": + return syscall.SIGKILL + case "SIGQUIT": + return syscall.SIGQUIT + case "SIGHUP": + return syscall.SIGHUP + default: + return 0 + } +} + +// ResizePTY updates the terminal window size (PTY mode only; no-op in pipe mode). +func (s *ptySession) ResizePTY(cols, rows uint16) error { + s.mu.Lock() + ptmx := s.ptmx + s.mu.Unlock() + if ptmx == nil { + return nil // pipe mode or not started + } + return pty.Setsize(ptmx, &pty.Winsize{Cols: cols, Rows: rows}) +} + +// close terminates the session and releases all resources. +// Safe to call multiple times. +func (s *ptySession) close() { + s.mu.Lock() + if s.closing { + s.mu.Unlock() + return + } + s.closing = true + pid := s.pid + ptmx := s.ptmx + stdin := s.stdin + s.mu.Unlock() + + if pid != 0 { + _ = syscall.Kill(-pid, syscall.SIGKILL) + } + if ptmx != nil { + _ = ptmx.Close() + } else if stdin != nil { + _ = stdin.Close() + } + + // Detach any active WS output pipe so pump goroutines unblock. + s.outMu.Lock() + stdoutW := s.stdoutW + stderrW := s.stderrW + s.stdoutW = nil + s.stderrW = nil + s.outMu.Unlock() + if stdoutW != nil { + _ = stdoutW.Close() + } + if stderrW != nil { + _ = stderrW.Close() + } +} + +// CreatePTYSession creates a new PTY session and stores it in the map. +func (c *Controller) CreatePTYSession(id, cwd string) *ptySession { + s := newPTYSession(id, cwd) + c.ptySessionMap.Store(id, s) + log.Info("created pty session %s", id) + return s +} + +// GetPTYSession looks up a PTY session by ID. Returns nil if not found. +func (c *Controller) GetPTYSession(id string) *ptySession { + if v, ok := c.ptySessionMap.Load(id); ok { + if s, ok := v.(*ptySession); ok { + return s + } + } + return nil +} + +// DeletePTYSession terminates and removes a PTY session. +// Returns ErrContextNotFound if the session does not exist. +func (c *Controller) DeletePTYSession(id string) error { + s := c.GetPTYSession(id) + if s == nil { + return ErrContextNotFound + } + s.close() + c.ptySessionMap.Delete(id) + log.Info("deleted pty session %s", id) + return nil +} + +// GetPTYSessionStatus returns status information for a PTY session. +func (c *Controller) GetPTYSessionStatus(id string) (running bool, outputOffset int64, err error) { + s := c.GetPTYSession(id) + if s == nil { + return false, 0, ErrContextNotFound + } + return s.IsRunning(), s.replay.Total(), nil +} diff --git a/components/execd/pkg/runtime/pty_session_test.go b/components/execd/pkg/runtime/pty_session_test.go new file mode 100644 index 00000000..d9c2d3f4 --- /dev/null +++ b/components/execd/pkg/runtime/pty_session_test.go @@ -0,0 +1,342 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "bufio" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// replayContains polls the replay buffer until it contains substr or timeout expires. +func replayContains(t *testing.T, s *ptySession, substr string, timeout time.Duration) bool { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + data, _ := s.replay.readFrom(0) + if strings.Contains(string(data), substr) { + return true + } + time.Sleep(25 * time.Millisecond) + } + return false +} + +func TestPTYSession_BasicExecution(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + t.Cleanup(func() { s.close() }) + + stdoutR, _, detach := s.AttachOutput() + defer detach() + go io.Copy(io.Discard, stdoutR) //nolint:errcheck + + _, err := s.WriteStdin([]byte("echo hello_pty\n")) + require.NoError(t, err) + + require.True(t, replayContains(t, s, "hello_pty", 5*time.Second), + "expected 'hello_pty' in PTY replay buffer") +} + +func TestPTYSession_IsRunning(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.False(t, s.IsRunning()) + require.NoError(t, s.StartPTY()) + t.Cleanup(func() { s.close() }) + require.True(t, s.IsRunning()) +} + +func TestPTYSession_ResizeWinsize(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + t.Cleanup(func() { s.close() }) + + // Attach output so the broadcast goroutine has a sink and fills the replay buffer. + stdoutR, _, detach := s.AttachOutput() + defer detach() + go io.Copy(io.Discard, stdoutR) //nolint:errcheck + + // Wait for bash to start (prompt appears). + time.Sleep(150 * time.Millisecond) + + require.NoError(t, s.ResizePTY(120, 40)) + + // stty size reports "rows cols" (e.g. "40 120"). + _, err := s.WriteStdin([]byte("stty size\n")) + require.NoError(t, err) + + require.True(t, replayContains(t, s, "40 120", 5*time.Second), + "expected 'stty size' output '40 120' in replay buffer") +} + +func TestPTYSession_ANSISequences(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + t.Cleanup(func() { s.close() }) + + stdoutR, _, detach := s.AttachOutput() + defer detach() + go io.Copy(io.Discard, stdoutR) //nolint:errcheck + + // Send printf with explicit ESC bytes via $'\033'. + _, err := s.WriteStdin([]byte("printf $'\\033[1;32mGREEN\\033[0m\\n'\n")) + require.NoError(t, err) + + // Wait for "GREEN" to appear in the replay buffer. + require.True(t, replayContains(t, s, "GREEN", 5*time.Second), + "expected 'GREEN' in replay buffer") + + // PTY mode should propagate ESC bytes verbatim. + data, _ := s.replay.readFrom(0) + assert.Contains(t, string(data), "\x1b", "expected ESC bytes in PTY output") +} + +func TestPTYSession_PipeMode(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPipe()) + t.Cleanup(func() { s.close() }) + + require.False(t, s.IsPTY()) + + stdoutR, stderrR, detach := s.AttachOutput() + defer detach() + require.NotNil(t, stderrR) + + stdoutCh := make(chan string, 32) + go func() { + sc := bufio.NewScanner(stdoutR) + for sc.Scan() { + stdoutCh <- sc.Text() + } + }() + + stderrCh := make(chan string, 32) + go func() { + sc := bufio.NewScanner(stderrR) + for sc.Scan() { + stderrCh <- sc.Text() + } + }() + + _, err := s.WriteStdin([]byte("echo hello_pipe\necho err_pipe >&2\n")) + require.NoError(t, err) + + require.True(t, waitForLine(stdoutCh, "hello_pipe", 5*time.Second), + "expected 'hello_pipe' on stdout") + require.True(t, waitForLine(stderrCh, "err_pipe", 5*time.Second), + "expected 'err_pipe' on stderr") +} + +func TestPTYSession_ReconnectReplay(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + t.Cleanup(func() { s.close() }) + + // First connection — drain output so replay buffer fills. + stdoutR1, _, detach1 := s.AttachOutput() + go io.Copy(io.Discard, stdoutR1) //nolint:errcheck + + _, err := s.WriteStdin([]byte("echo first_output\n")) + require.NoError(t, err) + + require.True(t, replayContains(t, s, "first_output", 5*time.Second), + "expected 'first_output' in replay buffer") + + snapshot, snapshotOff := s.replay.readFrom(0) + detach1() + time.Sleep(50 * time.Millisecond) + + // Reconnect — replay from offset 0 should return the same bytes we snapshotted. + replay, replayOff := s.replay.readFrom(0) + require.Equal(t, snapshotOff, replayOff) + // The new snapshot may be larger (more output arrived), but must contain snapshot. + require.True(t, strings.HasPrefix(string(replay), string(snapshot)) || len(replay) >= len(snapshot), + "replay should contain at least the original snapshot bytes") + + // Second connection. + stdoutR2, _, detach2 := s.AttachOutput() + defer detach2() + go io.Copy(io.Discard, stdoutR2) //nolint:errcheck + + offsetAfterFirst := int64(len(replay)) + + _, err = s.WriteStdin([]byte("echo second_output\n")) + require.NoError(t, err) + + require.True(t, replayContains(t, s, "second_output", 5*time.Second), + "expected 'second_output' in replay buffer") + + // Delta replay from offset-after-first should contain only new bytes. + newData, _ := s.replay.readFrom(offsetAfterFirst) + require.True(t, strings.Contains(string(newData), "second_output"), + "delta replay should contain 'second_output', got %q", string(newData)) +} + +func TestPTYSession_SendSIGINT(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + t.Cleanup(func() { s.close() }) + + stdoutR, _, detach := s.AttachOutput() + defer detach() + go io.Copy(io.Discard, stdoutR) //nolint:errcheck + + // Start a sleep inside the PTY. + _, err := s.WriteStdin([]byte("sleep 30\n")) + require.NoError(t, err) + time.Sleep(200 * time.Millisecond) + + // SIGINT interrupts the sleep; bash continues. + s.SendSignal("SIGINT") + + // Bash itself should still be alive. + require.Eventually(t, func() bool { + return s.IsRunning() + }, 2*time.Second, 50*time.Millisecond, "bash should still be running after SIGINT") +} + +func TestPTYSession_CloseTerminatesProcess(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + + require.True(t, s.IsRunning()) + s.close() + + done := s.Done() + if done != nil { + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("process did not exit within 3s after close()") + } + } +} + +func TestPTYSession_ExitCode(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + + s := newPTYSession(uuidString(), "") + require.NoError(t, s.StartPTY()) + + stdoutR, _, detach := s.AttachOutput() + go io.Copy(io.Discard, stdoutR) //nolint:errcheck + + _, _ = s.WriteStdin([]byte("exit 42\n")) + + done := s.Done() + select { + case <-done: + case <-time.After(5 * time.Second): + detach() + t.Fatal("bash did not exit within 5s") + } + detach() + + require.Equal(t, 42, s.ExitCode()) +} + +func TestPTYSession_LockWS(t *testing.T) { + s := newPTYSession(uuidString(), "") + require.True(t, s.LockWS(), "first lock should succeed") + require.False(t, s.LockWS(), "second lock should fail") + s.UnlockWS() + require.True(t, s.LockWS(), "lock after unlock should succeed") + s.UnlockWS() +} + +func TestPTYSession_ControllerCRUD(t *testing.T) { + c := NewController("", "") + id := uuidString() + + sess := c.CreatePTYSession(id, "") + require.NotNil(t, sess) + + got := c.GetPTYSession(id) + require.NotNil(t, got) + require.Equal(t, id, got.id) + + running, offset, err := c.GetPTYSessionStatus(id) + require.NoError(t, err) + require.False(t, running) + require.Equal(t, int64(0), offset) + + require.NoError(t, c.DeletePTYSession(id)) + require.Nil(t, c.GetPTYSession(id)) + + require.ErrorIs(t, c.DeletePTYSession(id), ErrContextNotFound) +} + +// waitForLine reads from ch until target is found or timeout expires. +func waitForLine(ch <-chan string, target string, timeout time.Duration) bool { + deadline := time.After(timeout) + for { + select { + case line := <-ch: + if strings.Contains(line, target) { + return true + } + case <-deadline: + return false + } + } +} + +// discardAll drains r until EOF. +func discardAll(r io.Reader) { + _, _ = io.Copy(io.Discard, r) +} diff --git a/components/execd/pkg/runtime/pty_session_windows.go b/components/execd/pkg/runtime/pty_session_windows.go new file mode 100644 index 00000000..c52fcba1 --- /dev/null +++ b/components/execd/pkg/runtime/pty_session_windows.go @@ -0,0 +1,67 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows +// +build windows + +package runtime + +import ( + "errors" + "io" +) + +// ptySession is an opaque stub on Windows (real type in pty_session.go, !windows). +// All methods the controller layer calls must be present here so Windows cross-compilation succeeds. +type ptySession struct{} + +var errPTYSessionNotSupported = errors.New("pty session is not supported on windows") + +// IsPTYSessionSupported reports whether PTY sessions are supported on this platform. +func IsPTYSessionSupported() bool { return false } + +// NewPTYSessionID returns a new unique session ID (Windows stub). +func NewPTYSessionID() string { return "" } + +// CreatePTYSession is not supported on Windows. +func (c *Controller) CreatePTYSession(id, cwd string) *ptySession { return nil } //nolint:revive + +// GetPTYSession is not supported on Windows. +func (c *Controller) GetPTYSession(id string) *ptySession { return nil } //nolint:revive + +// DeletePTYSession is not supported on Windows. +func (c *Controller) DeletePTYSession(id string) error { //nolint:revive + return errPTYSessionNotSupported +} + +// GetPTYSessionStatus is not supported on Windows. +func (c *Controller) GetPTYSessionStatus(id string) (bool, int64, error) { //nolint:revive + return false, 0, errPTYSessionNotSupported +} + +// Method stubs so the controller layer can call them without build-tag guards. + +func (s *ptySession) LockWS() bool { return false } +func (s *ptySession) UnlockWS() {} +func (s *ptySession) IsRunning() bool { return false } +func (s *ptySession) IsPTY() bool { return false } +func (s *ptySession) ExitCode() int { return -1 } +func (s *ptySession) Done() <-chan struct{} { return nil } +func (s *ptySession) ReplayBuffer() *replayBuffer { return nil } +func (s *ptySession) StartPTY() error { return errPTYSessionNotSupported } +func (s *ptySession) StartPipe() error { return errPTYSessionNotSupported } +func (s *ptySession) WriteStdin(_ []byte) (int, error) { return 0, errPTYSessionNotSupported } +func (s *ptySession) AttachOutput() (io.Reader, io.Reader, func()) { return nil, nil, func() {} } +func (s *ptySession) SendSignal(_ string) {} +func (s *ptySession) ResizePTY(_, _ uint16) error { return nil } diff --git a/components/execd/pkg/runtime/replay_buffer.go b/components/execd/pkg/runtime/replay_buffer.go new file mode 100644 index 00000000..5285bea7 --- /dev/null +++ b/components/execd/pkg/runtime/replay_buffer.go @@ -0,0 +1,125 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtime + +import "sync" + +const replayBufferSize = 1 << 20 // 1 MiB + +// replayBuffer is a fixed-capacity circular byte buffer with a monotonic write counter. +// +// Invariant: head == total % size (where size is the buffer capacity). +// This means the byte at absolute offset o (o >= oldest) is stored at buf[o % size]. +type replayBuffer struct { + mu sync.Mutex + buf []byte + size int // == replayBufferSize (or smaller in tests) + head int // next write position; always == total % size + total int64 // monotonic byte counter (total bytes ever written) +} + +func newReplayBuffer() *replayBuffer { + return &replayBuffer{ + buf: make([]byte, replayBufferSize), + size: replayBufferSize, + } +} + +// write appends p to the buffer, evicting the oldest bytes when the buffer is full. +// The invariant head == total % size is preserved on every call. +func (r *replayBuffer) write(p []byte) { + if len(p) == 0 { + return + } + r.mu.Lock() + defer r.mu.Unlock() + + // When p is larger than the whole buffer we only keep the last size bytes, + // but we still advance total by the full len(p) to maintain the invariant. + if len(p) >= r.size { + skip := len(p) - r.size + r.total += int64(skip) + r.head = (r.head + skip) % r.size + p = p[skip:] + // len(p) == r.size now + } + + // len(p) <= r.size — split into at most two contiguous copies. + n := copy(r.buf[r.head:], p) + r.head = (r.head + n) % r.size + r.total += int64(n) + + if n < len(p) { + rest := p[n:] + copy(r.buf[r.head:], rest) + r.head = (r.head + len(rest)) % r.size + r.total += int64(len(rest)) + } +} + +// Total returns the total number of bytes ever written to the buffer. +func (r *replayBuffer) Total() int64 { + r.mu.Lock() + defer r.mu.Unlock() + return r.total +} + +// readFrom is the unexported thin wrapper used inside the runtime package. +func (r *replayBuffer) readFrom(offset int64) ([]byte, int64) { + return r.ReadFrom(offset) +} + +// ReadFrom returns a snapshot of all bytes starting from the given absolute byte offset. +// +// Returns (data, actualOffset) where actualOffset is the offset of the first returned byte: +// - If offset >= total, returns (nil, total) — caller is already caught up. +// - If offset < oldest retained byte, clamps to oldest (bytes were evicted). +// - Otherwise returns bytes [offset, total). +func (r *replayBuffer) ReadFrom(offset int64) ([]byte, int64) { + r.mu.Lock() + defer r.mu.Unlock() + + if offset >= r.total { + return nil, r.total + } + + // Oldest retained absolute offset. + oldest := r.total - int64(r.size) + if oldest < 0 { + oldest = 0 + } + if offset < oldest { + offset = oldest + } + + count := int(r.total - offset) + if count <= 0 { + return nil, r.total + } + + // Thanks to the invariant head == total % size, the byte at absolute offset o + // is stored at buf[o % size]. This holds whether or not the buffer has wrapped. + start := int(offset % int64(r.size)) + end := start + count + + result := make([]byte, count) + if end <= r.size { + copy(result, r.buf[start:end]) + } else { + n := copy(result, r.buf[start:]) + copy(result[n:], r.buf[:count-n]) + } + return result, offset +} diff --git a/components/execd/pkg/runtime/replay_buffer_test.go b/components/execd/pkg/runtime/replay_buffer_test.go new file mode 100644 index 00000000..4c0227fd --- /dev/null +++ b/components/execd/pkg/runtime/replay_buffer_test.go @@ -0,0 +1,154 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtime + +import ( + "bytes" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReplayBuffer_BasicWriteRead(t *testing.T) { + rb := newReplayBuffer() + rb.write([]byte("hello")) + rb.write([]byte(" world")) + + data, off := rb.readFrom(0) + require.Equal(t, int64(0), off) + require.Equal(t, []byte("hello world"), data) + require.Equal(t, int64(11), rb.Total()) +} + +func TestReplayBuffer_ReadFromMiddle(t *testing.T) { + rb := newReplayBuffer() + rb.write([]byte("abcde")) + + data, off := rb.readFrom(2) + require.Equal(t, int64(2), off) + require.Equal(t, []byte("cde"), data) +} + +func TestReplayBuffer_ReadFromCurrent(t *testing.T) { + rb := newReplayBuffer() + rb.write([]byte("abc")) + + data, off := rb.readFrom(3) + require.Nil(t, data, "should return nil when caught up") + require.Equal(t, int64(3), off) +} + +func TestReplayBuffer_CircularEviction(t *testing.T) { + rb := &replayBuffer{ + buf: make([]byte, 8), + size: 8, + } + + // Write 6 bytes: "abcdef" + rb.write([]byte("abcdef")) + require.Equal(t, int64(6), rb.Total()) + + // Write 4 more bytes: now total=10, oldest=2 (evicted "ab") + rb.write([]byte("ghij")) + require.Equal(t, int64(10), rb.Total()) + + // offset 0 should be clamped to oldest=2 + data, off := rb.readFrom(0) + require.Equal(t, int64(2), off) + require.Equal(t, []byte("cdefghij"), data) + + // Read from offset 5 (within retained range) + data, off = rb.readFrom(5) + require.Equal(t, int64(5), off) + require.Equal(t, []byte("fghij"), data) +} + +func TestReplayBuffer_LargeGap(t *testing.T) { + rb := &replayBuffer{ + buf: make([]byte, 4), + size: 4, + } + // Write "ABCDEF" — total=6, oldest=2, retained="CDEF" + rb.write([]byte("ABCDEF")) + + // Requesting from 0 should clamp to oldest=2 + data, off := rb.readFrom(0) + require.Equal(t, int64(2), off) + require.Equal(t, []byte("CDEF"), data) + + // Requesting from 1 should also clamp to oldest=2 + data, off = rb.readFrom(1) + require.Equal(t, int64(2), off) + require.Equal(t, []byte("CDEF"), data) +} + +func TestReplayBuffer_Concurrent(t *testing.T) { + rb := newReplayBuffer() + chunk := bytes.Repeat([]byte("x"), 1024) + + var wg sync.WaitGroup + for range 16 { + wg.Add(1) + go func() { + defer wg.Done() + for range 64 { + rb.write(chunk) + } + }() + } + for range 4 { + wg.Add(1) + go func() { + defer wg.Done() + for range 32 { + rb.readFrom(0) + rb.Total() + } + }() + } + wg.Wait() + + total := rb.Total() + require.Equal(t, int64(16*64*1024), total) +} + +func TestReplayBuffer_ExactlyFull(t *testing.T) { + rb := &replayBuffer{ + buf: make([]byte, 4), + size: 4, + } + rb.write([]byte("1234")) + require.Equal(t, int64(4), rb.Total()) + + data, off := rb.readFrom(0) + require.Equal(t, int64(0), off) + require.Equal(t, []byte("1234"), data) +} + +func TestReplayBuffer_WriteWrapsCorrectly(t *testing.T) { + rb := &replayBuffer{ + buf: make([]byte, 4), + size: 4, + } + // Write "ABCD" — buffer full + rb.write([]byte("ABCD")) + // Write "EF" — evicts "AB", retained "CDEF" + rb.write([]byte("EF")) + + data, off := rb.readFrom(0) + require.Equal(t, int64(2), off, "offset should be clamped to oldest=2") + require.Equal(t, []byte("CDEF"), data) +} diff --git a/components/execd/pkg/web/controller/pty_controller.go b/components/execd/pkg/web/controller/pty_controller.go new file mode 100644 index 00000000..fc9e5eb9 --- /dev/null +++ b/components/execd/pkg/web/controller/pty_controller.go @@ -0,0 +1,133 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package controller + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/alibaba/opensandbox/execd/pkg/runtime" + "github.com/alibaba/opensandbox/execd/pkg/web/model" +) + +// PTYController handles /pty/* REST endpoints. +type PTYController struct { + *basicController +} + +// NewPTYController creates a new PTYController from the current Gin context. +func NewPTYController(ctx *gin.Context) *PTYController { + return &PTYController{basicController: newBasicController(ctx)} +} + +// CreatePTYSession handles POST /pty. +// Creates a new PTY session and returns its session_id. +func (c *PTYController) CreatePTYSession() { + if !runtime.IsPTYSessionSupported() { + c.RespondError( + http.StatusNotImplemented, + model.ErrorCodeNotSupported, + "pty sessions are not supported on this platform", + ) + return + } + + var req model.CreatePTYSessionRequest + if err := c.bindJSON(&req); err != nil && !errors.Is(err, io.EOF) { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeInvalidRequest, + fmt.Sprintf("error parsing request: %v", err), + ) + return + } + + id := runtime.NewPTYSessionID() + codeRunner.CreatePTYSession(id, req.Cwd) + c.ctx.JSON(http.StatusCreated, model.CreatePTYSessionResponse{SessionID: id}) +} + +// GetPTYSessionStatus handles GET /pty/:sessionId. +func (c *PTYController) GetPTYSessionStatus() { + id := c.ctx.Param("sessionId") + if id == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + running, offset, err := codeRunner.GetPTYSessionStatus(id) + if err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("pty session %s not found", id), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error getting pty session status: %v", err), + ) + return + } + + c.RespondSuccess(model.PTYSessionStatusResponse{ + SessionID: id, + Running: running, + OutputOffset: offset, + }) +} + +// DeletePTYSession handles DELETE /pty/:sessionId. +func (c *PTYController) DeletePTYSession() { + id := c.ctx.Param("sessionId") + if id == "" { + c.RespondError( + http.StatusBadRequest, + model.ErrorCodeMissingQuery, + "missing path parameter 'sessionId'", + ) + return + } + + if err := codeRunner.DeletePTYSession(id); err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("pty session %s not found", id), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error deleting pty session: %v", err), + ) + return + } + + c.RespondSuccess(nil) +} diff --git a/components/execd/pkg/web/controller/pty_ws.go b/components/execd/pkg/web/controller/pty_ws.go new file mode 100644 index 00000000..31978ab6 --- /dev/null +++ b/components/execd/pkg/web/controller/pty_ws.go @@ -0,0 +1,358 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package controller + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + + "github.com/alibaba/opensandbox/execd/pkg/log" + "github.com/alibaba/opensandbox/execd/pkg/web/model" +) + +var wsUpgrader = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + // Allow all origins — execd runs behind a trusted reverse proxy. + CheckOrigin: func(r *http.Request) bool { return true }, +} + +const ( + wsPingInterval = 30 * time.Second + wsReadDeadline = 60 * time.Second + wsWriteDeadline = 10 * time.Second +) + +// PTYSessionWebSocket handles GET /pty/:sessionId/ws. +// +// 1. Look up session → 404 before upgrade if missing +// 2. Acquire exclusive WS lock → 409 if already held +// 3. Upgrade HTTP → WebSocket +// 4. Start bash if not already running +// 5. Snapshot replay buffer BEFORE AttachOutput (prevents duplicate delivery) +// 6. AttachOutput (live pipe now active) +// 7. defer: detach → pumpWg.Wait → UnlockWS +// 8. Send replay frame if snapshot non-empty +// 9. Send connected frame +// 10. Start RFC 6455 ping, streamPump(s), exitWatcher goroutines +// 11. Read loop: dispatch client frames +func PTYSessionWebSocket(ctx *gin.Context) { + id := ctx.Param("sessionId") + if id == "" { + ctx.JSON(http.StatusBadRequest, model.ErrorResponse{ + Code: model.ErrorCodeMissingQuery, + Message: "missing path parameter 'sessionId'", + }) + return + } + + // 1. Look up session — must happen before upgrade so we can return HTTP errors. + session := codeRunner.GetPTYSession(id) + if session == nil { + ctx.JSON(http.StatusNotFound, model.ErrorResponse{ + Code: model.ErrorCodeContextNotFound, + Message: "pty session " + id + " not found", + }) + return + } + + // 2. Acquire exclusive WS lock. + if !session.LockWS() { + ctx.JSON(http.StatusConflict, model.ErrorResponse{ + Code: model.WSErrCodeAlreadyConnected, + Message: "another client is already connected to pty session " + id, + }) + return + } + // NOTE: the lock is released at the very end of this function (see defer below), + // only after all pump goroutines have exited. + + // 3. Upgrade HTTP connection to WebSocket. + conn, err := wsUpgrader.Upgrade(ctx.Writer, ctx.Request, nil) + if err != nil { + log.Warning("pty ws upgrade failed for session %s: %v", id, err) + session.UnlockWS() + return + } + + // Resolve query parameters. + pipeMode := ctx.Query("pty") == "0" + since := queryInt64(ctx.Query("since"), 0) + + // 4. Start bash if not already running. + if !session.IsRunning() { + var startErr error + if pipeMode { + startErr = session.StartPipe() + } else { + startErr = session.StartPTY() + } + if startErr != nil { + log.Warning("pty start failed for session %s: %v", id, startErr) + writeErrFrame(conn, model.WSErrCodeStartFailed, startErr.Error()) + _ = conn.Close() + session.UnlockWS() + return + } + } + + // 5. Snapshot replay buffer BEFORE attaching live pipe. + snapshotBytes, snapshotOffset := session.ReplayBuffer().ReadFrom(since) + + // 6. Attach per-connection output pipe — live sink is now active. + stdoutR, stderrR, detach := session.AttachOutput() + + // 7. Deferred cleanup order: detach writers → wait for pump goroutines → unlock WS. + var pumpWg sync.WaitGroup + defer func() { + detach() + pumpWg.Wait() + session.UnlockWS() + }() + + // cancelCh is closed to signal all goroutines to stop. + cancelCh := make(chan struct{}) + cancelOnce := sync.OnceFunc(func() { close(cancelCh) }) + + // connMu serialises all writes to conn (gorilla/websocket requires single-writer). + var connMu sync.Mutex + + writeJSON := func(v any) error { + connMu.Lock() + defer connMu.Unlock() + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + return conn.WriteJSON(v) + } + + closeConn := func(code int, text string) { + connMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + _ = conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(code, text)) + connMu.Unlock() + _ = conn.Close() + } + + // Set initial read deadline; pong handler resets it. + _ = conn.SetReadDeadline(time.Now().Add(wsReadDeadline)) + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(wsReadDeadline)) + }) + + // 8. Send replay frame if there is missed output. + if len(snapshotBytes) > 0 { + frame := make([]byte, 1+8+len(snapshotBytes)) + frame[0] = model.BinReplay + binary.BigEndian.PutUint64(frame[1:9], uint64(snapshotOffset)) + copy(frame[9:], snapshotBytes) + // No connMu needed — pump goroutines not yet started. + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + if err2 := conn.WriteMessage(websocket.BinaryMessage, frame); err2 != nil { + log.Warning("pty ws send replay for session %s: %v", id, err2) + return + } + } + + // 9. Send connected frame. + mode := "pty" + if !session.IsPTY() { + mode = "pipe" + } + if err2 := writeJSON(model.ServerFrame{ + Type: "connected", + SessionID: id, + Mode: mode, + }); err2 != nil { + log.Warning("pty ws send connected for session %s: %v", id, err2) + return + } + + // 10a. RFC 6455 binary ping goroutine (30 s interval). + go func() { + t := time.NewTicker(wsPingInterval) + defer t.Stop() + for { + select { + case <-cancelCh: + return + case <-t.C: + connMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + pingErr := conn.WriteMessage(websocket.PingMessage, nil) + connMu.Unlock() + if pingErr != nil { + cancelOnce() + return + } + } + } + }() + + // streamPump reads raw chunks from r and sends them as binary frames over WS. + streamPump := func(r io.Reader, typeByte byte, name string) { + defer pumpWg.Done() + const chunkSize = 32 * 1024 + frame := make([]byte, 1+chunkSize) // single allocation for session lifetime + frame[0] = typeByte + for { + select { + case <-cancelCh: + return + default: + } + n, readErr := r.Read(frame[1:]) + if n > 0 { + connMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + writeErr := conn.WriteMessage(websocket.BinaryMessage, frame[:1+n]) + connMu.Unlock() + if writeErr != nil { + log.Warning("pty ws write %s for session %s: %v", name, id, writeErr) + cancelOnce() + return + } + } + if readErr != nil { + // io.EOF or io.ErrClosedPipe when detach() closes the PipeWriter. + return + } + } + } + + // 10b. Launch stdout pump. + pumpWg.Add(1) + go streamPump(stdoutR, model.BinStdout, "stdout") + + // 10c. Launch stderr pump (pipe mode only). + if stderrR != nil { + pumpWg.Add(1) + go streamPump(stderrR, model.BinStderr, "stderr") + } + + // 10d. Exit watcher: waits for the process to exit, then sends exit frame + // and closes the WS connection immediately (unblocks ReadJSON in the read loop). + go func() { + doneCh := session.Done() + if doneCh == nil { + return + } + select { + case <-doneCh: + case <-cancelCh: + return + } + exitCode := session.ExitCode() + _ = writeJSON(model.ServerFrame{ + Type: "exit", + ExitCode: &exitCode, + }) + closeConn(websocket.CloseNormalClosure, "process exited") + cancelOnce() + }() + + // 11. Client read loop. + for { + select { + case <-cancelCh: + return + default: + } + + msgType, data, err2 := conn.ReadMessage() + if err2 != nil { + cancelOnce() + return + } + + // Any incoming frame resets the read deadline. + _ = conn.SetReadDeadline(time.Now().Add(wsReadDeadline)) + + switch msgType { + + case websocket.BinaryMessage: + if len(data) == 0 { + continue + } + if data[0] != model.BinStdin { + continue // only stdin expected C→S + } + if _, writeErr := session.WriteStdin(data[1:]); writeErr != nil { + _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeStdinWriteFailed, + Error: writeErr.Error()}) + cancelOnce() + return + } + + case websocket.TextMessage: + var frame model.ClientFrame + if json.Unmarshal(data, &frame) != nil { + continue + } + switch frame.Type { + case "stdin": + // wscat / debug fallback: plain UTF-8 text, no base64. + if _, writeErr := session.WriteStdin([]byte(frame.Data)); writeErr != nil { + _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeStdinWriteFailed, + Error: writeErr.Error()}) + cancelOnce() + return + } + case "signal": + session.SendSignal(frame.Signal) + case "resize": + if frame.Cols > 0 && frame.Rows > 0 { + if resErr := session.ResizePTY(uint16(frame.Cols), uint16(frame.Rows)); resErr != nil { + log.Warning("pty resize session %s: %v", id, resErr) + } + } + case "ping": + _ = writeJSON(model.ServerFrame{Type: "pong"}) + default: + _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeInvalidFrame, + Error: fmt.Sprintf("unknown frame type %q", frame.Type)}) + } + } + } +} + +// writeErrFrame sends a JSON error frame. Safe to call before pump goroutines start. +func writeErrFrame(conn *websocket.Conn, code, message string) { + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + _ = conn.WriteJSON(model.ServerFrame{ + Type: "error", + Error: message, + Code: code, + }) +} + +// queryInt64 parses a decimal query string value, returning defaultVal on error. +func queryInt64(s string, defaultVal int64) int64 { + if s == "" { + return defaultVal + } + var n int64 + if _, err := fmt.Sscanf(s, "%d", &n); err != nil { + return defaultVal + } + return n +} diff --git a/components/execd/pkg/web/controller/pty_ws_test.go b/components/execd/pkg/web/controller/pty_ws_test.go new file mode 100644 index 00000000..d3cfb91e --- /dev/null +++ b/components/execd/pkg/web/controller/pty_ws_test.go @@ -0,0 +1,397 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package controller + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os/exec" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + + "github.com/alibaba/opensandbox/execd/pkg/runtime" + "github.com/alibaba/opensandbox/execd/pkg/web/model" +) + +// buildPTYRouter assembles a minimal Gin router with only the /pty routes, +// avoiding any import cycle with pkg/web. +func buildPTYRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(gin.Recovery()) + + pty := r.Group("/pty") + { + pty.POST("", func(ctx *gin.Context) { + NewPTYController(ctx).CreatePTYSession() + }) + pty.GET("/:sessionId", func(ctx *gin.Context) { + NewPTYController(ctx).GetPTYSessionStatus() + }) + pty.DELETE("/:sessionId", func(ctx *gin.Context) { + NewPTYController(ctx).DeletePTYSession() + }) + pty.GET("/:sessionId/ws", PTYSessionWebSocket) + } + return r +} + +// newPTYTestServer creates a test HTTP server with fresh codeRunner and PTY routes only. +func newPTYTestServer(t *testing.T) *httptest.Server { + t.Helper() + prev := codeRunner + codeRunner = runtime.NewController("", "") + t.Cleanup(func() { codeRunner = prev }) + return httptest.NewServer(buildPTYRouter()) +} + +// wsDialPTY dials a WebSocket URL; accepts an optional extra query string. +func wsDialPTY(t *testing.T, baseURL, path, query string) *websocket.Conn { + t.Helper() + u := "ws" + strings.TrimPrefix(baseURL+path, "http") + if query != "" { + u += "?" + query + } + conn, resp, err := websocket.DefaultDialer.Dial(u, nil) + if err != nil { + if resp != nil { + t.Fatalf("WS dial %s: %v (HTTP %d)", u, err, resp.StatusCode) + } + t.Fatalf("WS dial %s: %v", u, err) + } + t.Cleanup(func() { _ = conn.Close() }) + return conn +} + +// wsDialExpectHTTP dials and returns the HTTP status without upgrading (for 4xx cases). +func wsDialExpectHTTP(t *testing.T, baseURL, path, query string) int { + t.Helper() + u := "ws" + strings.TrimPrefix(baseURL+path, "http") + if query != "" { + u += "?" + query + } + _, resp, err := websocket.DefaultDialer.Dial(u, nil) + require.Error(t, err, "expected dial to fail with HTTP error") + require.NotNil(t, resp) + return resp.StatusCode +} + +// ptyCreateSession calls POST /pty and returns the session_id. +func ptyCreateSession(t *testing.T, srv *httptest.Server) string { + t.Helper() + resp, err := http.Post(srv.URL+"/pty", "application/json", strings.NewReader(`{}`)) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusCreated, resp.StatusCode) + var r model.CreatePTYSessionResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&r)) + require.NotEmpty(t, r.SessionID) + return r.SessionID +} + +// ptyReadFrame reads the next frame, handling both binary data frames and JSON control frames. +func ptyReadFrame(conn *websocket.Conn, timeout time.Duration) (model.ServerFrame, error) { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + msgType, raw, err := conn.ReadMessage() + if err != nil { + return model.ServerFrame{}, err + } + if msgType == websocket.TextMessage { + var f model.ServerFrame + return f, json.Unmarshal(raw, &f) + } + // Binary data frame: decode type byte into ServerFrame. + if len(raw) == 0 { + return model.ServerFrame{}, errors.New("empty binary frame") + } + switch raw[0] { + case model.BinStdout: + return model.ServerFrame{Type: "stdout", Data: string(raw[1:])}, nil + case model.BinStderr: + return model.ServerFrame{Type: "stderr", Data: string(raw[1:])}, nil + case model.BinReplay: + if len(raw) < 9 { + return model.ServerFrame{}, fmt.Errorf("replay frame too short: %d bytes", len(raw)) + } + offset := int64(binary.BigEndian.Uint64(raw[1:9])) + return model.ServerFrame{Type: "replay", Data: string(raw[9:]), Offset: offset}, nil + } + return model.ServerFrame{}, fmt.Errorf("unknown binary frame type 0x%02x", raw[0]) +} + +// ptyWaitFrame reads frames until one with the given type is found. +func ptyWaitFrame(t *testing.T, conn *websocket.Conn, wantType string, timeout time.Duration) model.ServerFrame { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + f, err := ptyReadFrame(conn, time.Until(deadline)) + if err != nil { + t.Fatalf("ptyWaitFrame(%q): %v", wantType, err) + } + if f.Type == wantType { + return f + } + } + t.Fatalf("ptyWaitFrame(%q): timed out after %s", wantType, timeout) + return model.ServerFrame{} +} + +// ptyOutputContains reads frames until stdout/stderr/replay contains substr. +func ptyOutputContains(t *testing.T, conn *websocket.Conn, substr string, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + f, err := ptyReadFrame(conn, time.Until(deadline)) + if err != nil { + t.Fatalf("ptyOutputContains(%q): read error: %v", substr, err) + } + if f.Type == "stdout" || f.Type == "stderr" || f.Type == "replay" { + if strings.Contains(f.Data, substr) { + return + } + } + } + t.Fatalf("ptyOutputContains(%q): timed out", substr) +} + +// ptyWriteStdin sends a binary stdin frame (production path). +func ptyWriteStdin(t *testing.T, conn *websocket.Conn, text string) { + t.Helper() + frame := make([]byte, 1+len(text)) + frame[0] = model.BinStdin + copy(frame[1:], text) + require.NoError(t, conn.WriteMessage(websocket.BinaryMessage, frame)) +} + +// --- Tests --- + +func TestPTYWS_UnknownSessionReturns404(t *testing.T) { + srv := newPTYTestServer(t) + defer srv.Close() + + code := wsDialExpectHTTP(t, srv.URL, "/pty/nonexistent/ws", "") + require.Equal(t, http.StatusNotFound, code) +} + +func TestPTYWS_AlreadyConnectedReturns409(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn1 := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + ptyWaitFrame(t, conn1, "connected", 10*time.Second) + + code := wsDialExpectHTTP(t, srv.URL, "/pty/"+id+"/ws", "") + require.Equal(t, http.StatusConflict, code) +} + +func TestPTYWS_ConnectedFramePTYMode(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + f := ptyWaitFrame(t, conn, "connected", 10*time.Second) + + require.Equal(t, "connected", f.Type) + require.Equal(t, id, f.SessionID) + require.Equal(t, "pty", f.Mode) +} + +func TestPTYWS_StdinForwarding(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + ptyWaitFrame(t, conn, "connected", 10*time.Second) + + ptyWriteStdin(t, conn, "echo hello_ws\n") + ptyOutputContains(t, conn, "hello_ws", 8*time.Second) +} + +func TestPTYWS_PingPong(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + ptyWaitFrame(t, conn, "connected", 10*time.Second) + + require.NoError(t, conn.WriteJSON(model.ClientFrame{Type: "ping"})) + f := ptyWaitFrame(t, conn, "pong", 5*time.Second) + require.Equal(t, "pong", f.Type) +} + +func TestPTYWS_ExitFrame(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + ptyWaitFrame(t, conn, "connected", 10*time.Second) + + ptyWriteStdin(t, conn, "exit 0\n") + + f := ptyWaitFrame(t, conn, "exit", 10*time.Second) + require.Equal(t, "exit", f.Type) + require.NotNil(t, f.ExitCode) + require.Equal(t, 0, *f.ExitCode) +} + +func TestPTYWS_ReplayOnReconnect(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + + // First connection: produce output. + conn1 := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + ptyWaitFrame(t, conn1, "connected", 10*time.Second) + ptyWriteStdin(t, conn1, "echo replay_test\n") + ptyOutputContains(t, conn1, "replay_test", 8*time.Second) + + // Check offset via REST. + resp, err := http.Get(srv.URL + "/pty/" + id) + require.NoError(t, err) + defer resp.Body.Close() + var status model.PTYSessionStatusResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&status)) + require.True(t, status.OutputOffset > 0) + + // Disconnect. + _ = conn1.Close() + time.Sleep(100 * time.Millisecond) + + // Reconnect from offset 0 — should receive a replay frame. + conn2 := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "since=0") + + deadline := time.Now().Add(8 * time.Second) + gotReplay := false + for time.Now().Before(deadline) { + f, err2 := ptyReadFrame(conn2, time.Until(deadline)) + if err2 != nil { + break + } + if f.Type == "replay" && strings.Contains(f.Data, "replay_test") { + gotReplay = true + break + } + } + require.True(t, gotReplay, "expected replay frame containing 'replay_test'") +} + +func TestPTYWS_ResizeFrame(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "") + ptyWaitFrame(t, conn, "connected", 10*time.Second) + + require.NoError(t, conn.WriteJSON(model.ClientFrame{ + Type: "resize", + Cols: 120, + Rows: 40, + })) + time.Sleep(100 * time.Millisecond) + + ptyWriteStdin(t, conn, "stty size\n") + ptyOutputContains(t, conn, "40 120", 8*time.Second) +} + +func TestPTYWS_PipeModeConnectedFrame(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found") + } + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + conn := wsDialPTY(t, srv.URL, "/pty/"+id+"/ws", "pty=0") + + f := ptyWaitFrame(t, conn, "connected", 10*time.Second) + require.Equal(t, "pipe", f.Mode) +} + +func TestPTYWS_RESTGetStatus(t *testing.T) { + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + + resp, err := http.Get(srv.URL + "/pty/" + id) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + var s model.PTYSessionStatusResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&s)) + require.Equal(t, id, s.SessionID) + require.False(t, s.Running) +} + +func TestPTYWS_RESTDeleteSession(t *testing.T) { + srv := newPTYTestServer(t) + defer srv.Close() + + id := ptyCreateSession(t, srv) + + req, err := http.NewRequest(http.MethodDelete, srv.URL+"/pty/"+id, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + req2, _ := http.NewRequest(http.MethodDelete, srv.URL+"/pty/"+id, nil) + resp2, err := http.DefaultClient.Do(req2) + require.NoError(t, err) + _ = resp2.Body.Close() + require.Equal(t, http.StatusNotFound, resp2.StatusCode) +} diff --git a/components/execd/pkg/web/model/error.go b/components/execd/pkg/web/model/error.go index 80e0ef23..46c21ac5 100644 --- a/components/execd/pkg/web/model/error.go +++ b/components/execd/pkg/web/model/error.go @@ -26,6 +26,7 @@ const ( ErrorCodeFileNotFound ErrorCode = "FILE_NOT_FOUND" ErrorCodeUnknown ErrorCode = "UNKNOWN" ErrorCodeContextNotFound ErrorCode = "CONTEXT_NOT_FOUND" + ErrorCodeNotSupported ErrorCode = "NOT_SUPPORTED" ) type ErrorResponse struct { diff --git a/components/execd/pkg/web/model/pty.go b/components/execd/pkg/web/model/pty.go new file mode 100644 index 00000000..6d9a3e9e --- /dev/null +++ b/components/execd/pkg/web/model/pty.go @@ -0,0 +1,32 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +// CreatePTYSessionRequest is the request body for POST /pty. +type CreatePTYSessionRequest struct { + Cwd string `json:"cwd,omitempty"` +} + +// CreatePTYSessionResponse is the response for POST /pty. +type CreatePTYSessionResponse struct { + SessionID string `json:"session_id"` +} + +// PTYSessionStatusResponse is the response for GET /pty/:sessionId. +type PTYSessionStatusResponse struct { + SessionID string `json:"session_id"` + Running bool `json:"running"` + OutputOffset int64 `json:"output_offset"` +} diff --git a/components/execd/pkg/web/model/pty_ws.go b/components/execd/pkg/web/model/pty_ws.go new file mode 100644 index 00000000..fcb564e8 --- /dev/null +++ b/components/execd/pkg/web/model/pty_ws.go @@ -0,0 +1,55 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package model + +// ClientFrame is a JSON frame sent from the WebSocket client to the server. +type ClientFrame struct { + Type string `json:"type"` + Data string `json:"data,omitempty"` + Cols int `json:"cols,omitempty"` + Rows int `json:"rows,omitempty"` + Signal string `json:"signal,omitempty"` +} + +// ServerFrame is a JSON frame sent from the server to the WebSocket client. +type ServerFrame struct { + Type string `json:"type"` + SessionID string `json:"session_id,omitempty"` + Mode string `json:"mode,omitempty"` + Data string `json:"data,omitempty"` + Offset int64 `json:"offset,omitempty"` + ExitCode *int `json:"exit_code,omitempty"` + Error string `json:"error,omitempty"` + Code string `json:"code,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// Binary WebSocket frame type bytes — prefix byte for all binary frames. +const ( + BinStdin byte = 0x00 // Client → Server: raw stdin bytes + BinStdout byte = 0x01 // Server → Client: raw stdout bytes + BinStderr byte = 0x02 // Server → Client: raw stderr bytes (pipe mode) + BinReplay byte = 0x03 // Server → Client: [8 bytes int64 BE offset][raw bytes] +) + +// WebSocket error codes sent in ServerFrame.Code. +const ( + WSErrCodeSessionGone = "SESSION_GONE" + WSErrCodeStartFailed = "START_FAILED" + WSErrCodeStdinWriteFailed = "STDIN_WRITE_FAILED" + WSErrCodeInvalidFrame = "INVALID_FRAME" + WSErrCodeAlreadyConnected = "ALREADY_CONNECTED" + WSErrCodeRuntimeError = "RUNTIME_ERROR" +) diff --git a/components/execd/pkg/web/router.go b/components/execd/pkg/web/router.go index 8894257d..47f1931e 100644 --- a/components/execd/pkg/web/router.go +++ b/components/execd/pkg/web/router.go @@ -83,6 +83,14 @@ func NewRouter(accessToken string) *gin.Engine { metric.GET("/watch", withMetric(func(c *controller.MetricController) { c.WatchMetrics() })) } + pty := r.Group("/pty") + { + pty.POST("", withPTY(func(c *controller.PTYController) { c.CreatePTYSession() })) + pty.GET("/:sessionId", withPTY(func(c *controller.PTYController) { c.GetPTYSessionStatus() })) + pty.DELETE("/:sessionId", withPTY(func(c *controller.PTYController) { c.DeletePTYSession() })) + pty.GET("/:sessionId/ws", controller.PTYSessionWebSocket) + } + return r } @@ -104,6 +112,12 @@ func withMetric(fn func(*controller.MetricController)) gin.HandlerFunc { } } +func withPTY(fn func(*controller.PTYController)) gin.HandlerFunc { + return func(ctx *gin.Context) { + fn(controller.NewPTYController(ctx)) + } +} + func accessTokenMiddleware(token string) gin.HandlerFunc { return func(ctx *gin.Context) { if token == "" { From 84e9b8dd91a96237c0f82eb57133511ff8b4bc78 Mon Sep 17 00:00:00 2001 From: Jon Date: Sat, 28 Mar 2026 18:20:18 +0800 Subject: [PATCH 2/4] fix(execd/pty): close output-loss race and map unsupported to 501 --- components/execd/pkg/runtime/pty_session.go | 60 ++++++++++++++++--- .../execd/pkg/runtime/pty_session_windows.go | 3 + .../pkg/web/controller/pty_controller.go | 18 ++++++ components/execd/pkg/web/controller/pty_ws.go | 12 ++-- 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/components/execd/pkg/runtime/pty_session.go b/components/execd/pkg/runtime/pty_session.go index 84fa3493..453c9de3 100644 --- a/components/execd/pkg/runtime/pty_session.go +++ b/components/execd/pkg/runtime/pty_session.go @@ -247,9 +247,7 @@ func (s *ptySession) broadcastPTY() { for { n, err := s.ptmx.Read(buf) if n > 0 { - chunk := buf[:n] - s.replay.write(chunk) - s.fanout(chunk, true) + s.writeAndFanout(buf[:n], true) } if err != nil { // EIO or EOF when the child exits — normal termination @@ -264,9 +262,7 @@ func (s *ptySession) broadcastPipe(r *os.File, isStdout bool) { for { n, err := r.Read(buf) if n > 0 { - chunk := buf[:n] - s.replay.write(chunk) - s.fanout(chunk, isStdout) + s.writeAndFanout(buf[:n], isStdout) } if err != nil { break @@ -275,10 +271,15 @@ func (s *ptySession) broadcastPipe(r *os.File, isStdout bool) { _ = r.Close() } -// fanout delivers chunk to the current per-connection PipeWriter (if any). -// If isStdout is false, it targets stderrW instead. -func (s *ptySession) fanout(chunk []byte, isStdout bool) { +// writeAndFanout writes chunk to the replay buffer and delivers it to the +// active per-connection pipe, atomically under outMu. +// +// Holding outMu across both operations closes the window where bytes written +// to replay after ReadFrom but before AttachOutput would be silently dropped. +// Lock order is always outMu → replay.mu (both paths), so no deadlock is possible. +func (s *ptySession) writeAndFanout(chunk []byte, isStdout bool) { s.outMu.Lock() + s.replay.write(chunk) // acquires replay.mu inside (outMu → replay.mu) var w *io.PipeWriter if isStdout { w = s.stdoutW @@ -392,6 +393,47 @@ func (s *ptySession) AttachOutput() (io.Reader, io.Reader, func()) { return stdoutR, stderrR, detach } +// AttachOutputWithSnapshot atomically snapshots the replay buffer and attaches +// the per-connection output pipe, eliminating the output-loss window that exists +// when ReadFrom and AttachOutput are called separately. +// +// Must be used together with writeAndFanout (which holds outMu during both +// replay.write and the fanout pointer read). +// +// Lock order is always outMu → replay.mu (both paths), so no deadlock is possible. +// +// Returns (stdoutR, stderrR [nil in PTY mode], detach, snapshotBytes, snapshotOffset). +func (s *ptySession) AttachOutputWithSnapshot(since int64) (io.Reader, io.Reader, func(), []byte, int64) { + stdoutR, stdoutW := io.Pipe() + var stderrR io.Reader + var stderrW *io.PipeWriter + if !s.isPTY { + stderrR, stderrW = io.Pipe() + } + + s.outMu.Lock() + snapshotBytes, snapshotOffset := s.replay.ReadFrom(since) // acquires replay.mu inside + s.stdoutW = stdoutW + if stderrW != nil { + s.stderrW = stderrW + } + s.outMu.Unlock() + + detach := func() { + s.outMu.Lock() + s.stdoutW = nil + if stderrW != nil { + s.stderrW = nil + } + s.outMu.Unlock() + _ = stdoutW.Close() + if stderrW != nil { + _ = stderrW.Close() + } + } + return stdoutR, stderrR, detach, snapshotBytes, snapshotOffset +} + // SendSignal sends the named signal to the process group. // Recognised names: SIGINT, SIGTERM, SIGKILL, SIGQUIT, SIGHUP. func (s *ptySession) SendSignal(name string) { diff --git a/components/execd/pkg/runtime/pty_session_windows.go b/components/execd/pkg/runtime/pty_session_windows.go index c52fcba1..77b7bdd6 100644 --- a/components/execd/pkg/runtime/pty_session_windows.go +++ b/components/execd/pkg/runtime/pty_session_windows.go @@ -63,5 +63,8 @@ func (s *ptySession) StartPTY() error { return func (s *ptySession) StartPipe() error { return errPTYSessionNotSupported } func (s *ptySession) WriteStdin(_ []byte) (int, error) { return 0, errPTYSessionNotSupported } func (s *ptySession) AttachOutput() (io.Reader, io.Reader, func()) { return nil, nil, func() {} } +func (s *ptySession) AttachOutputWithSnapshot(_ int64) (io.Reader, io.Reader, func(), []byte, int64) { + return nil, nil, func() {}, nil, 0 +} func (s *ptySession) SendSignal(_ string) {} func (s *ptySession) ResizePTY(_, _ uint16) error { return nil } diff --git a/components/execd/pkg/web/controller/pty_controller.go b/components/execd/pkg/web/controller/pty_controller.go index fc9e5eb9..065f93ec 100644 --- a/components/execd/pkg/web/controller/pty_controller.go +++ b/components/execd/pkg/web/controller/pty_controller.go @@ -65,6 +65,15 @@ func (c *PTYController) CreatePTYSession() { // GetPTYSessionStatus handles GET /pty/:sessionId. func (c *PTYController) GetPTYSessionStatus() { + if !runtime.IsPTYSessionSupported() { + c.RespondError( + http.StatusNotImplemented, + model.ErrorCodeNotSupported, + "pty sessions are not supported on this platform", + ) + return + } + id := c.ctx.Param("sessionId") if id == "" { c.RespondError( @@ -102,6 +111,15 @@ func (c *PTYController) GetPTYSessionStatus() { // DeletePTYSession handles DELETE /pty/:sessionId. func (c *PTYController) DeletePTYSession() { + if !runtime.IsPTYSessionSupported() { + c.RespondError( + http.StatusNotImplemented, + model.ErrorCodeNotSupported, + "pty sessions are not supported on this platform", + ) + return + } + id := c.ctx.Param("sessionId") if id == "" { c.RespondError( diff --git a/components/execd/pkg/web/controller/pty_ws.go b/components/execd/pkg/web/controller/pty_ws.go index 31978ab6..72b227b5 100644 --- a/components/execd/pkg/web/controller/pty_ws.go +++ b/components/execd/pkg/web/controller/pty_ws.go @@ -49,8 +49,7 @@ const ( // 2. Acquire exclusive WS lock → 409 if already held // 3. Upgrade HTTP → WebSocket // 4. Start bash if not already running -// 5. Snapshot replay buffer BEFORE AttachOutput (prevents duplicate delivery) -// 6. AttachOutput (live pipe now active) +// 5+6. AtomicAttachOutputWithSnapshot (snapshot + attach under outMu — no loss window) // 7. defer: detach → pumpWg.Wait → UnlockWS // 8. Send replay frame if snapshot non-empty // 9. Send connected frame @@ -116,11 +115,10 @@ func PTYSessionWebSocket(ctx *gin.Context) { } } - // 5. Snapshot replay buffer BEFORE attaching live pipe. - snapshotBytes, snapshotOffset := session.ReplayBuffer().ReadFrom(since) - - // 6. Attach per-connection output pipe — live sink is now active. - stdoutR, stderrR, detach := session.AttachOutput() + // 5+6. Atomically snapshot replay buffer and attach live pipe — eliminates the + // output-loss window where bytes written between ReadFrom and AttachOutput + // would be dropped by fanout (stdoutW still nil) yet missed by snapshot. + stdoutR, stderrR, detach, snapshotBytes, snapshotOffset := session.AttachOutputWithSnapshot(since) // 7. Deferred cleanup order: detach writers → wait for pump goroutines → unlock WS. var pumpWg sync.WaitGroup From e48c639966e5cb16178cda6d1c05a5bf2aed12b7 Mon Sep 17 00:00:00 2001 From: Jon Date: Sat, 28 Mar 2026 18:40:10 +0800 Subject: [PATCH 3/4] fix(execd/pty): export PTYSession interface to satisfy codeExecutionRunner --- components/execd/pkg/runtime/pty_session.go | 39 ++++++++++++++++--- .../execd/pkg/runtime/pty_session_test.go | 2 +- .../execd/pkg/runtime/pty_session_windows.go | 23 ++++++++++- .../pkg/web/controller/codeinterpreting.go | 4 ++ .../web/controller/codeinterpreting_test.go | 5 +++ 5 files changed, 65 insertions(+), 8 deletions(-) diff --git a/components/execd/pkg/runtime/pty_session.go b/components/execd/pkg/runtime/pty_session.go index 453c9de3..049e5947 100644 --- a/components/execd/pkg/runtime/pty_session.go +++ b/components/execd/pkg/runtime/pty_session.go @@ -34,6 +34,25 @@ import ( var errPTYSessionNotSupported = errors.New("pty session not supported") +// PTYSession is the public interface for an interactive PTY/pipe session. +// The concrete implementation (*ptySession) is unexported; callers outside +// this package must use this interface. +type PTYSession interface { + LockWS() bool + UnlockWS() + IsRunning() bool + IsPTY() bool + ExitCode() int + Done() <-chan struct{} + StartPTY() error + StartPipe() error + WriteStdin(p []byte) (int, error) + AttachOutput() (io.Reader, io.Reader, func()) + AttachOutputWithSnapshot(since int64) (io.Reader, io.Reader, func(), []byte, int64) + SendSignal(name string) + ResizePTY(cols, rows uint16) error +} + // IsPTYSessionSupported reports whether PTY sessions are supported on this platform. func IsPTYSessionSupported() bool { return true } @@ -525,15 +544,16 @@ func (s *ptySession) close() { } // CreatePTYSession creates a new PTY session and stores it in the map. -func (c *Controller) CreatePTYSession(id, cwd string) *ptySession { +func (c *Controller) CreatePTYSession(id, cwd string) PTYSession { s := newPTYSession(id, cwd) c.ptySessionMap.Store(id, s) log.Info("created pty session %s", id) return s } -// GetPTYSession looks up a PTY session by ID. Returns nil if not found. -func (c *Controller) GetPTYSession(id string) *ptySession { +// getPTYSession looks up a PTY session by ID. Returns nil if not found. +// For internal use only; outside callers should use GetPTYSession. +func (c *Controller) getPTYSession(id string) *ptySession { if v, ok := c.ptySessionMap.Load(id); ok { if s, ok := v.(*ptySession); ok { return s @@ -542,10 +562,19 @@ func (c *Controller) GetPTYSession(id string) *ptySession { return nil } +// GetPTYSession looks up a PTY session by ID. Returns nil if not found. +func (c *Controller) GetPTYSession(id string) PTYSession { + s := c.getPTYSession(id) + if s == nil { + return nil + } + return s +} + // DeletePTYSession terminates and removes a PTY session. // Returns ErrContextNotFound if the session does not exist. func (c *Controller) DeletePTYSession(id string) error { - s := c.GetPTYSession(id) + s := c.getPTYSession(id) if s == nil { return ErrContextNotFound } @@ -557,7 +586,7 @@ func (c *Controller) DeletePTYSession(id string) error { // GetPTYSessionStatus returns status information for a PTY session. func (c *Controller) GetPTYSessionStatus(id string) (running bool, outputOffset int64, err error) { - s := c.GetPTYSession(id) + s := c.getPTYSession(id) if s == nil { return false, 0, ErrContextNotFound } diff --git a/components/execd/pkg/runtime/pty_session_test.go b/components/execd/pkg/runtime/pty_session_test.go index d9c2d3f4..e5cb7dee 100644 --- a/components/execd/pkg/runtime/pty_session_test.go +++ b/components/execd/pkg/runtime/pty_session_test.go @@ -308,7 +308,7 @@ func TestPTYSession_ControllerCRUD(t *testing.T) { got := c.GetPTYSession(id) require.NotNil(t, got) - require.Equal(t, id, got.id) + require.Equal(t, id, got.(*ptySession).id) running, offset, err := c.GetPTYSessionStatus(id) require.NoError(t, err) diff --git a/components/execd/pkg/runtime/pty_session_windows.go b/components/execd/pkg/runtime/pty_session_windows.go index 77b7bdd6..2d65810e 100644 --- a/components/execd/pkg/runtime/pty_session_windows.go +++ b/components/execd/pkg/runtime/pty_session_windows.go @@ -26,6 +26,25 @@ import ( // All methods the controller layer calls must be present here so Windows cross-compilation succeeds. type ptySession struct{} +// PTYSession is the public interface for an interactive PTY/pipe session. +// The concrete implementation (*ptySession) is unexported; callers outside +// this package must use this interface. +type PTYSession interface { + LockWS() bool + UnlockWS() + IsRunning() bool + IsPTY() bool + ExitCode() int + Done() <-chan struct{} + StartPTY() error + StartPipe() error + WriteStdin(p []byte) (int, error) + AttachOutput() (io.Reader, io.Reader, func()) + AttachOutputWithSnapshot(since int64) (io.Reader, io.Reader, func(), []byte, int64) + SendSignal(name string) + ResizePTY(cols, rows uint16) error +} + var errPTYSessionNotSupported = errors.New("pty session is not supported on windows") // IsPTYSessionSupported reports whether PTY sessions are supported on this platform. @@ -35,10 +54,10 @@ func IsPTYSessionSupported() bool { return false } func NewPTYSessionID() string { return "" } // CreatePTYSession is not supported on Windows. -func (c *Controller) CreatePTYSession(id, cwd string) *ptySession { return nil } //nolint:revive +func (c *Controller) CreatePTYSession(id, cwd string) PTYSession { return nil } //nolint:revive // GetPTYSession is not supported on Windows. -func (c *Controller) GetPTYSession(id string) *ptySession { return nil } //nolint:revive +func (c *Controller) GetPTYSession(id string) PTYSession { return nil } //nolint:revive // DeletePTYSession is not supported on Windows. func (c *Controller) DeletePTYSession(id string) error { //nolint:revive diff --git a/components/execd/pkg/web/controller/codeinterpreting.go b/components/execd/pkg/web/controller/codeinterpreting.go index e9023605..e30131f3 100644 --- a/components/execd/pkg/web/controller/codeinterpreting.go +++ b/components/execd/pkg/web/controller/codeinterpreting.go @@ -58,6 +58,10 @@ type codeExecutionRunner interface { SeekBackgroundCommandOutput(session string, cursor int64) ([]byte, int64, error) DeleteBashSession(sessionID string) error Interrupt(sessionID string) error + CreatePTYSession(id, cwd string) runtime.PTYSession + GetPTYSession(id string) runtime.PTYSession + DeletePTYSession(id string) error + GetPTYSessionStatus(id string) (bool, int64, error) } func NewCodeInterpretingController(ctx *gin.Context) *CodeInterpretingController { diff --git a/components/execd/pkg/web/controller/codeinterpreting_test.go b/components/execd/pkg/web/controller/codeinterpreting_test.go index f97c6eef..417bd1f4 100644 --- a/components/execd/pkg/web/controller/codeinterpreting_test.go +++ b/components/execd/pkg/web/controller/codeinterpreting_test.go @@ -89,6 +89,11 @@ func (f *fakeCodeRunner) Interrupt(_ string) error { return nil } +func (f *fakeCodeRunner) CreatePTYSession(_ string, _ string) runtime.PTYSession { return nil } +func (f *fakeCodeRunner) GetPTYSession(_ string) runtime.PTYSession { return nil } +func (f *fakeCodeRunner) DeletePTYSession(_ string) error { return nil } +func (f *fakeCodeRunner) GetPTYSessionStatus(_ string) (bool, int64, error) { return false, 0, nil } + func TestBuildExecuteCodeRequestDefaultsToCommand(t *testing.T) { ctrl := &CodeInterpretingController{} req := model.RunCodeRequest{ From 9ce847a6216112ea9af0fa1da0fcb6bc50e81238 Mon Sep 17 00:00:00 2001 From: Jon Date: Sat, 28 Mar 2026 19:11:17 +0800 Subject: [PATCH 4/4] fix(execd/pty): fix readFrom casing and refactor WS read loop --- components/execd/pkg/runtime/pty_session.go | 2 - .../execd/pkg/runtime/pty_session_test.go | 10 +- .../execd/pkg/runtime/pty_session_windows.go | 24 +- components/execd/pkg/runtime/replay_buffer.go | 5 - .../execd/pkg/runtime/replay_buffer_test.go | 20 +- components/execd/pkg/web/controller/pty_ws.go | 237 ++++++++++-------- 6 files changed, 160 insertions(+), 138 deletions(-) diff --git a/components/execd/pkg/runtime/pty_session.go b/components/execd/pkg/runtime/pty_session.go index 049e5947..74a0e38e 100644 --- a/components/execd/pkg/runtime/pty_session.go +++ b/components/execd/pkg/runtime/pty_session.go @@ -32,8 +32,6 @@ import ( "github.com/alibaba/opensandbox/execd/pkg/log" ) -var errPTYSessionNotSupported = errors.New("pty session not supported") - // PTYSession is the public interface for an interactive PTY/pipe session. // The concrete implementation (*ptySession) is unexported; callers outside // this package must use this interface. diff --git a/components/execd/pkg/runtime/pty_session_test.go b/components/execd/pkg/runtime/pty_session_test.go index e5cb7dee..f7604ddc 100644 --- a/components/execd/pkg/runtime/pty_session_test.go +++ b/components/execd/pkg/runtime/pty_session_test.go @@ -34,7 +34,7 @@ func replayContains(t *testing.T, s *ptySession, substr string, timeout time.Dur t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { - data, _ := s.replay.readFrom(0) + data, _ := s.replay.ReadFrom(0) if strings.Contains(string(data), substr) { return true } @@ -124,7 +124,7 @@ func TestPTYSession_ANSISequences(t *testing.T) { "expected 'GREEN' in replay buffer") // PTY mode should propagate ESC bytes verbatim. - data, _ := s.replay.readFrom(0) + data, _ := s.replay.ReadFrom(0) assert.Contains(t, string(data), "\x1b", "expected ESC bytes in PTY output") } @@ -187,12 +187,12 @@ func TestPTYSession_ReconnectReplay(t *testing.T) { require.True(t, replayContains(t, s, "first_output", 5*time.Second), "expected 'first_output' in replay buffer") - snapshot, snapshotOff := s.replay.readFrom(0) + snapshot, snapshotOff := s.replay.ReadFrom(0) detach1() time.Sleep(50 * time.Millisecond) // Reconnect — replay from offset 0 should return the same bytes we snapshotted. - replay, replayOff := s.replay.readFrom(0) + replay, replayOff := s.replay.ReadFrom(0) require.Equal(t, snapshotOff, replayOff) // The new snapshot may be larger (more output arrived), but must contain snapshot. require.True(t, strings.HasPrefix(string(replay), string(snapshot)) || len(replay) >= len(snapshot), @@ -212,7 +212,7 @@ func TestPTYSession_ReconnectReplay(t *testing.T) { "expected 'second_output' in replay buffer") // Delta replay from offset-after-first should contain only new bytes. - newData, _ := s.replay.readFrom(offsetAfterFirst) + newData, _ := s.replay.ReadFrom(offsetAfterFirst) require.True(t, strings.Contains(string(newData), "second_output"), "delta replay should contain 'second_output', got %q", string(newData)) } diff --git a/components/execd/pkg/runtime/pty_session_windows.go b/components/execd/pkg/runtime/pty_session_windows.go index 2d65810e..ed534685 100644 --- a/components/execd/pkg/runtime/pty_session_windows.go +++ b/components/execd/pkg/runtime/pty_session_windows.go @@ -71,19 +71,19 @@ func (c *Controller) GetPTYSessionStatus(id string) (bool, int64, error) { //nol // Method stubs so the controller layer can call them without build-tag guards. -func (s *ptySession) LockWS() bool { return false } -func (s *ptySession) UnlockWS() {} -func (s *ptySession) IsRunning() bool { return false } -func (s *ptySession) IsPTY() bool { return false } -func (s *ptySession) ExitCode() int { return -1 } -func (s *ptySession) Done() <-chan struct{} { return nil } -func (s *ptySession) ReplayBuffer() *replayBuffer { return nil } -func (s *ptySession) StartPTY() error { return errPTYSessionNotSupported } -func (s *ptySession) StartPipe() error { return errPTYSessionNotSupported } -func (s *ptySession) WriteStdin(_ []byte) (int, error) { return 0, errPTYSessionNotSupported } +func (s *ptySession) LockWS() bool { return false } +func (s *ptySession) UnlockWS() {} +func (s *ptySession) IsRunning() bool { return false } +func (s *ptySession) IsPTY() bool { return false } +func (s *ptySession) ExitCode() int { return -1 } +func (s *ptySession) Done() <-chan struct{} { return nil } +func (s *ptySession) ReplayBuffer() *replayBuffer { return nil } +func (s *ptySession) StartPTY() error { return errPTYSessionNotSupported } +func (s *ptySession) StartPipe() error { return errPTYSessionNotSupported } +func (s *ptySession) WriteStdin(_ []byte) (int, error) { return 0, errPTYSessionNotSupported } func (s *ptySession) AttachOutput() (io.Reader, io.Reader, func()) { return nil, nil, func() {} } func (s *ptySession) AttachOutputWithSnapshot(_ int64) (io.Reader, io.Reader, func(), []byte, int64) { return nil, nil, func() {}, nil, 0 } -func (s *ptySession) SendSignal(_ string) {} -func (s *ptySession) ResizePTY(_, _ uint16) error { return nil } +func (s *ptySession) SendSignal(_ string) {} +func (s *ptySession) ResizePTY(_, _ uint16) error { return nil } diff --git a/components/execd/pkg/runtime/replay_buffer.go b/components/execd/pkg/runtime/replay_buffer.go index 5285bea7..5d79e377 100644 --- a/components/execd/pkg/runtime/replay_buffer.go +++ b/components/execd/pkg/runtime/replay_buffer.go @@ -76,11 +76,6 @@ func (r *replayBuffer) Total() int64 { return r.total } -// readFrom is the unexported thin wrapper used inside the runtime package. -func (r *replayBuffer) readFrom(offset int64) ([]byte, int64) { - return r.ReadFrom(offset) -} - // ReadFrom returns a snapshot of all bytes starting from the given absolute byte offset. // // Returns (data, actualOffset) where actualOffset is the offset of the first returned byte: diff --git a/components/execd/pkg/runtime/replay_buffer_test.go b/components/execd/pkg/runtime/replay_buffer_test.go index 4c0227fd..e9438e80 100644 --- a/components/execd/pkg/runtime/replay_buffer_test.go +++ b/components/execd/pkg/runtime/replay_buffer_test.go @@ -27,7 +27,7 @@ func TestReplayBuffer_BasicWriteRead(t *testing.T) { rb.write([]byte("hello")) rb.write([]byte(" world")) - data, off := rb.readFrom(0) + data, off := rb.ReadFrom(0) require.Equal(t, int64(0), off) require.Equal(t, []byte("hello world"), data) require.Equal(t, int64(11), rb.Total()) @@ -37,7 +37,7 @@ func TestReplayBuffer_ReadFromMiddle(t *testing.T) { rb := newReplayBuffer() rb.write([]byte("abcde")) - data, off := rb.readFrom(2) + data, off := rb.ReadFrom(2) require.Equal(t, int64(2), off) require.Equal(t, []byte("cde"), data) } @@ -46,7 +46,7 @@ func TestReplayBuffer_ReadFromCurrent(t *testing.T) { rb := newReplayBuffer() rb.write([]byte("abc")) - data, off := rb.readFrom(3) + data, off := rb.ReadFrom(3) require.Nil(t, data, "should return nil when caught up") require.Equal(t, int64(3), off) } @@ -66,12 +66,12 @@ func TestReplayBuffer_CircularEviction(t *testing.T) { require.Equal(t, int64(10), rb.Total()) // offset 0 should be clamped to oldest=2 - data, off := rb.readFrom(0) + data, off := rb.ReadFrom(0) require.Equal(t, int64(2), off) require.Equal(t, []byte("cdefghij"), data) // Read from offset 5 (within retained range) - data, off = rb.readFrom(5) + data, off = rb.ReadFrom(5) require.Equal(t, int64(5), off) require.Equal(t, []byte("fghij"), data) } @@ -85,12 +85,12 @@ func TestReplayBuffer_LargeGap(t *testing.T) { rb.write([]byte("ABCDEF")) // Requesting from 0 should clamp to oldest=2 - data, off := rb.readFrom(0) + data, off := rb.ReadFrom(0) require.Equal(t, int64(2), off) require.Equal(t, []byte("CDEF"), data) // Requesting from 1 should also clamp to oldest=2 - data, off = rb.readFrom(1) + data, off = rb.ReadFrom(1) require.Equal(t, int64(2), off) require.Equal(t, []byte("CDEF"), data) } @@ -114,7 +114,7 @@ func TestReplayBuffer_Concurrent(t *testing.T) { go func() { defer wg.Done() for range 32 { - rb.readFrom(0) + rb.ReadFrom(0) rb.Total() } }() @@ -133,7 +133,7 @@ func TestReplayBuffer_ExactlyFull(t *testing.T) { rb.write([]byte("1234")) require.Equal(t, int64(4), rb.Total()) - data, off := rb.readFrom(0) + data, off := rb.ReadFrom(0) require.Equal(t, int64(0), off) require.Equal(t, []byte("1234"), data) } @@ -148,7 +148,7 @@ func TestReplayBuffer_WriteWrapsCorrectly(t *testing.T) { // Write "EF" — evicts "AB", retained "CDEF" rb.write([]byte("EF")) - data, off := rb.readFrom(0) + data, off := rb.ReadFrom(0) require.Equal(t, int64(2), off, "offset should be clamped to oldest=2") require.Equal(t, []byte("CDEF"), data) } diff --git a/components/execd/pkg/web/controller/pty_ws.go b/components/execd/pkg/web/controller/pty_ws.go index 72b227b5..6273557f 100644 --- a/components/execd/pkg/web/controller/pty_ws.go +++ b/components/execd/pkg/web/controller/pty_ws.go @@ -27,6 +27,7 @@ import ( "github.com/gorilla/websocket" "github.com/alibaba/opensandbox/execd/pkg/log" + "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" ) @@ -186,89 +187,152 @@ func PTYSessionWebSocket(ctx *gin.Context) { } // 10a. RFC 6455 binary ping goroutine (30 s interval). - go func() { - t := time.NewTicker(wsPingInterval) - defer t.Stop() - for { - select { - case <-cancelCh: - return - case <-t.C: - connMu.Lock() - _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) - pingErr := conn.WriteMessage(websocket.PingMessage, nil) - connMu.Unlock() - if pingErr != nil { - cancelOnce() - return - } - } - } - }() - - // streamPump reads raw chunks from r and sends them as binary frames over WS. - streamPump := func(r io.Reader, typeByte byte, name string) { - defer pumpWg.Done() - const chunkSize = 32 * 1024 - frame := make([]byte, 1+chunkSize) // single allocation for session lifetime - frame[0] = typeByte - for { - select { - case <-cancelCh: - return - default: - } - n, readErr := r.Read(frame[1:]) - if n > 0 { - connMu.Lock() - _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) - writeErr := conn.WriteMessage(websocket.BinaryMessage, frame[:1+n]) - connMu.Unlock() - if writeErr != nil { - log.Warning("pty ws write %s for session %s: %v", name, id, writeErr) - cancelOnce() - return - } - } - if readErr != nil { - // io.EOF or io.ErrClosedPipe when detach() closes the PipeWriter. - return - } - } - } + go ptyPingLoop(conn, &connMu, cancelCh, cancelOnce) // 10b. Launch stdout pump. pumpWg.Add(1) - go streamPump(stdoutR, model.BinStdout, "stdout") + go ptyStreamPump(stdoutR, model.BinStdout, "stdout", id, conn, &connMu, &pumpWg, cancelCh, cancelOnce) // 10c. Launch stderr pump (pipe mode only). if stderrR != nil { pumpWg.Add(1) - go streamPump(stderrR, model.BinStderr, "stderr") + go ptyStreamPump(stderrR, model.BinStderr, "stderr", id, conn, &connMu, &pumpWg, cancelCh, cancelOnce) } // 10d. Exit watcher: waits for the process to exit, then sends exit frame // and closes the WS connection immediately (unblocks ReadJSON in the read loop). - go func() { - doneCh := session.Done() - if doneCh == nil { + go ptyExitWatcher(session, writeJSON, closeConn, cancelCh, cancelOnce) + + // 11. Client read loop. + ptyClientReadLoop(conn, session, id, writeJSON, cancelCh, cancelOnce) +} + +// ptyPingLoop sends periodic WebSocket pings until cancelCh is closed. +func ptyPingLoop(conn *websocket.Conn, connMu *sync.Mutex, cancelCh <-chan struct{}, cancelOnce func()) { + t := time.NewTicker(wsPingInterval) + defer t.Stop() + for { + select { + case <-cancelCh: return + case <-t.C: + connMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + pingErr := conn.WriteMessage(websocket.PingMessage, nil) + connMu.Unlock() + if pingErr != nil { + cancelOnce() + return + } } + } +} + +// ptyStreamPump reads raw chunks from r and sends them as binary frames over WS. +func ptyStreamPump(r io.Reader, typeByte byte, name, id string, conn *websocket.Conn, connMu *sync.Mutex, pumpWg *sync.WaitGroup, cancelCh <-chan struct{}, cancelOnce func()) { + defer pumpWg.Done() + const chunkSize = 32 * 1024 + frame := make([]byte, 1+chunkSize) // single allocation for session lifetime + frame[0] = typeByte + for { select { - case <-doneCh: case <-cancelCh: return + default: } - exitCode := session.ExitCode() - _ = writeJSON(model.ServerFrame{ - Type: "exit", - ExitCode: &exitCode, - }) - closeConn(websocket.CloseNormalClosure, "process exited") + n, readErr := r.Read(frame[1:]) + if n > 0 { + connMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(wsWriteDeadline)) + writeErr := conn.WriteMessage(websocket.BinaryMessage, frame[:1+n]) + connMu.Unlock() + if writeErr != nil { + log.Warning("pty ws write %s for session %s: %v", name, id, writeErr) + cancelOnce() + return + } + } + if readErr != nil { + // io.EOF or io.ErrClosedPipe when detach() closes the PipeWriter. + return + } + } +} + +// ptyExitWatcher waits for the session process to exit, then sends an exit frame +// and closes the WS connection. +func ptyExitWatcher(session runtime.PTYSession, writeJSON func(any) error, closeConn func(int, string), cancelCh <-chan struct{}, cancelOnce func()) { + doneCh := session.Done() + if doneCh == nil { + return + } + select { + case <-doneCh: + case <-cancelCh: + return + } + exitCode := session.ExitCode() + _ = writeJSON(model.ServerFrame{ + Type: "exit", + ExitCode: &exitCode, + }) + closeConn(websocket.CloseNormalClosure, "process exited") + cancelOnce() +} + +// ptyHandleBinaryMsg processes an incoming binary WebSocket frame from the client. +// Returns true if the connection should be terminated. +func ptyHandleBinaryMsg(session runtime.PTYSession, data []byte, writeJSON func(any) error, cancelOnce func()) bool { + if len(data) == 0 { + return false + } + if data[0] != model.BinStdin { + return false // only stdin expected C→S + } + if _, writeErr := session.WriteStdin(data[1:]); writeErr != nil { + _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeStdinWriteFailed, + Error: writeErr.Error()}) cancelOnce() - }() + return true + } + return false +} - // 11. Client read loop. +// ptyHandleTextMsg processes an incoming text WebSocket frame from the client. +// Returns true if the connection should be terminated. +func ptyHandleTextMsg(session runtime.PTYSession, id string, data []byte, writeJSON func(any) error, cancelOnce func()) bool { + var frame model.ClientFrame + if json.Unmarshal(data, &frame) != nil { + return false + } + switch frame.Type { + case "stdin": + // wscat / debug fallback: plain UTF-8 text, no base64. + if _, writeErr := session.WriteStdin([]byte(frame.Data)); writeErr != nil { + _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeStdinWriteFailed, + Error: writeErr.Error()}) + cancelOnce() + return true + } + case "signal": + session.SendSignal(frame.Signal) + case "resize": + if frame.Cols > 0 && frame.Rows > 0 { + if resErr := session.ResizePTY(uint16(frame.Cols), uint16(frame.Rows)); resErr != nil { + log.Warning("pty resize session %s: %v", id, resErr) + } + } + case "ping": + _ = writeJSON(model.ServerFrame{Type: "pong"}) + default: + _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeInvalidFrame, + Error: fmt.Sprintf("unknown frame type %q", frame.Type)}) + } + return false +} + +// ptyClientReadLoop processes incoming WebSocket messages until the connection closes. +func ptyClientReadLoop(conn *websocket.Conn, session runtime.PTYSession, id string, writeJSON func(any) error, cancelCh <-chan struct{}, cancelOnce func()) { for { select { case <-cancelCh: @@ -276,8 +340,8 @@ func PTYSessionWebSocket(ctx *gin.Context) { default: } - msgType, data, err2 := conn.ReadMessage() - if err2 != nil { + msgType, data, err := conn.ReadMessage() + if err != nil { cancelOnce() return } @@ -286,48 +350,13 @@ func PTYSessionWebSocket(ctx *gin.Context) { _ = conn.SetReadDeadline(time.Now().Add(wsReadDeadline)) switch msgType { - case websocket.BinaryMessage: - if len(data) == 0 { - continue - } - if data[0] != model.BinStdin { - continue // only stdin expected C→S - } - if _, writeErr := session.WriteStdin(data[1:]); writeErr != nil { - _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeStdinWriteFailed, - Error: writeErr.Error()}) - cancelOnce() + if ptyHandleBinaryMsg(session, data, writeJSON, cancelOnce) { return } - case websocket.TextMessage: - var frame model.ClientFrame - if json.Unmarshal(data, &frame) != nil { - continue - } - switch frame.Type { - case "stdin": - // wscat / debug fallback: plain UTF-8 text, no base64. - if _, writeErr := session.WriteStdin([]byte(frame.Data)); writeErr != nil { - _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeStdinWriteFailed, - Error: writeErr.Error()}) - cancelOnce() - return - } - case "signal": - session.SendSignal(frame.Signal) - case "resize": - if frame.Cols > 0 && frame.Rows > 0 { - if resErr := session.ResizePTY(uint16(frame.Cols), uint16(frame.Rows)); resErr != nil { - log.Warning("pty resize session %s: %v", id, resErr) - } - } - case "ping": - _ = writeJSON(model.ServerFrame{Type: "pong"}) - default: - _ = writeJSON(model.ServerFrame{Type: "error", Code: model.WSErrCodeInvalidFrame, - Error: fmt.Sprintf("unknown frame type %q", frame.Type)}) + if ptyHandleTextMsg(session, id, data, writeJSON, cancelOnce) { + return } } }