diff --git a/pyproject.toml b/pyproject.toml index 633a0b59e..4b7b95086 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ dev = [ "pytest-mock", "pytest==8.3.2", "python-dotenv", + "psycopg>=3.1.19", "respx", ] diff --git a/tests/conftest.py b/tests/conftest.py index 7bc869e8f..e20e962eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,13 +2,18 @@ import json import os import uuid -from collections.abc import Iterator +from collections.abc import AsyncGenerator, Iterator from pathlib import Path from typing import Any import httpx import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession +from tests.database import TEST_DB_CONFIG from tracecat import config from tracecat.contexts import ctx_role from tracecat.db.engine import get_async_engine, get_async_session_context_manager @@ -34,6 +39,7 @@ def monkeysession(request: pytest.FixtureRequest): @pytest.fixture(autouse=True, scope="function") async def test_db_engine(): + """Create a new engine for each integration test.""" engine = get_async_engine() try: yield engine @@ -42,6 +48,63 @@ async def test_db_engine(): await engine.dispose() +@pytest.fixture(scope="session") +def db() -> Iterator[None]: + """Session-scoped fixture to create and teardown test database using sync SQLAlchemy.""" + + default_engine = create_engine( + TEST_DB_CONFIG.sys_url_sync, isolation_level="AUTOCOMMIT" + ) + + termination_query = text( + f""" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{TEST_DB_CONFIG.test_db_name}' + AND pid <> pg_backend_pid(); + """ + ) + + try: + with default_engine.connect() as conn: + # Terminate existing connections + conn.execute(termination_query) + # Create test database + conn.execute(text(f'CREATE DATABASE "{TEST_DB_CONFIG.test_db_name}"')) + logger.info("Created test database") + + # Create sync engine for test db + test_engine = create_engine(TEST_DB_CONFIG.test_url_sync) + with test_engine.begin() as conn: + logger.info("Creating all tables") + SQLModel.metadata.create_all(conn) + yield + finally: + test_engine.dispose() + # # Cleanup - reconnect to system db to drop test db + with default_engine.begin() as conn: + conn.execute(termination_query) + conn.execute( + text(f'DROP DATABASE IF EXISTS "{TEST_DB_CONFIG.test_db_name}"') + ) + logger.info("Dropped test database") + default_engine.dispose() + + +@pytest.fixture(scope="function") +async def session() -> AsyncGenerator[AsyncSession, None]: + """Creates a new database session with (with working transaction) + for test duration. Use this for unit tests.""" + async_engine = create_async_engine(TEST_DB_CONFIG.test_url) + async_session = AsyncSession(async_engine, expire_on_commit=False) + try: + await async_session.begin_nested() + yield async_session + finally: + await async_session.rollback() # Rollback any changes made during the test + await async_engine.dispose() + + @pytest.fixture(autouse=True, scope="session") def env_sandbox(monkeysession: pytest.MonkeyPatch): from dotenv import load_dotenv diff --git a/tests/database.py b/tests/database.py new file mode 100644 index 000000000..ec8c23775 --- /dev/null +++ b/tests/database.py @@ -0,0 +1,29 @@ +import uuid +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DBConfig: + test_db_name: str + base_url: str + + @property + def test_url(self) -> str: + return f"{self.base_url}{self.test_db_name}" + + @property + def test_url_sync(self) -> str: + return self.test_url.replace("+asyncpg", "+psycopg") + + @property + def sys_url(self) -> str: + return f"{self.base_url}postgres" + + @property + def sys_url_sync(self) -> str: + return self.sys_url.replace("+asyncpg", "+psycopg") + + +TEST_DB_NAME = f"test_db_{uuid.uuid4()}" +TEST_DB_URL_BASE = "postgresql+asyncpg://postgres:postgres@localhost:5432/" +TEST_DB_CONFIG = DBConfig(TEST_DB_NAME, TEST_DB_URL_BASE) diff --git a/tests/unit/test_secrets_service.py b/tests/unit/test_secrets_service.py new file mode 100644 index 000000000..762a302e8 --- /dev/null +++ b/tests/unit/test_secrets_service.py @@ -0,0 +1,214 @@ +import uuid +from collections.abc import AsyncGenerator + +import pytest +from pydantic import SecretStr +from sqlmodel.ext.asyncio.session import AsyncSession + +from tracecat.config import TRACECAT__DEFAULT_ORG_ID +from tracecat.db.schemas import Workspace +from tracecat.secrets.enums import SecretType +from tracecat.secrets.models import ( + SecretCreate, + SecretKeyValue, + SecretSearch, + SecretUpdate, +) +from tracecat.secrets.service import SecretsService +from tracecat.types.auth import Role +from tracecat.types.exceptions import TracecatNotFoundError + +pytestmark = pytest.mark.usefixtures("db") + + +@pytest.fixture +async def workspace( + session: AsyncSession, +) -> AsyncGenerator[Workspace, None]: + """Create a test workspace.""" + workspace = Workspace( + name="test-workspace", + owner_id=TRACECAT__DEFAULT_ORG_ID, + ) # type: ignore + session.add(workspace) + await session.commit() + yield workspace + await session.delete(workspace) + await session.commit() + + +@pytest.fixture +async def role(workspace: Workspace) -> Role: + """Create a test role.""" + role = Role( + type="user", + workspace_id=workspace.id, + user_id=uuid.uuid4(), + service_id="tracecat-api", + ) + return role + + +@pytest.fixture +async def service(session: AsyncSession, role: Role) -> SecretsService: + """Create a secrets service instance for testing.""" + return SecretsService(session=session, role=role) + + +@pytest.fixture +def secret_create_params() -> SecretCreate: + """Sample secret creation parameters.""" + return SecretCreate( + name="test-secret", + type=SecretType.SSH_KEY, + description="Test secret", + tags={"test": "test"}, + keys=[SecretKeyValue(key="private_key", value=SecretStr("test-private-key"))], + environment="test", + ) + + +@pytest.mark.anyio +class TestSecretsService: + async def test_create_and_get_secret( + self, service: SecretsService, secret_create_params: SecretCreate + ) -> None: + """Test creating and retrieving a secret.""" + # Create secret + await service.create_secret(secret_create_params) + + # Retrieve by name + secret = await service.get_secret_by_name( + secret_create_params.name, raise_on_error=True + ) + assert secret is not None + assert secret.name == secret_create_params.name + assert secret.type == secret_create_params.type + assert secret.description == secret_create_params.description + assert secret.tags == secret_create_params.tags + assert secret.environment == secret_create_params.environment + + # Verify decrypted keys + decrypted_keys = service.decrypt_keys(secret.encrypted_keys) + assert len(decrypted_keys) == 1 + assert decrypted_keys[0].key == secret_create_params.keys[0].key + assert decrypted_keys[0].value == secret_create_params.keys[0].value + + async def test_update_secret( + self, service: SecretsService, secret_create_params: SecretCreate + ) -> None: + """Test updating a secret.""" + # Create initial secret + await service.create_secret(secret_create_params) + + # Update parameters + update_params = SecretUpdate( + description="Updated description", + keys=[SecretKeyValue(key="new_key", value=SecretStr("new_value"))], + ) + + # Update secret + await service.update_secret_by_name(secret_create_params.name, update_params) + + # Verify updates + updated_secret = await service.get_secret_by_name( + secret_create_params.name, raise_on_error=True + ) + assert updated_secret.description == update_params.description + decrypted_keys = service.decrypt_keys(updated_secret.encrypted_keys) + assert len(decrypted_keys) == 1 + assert decrypted_keys[0].key == "new_key" + assert decrypted_keys[0].value.get_secret_value() == "new_value" + + async def test_delete_secret( + self, service: SecretsService, secret_create_params: SecretCreate + ) -> None: + """Test deleting a secret.""" + # Create secret + await service.create_secret(secret_create_params) + + # Get secret to obtain ID + secret = await service.get_secret_by_name( + secret_create_params.name, raise_on_error=True + ) + assert secret is not None + + # Delete secret + await service.delete_secret_by_id(secret.id) + + # Verify deletion + deleted_secret = await service.get_secret_by_name( + secret_create_params.name, raise_on_error=False + ) + assert deleted_secret is None + + async def test_list_secrets( + self, service: SecretsService, secret_create_params: SecretCreate + ) -> None: + """Test listing secrets.""" + # Create multiple secrets + await service.create_secret(secret_create_params) + + second_secret = SecretCreate( + name="test-secret-2", + type=SecretType.CUSTOM, + description="Second test secret", + tags={"test": "test"}, + keys=[SecretKeyValue(key="api_key", value=SecretStr("test-api-key"))], + environment="test", + ) + await service.create_secret(second_secret) + + # List all secrets + secrets = await service.list_secrets() + assert len(secrets) >= 2 + + # List secrets by type + api_secrets = await service.list_secrets(types={SecretType.CUSTOM}) + assert len(api_secrets) >= 1 + assert all(s.type == SecretType.CUSTOM for s in api_secrets) + + async def test_get_ssh_key( + self, service: SecretsService, secret_create_params: SecretCreate + ) -> None: + """Test retrieving SSH key.""" + # Create SSH key secret + await service.create_secret(secret_create_params) + + # Retrieve SSH key + ssh_key = await service.get_ssh_key(secret_create_params.name) + assert ssh_key.key == "private_key" + assert ssh_key.value.get_secret_value() == "test-private-key" + + async def test_get_nonexistent_ssh_key(self, service: SecretsService) -> None: + """Test retrieving non-existent SSH key.""" + with pytest.raises(TracecatNotFoundError): + await service.get_ssh_key("nonexistent-key") + + async def test_search_secrets( + self, service: SecretsService, secret_create_params: SecretCreate + ) -> None: + """Test searching secrets.""" + # Create a secret + await service.create_secret(secret_create_params) + + # Search by name + found_secrets = await service.search_secrets( + params=SecretSearch( + names={secret_create_params.name}, + environment=secret_create_params.environment, + ) + ) + assert len(found_secrets) == 1 + assert found_secrets[0].name == secret_create_params.name + + # Search by environment + env_secrets = await service.search_secrets( + params=SecretSearch( + environment=secret_create_params.environment, + ) + ) + assert len(env_secrets) >= 1 + assert all( + s.environment == secret_create_params.environment for s in env_secrets + )