Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify seq test util function #1

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 72 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import logging
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

import pytest
from asgiref.sync import sync_to_async
from django.conf import settings
from model_bakery import baker

from django_github_app.github import AsyncGitHubAPI

from .settings import DEFAULT_SETTINGS
from .utils import seq

@@ -52,37 +56,90 @@ def pytest_configure(config):


@pytest.fixture
def id_sequence_start():
return 1000
def installation_id():
return seq.next()


@pytest.fixture
def installation_id(id_sequence_start):
return seq(id_sequence_start)
def repository_id():
return seq.next()


@pytest.fixture
def installation_id_iter(id_sequence_start):
return seq.iter(id_sequence_start)
def installation():
return baker.make("django_github_app.Installation", installation_id=seq.next())


@pytest.fixture
def repository_id(id_sequence_start):
return seq(id_sequence_start)
async def ainstallation():
return await sync_to_async(baker.make)(
"django_github_app.Installation", installation_id=seq.next()
)


@pytest.fixture
def repository_id_iter(id_sequence_start):
return seq.iter(id_sequence_start)
def mock_github_api():
mock_api = AsyncMock(spec=AsyncGitHubAPI)

async def mock_getiter(*args, **kwargs):
test_issues = [
{
"number": 1,
"title": "Test Issue 1",
"state": "open",
},
{
"number": 2,
"title": "Test Issue 2",
"state": "closed",
},
]
for issue in test_issues:
yield issue

mock_api.getiter = mock_getiter
mock_api.__aenter__.return_value = mock_api
mock_api.__aexit__.return_value = None

return mock_api


@pytest.fixture
def installation(installation_id):
return baker.make("django_github_app.Installation", installation_id=installation_id)
def repository(installation, mock_github_api):
repository = baker.make(
"django_github_app.Repository",
repository_id=seq.next(),
full_name="owner/repo",
installation=installation,
)

mock_github_api.installation_id = repository.installation.installation_id

if isinstance(repository, list):
for repo in repository:
repo.get_gh_client = MagicMock(mock_github_api)
else:
repository.get_gh_client = MagicMock(return_value=mock_github_api)

return repository


@pytest.fixture
async def ainstallation(installation_id):
return await sync_to_async(baker.make)(
"django_github_app.Installation", installation_id=installation_id
async def arepository(ainstallation, mock_github_api):
installation = await ainstallation
repository = await sync_to_async(baker.make)(
"django_github_app.Repository",
repository_id=seq.next(),
full_name="owner/repo",
installation=installation,
)

mock_github_api.installation_id = repository.installation.installation_id

if isinstance(repository, list):
for repo in repository:
repo.get_gh_client = MagicMock(mock_github_api)
else:
repository.get_gh_client = MagicMock(return_value=mock_github_api)

return repository
8 changes: 4 additions & 4 deletions tests/events/test_installation.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ async def test_create_installation(installation_id, repository_id):
data = {
"installation": {
"id": installation_id,
"app_id": seq(1000),
"app_id": seq.next(),
},
"repositories": [
{"id": repository_id, "node_id": "node1234", "full_name": "owner/repo"}
@@ -107,12 +107,12 @@ async def test_sync_installation_data(ainstallation):
assert installation.data == data["installation"]


async def test_sync_installation_repositories(ainstallation, repository_id_iter):
async def test_sync_installation_repositories(ainstallation):
installation = await ainstallation
existing_repo = await sync_to_async(baker.make)(
"django_github_app.Repository",
installation=installation,
repository_id=next(repository_id_iter),
repository_id=seq.next(),
)

data = {
@@ -126,7 +126,7 @@ async def test_sync_installation_repositories(ainstallation, repository_id_iter)
],
"repositories_added": [
{
"id": next(repository_id_iter),
"id": seq.next(),
"node_id": "repo1234",
"full_name": "owner/repo",
}
307 changes: 9 additions & 298 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import datetime
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

import pytest
from asgiref.sync import sync_to_async
@@ -25,119 +23,11 @@
@pytest.fixture
def create_event():
def _create_event(data, event):
return sansio.Event(data=data, event=event, delivery_id=seq("delivery-"))
return sansio.Event(data=data, event=event, delivery_id=seq.next())

return _create_event


@pytest.fixture
def create_installation(installation_id):
def _create_installation(**kwargs):
return baker.make(
"django_github_app.Installation",
installation_id=kwargs.pop("installation_id", installation_id),
**kwargs,
)

return _create_installation


@pytest.fixture
def acreate_installation(installation_id):
async def _acreate_installation(**kwargs):
return await sync_to_async(baker.make)(
"django_github_app.Installation",
installation_id=kwargs.pop("installation_id", installation_id),
**kwargs,
)

return _acreate_installation


@pytest.fixture
def mock_github_api():
mock_api = AsyncMock(spec=AsyncGitHubAPI)

async def mock_getiter(*args, **kwargs):
test_issues = [
{
"number": 1,
"title": "Test Issue 1",
"state": "open",
},
{
"number": 2,
"title": "Test Issue 2",
"state": "closed",
},
]
for issue in test_issues:
yield issue

mock_api.getiter = mock_getiter
mock_api.__aenter__.return_value = mock_api
mock_api.__aexit__.return_value = None

return mock_api


@pytest.fixture
def create_repository(installation, mock_github_api, repository_id):
def _create_repository(**kwargs):
repository = baker.make(
"django_github_app.Repository",
repository_id=repository_id,
full_name=kwargs.pop("full_name", "owner/repo"),
installation=kwargs.pop("installation", installation),
**kwargs,
)

mock_github_api.installation_id = repository.installation.installation_id

if isinstance(repository, list):
for repo in repository:
repo.get_gh_client = MagicMock(mock_github_api)
else:
repository.get_gh_client = MagicMock(return_value=mock_github_api)
return repository

return _create_repository


@pytest.fixture
def acreate_repository(ainstallation, mock_github_api, repository_id):
async def _acreate_repository(**kwargs):
repository = await sync_to_async(baker.make)(
"django_github_app.Repository",
repository_id=repository_id,
full_name=kwargs.pop("full_name", "owner/repo"),
installation=kwargs.pop("installation", await ainstallation),
**kwargs,
)

mock_github_api.installation_id = repository.installation.installation_id

if isinstance(repository, list):
for repo in repository:
repo.get_gh_client = MagicMock(mock_github_api)
else:
repository.get_gh_client = MagicMock(return_value=mock_github_api)

return repository

return _acreate_repository


@pytest.fixture
def repository(create_repository):
return create_repository()


@pytest.fixture
async def arepository(acreate_repository):
return await acreate_repository()


class TestEventLogManager:
@pytest.mark.asyncio
async def test_acreate_from_event(self, create_event):
@@ -221,12 +111,12 @@ class TestInstallationManager:
@pytest.mark.asyncio
async def test_acreate_from_event(self, create_event):
repositories = [
{"id": seq(1000), "node_id": "node1", "full_name": "owner/repo1"},
{"id": seq(1000), "node_id": "node2", "full_name": "owner/repo2"},
{"id": seq.next(), "node_id": "node1", "full_name": "owner/repo1"},
{"id": seq.next(), "node_id": "node2", "full_name": "owner/repo2"},
]
installation_data = {
"id": seq(1000),
"app_id": seq(1000),
"id": seq.next(),
"app_id": seq.next(),
}
event = create_event(
{
@@ -249,12 +139,12 @@ async def test_acreate_from_event(self, create_event):

def test_create_from_event(self, create_event):
repositories = [
{"id": 1, "node_id": "node1", "full_name": "owner/repo1"},
{"id": 2, "node_id": "node2", "full_name": "owner/repo2"},
{"id": seq.next(), "node_id": "node1", "full_name": "owner/repo1"},
{"id": seq.next(), "node_id": "node2", "full_name": "owner/repo2"},
]
installation_data = {
"id": seq(1000),
"app_id": seq(1000),
"id": seq.next(),
"app_id": seq.next(),
}
event = create_event(
{
@@ -327,185 +217,6 @@ def test_from_event_invalid_action(self, create_event):
InstallationStatus.from_event(event)


# class TestInstallation:
# TEST_CASES_STATUS = [
# (InstallationStatus.ACTIVE, "created", InstallationStatus.ACTIVE),
# (InstallationStatus.INACTIVE, "created", InstallationStatus.ACTIVE),
# (InstallationStatus.ACTIVE, "deleted", InstallationStatus.INACTIVE),
# (InstallationStatus.INACTIVE, "deleted", InstallationStatus.INACTIVE),
# (
# InstallationStatus.ACTIVE,
# "new_permissions_accepted",
# InstallationStatus.ACTIVE,
# ),
# (
# InstallationStatus.INACTIVE,
# "new_permissions_accepted",
# InstallationStatus.ACTIVE,
# ),
# (InstallationStatus.ACTIVE, "suspend", InstallationStatus.INACTIVE),
# (InstallationStatus.INACTIVE, "suspend", InstallationStatus.INACTIVE),
# (InstallationStatus.ACTIVE, "unsuspend", InstallationStatus.ACTIVE),
# (InstallationStatus.INACTIVE, "unsuspend", InstallationStatus.ACTIVE),
# ]
#
# @pytest.fixture
# def installation_event(self, create_event):
# def _installation_event(action, installation_id):
# return create_event(
# {"action": action, "installation": {"id": installation_id}},
# "installation",
# )
#
# return _installation_event
#
# @pytest.mark.parametrize("status,action,expected", TEST_CASES_STATUS)
# @pytest.mark.asyncio
# async def test_atoggle_status_from_event(
# self, status, action, expected, acreate_installation, installation_event
# ):
# installation = await acreate_installation(status=status)
# event = installation_event(action, installation.installation_id)
#
# assert installation.status == status
#
# await installation.atoggle_status_from_event(event)
#
# assert installation.status == expected
#
# @pytest.mark.parametrize("status,action,expected", TEST_CASES_STATUS)
# def test_toggle_status_from_event(
# self, status, action, expected, create_installation, installation_event
# ):
# installation = create_installation(status=status)
# event = installation_event(action, installation.installation_id)
#
# assert installation.status == status
#
# installation.toggle_status_from_event(event)
#
# assert installation.status == expected
#
# @pytest.mark.asyncio
# async def test_async_data_from_event(self, ainstallation, create_event):
# data = {"installation": {"foo": "bar"}}
# event = create_event(data, "installation")
# installation = await ainstallation
#
# await installation.async_data_from_event(event)
#
# assert installation.data == data["installation"]
#
# def test_sync_data_from_event(self, installation, create_event):
# data = {"installation": {"foo": "bar"}}
# event = create_event(data, "installation")
#
# installation.sync_data_from_event(event)
#
# assert installation.data == data["installation"]
#
# @pytest.mark.asyncio
# async def test_async_repositories_from_event(
# self, ainstallation, acreate_repository, create_event
# ):
# installation = await ainstallation
#
# removed_repos = [
# {
# "id": repo.repository_id,
# "node_id": repo.repository_node_id,
# "full_name": repo.full_name,
# }
# for repo in await acreate_repository(
# installation=installation,
# repository_id=itertools.cycle(seq.iter(1000)),
# _quantity=2,
# )
# ]
# added_repos = [
# {
# "id": i,
# "node_id": f"node{i}",
# "full_name": f"owner/repo{i}",
# }
# for i in itertools.islice(seq.iter(1000), 2)
# ]
#
# event = create_event(
# {
# "repositories_removed": removed_repos,
# "repositories_added": added_repos,
# },
# "installation",
# )
#
# await installation.async_repositories_from_event(event)
#
# remaining = await Repository.objects.filter(
# repository_id__in=[r["id"] for r in removed_repos]
# ).acount()
# assert remaining == 0
#
# new_repos = await Repository.objects.filter(
# repository_id__in=[r["id"] for r in added_repos]
# ).acount()
# assert new_repos == len(added_repos)
#
# installation_repos = await Repository.objects.filter(
# installation=installation
# ).acount()
# assert installation_repos == len(added_repos)
#
# def test_sync_repositories_from_event(
# self, installation, create_repository, create_event
# ):
# removed_repos = [
# {
# "id": repo.repository_id,
# "node_id": repo.repository_node_id,
# "full_name": repo.full_name,
# }
# for repo in create_repository(
# installation=installation,
# repository_id=itertools.cycle(seq.iter(1000)),
# _quantity=2,
# )
# ]
# added_repos = [
# {
# "id": i,
# "node_id": f"node{i}",
# "full_name": f"owner/repo{i}",
# }
# for i in itertools.islice(seq.iter(1000), 2)
# ]
#
# event = create_event(
# {
# "repositories_removed": removed_repos,
# "repositories_added": added_repos,
# },
# "installation",
# )
#
# installation.sync_repositories_from_event(event)
#
# remaining = Repository.objects.filter(
# repository_id__in=[r["id"] for r in removed_repos]
# )
# assert remaining.count() == 0
#
# new_repos = Repository.objects.filter(
# repository_id__in=[r["id"] for r in added_repos]
# ).count()
# assert new_repos == len(added_repos)
#
# installation_repos = Repository.objects.filter(
# installation=installation
# ).count()
# assert installation_repos == len(added_repos)


class TestRepositoryManager:
@pytest.mark.asyncio
async def test_aget_from_event(self, arepository, create_event):
348 changes: 145 additions & 203 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,265 +1,207 @@
from __future__ import annotations

import asyncio
import datetime
import itertools
from threading import Thread
import random
import threading
import time

import pytest

from .utils import seq
from .utils import SequenceGenerator


@pytest.fixture(autouse=True)
def clear_seq_state():
seq._instances = {}
seq._locks = {}
yield
seq._instances = {}
seq._locks = {}
@pytest.fixture
def seq():
"""Fixture to reset the counter before each test."""
SequenceGenerator._instance = None
return SequenceGenerator()


def test_basic_number_sequence():
assert seq(1000) == 1001
assert seq(1000) == 1002
assert seq(1000) == 1003
def test_multithreading(seq):
num_threads = 100
increments_per_thread = 100

def worker():
for _ in range(increments_per_thread):
seq.next()

def test_string_sequence_with_suffix():
assert seq("User-", suffix="-test") == "User-1-test"
assert seq("User-", suffix="-test") == "User-2-test"
assert seq("User-", suffix="-test") == "User-3-test"


def test_custom_start():
assert seq("User-", start=10) == "User-10"
assert seq("User-", start=10) == "User-11"
assert seq("User-", start=10) == "User-12"


def test_custom_increment():
assert seq(1000, increment_by=10) == 1010
assert seq(1000, increment_by=10) == 1020
assert seq(1000, increment_by=10) == 1030


def test_datetime_sequence():
start_date = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
increment = datetime.timedelta(days=1)

assert seq(start_date, increment_by=increment) == datetime.datetime(
2024, 1, 2, tzinfo=datetime.timezone.utc
)
assert seq(start_date, increment_by=increment) == datetime.datetime(
2024, 1, 3, tzinfo=datetime.timezone.utc
)
threads = [threading.Thread(target=worker) for _ in range(num_threads)]

for thread in threads:
thread.start()
for thread in threads:
thread.join()

def test_date_sequence():
start = datetime.date(2024, 1, 1)
increment = datetime.timedelta(days=1)
assert seq.next() == num_threads * increments_per_thread + 1

assert seq(start, increment_by=increment) == datetime.date(2024, 1, 2)
assert seq(start, increment_by=increment) == datetime.date(2024, 1, 3)

@pytest.mark.asyncio
async def test_async(seq):
num_coroutines = 100
increments_per_coroutine = 100

def test_time_sequence():
start = datetime.time(12, 0)
increment = datetime.timedelta(hours=1)
async def worker():
for _ in range(increments_per_coroutine):
seq.next()

assert seq(start, increment_by=increment) == datetime.time(13, 0)
assert seq(start, increment_by=increment) == datetime.time(14, 0)
await asyncio.gather(*(worker() for _ in range(num_coroutines)))

assert seq.next() == num_coroutines * increments_per_coroutine + 1

def test_float_sequence():
assert seq(1.5, increment_by=0.5) == 2.0
assert seq(1.5, increment_by=0.5) == 2.5
assert seq(1.5, increment_by=0.5) == 3.0

@pytest.mark.asyncio
async def test_combined(seq):
num_threads = 50
num_coroutines = 50
increments_per_thread = 50
increments_per_coroutine = 50

def test_same_value_diff_increment_by():
assert seq(1000, increment_by=1) == 1001
assert seq(1000, increment_by=2) == 1002
assert seq(1000, increment_by=1) == 1002
assert seq(1000, increment_by=2) == 1004
def thread_worker():
for _ in range(increments_per_thread):
seq.next()

async def async_worker():
for _ in range(increments_per_coroutine):
seq.next()

def test_same_value_diff_suffix():
assert seq("User-", suffix="-test1") == "User-1-test1"
assert seq("User-", suffix="-test2") == "User-1-test2"
assert seq("User-", suffix="-test1") == "User-2-test1"
assert seq("User-", suffix="-test2") == "User-2-test2"
threads = [threading.Thread(target=thread_worker) for _ in range(num_threads)]

for thread in threads:
thread.start()

def test_invalid_suffix():
with pytest.raises(
TypeError, match="Sequences with suffix can only be used with text values"
):
seq(1000, suffix="-test")
await asyncio.gather(*(async_worker() for _ in range(num_coroutines)))

for thread in threads:
thread.join()

def test_invalid_datetime_increment():
start_date = datetime.datetime.now(datetime.timezone.utc)
with pytest.raises(TypeError, match="increment_by must be a datetime.timedelta"):
seq(start_date, increment_by=1)
expected_value = (num_threads * increments_per_thread) + (
num_coroutines * increments_per_coroutine
)
assert seq.next() == expected_value + 1


def test_safety_threads():
results: list[int] = []
num_threads = 50
iterations_per_thread = 100
def test_multithreading_with_sequence_validation(seq):
num_threads = 100
increments_per_thread = 100
results: set[int] = set()

def worker():
for _ in range(iterations_per_thread):
results.append(seq(1000))
for _ in range(increments_per_thread):
value = seq.next()
results.add(value)
# Random sleep to try to force race conditions
if random.random() < 0.1:
time.sleep(0.001)

threads = [Thread(target=worker) for _ in range(num_threads)]
threads = [threading.Thread(target=worker) for _ in range(num_threads)]

for t in threads:
t.start()
for t in threads:
t.join()
for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert len(results) == num_threads * iterations_per_thread
assert len(set(results)) == len(results)
assert sorted(results) == list(range(1001, 1001 + len(results)))
expected_set = set(range(1, num_threads * increments_per_thread + 1))
assert results == expected_set
assert seq.next() == num_threads * increments_per_thread + 1


@pytest.mark.asyncio
async def test_safety_async():
results: list[int] = []
num_tasks = 50
iterations_per_task = 100
async def test_async_with_sequence_validation(seq):
num_coroutines = 100
increments_per_coroutine = 100
results: set[int] = set()

async def worker():
for _ in range(iterations_per_task):
results.append(seq(1000))
# Simulate some async work
await asyncio.sleep(0.001)

tasks = [asyncio.create_task(worker()) for _ in range(num_tasks)]
await asyncio.gather(*tasks)
for _ in range(increments_per_coroutine):
value = seq.next()
results.add(value)
# Random sleep to try to force race conditions
if random.random() < 0.1:
await asyncio.sleep(0.001)

assert len(results) == num_tasks * iterations_per_task
assert len(set(results)) == len(results)
assert sorted(results) == list(range(1001, 1001 + len(results)))
await asyncio.gather(*(worker() for _ in range(num_coroutines)))

expected_set = set(range(1, num_coroutines * increments_per_coroutine + 1))
assert results == expected_set
assert seq.next() == num_coroutines * increments_per_coroutine + 1

def test_multiple_sequences_threads():
results1 = []
results2 = []
num_threads = 20

def worker1():
results1.append(seq(1000))
@pytest.mark.asyncio
async def test_combined_with_sequence_validation(seq):
num_threads = 50
num_coroutines = 50
increments_per_thread = 50
increments_per_coroutine = 50
results: set[int] = set()

def worker2():
results2.append(seq(2000))
def thread_worker():
for _ in range(increments_per_thread):
value = seq.next()
results.add(value)
if random.random() < 0.1:
time.sleep(0.001)

threads = []
for _ in range(num_threads):
t1 = Thread(target=worker1)
t2 = Thread(target=worker2)
threads.extend([t1, t2])
t1.start()
t2.start()
async def async_worker():
for _ in range(increments_per_coroutine):
value = seq.next()
results.add(value)
if random.random() < 0.1:
await asyncio.sleep(0.001)

for t in threads:
t.join()
threads = [threading.Thread(target=thread_worker) for _ in range(num_threads)]

assert len(results1) == num_threads
assert len(set(results1)) == len(results1)
assert sorted(results1) == list(range(1001, 1001 + len(results1)))
for thread in threads:
thread.start()

assert len(results2) == num_threads
assert len(set(results2)) == len(results2)
assert sorted(results2) == list(range(2001, 2001 + len(results2)))
await asyncio.gather(*(async_worker() for _ in range(num_coroutines)))

for thread in threads:
thread.join()

@pytest.mark.asyncio
async def test_multiple_sequences_async():
results1 = []
results2 = []
num_tasks = 20
expected_total = (num_threads * increments_per_thread) + (
num_coroutines * increments_per_coroutine
)
expected_set = set(range(1, expected_total + 1))
assert results == expected_set
assert seq.next() == expected_total + 1

async def worker1():
results1.append(seq(1000))
# Simulate some async work
await asyncio.sleep(0.001)

async def worker2():
results2.append(seq(2000))
# Simulate some async work
await asyncio.sleep(0.001)
def test_large_scale(seq):
"""Test with a larger number of threads and increments"""
num_threads = 200
increments_per_thread = 1000
results: set[int] = set()

tasks = []
for _ in range(num_tasks):
tasks.append(asyncio.create_task(worker1()))
tasks.append(asyncio.create_task(worker2()))
def worker():
for _ in range(increments_per_thread):
value = seq.next()
results.add(value)

await asyncio.gather(*tasks)
threads = [threading.Thread(target=worker) for _ in range(num_threads)]

assert len(results1) == num_tasks
assert len(set(results1)) == len(results1)
assert sorted(results1) == list(range(1001, 1001 + len(results1)))
for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert len(results2) == num_tasks
assert len(set(results2)) == len(results2)
assert sorted(results2) == list(range(2001, 2001 + len(results2)))
expected_set = set(range(1, num_threads * increments_per_thread + 1))
assert results == expected_set
assert seq.next() == num_threads * increments_per_thread + 1


@pytest.mark.asyncio
async def test_multiple_sequences_async_complex():
num_tasks = 20
results_number = []
results_string = []
results_date = []

start_date = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)

async def worker_number():
results_number.append(seq(1000))
await asyncio.sleep(0.001)

async def worker_string():
results_string.append(seq("User-", suffix="-test"))
await asyncio.sleep(0.001)

async def worker_date():
results_date.append(seq(start_date, increment_by=datetime.timedelta(days=1)))
await asyncio.sleep(0.001)

tasks = []
for _ in range(num_tasks):
tasks.extend(
[
asyncio.create_task(worker_number()),
asyncio.create_task(worker_string()),
asyncio.create_task(worker_date()),
]
)

await asyncio.gather(*tasks)

assert len(results_number) == num_tasks
assert len(set(results_number)) == len(results_number)
assert sorted(results_number) == list(range(1001, 1001 + len(results_number)))

assert len(results_string) == num_tasks
assert len(set(results_string)) == len(results_string)
assert sorted(results_string) == sorted(
[f"User-{i}-test" for i in range(1, num_tasks + 1)]
)
def test_error_conditions(seq):
"""Test error conditions and edge cases"""
seq2 = SequenceGenerator()
assert seq is seq2

assert len(results_date) == num_tasks
assert len(set(results_date)) == len(results_date)
assert sorted(results_date) == [
start_date + datetime.timedelta(days=i) for i in range(1, num_tasks + 1)
]
def create_instance():
SequenceGenerator()

threads = [threading.Thread(target=create_instance) for _ in range(100)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

def test_sequence_iterator():
cycled = itertools.cycle(seq.iter(1))
assert next(cycled) == 2
assert next(cycled) == 3
assert next(cycled) == 4
instances = [SequenceGenerator() for _ in range(10)]
assert all(instance is seq for instance in instances)
224 changes: 17 additions & 207 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,217 +1,27 @@
from __future__ import annotations

import datetime
import warnings
from collections.abc import Iterator
from threading import Lock
from typing import Any

EPOCH = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc)

class SequenceGenerator:
_instance = None
_lock = Lock()

class seq:
"""A thread-safe sequence generator that mimics model-bakery's seq functionality.
def __init__(self):
self._counter = 1

This class provides a way to generate sequential values for use in tests, particularly
with Django models and model-bakery. Unlike model-bakery's seq, this implementation
is thread-safe and works reliably in async/concurrent test environments.
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._counter = 1
return cls._instance

The class maintains separate sequences for different parameter combinations using
class-level state, protected by locks for thread safety. It supports numbers,
strings, dates, times, and datetimes.
def next(self):
with self._lock:
current = self._counter
self._counter += 1
return current

Examples:
Simple number sequence:
>>> seq(1000)
1001
>>> seq(1000)
1002
>>> seq(1000)
1003

String sequence with suffix:
>>> seq("User-", suffix="-test")
'User-1-test'
>>> seq("User-", suffix="-test")
'User-2-test'
String sequence with custom start:
>>> seq("User-", start=10)
'User-10'
>>> seq("User-", start=10)
'User-11'
Number sequence with custom increment:
>>> seq(1000, increment_by=10)
1010
>>> seq(1000, increment_by=10)
1020
DateTime sequence:
>>> start_date = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
>>> first = seq(start_date, increment_by=datetime.timedelta(days=1))
>>> first.isoformat()
'2024-01-02T00:00:00+00:00'
>>> second = seq(start_date, increment_by=datetime.timedelta(days=1))
>>> second.isoformat()
'2024-01-03T00:00:00+00:00'
Date sequence:
>>> start = datetime.date(2024, 1, 1)
>>> first = seq(start, increment_by=datetime.timedelta(days=1))
>>> str(first)
'2024-01-02'
>>> second = seq(start, increment_by=datetime.timedelta(days=1))
>>> str(second)
'2024-01-03'
Time sequence:
>>> start = datetime.time(12, 0)
>>> first = seq(start, increment_by=datetime.timedelta(hours=1))
>>> str(first)
'13:00:00'
>>> second = seq(start, increment_by=datetime.timedelta(hours=1))
>>> str(second)
'14:00:00'
"""

_instances = {}
_locks = {}

def __init__(
self,
value: Any,
increment_by: int | float | datetime.timedelta = 1,
start: int | float | None = None,
suffix: str | None = None,
):
"""Initialize sequence parameters."""
self._validate_parameters(value, increment_by, start, suffix)
self.value = value
self.increment_by = increment_by
self.start = start
self.suffix = suffix
self._current = 0
self._increment = 0
self._base = None

def _validate_parameters(
self,
value: Any,
increment_by: int | float | datetime.timedelta,
start: int | float | None,
suffix: str | None,
) -> None:
"""Validate sequence parameters match model-bakery's requirements."""
if suffix and not isinstance(value, str):
raise TypeError("Sequences with suffix can only be used with text values")

if isinstance(value, (datetime.datetime, datetime.date, datetime.time)):
if not isinstance(increment_by, datetime.timedelta):
raise TypeError(
"Sequences with values datetime.datetime, datetime.date and datetime.time, "
"increment_by must be a datetime.timedelta."
)
if start:
warnings.warn(
"start parameter is ignored when using seq with date, time or datetime objects",
stacklevel=1,
)

def __new__(
cls,
value: Any,
increment_by: int | float | datetime.timedelta = 1,
start: int | float | None = None,
suffix: str | None = None,
):
key = (value, increment_by, start, suffix)

if key not in cls._locks:
cls._locks[key] = Lock()
cls._instances[key] = super().__new__(cls)

instance = cls._instances[key]

if not hasattr(instance, "_initialized"):
instance.__init__(value, increment_by, start, suffix)
instance._initialized = True
if isinstance(value, (datetime.datetime, datetime.date, datetime.time)):
instance._initialize_datetime_sequence(value, increment_by)
else:
instance._initialize_basic_sequence(value, increment_by, start)

with cls._locks[key]:
instance._current += instance._increment

if isinstance(value, (datetime.datetime, datetime.date)):
return instance._generate_datetime_value()
elif isinstance(value, datetime.time):
return instance._generate_time_value()
elif isinstance(instance._base, (int, float)):
return instance._generate_numeric_value()
else:
return instance._generate_text_value()

def _initialize_datetime_sequence(
self,
value: datetime.datetime | datetime.date | datetime.time,
increment_by: datetime.timedelta,
) -> None:
if isinstance(value, datetime.datetime):
date = value
elif isinstance(value, datetime.date):
date = datetime.datetime.combine(value, datetime.datetime.now().time())
else:
date = datetime.datetime.combine(EPOCH.date(), value)

epoch = EPOCH.replace(tzinfo=date.tzinfo)
self._current = (date - epoch).total_seconds()
self._increment = increment_by.total_seconds()

def _initialize_basic_sequence(
self, value: Any, increment_by: Any, start: Any
) -> None:
self._current = 0 if start is None else start - increment_by
self._increment = increment_by
self._base = value

def _generate_time_value(self) -> datetime.time:
total_seconds = self._current % (24 * 3600)
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = int(total_seconds % 60)
return datetime.time(hours, minutes, seconds)

def _generate_datetime_value(self) -> datetime.datetime | datetime.date:
tz = self.value.tzinfo if isinstance(self.value, datetime.datetime) else None
result = datetime.datetime.fromtimestamp(self._current, tz)

if isinstance(self.value, datetime.date) and not isinstance(
self.value, datetime.datetime
):
return result.date()
return result

def _generate_numeric_value(self) -> int | float:
if not isinstance(self._base, (int, float)):
raise ValueError("base must be a numeric type")
return self._base + self._current

def _generate_text_value(self) -> str:
value = [self._base, self._current]
if self.suffix:
value.append(self.suffix)
stringified_value = [str(v) for v in value]
return "".join(stringified_value)

@classmethod
def iter(
cls,
value: Any,
increment_by: int | float | datetime.timedelta = 1,
start: int | float | None = None,
suffix: str | None = None,
) -> Iterator[Any]:
while True:
yield cls(value, increment_by=increment_by, start=start, suffix=suffix)
seq = SequenceGenerator()