Skip to content

Commit

Permalink
test: Add test database for service tests (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt authored Dec 23, 2024
1 parent 74bb3d7 commit 33e427e
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ dev = [
"pytest-mock",
"pytest==8.3.2",
"python-dotenv",
"psycopg>=3.1.19",
"respx",
]

Expand Down
65 changes: 64 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
import json
import os
import uuid
from collections.abc import Iterator
from collections.abc import AsyncGenerator, Iterator
from pathlib import Path
from typing import Any

import httpx
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession

from tests.database import TEST_DB_CONFIG
from tracecat import config
from tracecat.contexts import ctx_role
from tracecat.db.engine import get_async_engine, get_async_session_context_manager
Expand All @@ -34,6 +39,7 @@ def monkeysession(request: pytest.FixtureRequest):

@pytest.fixture(autouse=True, scope="function")
async def test_db_engine():
"""Create a new engine for each integration test."""
engine = get_async_engine()
try:
yield engine
Expand All @@ -42,6 +48,63 @@ async def test_db_engine():
await engine.dispose()


@pytest.fixture(scope="session")
def db() -> Iterator[None]:
"""Session-scoped fixture to create and teardown test database using sync SQLAlchemy."""

default_engine = create_engine(
TEST_DB_CONFIG.sys_url_sync, isolation_level="AUTOCOMMIT"
)

termination_query = text(
f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{TEST_DB_CONFIG.test_db_name}'
AND pid <> pg_backend_pid();
"""
)

try:
with default_engine.connect() as conn:
# Terminate existing connections
conn.execute(termination_query)
# Create test database
conn.execute(text(f'CREATE DATABASE "{TEST_DB_CONFIG.test_db_name}"'))
logger.info("Created test database")

# Create sync engine for test db
test_engine = create_engine(TEST_DB_CONFIG.test_url_sync)
with test_engine.begin() as conn:
logger.info("Creating all tables")
SQLModel.metadata.create_all(conn)
yield
finally:
test_engine.dispose()
# # Cleanup - reconnect to system db to drop test db
with default_engine.begin() as conn:
conn.execute(termination_query)
conn.execute(
text(f'DROP DATABASE IF EXISTS "{TEST_DB_CONFIG.test_db_name}"')
)
logger.info("Dropped test database")
default_engine.dispose()


@pytest.fixture(scope="function")
async def session() -> AsyncGenerator[AsyncSession, None]:
"""Creates a new database session with (with working transaction)
for test duration. Use this for unit tests."""
async_engine = create_async_engine(TEST_DB_CONFIG.test_url)
async_session = AsyncSession(async_engine, expire_on_commit=False)
try:
await async_session.begin_nested()
yield async_session
finally:
await async_session.rollback() # Rollback any changes made during the test
await async_engine.dispose()


@pytest.fixture(autouse=True, scope="session")
def env_sandbox(monkeysession: pytest.MonkeyPatch):
from dotenv import load_dotenv
Expand Down
29 changes: 29 additions & 0 deletions tests/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import uuid
from dataclasses import dataclass


@dataclass(frozen=True)
class DBConfig:
test_db_name: str
base_url: str

@property
def test_url(self) -> str:
return f"{self.base_url}{self.test_db_name}"

@property
def test_url_sync(self) -> str:
return self.test_url.replace("+asyncpg", "+psycopg")

@property
def sys_url(self) -> str:
return f"{self.base_url}postgres"

@property
def sys_url_sync(self) -> str:
return self.sys_url.replace("+asyncpg", "+psycopg")


TEST_DB_NAME = f"test_db_{uuid.uuid4()}"
TEST_DB_URL_BASE = "postgresql+asyncpg://postgres:postgres@localhost:5432/"
TEST_DB_CONFIG = DBConfig(TEST_DB_NAME, TEST_DB_URL_BASE)
214 changes: 214 additions & 0 deletions tests/unit/test_secrets_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import uuid
from collections.abc import AsyncGenerator

import pytest
from pydantic import SecretStr
from sqlmodel.ext.asyncio.session import AsyncSession

from tracecat.config import TRACECAT__DEFAULT_ORG_ID
from tracecat.db.schemas import Workspace
from tracecat.secrets.enums import SecretType
from tracecat.secrets.models import (
SecretCreate,
SecretKeyValue,
SecretSearch,
SecretUpdate,
)
from tracecat.secrets.service import SecretsService
from tracecat.types.auth import Role
from tracecat.types.exceptions import TracecatNotFoundError

pytestmark = pytest.mark.usefixtures("db")


@pytest.fixture
async def workspace(
session: AsyncSession,
) -> AsyncGenerator[Workspace, None]:
"""Create a test workspace."""
workspace = Workspace(
name="test-workspace",
owner_id=TRACECAT__DEFAULT_ORG_ID,
) # type: ignore
session.add(workspace)
await session.commit()
yield workspace
await session.delete(workspace)
await session.commit()


@pytest.fixture
async def role(workspace: Workspace) -> Role:
"""Create a test role."""
role = Role(
type="user",
workspace_id=workspace.id,
user_id=uuid.uuid4(),
service_id="tracecat-api",
)
return role


@pytest.fixture
async def service(session: AsyncSession, role: Role) -> SecretsService:
"""Create a secrets service instance for testing."""
return SecretsService(session=session, role=role)


@pytest.fixture
def secret_create_params() -> SecretCreate:
"""Sample secret creation parameters."""
return SecretCreate(
name="test-secret",
type=SecretType.SSH_KEY,
description="Test secret",
tags={"test": "test"},
keys=[SecretKeyValue(key="private_key", value=SecretStr("test-private-key"))],
environment="test",
)


@pytest.mark.anyio
class TestSecretsService:
async def test_create_and_get_secret(
self, service: SecretsService, secret_create_params: SecretCreate
) -> None:
"""Test creating and retrieving a secret."""
# Create secret
await service.create_secret(secret_create_params)

# Retrieve by name
secret = await service.get_secret_by_name(
secret_create_params.name, raise_on_error=True
)
assert secret is not None
assert secret.name == secret_create_params.name
assert secret.type == secret_create_params.type
assert secret.description == secret_create_params.description
assert secret.tags == secret_create_params.tags
assert secret.environment == secret_create_params.environment

# Verify decrypted keys
decrypted_keys = service.decrypt_keys(secret.encrypted_keys)
assert len(decrypted_keys) == 1
assert decrypted_keys[0].key == secret_create_params.keys[0].key
assert decrypted_keys[0].value == secret_create_params.keys[0].value

async def test_update_secret(
self, service: SecretsService, secret_create_params: SecretCreate
) -> None:
"""Test updating a secret."""
# Create initial secret
await service.create_secret(secret_create_params)

# Update parameters
update_params = SecretUpdate(
description="Updated description",
keys=[SecretKeyValue(key="new_key", value=SecretStr("new_value"))],
)

# Update secret
await service.update_secret_by_name(secret_create_params.name, update_params)

# Verify updates
updated_secret = await service.get_secret_by_name(
secret_create_params.name, raise_on_error=True
)
assert updated_secret.description == update_params.description
decrypted_keys = service.decrypt_keys(updated_secret.encrypted_keys)
assert len(decrypted_keys) == 1
assert decrypted_keys[0].key == "new_key"
assert decrypted_keys[0].value.get_secret_value() == "new_value"

async def test_delete_secret(
self, service: SecretsService, secret_create_params: SecretCreate
) -> None:
"""Test deleting a secret."""
# Create secret
await service.create_secret(secret_create_params)

# Get secret to obtain ID
secret = await service.get_secret_by_name(
secret_create_params.name, raise_on_error=True
)
assert secret is not None

# Delete secret
await service.delete_secret_by_id(secret.id)

# Verify deletion
deleted_secret = await service.get_secret_by_name(
secret_create_params.name, raise_on_error=False
)
assert deleted_secret is None

async def test_list_secrets(
self, service: SecretsService, secret_create_params: SecretCreate
) -> None:
"""Test listing secrets."""
# Create multiple secrets
await service.create_secret(secret_create_params)

second_secret = SecretCreate(
name="test-secret-2",
type=SecretType.CUSTOM,
description="Second test secret",
tags={"test": "test"},
keys=[SecretKeyValue(key="api_key", value=SecretStr("test-api-key"))],
environment="test",
)
await service.create_secret(second_secret)

# List all secrets
secrets = await service.list_secrets()
assert len(secrets) >= 2

# List secrets by type
api_secrets = await service.list_secrets(types={SecretType.CUSTOM})
assert len(api_secrets) >= 1
assert all(s.type == SecretType.CUSTOM for s in api_secrets)

async def test_get_ssh_key(
self, service: SecretsService, secret_create_params: SecretCreate
) -> None:
"""Test retrieving SSH key."""
# Create SSH key secret
await service.create_secret(secret_create_params)

# Retrieve SSH key
ssh_key = await service.get_ssh_key(secret_create_params.name)
assert ssh_key.key == "private_key"
assert ssh_key.value.get_secret_value() == "test-private-key"

async def test_get_nonexistent_ssh_key(self, service: SecretsService) -> None:
"""Test retrieving non-existent SSH key."""
with pytest.raises(TracecatNotFoundError):
await service.get_ssh_key("nonexistent-key")

async def test_search_secrets(
self, service: SecretsService, secret_create_params: SecretCreate
) -> None:
"""Test searching secrets."""
# Create a secret
await service.create_secret(secret_create_params)

# Search by name
found_secrets = await service.search_secrets(
params=SecretSearch(
names={secret_create_params.name},
environment=secret_create_params.environment,
)
)
assert len(found_secrets) == 1
assert found_secrets[0].name == secret_create_params.name

# Search by environment
env_secrets = await service.search_secrets(
params=SecretSearch(
environment=secret_create_params.environment,
)
)
assert len(env_secrets) >= 1
assert all(
s.environment == secret_create_params.environment for s in env_secrets
)

0 comments on commit 33e427e

Please sign in to comment.