Skip to content

Commit 92922c6

Browse files
authored
Gracefully Shutdown Foreground Server on Interrupt (#2927)
Summary This pull request is a second attempt at #2863: Gracefully shutting down foreground servers when receiving a SIGINT. This PR takes greater care to fix tests which became flaky with the original change. Fixes #1855. Fixed Tests and Their Changes osv_mcp_server_test.go This file has a test that became flaky, specifically the Running OSV MCP server in the foreground" tests. This is the only e2e test that exercises the foreground server like a user would. I made the following changes to the test/code to ensure it passes: Update the test to document how SIGINT should behave rather than using thv stop/rm. Simplify the handling of context cancellation in runForeground. Previously, we had two goroutines doing parallel shutdown work: workloadManager.RunWorkload and runForeground. runForeground calls RunWorkload so it's natural to block within runForeground until RunWorload returns. Previously, the two shutdown routines could race on modifying/deleting the workload. Screenshot 2025-12-07 at 3 32 00 PM fetch_mcp_server_test.go This file and likely others exercise --foreground indirectly. The fetch tests are structured to stop a server with a shared name in between tests. Previously, thv stop sent a SIGINT to the background process thv restart --foreground that was spawned by thv run. When the background process received that signal, the signal handler we removed called os.Exit causing the process to exit and no longer babysit the workload. The problem and solution thv restart was intentionally coded to ignore context cancellation because we wanted a timeout around the restart's initial startup. Because thv restart ignore all context cancellation, it effectively ignored all SIGINTs and the process would continually resuscitate the server with stale state we were attempting to change across tests. The solution preserves the timeout on startup but ensures post-startup context cancellation is respected.
1 parent 5084b21 commit 92922c6

File tree

6 files changed

+139
-98
lines changed

6 files changed

+139
-98
lines changed

cmd/thv/app/run.go

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ import (
77
"net"
88
"net/url"
99
"os"
10-
"os/signal"
1110
"strings"
12-
"syscall"
1311
"time"
1412

1513
"github.com/spf13/cobra"
@@ -126,7 +124,7 @@ func init() {
126124
AddOIDCFlags(runCmd)
127125
}
128126

129-
func cleanupAndWait(workloadManager workloads.Manager, name string, cancel context.CancelFunc, errCh <-chan error) {
127+
func cleanupAndWait(workloadManager workloads.Manager, name string) {
130128
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 30*time.Second)
131129
defer cleanupCancel()
132130

@@ -138,13 +136,6 @@ func cleanupAndWait(workloadManager workloads.Manager, name string, cancel conte
138136
logger.Warnf("DeleteWorkloads group error for %q: %v", name, err)
139137
}
140138
}
141-
142-
cancel()
143-
select {
144-
case <-errCh:
145-
case <-time.After(5 * time.Second):
146-
logger.Warnf("Timeout waiting for workload to stop")
147-
}
148139
}
149140

150141
// nolint:gocyclo // This function is complex by design
@@ -304,28 +295,26 @@ func getworkloadDefaultName(_ context.Context, serverOrImage string) string {
304295
}
305296

306297
func runForeground(ctx context.Context, workloadManager workloads.Manager, runnerConfig *runner.RunConfig) error {
307-
ctx, cancel := context.WithCancel(ctx)
308-
defer cancel()
309-
310-
sigCh := make(chan os.Signal, 1)
311-
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
312-
defer signal.Stop(sigCh)
313298

314299
errCh := make(chan error, 1)
315300
go func() {
316301
errCh <- workloadManager.RunWorkload(ctx, runnerConfig)
317302
}()
318303

319-
select {
320-
case sig := <-sigCh:
321-
if !process.IsDetached() {
322-
logger.Infof("Received signal: %v, stopping server %q", sig, runnerConfig.BaseName)
323-
cleanupAndWait(workloadManager, runnerConfig.BaseName, cancel, errCh)
324-
}
325-
return nil
326-
case err := <-errCh:
327-
return err
304+
// workloadManager.RunWorkload will block until the context is cancelled
305+
// or an unrecoverable error is returned. In either case, it will stop the server.
306+
// We wait until workloadManager.RunWorkload exits before deleting the workload,
307+
// so stopping and deleting don't race.
308+
//
309+
// There's room for improvement in the factoring here.
310+
// Shutdown and cancellation logic is unnecessarily spread across two goroutines.
311+
err := <-errCh
312+
if !process.IsDetached() {
313+
logger.Infof("RunWorkload Exited. Error: %v, stopping server %q", err, runnerConfig.BaseName)
314+
cleanupAndWait(workloadManager, runnerConfig.BaseName)
328315
}
316+
return err
317+
329318
}
330319

331320
func validateGroup(ctx context.Context, workloadsManager workloads.Manager, serverOrImage string) error {

cmd/thv/main.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package main
33

44
import (
5+
"context"
56
"os"
67
"os/signal"
78
"syscall"
@@ -12,7 +13,6 @@ import (
1213
"github.com/stacklok/toolhive/cmd/thv/app"
1314
"github.com/stacklok/toolhive/pkg/client"
1415
"github.com/stacklok/toolhive/pkg/container"
15-
"github.com/stacklok/toolhive/pkg/container/runtime"
1616
"github.com/stacklok/toolhive/pkg/lockfile"
1717
"github.com/stacklok/toolhive/pkg/logger"
1818
"github.com/stacklok/toolhive/pkg/migration"
@@ -23,7 +23,7 @@ func main() {
2323
logger.Initialize()
2424

2525
// Setup signal handling for graceful cleanup
26-
setupSignalHandler()
26+
ctx := setupSignalHandler()
2727

2828
// Clean up stale lock files on startup
2929
cleanupStaleLockFiles()
@@ -47,8 +47,10 @@ func main() {
4747
migration.CheckAndPerformDefaultGroupMigration()
4848
}
4949

50+
cmd := app.NewRootCmd(!app.IsCompletionCommand(os.Args))
51+
5052
// Skip update check for completion command or if we are running in kubernetes
51-
if err := app.NewRootCmd(!app.IsCompletionCommand(os.Args) && !runtime.IsKubernetesRuntime()).Execute(); err != nil {
53+
if err := cmd.ExecuteContext(ctx); err != nil {
5254
// Clean up any remaining lock files on error exit
5355
lockfile.CleanupAllLocks()
5456
os.Exit(1)
@@ -59,16 +61,19 @@ func main() {
5961
}
6062

6163
// setupSignalHandler configures signal handling to ensure lock files are cleaned up
62-
func setupSignalHandler() {
64+
func setupSignalHandler() context.Context {
6365
sigCh := make(chan os.Signal, 1)
6466
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT)
6567

68+
ctx, cancel := context.WithCancel(context.Background())
6669
go func() {
6770
<-sigCh
6871
logger.Debugf("Received signal, cleaning up lock files...")
6972
lockfile.CleanupAllLocks()
70-
os.Exit(0)
73+
cancel()
7174
}()
75+
76+
return ctx
7277
}
7378

7479
// cleanupStaleLockFiles removes stale lock files from known directories on startup

pkg/runner/runner.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ import (
88
"fmt"
99
"net/http"
1010
"os"
11-
"os/signal"
1211
"strings"
13-
"syscall"
1412
"time"
1513

1614
"golang.org/x/oauth2"
@@ -317,16 +315,19 @@ func (r *Runner) Run(ctx context.Context) error {
317315

318316
// Define a function to stop the MCP server
319317
stopMCPServer := func(reason string) {
318+
// Use a background context to avoid cancellation of the main context.
319+
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 1*time.Minute)
320+
defer cleanupCancel()
320321
logger.Infof("Stopping MCP server: %s", reason)
321322

322323
// Stop the transport (which also stops the container, monitoring, and handles removal)
323324
logger.Infof("Stopping %s transport...", r.Config.Transport)
324-
if err := transportHandler.Stop(ctx); err != nil {
325+
if err := transportHandler.Stop(cleanupCtx); err != nil {
325326
logger.Warnf("Warning: Failed to stop transport: %v", err)
326327
}
327328

328329
// Cleanup telemetry provider
329-
if err := r.Cleanup(ctx); err != nil {
330+
if err := r.Cleanup(cleanupCtx); err != nil {
330331
logger.Warnf("Warning: Failed to cleanup telemetry: %v", err)
331332
}
332333

@@ -335,7 +336,7 @@ func (r *Runner) Run(ctx context.Context) error {
335336
if err := process.RemovePIDFile(r.Config.BaseName); err != nil {
336337
logger.Warnf("Warning: Failed to remove PID file: %v", err)
337338
}
338-
if err := r.statusManager.ResetWorkloadPID(ctx, r.Config.BaseName); err != nil {
339+
if err := r.statusManager.ResetWorkloadPID(cleanupCtx, r.Config.BaseName); err != nil {
339340
logger.Warnf("Warning: Failed to reset workload %s PID: %v", r.Config.ContainerName, err)
340341
}
341342

@@ -354,10 +355,6 @@ func (r *Runner) Run(ctx context.Context) error {
354355
logger.Info("Press Ctrl+C to stop or wait for container to exit")
355356
}
356357

357-
// Set up signal handling
358-
sigCh := make(chan os.Signal, 1)
359-
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
360-
361358
// Create a done channel to signal when the server has been stopped
362359
doneCh := make(chan struct{})
363360

@@ -399,8 +396,8 @@ func (r *Runner) Run(ctx context.Context) error {
399396

400397
// Wait for either a signal or the done channel to be closed
401398
select {
402-
case sig := <-sigCh:
403-
stopMCPServer(fmt.Sprintf("Received signal %s", sig))
399+
case <-ctx.Done():
400+
stopMCPServer("Context cancelled")
404401
case <-doneCh:
405402
// The transport has already been stopped (likely by the container exit)
406403
// Clean up the PID file and state

pkg/workloads/manager.go

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ func (d *DefaultManager) DeleteWorkloads(_ context.Context, names []string) (*er
872872
}
873873

874874
// RestartWorkloads restarts the specified workloads by name.
875-
func (d *DefaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) {
875+
func (d *DefaultManager) RestartWorkloads(ctx context.Context, names []string, foreground bool) (*errgroup.Group, error) {
876876
// Validate all workload names to prevent path traversal attacks
877877
for _, name := range names {
878878
if err := types.ValidateWorkloadName(name); err != nil {
@@ -884,7 +884,7 @@ func (d *DefaultManager) RestartWorkloads(_ context.Context, names []string, for
884884

885885
for _, name := range names {
886886
group.Go(func() error {
887-
return d.restartSingleWorkload(name, foreground)
887+
return d.restartSingleWorkload(ctx, name, foreground)
888888
})
889889
}
890890

@@ -943,39 +943,59 @@ func (d *DefaultManager) updateSingleWorkload(workloadName string, newConfig *ru
943943
}
944944

945945
// restartSingleWorkload handles the restart logic for a single workload
946-
func (d *DefaultManager) restartSingleWorkload(name string, foreground bool) error {
947-
// Create a child context with a longer timeout
948-
childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout)
949-
defer cancel()
946+
func (d *DefaultManager) restartSingleWorkload(ctx context.Context, name string, foreground bool) error {
950947

951948
// First, try to load the run configuration to check if it's a remote workload
952-
runConfig, err := runner.LoadState(childCtx, name)
949+
runConfig, err := runner.LoadState(ctx, name)
953950
if err != nil {
954951
// If we can't load the state, it might be a container workload or the workload doesn't exist
955952
// Try to restart it as a container workload
956-
return d.restartContainerWorkload(childCtx, name, foreground)
953+
return d.restartContainerWorkload(ctx, name, foreground)
957954
}
958955

959956
// Check if this is a remote workload
960957
if runConfig.RemoteURL != "" {
961-
return d.restartRemoteWorkload(childCtx, name, runConfig, foreground)
958+
return d.restartRemoteWorkload(ctx, name, runConfig, foreground)
962959
}
963960

964961
// This is a container-based workload
965-
return d.restartContainerWorkload(childCtx, name, foreground)
962+
return d.restartContainerWorkload(ctx, name, foreground)
966963
}
967964

968965
// restartRemoteWorkload handles restarting a remote workload
966+
// It blocks until the context is cancelled or there is already a supervisor process running.
969967
func (d *DefaultManager) restartRemoteWorkload(
970968
ctx context.Context,
971969
name string,
972970
runConfig *runner.RunConfig,
973971
foreground bool,
974972
) error {
973+
mcpRunner, err := d.maybeSetupRemoteWorkload(ctx, name, runConfig)
974+
if err != nil {
975+
return fmt.Errorf("failed to setup remote workload: %w", err)
976+
}
977+
978+
if mcpRunner == nil {
979+
return nil
980+
}
981+
982+
return d.startWorkload(ctx, name, mcpRunner, foreground)
983+
}
984+
985+
// maybeSetupRemoteWorkload is the startup steps for a remote workload.
986+
// A runner may not be returned if the workload is already running and supervised.
987+
func (d *DefaultManager) maybeSetupRemoteWorkload(
988+
ctx context.Context,
989+
name string,
990+
runConfig *runner.RunConfig,
991+
) (*runner.Runner, error) {
992+
ctx, cancel := context.WithTimeout(ctx, AsyncOperationTimeout)
993+
defer cancel()
994+
975995
// Get workload status using the status manager
976996
workload, err := d.statuses.GetWorkload(ctx, name)
977997
if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) {
978-
return err
998+
return nil, err
979999
}
9801000

9811001
// If workload is already running, check if the supervisor process is healthy
@@ -986,7 +1006,7 @@ func (d *DefaultManager) restartRemoteWorkload(
9861006
if supervisorAlive {
9871007
// Workload is running and healthy - preserve old behavior (no-op)
9881008
logger.Infof("Remote workload %s is already running", name)
989-
return nil
1009+
return nil, nil
9901010
}
9911011

9921012
// Supervisor is dead/missing - we need to clean up and restart to fix the damaged state
@@ -1015,7 +1035,7 @@ func (d *DefaultManager) restartRemoteWorkload(
10151035
// Load runner configuration from state
10161036
mcpRunner, err := d.loadRunnerFromState(ctx, runConfig.BaseName)
10171037
if err != nil {
1018-
return fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err)
1038+
return nil, fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err)
10191039
}
10201040

10211041
// Set status to starting
@@ -1024,16 +1044,31 @@ func (d *DefaultManager) restartRemoteWorkload(
10241044
}
10251045

10261046
logger.Infof("Loaded configuration from state for %s", runConfig.BaseName)
1047+
return mcpRunner, nil
1048+
}
1049+
1050+
// restartContainerWorkload handles restarting a container-based workload.
1051+
// It blocks until the context is cancelled or there is already a supervisor process running.
1052+
func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error {
1053+
workloadName, mcpRunner, err := d.maybeSetupContainerWorkload(ctx, name)
1054+
if err != nil {
1055+
return fmt.Errorf("failed to setup container workload: %w", err)
1056+
}
1057+
1058+
if mcpRunner == nil {
1059+
return nil
1060+
}
10271061

1028-
// Start the remote workload using the loaded runner
1029-
// Use background context to avoid timeout cancellation - same reasoning as container workloads
1030-
return d.startWorkload(context.Background(), name, mcpRunner, foreground)
1062+
return d.startWorkload(ctx, workloadName, mcpRunner, foreground)
10311063
}
10321064

1033-
// restartContainerWorkload handles restarting a container-based workload
1065+
// maybeSetupContainerWorkload is the startup steps for a container-based workload.
1066+
// A runner may not be returned if the workload is already running and supervised.
10341067
//
10351068
//nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases
1036-
func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error {
1069+
func (d *DefaultManager) maybeSetupContainerWorkload(ctx context.Context, name string) (string, *runner.Runner, error) {
1070+
ctx, cancel := context.WithTimeout(ctx, AsyncOperationTimeout)
1071+
defer cancel()
10371072
// Get container info to resolve partial names and extract proper workload name
10381073
var containerName string
10391074
var workloadName string
@@ -1057,7 +1092,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
10571092
// Get workload status using the status manager
10581093
workload, err := d.statuses.GetWorkload(ctx, name)
10591094
if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) {
1060-
return err
1095+
return "", nil, err
10611096
}
10621097

10631098
// Check if workload is running and healthy (including supervisor process)
@@ -1068,7 +1103,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
10681103
if supervisorAlive {
10691104
// Workload is running and healthy - preserve old behavior (no-op)
10701105
logger.Infof("Container %s is already running", containerName)
1071-
return nil
1106+
return "", nil, nil
10721107
}
10731108

10741109
// Supervisor is dead/missing - we need to clean up and restart to fix the damaged state
@@ -1107,7 +1142,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
11071142
if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, err.Error()); statusErr != nil {
11081143
logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr)
11091144
}
1110-
return fmt.Errorf("failed to stop container %s: %v", containerName, err)
1145+
return "", nil, fmt.Errorf("failed to stop container %s: %v", containerName, err)
11111146
}
11121147
logger.Infof("Container %s stopped", containerName)
11131148
}
@@ -1126,7 +1161,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
11261161
// Load runner configuration from state
11271162
mcpRunner, err := d.loadRunnerFromState(ctx, workloadName)
11281163
if err != nil {
1129-
return fmt.Errorf("failed to load state for %s: %v", workloadName, err)
1164+
return "", nil, fmt.Errorf("failed to load state for %s: %v", workloadName, err)
11301165
}
11311166

11321167
// Set workload status to starting - use the workload name for status operations
@@ -1135,11 +1170,7 @@ func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name stri
11351170
}
11361171
logger.Infof("Loaded configuration from state for %s", workloadName)
11371172

1138-
// Start the workload with background context to avoid timeout cancellation
1139-
// The ctx with AsyncOperationTimeout is only for the restart setup operations,
1140-
// but the actual workload should run indefinitely with its own lifecycle management
1141-
// Use workload name for user-facing operations
1142-
return d.startWorkload(context.Background(), workloadName, mcpRunner, foreground)
1173+
return workloadName, mcpRunner, nil
11431174
}
11441175

11451176
// startWorkload starts the workload in either foreground or background mode

0 commit comments

Comments
 (0)