diff --git a/cmd/tsgo/lsp.go b/cmd/tsgo/lsp.go index 28d460b29f..29c6e38aca 100644 --- a/cmd/tsgo/lsp.go +++ b/cmd/tsgo/lsp.go @@ -1,10 +1,14 @@ package main import ( + "context" "flag" "fmt" "os" + "os/signal" "runtime" + "syscall" + "time" "github.com/microsoft/typescript-go/internal/bundled" "github.com/microsoft/typescript-go/internal/core" @@ -41,6 +45,9 @@ func runLSP(args []string) int { defaultLibraryPath := bundled.LibPath() typingsLocation := getGlobalTypingsCacheLocation() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + s := lsp.NewServer(&lsp.ServerOptions{ In: os.Stdin, Out: os.Stdout, @@ -49,9 +56,27 @@ func runLSP(args []string) int { FS: fs, DefaultLibraryPath: defaultLibraryPath, TypingsLocation: typingsLocation, + SetParentPID: func(pid int) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Second): + p, err := os.FindProcess(pid) + if err != nil { + os.Exit(1) + } + if p.Signal(syscall.Signal(0)) != nil { + os.Exit(1) + } + } + } + }() + }, }) - if err := s.Run(); err != nil { + if err := s.Run(ctx); err != nil { return 1 } return 0 diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 6254b09624..92d63953eb 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -6,12 +6,9 @@ import ( "errors" "fmt" "io" - "os" - "os/signal" "runtime/debug" "slices" "sync" - "syscall" "github.com/microsoft/typescript-go/internal/core" "github.com/microsoft/typescript-go/internal/ls" @@ -31,6 +28,8 @@ type ServerOptions struct { FS vfs.FS DefaultLibraryPath string TypingsLocation string + + SetParentPID func(pid int) } func NewServer(opts *ServerOptions) *Server { @@ -50,6 +49,7 @@ func NewServer(opts *ServerOptions) *Server { fs: opts.FS, defaultLibraryPath: opts.DefaultLibraryPath, typingsLocation: opts.TypingsLocation, + setParentPID: opts.SetParentPID, } } @@ -83,6 +83,8 @@ type Server struct { defaultLibraryPath string typingsLocation string + setParentPID func(pid int) + initializeParams *lsproto.InitializeParams positionEncoding lsproto.PositionEncodingKind @@ -187,10 +189,7 @@ func (s *Server) RefreshDiagnostics(ctx context.Context) error { return nil } -func (s *Server) Run() error { - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - +func (s *Server) Run(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { return s.dispatchLoop(ctx) }) g.Go(func() error { return s.writeLoop(ctx) }) @@ -452,6 +451,12 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R func (s *Server) handleInitialize(req *lsproto.RequestMessage) { s.initializeParams = req.Params.(*lsproto.InitializeParams) + if s.setParentPID != nil { + if pid := s.initializeParams.ProcessId.Value; pid != 0 { + s.setParentPID(int(pid)) + } + } + s.positionEncoding = lsproto.PositionEncodingKindUTF16 if genCapabilities := s.initializeParams.Capabilities.General; genCapabilities != nil && genCapabilities.PositionEncodings != nil { if slices.Contains(*genCapabilities.PositionEncodings, lsproto.PositionEncodingKindUTF8) {