diff --git a/sqlmesh/cli/daemon_connector.py b/sqlmesh/cli/daemon_connector.py new file mode 100644 index 0000000000..5fff626a2d --- /dev/null +++ b/sqlmesh/cli/daemon_connector.py @@ -0,0 +1,205 @@ +import json +import typing as t +import uuid +from pathlib import Path + +from sqlmesh.core.console import JanitorState, JanitorStateRenderer +from sqlmesh.lsp.cli_calls import ( + DaemonCommunicationModeTCP, + DaemonCommunicationModeUnixSocket, + LockFile, + generate_lock_file, + return_lock_file_path, +) +from sqlmesh.utils.pydantic import PydanticModel + + +class LSPCLICallRequest(PydanticModel): + """Request to call a CLI command through the LSP.""" + + arguments: t.List[str] + + +class SocketMessageFinished(PydanticModel): + state: t.Literal["finished"] = "finished" + + +class SocketMessageOngoing(PydanticModel): + state: t.Literal["ongoing"] = "ongoing" + message: t.Dict[str, t.Any] + + +class SocketMessageError(PydanticModel): + state: t.Literal["error"] = "error" + message: str + + +SocketMessage = t.Union[SocketMessageFinished, SocketMessageOngoing, SocketMessageError] + + +def _validate_lock_file(lock_file_path: Path) -> LockFile: + """Validate that the lock file is compatible with current version.""" + current_lock = generate_lock_file() + try: + read_file = LockFile.model_validate_json(lock_file_path.read_text()) + except Exception as e: + raise ValueError(f"Failed to parse lock file: {e}") + + if not read_file.validate_lock_file(current_lock): + raise ValueError( + f"Lock file version mismatch. Expected: {current_lock.version}, " + f"Got: {read_file.version}" + ) + return read_file + + +import socket + + +class DaemonConnector: + """Connects to the LSP daemon via socket to execute commands.""" + + def __init__(self, project_path: Path, lock_file: LockFile): + self.project_path = project_path + self.lock_file = lock_file + self.renderer = JanitorStateRenderer() + + def _open_connection(self) -> tuple[t.BinaryIO, t.BinaryIO]: + lock_file = self.lock_file + communication = lock_file.communication + print(f"Using communication mode: {communication}") + if communication is None: + raise ValueError("not correct") + + if isinstance(communication.type, DaemonCommunicationModeUnixSocket): + print("Opening Unix socket connection...") + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(communication.type.socket) + print(f"Connected to Unix socket at {communication.type.socket}") + rfile = sock.makefile("rb", buffering=0) + wfile = sock.makefile("wb", buffering=0) + print("Connected to daemon via Unix socket.") + return rfile, wfile + else: + raise ValueError("Only Unix socket communication is supported") + + def _send_jsonrpc_request(self, connection: t.Any, method: str, params: dict) -> str: + """Send a JSON-RPC request over the connection and return the request ID.""" + request_id = str(uuid.uuid4()) + jsonrpc_request = {"jsonrpc": "2.0", "method": method, "params": params, "id": request_id} + + # JSON-RPC over connection uses Content-Length header (LSP protocol style) + message = json.dumps(jsonrpc_request) + content_length = len(message.encode("utf-8")) + + # Send with Content-Length header + header = f"Content-Length: {content_length}\r\n\r\n" + full_message = header.encode("utf-8") + message.encode("utf-8") + connection.write(full_message) + connection.flush() + + return request_id + + def _read_jsonrpc_message(self, connection: t.Any) -> t.Dict[str, t.Any]: + """Read any JSON-RPC message (response or notification) from the connection.""" + # Read headers + headers = b"" + while b"\r\n\r\n" not in headers: + chunk = connection.read(1) + if not chunk: + raise ValueError("Connection closed while reading headers") + headers += chunk + + # Parse Content-Length header + header_str = headers.decode("utf-8") + content_length = None + for line in header_str.split("\r\n"): + if line.startswith("Content-Length:"): + content_length = int(line.split(":")[1].strip()) + break + + if content_length is None: + raise ValueError("No Content-Length header found") + + # Read the content + content = connection.read(content_length) + if len(content) < content_length: + raise ValueError("Connection closed while reading content") + + # Parse JSON-RPC message + message = json.loads(content.decode("utf-8")) + return message + + def _read_jsonrpc_response(self, connection: t.Any, expected_id: str) -> t.Any: + """Read a JSON-RPC response from the connection.""" + message = self._read_jsonrpc_message(connection) + + if message.get("id") != expected_id: + raise ValueError(f"Unexpected response ID: {message.get('id')}") + + if "error" in message: + raise ValueError(f"JSON-RPC error: {message['error']}") + + return message.get("result", {}) + + def call_janitor(self, ignore_ttl: bool = False) -> bool: + rfile = wfile = None + try: + rfile, wfile = self._open_connection() + + request = LSPCLICallRequest( + arguments=["janitor"] + (["--ignore-ttl"] if ignore_ttl else []) + ) + request_id = self._send_jsonrpc_request(wfile, "sqlmesh/cli/call", request.model_dump()) + + with self.renderer as renderer: + while True: + try: + message_data = self._read_jsonrpc_message(rfile) + if "id" in message_data and message_data["id"] == request_id: + result = message_data.get("result", {}) + if result.get("state") == "finished": + return True + elif result.get("state") == "error": + print(f"Error from daemon: {result.get('message', 'Unknown error')}") + return False + elif message_data.get("method") == "sqlmesh/cli/update": + params = message_data.get("params", {}) + if params.get("state") == "ongoing": + message = params.get("message", {}) + if "state" in message: + janitor_state = JanitorState.model_validate(message) + renderer.render(janitor_state.state) + except Exception as stream_error: + print(f"Stream ended: {stream_error}") + break + return True + except Exception as e: + print(f"Failed to communicate with daemon: {e}") + return False + finally: + try: + if rfile: rfile.close() + finally: + if wfile: wfile.close() + +def get_daemon_connector(project_path: Path) -> t.Optional[DaemonConnector]: + """Get a daemon connector if a valid lock file exists.""" + lock_path = return_lock_file_path(project_path) + + if not lock_path.exists(): + return None + + try: + # Validate the lock file + lock_file = _validate_lock_file(lock_path) + + # Check if communication info is present + if lock_file.communication is None: + return None + + return DaemonConnector(project_path, lock_file) + except Exception as e: + # Log the error but don't fail - fall back to direct execution + print(f"Warning: Could not connect to daemon: {e}") + return None diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 2f18c0a4b7..5990d6f2c2 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -24,6 +24,7 @@ from sqlmesh.utils import Verbosity from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError +from sqlmesh.cli.daemon_connector import get_daemon_connector logger = logging.getLogger(__name__) @@ -640,7 +641,16 @@ def janitor(ctx: click.Context, ignore_ttl: bool, **kwargs: t.Any) -> None: The janitor cleans up old environments and expired snapshots. """ - ctx.obj.run_janitor(ignore_ttl, **kwargs) + project_path = Path.cwd() + daemon = get_daemon_connector(project_path) + if daemon: + print("Connecting to SQLMesh daemon...") + success = daemon.call_janitor(ignore_ttl) + print("Janitor completed, with success:", success) + else: + # No daemon available, run directly + # ctx.obj.run_janitor(ignore_ttl, **kwargs) + raise click.ClickException("no socket found") @cli.command("destroy") diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 8af837b08a..41f011189a 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -52,6 +52,7 @@ format_destructive_change_msg, format_additive_change_msg, ) +from sqlmesh.utils.pydantic import PydanticModel from sqlmesh.utils.rich import strip_ansi_codes if t.TYPE_CHECKING: @@ -186,6 +187,126 @@ def stop_cleanup(self, success: bool = True) -> None: """ +from rich.console import Console as RichConsole, Group +from rich.live import Live +from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn +from rich.table import Table +from rich.text import Text + + +class JanitorStateRenderer: + def __init__(self) -> None: + self.console = RichConsole() + self.live = Live(auto_refresh=True) + + def __enter__(self) -> "JanitorStateRenderer": + return self + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]], + exc_value: t.Optional[BaseException], + traceback: t.Optional[t.Any], + ) -> None: + self.live.stop() + + @staticmethod + def _render_start(cleanedup_obects: t.List[str]) -> t.List: + return [ + Text("Started janitor job."), + # Flatten + [Text(f"Deleted object {object_name}") for object_name in cleanedup_obects], + ] + + @staticmethod + def _render_completion(state: JanitorState) -> t.List: + if isinstance(state.state, JanitorStateFinished): + if state.state.success: + return [Text("Cleanup complete.")] + return [Text("Cleanup failed!")] + return [] + + def render(self, state: JanitorState) -> None: + # TODO Can we just get rid of the ignore_ttl we don't use it in the interface + if not self.live.is_started: + self.live.start() + group = Group( + Text("Started janitor job."), + # TO + JanitorStateRenderer.render_start(state), + JanitorStateRenderer.render_completion(state), + ) + self.live.update(group) + self.live.refresh() + + +class JanitorStateStarted(PydanticModel): + state: t.Literal["started"] = "started" + cleanedup_objects: t.List[str] = [] + ignore_ttl: bool + + +class JanitorStateFinished(PydanticModel): + state: t.Literal["finished"] = "finished" + cleanedup_object: t.List[str] = [] + ignore_ttl: bool + success: bool + + +JanitorStates = t.Union[ + JanitorStateStarted, + JanitorStateFinished, +] + +from pydantic import Field + + +class JanitorState(PydanticModel): + state: t.Annotated[ + JanitorStates, + Field(discriminator="state"), + ] + + +class JanitorConosoleDaemon(JanitorConsole): + state: t.Optional[JanitorState] = None + callback: t.Callable + + def __init__(self, callback: t.Callable): + self.callback = callback + + def send_message(self): + self.callback(self.state) + + def start_cleanup(self, ignore_ttl: bool) -> bool: + # TODO: In the normal one there is an input request that the server should wait for + self.state = JanitorState(state=JanitorStateStarted(ignore_ttl=ignore_ttl)) + self.send_message() + return True + + def update_cleanup_progress(self, object_name: str) -> None: + self.state.cleanedup_object.append(object_name) + self.send_message() + + def stop_cleanup(self, success: bool = True) -> None: + state = self.state + if state is None: + raise ValueError("should only finish after having started") + state_inside = state.state + if not instanceof(state_inside, JanitorStateStarted): + raise ValueError("should only finish after having started") + + self.state = JanitorState( + state=JanitorStateFinished( + cleanedup_object=state_inside.cleanedup_object, + ignore_ttl=state_inside.ignore_ttl, + success=success, + ) + ) + self.send_message() + self.state = None + + class DestroyConsole(abc.ABC): """Console for describing a destroy operation""" @@ -2614,16 +2735,10 @@ def show_row_diff( table.add_column(column_name, style=style, header_style=style) for _, row in column_table.iterrows(): - table.add_row( - *[ - str( - round(cell, row_diff.decimals) - if isinstance(cell, float) - else cell - ) - for cell in row - ] - ) + table.add_row(*[ + str(round(cell, row_diff.decimals) if isinstance(cell, float) else cell) + for cell in row + ]) self.console.print( f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]", @@ -2996,13 +3111,11 @@ def going_forward_change_callback(change: t.Dict[str, bool]) -> None: add_to_layout_widget( prompt, - widgets.HBox( - [ - widgets.Label("Effective From Date:", layout={"width": "8rem"}), - date_picker, - going_forward_checkbox, - ] - ), + widgets.HBox([ + widgets.Label("Effective From Date:", layout={"width": "8rem"}), + date_picker, + going_forward_checkbox, + ]), ) self._add_to_dynamic_options(prompt) @@ -3047,30 +3160,24 @@ def end_change_callback(change: t.Dict[str, datetime.datetime]) -> None: if plan_builder.is_start_and_end_allowed: add_to_layout_widget( prompt, - widgets.HBox( - [ - widgets.Label( - f"Start {backfill_or_preview} Date:", layout={"width": "8rem"} - ), - _date_picker( - plan_builder, to_date(plan_builder.build().start), start_change_callback - ), - ] - ), + widgets.HBox([ + widgets.Label(f"Start {backfill_or_preview} Date:", layout={"width": "8rem"}), + _date_picker( + plan_builder, to_date(plan_builder.build().start), start_change_callback + ), + ]), ) add_to_layout_widget( prompt, - widgets.HBox( - [ - widgets.Label(f"End {backfill_or_preview} Date:", layout={"width": "8rem"}), - _date_picker( - plan_builder, - to_date(plan_builder.build().end), - end_change_callback, - ), - ] - ), + widgets.HBox([ + widgets.Label(f"End {backfill_or_preview} Date:", layout={"width": "8rem"}), + _date_picker( + plan_builder, + to_date(plan_builder.build().end), + end_change_callback, + ), + ]), ) self._add_to_dynamic_options(prompt) @@ -3308,9 +3415,10 @@ def __init__(self, **kwargs: t.Any) -> None: self.warning_capture_only = kwargs.pop("warning_capture_only", False) self.error_capture_only = kwargs.pop("error_capture_only", False) - super().__init__( - **{**kwargs, "console": RichConsole(no_color=True, width=kwargs.pop("width", None))} - ) + super().__init__(**{ + **kwargs, + "console": RichConsole(no_color=True, width=kwargs.pop("width", None)), + }) def show_environment_difference_summary( self, @@ -3458,9 +3566,9 @@ def _print_modified_models( f"* `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}` ({category_str})" ) - indirectly_modified_children = sorted( - [s for s in indirectly_modified if snapshot.snapshot_id in s.parents] - ) + indirectly_modified_children = sorted([ + s for s in indirectly_modified if snapshot.snapshot_id in s.parents + ]) if not no_diff: diff_text = context_diff.text_diff(snapshot.name) @@ -3468,9 +3576,9 @@ def _print_modified_models( if diff_text: diff_text = f"\n```diff\n{diff_text}\n```" # these are part of a Markdown list, so indent them by 2 spaces to relate them to the current list item - diff_text_indented = "\n".join( - [f" {line}" for line in diff_text.splitlines()] - ) + diff_text_indented = "\n".join([ + f" {line}" for line in diff_text.splitlines() + ]) self._print(diff_text_indented) else: if indirectly_modified_children: @@ -3771,13 +3879,11 @@ def update_snapshot_evaluation_progress( status = "Loaded" if finished_loading else "Loading" print(f"{status} '{view_name}', Completed Batches: {loaded_batches}/{total_batches}") if finished_loading: - total_finished_loading = len( - [ - s - for s, total in self.evaluation_model_batch_sizes.items() - if self.evaluation_batch_progress.get(s.snapshot_id, (None, -1))[1] == total - ] - ) + total_finished_loading = len([ + s + for s, total in self.evaluation_model_batch_sizes.items() + if self.evaluation_batch_progress.get(s.snapshot_id, (None, -1))[1] == total + ]) total = len(self.evaluation_batch_progress) print(f"Completed Loading {total_finished_loading}/{total} Models") @@ -4153,9 +4259,10 @@ def create_console( rich_console_kwargs: t.Dict[str, t.Any] = {"theme": srich.theme} if runtime_env.is_jupyter or runtime_env.is_google_colab: rich_console_kwargs["force_jupyter"] = True - return runtime_env_mapping[runtime_env]( - **{**{"console": RichConsole(**rich_console_kwargs)}, **kwargs} - ) + return runtime_env_mapping[runtime_env](**{ + **{"console": RichConsole(**rich_console_kwargs)}, + **kwargs, + }) def _format_missing_intervals(snapshot: Snapshot, missing: SnapshotIntervals) -> str: diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index e3feb1e14b..16652b88f1 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -413,9 +413,10 @@ def __init__( global_defaults = self.config.model_defaults.model_dump(exclude_unset=True) gateway_defaults = gw_model_defaults.model_dump(exclude_unset=True) - self.config.model_defaults = ModelDefaultsConfig( - **{**global_defaults, **gateway_defaults} - ) + self.config.model_defaults = ModelDefaultsConfig(**{ + **global_defaults, + **gateway_defaults, + }) # This allows overriding the default dialect's normalization strategy, so for example # one can do `dialect="duckdb,normalization_strategy=lowercase"` and this will be @@ -513,13 +514,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: self.dag.add(model.fqn, model.depends_on) - self._models.update( - { - model.fqn: model, - # bust the fingerprint cache for all downstream models - **{fqn: self._models[fqn].copy() for fqn in self.dag.downstream(model.fqn)}, - } - ) + self._models.update({ + model.fqn: model, + # bust the fingerprint cache for all downstream models + **{fqn: self._models[fqn].copy() for fqn in self.dag.downstream(model.fqn)}, + }) update_model_schemas( self.dag, diff --git a/sqlmesh/lsp/cli_calls.py b/sqlmesh/lsp/cli_calls.py new file mode 100644 index 0000000000..51fc19a2ef --- /dev/null +++ b/sqlmesh/lsp/cli_calls.py @@ -0,0 +1,130 @@ +from pathlib import Path +import pathlib +import os +import time +import errno +from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh._version import __version__ +import typing as t +from pydantic import Field + + +class DaemonCommunicationModeTCP(PydanticModel): + type: t.Literal["tcp"] = "tcp" + address: str + + +class DaemonCommunicationModeUnixSocket(PydanticModel): + type: t.Literal["unix_socket"] = "unix_socket" + socket: str + + +class DaemonCommunicationMode(PydanticModel): + type: t.Union[DaemonCommunicationModeTCP, DaemonCommunicationModeUnixSocket] = Field( + discriminator="type" + ) + + +class LockFile(PydanticModel): + version: str + file_path: str + communication: t.Optional[DaemonCommunicationMode] = None + + def validate_lock_file(self, other: "LockFile") -> bool: + return self.version == other.version and self.file_path == other.file_path + + +def generate_lock_file() -> LockFile: + version = __version__ + bin_path = pathlib.Path(__file__).resolve() + return LockFile( + version=version, + file_path=str(bin_path), + ) + + +def return_lock_file_path(project_path: Path) -> Path: + return project_path / ".sqlmesh" / ".lsp_lock" + + +class FileLock: + """ + A simple, portable file‑based lock. + + Usage: + with FileLock("/tmp/mylock.lock"): + # critical section + ... + """ + + def __init__(self, lock_path: str, timeout: float = 10.0, wait: float = 0.1): + """ + :param lock_path: Path to the lock file. + :param timeout: Maximum seconds to wait before raising TimeoutError. + :param wait: Sleep time (in seconds) between attempts. + """ + self.lock_path = lock_path + self.timeout = timeout + self.wait = wait + self.fd = None # file descriptor returned by os.open + + def acquire(self) -> None: + """ + Acquire the lock. Blocks until the lock is obtained or the timeout + is reached, in which case a TimeoutError is raised. + """ + start = time.monotonic() + while True: + try: + # O_CREAT | O_EXCL guarantees atomic creation. + self.fd = os.open( + self.lock_path, + os.O_CREAT | os.O_EXCL | os.O_RDWR, + 0o644, # permissions for the new file + ) + # Optional: write the PID into the file so that stale locks can + # be identified later. + os.write(self.fd, str(os.getpid()).encode()) + return + except FileExistsError: + # File already exists – somebody else holds the lock. + if time.monotonic() - start > self.timeout: + raise TimeoutError( + f"Could not acquire lock on {self.lock_path} after {self.timeout} seconds" + ) + time.sleep(self.wait) + except OSError as exc: + # A different OSError – re‑raise so callers can see it. + raise exc + + def release(self) -> None: + """ + Release the lock and delete the lock file. + """ + if self.fd is not None: + try: + os.close(self.fd) + finally: + self.fd = None + + # Delete the file only if we created it. If another process created a + # file with the same name while we were running we must not delete it. + try: + os.unlink(self.lock_path) + except FileNotFoundError: + # The file might have been removed by another process. + pass + except PermissionError as exc: + # On Windows a file can’t be removed if it is still opened by + # another process; swallow the error – the lock is already + # released because we closed our fd. + if exc.errno != errno.EACCES: + raise + + # Context‑manager support --------------------------------------------- + def __enter__(self) -> "FileLock": + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.release() diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 71dc5e1e2b..4eb0fa5643 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -3,6 +3,8 @@ from itertools import chain import logging +import socket +import sys import typing as t from pathlib import Path import urllib.parse @@ -26,6 +28,12 @@ ApiResponseGetModels, ApiResponseGetTableDiff, ) +from sqlmesh.lsp.cli_calls import ( + DaemonCommunicationMode, + DaemonCommunicationModeUnixSocket, + generate_lock_file, + return_lock_file_path, +) from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS from sqlmesh.lsp.completions import get_sql_completions @@ -69,6 +77,7 @@ GetModelsResponse, ModelInfo, ) +from pydantic import Field from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position from sqlmesh.lsp.hints import get_hints @@ -126,11 +135,40 @@ class ContextFailed: ContextState = Union[NoContext, ContextLoaded, ContextFailed] +class LSPCLICallRequest(PydanticModel): + argumens: t.List[str] + + +class SocketMessageFinished(PydanticModel): + state: t.Literal["finished"] = "finished" + + +T = t.TypeVar("T", bound=PydanticModel) + + +class SocketMessageOngoing(PydanticModel, t.Generic[T]): + state: t.Literal["ongoing"] = "ongoing" + message: T + + +class SocketMessageError(PydanticModel): + state: t.Literal["error"] = "error" + message: str + + +class SocketMessage(PydanticModel): + state: t.Union[SocketMessageOngoing, SocketMessageFinished] = Field(discriminator="state") + + +CLI_CALL_FEATURE = "sqlmesh/cli/call" + + class SQLMeshLanguageServer: # Specified folders take precedence over workspace folders or looking # for a config files. They are explicitly set by the user and optionally # pass in at init specified_paths: t.Optional[t.List[Path]] = None + claenup_calls: t.List = [] def __init__( self, @@ -170,6 +208,7 @@ def __init__( RUN_TEST_FEATURE: self._run_test, GET_ENVIRONMENTS_FEATURE: self._custom_get_environments, GET_MODELS_FEATURE: self._custom_get_models, + CLI_CALL_FEATURE: self._custom_cli_call, } # Register LSP features (e.g., formatting, hover, etc.) @@ -232,26 +271,26 @@ def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> Al try: context = self._context_get_or_load(uri) return LSPContext.get_completions(context, uri, content) - except Exception as e: + except Exception as _: from sqlmesh.lsp.completions import get_sql_completions return get_sql_completions(None, URI(params.textDocument.uri), content) def _custom_render_model( - self, ls: LanguageServer, params: RenderModelRequest + self, _: LanguageServer, params: RenderModelRequest ) -> RenderModelResponse: uri = URI(params.textDocumentUri) context = self._context_get_or_load(uri) return RenderModelResponse(models=context.render_model(uri)) def _custom_all_models_for_render( - self, ls: LanguageServer, params: AllModelsForRenderRequest + self, ls: LanguageServer, _: AllModelsForRenderRequest ) -> AllModelsForRenderResponse: context = self._context_get_or_load() return AllModelsForRenderResponse(models=context.list_of_models_for_rendering()) def _custom_format_project( - self, ls: LanguageServer, params: FormatProjectRequest + self, ls: LanguageServer, _: FormatProjectRequest ) -> FormatProjectResponse: """Format all models in the current project.""" try: @@ -263,7 +302,7 @@ def _custom_format_project( return FormatProjectResponse() def _custom_get_environments( - self, ls: LanguageServer, params: GetEnvironmentsRequest + self, ls: LanguageServer, _: GetEnvironmentsRequest ) -> GetEnvironmentsResponse: """Get all environments in the current project.""" try: @@ -425,6 +464,71 @@ def _custom_api( raise NotImplementedError(f"API request not implemented: {request.url}") + def _custom_cli_call(self, ls: LanguageServer, params: LSPCLICallRequest) -> SocketMessage: + """Handle CLI call requests from the daemon connector.""" + try: + context = self._context_get_or_load() + if not context or not hasattr(context, "context"): + return SocketMessage(state=SocketMessageError(message="No context available")) + + arguments = params.argumens if hasattr(params, "argumens") else params.arguments + + # For now, only support janitor command + if not arguments or arguments[0] != "janitor": + return SocketMessage( + state=SocketMessageError(message="Only 'janitor' command is supported") + ) + + # Parse janitor arguments + ignore_ttl = "--ignore-ttl" in arguments + + # Run the janitor with a custom console that sends updates via notifications + try: + # Create a daemon console that will send state updates + from sqlmesh.core.console import JanitorConosoleDaemon, JanitorState + import json + + def send_state(state: JanitorState): + """Callback to send janitor state updates as notifications.""" + # Send each update as a separate JSON-RPC notification + notification = { + "jsonrpc": "2.0", + "method": "sqlmesh/cli/update", + "params": {"state": "ongoing", "message": state.state.model_dump()}, + } + + # Send notification through the server's transport + message = json.dumps(notification) + content_length = len(message.encode("utf-8")) + header = f"Content-Length: {content_length}\r\n\r\n" + full_message = header.encode("utf-8") + message.encode("utf-8") + + # Write directly to the server's output stream + if hasattr(ls.lsp, "_writer") and ls.lsp._writer: + ls.lsp._writer.write(full_message) + ls.lsp._writer.flush() + + # Create daemon console with callback + daemon_console = JanitorConosoleDaemon() + daemon_console.callback = send_state + + # Run janitor with the daemon console + original_console = context.context.console + try: + context.context.console = daemon_console + context.context.run_janitor(ignore_ttl=ignore_ttl) + finally: + context.context.console = original_console + + # Return final success message + return SocketMessage(state=SocketMessageFinished()) + + except Exception as e: + return SocketMessage(state=SocketMessageError(message=f"Janitor failed: {str(e)}")) + + except Exception as e: + return SocketMessage(state=SocketMessageError(message=f"CLI call failed: {str(e)}")) + def _custom_supported_methods( self, ls: LanguageServer, params: SupportedMethodsRequest ) -> SupportedMethodsResponse: @@ -1181,10 +1285,15 @@ def _uri_to_path(uri: str) -> Path: """Convert a URI to a path.""" return URI(uri).to_path() - def start(self) -> None: - """Start the server with I/O transport.""" + def start(self, rfile: t.Optional[t.Any], wfile: t.Optional[t.Any]) -> None: + """Start the server with Unix socket (for both VS Code and CLI).""" logging.basicConfig(level=logging.DEBUG) - self.server.start_io() + + if rfile is None or wfile is None: + self.server.start_io() + else: + self.server.start_io(rfile, wfile) + def loaded_sqlmesh_message(ls: LanguageServer) -> None: @@ -1195,9 +1304,39 @@ def loaded_sqlmesh_message(ls: LanguageServer) -> None: def main() -> None: + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--pipe", help="Unix socket path for communication", default=None) + args = parser.parse_args() + # Example instantiator that just uses the same signature as your original `Context` usage. + file = args.pipe sqlmesh_server = SQLMeshLanguageServer(context_class=Context) - sqlmesh_server.start() + if args.pipe: + # Connect to the pipe the VS Code client created + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(args.pipe) + except OSError as e: + print(f"Failed to connect to pipe {args.pipe}: {e}", file=sys.stderr) + sys.exit(1) + + # Wrap the socket as file-like binary streams for pygls + rfile = sock.makefile("rb", buffering=0) + wfile = sock.makefile("wb", buffering=0) + + try: + sqlmesh_server.start(rfile, wfile) + finally: + try: rfile.close() + except: pass + try: wfile.close() + except: pass + try: sock.close() + except: pass + else: + sqlmesh_server.start() if __name__ == "__main__": diff --git a/vscode/extension/src/lsp/lsp.ts b/vscode/extension/src/lsp/lsp.ts index 1a11249853..9fbbaa8d5c 100644 --- a/vscode/extension/src/lsp/lsp.ts +++ b/vscode/extension/src/lsp/lsp.ts @@ -17,6 +17,9 @@ import { } from '../utilities/errors' import { CustomLSPMethods } from './custom' import { resolveProjectPath } from '../utilities/config' +import * as path from 'path' +import * as os from 'os' +import * as crypto from 'crypto' type SupportedMethodsState = | { type: 'not-fetched' } @@ -96,18 +99,23 @@ export class LSPClient implements Disposable { } const workspacePath = sqlmesh.value.workspacePath + + // Add --socket argument to the LSP server + const socketPath = socketAddressForWorkspace(workspacePath) + const argsWithSocket = [...sqlmesh.value.args] + // '--socket', socketPath] const serverOptions: ServerOptions = { run: { command: sqlmesh.value.bin, - transport: TransportKind.stdio, + transport: TransportKind.pipe, options: { cwd: workspacePath, env: sqlmesh.value.env }, - args: sqlmesh.value.args, + args: argsWithSocket, }, debug: { command: sqlmesh.value.bin, - transport: TransportKind.stdio, + transport: TransportKind.pipe, options: { cwd: workspacePath, env: sqlmesh.value.env }, - args: sqlmesh.value.args, + args: argsWithSocket, }, } const paths = resolveProjectPath(getWorkspaceFolders()[0]) @@ -267,3 +275,17 @@ export class LSPClient implements Disposable { } } } + +function socketAddressForWorkspace(workspacePath: string) { + const hash = crypto.createHash('md5').update(workspacePath).digest('hex').slice(0, 8) + const base = `sqlmesh_${hash}` + + if (process.platform === 'win32') { + // Windows wants a Named Pipe path, not a filesystem .sock file + // NOTE: double backslashes are required in JS strings + return `\\\\.\\pipe\\${base}` + } + + // POSIX: use a short dir (tmp) to avoid the ~108 byte sun_path limit + return path.join(os.tmpdir(), `${base}.sock`) +} \ No newline at end of file