diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 5c136b687..fbcd5d760 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -78,8 +78,6 @@ import type { RegistryRepositoriesGetRegistryRepositoryResponse, RegistryRepositoriesListRegistryRepositoriesResponse, RegistryRepositoriesReloadRegistryRepositoriesResponse, - RegistryRepositoriesSyncExecutorFromRegistryRepositoryData, - RegistryRepositoriesSyncExecutorFromRegistryRepositoryResponse, RegistryRepositoriesSyncRegistryRepositoryData, RegistryRepositoriesSyncRegistryRepositoryResponse, RegistryRepositoriesUpdateRegistryRepositoryData, @@ -1894,28 +1892,6 @@ export const registryRepositoriesSyncRegistryRepository = ( }) } -/** - * Sync Executor From Registry Repository - * @param data The data for the request. - * @param data.repositoryId - * @returns void Successful Response - * @throws ApiError - */ -export const registryRepositoriesSyncExecutorFromRegistryRepository = ( - data: RegistryRepositoriesSyncExecutorFromRegistryRepositoryData -): CancelablePromise => { - return __request(OpenAPI, { - method: "POST", - url: "/registry/repos/{repository_id}/sync-executor", - path: { - repository_id: data.repositoryId, - }, - errors: { - 422: "Validation Error", - }, - }) -} - /** * List Registry Repositories * List all registry repositories. diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 9f4f0e949..db1f33e16 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -1884,13 +1884,6 @@ export type RegistryRepositoriesSyncRegistryRepositoryData = { export type RegistryRepositoriesSyncRegistryRepositoryResponse = void -export type RegistryRepositoriesSyncExecutorFromRegistryRepositoryData = { - repositoryId: string -} - -export type RegistryRepositoriesSyncExecutorFromRegistryRepositoryResponse = - void - export type RegistryRepositoriesListRegistryRepositoriesResponse = Array @@ -3015,21 +3008,6 @@ export type $OpenApiTs = { } } } - "/registry/repos/{repository_id}/sync-executor": { - post: { - req: RegistryRepositoriesSyncExecutorFromRegistryRepositoryData - res: { - /** - * Successful Response - */ - 204: void - /** - * Validation Error - */ - 422: HTTPValidationError - } - } - } "/registry/repos": { get: { res: { diff --git a/frontend/src/components/registry/registry-repos-table.tsx b/frontend/src/components/registry/registry-repos-table.tsx index c2f44f9a2..85db1bc82 100644 --- a/frontend/src/components/registry/registry-repos-table.tsx +++ b/frontend/src/components/registry/registry-repos-table.tsx @@ -5,7 +5,6 @@ import { RegistryRepositoryReadMinimal } from "@/client" import { DropdownMenuLabel } from "@radix-ui/react-dropdown-menu" import { DotsHorizontalIcon } from "@radix-ui/react-icons" import { - ArrowRightToLineIcon, CopyIcon, LoaderCircleIcon, RefreshCcw, @@ -54,8 +53,6 @@ export function RegistryRepositoriesTable() { syncRepo, syncRepoIsPending, deleteRepo, - syncExecutor, - syncExecutorIsPending, } = useRegistryRepositories() const [selectedRepo, setSelectedRepo] = useState(null) @@ -101,7 +98,7 @@ export function RegistryRepositoriesTable() { label: (
- Sync only + Sync
), action: async () => { @@ -130,120 +127,6 @@ export function RegistryRepositoriesTable() { } }, }, - { - label: ( -
- - Sync and push to executor -
- ), - action: async () => { - if (!selectedRepo) { - console.error("No repository selected") - return - } - console.log("Reloading repository", selectedRepo.origin) - try { - await syncRepo({ repositoryId: selectedRepo.id }) - toast({ - title: "Successfully synced repository", - description: ( -
-
- Successfully reloaded actions from{" "} - {selectedRepo.origin} -
-
- ), - }) - await syncExecutor({ repositoryId: selectedRepo.id }) - toast({ - title: "Successfully pushed to executor", - description: ( -
-
- Successfully pushed actions from{" "} - {selectedRepo.origin} -
-
- ), - }) - } catch (error) { - console.error("Error reloading repository", error) - } finally { - setSelectedRepo(null) - } - }, - }, - ], - } - case AlertAction.SYNC_EXECUTOR: - return { - title: "Push to executor", - description: ( -
- - You are about to push the current version of the repository{" "} - - - {selectedRepo?.origin} - - to the executor. - {selectedRepo?.commit_sha && ( -
- Current SHA: - - {selectedRepo.commit_sha} - -
- )} - {selectedRepo?.last_synced_at && ( -
- Last synced: - - {new Date(selectedRepo.last_synced_at).toLocaleString()} - -
- )} -

- Are you sure you want to proceed? This will reload all existing - modules from this repository on the executor. -

-
- ), - actions: [ - { - label: ( -
- - Push to executor -
- ), - action: async () => { - if (!selectedRepo) { - console.error("No repository selected") - return - } - try { - await syncExecutor({ repositoryId: selectedRepo.id }) - toast({ - title: "Successfully synced executor", - description: ( -
-
- Successfully reloaded actions from{" "} - {selectedRepo.origin} -
-
- ), - }) - } catch (error) { - console.error("Error syncing executor", error) - } finally { - setSelectedRepo(null) - } - }, - }, ], } case AlertAction.DELETE: @@ -402,11 +285,11 @@ export function RegistryRepositoriesTable() { > Open menu {row.original.id === selectedRepo?.id && - (syncRepoIsPending || syncExecutorIsPending) ? ( + syncRepoIsPending ? (
- {syncRepoIsPending ? "Pulling..." : "Pushing..."} + Pulling...
) : ( @@ -473,20 +356,6 @@ export function RegistryRepositoriesTable() { Sync from remote - {row.original.last_synced_at !== null && ( - { - e.stopPropagation() // Prevent row click - setSelectedRepo(row.original) - setAlertAction(AlertAction.SYNC_EXECUTOR) - setAlertOpen(true) - }} - > - - Push to executor - - )} { @@ -526,7 +395,7 @@ export function RegistryRepositoriesTable() { setAlertOpen(false) await action.action() }} - disabled={syncRepoIsPending || syncExecutorIsPending} + disabled={syncRepoIsPending} > {action.label} diff --git a/frontend/src/lib/hooks.tsx b/frontend/src/lib/hooks.tsx index aac44d3e5..ec9c75f42 100644 --- a/frontend/src/lib/hooks.tsx +++ b/frontend/src/lib/hooks.tsx @@ -37,8 +37,6 @@ import { RegistryRepositoriesDeleteRegistryRepositoryData, registryRepositoriesListRegistryRepositories, registryRepositoriesReloadRegistryRepositories, - registryRepositoriesSyncExecutorFromRegistryRepository, - RegistryRepositoriesSyncExecutorFromRegistryRepositoryData, registryRepositoriesSyncRegistryRepository, RegistryRepositoriesSyncRegistryRepositoryData, RegistryRepositoryReadMinimal, @@ -1159,40 +1157,6 @@ export function useRegistryRepositories() { }, }) - const { - mutateAsync: syncExecutor, - isPending: syncExecutorIsPending, - error: syncExecutorError, - } = useMutation({ - mutationFn: async ( - params: RegistryRepositoriesSyncExecutorFromRegistryRepositoryData - ) => await registryRepositoriesSyncExecutorFromRegistryRepository(params), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ["registry_repositories"] }) - queryClient.invalidateQueries({ queryKey: ["registry_actions"] }) - toast({ - title: "Synced executor", - description: "Executor synced successfully.", - }) - }, - onError: (error: TracecatApiError) => { - const apiError = error as TracecatApiError - switch (apiError.status) { - case 403: - toast({ - title: "You cannot perform this action", - description: `${apiError.message}: ${apiError.body.detail}`, - }) - break - default: - toast({ - title: "Failed to sync executor", - description: `An unexpected error occurred while syncing the executor. ${apiError.message}: ${apiError.body.detail}`, - }) - } - }, - }) - return { repos, reposIsLoading, @@ -1203,9 +1167,6 @@ export function useRegistryRepositories() { deleteRepo, deleteRepoIsPending, deleteRepoError, - syncExecutor, - syncExecutorIsPending, - syncExecutorError, } } diff --git a/pyproject.toml b/pyproject.toml index cad00840e..2e4d0ef81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "alembic==1.13.2", "asyncpg==0.29.0", "authlib>=1.3.1,<1.4.0", + "async-lru==2.0.4", "cloudpickle==3.0.0", "colorlog==6.8.2", "cryptography==43.0.1", @@ -54,6 +55,7 @@ dependencies = [ "tenacity==8.3.0", "uv==0.4.10", "uvicorn==0.29.0", + "virtualenv==20.27.0", ] dynamic = ["version"] diff --git a/tests/unit/test_executor_service.py b/tests/unit/test_executor_service.py new file mode 100644 index 000000000..85d1f6bd2 --- /dev/null +++ b/tests/unit/test_executor_service.py @@ -0,0 +1,145 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest + +from tracecat.dsl.models import ActionStatement, RunActionInput, RunContext +from tracecat.executor.models import DispatchActionContext +from tracecat.executor.service import _dispatch_action, dispatch_action_on_cluster +from tracecat.expressions.common import ExprContext +from tracecat.git import GitUrl +from tracecat.types.auth import Role + + +@pytest.fixture +def mock_session(): + return AsyncMock() + + +@pytest.fixture +def basic_task_input(): + """Fixture that provides a basic RunActionInput without looping.""" + wf_id = "wf-" + uuid.uuid4().hex + wf_exec_id = wf_id + ":exec-test" + wf_run_id = uuid.uuid4() + return RunActionInput( + task=ActionStatement( + action="test_action", + args={"key": "value"}, + ref="test_ref", + ), + exec_context={ + ExprContext.ACTIONS: { + "test_action": { + "args": {"key": "value"}, + "ref": "test-ref", + } + } + }, + run_context=RunContext( + wf_id=wf_id, + wf_exec_id=wf_exec_id, + wf_run_id=wf_run_id, + environment="test-env", + ), + ) + + +@pytest.fixture +def basic_looped_task_input(): + wf_id = "wf-" + uuid.uuid4().hex + wf_exec_id = wf_id + ":exec-test" + wf_run_id = uuid.uuid4() + return RunActionInput( + task=ActionStatement( + action="test_action", + args={"key": "value"}, + ref="test_ref", + for_each="${{ for var.x in [1,2,3] }}", + ), + exec_context={ + ExprContext.ACTIONS: { + "test_action": { + "args": {"key": "value"}, + "ref": "test-ref", + } + } + }, + run_context=RunContext( + wf_id=wf_id, + wf_exec_id=wf_exec_id, + wf_run_id=wf_run_id, + environment="test-env", + ), + ) + + +@pytest.fixture +def dispatch_context(): + return DispatchActionContext( + role=Role(type="service", service_id="tracecat-executor"), + ssh_command="ssh -i /tmp/key", + git_url=GitUrl(host="github.com", org="org", repo="repo", ref="abc123"), + ) + + +@pytest.mark.anyio +async def test_dispatch_action_basic(mock_session, basic_task_input, dispatch_context): + with patch("tracecat.executor.service.run_action_on_ray_cluster") as mock_ray: + mock_ray.return_value = {"result": "success"} + + result = await _dispatch_action(input=basic_task_input, ctx=dispatch_context) + + assert result == {"result": "success"} + mock_ray.assert_called_once_with(basic_task_input, dispatch_context) + + +@pytest.mark.anyio +async def test_dispatch_action_with_foreach( + mock_session, basic_looped_task_input, dispatch_context +): + with patch("tracecat.executor.service.run_action_on_ray_cluster") as mock_ray: + mock_ray.return_value = {"result": "success"} + + result = await _dispatch_action( + input=basic_looped_task_input, ctx=dispatch_context + ) + + assert result == [{"result": "success"}] * 3 + + # Assert the number of calls + assert mock_ray.call_count == 3 + + # Get all calls and their arguments + calls = mock_ray.call_args_list + + # Verify each call's arguments + for i, call in enumerate(calls, 1): + args, kwargs = call + input_arg = args[0] + # Verify the loop variable 'x' was set to different values (1, 2, 3) + assert input_arg.task.args["key"] == "value" + assert input_arg.exec_context[ExprContext.LOCAL_VARS] == {"x": i} + assert args[1] == dispatch_context + + +@pytest.mark.anyio +async def test_dispatch_action_with_git_url(mock_session, basic_task_input): + with ( + patch("tracecat.executor.service.prepare_git_url") as mock_git_url, + patch("tracecat.executor.service._dispatch_action") as mock_dispatch, + patch("tracecat.executor.service.opt_temp_key_file") as mock_key_file, + ): + mock_git_url.return_value = GitUrl( + host="github.com", org="org", repo="repo", ref="abc123" + ) + mock_key_file.return_value.__aenter__.return_value = "ssh -i /tmp/key" + mock_dispatch.return_value = {"result": "success"} + + result = await dispatch_action_on_cluster( + input=basic_task_input, session=mock_session + ) + + assert result == {"result": "success"} + mock_git_url.assert_called_once() + mock_key_file.return_value.__aenter__.assert_called_once() diff --git a/tracecat/api/executor.py b/tracecat/api/executor.py index 1983368b2..8877638f7 100644 --- a/tracecat/api/executor.py +++ b/tracecat/api/executor.py @@ -6,7 +6,6 @@ from tracecat import config from tracecat.api.common import ( - bootstrap_role, custom_generate_unique_id, generic_exception_handler, setup_oss_models, @@ -16,9 +15,6 @@ from tracecat.executor.router import router as executor_router from tracecat.logger import logger from tracecat.middleware import RequestLoggingMiddleware -from tracecat.registry.repositories.service import RegistryReposService -from tracecat.registry.repository import Repository -from tracecat.settings.service import get_setting from tracecat.types.exceptions import TracecatException @@ -28,49 +24,10 @@ async def lifespan(app: FastAPI): await setup_oss_models() except Exception as e: logger.error("Failed to preload OSS models", error=e) - try: - await setup_custom_remote_repository() - except Exception as e: - logger.error("Error setting up custom remote repository", exc=e) - with setup_ray(): yield -async def setup_custom_remote_repository(): - """Install the remote repository if it is set. - - Steps - ----- - 1. Get the SHA of the remote repository from the DB - 2. If it doesn't exist, create it - 3. If it does exist, sync it - """ - role = bootstrap_role() - url = await get_setting( - "git_repo_url", - role=role, - # TODO: Deprecate in future version - default=config.TRACECAT__REMOTE_REPOSITORY_URL, - ) - if not url: - logger.info("Remote repository URL not set, skipping") - return - logger.info("Remote repository URL found", url=url) - async with RegistryReposService.with_session(role) as service: - db_repo = await service.get_repository(url) - # If it doesn't exist, do nothing - if db_repo is None: - logger.warning("Remote repository not found in DB, skipping") - return - # If it does exist, sync it - if db_repo.last_synced_at is None: - logger.info("Remote repository not synced, skipping") - return - repo = Repository(db_repo.origin, role=role) - await repo.load_from_origin(commit_sha=db_repo.commit_sha) - - def create_app(**kwargs) -> FastAPI: if config.TRACECAT__ALLOW_ORIGINS is not None: allow_origins = config.TRACECAT__ALLOW_ORIGINS.split(",") diff --git a/tracecat/dsl/action.py b/tracecat/dsl/action.py index 9fa7ff1e4..34e882d59 100644 --- a/tracecat/dsl/action.py +++ b/tracecat/dsl/action.py @@ -9,13 +9,15 @@ from temporalio.exceptions import ApplicationError from tracecat.contexts import ctx_logger, ctx_run +from tracecat.db.engine import get_async_session_context_manager from tracecat.dsl.common import context_locator from tracecat.dsl.models import ActionErrorInfo, ActionStatement, RunActionInput from tracecat.executor.client import ExecutorClient from tracecat.logger import logger from tracecat.registry.actions.models import RegistryActionValidateResponse from tracecat.types.auth import Role -from tracecat.types.exceptions import ExecutorClientError +from tracecat.types.exceptions import ExecutorClientError, RegistryError +from tracecat.validation.service import validate_registry_action_args def contextualize_message( @@ -61,10 +63,25 @@ async def validate_action_activity( - Validate the action arguments against the UDF spec. - Return the validated arguments. """ - client = ExecutorClient(role=input.role) - return await client.validate_action( - action_name=input.task.action, args=input.task.args - ) + try: + async with get_async_session_context_manager() as session: + result = await validate_registry_action_args( + session=session, + action_name=input.task.action, + args=input.task.args, + ) + + if result.status == "error": + logger.warning( + "Error validating UDF args", + message=result.msg, + details=result.detail, + ) + return RegistryActionValidateResponse.from_validation_result(result) + except KeyError as e: + raise RegistryError( + f"Action {input.task.action!r} not found in registry", + ) from e @staticmethod @activity.defn diff --git a/tracecat/dsl/models.py b/tracecat/dsl/models.py index 2be334583..a19264c21 100644 --- a/tracecat/dsl/models.py +++ b/tracecat/dsl/models.py @@ -2,14 +2,14 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Annotated, Any, Literal, TypedDict +from typing import Any, Literal, TypedDict from pydantic import BaseModel, Field from tracecat.dsl.constants import DEFAULT_ACTION_TIMEOUT from tracecat.dsl.enums import JoinStrategy from tracecat.expressions.common import ExprContext -from tracecat.expressions.validation import ExpressionStr, TemplateValidator +from tracecat.expressions.validation import ExpressionStr from tracecat.identifiers import WorkflowExecutionID, WorkflowID, WorkflowRunID from tracecat.secrets.constants import DEFAULT_SECRETS_ENVIRONMENT @@ -94,20 +94,13 @@ class ActionStatement(BaseModel): """Control flow options""" - run_if: Annotated[ - str | None, - Field(default=None, description="Condition to run the task"), - TemplateValidator(), - ] - - for_each: Annotated[ - str | list[str] | None, - Field( - default=None, - description="Iterate over a list of items and run the task for each item.", - ), - TemplateValidator(), - ] + run_if: ExpressionStr | None = Field( + default=None, description="Condition to run the task" + ) + for_each: ExpressionStr | list[ExpressionStr] | None = Field( + default=None, + description="Iterate over a list of items and run the task for each item.", + ) retry_policy: ActionRetryPolicy = Field( default_factory=ActionRetryPolicy, description="Retry policy for the action." ) diff --git a/tracecat/executor/client.py b/tracecat/executor/client.py index afb88d304..c16215d21 100644 --- a/tracecat/executor/client.py +++ b/tracecat/executor/client.py @@ -7,19 +7,12 @@ import httpx import orjson -from pydantic import UUID4 -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) from tracecat import config from tracecat.clients import AuthenticatedServiceClient from tracecat.contexts import ctx_role from tracecat.dsl.models import RunActionInput -from tracecat.executor.models import ExecutorActionErrorInfo, ExecutorSyncInput +from tracecat.executor.models import ExecutorActionErrorInfo from tracecat.logger import logger from tracecat.registry.actions.models import ( RegistryActionValidateResponse, @@ -118,63 +111,6 @@ async def validate_action( f"Unexpected error while listing registries: {str(e)}" ) from e - # === Management === - - async def sync_executor( - self, repository_id: UUID4, *, max_attempts: int = 3 - ) -> None: - """Sync the executor from the registry. - - Args: - origin: The origin of the sync request - - Raises: - RegistryError: If the sync fails after all retries - """ - - @retry( - stop=stop_after_attempt(max_attempts), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - httpx.HTTPStatusError, - httpx.RequestError, - httpx.TimeoutException, - httpx.ConnectError, - ) - ), - ) - async def _sync_request() -> None: - try: - async with self._client() as client: - response = await client.post( - "/sync", - content=ExecutorSyncInput( - repository_id=repository_id - ).model_dump_json(), - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - except Exception as e: - logger.error("Error syncing executor", error=e) - raise - - try: - logger.info("Syncing executor", repository_id=repository_id) - _ = await _sync_request() - except httpx.HTTPStatusError as e: - raise RegistryError( - f"Failed to sync executor: HTTP {e.response.status_code}" - ) from e - except httpx.RequestError as e: - raise RegistryError( - f"Network error while syncing executor: {str(e)}" - ) from e - except Exception as e: - raise RegistryError( - f"Unexpected error while syncing executor: {str(e)}" - ) from e - # === Utility === def _handle_http_status_error( diff --git a/tracecat/executor/models.py b/tracecat/executor/models.py index b6bc87d18..c796320f8 100644 --- a/tracecat/executor/models.py +++ b/tracecat/executor/models.py @@ -1,10 +1,13 @@ from __future__ import annotations import traceback +from dataclasses import dataclass from pydantic import UUID4, BaseModel from tracecat.config import TRACECAT__APP_ENV +from tracecat.git import GitUrl +from tracecat.types.auth import Role class ExecutorSyncInput(BaseModel): @@ -56,3 +59,10 @@ def from_exc(e: Exception, action_name: str) -> ExecutorActionErrorInfo: function=tb.name, lineno=tb.lineno, ) + + +@dataclass +class DispatchActionContext: + role: Role + ssh_command: str | None = None + git_url: GitUrl | None = None diff --git a/tracecat/executor/router.py b/tracecat/executor/router.py index 298758bb2..f7002c730 100644 --- a/tracecat/executor/router.py +++ b/tracecat/executor/router.py @@ -4,47 +4,18 @@ from pydantic import BaseModel from tracecat.auth.credentials import RoleACL -from tracecat.contexts import ctx_logger, ctx_role +from tracecat.contexts import ctx_logger from tracecat.db.dependencies import AsyncDBSession from tracecat.dsl.models import RunActionInput -from tracecat.executor.models import ExecutorActionErrorInfo, ExecutorSyncInput +from tracecat.executor.models import ExecutorActionErrorInfo from tracecat.executor.service import dispatch_action_on_cluster from tracecat.logger import logger -from tracecat.registry.actions.models import ( - RegistryActionValidate, - RegistryActionValidateResponse, -) -from tracecat.registry.repository import RegistryReposService, Repository from tracecat.types.auth import Role -from tracecat.types.exceptions import WrappedExecutionError -from tracecat.validation.service import validate_registry_action_args +from tracecat.types.exceptions import TracecatSettingsError, WrappedExecutionError router = APIRouter() -@router.post("/sync") -async def sync_executor( - *, - role: Role = RoleACL( - allow_user=False, # XXX(authz): Users cannot sync the executor - allow_service=True, # Only services can sync the executor - require_workspace="no", - ), - session: AsyncDBSession, - input: ExecutorSyncInput, -) -> None: - """Sync the executor from the registry.""" - rr_service = RegistryReposService(session, role=role) - db_repo = await rr_service.get_repository_by_id(input.repository_id) - # If it doesn't exist, do nothing - if db_repo is None: - logger.info("Remote repository not found in DB, skipping") - return - # If it does exist, sync it - repo = Repository(db_repo.origin, role=role) - await repo.load_from_origin(commit_sha=db_repo.commit_sha) - - @router.post("/run/{action_name}", tags=["execution"]) async def run_action( *, @@ -53,18 +24,24 @@ async def run_action( allow_service=True, # Only services can execute actions require_workspace="no", ), + session: AsyncDBSession, action_name: str, action_input: RunActionInput, ) -> Any: """Execute a registry action.""" ref = action_input.task.ref - ctx_role.set(role) act_logger = logger.bind(role=role, action_name=action_name, ref=ref) ctx_logger.set(act_logger) act_logger.info("Starting action") + try: - return await dispatch_action_on_cluster(input=action_input, role=role) + return await dispatch_action_on_cluster(input=action_input, session=session) + except TracecatSettingsError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"message": str(e)}, + ) from e except WrappedExecutionError as e: # This is an error that occurred inside an executing action err = e.error @@ -85,33 +62,3 @@ async def run_action( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=err_info_dict, ) from e - - -@router.post("/validate/{action_name}") -async def validate_action( - *, - role: Role = RoleACL( - allow_user=False, # XXX(authz): Users cannot validate actions - allow_service=True, # Only services can validate actions - require_workspace="no", - ), - session: AsyncDBSession, - action_name: str, - params: RegistryActionValidate, -) -> RegistryActionValidateResponse: - """Validate a registry action.""" - try: - result = await validate_registry_action_args( - session=session, action_name=action_name, args=params.args - ) - - if result.status == "error": - logger.warning( - "Error validating UDF args", message=result.msg, details=result.detail - ) - return RegistryActionValidateResponse.from_validation_result(result) - except KeyError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Action {action_name!r} not found in registry", - ) from e diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index f4c982d7d..28ba3a69a 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -1,8 +1,3 @@ -"""Functions for executing actions and templates. - -NOTE: This is only used in the API server, not the worker -""" - from __future__ import annotations import asyncio @@ -13,6 +8,8 @@ import ray import uvloop from ray.exceptions import RayTaskError +from ray.runtime_env import RuntimeEnv +from sqlmodel.ext.asyncio.session import AsyncSession from tracecat import config from tracecat.auth.sandbox import AuthSandbox @@ -27,22 +24,22 @@ RunActionInput, ) from tracecat.executor.engine import EXECUTION_TIMEOUT -from tracecat.executor.models import ExecutorActionErrorInfo +from tracecat.executor.models import DispatchActionContext, ExecutorActionErrorInfo from tracecat.expressions.common import ExprContext, ExprOperand from tracecat.expressions.eval import ( eval_templated_object, extract_templated_secrets, get_iterables_from_expression, ) +from tracecat.git import GitUrl, prepare_git_url from tracecat.logger import logger from tracecat.parse import traverse_leaves -from tracecat.registry.actions.models import ( - BoundRegistryAction, -) +from tracecat.registry.actions.models import BoundRegistryAction from tracecat.registry.actions.service import RegistryActionsService from tracecat.secrets.common import apply_masks_object from tracecat.secrets.constants import DEFAULT_SECRETS_ENVIRONMENT from tracecat.secrets.secrets_manager import env_sandbox +from tracecat.ssh import opt_temp_key_file from tracecat.types.auth import Role from tracecat.types.exceptions import TracecatException, WrappedExecutionError @@ -319,15 +316,27 @@ def run_action_task(input: RunActionInput, role: Role) -> ExecutionResult: async def run_action_on_ray_cluster( - input: RunActionInput, role: Role + input: RunActionInput, ctx: DispatchActionContext ) -> ExecutionResult: """Run an action on the ray cluster. If any exceptions are thrown here, they're platform level errors. All application/user level errors are caught by the executor and returned as values. """ + # Initialize runtime environment variables + env_vars = {"GIT_SSH_COMMAND": ctx.ssh_command} if ctx.ssh_command else {} + additional_vars: dict[str, Any] = {} + + # Add git URL to pip dependencies if SHA is present + if ctx.git_url and ctx.git_url.ref: + url = ctx.git_url.to_url() + additional_vars["pip"] = [url] + logger.trace("Adding git URL to runtime env", git_url=ctx.git_url, url=url) - obj_ref = run_action_task.remote(input, role) + runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars) + + logger.info("Running action on ray cluster", runtime_env=runtime_env) + obj_ref = run_action_task.options(runtime_env=runtime_env).remote(input, ctx.role) try: coro = asyncio.to_thread(ray.get, obj_ref) exec_result = await asyncio.wait_for(coro, timeout=EXECUTION_TIMEOUT) @@ -349,7 +358,12 @@ async def run_action_on_ray_cluster( return exec_result -async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: +async def dispatch_action_on_cluster( + input: RunActionInput, + *, + session: AsyncSession, + git_url: GitUrl | None = None, +) -> Any: """Schedule actions on the ray cluster. This function handles dispatching actions to be executed on a Ray cluster. It supports @@ -358,7 +372,7 @@ async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: Args: input: The RunActionInput containing the task definition and execution context role: The Role used for authorization - + git_url: The Git URL to use for the action Returns: Any: For single actions, returns the ExecutionResult. For for_each loops, returns a list of results from all parallel executions. @@ -367,12 +381,26 @@ async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: TracecatException: If there are errors evaluating for_each expressions or during execution ExecutorErrorWrapper: If there are errors from the executor itself """ + git_url = await prepare_git_url() - task = input.task + role = ctx_role.get() + + async with opt_temp_key_file(git_url=git_url, session=session) as ssh_command: + logger.trace("SSH command", ssh_command=ssh_command) + ctx = DispatchActionContext(role=role, git_url=git_url, ssh_command=ssh_command) + result = await _dispatch_action(input=input, ctx=ctx) + return result + +async def _dispatch_action( + input: RunActionInput, + ctx: DispatchActionContext, +) -> Any: + task = input.task + logger.info("Preparing runtime environment", ctx=ctx) # If there's no for_each, execute normally if not task.for_each: - return await run_action_on_ray_cluster(input, role) + return await run_action_on_ray_cluster(input, ctx) logger.info("Running for_each on action in parallel", action=task.action) @@ -383,7 +411,7 @@ async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: iterators = get_iterables_from_expression(expr=task.for_each, operand=base_context) async def coro(patched_input: RunActionInput): - return await run_action_on_ray_cluster(patched_input, role) + return await run_action_on_ray_cluster(patched_input, ctx) try: async with GatheringTaskGroup() as tg: diff --git a/tracecat/git.py b/tracecat/git.py index 98f364471..9d4d978d8 100644 --- a/tracecat/git.py +++ b/tracecat/git.py @@ -7,7 +7,7 @@ from tracecat.contexts import ctx_role from tracecat.logger import logger from tracecat.registry.repositories.service import RegistryReposService -from tracecat.settings.service import get_setting +from tracecat.settings.service import get_setting_cached from tracecat.ssh import SshEnv from tracecat.types.auth import Role from tracecat.types.exceptions import TracecatSettingsError @@ -111,7 +111,7 @@ async def prepare_git_url(role: Role | None = None) -> GitUrl | None: role = role or ctx_role.get() # Handle the git repo - url = await get_setting( + url = await get_setting_cached( "git_repo_url", # TODO: Deprecate in future version default=config.TRACECAT__REMOTE_REPOSITORY_URL, @@ -122,10 +122,11 @@ async def prepare_git_url(role: Role | None = None) -> GitUrl | None: logger.debug("Runtime environment", url=url) - allowed_domains_setting = await get_setting( + allowed_domains_setting = await get_setting_cached( "git_allowed_domains", # TODO: Deprecate in future version - default=config.TRACECAT__ALLOWED_GIT_DOMAINS, + # Must be hashable + default=frozenset(config.TRACECAT__ALLOWED_GIT_DOMAINS), ) allowed_domains = cast(set[str], allowed_domains_setting or {"github.com"}) diff --git a/tracecat/registry/repositories/router.py b/tracecat/registry/repositories/router.py index 19f70c1dd..85fd41fea 100644 --- a/tracecat/registry/repositories/router.py +++ b/tracecat/registry/repositories/router.py @@ -6,7 +6,6 @@ from tracecat.auth.credentials import RoleACL from tracecat.db.dependencies import AsyncDBSession -from tracecat.executor.client import ExecutorClient from tracecat.logger import logger from tracecat.registry.actions.models import RegistryActionRead from tracecat.registry.actions.service import RegistryActionsService @@ -102,41 +101,6 @@ async def sync_registry_repository( ) from e -@router.post("/{repository_id}/sync-executor", status_code=status.HTTP_204_NO_CONTENT) -async def sync_executor_from_registry_repository( - *, - role: Role = RoleACL( - allow_user=True, - allow_service=False, - require_workspace="no", - min_access_level=AccessLevel.ADMIN, - ), - session: AsyncDBSession, - repository_id: UUID4, -): - # # We might want to update the executor's view of the repository here - # # (3) Update the executor's view of the repository - rr_service = RegistryReposService(session, role) - try: - repo = await rr_service.get_repository_by_id(repository_id) - except NoResultFound as e: - logger.error("Registry repository not found", repository_id=repository_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Registry repository not found", - ) from e - logger.info("Syncing executor", origin=repo.origin) - client = ExecutorClient(role=role) - try: - await client.sync_executor(repository_id=repo.id) - except RegistryError as e: - logger.warning("Error syncing executor", exc=e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Error while syncing executor {repo.origin!r}: {e}", - ) from e - - @router.get("") async def list_registry_repositories( *, diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index 57677e0bb..19a19d0f3 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -29,9 +29,10 @@ from tracecat import config from tracecat.contexts import ctx_role +from tracecat.db.engine import get_async_session_context_manager from tracecat.expressions.expectations import create_expectation_model from tracecat.expressions.validation import TemplateValidator -from tracecat.git import get_git_repository_sha, parse_git_url +from tracecat.git import GitUrl, get_git_repository_sha, parse_git_url from tracecat.logger import logger from tracecat.parse import safe_url from tracecat.registry.actions.models import BoundRegistryAction, TemplateAction @@ -41,14 +42,8 @@ ) from tracecat.registry.repositories.models import RegistryRepositoryCreate from tracecat.registry.repositories.service import RegistryReposService -from tracecat.secrets.service import SecretsService from tracecat.settings.service import get_setting -from tracecat.ssh import ( - SshEnv, - add_host_to_known_hosts, - add_ssh_key_to_agent, - temporary_ssh_agent, -) +from tracecat.ssh import SshEnv, ssh_context from tracecat.types.auth import Role from tracecat.types.exceptions import RegistryError @@ -118,10 +113,6 @@ def get(self, name: str) -> BoundRegistryAction[ArgsClsT]: """Retrieve a registered udf.""" return self._store[name] - def safe_remote_url(self, remote_registry_url: str) -> str: - """Clean a remote registry url.""" - return safe_url(remote_registry_url) - def init(self, include_base: bool = True, include_templates: bool = True) -> None: """Initialize the registry.""" if not self._is_initialized: @@ -223,11 +214,6 @@ def register_template_action( origin=origin, ) - def _reset(self) -> None: - logger.warning("Resetting registry") - self._store = {} - self._is_initialized = False - def _load_base_udfs(self) -> None: """Load all udfs and template actions into the registry.""" # Load udfs @@ -262,8 +248,9 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: or {"github.com"}, ) + cleaned_url = safe_url(self._origin) try: - git_url = parse_git_url(self._origin, allowed_domains=allowed_domains) + git_url = parse_git_url(cleaned_url, allowed_domains=allowed_domains) host = git_url.host org = git_url.org repo_name = git_url.repo @@ -288,35 +275,34 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: package_name=package_name, ) - cleaned_url = self.safe_remote_url(self._origin) - logger.debug("Cleaned URL", url=cleaned_url) + logger.debug("Git URL", git_url=git_url) commit_sha = await self._install_remote_repository( - host=host, repo_url=cleaned_url, commit_sha=commit_sha + git_url=git_url, commit_sha=commit_sha ) module = await self._load_remote_repository(cleaned_url, package_name) logger.info( "Imported and reloaded remote repository", module_name=module.__name__, package_name=package_name, + commit_sha=commit_sha, ) return commit_sha async def _install_remote_repository( - self, host: str, repo_url: str, commit_sha: str | None = None + self, git_url: GitUrl, commit_sha: str | None = None ) -> str: """Install the remote repository into the filesystem and return the commit sha.""" - logger.info("Getting SSH key", role=self.role) - async with SecretsService.with_session(role=self.role) as service: - secret = await service.get_ssh_key() - - async with temporary_ssh_agent() as env: - logger.info("Entered temporary SSH agent context") - await add_ssh_key_to_agent(secret.reveal().value, env=env) - await add_host_to_known_hosts(host, env=env) + url = git_url.to_url() + async with ( + get_async_session_context_manager() as session, + ssh_context(role=self.role, git_url=git_url, session=session) as env, + ): + if env is None: + raise RegistryError("No SSH key found") if commit_sha is None: - commit_sha = await get_git_repository_sha(repo_url, env=env) - await install_remote_repository(repo_url, commit_sha=commit_sha, env=env) + commit_sha = await get_git_repository_sha(url, env=env) + await install_remote_repository(url, commit_sha=commit_sha, env=env) return commit_sha async def _load_remote_repository( diff --git a/tracecat/secrets/service.py b/tracecat/secrets/service.py index 2c50b4738..c8df473e5 100644 --- a/tracecat/secrets/service.py +++ b/tracecat/secrets/service.py @@ -279,14 +279,13 @@ async def get_ssh_key( key_name: str = GIT_SSH_KEY_SECRET_NAME, environment: str | None = None, ) -> SecretKeyValue: - # NOTE: Don't set the workspace_id, as we want to search for - # organization secrets if it's not set. - logger.info("Getting SSH key", key_name=key_name, role=self.role) try: secret = await self.get_org_secret_by_name(key_name, environment) + key = self.decrypt_keys(secret.encrypted_keys)[0] + logger.debug("SSH key found", key_name=key_name, key_length=len(key.value)) + return key except TracecatNotFoundError as e: raise TracecatNotFoundError( f"SSH key {key_name} not found. Please check whether this key exists.\n\n" " If not, please create a key in your organization's credentials page and try again." ) from e - return self.decrypt_keys(secret.encrypted_keys)[0] diff --git a/tracecat/settings/service.py b/tracecat/settings/service.py index 1968e3626..e6e6ecf18 100644 --- a/tracecat/settings/service.py +++ b/tracecat/settings/service.py @@ -3,6 +3,7 @@ from typing import Any import orjson +from async_lru import alru_cache from pydantic import BaseModel, SecretStr from pydantic_core import to_jsonable_python from sqlmodel import col, select @@ -299,6 +300,29 @@ async def get_setting( return no_default_val +@alru_cache(ttl=30) +async def get_setting_cached( + key: str, + *, + role: Role | None = None, + session: AsyncSession | None = None, + default: Any | None = None, +) -> Any | None: + """Cached version of get_setting function. + + Args: + key: The setting key to retrieve + role: Optional role to use for permissions check + session: Optional database session to use + default: Optional default value if setting not found. Must be hashable. + + Returns: + The setting value or None if not found + """ + logger.debug("Cache miss", key=key) + return await get_setting(key, role=role, session=session, default=default) + + def get_setting_override(key: str) -> Any | None: """Get an environment override for a setting.""" # Only allow overrides for specific settings diff --git a/tracecat/ssh.py b/tracecat/ssh.py index 8f5979e0e..799ceb122 100644 --- a/tracecat/ssh.py +++ b/tracecat/ssh.py @@ -239,7 +239,6 @@ async def ssh_context( if git_url is None: yield None else: - logger.info("Getting SSH key", role=role, git_url=git_url) sec_svc = SecretsService(session, role=role) secret = await sec_svc.get_ssh_key() async with temporary_ssh_agent() as env: