diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index ff2ac3975..b120fa3e3 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -12,6 +12,7 @@ from flyteidl2.dataproxy import dataproxy_service_pb2 from flyteidl2.dataproxy.dataproxy_service_connect import DataProxyServiceClient from flyteidl2.project.project_service_connect import ProjectServiceClient +from flyteidl2.secret import payload_pb2 as secret_payload_pb2 from flyteidl2.secret.secret_connect import SecretServiceClient from flyteidl2.task.task_service_connect import TaskServiceClient from flyteidl2.trigger.trigger_service_connect import TriggerServiceClient @@ -313,6 +314,94 @@ async def _select_and_build(self, req: cluster_payload_pb2.SelectClusterRequest) return DataProxyServiceClient(**new_cfg.connect_kwargs()) +class ClusterAwareSecretService: + """Secret service client that routes each call to the correct cluster. + + Same pattern as ClusterAwareDataProxy: uses SelectCluster with + OPERATION_USE_SECRETS to discover the cluster endpoint, then dispatches + to a per-cluster SecretServiceClient. Clients are cached by project. + """ + + def __init__( + self, + cluster_service: ClusterService, + session_config: SessionConfig, + default_client: SecretServiceClient, + ): + self._cluster_service = cluster_service + self._session_config = session_config + self._default_client = default_client + + async def create_secret( + self, request: secret_payload_pb2.CreateSecretRequest + ) -> secret_payload_pb2.CreateSecretResponse: + client = await self._resolve(request.id.organization, request.id.project, request.id.domain) + return await client.create_secret(request) + + async def update_secret( + self, request: secret_payload_pb2.UpdateSecretRequest + ) -> secret_payload_pb2.UpdateSecretResponse: + client = await self._resolve(request.id.organization, request.id.project, request.id.domain) + return await client.update_secret(request) + + async def get_secret(self, request: secret_payload_pb2.GetSecretRequest) -> secret_payload_pb2.GetSecretResponse: + client = await self._resolve(request.id.organization, request.id.project, request.id.domain) + return await client.get_secret(request) + + async def list_secrets( + self, request: secret_payload_pb2.ListSecretsRequest + ) -> secret_payload_pb2.ListSecretsResponse: + client = await self._resolve(request.organization, request.project, request.domain) + return await client.list_secrets(request) + + async def delete_secret( + self, request: secret_payload_pb2.DeleteSecretRequest + ) -> secret_payload_pb2.DeleteSecretResponse: + client = await self._resolve(request.id.organization, request.id.project, request.id.domain) + return await client.delete_secret(request) + + @alru_cache + async def _resolve(self, org: str, project: str, domain: str) -> SecretService: + """Cached SelectCluster lookup for secrets. + + Routes by ProjectIdentifier when project and domain are set, + DomainIdentifier when only domain is set (domain-scoped secrets), + or OrgIdentifier for org-wide secrets. + """ + from flyte._logging import logger + + req = cluster_payload_pb2.SelectClusterRequest( + operation=cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_USE_SECRETS, + ) + if project and domain: + req.project_id.CopyFrom(identifier_pb2.ProjectIdentifier(name=project, domain=domain, organization=org)) + elif domain: + req.domain_id.CopyFrom(identifier_pb2.DomainIdentifier(name=domain, organization=org)) + else: + req.org_id.CopyFrom(identifier_pb2.OrgIdentifier(name=org)) + try: + resp = await self._cluster_service.select_cluster(req) + except Exception as e: + raise RuntimeError(f"SelectCluster failed for OPERATION_USE_SECRETS: {e}") from e + + endpoint = resp.cluster_endpoint + if not endpoint or endpoint == self._session_config.endpoint: + return self._default_client + + try: + new_cfg = await create_session_config( + endpoint, + insecure=self._session_config.insecure, + insecure_skip_verify=self._session_config.insecure_skip_verify, + auth_endpoint=self._session_config.endpoint, + ) + except Exception as e: + raise RuntimeError(f"Failed to create session for cluster endpoint '{endpoint}': {e}") from e + + logger.debug(f"Created SecretService client for cluster endpoint: {endpoint}") + return SecretServiceClient(**new_cfg.connect_kwargs()) + + class ClientSet: def __init__(self, session_cfg: SessionConfig): self._console = Console(session_cfg.endpoint, session_cfg.insecure) @@ -323,10 +412,14 @@ def __init__(self, session_cfg: SessionConfig): self._app_service = AppServiceClient(**shared) self._run_service = RunServiceClient(**shared) self._log_service = RunLogsServiceClient(**shared) - self._secrets_service = SecretServiceClient(**shared) self._identity_service = IdentityServiceClient(**shared) self._trigger_service = TriggerServiceClient(**shared) self._cluster_service = ClusterServiceClient(**shared) + self._secrets_service = ClusterAwareSecretService( + cluster_service=self._cluster_service, + session_config=session_cfg, + default_client=SecretServiceClient(**shared), + ) self._dataproxy = ClusterAwareDataProxy( cluster_service=self._cluster_service, session_config=session_cfg, diff --git a/tests/flyte/remote/test_cluster_aware_secrets.py b/tests/flyte/remote/test_cluster_aware_secrets.py new file mode 100644 index 000000000..3fb6b4960 --- /dev/null +++ b/tests/flyte/remote/test_cluster_aware_secrets.py @@ -0,0 +1,202 @@ +"""Tests for the ClusterAwareSecretService wrapper in flyte.remote._client.controlplane.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from flyteidl2.cluster import payload_pb2 as cluster_payload_pb2 +from flyteidl2.secret import definition_pb2 as secret_definition_pb2 +from flyteidl2.secret import payload_pb2 as secret_payload_pb2 + +from flyte.remote._client.controlplane import ClusterAwareSecretService + + +def _make_wrapper( + cluster_endpoint: str = "", + own_endpoint: str = "dns:///localhost:8090", +): + cluster_service = MagicMock() + cluster_service.select_cluster = AsyncMock( + return_value=cluster_payload_pb2.SelectClusterResponse(cluster_endpoint=cluster_endpoint) + ) + session_config = MagicMock() + session_config.endpoint = own_endpoint + session_config.insecure = True + session_config.insecure_skip_verify = False + default_client = MagicMock() + default_client.create_secret = AsyncMock(return_value=secret_payload_pb2.CreateSecretResponse()) + default_client.update_secret = AsyncMock(return_value=secret_payload_pb2.UpdateSecretResponse()) + default_client.get_secret = AsyncMock(return_value=secret_payload_pb2.GetSecretResponse()) + default_client.list_secrets = AsyncMock(return_value=secret_payload_pb2.ListSecretsResponse()) + default_client.delete_secret = AsyncMock(return_value=secret_payload_pb2.DeleteSecretResponse()) + return ( + ClusterAwareSecretService( + cluster_service=cluster_service, + session_config=session_config, + default_client=default_client, + ), + cluster_service, + default_client, + ) + + +def _secret_id(org="o", project="p", domain="d", name="s"): + return secret_definition_pb2.SecretIdentifier(organization=org, project=project, domain=domain, name=name) + + +# --- Routing: project-scoped secrets --- + + +@pytest.mark.asyncio +async def test_create_secret_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.CreateSecretRequest(id=_secret_id()) + + await wrapper.create_secret(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.operation == cluster_payload_pb2.SelectClusterRequest.Operation.OPERATION_USE_SECRETS + assert sent.WhichOneof("resource") == "project_id" + assert sent.project_id.name == "p" + assert sent.project_id.domain == "d" + assert sent.project_id.organization == "o" + default_client.create_secret.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_get_secret_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.GetSecretRequest(id=_secret_id()) + + await wrapper.get_secret(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "project_id" + assert sent.project_id.name == "p" + default_client.get_secret.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_update_secret_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.UpdateSecretRequest(id=_secret_id()) + + await wrapper.update_secret(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "project_id" + default_client.update_secret.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_delete_secret_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.DeleteSecretRequest(id=_secret_id()) + + await wrapper.delete_secret(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "project_id" + default_client.delete_secret.assert_awaited_once_with(req) + + +@pytest.mark.asyncio +async def test_list_secrets_routes_by_project(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.ListSecretsRequest(organization="o", project="p", domain="d") + + await wrapper.list_secrets(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "project_id" + assert sent.project_id.name == "p" + assert sent.project_id.domain == "d" + default_client.list_secrets.assert_awaited_once_with(req) + + +# --- Routing: domain-scoped secrets --- + + +@pytest.mark.asyncio +async def test_get_secret_domain_only_routes_by_domain_id(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.GetSecretRequest(id=_secret_id(project="", domain="d")) + + await wrapper.get_secret(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "domain_id" + assert sent.domain_id.name == "d" + assert sent.domain_id.organization == "o" + default_client.get_secret.assert_awaited_once_with(req) + + +# --- Routing: org-wide secrets --- + + +@pytest.mark.asyncio +async def test_get_secret_org_only_routes_by_org_id(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.GetSecretRequest(id=_secret_id(project="", domain="")) + + await wrapper.get_secret(req) + + sent = cluster_service.select_cluster.await_args[0][0] + assert sent.WhichOneof("resource") == "org_id" + assert sent.org_id.name == "o" + default_client.get_secret.assert_awaited_once_with(req) + + +# --- Caching --- + + +@pytest.mark.asyncio +async def test_cache_hits_reuse_selected_client(): + wrapper, cluster_service, default_client = _make_wrapper() + req = secret_payload_pb2.GetSecretRequest(id=_secret_id()) + + await wrapper.get_secret(req) + await wrapper.get_secret(req) + + assert cluster_service.select_cluster.await_count == 1 + assert default_client.get_secret.await_count == 2 + + +@pytest.mark.asyncio +async def test_different_projects_get_separate_cache_entries(): + wrapper, cluster_service, _ = _make_wrapper() + + await wrapper.get_secret(secret_payload_pb2.GetSecretRequest(id=_secret_id(project="p1"))) + await wrapper.get_secret(secret_payload_pb2.GetSecretRequest(id=_secret_id(project="p2"))) + + assert cluster_service.select_cluster.await_count == 2 + + +# --- Remote cluster --- + + +@pytest.mark.asyncio +async def test_remote_cluster_endpoint_creates_new_client(): + wrapper, cluster_service, default_client = _make_wrapper(cluster_endpoint="dns:///other:8090") + + new_client_inst = MagicMock() + new_client_inst.get_secret = AsyncMock(return_value=secret_payload_pb2.GetSecretResponse()) + new_session_cfg = MagicMock() + new_session_cfg.connect_kwargs.return_value = {} + + with ( + patch( + "flyte.remote._client.controlplane.create_session_config", + new=AsyncMock(return_value=new_session_cfg), + ), + patch( + "flyte.remote._client.controlplane.SecretServiceClient", + return_value=new_client_inst, + ), + ): + req = secret_payload_pb2.GetSecretRequest(id=_secret_id()) + await wrapper.get_secret(req) + await wrapper.get_secret(req) + + assert cluster_service.select_cluster.await_count == 1 + assert new_client_inst.get_secret.await_count == 2 + default_client.get_secret.assert_not_awaited()