diff --git a/README.md b/README.md index f6c678e..c8afbf4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,17 @@ # ruia-peewee-async -A Ruia plugin that uses the peewee-async to store data to MySQL + +A [Ruia](https://github.com/howie6879/ruia) plugin that uses [peewee-async](https://github.com/05bit/peewee-async) to store data to MySQL or PostgreSQL or both. + + +## Installation + +```shell +pip install ruia-peewee-async +``` + +## Usage + + +```python + +``` diff --git a/ruia_peewee_async/__init__.py b/ruia_peewee_async/__init__.py index f5e11f3..960d1ed 100644 --- a/ruia_peewee_async/__init__.py +++ b/ruia_peewee_async/__init__.py @@ -6,13 +6,22 @@ from peewee import DoesNotExist, Model, Query from peewee_async import Manager, MySQLDatabase, PostgresqlDatabase from pymysql import OperationalError -from ruia import Spider +from ruia import Spider as RuiaSpider from ruia.exceptions import SpiderHookError +class Spider(RuiaSpider): + mysql_model: Union[Model, Dict] + mysql_manager: Manager + postgres_model: Union[Model, Dict] + postgres_manager: Manager + mysql_db: MySQLDatabase + postgres_db: PostgresqlDatabase + + class TargetDB(Enum): MYSQL = 0 - POSTGRESQL = 1 + POSTGRES = 1 BOTH = 2 @@ -32,7 +41,7 @@ async def process(spider_ins, callback_result): try: if database == TargetDB.MYSQL: await spider_ins.mysql_manager.create(spider_ins.mysql_model, **data) - elif database == TargetDB.POSTGRESQL: + elif database == TargetDB.POSTGRES: await spider_ins.postgres_manager.create( spider_ins.postgres_model, **data ) @@ -69,42 +78,31 @@ def __init__( self.only = only @staticmethod - async def _update(spider_ins, data, database, query, create_when_not_exists, only): - if database == TargetDB.MYSQL: + async def _deal_update( + spider_ins, data, query, create_when_not_exists, only, databases + ): + for database in databases: + database = database.lower() + manager: Manager = getattr(spider_ins, f"{database}_manager") + model: Model = getattr(spider_ins, f"{database}_model") try: - model_ins = await spider_ins.mysql_manager.get( - spider_ins.mysql_model, **query - ) + model_ins = await manager.get(model, **query) except DoesNotExist: if create_when_not_exists: - await spider_ins.mysql_manager.create( - spider_ins.mysql_model, **data - ) + await manager.create(model, **data) else: model_ins.__data__.update(data) - await spider_ins.mysql_manager.update(model_ins, only=only) - elif database == TargetDB.POSTGRESQL: - model_ins, created = await spider_ins.postgres_manager.get_or_create( - spider_ins.postgres_model, query, defaults=data - ) - if not created: - model_ins.__data__.update(data) - await spider_ins.postgres_manager.update(model_ins, only=only) - elif database == TargetDB.BOTH: - model_ins, created = await spider_ins.mysql_manager.get_or_create( - spider_ins.mysql_model, query, defaults=data - ) - if not created: - model_ins.__data__.update(data) - await spider_ins.mysql_manager.update(model_ins, only=only) - model_ins, created = await spider_ins.postgres_manager.get_or_create( - spider_ins.postgres_model, query, defaults=data - ) - if not created: - model_ins.__data__.update(data) - await spider_ins.postgres_manager.update(model_ins, only=only) + await manager.update(model_ins, only=only) + + @staticmethod + async def _update(spider_ins, data, query, database, create_when_not_exists, only): + if database == TargetDB.BOTH: + databases = [TargetDB.MYSQL.name, TargetDB.POSTGRES.name] else: - raise ValueError(f"TargetDB Enum value error: {database}") + databases = [database.name] + await RuiaPeeweeUpdate._deal_update( + spider_ins, data, query, create_when_not_exists, only, databases + ) @staticmethod async def process(spider_ins, callback_result): @@ -123,7 +121,7 @@ async def process(spider_ins, callback_result): ) try: await RuiaPeeweeUpdate._update( - spider_ins, data, database, query, create_when_not_exists, only + spider_ins, data, query, database, create_when_not_exists, only ) except OperationalError as ope: spider_ins.logger.error( @@ -134,8 +132,8 @@ async def process(spider_ins, callback_result): def init_spider(*, spider_ins: Spider): - mysql_config = getattr(spider_ins, "mysql_config", None) - postgres_config = getattr(spider_ins, "postgres_config", None) + mysql_config = getattr(spider_ins, "mysql_config", {}) + postgres_config = getattr(spider_ins, "postgres_config", {}) if ( (not mysql_config and not postgres_config) or (mysql_config and not isinstance(mysql_config, dict)) diff --git a/tests/common.py b/tests/common.py index fb6eee2..c00a2d4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +import typing -# -*- coding: utf-8 -*- -from peewee import CharField, Model -from ruia import AttrField, Item, Response, Spider, TextField +from ruia import AttrField, Item, Middleware, Response, TextField + +from ruia_peewee_async import RuiaPeeweeInsert, RuiaPeeweeUpdate, Spider, TargetDB class HackerNewsItem(Item): @@ -23,7 +24,44 @@ async def parse(self, response: Response): yield item -class ResultModel(Model): +class Insert(HackerNewsSpider): + def __init__( + self, + middleware: typing.Union[typing.Iterable, Middleware] = None, + loop=None, + is_async_start: bool = False, + cancel_tasks: bool = True, + target_db: TargetDB = TargetDB.MYSQL, + **spider_kwargs, + ): + self.target_db = target_db + super().__init__( + middleware, loop, is_async_start, cancel_tasks, **spider_kwargs + ) + + async def parse(self, response): + async for item in super().parse(response): + yield RuiaPeeweeInsert(item.results, database=self.target_db) + + +class Update(HackerNewsSpider): + def __init__( + self, + middleware: typing.Union[typing.Iterable, Middleware] = None, + loop=None, + is_async_start: bool = False, + cancel_tasks: bool = True, + target_db: TargetDB = TargetDB.MYSQL, + **spider_kwargs, + ): + self.target_db = target_db + super().__init__( + middleware, loop, is_async_start, cancel_tasks, **spider_kwargs + ) - title = CharField() - url = CharField() + async def parse(self, response): + async for item in super().parse(response): + res = {} + res["title"] = item.results["title"] + res["url"] = "http://testing.com" + yield RuiaPeeweeUpdate(res, {"title": res["title"]}, self.target_db) diff --git a/tests/conftest.py b/tests/conftest.py index 8b8c5bd..fd50d5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- from logging import getLogger +import psycopg2 import pymysql import pytest logger = getLogger(__name__) -def check(mysql_config) -> bool: +def check_mysql(mysql_config) -> bool: try: connection = pymysql.connect(**mysql_config) return connection.open @@ -16,17 +17,40 @@ def check(mysql_config) -> bool: return False -@pytest.fixture(scope="session") +@pytest.fixture(scope="class") def mysql(docker_ip, docker_services): port = docker_services.port_for("mysql", 3306) mysql_config = { "host": docker_ip, "port": port, - "user": "root", + "user": "ruiamysql", "password": "abc123", "database": "ruiamysql", } - mysql_info = f"docker_ip: {docker_ip}, port: {port}" - logger.info(mysql_info) - docker_services.wait_until_responsive(lambda: check(mysql_config), 300, 10) + docker_services.wait_until_responsive(lambda: check_mysql(mysql_config), 300, 10) return mysql_config + + +def check_postgres(postgres_config): + try: + conn = psycopg2.connect(**postgres_config) + return conn.status == psycopg2.extensions.STATUS_READY + except psycopg2.OperationalError: + logger.info("Waitting for PostgreSQL starting completed.") + return False + + +@pytest.fixture(scope="class") +def postgresql(docker_ip, docker_services): + port = docker_services.port_for("postgres", 5432) + postgres_config = { + "host": docker_ip, + "port": port, + "user": "ruiapostgres", + "password": "abc123", + "database": "ruiapostgres", + } + docker_services.wait_until_responsive( + lambda: check_postgres(postgres_config), 300, 10 + ) + return postgres_config diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 28ff374..4d602bb 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -15,6 +15,6 @@ services: ports: - '54321:5432' environment: - POSTGRES_USER: ruia_peewee_async + POSTGRES_USER: ruiapostgres POSTGRES_PASSWORD: abc123 image: 'postgres:latest' diff --git a/tests/test_both.py b/tests/test_both.py new file mode 100644 index 0000000..7be64e5 --- /dev/null +++ b/tests/test_both.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +from random import randint + +import pytest +from peewee import CharField + +from ruia_peewee_async import TargetDB, init_spider + +from .common import Insert, Update + + +class BothInsert(Insert): + async def parse(self, response): + async for item in super().parse(response): + yield item + + +class BothUpdate(Update): + async def parse(self, response): + async for item in super().parse(response): + yield item + + +def basic_setup(mysql, postgresql): + async def init_after_start(spider_ins): + spider_ins.mysql_config = mysql + spider_ins.mysql_model = { + "table_name": "ruia_mysql_both", + "title": CharField(), + "url": CharField(), + } + spider_ins.postgres_config = postgresql + spider_ins.postgres_model = { + "table_name": "ruia_postgres_both", + "title": CharField(), + "url": CharField(), + } + init_spider(spider_ins=spider_ins) + + return init_after_start + + +class TestBoth: + @pytest.mark.dependency() + async def test_both_insert(self, mysql, postgresql, event_loop): + after_start = basic_setup(mysql, postgresql) + spider_ins = await BothInsert.async_start( + loop=event_loop, after_start=after_start, target_db=TargetDB.BOTH + ) + count_mysql = await spider_ins.mysql_manager.count( + spider_ins.mysql_model.select() + ) + count_postgres = await spider_ins.postgres_manager.count( + spider_ins.postgres_model.select() + ) + assert count_mysql >= 10, "Should insert 10 rows in MySQL." + assert count_postgres >= 10, "Should insert 10 rows in PostgreSQL." + + @pytest.mark.dependency(depends=["TestBoth::test_both_insert"]) + async def test_both_update(self, mysql, postgresql, event_loop): + after_start = basic_setup(mysql, postgresql) + spider_ins = await BothUpdate.async_start( + loop=event_loop, after_start=after_start, target_db=TargetDB.BOTH + ) + mysql_one = await spider_ins.mysql_manager.get( + spider_ins.mysql_model, id=randint(1, 11) + ) + postgres_one = await spider_ins.postgres_manager.get( + spider_ins.postgres_model, id=randint(1, 11) + ) + assert mysql_one.url == "http://testing.com" + assert postgres_one.url == "http://testing.com" diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 90c2fdb..21fc327 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -1,37 +1,31 @@ # -*- coding: utf-8 -*- -from logging import getLogger from random import randint import pytest from peewee import CharField -from ruia_peewee_async import RuiaPeeweeInsert, RuiaPeeweeUpdate, init_spider +from ruia_peewee_async import init_spider -from .common import HackerNewsSpider +from .common import Insert, Update -logger = getLogger(__name__) - -class MySQLInsert(HackerNewsSpider): +class MySQLInsert(Insert): async def parse(self, response): async for item in super().parse(response): - yield RuiaPeeweeInsert(item.results) + yield item -class MySQLUpdate(HackerNewsSpider): +class MySQLUpdate(Update): async def parse(self, response): async for item in super().parse(response): - res = {} - res["title"] = item.results["title"] - res["url"] = "http://testing.com" - yield RuiaPeeweeUpdate(res, {"title": res["title"]}) + yield item -def base_setup(mysql): +def basic_setup(mysql): async def init_after_start(spider_ins): spider_ins.mysql_config = mysql spider_ins.mysql_model = { - "table_name": "ruia_mysql_test", + "table_name": "ruia_mysql", "title": CharField(), "url": CharField(), } @@ -40,22 +34,23 @@ async def init_after_start(spider_ins): return init_after_start -@pytest.mark.dependency() -async def test_mysql_insert(mysql, event_loop): - after_start = base_setup(mysql) - spider_ins = await MySQLInsert.async_start(loop=event_loop, after_start=after_start) - count = await spider_ins.mysql_manager.count(spider_ins.mysql_model.select()) - one = await spider_ins.mysql_manager.get(spider_ins.mysql_model, id=randint(1, 11)) - one_msg = f"One data, title: {one.title}, url: {one.url}" - logger.info(one_msg) - assert count >= 10, "Should insert 10 rows in MySQL." - - -@pytest.mark.dependency(depends=["test_mysql_insert"]) -async def test_mysql_update(mysql, event_loop): - after_start = base_setup(mysql) - spider_ins = await MySQLUpdate.async_start(loop=event_loop, after_start=after_start) - one = await spider_ins.mysql_manager.get(spider_ins.mysql_model, id=randint(1, 11)) - one_msg = f"One data, title: {one.title}, url: {one.url}" - logger.info(one_msg) - assert one.url == "http://testing.com" +class TestMySQL: + @pytest.mark.dependency() + async def test_mysql_insert(self, mysql, event_loop): + after_start = basic_setup(mysql) + spider_ins = await MySQLInsert.async_start( + loop=event_loop, after_start=after_start + ) + count = await spider_ins.mysql_manager.count(spider_ins.mysql_model.select()) + assert count >= 10, "Should insert 10 rows in MySQL." + + @pytest.mark.dependency(depends=["TestMySQL::test_mysql_insert"]) + async def test_mysql_update(self, mysql, event_loop): + after_start = basic_setup(mysql) + spider_ins = await MySQLUpdate.async_start( + loop=event_loop, after_start=after_start + ) + one = await spider_ins.mysql_manager.get( + spider_ins.mysql_model, id=randint(1, 11) + ) + assert one.url == "http://testing.com" diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py new file mode 100644 index 0000000..5bbbec3 --- /dev/null +++ b/tests/test_postgresql.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +from random import randint + +import pytest +from peewee import CharField + +from ruia_peewee_async import TargetDB, init_spider + +from .common import Insert, Update + + +class PostgresqlInsert(Insert): + async def parse(self, response): + async for item in super().parse(response): + yield item + + +class PostgresqlUpdate(Update): + async def parse(self, response): + async for item in super().parse(response): + yield item + + +def basic_setup(postgresql): + async def init_after_start(spider_ins): + spider_ins.postgres_config = postgresql + spider_ins.postgres_model = { + "table_name": "ruia_postgres", + "title": CharField(), + "url": CharField(), + } + init_spider(spider_ins=spider_ins) + + return init_after_start + + +class TestPostgreSQL: + @pytest.mark.dependency() + async def test_postgres_insert(self, postgresql, event_loop): + after_start = basic_setup(postgresql) + spider_ins = await PostgresqlInsert.async_start( + loop=event_loop, after_start=after_start, target_db=TargetDB.POSTGRES + ) + count = await spider_ins.postgres_manager.count( + spider_ins.postgres_model.select() + ) + assert count >= 10, "Should insert 10 rows in PostgreSQL." + + @pytest.mark.dependency(depends=["TestPostgreSQL::test_postgres_insert"]) + async def test_postgres_update(self, postgresql, event_loop): + after_start = basic_setup(postgresql) + spider_ins = await PostgresqlUpdate.async_start( + loop=event_loop, after_start=after_start, target_db=TargetDB.POSTGRES + ) + one = await spider_ins.postgres_manager.get( + spider_ins.postgres_model, id=randint(1, 11) + ) + assert one.url == "http://testing.com"