Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ConfigurationWithPortsParams,
DevEnvironmentConfiguration,
PortMapping,
RunAttachConfiguration,
RunAttachParams,
RunConfigurationType,
ServiceConfiguration,
TaskConfiguration,
Expand All @@ -57,6 +59,12 @@
get_repo_creds_and_default_branch,
load_repo,
)
from dstack._internal.core.services.ssh.attach import (
SSHProxyAwsSSMConfig,
SSHProxyCommandConfig,
SSHProxyConfig,
SSHProxyJumpConfig,
)
from dstack._internal.utils.common import local_time
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -241,8 +249,30 @@ def apply_configuration(
bind_address: Optional[str] = getattr(
configurator_args, _BIND_ADDRESS_ARG, None
)
# Map the attach.proxy settings to the original configuration
attach_proxy_config = SSHProxyConfig()
if isinstance(conf, RunAttachConfiguration) and isinstance(
conf.attach, RunAttachParams
):
attach = conf.attach
if attach.proxy.type == "jump":
attach_proxy_config = SSHProxyJumpConfig(attach.proxy.proxy_jump)
elif attach.proxy.type == "command":
attach_proxy_config = SSHProxyCommandConfig(attach.proxy.proxy_command)
elif attach.proxy.type == "aws-ssm":
attach_proxy_config = SSHProxyAwsSSMConfig(
profile=attach.proxy.profile,
region=attach.proxy.region,
document_name=attach.proxy.document_name,
)
if attach.proxy.type != "none":
console.print(
f"Using client-side attach proxy: [code]{attach.proxy.type}[/]"
)
try:
if run.attach(bind_address=bind_address):
if run.attach(
bind_address=bind_address, attach_proxy_config=attach_proxy_config
):
for entry in run.logs():
sys.stdout.buffer.write(entry)
sys.stdout.buffer.flush()
Expand Down
74 changes: 74 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,83 @@ def schema_extra(schema: Dict[str, Any]):
BaseRunConfigurationConfig.schema_extra(schema)


AttachProxyType = Literal["none", "jump", "command", "aws-ssm"]


class AttachProxyNone(CoreModel):
type: Annotated[
Literal["none"], Field(description="No client-side proxy for the first SSH hop")
] = "none"


class AttachProxyJump(CoreModel):
type: Annotated[
Literal["jump"],
Field(description="Use ProxyJump on the client-side for the first SSH hop"),
] = "jump"
proxy_jump: Annotated[
str,
Field(description="Host alias from ~/.ssh/config for using in ProxyJump"),
]


class AttachProxyCommand(CoreModel):
type: Annotated[
Literal["command"],
Field(description="Use ProxyCommand on the client-side for the first SSH hop"),
] = "command"
proxy_command: Annotated[
str,
Field(
description=(
"ProxyCommand string to execute for the first hop."
" The value is passed as-is to ssh_config."
" If you need stream forwarding through SSH, include '-W %h:%p' yourself."
)
),
]


class AttachProxyAwsSSM(CoreModel):
type: Annotated[
Literal["aws-ssm"], Field(description="Use AWS SSM as a proxy for the first SSH hop")
] = "aws-ssm"
profile: Annotated[Optional[str], Field(description="AWS profile name to use")] = None
region: Annotated[Optional[str], Field(description="AWS region for SSM")] = None
document_name: Annotated[
str,
Field(description="SSM document name for SSH session"),
] = "AWS-StartSSHSession"


class RunAttachParams(CoreModel):
proxy: Annotated[
Union[
AttachProxyNone,
AttachProxyJump,
AttachProxyCommand,
AttachProxyAwsSSM,
],
Field(
discriminator="type",
description="Client-side SSH transport overrides for attach",
),
] = AttachProxyNone()


class RunAttachConfiguration(CoreModel):
attach: Annotated[
RunAttachParams,
Field(description="Attach transport settings (client-side only)", exclude=True),
] = RunAttachParams()


class DevEnvironmentConfiguration(
ProfileParams,
BaseRunConfiguration,
ConfigurationWithPortsParams,
DevEnvironmentConfigurationParams,
RunAttachConfiguration,
generate_dual_core_model(DevEnvironmentConfigurationConfig),
):
type: Literal["dev-environment"] = "dev-environment"
Expand Down Expand Up @@ -680,6 +752,7 @@ class TaskConfiguration(
ConfigurationWithCommandsParams,
ConfigurationWithPortsParams,
TaskConfigurationParams,
RunAttachConfiguration,
generate_dual_core_model(TaskConfigurationConfig),
):
type: Literal["task"] = "task"
Expand Down Expand Up @@ -838,6 +911,7 @@ class ServiceConfiguration(
BaseRunConfiguration,
ConfigurationWithCommandsParams,
ServiceConfigurationParams,
RunAttachConfiguration,
generate_dual_core_model(ServiceConfigurationConfig),
):
type: Literal["service"] = "service"
Expand Down
94 changes: 93 additions & 1 deletion src/dstack/_internal/core/services/ssh/attach.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import atexit
import dataclasses
import re
import time
from pathlib import Path
from typing import Optional, Union

import psutil

from dstack._internal.core.errors import SSHError
from dstack._internal.core.errors import ConfigurationError, SSHError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobProvisioningData
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.ssh.client import get_ssh_client_info
from dstack._internal.core.services.ssh.ports import PortsLock
Expand All @@ -25,6 +28,82 @@
_SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P<local_port>\d+):localhost:(?P<remote_port>\d+)")


HostConfigType = dict[str, Union[str, int, FilePath]]


class SSHProxyConfig:
"""Do nothing"""

def update_host(self, host: HostConfigType):
pass

def apply_provisioning_data(self, provisioning_data: JobProvisioningData):
pass


@dataclasses.dataclass
class SSHProxyJumpConfig(SSHProxyConfig):
"""Add ProxyJump to the given host configuration"""

jump_host: str

def update_host(self, host: HostConfigType):
host["ProxyJump"] = self.jump_host


@dataclasses.dataclass
class SSHProxyCommandConfig(SSHProxyConfig):
"""Add ProxyCommand to the given host configuration"""

command: str

def update_host(self, host: HostConfigType):
host["ProxyCommand"] = self.command


@dataclasses.dataclass
class SSHProxyAwsSSMConfig(SSHProxyConfig):
"""Add ProxyCommand to use AWS SSM"""

profile: Optional[str] = None
region: Optional[str] = None
document_name: Optional[str] = None
instance_id: Optional[str] = None
instance_region: Optional[str] = None

def update_host(self, host: HostConfigType):
if self.instance_id:
host["HostName"] = self.instance_id
region = self.region if self.region else self.instance_region
document = self.document_name if self.document_name else "AWS-StartSSHSession"
args = [
"aws",
"ssm",
"start-session",
"--target",
"%h",
"--document-name",
document,
"--parameters",
"portNumber=%p",
]
if region:
args.extend(["--region", region])
if self.profile:
args.extend(["--profile", self.profile])
command = f"sh -c '{' '.join(map(str, args))}'"
host["ProxyCommand"] = command

def apply_provisioning_data(self, provisioning_data: JobProvisioningData):
backend = provisioning_data.get_base_backend()
if backend != BackendType.AWS:
raise ConfigurationError(
"attach.proxy.type=aws-ssm is supported only for the AWS backend"
)
self.instance_id = provisioning_data.instance_id
self.instance_region = provisioning_data.region


class SSHAttach:
@classmethod
def get_control_sock_path(cls, run_name: str) -> Path:
Expand Down Expand Up @@ -67,6 +146,7 @@ def __init__(
service_port: Optional[int] = None,
local_backend: bool = False,
bind_address: Optional[str] = None,
proxy_config: Optional[SSHProxyConfig] = None,
):
self._ports_lock = ports_lock
self.ports = ports_lock.dict()
Expand Down Expand Up @@ -199,6 +279,18 @@ def __init__(
}
)

if proxy_config:
# Apply proxy configuration for the first hop connection
first_hop_key: Optional[str] = None
if f"{run_name}-jump-host" in hosts:
first_hop_key = f"{run_name}-jump-host"
elif f"{run_name}-host" in hosts:
first_hop_key = f"{run_name}-host"
else:
first_hop_key = run_name

proxy_config.update_host(hosts[first_hop_key])

def attach(self):
include_ssh_config(self.ssh_config_path)
for host, options in self.hosts.items():
Expand Down
7 changes: 6 additions & 1 deletion src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from dstack._internal.core.models.runs import Run as RunModel
from dstack._internal.core.services.logs import URLReplacer
from dstack._internal.core.services.ssh.attach import SSHAttach
from dstack._internal.core.services.ssh.attach import SSHAttach, SSHProxyConfig
from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.utils.common import get_or_error, make_proxy_url
Expand Down Expand Up @@ -259,6 +259,7 @@ def attach(
ports_overrides: Optional[List[PortMapping]] = None,
replica_num: Optional[int] = None,
job_num: int = 0,
attach_proxy_config: Optional[SSHProxyConfig] = None,
) -> bool:
"""
Establish an SSH tunnel to the instance and update SSH config
Expand Down Expand Up @@ -347,6 +348,9 @@ def attach(
if isinstance(self._run.run_spec.configuration, ServiceConfiguration):
service_port = get_service_port(job.job_spec, self._run.run_spec.configuration)

if attach_proxy_config:
attach_proxy_config.apply_provisioning_data(provisioning_data)

self._ssh_attach = SSHAttach(
hostname=provisioning_data.hostname,
ssh_port=provisioning_data.ssh_port,
Expand All @@ -361,6 +365,7 @@ def attach(
service_port=service_port,
local_backend=provisioning_data.backend == BackendType.LOCAL,
bind_address=bind_address,
proxy_config=attach_proxy_config,
)
if not ports_lock:
self._ssh_attach.attach()
Expand Down