From 2327158109afd5895f84ad8cd70dc4b7e86da3ad Mon Sep 17 00:00:00 2001 From: "skyler.su" Date: Thu, 26 Mar 2026 16:11:40 +0800 Subject: [PATCH 1/3] feat(execd): tune jupyter idle polling and sse completion wait --- components/execd/pkg/flag/flags.go | 4 +++ components/execd/pkg/flag/parser.go | 15 +++++++++ .../execd/pkg/jupyter/execute/execute.go | 17 ++++++++-- .../pkg/web/controller/codeinterpreting.go | 33 +++++++++++++++++-- 4 files changed, 64 insertions(+), 5 deletions(-) diff --git a/components/execd/pkg/flag/flags.go b/components/execd/pkg/flag/flags.go index b1442954f..9974cf8ce 100644 --- a/components/execd/pkg/flag/flags.go +++ b/components/execd/pkg/flag/flags.go @@ -34,4 +34,8 @@ var ( // ApiGracefulShutdownTimeout waits before tearing down SSE streams. ApiGracefulShutdownTimeout time.Duration + + // JupyterIdlePollInterval controls how often ExecuteCodeStream checks for + // late execute_result/error messages after receiving idle status. + JupyterIdlePollInterval time.Duration ) diff --git a/components/execd/pkg/flag/parser.go b/components/execd/pkg/flag/parser.go index c1c0dc080..67ee54e01 100644 --- a/components/execd/pkg/flag/parser.go +++ b/components/execd/pkg/flag/parser.go @@ -28,6 +28,7 @@ const ( jupyterHostEnv = "JUPYTER_HOST" jupyterTokenEnv = "JUPYTER_TOKEN" gracefulShutdownTimeoutEnv = "EXECD_API_GRACE_SHUTDOWN" + jupyterIdlePollIntervalEnv = "EXECD_JUPYTER_IDLE_POLL_INTERVAL" ) // InitFlags registers CLI flags and env overrides. @@ -37,6 +38,7 @@ func InitFlags() { ServerLogLevel = 6 ServerAccessToken = "" ApiGracefulShutdownTimeout = time.Second * 1 + JupyterIdlePollInterval = 10 * time.Millisecond // First, set default values from environment variables if jupyterFromEnv := os.Getenv(jupyterHostEnv); jupyterFromEnv != "" { @@ -65,7 +67,20 @@ func InitFlags() { ApiGracefulShutdownTimeout = duration } + if idlePollInterval := os.Getenv(jupyterIdlePollIntervalEnv); idlePollInterval != "" { + duration, err := time.ParseDuration(idlePollInterval) + if err != nil { + stdlog.Panicf("Failed to parse jupyter idle poll interval from env: %v", err) + } + if duration <= 0 { + stdlog.Printf("Invalid %s=%s; fallback to default %s", jupyterIdlePollIntervalEnv, idlePollInterval, JupyterIdlePollInterval) + } else { + JupyterIdlePollInterval = duration + } + } + flag.DurationVar(&ApiGracefulShutdownTimeout, "graceful-shutdown-timeout", ApiGracefulShutdownTimeout, "API graceful shutdown timeout duration (default: 3s)") + flag.DurationVar(&JupyterIdlePollInterval, "jupyter-idle-poll-interval", JupyterIdlePollInterval, "Polling interval after Jupyter idle status before closing stream (default: 10ms)") // Parse flags - these will override environment variables if provided flag.Parse() diff --git a/components/execd/pkg/jupyter/execute/execute.go b/components/execd/pkg/jupyter/execute/execute.go index 23a65e94b..0c4096961 100644 --- a/components/execd/pkg/jupyter/execute/execute.go +++ b/components/execd/pkg/jupyter/execute/execute.go @@ -25,6 +25,8 @@ import ( "github.com/google/uuid" "github.com/gorilla/websocket" + + execdflag "github.com/alibaba/opensandbox/execd/pkg/flag" ) // HTTPClient defines the HTTP client interface @@ -268,8 +270,19 @@ func (c *Client) ExecuteCodeStream(code string, resultChan chan *ExecutionResult resultChan <- notify resultMutex.Unlock() - for result.ExecutionCount <= 0 && result.Error == nil { - time.Sleep(300 * time.Millisecond) + pollInterval := execdflag.JupyterIdlePollInterval + if pollInterval <= 0 { + pollInterval = 10 * time.Millisecond + } + + for { + resultMutex.Lock() + done := result.ExecutionCount > 0 || result.Error != nil + resultMutex.Unlock() + if done { + break + } + time.Sleep(pollInterval) } // Close result channel diff --git a/components/execd/pkg/web/controller/codeinterpreting.go b/components/execd/pkg/web/controller/codeinterpreting.go index b0facef4e..73ec70960 100644 --- a/components/execd/pkg/web/controller/codeinterpreting.go +++ b/components/execd/pkg/web/controller/codeinterpreting.go @@ -113,6 +113,16 @@ func (c *CodeInterpretingController) RunCode() { defer cancel() runCodeRequest := c.buildExecuteCodeRequest(request) eventsHandler := c.setServerEventsHandler(ctx) + + // completeCh is closed when OnExecuteComplete fires, meaning the final SSE + // event has been written and flushed. We only wait for this callback as a + // safety check and then return immediately to avoid fixed tail latency. + completeCh := make(chan struct{}) + origComplete := eventsHandler.OnExecuteComplete + eventsHandler.OnExecuteComplete = func(executionTime time.Duration) { + origComplete(executionTime) + close(completeCh) + } runCodeRequest.Hooks = eventsHandler c.setupSSEResponse() @@ -126,7 +136,10 @@ func (c *CodeInterpretingController) RunCode() { return } - time.Sleep(flag.ApiGracefulShutdownTimeout) + select { + case <-completeCh: + case <-time.After(flag.ApiGracefulShutdownTimeout): + } } // GetContext returns a specific code context by id. @@ -305,7 +318,18 @@ func (c *CodeInterpretingController) RunInSession() { } ctx, cancel := context.WithCancel(c.ctx.Request.Context()) defer cancel() - runReq.Hooks = c.setServerEventsHandler(ctx) + + // completeCh is closed when OnExecuteComplete fires, meaning the final SSE + // event has been written and flushed. We only wait for this callback as a + // safety check and then return immediately to avoid fixed tail latency. + completeCh := make(chan struct{}) + hooks := c.setServerEventsHandler(ctx) + origComplete := hooks.OnExecuteComplete + hooks.OnExecuteComplete = func(executionTime time.Duration) { + origComplete(executionTime) + close(completeCh) + } + runReq.Hooks = hooks c.setupSSEResponse() err := codeRunner.RunInBashSession(ctx, runReq) @@ -318,7 +342,10 @@ func (c *CodeInterpretingController) RunInSession() { return } - time.Sleep(flag.ApiGracefulShutdownTimeout) + select { + case <-completeCh: + case <-time.After(flag.ApiGracefulShutdownTimeout): + } } // DeleteSession deletes a bash session (delete_session API). From 6c9a91407b501bbcec39fae3d1f6e74791c9dbc6 Mon Sep 17 00:00:00 2001 From: "skyler.su" Date: Thu, 26 Mar 2026 16:59:59 +0800 Subject: [PATCH 2/3] feat(execd): solve the issues from github copilot review --- components/execd/pkg/flag/parser.go | 6 +- components/execd/pkg/flag/parser_test.go | 41 +++ .../execd/pkg/jupyter/execute/execute_test.go | 142 +++++++++++ .../pkg/web/controller/codeinterpreting.go | 72 +++++- .../web/controller/codeinterpreting_test.go | 234 ++++++++++++++++++ 5 files changed, 483 insertions(+), 12 deletions(-) create mode 100644 components/execd/pkg/flag/parser_test.go diff --git a/components/execd/pkg/flag/parser.go b/components/execd/pkg/flag/parser.go index 67ee54e01..d24b9e8e0 100644 --- a/components/execd/pkg/flag/parser.go +++ b/components/execd/pkg/flag/parser.go @@ -79,11 +79,15 @@ func InitFlags() { } } - flag.DurationVar(&ApiGracefulShutdownTimeout, "graceful-shutdown-timeout", ApiGracefulShutdownTimeout, "API graceful shutdown timeout duration (default: 3s)") + flag.DurationVar(&ApiGracefulShutdownTimeout, "graceful-shutdown-timeout", ApiGracefulShutdownTimeout, "API graceful shutdown timeout duration (default: 1s)") flag.DurationVar(&JupyterIdlePollInterval, "jupyter-idle-poll-interval", JupyterIdlePollInterval, "Polling interval after Jupyter idle status before closing stream (default: 10ms)") // Parse flags - these will override environment variables if provided flag.Parse() + if JupyterIdlePollInterval <= 0 { + stdlog.Printf("Invalid --jupyter-idle-poll-interval=%s; fallback to default %s", JupyterIdlePollInterval, 10*time.Millisecond) + JupyterIdlePollInterval = 10 * time.Millisecond + } // Log final values log.Info("Jupyter server host is: %s", JupyterServerHost) diff --git a/components/execd/pkg/flag/parser_test.go b/components/execd/pkg/flag/parser_test.go new file mode 100644 index 000000000..23f2a70e6 --- /dev/null +++ b/components/execd/pkg/flag/parser_test.go @@ -0,0 +1,41 @@ +// Copyright 2025 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flag + +import ( + "flag" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestInitFlagsSanitizesNonPositiveJupyterIdlePollIntervalFromCLI(t *testing.T) { + previousArgs := os.Args + previousCommandLine := flag.CommandLine + defaultPollInterval := 10 * time.Millisecond + + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + os.Args = []string{previousArgs[0], "--jupyter-idle-poll-interval=0"} + t.Cleanup(func() { + os.Args = previousArgs + flag.CommandLine = previousCommandLine + }) + + InitFlags() + + require.Equal(t, defaultPollInterval, JupyterIdlePollInterval) +} diff --git a/components/execd/pkg/jupyter/execute/execute_test.go b/components/execd/pkg/jupyter/execute/execute_test.go index b23a3a9e3..08ef456a0 100644 --- a/components/execd/pkg/jupyter/execute/execute_test.go +++ b/components/execd/pkg/jupyter/execute/execute_test.go @@ -23,6 +23,9 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + + execdflag "github.com/alibaba/opensandbox/execd/pkg/flag" ) // Create WebSocket test server @@ -156,3 +159,142 @@ func TestExecuteCodeStream(t *testing.T) { t.Errorf("expected at least 4 results, got %d", resultCount) } } + +func TestExecuteCodeStreamWaitsForLateExecuteResultUsingConfiguredPollInterval(t *testing.T) { + previousPollInterval := execdflag.JupyterIdlePollInterval + execdflag.JupyterIdlePollInterval = time.Millisecond + t.Cleanup(func() { + execdflag.JupyterIdlePollInterval = previousPollInterval + }) + + server := createTestServer(t, func(conn *websocket.Conn) { + var executeRequest Message + err := conn.ReadJSON(&executeRequest) + if err != nil { + t.Fatalf("failed to read execution request: %v", err) + } + + statusContent, _ := json.Marshal(StatusUpdate{ExecutionState: StateIdle}) + statusMsg := Message{ + Header: Header{ + MessageID: "status-msg-id", + Session: executeRequest.Header.Session, + MessageType: string(MsgStatus), + }, + ParentHeader: executeRequest.Header, + Content: json.RawMessage(statusContent), + } + require.NoError(t, conn.WriteJSON(statusMsg)) + + time.Sleep(15 * time.Millisecond) + + resultContent, _ := json.Marshal(ExecuteResult{ + ExecutionCount: 1, + Data: map[string]interface{}{ + "text/plain": "Completed late", + }, + Metadata: map[string]interface{}{}, + }) + executeResultMsg := Message{ + Header: Header{ + MessageID: "result-msg-id", + Session: executeRequest.Header.Session, + MessageType: string(MsgExecuteResult), + }, + ParentHeader: executeRequest.Header, + Content: json.RawMessage(resultContent), + } + require.NoError(t, conn.WriteJSON(executeResultMsg)) + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/kernels/test-kernel-id/channels" + executor := NewExecutor(wsURL, nil) + require.NoError(t, executor.Connect()) + defer executor.Disconnect() + + resultChan := make(chan *ExecutionResult, 10) + require.NoError(t, executor.ExecuteCodeStream("print('late result')", resultChan)) + + start := time.Now() + var gotLateResult bool + for result := range resultChan { + if result != nil && result.ExecutionCount == 1 { + gotLateResult = true + } + } + elapsed := time.Since(start) + + require.True(t, gotLateResult, "expected late execute_result to be delivered before stream close") + require.Less(t, elapsed, 100*time.Millisecond, "expected stream to close promptly after late execute_result") +} + +func TestExecuteCodeStreamFallsBackWhenPollIntervalIsNonPositive(t *testing.T) { + previousPollInterval := execdflag.JupyterIdlePollInterval + execdflag.JupyterIdlePollInterval = 0 + t.Cleanup(func() { + execdflag.JupyterIdlePollInterval = previousPollInterval + }) + + server := createTestServer(t, func(conn *websocket.Conn) { + var executeRequest Message + err := conn.ReadJSON(&executeRequest) + if err != nil { + t.Fatalf("failed to read execution request: %v", err) + } + + statusContent, _ := json.Marshal(StatusUpdate{ExecutionState: StateIdle}) + statusMsg := Message{ + Header: Header{ + MessageID: "status-msg-id", + Session: executeRequest.Header.Session, + MessageType: string(MsgStatus), + }, + ParentHeader: executeRequest.Header, + Content: json.RawMessage(statusContent), + } + require.NoError(t, conn.WriteJSON(statusMsg)) + + time.Sleep(15 * time.Millisecond) + + resultContent, _ := json.Marshal(ExecuteResult{ + ExecutionCount: 1, + Data: map[string]interface{}{ + "text/plain": "Completed with fallback", + }, + Metadata: map[string]interface{}{}, + }) + executeResultMsg := Message{ + Header: Header{ + MessageID: "result-msg-id", + Session: executeRequest.Header.Session, + MessageType: string(MsgExecuteResult), + }, + ParentHeader: executeRequest.Header, + Content: json.RawMessage(resultContent), + } + require.NoError(t, conn.WriteJSON(executeResultMsg)) + }) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/kernels/test-kernel-id/channels" + executor := NewExecutor(wsURL, nil) + require.NoError(t, executor.Connect()) + defer executor.Disconnect() + + resultChan := make(chan *ExecutionResult, 10) + require.NoError(t, executor.ExecuteCodeStream("print('fallback')", resultChan)) + + start := time.Now() + var gotLateResult bool + for result := range resultChan { + if result != nil && result.ExecutionCount == 1 { + gotLateResult = true + } + } + elapsed := time.Since(start) + + require.True(t, gotLateResult, "expected late execute_result to be delivered before stream close") + require.GreaterOrEqual(t, elapsed, 10*time.Millisecond, "expected non-positive poll interval to fall back to runtime default") + require.Less(t, elapsed, 120*time.Millisecond, "expected fallback poll interval to still close stream promptly") +} diff --git a/components/execd/pkg/web/controller/codeinterpreting.go b/components/execd/pkg/web/controller/codeinterpreting.go index 73ec70960..e90236052 100644 --- a/components/execd/pkg/web/controller/codeinterpreting.go +++ b/components/execd/pkg/web/controller/codeinterpreting.go @@ -26,11 +26,12 @@ import ( "github.com/gin-gonic/gin" "github.com/alibaba/opensandbox/execd/pkg/flag" + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" ) -var codeRunner *runtime.Controller +var codeRunner codeExecutionRunner func InitCodeRunner() { codeRunner = runtime.NewController(flag.JupyterServerHost, flag.JupyterServerToken) @@ -44,6 +45,21 @@ type CodeInterpretingController struct { chunkWriter sync.Mutex } +type codeExecutionRunner interface { + CreateContext(req *runtime.CreateContextRequest) (string, error) + Execute(request *runtime.ExecuteCodeRequest) error + GetContext(session string) (runtime.CodeContext, error) + GetCommandStatus(session string) (*runtime.CommandStatus, error) + ListContext(language string) ([]runtime.CodeContext, error) + DeleteLanguageContext(language runtime.Language) error + DeleteContext(session string) error + CreateBashSession(req *runtime.CreateContextRequest) (string, error) + RunInBashSession(ctx context.Context, req *runtime.ExecuteCodeRequest) error + SeekBackgroundCommandOutput(session string, cursor int64) ([]byte, int64, error) + DeleteBashSession(sessionID string) error + Interrupt(sessionID string) error +} + func NewCodeInterpretingController(ctx *gin.Context) *CodeInterpretingController { return &CodeInterpretingController{ basicController: newBasicController(ctx), @@ -118,10 +134,21 @@ func (c *CodeInterpretingController) RunCode() { // event has been written and flushed. We only wait for this callback as a // safety check and then return immediately to avoid fixed tail latency. completeCh := make(chan struct{}) + var completeOnce sync.Once + signalComplete := func() { + completeOnce.Do(func() { + close(completeCh) + }) + } origComplete := eventsHandler.OnExecuteComplete eventsHandler.OnExecuteComplete = func(executionTime time.Duration) { origComplete(executionTime) - close(completeCh) + signalComplete() + } + origError := eventsHandler.OnExecuteError + eventsHandler.OnExecuteError = func(err *execute.ErrorOutput) { + origError(err) + signalComplete() } runCodeRequest.Hooks = eventsHandler @@ -136,10 +163,7 @@ func (c *CodeInterpretingController) RunCode() { return } - select { - case <-completeCh: - case <-time.After(flag.ApiGracefulShutdownTimeout): - } + waitForExecutionComplete(ctx, completeCh) } // GetContext returns a specific code context by id. @@ -323,11 +347,22 @@ func (c *CodeInterpretingController) RunInSession() { // event has been written and flushed. We only wait for this callback as a // safety check and then return immediately to avoid fixed tail latency. completeCh := make(chan struct{}) + var completeOnce sync.Once + signalComplete := func() { + completeOnce.Do(func() { + close(completeCh) + }) + } hooks := c.setServerEventsHandler(ctx) origComplete := hooks.OnExecuteComplete hooks.OnExecuteComplete = func(executionTime time.Duration) { origComplete(executionTime) - close(completeCh) + signalComplete() + } + origError := hooks.OnExecuteError + hooks.OnExecuteError = func(err *execute.ErrorOutput) { + origError(err) + signalComplete() } runReq.Hooks = hooks @@ -342,10 +377,7 @@ func (c *CodeInterpretingController) RunInSession() { return } - select { - case <-completeCh: - case <-time.After(flag.ApiGracefulShutdownTimeout): - } + waitForExecutionComplete(ctx, completeCh) } // DeleteSession deletes a bash session (delete_session API). @@ -396,6 +428,24 @@ func (c *CodeInterpretingController) buildExecuteCodeRequest(request model.RunCo return req } +func waitForExecutionComplete(ctx context.Context, completeCh <-chan struct{}) { + timer := time.NewTimer(flag.ApiGracefulShutdownTimeout) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + }() + + select { + case <-completeCh: + case <-ctx.Done(): + case <-timer.C: + } +} + func (c *CodeInterpretingController) interrupt() { session := c.ctx.Query("id") if session == "" { diff --git a/components/execd/pkg/web/controller/codeinterpreting_test.go b/components/execd/pkg/web/controller/codeinterpreting_test.go index d4f11dcc0..f97c6eef1 100644 --- a/components/execd/pkg/web/controller/codeinterpreting_test.go +++ b/components/execd/pkg/web/controller/codeinterpreting_test.go @@ -15,17 +15,80 @@ package controller import ( + "context" "encoding/json" "net/http" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/alibaba/opensandbox/execd/pkg/flag" + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" "github.com/stretchr/testify/require" ) +type fakeCodeRunner struct { + execute func(request *runtime.ExecuteCodeRequest) error + runInBashSession func(_ context.Context, _ *runtime.ExecuteCodeRequest) error +} + +func (f *fakeCodeRunner) CreateContext(_ *runtime.CreateContextRequest) (string, error) { + return "", nil +} + +func (f *fakeCodeRunner) Execute(request *runtime.ExecuteCodeRequest) error { + if f.execute != nil { + return f.execute(request) + } + return nil +} + +func (f *fakeCodeRunner) GetContext(_ string) (runtime.CodeContext, error) { + return runtime.CodeContext{}, nil +} + +func (f *fakeCodeRunner) GetCommandStatus(_ string) (*runtime.CommandStatus, error) { + return nil, nil +} + +func (f *fakeCodeRunner) ListContext(_ string) ([]runtime.CodeContext, error) { + return nil, nil +} + +func (f *fakeCodeRunner) DeleteLanguageContext(_ runtime.Language) error { + return nil +} + +func (f *fakeCodeRunner) DeleteContext(_ string) error { + return nil +} + +func (f *fakeCodeRunner) CreateBashSession(_ *runtime.CreateContextRequest) (string, error) { + return "", nil +} + +func (f *fakeCodeRunner) RunInBashSession(ctx context.Context, req *runtime.ExecuteCodeRequest) error { + if f.runInBashSession != nil { + return f.runInBashSession(ctx, req) + } + return nil +} + +func (f *fakeCodeRunner) SeekBackgroundCommandOutput(_ string, _ int64) ([]byte, int64, error) { + return nil, 0, nil +} + +func (f *fakeCodeRunner) DeleteBashSession(_ string) error { + return nil +} + +func (f *fakeCodeRunner) Interrupt(_ string) error { + return nil +} + func TestBuildExecuteCodeRequestDefaultsToCommand(t *testing.T) { ctrl := &CodeInterpretingController{} req := model.RunCodeRequest{ @@ -92,3 +155,174 @@ func TestGetContext_MissingIDReturns400(t *testing.T) { require.Equal(t, model.ErrorCodeMissingQuery, resp.Code) require.Equal(t, "missing path parameter 'contextId'", resp.Message) } + +func TestRunCodeReturnsBeforeGracefulShutdownTimeoutAfterImmediateComplete(t *testing.T) { + previousRunner := codeRunner + previousTimeout := flag.ApiGracefulShutdownTimeout + codeRunner = &fakeCodeRunner{ + execute: func(request *runtime.ExecuteCodeRequest) error { + request.Hooks.OnExecuteComplete(5 * time.Millisecond) + return nil + }, + } + flag.ApiGracefulShutdownTimeout = 200 * time.Millisecond + t.Cleanup(func() { + codeRunner = previousRunner + flag.ApiGracefulShutdownTimeout = previousTimeout + }) + + body := []byte(`{"code":"print(1)","context":{"id":"ctx-1","language":"python"}}`) + ctx, w := newTestContext(http.MethodPost, "/code/run", body) + ctrl := NewCodeInterpretingController(ctx) + + start := time.Now() + ctrl.RunCode() + elapsed := time.Since(start) + + require.Equal(t, http.StatusOK, w.Code) + require.Less(t, elapsed, flag.ApiGracefulShutdownTimeout/2) +} + +func TestRunInSessionReturnsBeforeGracefulShutdownTimeoutAfterImmediateComplete(t *testing.T) { + previousRunner := codeRunner + previousTimeout := flag.ApiGracefulShutdownTimeout + codeRunner = &fakeCodeRunner{ + runInBashSession: func(_ context.Context, request *runtime.ExecuteCodeRequest) error { + request.Hooks.OnExecuteComplete(5 * time.Millisecond) + return nil + }, + } + flag.ApiGracefulShutdownTimeout = 200 * time.Millisecond + t.Cleanup(func() { + codeRunner = previousRunner + flag.ApiGracefulShutdownTimeout = previousTimeout + }) + + body := []byte(`{"command":"echo hi","timeout":0}`) + ctx, w := newTestContext(http.MethodPost, "/sessions/session-1/run", body) + ctx.Params = append(ctx.Params, gin.Param{Key: "sessionId", Value: "session-1"}) + ctrl := NewCodeInterpretingController(ctx) + + start := time.Now() + ctrl.RunInSession() + elapsed := time.Since(start) + + require.Equal(t, http.StatusOK, w.Code) + require.Less(t, elapsed, flag.ApiGracefulShutdownTimeout/2) +} + +func TestRunCodeReturnsBeforeGracefulShutdownTimeoutWhenRequestContextCanceled(t *testing.T) { + previousRunner := codeRunner + previousTimeout := flag.ApiGracefulShutdownTimeout + reqCtx, cancelReq := context.WithCancel(context.Background()) + codeRunner = &fakeCodeRunner{ + execute: func(_ *runtime.ExecuteCodeRequest) error { + cancelReq() + return nil + }, + } + flag.ApiGracefulShutdownTimeout = 200 * time.Millisecond + t.Cleanup(func() { + codeRunner = previousRunner + flag.ApiGracefulShutdownTimeout = previousTimeout + cancelReq() + }) + + body := []byte(`{"code":"print(1)","context":{"id":"ctx-1","language":"python"}}`) + ctx, w := newTestContext(http.MethodPost, "/code/run", body) + ctx.Request = ctx.Request.WithContext(reqCtx) + ctrl := NewCodeInterpretingController(ctx) + + start := time.Now() + ctrl.RunCode() + elapsed := time.Since(start) + + require.Equal(t, http.StatusOK, w.Code) + require.Less(t, elapsed, flag.ApiGracefulShutdownTimeout/2) +} + +func TestRunInSessionReturnsBeforeGracefulShutdownTimeoutWhenRequestContextCanceled(t *testing.T) { + previousRunner := codeRunner + previousTimeout := flag.ApiGracefulShutdownTimeout + reqCtx, cancelReq := context.WithCancel(context.Background()) + codeRunner = &fakeCodeRunner{ + runInBashSession: func(_ context.Context, _ *runtime.ExecuteCodeRequest) error { + cancelReq() + return nil + }, + } + flag.ApiGracefulShutdownTimeout = 200 * time.Millisecond + t.Cleanup(func() { + codeRunner = previousRunner + flag.ApiGracefulShutdownTimeout = previousTimeout + cancelReq() + }) + + body := []byte(`{"command":"echo hi","timeout":0}`) + ctx, w := newTestContext(http.MethodPost, "/sessions/session-1/run", body) + ctx.Request = ctx.Request.WithContext(reqCtx) + ctx.Params = append(ctx.Params, gin.Param{Key: "sessionId", Value: "session-1"}) + ctrl := NewCodeInterpretingController(ctx) + + start := time.Now() + ctrl.RunInSession() + elapsed := time.Since(start) + + require.Equal(t, http.StatusOK, w.Code) + require.Less(t, elapsed, flag.ApiGracefulShutdownTimeout/2) +} + +func TestRunCodeReturnsBeforeGracefulShutdownTimeoutAfterImmediateError(t *testing.T) { + previousRunner := codeRunner + previousTimeout := flag.ApiGracefulShutdownTimeout + codeRunner = &fakeCodeRunner{ + execute: func(request *runtime.ExecuteCodeRequest) error { + request.Hooks.OnExecuteError(&execute.ErrorOutput{EName: "ExecError", EValue: "boom"}) + return nil + }, + } + flag.ApiGracefulShutdownTimeout = 200 * time.Millisecond + t.Cleanup(func() { + codeRunner = previousRunner + flag.ApiGracefulShutdownTimeout = previousTimeout + }) + + body := []byte(`{"code":"print(1)","context":{"id":"ctx-1","language":"python"}}`) + ctx, w := newTestContext(http.MethodPost, "/code/run", body) + ctrl := NewCodeInterpretingController(ctx) + + start := time.Now() + ctrl.RunCode() + elapsed := time.Since(start) + + require.Equal(t, http.StatusOK, w.Code) + require.Less(t, elapsed, flag.ApiGracefulShutdownTimeout/2) +} + +func TestRunInSessionReturnsBeforeGracefulShutdownTimeoutAfterImmediateError(t *testing.T) { + previousRunner := codeRunner + previousTimeout := flag.ApiGracefulShutdownTimeout + codeRunner = &fakeCodeRunner{ + runInBashSession: func(_ context.Context, request *runtime.ExecuteCodeRequest) error { + request.Hooks.OnExecuteError(&execute.ErrorOutput{EName: "ExecError", EValue: "boom"}) + return nil + }, + } + flag.ApiGracefulShutdownTimeout = 200 * time.Millisecond + t.Cleanup(func() { + codeRunner = previousRunner + flag.ApiGracefulShutdownTimeout = previousTimeout + }) + + body := []byte(`{"command":"echo hi","timeout":0}`) + ctx, w := newTestContext(http.MethodPost, "/sessions/session-1/run", body) + ctx.Params = append(ctx.Params, gin.Param{Key: "sessionId", Value: "session-1"}) + ctrl := NewCodeInterpretingController(ctx) + + start := time.Now() + ctrl.RunInSession() + elapsed := time.Since(start) + + require.Equal(t, http.StatusOK, w.Code) + require.Less(t, elapsed, flag.ApiGracefulShutdownTimeout/2) +} From bdb12720a496c7877b0135c9506313e3934f8b40 Mon Sep 17 00:00:00 2001 From: "skyler.su" Date: Fri, 27 Mar 2026 17:05:35 +0800 Subject: [PATCH 3/3] feat(execd): solve the issues from unit test --- .../execd/pkg/jupyter/execute/execute.go | 282 ++++++++++-------- components/execd/pkg/runtime/jupyter.go | 58 ++-- .../execd/pkg/runtime/jupyter_hooks_test.go | 64 ++++ 3 files changed, 244 insertions(+), 160 deletions(-) create mode 100644 components/execd/pkg/runtime/jupyter_hooks_test.go diff --git a/components/execd/pkg/jupyter/execute/execute.go b/components/execd/pkg/jupyter/execute/execute.go index 0c4096961..02807cbec 100644 --- a/components/execd/pkg/jupyter/execute/execute.go +++ b/components/execd/pkg/jupyter/execute/execute.go @@ -113,16 +113,50 @@ func (c *Client) IsConnected() bool { return c.conn != nil } +type streamExecutionState struct { + startTime time.Time + result *ExecutionResult + executeDone bool + executeMutex sync.Mutex + resultMutex sync.Mutex +} + +func newStreamExecutionState(startTime time.Time) *streamExecutionState { + return &streamExecutionState{ + startTime: startTime, + result: &ExecutionResult{ + Status: "ok", + Stream: make([]*StreamOutput, 0), + ExecutionTime: 0, + }, + } +} + // ExecuteCodeStream executes code in streaming mode, sending results to the provided channel func (c *Client) ExecuteCodeStream(code string, resultChan chan *ExecutionResult) error { if !c.IsConnected() { return errors.New("not connected to kernel, please call Connect method") } - // record start time - startTime := time.Now() + msg, err := c.buildExecuteMessage(code) + if err != nil { + return err + } - // prepare execution request + state := newStreamExecutionState(time.Now()) + + // Clear temporary handlers + c.clearTemporaryHandlers() + c.registerExecuteCodeStreamHandlers(state, resultChan) + + if err := c.writeMessage(msg); err != nil { + return fmt.Errorf("failed to send execution request: %w", err) + } + + return nil +} + +func (c *Client) buildExecuteMessage(code string) (*Message, error) { msgID := c.nextMessageID() request := &ExecuteRequest{ Code: code, @@ -133,13 +167,11 @@ func (c *Client) ExecuteCodeStream(code string, resultChan chan *ExecutionResult StopOnError: true, } - // serialize request content content, err := json.Marshal(request) if err != nil { - return fmt.Errorf("failed to serialize request: %w", err) + return nil, fmt.Errorf("failed to serialize request: %w", err) } - // create message msg := &Message{ Header: Header{ MessageID: msgID, @@ -155,153 +187,139 @@ func (c *Client) ExecuteCodeStream(code string, resultChan chan *ExecutionResult Channel: "shell", } - // Create result object - result := &ExecutionResult{ - Status: "ok", - Stream: make([]*StreamOutput, 0), - ExecutionTime: 0, - } - - // Register temporary handler to receive execution result - var executeDone bool - var executeMutex sync.Mutex - var executeResult *ExecuteResult - - // Create mutex to protect result object - var resultMutex sync.Mutex - - // Clear temporary handlers - c.clearTemporaryHandlers() + return msg, nil +} +func (c *Client) registerExecuteCodeStreamHandlers(state *streamExecutionState, resultChan chan *ExecutionResult) { c.registerHandler(MsgExecuteReply, func(msg *Message) { - var execReply ExecuteReply - if err := json.Unmarshal(msg.Content, &execReply); err != nil { - return - } - - resultMutex.Lock() - result.ExecutionCount = execReply.ExecutionCount - if execReply.EName != "" { - result.Error = &execReply.ErrorOutput - } - resultMutex.Unlock() + c.handleExecuteReply(msg, state) }) - - // register execution result handler c.registerHandler(MsgExecuteResult, func(msg *Message) { - var execResult ExecuteResult - if err := json.Unmarshal(msg.Content, &execResult); err != nil { - return - } + c.handleExecuteResult(msg, state, resultChan) + }) + c.registerHandler(MsgStream, func(msg *Message) { + c.handleStreamOutput(msg, state, resultChan) + }) + c.registerHandler(MsgError, func(msg *Message) { + c.handleExecutionError(msg, state, resultChan) + }) + c.registerHandler(MsgStatus, func(msg *Message) { + c.handleExecutionStatus(msg, state, resultChan) + }) +} - executeMutex.Lock() - executeResult = &execResult - executeMutex.Unlock() +func (c *Client) handleExecuteReply(msg *Message, state *streamExecutionState) { + var execReply ExecuteReply + if err := json.Unmarshal(msg.Content, &execReply); err != nil { + return + } - resultMutex.Lock() - result.ExecutionCount = execResult.ExecutionCount + state.resultMutex.Lock() + defer state.resultMutex.Unlock() + state.result.ExecutionCount = execReply.ExecutionCount + if execReply.EName != "" { + state.result.Error = &execReply.ErrorOutput + } +} - notify := &ExecutionResult{} - notify.ExecutionCount = executeResult.ExecutionCount - notify.ExecutionData = executeResult.Data +func (c *Client) handleExecuteResult(msg *Message, state *streamExecutionState, resultChan chan *ExecutionResult) { + var execResult ExecuteResult + if err := json.Unmarshal(msg.Content, &execResult); err != nil { + return + } - resultChan <- notify - resultMutex.Unlock() - }) + state.resultMutex.Lock() + defer state.resultMutex.Unlock() + state.result.ExecutionCount = execResult.ExecutionCount - // Register stream output handler - c.registerHandler(MsgStream, func(msg *Message) { - var stream StreamOutput - if err := json.Unmarshal(msg.Content, &stream); err != nil { - return - } + notify := &ExecutionResult{ + ExecutionCount: execResult.ExecutionCount, + ExecutionData: execResult.Data, + } + resultChan <- notify +} - resultMutex.Lock() - result.Stream = append(result.Stream, &stream) +func (c *Client) handleStreamOutput(msg *Message, state *streamExecutionState, resultChan chan *ExecutionResult) { + var stream StreamOutput + if err := json.Unmarshal(msg.Content, &stream); err != nil { + return + } + + state.resultMutex.Lock() + defer state.resultMutex.Unlock() + state.result.Stream = append(state.result.Stream, &stream) + notify := &ExecutionResult{ + Stream: []*StreamOutput{&stream}, + } + resultChan <- notify +} - notify := &ExecutionResult{} - notify.Stream = []*StreamOutput{&stream} +func (c *Client) handleExecutionError(msg *Message, state *streamExecutionState, resultChan chan *ExecutionResult) { + var errOutput ErrorOutput + if err := json.Unmarshal(msg.Content, &errOutput); err != nil { + return + } - resultChan <- notify - resultMutex.Unlock() - }) + state.resultMutex.Lock() + defer state.resultMutex.Unlock() + state.result.Status = "error" + state.result.Error = &errOutput + notify := &ExecutionResult{ + Error: &errOutput, + Status: "error", + } + resultChan <- notify +} - // register error handler - c.registerHandler(MsgError, func(msg *Message) { - var errOutput ErrorOutput - if err := json.Unmarshal(msg.Content, &errOutput); err != nil { - return - } +func (c *Client) handleExecutionStatus(msg *Message, state *streamExecutionState, resultChan chan *ExecutionResult) { + var status StatusUpdate + if err := json.Unmarshal(msg.Content, &status); err != nil { + return + } + if status.ExecutionState != StateIdle { + return + } - resultMutex.Lock() - result.Status = "error" - result.Error = &errOutput + state.executeMutex.Lock() + defer state.executeMutex.Unlock() + if state.executeDone { + return + } + state.executeDone = true + go c.finalizeExecution(state, resultChan) +} - notify := &ExecutionResult{} - notify.Error = &errOutput - notify.Status = "error" +func (c *Client) finalizeExecution(state *streamExecutionState, resultChan chan *ExecutionResult) { + state.resultMutex.Lock() + state.result.ExecutionTime = time.Since(state.startTime) + notify := &ExecutionResult{ + ExecutionTime: state.result.ExecutionTime, + } + resultChan <- notify + state.resultMutex.Unlock() - resultChan <- notify - resultMutex.Unlock() - }) + pollInterval := execdflag.JupyterIdlePollInterval + if pollInterval <= 0 { + pollInterval = 10 * time.Millisecond + } - // register status handler - c.registerHandler(MsgStatus, func(msg *Message) { - var status StatusUpdate - if err := json.Unmarshal(msg.Content, &status); err != nil { - return + for { + state.resultMutex.Lock() + done := state.result.ExecutionCount > 0 || state.result.Error != nil + state.resultMutex.Unlock() + if done { + break } + time.Sleep(pollInterval) + } - if status.ExecutionState == StateIdle { - executeMutex.Lock() - - // Check whether execution can be completed - if !executeDone { - executeDone = true - go func() { - // calculate execution time - resultMutex.Lock() - result.ExecutionTime = time.Since(startTime) - - // Send final result - notify := &ExecutionResult{} - notify.ExecutionTime = result.ExecutionTime - - resultChan <- notify - resultMutex.Unlock() - - pollInterval := execdflag.JupyterIdlePollInterval - if pollInterval <= 0 { - pollInterval = 10 * time.Millisecond - } - - for { - resultMutex.Lock() - done := result.ExecutionCount > 0 || result.Error != nil - resultMutex.Unlock() - if done { - break - } - time.Sleep(pollInterval) - } - - // Close result channel - close(resultChan) - }() - } - executeMutex.Unlock() - } - }) + close(resultChan) +} - // send execution request +func (c *Client) writeMessage(msg *Message) error { c.mu.Lock() - err = c.conn.WriteJSON(msg) - c.mu.Unlock() - if err != nil { - return fmt.Errorf("failed to send execution request: %w", err) - } - - return nil + defer c.mu.Unlock() + return c.conn.WriteJSON(msg) } // ExecuteCodeWithCallback executes code using callback functions diff --git a/components/execd/pkg/runtime/jupyter.go b/components/execd/pkg/runtime/jupyter.go index 9ea33b13b..f4732d515 100644 --- a/components/execd/pkg/runtime/jupyter.go +++ b/components/execd/pkg/runtime/jupyter.go @@ -82,34 +82,7 @@ func (c *Controller) runJupyterCode(ctx context.Context, kernel *jupyterKernel, if result == nil { return nil } - - if result.ExecutionCount > 0 || len(result.ExecutionData) > 0 { - request.Hooks.OnExecuteResult(result.ExecutionData, result.ExecutionCount) - } - - if result.Status != "" { - request.Hooks.OnExecuteStatus(result.Status) - } - - if result.ExecutionTime > 0 { - request.Hooks.OnExecuteComplete(result.ExecutionTime) - } - - if result.Error != nil { - request.Hooks.OnExecuteError(result.Error) - } - - if len(result.Stream) > 0 { - for _, stream := range result.Stream { - switch stream.Name { - case execute.StreamStdout: - request.Hooks.OnExecuteStdout(stream.Text) - case execute.StreamStderr: - request.Hooks.OnExecuteStderr(stream.Text) - default: - } - } - } + dispatchExecutionResultHooks(request, result) case <-ctx.Done(): log.Warning("context cancelled, try to interrupt kernel") @@ -127,6 +100,35 @@ func (c *Controller) runJupyterCode(ctx context.Context, kernel *jupyterKernel, } } +func dispatchExecutionResultHooks(request *ExecuteCodeRequest, result *execute.ExecutionResult) { + if result.ExecutionCount > 0 || len(result.ExecutionData) > 0 { + request.Hooks.OnExecuteResult(result.ExecutionData, result.ExecutionCount) + } + + if result.Status != "" { + request.Hooks.OnExecuteStatus(result.Status) + } + + if result.Error != nil { + request.Hooks.OnExecuteError(result.Error) + } + + // Treat completion as success-only terminal signal. For failed executions, + // error should be the terminal event to avoid losing error delivery. + if result.ExecutionTime > 0 && result.Error == nil { + request.Hooks.OnExecuteComplete(result.ExecutionTime) + } + + for _, stream := range result.Stream { + switch stream.Name { + case execute.StreamStdout: + request.Hooks.OnExecuteStdout(stream.Text) + case execute.StreamStderr: + request.Hooks.OnExecuteStderr(stream.Text) + } + } +} + // setWorkingDir configures the working directory for a kernel session. func (c *Controller) setWorkingDir(_ *jupyterKernel, _ *CreateContextRequest) error { return nil diff --git a/components/execd/pkg/runtime/jupyter_hooks_test.go b/components/execd/pkg/runtime/jupyter_hooks_test.go new file mode 100644 index 000000000..7a9e3a242 --- /dev/null +++ b/components/execd/pkg/runtime/jupyter_hooks_test.go @@ -0,0 +1,64 @@ +package runtime + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" +) + +func TestDispatchExecutionResultHooks_ErrorSkipsComplete(t *testing.T) { + var ( + errorCalls int + completeCalls int + ) + + req := &ExecuteCodeRequest{ + Hooks: ExecuteResultHook{ + OnExecuteError: func(_ *execute.ErrorOutput) { + errorCalls++ + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + }, + } + + dispatchExecutionResultHooks(req, &execute.ExecutionResult{ + ExecutionTime: 35 * time.Millisecond, + Error: &execute.ErrorOutput{ + EName: "RuntimeError", + EValue: "boom", + }, + }) + + require.Equal(t, 1, errorCalls) + require.Equal(t, 0, completeCalls) +} + +func TestDispatchExecutionResultHooks_SuccessEmitsComplete(t *testing.T) { + var ( + errorCalls int + completeCalls int + ) + + req := &ExecuteCodeRequest{ + Hooks: ExecuteResultHook{ + OnExecuteError: func(_ *execute.ErrorOutput) { + errorCalls++ + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + }, + } + + dispatchExecutionResultHooks(req, &execute.ExecutionResult{ + ExecutionTime: 50 * time.Millisecond, + }) + + require.Equal(t, 0, errorCalls) + require.Equal(t, 1, completeCalls) +}