diff --git a/packages/api/internal/handlers/admin_kill_team_sandboxes.go b/packages/api/internal/handlers/admin_kill_team_sandboxes.go index 5f3d589254..26ae6c7395 100644 --- a/packages/api/internal/handlers/admin_kill_team_sandboxes.go +++ b/packages/api/internal/handlers/admin_kill_team_sandboxes.go @@ -43,7 +43,7 @@ func (a *APIStore) PostAdminTeamsTeamIDSandboxesKill(c *gin.Context, teamID uuid // Kill each sandbox for _, sbx := range sandboxes { wg.Go(func() error { - err := a.orchestrator.RemoveSandbox(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + err := a.orchestrator.RemoveSandbox(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) if err != nil { logger.L().Error(ctx, "Failed to kill sandbox", logger.WithSandboxID(sbx.SandboxID), diff --git a/packages/api/internal/handlers/sandbox_connect.go b/packages/api/internal/handlers/sandbox_connect.go index 10046adb8a..652c1bd6c2 100644 --- a/packages/api/internal/handlers/sandbox_connect.go +++ b/packages/api/internal/handlers/sandbox_connect.go @@ -70,8 +70,7 @@ func (a *APIStore) PostSandboxesSandboxIDConnect(c *gin.Context, sandboxID api.S } // Sandbox not in store at all → fall through to snapshot resume. - var notFoundErr *sandbox.NotFoundError - if errors.As(apiErr.Err, ¬FoundErr) { + if errors.Is(apiErr.Err, sandbox.ErrNotFound) { break } diff --git a/packages/api/internal/handlers/sandbox_kill.go b/packages/api/internal/handlers/sandbox_kill.go index 12415f6c57..4033966f52 100644 --- a/packages/api/internal/handlers/sandbox_kill.go +++ b/packages/api/internal/handlers/sandbox_kill.go @@ -67,7 +67,7 @@ func (a *APIStore) DeleteSandboxesSandboxID( killedOrRemoved := false - err = a.orchestrator.RemoveSandbox(ctx, teamID, sandboxID, sandbox.StateActionKill) + err = a.orchestrator.RemoveSandbox(ctx, teamID, sandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) switch { case err == nil: killedOrRemoved = true diff --git a/packages/api/internal/handlers/sandbox_pause.go b/packages/api/internal/handlers/sandbox_pause.go index 3a07cee6c0..bb49c10f1b 100644 --- a/packages/api/internal/handlers/sandbox_pause.go +++ b/packages/api/internal/handlers/sandbox_pause.go @@ -39,7 +39,7 @@ func (a *APIStore) PostSandboxesSandboxIDPause(c *gin.Context, sandboxID api.San traceID := span.SpanContext().TraceID().String() c.Set("traceID", traceID) - err = a.orchestrator.RemoveSandbox(ctx, teamID, sandboxID, sandbox.StateActionPause) + err = a.orchestrator.RemoveSandbox(ctx, teamID, sandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) var transErr *sandbox.InvalidStateTransitionError switch { diff --git a/packages/api/internal/handlers/snapshot_template_create.go b/packages/api/internal/handlers/snapshot_template_create.go index 21314c8b10..f826fc9704 100644 --- a/packages/api/internal/handlers/snapshot_template_create.go +++ b/packages/api/internal/handlers/snapshot_template_create.go @@ -99,8 +99,7 @@ func (a *APIStore) PostSandboxesSandboxIDSnapshots(c *gin.Context, sandboxID api result, err := a.orchestrator.CreateSnapshotTemplate(ctx, teamID, sandboxID, opts) if err != nil { - var notFoundErr *sandbox.NotFoundError - if errors.As(err, ¬FoundErr) { + if errors.Is(err, sandbox.ErrNotFound) { logger.L().Debug(ctx, "Sandbox not found for snapshot", logger.WithSandboxID(sandboxID)) a.sendAPIStoreError(c, http.StatusNotFound, utils.SandboxNotFoundMsg(sandboxID)) diff --git a/packages/api/internal/orchestrator/delete_instance.go b/packages/api/internal/orchestrator/delete_instance.go index 9cc6bfe6d7..b9bf391b5c 100644 --- a/packages/api/internal/orchestrator/delete_instance.go +++ b/packages/api/internal/orchestrator/delete_instance.go @@ -17,16 +17,20 @@ import ( sbxlogger "github.com/e2b-dev/infra/packages/shared/pkg/logger/sandbox" ) -func (o *Orchestrator) RemoveSandbox(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) error { +func (o *Orchestrator) RemoveSandbox(ctx context.Context, teamID uuid.UUID, sandboxID string, opts sandbox.RemoveOpts) error { ctx, span := tracer.Start(ctx, "remove-sandbox") defer span.End() - sbx, alreadyDone, finish, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, stateAction) + sbx, alreadyDone, finish, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, opts) if err != nil { - switch stateAction { + // For eviction, propagate all errors to the evictor. + if opts.Eviction { + return err + } + + switch opts.Action { case sandbox.StateActionKill: - var notFoundErr *sandbox.NotFoundError - if errors.As(err, ¬FoundErr) { + if errors.Is(err, sandbox.ErrNotFound) { logger.L().Info(ctx, "Sandbox not found, already removed", logger.WithSandboxID(sandboxID)) return ErrSandboxNotFound @@ -43,8 +47,7 @@ func (o *Orchestrator) RemoveSandbox(ctx context.Context, teamID uuid.UUID, sand return ErrSandboxOperationFailed } case sandbox.StateActionPause: - var notFoundErrPause *sandbox.NotFoundError - if errors.As(err, ¬FoundErrPause) { + if errors.Is(err, sandbox.ErrNotFound) { logger.L().Info(ctx, "Sandbox not found for pause", logger.WithSandboxID(sandboxID)) return ErrSandboxNotFound @@ -65,7 +68,7 @@ func (o *Orchestrator) RemoveSandbox(ctx context.Context, teamID uuid.UUID, sand return ErrSandboxOperationFailed default: - logger.L().Error(ctx, "Invalid state action", logger.WithSandboxID(sandboxID), zap.String("state_action", stateAction.Name)) + logger.L().Error(ctx, "Invalid state action", logger.WithSandboxID(sandboxID), zap.String("state_action", opts.Action.Name)) return ErrSandboxOperationFailed } @@ -80,10 +83,10 @@ func (o *Orchestrator) RemoveSandbox(ctx context.Context, teamID uuid.UUID, sand return nil } - defer func() { go o.countersRemove(context.WithoutCancel(ctx), teamID, stateAction) }() - defer func() { go o.analyticsRemove(context.WithoutCancel(ctx), sbx, stateAction) }() + defer func() { go o.countersRemove(context.WithoutCancel(ctx), teamID, opts.Action) }() + defer func() { go o.analyticsRemove(context.WithoutCancel(ctx), sbx, opts.Action) }() defer o.sandboxStore.Remove(ctx, teamID, sandboxID) - err = o.removeSandboxFromNode(ctx, sbx, stateAction) + err = o.removeSandboxFromNode(ctx, sbx, opts.Action) if err != nil { logger.L().Error(ctx, "Error pausing sandbox", zap.Error(err), logger.WithSandboxID(sbx.SandboxID)) diff --git a/packages/api/internal/orchestrator/evictor/evict.go b/packages/api/internal/orchestrator/evictor/evict.go index 8a1622cb43..fa8734a223 100644 --- a/packages/api/internal/orchestrator/evictor/evict.go +++ b/packages/api/internal/orchestrator/evictor/evict.go @@ -2,6 +2,7 @@ package evictor import ( "context" + "errors" "time" "github.com/google/uuid" @@ -18,12 +19,12 @@ const ( type Evictor struct { store *sandbox.Store - removeSandbox func(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) error + removeSandbox func(ctx context.Context, teamID uuid.UUID, sandboxID string, opts sandbox.RemoveOpts) error } func New( store *sandbox.Store, - removeSandbox func(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) error, + removeSandbox func(ctx context.Context, teamID uuid.UUID, sandboxID string, opts sandbox.RemoveOpts) error, ) *Evictor { return &Evictor{ store: store, @@ -53,16 +54,21 @@ func (e *Evictor) Start(ctx context.Context) { for _, item := range sbxs { g.Go(func() error { - stateAction := sandbox.StateActionKill + action := sandbox.StateActionKill if item.AutoPause { - stateAction = sandbox.StateActionPause + action = sandbox.StateActionPause } - logger.L().Debug(ctx, "Evicting sandbox", logger.WithSandboxID(item.SandboxID), zap.String("state_action", stateAction.Name)) - if err := e.removeSandbox(context.WithoutCancel(ctx), item.TeamID, item.SandboxID, stateAction); err != nil { - logger.L().Debug(ctx, "Evicting sandbox failed", zap.Error(err), logger.WithSandboxID(item.SandboxID)) + if err := e.removeSandbox(context.WithoutCancel(ctx), item.TeamID, item.SandboxID, sandbox.RemoveOpts{Action: action, Eviction: true}); err != nil { + if !errors.Is(err, sandbox.ErrNotEvictable) && !errors.Is(err, sandbox.ErrNotFound) { + logger.L().Debug(ctx, "Evicting sandbox failed", zap.Error(err), logger.WithSandboxID(item.SandboxID)) + } + + return nil } + logger.L().Debug(ctx, "Sandbox evicted", logger.WithSandboxID(item.SandboxID)) + return nil }) } diff --git a/packages/api/internal/orchestrator/keep_alive.go b/packages/api/internal/orchestrator/keep_alive.go index 538b187462..01d6cd059d 100644 --- a/packages/api/internal/orchestrator/keep_alive.go +++ b/packages/api/internal/orchestrator/keep_alive.go @@ -45,14 +45,13 @@ func (o *Orchestrator) KeepAliveFor(ctx context.Context, teamID uuid.UUID, sandb return sbx, nil } - var sbxNotFoundErr *sandbox.NotFoundError var sbxNotRunningErr *sandbox.NotRunningError sbx, err := o.sandboxStore.Update(ctx, teamID, sandboxID, updateFunc) if err != nil { switch { case errors.As(err, &sbxNotRunningErr): return nil, &api.APIError{Code: http.StatusConflict, ClientMsg: utils.SandboxChangingStateMsg(sandboxID, sbxNotRunningErr.State), Err: err} - case errors.As(err, &sbxNotFoundErr): + case errors.Is(err, sandbox.ErrNotFound): return nil, &api.APIError{Code: http.StatusNotFound, ClientMsg: utils.SandboxNotFoundMsg(sandboxID), Err: err} case errors.Is(err, errMaxInstanceLengthExceeded): return nil, &api.APIError{Code: http.StatusBadRequest, ClientMsg: "Max instance length exceeded", Err: err} diff --git a/packages/api/internal/orchestrator/snapshot_template.go b/packages/api/internal/orchestrator/snapshot_template.go index 144fbf8aeb..84478ebf89 100644 --- a/packages/api/internal/orchestrator/snapshot_template.go +++ b/packages/api/internal/orchestrator/snapshot_template.go @@ -39,7 +39,7 @@ func (o *Orchestrator) CreateSnapshotTemplate(ctx context.Context, teamID uuid.U ctx, span := tracer.Start(ctx, "create-snapshot-template") defer span.End() - sbx, alreadyDone, finishSnapshotting, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, sandbox.StateActionSnapshot) + sbx, alreadyDone, finishSnapshotting, err := o.sandboxStore.StartRemoving(ctx, teamID, sandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionSnapshot}) if err != nil { return SnapshotTemplateResult{}, fmt.Errorf("failed to start snapshotting: %w", err) } @@ -97,7 +97,7 @@ func (o *Orchestrator) CreateSnapshotTemplate(ctx context.Context, teamID uuid.U // so RemoveSandbox can proceed without deadlock. finish(err) - if killErr := o.RemoveSandbox(ctx, teamID, sandboxID, sandbox.StateActionKill); killErr != nil { + if killErr := o.RemoveSandbox(ctx, teamID, sandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}); killErr != nil { telemetry.ReportError(ctx, "error killing sandbox after failed checkpoint", killErr) } diff --git a/packages/api/internal/sandbox/errors.go b/packages/api/internal/sandbox/errors.go index fe8ca4054e..d8d05b05e5 100644 --- a/packages/api/internal/sandbox/errors.go +++ b/packages/api/internal/sandbox/errors.go @@ -15,13 +15,7 @@ func (e *LimitExceededError) Error() string { return fmt.Sprintf("team %s has exceeded the limit", e.TeamID.String()) } -type NotFoundError struct { - SandboxID string -} - -func (e *NotFoundError) Error() string { - return fmt.Sprintf("sandbox %s not found", e.SandboxID) -} +var ErrNotFound = errors.New("sandbox not found") type InvalidStateTransitionError struct { CurrentState State @@ -42,3 +36,5 @@ func (e *NotRunningError) Error() string { } var ErrAlreadyExists = errors.New("sandbox already exists") + +var ErrNotEvictable = errors.New("sandbox is not expirable") diff --git a/packages/api/internal/sandbox/states.go b/packages/api/internal/sandbox/states.go index 637c58f947..0a59155310 100644 --- a/packages/api/internal/sandbox/states.go +++ b/packages/api/internal/sandbox/states.go @@ -42,6 +42,12 @@ var ( } ) +// RemoveOpts bundles the parameters that control sandbox removal. +type RemoveOpts struct { + Action StateAction + Eviction bool +} + var AllowedTransitions = map[State]map[State]bool{ StateRunning: {StatePausing: true, StateKilling: true, StateSnapshotting: true}, StatePausing: {StateKilling: true}, diff --git a/packages/api/internal/sandbox/storage/memory/operations.go b/packages/api/internal/sandbox/storage/memory/operations.go index 65b8500f49..c1f52f3820 100644 --- a/packages/api/internal/sandbox/storage/memory/operations.go +++ b/packages/api/internal/sandbox/storage/memory/operations.go @@ -43,12 +43,12 @@ func (s *Storage) get(sandboxID string) (*memorySandbox, error) { func (s *Storage) Get(_ context.Context, teamID uuid.UUID, sandboxID string) (sandbox.Sandbox, error) { item, ok := s.items.Get(sandboxID) if !ok { - return sandbox.Sandbox{}, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } data := item.Data() if data.TeamID != teamID { - return sandbox.Sandbox{}, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } return data, nil @@ -115,14 +115,14 @@ func (s *Storage) ExpiredItems(_ context.Context) ([]sandbox.Sandbox, error) { func (s *Storage) Update(_ context.Context, teamID uuid.UUID, sandboxID string, updateFunc func(sandbox.Sandbox) (sandbox.Sandbox, error)) (sandbox.Sandbox, error) { item, ok := s.items.Get(sandboxID) if !ok { - return sandbox.Sandbox{}, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } item.mu.Lock() defer item.mu.Unlock() if item._data.TeamID != teamID { - return sandbox.Sandbox{}, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } sbx, err := updateFunc(item._data) @@ -135,27 +135,45 @@ func (s *Storage) Update(_ context.Context, teamID uuid.UUID, sandboxID string, return sbx, nil } -func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (sandbox.Sandbox, bool, func(context.Context, error), error) { +func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, opts sandbox.RemoveOpts) (sandbox.Sandbox, bool, func(context.Context, error), error) { sbx, err := s.get(sandboxID) if err != nil { - return sandbox.Sandbox{}, false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, false, nil, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } data := sbx.Data() if data.TeamID != teamID { - return sandbox.Sandbox{}, false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, false, nil, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } - alreadyDone, callback, err := startRemoving(ctx, sbx, stateAction) + alreadyDone, callback, err := startRemoving(ctx, sbx, opts) return sbx.Data(), alreadyDone, callback, err } -func startRemoving(ctx context.Context, sbx *memorySandbox, stateAction sandbox.StateAction) (alreadyDone bool, callback func(ctx context.Context, err error), err error) { - newState := stateAction.TargetState - +func startRemoving(ctx context.Context, sbx *memorySandbox, opts sandbox.RemoveOpts) (alreadyDone bool, callback func(ctx context.Context, err error), err error) { sbx.mu.Lock() transition := sbx.transition + + // Resolve eviction under the lock + re-check expiry + if opts.Eviction { + // If there's a transition already in place, don't evict. + if transition != nil { + sbx.mu.Unlock() + + return false, nil, sandbox.ErrNotEvictable + } + + // If sandbox isn't expired (e.g. race condition with KeepAliveFor), skip. + if !sbx._data.IsExpired(time.Now()) { + sbx.mu.Unlock() + + return false, nil, sandbox.ErrNotEvictable + } + } + + newState := opts.Action.TargetState + if transition != nil { currentState := sbx._data.State sbx.mu.Unlock() @@ -175,7 +193,7 @@ func startRemoving(ctx context.Context, sbx *memorySandbox, stateAction sandbox. case currentState == newState: return true, func(context.Context, error) {}, nil case sandbox.AllowedTransitions[currentState][newState]: - return startRemoving(ctx, sbx, stateAction) + return startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: opts.Action}) default: return false, nil, fmt.Errorf("unexpected state transition") } @@ -192,7 +210,7 @@ func startRemoving(ctx context.Context, sbx *memorySandbox, stateAction sandbox. return false, nil, &sandbox.InvalidStateTransitionError{CurrentState: sbx._data.State, TargetState: newState} } - if stateAction.Effect == sandbox.TransitionExpires { + if opts.Action.Effect == sandbox.TransitionExpires { sbx.setExpired() } @@ -204,7 +222,7 @@ func startRemoving(ctx context.Context, sbx *memorySandbox, stateAction sandbox. sbx.mu.Lock() defer sbx.mu.Unlock() - if stateAction.Effect == sandbox.TransitionTransient { + if opts.Action.Effect == sandbox.TransitionTransient { if err == nil && sbx._data.State == newState { sbx._data.State = sandbox.StateRunning } diff --git a/packages/api/internal/sandbox/storage/memory/operations_test.go b/packages/api/internal/sandbox/storage/memory/operations_test.go index b030de1c3b..1994e0d45a 100644 --- a/packages/api/internal/sandbox/storage/memory/operations_test.go +++ b/packages/api/internal/sandbox/storage/memory/operations_test.go @@ -59,7 +59,7 @@ func TestStartRemoving_BasicTransitions(t *testing.T) { sbx._data.State = tt.fromState ctx := t.Context() - alreadyDone, finish, err := startRemoving(ctx, sbx, tt.stateAction) + alreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: tt.stateAction}) switch { case tt.shouldError: @@ -89,7 +89,7 @@ func TestStartRemoving_PauseThenKill(t *testing.T) { ctx := t.Context() // Simulate a pause operation that takes time - alreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, finish) @@ -112,7 +112,7 @@ func TestStartRemoving_PauseThenKill(t *testing.T) { <-started // Ensure the pause operation has started start := time.Now() - alreadyDone2, finish2, err2 := startRemoving(ctx, sbx, sandbox.StateActionKill) + alreadyDone2, finish2, err2 := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) elapsed := time.Since(start) // Should have waited for the pause to complete @@ -142,7 +142,7 @@ func TestStartRemoving_ConcurrentSameState(t *testing.T) { // Three concurrent requests to pause the sandbox for range 3 { go func() { - alreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) if err == nil { if alreadyDone { // Already alreadyDone (waited for another transition) @@ -196,7 +196,7 @@ func TestStartRemoving_Error(t *testing.T) { ctx := t.Context() // First attempt to pause - alreadyDone1, finish1, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyDone1, finish1, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone1) require.NotNil(t, finish1) @@ -209,7 +209,7 @@ func TestStartRemoving_Error(t *testing.T) { go func() { // This should wait for the first transition, then try to go to Killed - alreadyDone2, finish2, err2 = startRemoving(ctx, sbx, sandbox.StateActionKill) + alreadyDone2, finish2, err2 = startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) completed <- true }() @@ -230,14 +230,14 @@ func TestStartRemoving_Error(t *testing.T) { assert.Nil(t, finish2) // From Failed state, no transitions are allowed - alreadyDone3, finish3, err3 := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyDone3, finish3, err3 := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.Error(t, err3) require.ErrorIs(t, err3, failureErr) assert.False(t, alreadyDone3) assert.Nil(t, finish3) // Trying to transition to Killed should also fail - alreadyDone4, finish4, err4 := startRemoving(ctx, sbx, sandbox.StateActionKill) + alreadyDone4, finish4, err4 := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) require.Error(t, err4) require.ErrorIs(t, err4, failureErr) assert.False(t, alreadyDone4) @@ -251,7 +251,7 @@ func TestStartRemoving_ContextTimeout(t *testing.T) { sbx := createTestSandbox() // Start a long-running transition - alreadyDone1, finish1, err := startRemoving(t.Context(), sbx, sandbox.StateActionPause) + alreadyDone1, finish1, err := startRemoving(t.Context(), sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone1) require.NotNil(t, finish1) @@ -261,7 +261,7 @@ func TestStartRemoving_ContextTimeout(t *testing.T) { defer cancel() start := time.Now() - _, _, err2 := startRemoving(ctx, sbx, sandbox.StateActionKill) + _, _, err2 := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) elapsed := time.Since(start) // Should timeout after about 20ms @@ -294,7 +294,7 @@ func TestWaitForStateChange_WaitForCompletion(t *testing.T) { ctx := t.Context() // Start a transition - alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyalreadyDone) require.NotNil(t, finish) @@ -325,7 +325,7 @@ func TestWaitForStateChange_WaitWithError(t *testing.T) { ctx := t.Context() // Start a transition - alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyalreadyDone) require.NotNil(t, finish) @@ -358,7 +358,7 @@ func TestWaitForStateChange_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) // Start a transition - alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyalreadyDone) require.NotNil(t, finish) @@ -393,7 +393,7 @@ func TestWaitForStateChange_MultipleWaiters(t *testing.T) { ctx := t.Context() // Start a transition - alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.StateActionPause) + alreadyalreadyDone, finish, err := startRemoving(ctx, sbx, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyalreadyDone) require.NotNil(t, finish) @@ -449,7 +449,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionSnapshot}) require.NoError(t, err) assert.False(t, snapAlreadyDone) require.NotNil(t, finishSnap) @@ -461,7 +461,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { go func() { defer close(pauseDone) - _, pauseAlreadyDone, pauseFinish, pauseErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, pauseAlreadyDone, pauseFinish, pauseErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) }() time.Sleep(50 * time.Millisecond) @@ -505,7 +505,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionSnapshot}) require.NoError(t, err) assert.False(t, snapAlreadyDone) @@ -517,7 +517,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { go func() { defer close(killDone) - _, killAlreadyDone, killFinish, killErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, killAlreadyDone, killFinish, killErr = storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) }() // Give the kill goroutine time to start waiting @@ -564,7 +564,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - _, _, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, _, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionSnapshot}) require.NoError(t, err) // Finish with error — state stays Snapshotting, transition cleared @@ -575,7 +575,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { assert.Equal(t, sandbox.StateSnapshotting, got.State) // Kill proceeds immediately — no active transition, Snapshotting→Killing is allowed - _, killAlreadyDone, killFinish, killErr := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, killAlreadyDone, killFinish, killErr := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) require.NoError(t, killErr) assert.False(t, killAlreadyDone) require.NotNil(t, killFinish) @@ -607,7 +607,7 @@ func TestStartRemoving_DuringSnapshotting(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionSnapshot) + _, snapAlreadyDone, finishSnap, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionSnapshot}) require.NoError(t, err) assert.False(t, snapAlreadyDone) @@ -656,7 +656,7 @@ func TestConcurrency_StressTest(t *testing.T) { stateActions := []sandbox.StateAction{sandbox.StateActionPause, sandbox.StateActionKill} stateAction := stateActions[rand.Intn(len(stateActions))] - alreadyDone, finish, err := startRemoving(t.Context(), sbx, stateAction) + alreadyDone, finish, err := startRemoving(t.Context(), sbx, sandbox.RemoveOpts{Action: stateAction}) if err == nil && (finish != nil || alreadyDone) { if finish != nil { finish(t.Context(), nil) diff --git a/packages/api/internal/sandbox/storage/populate_redis/main.go b/packages/api/internal/sandbox/storage/populate_redis/main.go index 72c4fc01e9..ea6a787d71 100644 --- a/packages/api/internal/sandbox/storage/populate_redis/main.go +++ b/packages/api/internal/sandbox/storage/populate_redis/main.go @@ -79,8 +79,8 @@ func (m *PopulateRedisStorage) Update(ctx context.Context, teamID uuid.UUID, san return sbx, nil } -func (m *PopulateRedisStorage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (sandbox.Sandbox, bool, func(context.Context, error), error) { - return m.memoryBackend.StartRemoving(ctx, teamID, sandboxID, stateAction) +func (m *PopulateRedisStorage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, opts sandbox.RemoveOpts) (sandbox.Sandbox, bool, func(context.Context, error), error) { + return m.memoryBackend.StartRemoving(ctx, teamID, sandboxID, opts) } func (m *PopulateRedisStorage) WaitForStateChange(ctx context.Context, teamID uuid.UUID, sandboxID string) error { diff --git a/packages/api/internal/sandbox/storage/redis/operations.go b/packages/api/internal/sandbox/storage/redis/operations.go index e5f97f1cf5..f67db89ee0 100644 --- a/packages/api/internal/sandbox/storage/redis/operations.go +++ b/packages/api/internal/sandbox/storage/redis/operations.go @@ -59,7 +59,7 @@ func (s *Storage) Get(ctx context.Context, teamID uuid.UUID, sandboxID string) ( key := getSandboxKey(teamID.String(), sandboxID) data, err := s.redisClient.Get(ctx, key).Bytes() if errors.Is(err, redis.Nil) { - return sandbox.Sandbox{}, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } if err != nil { return sandbox.Sandbox{}, fmt.Errorf("failed to get sandbox from Redis: %w", err) @@ -181,7 +181,7 @@ func (s *Storage) Update(ctx context.Context, teamID uuid.UUID, sandboxID string // Get current value data, err := s.redisClient.Get(ctx, key).Bytes() if errors.Is(err, redis.Nil) { - return sandbox.Sandbox{}, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } if err != nil { return sandbox.Sandbox{}, err diff --git a/packages/api/internal/sandbox/storage/redis/state_change.go b/packages/api/internal/sandbox/storage/redis/state_change.go index 81941ae814..05477eeb57 100644 --- a/packages/api/internal/sandbox/storage/redis/state_change.go +++ b/packages/api/internal/sandbox/storage/redis/state_change.go @@ -31,9 +31,7 @@ import ( // // The callback is critical: it deletes the transition key // and sets the result value with short TTL to notify waiters of the outcome. -func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction sandbox.StateAction) (sandbox.Sandbox, bool, func(context.Context, error), error) { - newState := stateAction.TargetState - +func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, opts sandbox.RemoveOpts) (sandbox.Sandbox, bool, func(context.Context, error), error) { key := getSandboxKey(teamID.String(), sandboxID) transitionKey := getTransitionKey(teamID.String(), sandboxID) @@ -58,7 +56,7 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID // Get current sandbox state first data, err := s.redisClient.Get(ctx, key).Bytes() if errors.Is(err, redis.Nil) { - return sandbox.Sandbox{}, false, nil, &sandbox.NotFoundError{SandboxID: sandboxID} + return sandbox.Sandbox{}, false, nil, fmt.Errorf("sandbox %q: %w", sandboxID, sandbox.ErrNotFound) } if err != nil { return sandbox.Sandbox{}, false, nil, fmt.Errorf("failed to get sandbox from Redis: %w", err) @@ -75,13 +73,28 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID return sbx, false, nil, fmt.Errorf("failed to check transition key: %w", err) } + // Resolve eviction under the lock + re-check expiry + if opts.Eviction { + // if there's a transition already in place, don't do anything + if transactionID != "" { + return sbx, false, nil, sandbox.ErrNotEvictable + } + + // if sandbox isn't expired (e.g. race condition with SetTimeout) + if !sbx.IsExpired(time.Now()) { + return sbx, false, nil, sandbox.ErrNotEvictable + } + } + + newState := opts.Action.TargetState + if transactionID != "" { releaseErr := releaseFunc() if releaseErr != nil { logger.L().Warn(ctx, "Failed to release lock before waiting", zap.Error(releaseErr)) } - return s.handleExistingTransition(ctx, teamID, sbx, stateAction, newState, transactionID) + return s.handleExistingTransition(ctx, teamID, sbx, opts.Action, newState, transactionID) } // Check if already in target state @@ -100,7 +113,7 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID // This ensures that on failure the caller sees the pre-mutation state, updated := sbx updated.State = newState - if stateAction.Effect == sandbox.TransitionExpires { + if opts.Action.Effect == sandbox.TransitionExpires { now := time.Now() if !updated.IsExpired(now) { updated.EndTime = now @@ -127,7 +140,7 @@ func (s *Storage) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID logger.L().Debug(ctx, "Started state transition", logger.WithSandboxID(sandboxID), zap.String("state", string(newState)), zap.String("transitionID", transitionID)) - return updated, false, s.createCallback(teamID, sandboxID, transitionKey, resultKey, transitionID, stateAction), nil + return updated, false, s.createCallback(teamID, sandboxID, transitionKey, resultKey, transitionID, opts.Action), nil } // createCallback returns a callback function for completing a transition. @@ -292,5 +305,5 @@ func (s *Storage) handleExistingTransition( } // Retry with new state after transition completes - return s.StartRemoving(ctx, teamID, sbx.SandboxID, stateAction) + return s.StartRemoving(ctx, teamID, sbx.SandboxID, sandbox.RemoveOpts{Action: stateAction}) } diff --git a/packages/api/internal/sandbox/storage/redis/state_change_test.go b/packages/api/internal/sandbox/storage/redis/state_change_test.go index 3469e8b4d5..247a4c09a5 100644 --- a/packages/api/internal/sandbox/storage/redis/state_change_test.go +++ b/packages/api/internal/sandbox/storage/redis/state_change_test.go @@ -74,7 +74,7 @@ func TestStartRemoving_BasicTransitions(t *testing.T) { err := storage.Add(ctx, sbx) require.NoError(t, err) - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, tt.stateAction) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: tt.stateAction}) switch { case tt.shouldError: @@ -119,7 +119,7 @@ func TestStartRemoving_PauseThenKill(t *testing.T) { require.NoError(t, err) // Start pause operation - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -145,7 +145,7 @@ func TestStartRemoving_PauseThenKill(t *testing.T) { // Meanwhile, another request tries to kill the sandbox start := time.Now() - _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) elapsed := time.Since(start) // Should have waited for the pause to complete @@ -186,7 +186,7 @@ func TestStartRemoving_ConcurrentSameState(t *testing.T) { // Three concurrent requests to pause the sandbox for range 3 { go func() { - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) if err != nil { results <- struct { alreadyDone bool @@ -251,13 +251,12 @@ func TestStartRemoving_NotFound(t *testing.T) { ctx := context.Background() teamID := uuid.New() - _, alreadyDone, callback, err := storage.StartRemoving(ctx, teamID, "non-existent", sandbox.StateActionKill) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, teamID, "non-existent", sandbox.RemoveOpts{Action: sandbox.StateActionKill}) require.Error(t, err) assert.False(t, alreadyDone) assert.Nil(t, callback) - var notFoundErr *sandbox.NotFoundError - assert.ErrorAs(t, err, ¬FoundErr) + assert.ErrorIs(t, err, sandbox.ErrNotFound) } func TestStartRemoving_ContextCancellation(t *testing.T) { @@ -270,7 +269,7 @@ func TestStartRemoving_ContextCancellation(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone1, callback1, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone1, callback1, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone1) require.NotNil(t, callback1) @@ -280,7 +279,7 @@ func TestStartRemoving_ContextCancellation(t *testing.T) { defer cancel() start := time.Now() - _, alreadyDone2, _, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone2, _, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) elapsed := time.Since(start) // Should timeout @@ -320,7 +319,7 @@ func TestWaitForStateChange_WaitForCompletion(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -359,7 +358,7 @@ func TestWaitForStateChange_ContextCancellation(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(context.Background(), sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -400,7 +399,7 @@ func TestWaitForStateChange_MultipleWaiters(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -444,7 +443,7 @@ func TestStartRemoving_TransitionKeyTTL(t *testing.T) { require.NoError(t, err) // Start a transition but don't complete it - _, alreadyDone, _, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, _, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) @@ -472,7 +471,7 @@ func TestStartRemoving_CallbackMarksTransitionCompleted(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -514,7 +513,7 @@ func TestStartRemoving_CallbackSetsErrorOnFailure(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -559,7 +558,7 @@ func TestStartRemoving_SetsEndTimeWhenNotExpired(t *testing.T) { beforeTransition := time.Now() // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -586,7 +585,7 @@ func TestStartRemoving_WaiterCompletesOnCallbackSuccess(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -614,7 +613,7 @@ func TestStartRemoving_WaiterCompletesOnCallbackSuccess(t *testing.T) { } // Retry should work now - sandbox is already in pausing state - _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err2) // Already in pausing state from first transition assert.True(t, alreadyDone2) @@ -632,7 +631,7 @@ func TestStartRemoving_WaiterReceivesErrorOnCallbackFailure(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -675,7 +674,7 @@ func TestStartRemoving_DifferentExecutionID(t *testing.T) { require.NoError(t, err) // Start a transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) require.NotNil(t, callback) @@ -698,7 +697,7 @@ func TestStartRemoving_DifferentExecutionID(t *testing.T) { require.NoError(t, err) // Now start a new pause transition - should work since previous transition completed - _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err2) assert.False(t, alreadyDone2, "Should not be alreadyDone since we have a new execution") require.NotNil(t, callback2) @@ -728,7 +727,7 @@ func TestStartRemoving_TransientTransition(t *testing.T) { sbx := createTestSandbox("transient-restore") require.NoError(t, storage.Add(ctx, sbx)) - _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) + _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: transientAction}) require.NoError(t, err) finish(ctx, nil) @@ -747,7 +746,7 @@ func TestStartRemoving_TransientTransition(t *testing.T) { sbx := createTestSandbox("transient-fail-result") require.NoError(t, storage.Add(ctx, sbx)) - _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) + _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: transientAction}) require.NoError(t, err) transitionKey := getTransitionKey(sbx.TeamID.String(), sbx.SandboxID) @@ -773,7 +772,7 @@ func TestStartRemoving_TransientTransition(t *testing.T) { sbx := createTestSandbox("transient-restore-fail") require.NoError(t, storage.Add(ctx, sbx)) - _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, transientAction) + _, _, finish, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: transientAction}) require.NoError(t, err) // Remove the sandbox key to force restoreToRunning to fail @@ -806,13 +805,13 @@ func TestStartRemoving_CompletedTransitionAllowsNewTransition(t *testing.T) { require.NoError(t, err) // Start and complete a pause transition - _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionPause) + _, alreadyDone, callback, err := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionPause}) require.NoError(t, err) assert.False(t, alreadyDone) callback(ctx, nil) // Immediately try to kill - should work since pause is completed - _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.StateActionKill) + _, alreadyDone2, callback2, err2 := storage.StartRemoving(ctx, sbx.TeamID, sbx.SandboxID, sandbox.RemoveOpts{Action: sandbox.StateActionKill}) require.NoError(t, err2) assert.False(t, alreadyDone2) require.NotNil(t, callback2) diff --git a/packages/api/internal/sandbox/store.go b/packages/api/internal/sandbox/store.go index 52bf71e1de..3208075d64 100644 --- a/packages/api/internal/sandbox/store.go +++ b/packages/api/internal/sandbox/store.go @@ -36,7 +36,7 @@ type Storage interface { //nolint: interfacebloat TeamsWithSandboxCount(ctx context.Context) (map[uuid.UUID]int64, error) Update(ctx context.Context, teamID uuid.UUID, sandboxID string, updateFunc func(sandbox Sandbox) (Sandbox, error)) (Sandbox, error) - StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction StateAction) (Sandbox, bool, func(context.Context, error), error) + StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, opts RemoveOpts) (Sandbox, bool, func(context.Context, error), error) WaitForStateChange(ctx context.Context, teamID uuid.UUID, sandboxID string) error Sync(sandboxes []Sandbox, nodeID string) []Sandbox } @@ -150,8 +150,8 @@ func (s *Store) Update(ctx context.Context, teamID uuid.UUID, sandboxID string, return s.storage.Update(ctx, teamID, sandboxID, updateFunc) } -func (s *Store) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, stateAction StateAction) (Sandbox, bool, func(context.Context, error), error) { - return s.storage.StartRemoving(ctx, teamID, sandboxID, stateAction) +func (s *Store) StartRemoving(ctx context.Context, teamID uuid.UUID, sandboxID string, opts RemoveOpts) (Sandbox, bool, func(context.Context, error), error) { + return s.storage.StartRemoving(ctx, teamID, sandboxID, opts) } func (s *Store) WaitForStateChange(ctx context.Context, teamID uuid.UUID, sandboxID string) error {