Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion src/flyte/remote/_client/controlplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from flyteidl2.dataproxy import dataproxy_service_pb2
from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient
from flyteidl2.project.project_service_connect import ProjectServiceClient
from flyteidl2.secret import payload_pb2 as secret_payload_pb2
from flyteidl2.secret.secret_connect import SecretServiceClient
from flyteidl2.task.task_service_connect import TaskServiceClient
from flyteidl2.trigger.trigger_service_connect import TriggerServiceClient
Expand Down Expand Up @@ -313,6 +314,94 @@ async def _select_and_build(self, req: cluster_payload_pb2.SelectClusterRequest)
return DataProxyServiceClient(**new_cfg.connect_kwargs())


class ClusterAwareSecretService:
"""Secret service client that routes each call to the correct cluster.

Same pattern as ClusterAwareDataProxy: uses SelectCluster with
OPERATION_USE_SECRETS to discover the cluster endpoint, then dispatches
to a per-cluster SecretServiceClient. Clients are cached by project.
"""

def __init__(
self,
cluster_service: ClusterService,
session_config: SessionConfig,
default_client: SecretServiceClient,
):
self._cluster_service = cluster_service
self._session_config = session_config
self._default_client = default_client

async def create_secret(
self, request: secret_payload_pb2.CreateSecretRequest
) -> secret_payload_pb2.CreateSecretResponse:
client = await self._resolve(request.id.organization, request.id.project, request.id.domain)
return await client.create_secret(request)

async def update_secret(
self, request: secret_payload_pb2.UpdateSecretRequest
) -> secret_payload_pb2.UpdateSecretResponse:
client = await self._resolve(request.id.organization, request.id.project, request.id.domain)
return await client.update_secret(request)

async def get_secret(self, request: secret_payload_pb2.GetSecretRequest) -> secret_payload_pb2.GetSecretResponse:
client = await self._resolve(request.id.organization, request.id.project, request.id.domain)
return await client.get_secret(request)

async def list_secrets(
self, request: secret_payload_pb2.ListSecretsRequest
) -> secret_payload_pb2.ListSecretsResponse:
client = await self._resolve(request.organization, request.project, request.domain)
return await client.list_secrets(request)

async def delete_secret(
self, request: secret_payload_pb2.DeleteSecretRequest
) -> secret_payload_pb2.DeleteSecretResponse:
client = await self._resolve(request.id.organization, request.id.project, request.id.domain)
return await client.delete_secret(request)

@alru_cache
async def _resolve(self, org: str, project: str, domain: str) -> SecretService:
"""Cached SelectCluster lookup for secrets.

Routes by ProjectIdentifier when project and domain are set,
DomainIdentifier when only domain is set (domain-scoped secrets),
or OrgIdentifier for org-wide secrets.
"""
from flyte._logging import logger

req = cluster_payload_pb2.SelectClusterRequest(
operation=cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_USE_SECRETS,
)
if project and domain:
req.project_id.CopyFrom(identifier_pb2.ProjectIdentifier(name=project, domain=domain, organization=org))
elif domain:
req.domain_id.CopyFrom(identifier_pb2.DomainIdentifier(name=domain, organization=org))
else:
req.org_id.CopyFrom(identifier_pb2.OrgIdentifier(name=org))
try:
resp = await self._cluster_service.select_cluster(req)
except Exception as e:
raise RuntimeError(f"SelectCluster failed for OPERATION_USE_SECRETS: {e}") from e

endpoint = resp.cluster_endpoint
if not endpoint or endpoint == self._session_config.endpoint:
return self._default_client

try:
new_cfg = await create_session_config(
endpoint,
insecure=self._session_config.insecure,
insecure_skip_verify=self._session_config.insecure_skip_verify,
auth_endpoint=self._session_config.endpoint,
)
except Exception as e:
raise RuntimeError(f"Failed to create session for cluster endpoint '{endpoint}': {e}") from e

logger.debug(f"Created SecretService client for cluster endpoint: {endpoint}")
return SecretServiceClient(**new_cfg.connect_kwargs())


class ClientSet:
def __init__(self, session_cfg: SessionConfig):
self._console = Console(session_cfg.endpoint, session_cfg.insecure)
Expand All @@ -323,10 +412,14 @@ def __init__(self, session_cfg: SessionConfig):
self._app_service = AppServiceClient(**shared)
self._run_service = RunServiceClient(**shared)
self._log_service = RunLogsServiceClient(**shared)
self._secrets_service = SecretServiceClient(**shared)
self._identity_service = IdentityServiceClient(**shared)
self._trigger_service = TriggerServiceClient(**shared)
self._cluster_service = ClusterServiceClient(**shared)
self._secrets_service = ClusterAwareSecretService(
cluster_service=self._cluster_service,
session_config=session_cfg,
default_client=SecretServiceClient(**shared),
)
self._dataproxy = ClusterAwareDataProxy(
cluster_service=self._cluster_service,
session_config=session_cfg,
Expand Down
202 changes: 202 additions & 0 deletions tests/flyte/remote/test_cluster_aware_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Tests for the ClusterAwareSecretService wrapper in flyte.remote._client.controlplane."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2
from flyteidl2.secret import definition_pb2 as secret_definition_pb2
from flyteidl2.secret import payload_pb2 as secret_payload_pb2

from flyte.remote._client.controlplane import ClusterAwareSecretService


def _make_wrapper(
cluster_endpoint: str = "",
own_endpoint: str = "dns:///localhost:8090",
):
cluster_service = MagicMock()
cluster_service.select_cluster = AsyncMock(
return_value=cluster_payload_pb2.SelectClusterResponse(cluster_endpoint=cluster_endpoint)
)
session_config = MagicMock()
session_config.endpoint = own_endpoint
session_config.insecure = True
session_config.insecure_skip_verify = False
default_client = MagicMock()
default_client.create_secret = AsyncMock(return_value=secret_payload_pb2.CreateSecretResponse())
default_client.update_secret = AsyncMock(return_value=secret_payload_pb2.UpdateSecretResponse())
default_client.get_secret = AsyncMock(return_value=secret_payload_pb2.GetSecretResponse())
default_client.list_secrets = AsyncMock(return_value=secret_payload_pb2.ListSecretsResponse())
default_client.delete_secret = AsyncMock(return_value=secret_payload_pb2.DeleteSecretResponse())
return (
ClusterAwareSecretService(
cluster_service=cluster_service,
session_config=session_config,
default_client=default_client,
),
cluster_service,
default_client,
)


def _secret_id(org="o", project="p", domain="d", name="s"):
return secret_definition_pb2.SecretIdentifier(organization=org, project=project, domain=domain, name=name)


# --- Routing: project-scoped secrets ---


@pytest.mark.asyncio
async def test_create_secret_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.CreateSecretRequest(id=_secret_id())

await wrapper.create_secret(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_USE_SECRETS
assert sent.WhichOneof("resource") == "project_id"
assert sent.project_id.name == "p"
assert sent.project_id.domain == "d"
assert sent.project_id.organization == "o"
default_client.create_secret.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_get_secret_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.GetSecretRequest(id=_secret_id())

await wrapper.get_secret(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "project_id"
assert sent.project_id.name == "p"
default_client.get_secret.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_update_secret_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.UpdateSecretRequest(id=_secret_id())

await wrapper.update_secret(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "project_id"
default_client.update_secret.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_delete_secret_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.DeleteSecretRequest(id=_secret_id())

await wrapper.delete_secret(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "project_id"
default_client.delete_secret.assert_awaited_once_with(req)


@pytest.mark.asyncio
async def test_list_secrets_routes_by_project():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.ListSecretsRequest(organization="o", project="p", domain="d")

await wrapper.list_secrets(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "project_id"
assert sent.project_id.name == "p"
assert sent.project_id.domain == "d"
default_client.list_secrets.assert_awaited_once_with(req)


# --- Routing: domain-scoped secrets ---


@pytest.mark.asyncio
async def test_get_secret_domain_only_routes_by_domain_id():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.GetSecretRequest(id=_secret_id(project="", domain="d"))

await wrapper.get_secret(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "domain_id"
assert sent.domain_id.name == "d"
assert sent.domain_id.organization == "o"
default_client.get_secret.assert_awaited_once_with(req)


# --- Routing: org-wide secrets ---


@pytest.mark.asyncio
async def test_get_secret_org_only_routes_by_org_id():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.GetSecretRequest(id=_secret_id(project="", domain=""))

await wrapper.get_secret(req)

sent = cluster_service.select_cluster.await_args[0][0]
assert sent.WhichOneof("resource") == "org_id"
assert sent.org_id.name == "o"
default_client.get_secret.assert_awaited_once_with(req)


# --- Caching ---


@pytest.mark.asyncio
async def test_cache_hits_reuse_selected_client():
wrapper, cluster_service, default_client = _make_wrapper()
req = secret_payload_pb2.GetSecretRequest(id=_secret_id())

await wrapper.get_secret(req)
await wrapper.get_secret(req)

assert cluster_service.select_cluster.await_count == 1
assert default_client.get_secret.await_count == 2


@pytest.mark.asyncio
async def test_different_projects_get_separate_cache_entries():
wrapper, cluster_service, _ = _make_wrapper()

await wrapper.get_secret(secret_payload_pb2.GetSecretRequest(id=_secret_id(project="p1")))
await wrapper.get_secret(secret_payload_pb2.GetSecretRequest(id=_secret_id(project="p2")))

assert cluster_service.select_cluster.await_count == 2


# --- Remote cluster ---


@pytest.mark.asyncio
async def test_remote_cluster_endpoint_creates_new_client():
wrapper, cluster_service, default_client = _make_wrapper(cluster_endpoint="dns:///other:8090")

new_client_inst = MagicMock()
new_client_inst.get_secret = AsyncMock(return_value=secret_payload_pb2.GetSecretResponse())
new_session_cfg = MagicMock()
new_session_cfg.connect_kwargs.return_value = {}

with (
patch(
"flyte.remote._client.controlplane.create_session_config",
new=AsyncMock(return_value=new_session_cfg),
),
patch(
"flyte.remote._client.controlplane.SecretServiceClient",
return_value=new_client_inst,
),
):
req = secret_payload_pb2.GetSecretRequest(id=_secret_id())
await wrapper.get_secret(req)
await wrapper.get_secret(req)

assert cluster_service.select_cluster.await_count == 1
assert new_client_inst.get_secret.await_count == 2
default_client.get_secret.assert_not_awaited()
Loading