From 7f07677644aec7ce117cb66a0a42ab1b339adda0 Mon Sep 17 00:00:00 2001 From: Dylan Strohschein Date: Fri, 4 Apr 2025 17:49:47 -0400 Subject: [PATCH 1/4] add notification message type --- internal/lsp/lsproto/jsonrpc.go | 74 ++++++++++++++++------- internal/lsp/server.go | 102 +++++++++++++++++--------------- 2 files changed, 108 insertions(+), 68 deletions(-) diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index 952a3635b9..24524a2b12 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -43,19 +43,32 @@ func (id *ID) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &id.int) } -// TODO(jakebailey): NotificationMessage? Use RequestMessage without ID? +type Message struct { + JSONRPC JSONRPCVersion `json:"jsonrpc"` +} + +type NotificationMessage struct { + Message + Method Method `json:"method"` + Params any `json:"params"` +} type RequestMessage struct { - JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id"` - Method Method `json:"method"` - Params any `json:"params"` + Message + ID *ID `json:"id"` + Method Method `json:"method"` + Params any `json:"params"` } -func (r *RequestMessage) UnmarshalJSON(data []byte) error { +type RequestOrNotificationMessage struct { + NotificationMessage *NotificationMessage + RequestMessage *RequestMessage +} + +func (r *RequestOrNotificationMessage) UnmarshalJSON(data []byte) error { var raw struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id"` + ID *ID `json:"id,omitempty"` Method Method `json:"method"` Params json.RawMessage `json:"params"` } @@ -63,36 +76,55 @@ func (r *RequestMessage) UnmarshalJSON(data []byte) error { return fmt.Errorf("%w: %w", ErrInvalidRequest, err) } - r.ID = raw.ID - r.Method = raw.Method - if r.Method == MethodShutdown || r.Method == MethodExit { + params, err := unmarhalParams(raw.Method, raw.Params) + if err != nil { + return err + } + + if raw.ID != nil { + r.RequestMessage = &RequestMessage{ + ID: raw.ID, + Method: raw.Method, + Params: params, + } + } else { + r.NotificationMessage = &NotificationMessage{ + Method: raw.Method, + Params: params, + } + } + + return nil +} + +func unmarhalParams(rawMethod Method, rawParams []byte) (any, error) { + if rawMethod == MethodShutdown || rawMethod == MethodExit { // These methods have no params. - return nil + return nil, nil } var params any var err error - if unmarshalParams, ok := unmarshallers[raw.Method]; ok { - params, err = unmarshalParams(raw.Params) + if unmarshaller, ok := unmarshallers[rawMethod]; ok { + params, err = unmarshaller(rawParams) } else { // Fall back to default; it's probably an unknown message and we will probably not handle it. - err = json.Unmarshal(raw.Params, ¶ms) + err = json.Unmarshal(rawParams, ¶ms) } - r.Params = params if err != nil { - return fmt.Errorf("%w: %w", ErrInvalidRequest, err) + return nil, fmt.Errorf("%w: %w", ErrInvalidRequest, err) } - return nil + return params, nil } type ResponseMessage struct { - JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id,omitempty"` - Result any `json:"result"` - Error *ResponseError `json:"error,omitempty"` + Message + ID *ID `json:"id,omitempty"` + Result any `json:"result"` + Error *ResponseError `json:"error,omitempty"` } type ResponseError struct { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index b2ef52fd5e..81455952c1 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -105,15 +105,20 @@ func (s *Server) Run() error { } if s.initializeParams == nil { - if req.Method == lsproto.MethodInitialize { - if err := s.handleInitialize(req); err != nil { - return err - } - } else { - if err := s.sendError(req.ID, lsproto.ErrServerNotInitialized); err != nil { - return err + if req.RequestMessage != nil { + message := req.RequestMessage + + if message.Method == lsproto.MethodInitialize { + if err := s.handleInitialize(message); err != nil { + return err + } + } else { + if err := s.sendError(message.ID, lsproto.ErrServerNotInitialized); err != nil { + return err + } } } + continue } @@ -123,13 +128,13 @@ func (s *Server) Run() error { } } -func (s *Server) read() (*lsproto.RequestMessage, error) { +func (s *Server) read() (*lsproto.RequestOrNotificationMessage, error) { data, err := s.r.Read() if err != nil { return nil, err } - req := &lsproto.RequestMessage{} + req := &lsproto.RequestOrNotificationMessage{} if err := json.Unmarshal(data, req); err != nil { return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err) } @@ -170,45 +175,45 @@ func (s *Server) sendResponse(resp *lsproto.ResponseMessage) error { return s.w.Write(data) } -func (s *Server) handleMessage(req *lsproto.RequestMessage) error { - s.requestTime = time.Now() - s.requestMethod = string(req.Method) - - params := req.Params - switch params.(type) { - case *lsproto.InitializeParams: - return s.sendError(req.ID, lsproto.ErrInvalidRequest) - case *lsproto.InitializedParams: - return s.handleInitialized(req) - case *lsproto.DidOpenTextDocumentParams: - return s.handleDidOpen(req) - case *lsproto.DidChangeTextDocumentParams: - return s.handleDidChange(req) - case *lsproto.DidSaveTextDocumentParams: - return s.handleDidSave(req) - case *lsproto.DidCloseTextDocumentParams: - return s.handleDidClose(req) - case *lsproto.DocumentDiagnosticParams: - return s.handleDocumentDiagnostic(req) - case *lsproto.HoverParams: - return s.handleHover(req) - case *lsproto.DefinitionParams: - return s.handleDefinition(req) - default: +func (s *Server) handleMessage(msg *lsproto.RequestOrNotificationMessage) error { + if req := msg.RequestMessage; req != nil { switch req.Method { + case lsproto.MethodInitialize: + return s.sendError(req.ID, lsproto.ErrInvalidRequest) + case lsproto.MethodTextDocumentDiagnostic: + return s.handleDocumentDiagnostic(req) + case lsproto.MethodTextDocumentHover: + return s.handleHover(req) + case lsproto.MethodTextDocumentDefinition: + return s.handleDefinition(req) case lsproto.MethodShutdown: s.projectService.Close() return s.sendResult(req.ID, nil) - case lsproto.MethodExit: - return nil default: s.Log("unknown method", req.Method) - if req.ID != nil { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) - } + } + } else if notif := msg.NotificationMessage; notif != nil { + switch notif.Method { + case lsproto.MethodInitialized: + return s.handleInitialized() + case lsproto.MethodTextDocumentDidOpen: + return s.handleDidOpen(notif) + case lsproto.MethodTextDocumentDidChange: + return s.handleDidChange(notif) + case lsproto.MethodTextDocumentDidSave: + return s.handleDidSave(notif) + case lsproto.MethodTextDocumentDidClose: + return s.handleDidClose(notif) + case lsproto.MethodExit: return nil + default: + s.Log("unknown method", notif.Method) } + } else { + s.Log("Failed to parse unknown message") } + + return nil } func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { @@ -254,7 +259,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error { }) } -func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { +func (s *Server) handleInitialized() error { s.logger = project.NewLogger([]io.Writer{s.stderr}, project.LogLevelVerbose) s.projectService = project.NewService(s, project.ServiceOptions{ DefaultLibraryPath: s.defaultLibraryPath, @@ -264,24 +269,26 @@ func (s *Server) handleInitialized(req *lsproto.RequestMessage) error { return nil } -func (s *Server) handleDidOpen(req *lsproto.RequestMessage) error { +func (s *Server) handleDidOpen(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidOpenTextDocumentParams) s.projectService.OpenFile(documentUriToFileName(params.TextDocument.Uri), params.TextDocument.Text, languageKindToScriptKind(params.TextDocument.LanguageId), "") return nil } -func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { +func (s *Server) handleDidChange(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidChangeTextDocumentParams) scriptInfo := s.projectService.GetScriptInfo(documentUriToFileName(params.TextDocument.Uri)) if scriptInfo == nil { - return s.sendError(req.ID, lsproto.ErrRequestFailed) + s.logger.Error("Failed to get script info") + return nil } changes := make([]ls.TextChange, len(params.ContentChanges)) for i, change := range params.ContentChanges { if partialChange := change.TextDocumentContentChangePartial; partialChange != nil { if textChange, err := s.converters.fromLspTextChange(partialChange, scriptInfo.FileName()); err != nil { - return s.sendError(req.ID, err) + s.logger.Error(fmt.Sprintf("Error converting %v:", err)) + return nil } else { changes[i] = textChange } @@ -291,7 +298,8 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { NewText: wholeChange.Text, } } else { - return s.sendError(req.ID, lsproto.ErrInvalidRequest) + s.logger.Error(fmt.Sprintf("Invalid request")) + return nil } } @@ -299,13 +307,13 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error { return nil } -func (s *Server) handleDidSave(req *lsproto.RequestMessage) error { +func (s *Server) handleDidSave(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidSaveTextDocumentParams) s.projectService.MarkFileSaved(documentUriToFileName(params.TextDocument.Uri), *params.Text) return nil } -func (s *Server) handleDidClose(req *lsproto.RequestMessage) error { +func (s *Server) handleDidClose(req *lsproto.NotificationMessage) error { params := req.Params.(*lsproto.DidCloseTextDocumentParams) s.projectService.CloseFile(documentUriToFileName(params.TextDocument.Uri)) return nil From 14db7d9830572a74f58e9eaef049dc6909f0fc74 Mon Sep 17 00:00:00 2001 From: Dylan Strohschein Date: Fri, 4 Apr 2025 17:56:37 -0400 Subject: [PATCH 2/4] Update internal/lsp/lsproto/jsonrpc.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/lsp/lsproto/jsonrpc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index 24524a2b12..1238a8fc63 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -97,7 +97,7 @@ func (r *RequestOrNotificationMessage) UnmarshalJSON(data []byte) error { return nil } -func unmarhalParams(rawMethod Method, rawParams []byte) (any, error) { +func unmarshalParams(rawMethod Method, rawParams []byte) (any, error) { if rawMethod == MethodShutdown || rawMethod == MethodExit { // These methods have no params. return nil, nil From 83092d375e6884b5e65f9b53f88e1a7970071a62 Mon Sep 17 00:00:00 2001 From: Dylan Strohschein Date: Fri, 4 Apr 2025 17:58:38 -0400 Subject: [PATCH 3/4] omitzero and fixed function call --- internal/lsp/lsproto/jsonrpc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index 1238a8fc63..ba221282dd 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -68,7 +68,7 @@ type RequestOrNotificationMessage struct { func (r *RequestOrNotificationMessage) UnmarshalJSON(data []byte) error { var raw struct { JSONRPC JSONRPCVersion `json:"jsonrpc"` - ID *ID `json:"id,omitempty"` + ID *ID `json:"id,omitzero"` Method Method `json:"method"` Params json.RawMessage `json:"params"` } @@ -76,7 +76,7 @@ func (r *RequestOrNotificationMessage) UnmarshalJSON(data []byte) error { return fmt.Errorf("%w: %w", ErrInvalidRequest, err) } - params, err := unmarhalParams(raw.Method, raw.Params) + params, err := unmarshalParams(raw.Method, raw.Params) if err != nil { return err } From 0e472edc66ce7a216fee791d75309db4881897b3 Mon Sep 17 00:00:00 2001 From: Dylan Strohschein Date: Fri, 4 Apr 2025 18:09:44 -0400 Subject: [PATCH 4/4] favor omitzero over omitempty --- internal/lsp/lsproto/jsonrpc.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/lsp/lsproto/jsonrpc.go b/internal/lsp/lsproto/jsonrpc.go index ba221282dd..76fd4edc1e 100644 --- a/internal/lsp/lsproto/jsonrpc.go +++ b/internal/lsp/lsproto/jsonrpc.go @@ -122,13 +122,13 @@ func unmarshalParams(rawMethod Method, rawParams []byte) (any, error) { type ResponseMessage struct { Message - ID *ID `json:"id,omitempty"` + ID *ID `json:"id,omitzero"` Result any `json:"result"` - Error *ResponseError `json:"error,omitempty"` + Error *ResponseError `json:"error,omitzero"` } type ResponseError struct { Code int32 `json:"code"` Message string `json:"message"` - Data any `json:"data,omitempty"` + Data any `json:"data,omitzero"` }