From d86cd63dfa1e47135fa3a71afdc0b151037ff4fd Mon Sep 17 00:00:00 2001 From: Kenta Murata Date: Sat, 20 Sep 2025 06:26:22 +0900 Subject: [PATCH 1/2] Support client-side proxies for apply command This is a concept implementation to support SSH proxy in `dstack apply` comand. ## Summary - Add client-side SSH proxy configuration to dev/task/service configs - Let `dstack apply` reuse those settings for ProxyJump, ProxyCommand, and AWS SSM ## Liimitations - `dstack attach` still bypasses the new proxy controls, so detached runs cannot yet use them ## Next Steps - Add cloud-specific helpers (e.g., GCP, Azure) on top of this structure ## Example configuration ```yaml type: dev-environment name: example ide: vscode attach: proxy: type: aws-ssm profile: sso-profile-name ``` --- .../cli/services/configurators/run.py | 30 +++++- .../_internal/core/models/configurations.py | 74 +++++++++++++++ .../_internal/core/services/ssh/attach.py | 94 ++++++++++++++++++- src/dstack/api/_public/runs.py | 7 +- 4 files changed, 202 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 0403a57a6..dbace6c18 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -40,6 +40,8 @@ ConfigurationWithPortsParams, DevEnvironmentConfiguration, PortMapping, + RunAttachConfiguration, + RunAttachParams, RunConfigurationType, ServiceConfiguration, TaskConfiguration, @@ -57,6 +59,12 @@ get_repo_creds_and_default_branch, load_repo, ) +from dstack._internal.core.services.ssh.attach import ( + SSHProxyAwsSSMConfig, + SSHProxyCommandConfig, + SSHProxyConfig, + SSHProxyJumpConfig, +) from dstack._internal.utils.common import local_time from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger @@ -241,8 +249,28 @@ def apply_configuration( bind_address: Optional[str] = getattr( configurator_args, _BIND_ADDRESS_ARG, None ) + # Map the attach.proxy settings to the original configuration + attach_proxy_config = SSHProxyConfig() + if isinstance(conf, RunAttachConfiguration) and conf.attach is not None: + attach: RunAttachParams = conf.attach + if attach.proxy.type == "jump": + attach_proxy_config = SSHProxyJumpConfig(attach.proxy.proxy_jump) + elif attach.proxy.type == "command": + attach_proxy_config = SSHProxyCommandConfig(attach.proxy.proxy_command) + elif attach.proxy.type == "aws-ssm": + attach_proxy_config = SSHProxyAwsSSMConfig( + profile=attach.proxy.profile, + region=attach.proxy.region, + document_name=attach.proxy.document_name, + ) + if attach.proxy.type != "none": + console.print( + f"Using client-side attach proxy: [code]{attach.proxy.type}[/]" + ) try: - if run.attach(bind_address=bind_address): + if run.attach( + bind_address=bind_address, attach_proxy_config=attach_proxy_config + ): for entry in run.logs(): sys.stdout.buffer.write(entry) sys.stdout.buffer.flush() diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 6fe8132de..fc4ab4590 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -644,11 +644,83 @@ def schema_extra(schema: Dict[str, Any]): BaseRunConfigurationConfig.schema_extra(schema) +AttachProxyType = Literal["none", "jump", "command", "aws-ssm"] + + +class AttachProxyNone(CoreModel): + type: Annotated[ + Literal["none"], Field(description="No client-side proxy for the first SSH hop") + ] = "none" + + +class AttachProxyJump(CoreModel): + type: Annotated[ + Literal["jump"], + Field(description="Use ProxyJump on the client-side for the first SSH hop"), + ] = "jump" + proxy_jump: Annotated[ + str, + Field(description="Host alias from ~/.ssh/config for using in ProxyJump"), + ] + + +class AttachProxyCommand(CoreModel): + type: Annotated[ + Literal["command"], + Field(description="Use ProxyCommand on the client-side for the first SSH hop"), + ] = "command" + proxy_command: Annotated[ + str, + Field( + description=( + "ProxyCommand string to execute for the first hop." + " The value is passed as-is to ssh_config." + " If you need stream forwarding through SSH, include '-W %h:%p' yourself." + ) + ), + ] + + +class AttachProxyAwsSSM(CoreModel): + type: Annotated[ + Literal["aws-ssm"], Field(description="Use AWS SSM as a proxy for the first SSH hop") + ] = "aws-ssm" + profile: Annotated[Optional[str], Field(description="AWS profile name to use")] = None + region: Annotated[Optional[str], Field(description="AWS region for SSM")] = None + document_name: Annotated[ + str, + Field(description="SSM document name for SSH session"), + ] = "AWS-StartSSHSession" + + +class RunAttachParams(CoreModel): + proxy: Annotated[ + Union[ + AttachProxyNone, + AttachProxyJump, + AttachProxyCommand, + AttachProxyAwsSSM, + ], + Field( + discriminator="type", + description="Client-side SSH transport overrides for attach", + ), + ] = AttachProxyNone() + + +class RunAttachConfiguration(CoreModel): + attach: Annotated[ + RunAttachParams, + Field(description="Attach transport settings (client-side only)", exclude=True), + ] = RunAttachParams() + + class DevEnvironmentConfiguration( ProfileParams, BaseRunConfiguration, ConfigurationWithPortsParams, DevEnvironmentConfigurationParams, + RunAttachConfiguration, generate_dual_core_model(DevEnvironmentConfigurationConfig), ): type: Literal["dev-environment"] = "dev-environment" @@ -680,6 +752,7 @@ class TaskConfiguration( ConfigurationWithCommandsParams, ConfigurationWithPortsParams, TaskConfigurationParams, + RunAttachConfiguration, generate_dual_core_model(TaskConfigurationConfig), ): type: Literal["task"] = "task" @@ -838,6 +911,7 @@ class ServiceConfiguration( BaseRunConfiguration, ConfigurationWithCommandsParams, ServiceConfigurationParams, + RunAttachConfiguration, generate_dual_core_model(ServiceConfigurationConfig), ): type: Literal["service"] = "service" diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index d0ad4ac64..5ed78554d 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -1,4 +1,5 @@ import atexit +import dataclasses import re import time from pathlib import Path @@ -6,8 +7,10 @@ import psutil -from dstack._internal.core.errors import SSHError +from dstack._internal.core.errors import ConfigurationError, SSHError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import SSHConnectionParams +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.ssh.client import get_ssh_client_info from dstack._internal.core.services.ssh.ports import PortsLock @@ -25,6 +28,82 @@ _SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P\d+):localhost:(?P\d+)") +HostConfigType = dict[str, Union[str, int, FilePath]] + + +class SSHProxyConfig: + """Do nothing""" + + def update_host(self, host: HostConfigType): + pass + + def apply_provisioning_data(self, provisioning_data: JobProvisioningData): + pass + + +@dataclasses.dataclass +class SSHProxyJumpConfig(SSHProxyConfig): + """Add ProxyJump to the given host configuration""" + + jump_host: str + + def update_host(self, host: HostConfigType): + host["ProxyJump"] = self.jump_host + + +@dataclasses.dataclass +class SSHProxyCommandConfig(SSHProxyConfig): + """Add ProxyCommand to the given host configuration""" + + command: str + + def update_host(self, host: HostConfigType): + host["ProxyCommand"] = self.command + + +@dataclasses.dataclass(kw_only=True) +class SSHProxyAwsSSMConfig(SSHProxyConfig): + """Add ProxyCommand to use AWS SSM""" + + profile: Optional[str] = None + region: Optional[str] = None + document_name: Optional[str] = None + instance_id: Optional[str] = None + instance_region: Optional[str] = None + + def update_host(self, host: HostConfigType): + if self.instance_id: + host["HostName"] = self.instance_id + region = self.region if self.region else self.instance_region + document = self.document_name if self.document_name else "AWS-StartSSHSession" + args = [ + "aws", + "ssm", + "start-session", + "--target", + "%h", + "--document-name", + document, + "--parameters", + "portNumber=%p", + ] + if region: + args.extend(["--region", region]) + if self.profile: + args.extend(["--profile", self.profile]) + command = f"sh -c '{' '.join(map(str, args))}'" + host["ProxyCommand"] = command + + def apply_provisioning_data(self, provisioning_data: JobProvisioningData): + backend = provisioning_data.get_base_backend() + if backend != BackendType.AWS: + raise ConfigurationError( + "attach.proxy.type=aws-ssm is supported only for the AWS backend" + ) + self.instance_id = provisioning_data.instance_id + self.instance_region = provisioning_data.region + + class SSHAttach: @classmethod def get_control_sock_path(cls, run_name: str) -> Path: @@ -67,6 +146,7 @@ def __init__( service_port: Optional[int] = None, local_backend: bool = False, bind_address: Optional[str] = None, + proxy_config: Optional[SSHProxyConfig] = None, ): self._ports_lock = ports_lock self.ports = ports_lock.dict() @@ -199,6 +279,18 @@ def __init__( } ) + if proxy_config: + # Apply proxy configuration for the first hop connection + first_hop_key: Optional[str] = None + if f"{run_name}-jump-host" in hosts: + first_hop_key = f"{run_name}-jump-host" + elif f"{run_name}-host" in hosts: + first_hop_key = f"{run_name}-host" + else: + first_hop_key = run_name + + proxy_config.update_host(hosts[first_hop_key]) + def attach(self): include_ssh_config(self.ssh_config_path) for host, options in self.hosts.items(): diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index c6b786337..51265670b 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -46,7 +46,7 @@ ) from dstack._internal.core.models.runs import Run as RunModel from dstack._internal.core.services.logs import URLReplacer -from dstack._internal.core.services.ssh.attach import SSHAttach +from dstack._internal.core.services.ssh.attach import SSHAttach, SSHProxyConfig from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url @@ -259,6 +259,7 @@ def attach( ports_overrides: Optional[List[PortMapping]] = None, replica_num: Optional[int] = None, job_num: int = 0, + attach_proxy_config: Optional[SSHProxyConfig] = None, ) -> bool: """ Establish an SSH tunnel to the instance and update SSH config @@ -347,6 +348,9 @@ def attach( if isinstance(self._run.run_spec.configuration, ServiceConfiguration): service_port = get_service_port(job.job_spec, self._run.run_spec.configuration) + if attach_proxy_config: + attach_proxy_config.apply_provisioning_data(provisioning_data) + self._ssh_attach = SSHAttach( hostname=provisioning_data.hostname, ssh_port=provisioning_data.ssh_port, @@ -361,6 +365,7 @@ def attach( service_port=service_port, local_backend=provisioning_data.backend == BackendType.LOCAL, bind_address=bind_address, + proxy_config=attach_proxy_config, ) if not ports_lock: self._ssh_attach.attach() From a45a93def0409e77dd0a1a9d497d9fa3b3cdf6b9 Mon Sep 17 00:00:00 2001 From: Kenta Murata Date: Mon, 22 Sep 2025 15:20:28 +0900 Subject: [PATCH 2/2] Fix for pyright errors --- src/dstack/_internal/cli/services/configurators/run.py | 6 ++++-- src/dstack/_internal/core/services/ssh/attach.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index dbace6c18..0ac413cc2 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -251,8 +251,10 @@ def apply_configuration( ) # Map the attach.proxy settings to the original configuration attach_proxy_config = SSHProxyConfig() - if isinstance(conf, RunAttachConfiguration) and conf.attach is not None: - attach: RunAttachParams = conf.attach + if isinstance(conf, RunAttachConfiguration) and isinstance( + conf.attach, RunAttachParams + ): + attach = conf.attach if attach.proxy.type == "jump": attach_proxy_config = SSHProxyJumpConfig(attach.proxy.proxy_jump) elif attach.proxy.type == "command": diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index 5ed78554d..6136f5cb8 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -61,7 +61,7 @@ def update_host(self, host: HostConfigType): host["ProxyCommand"] = self.command -@dataclasses.dataclass(kw_only=True) +@dataclasses.dataclass class SSHProxyAwsSSMConfig(SSHProxyConfig): """Add ProxyCommand to use AWS SSM"""