Skip to content

Commit

Permalink
Completed all tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Deng <[email protected]>
  • Loading branch information
JackTheMico committed Aug 23, 2022
1 parent e955dc4 commit e0e7549
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 83 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
# ruia-peewee-async
A Ruia plugin that uses the peewee-async to store data to MySQL

A [Ruia](https://github.com/howie6879/ruia) plugin that uses [peewee-async](https://github.com/05bit/peewee-async) to store data to MySQL or PostgreSQL or both.


## Installation

```shell
pip install ruia-peewee-async
```

## Usage


```python

```
70 changes: 34 additions & 36 deletions ruia_peewee_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,22 @@
from peewee import DoesNotExist, Model, Query
from peewee_async import Manager, MySQLDatabase, PostgresqlDatabase
from pymysql import OperationalError
from ruia import Spider
from ruia import Spider as RuiaSpider
from ruia.exceptions import SpiderHookError


class Spider(RuiaSpider):
mysql_model: Union[Model, Dict]
mysql_manager: Manager
postgres_model: Union[Model, Dict]
postgres_manager: Manager
mysql_db: MySQLDatabase
postgres_db: PostgresqlDatabase


class TargetDB(Enum):
MYSQL = 0
POSTGRESQL = 1
POSTGRES = 1
BOTH = 2


Expand All @@ -32,7 +41,7 @@ async def process(spider_ins, callback_result):
try:
if database == TargetDB.MYSQL:
await spider_ins.mysql_manager.create(spider_ins.mysql_model, **data)
elif database == TargetDB.POSTGRESQL:
elif database == TargetDB.POSTGRES:
await spider_ins.postgres_manager.create(
spider_ins.postgres_model, **data
)
Expand Down Expand Up @@ -69,42 +78,31 @@ def __init__(
self.only = only

@staticmethod
async def _update(spider_ins, data, database, query, create_when_not_exists, only):
if database == TargetDB.MYSQL:
async def _deal_update(
spider_ins, data, query, create_when_not_exists, only, databases
):
for database in databases:
database = database.lower()
manager: Manager = getattr(spider_ins, f"{database}_manager")
model: Model = getattr(spider_ins, f"{database}_model")
try:
model_ins = await spider_ins.mysql_manager.get(
spider_ins.mysql_model, **query
)
model_ins = await manager.get(model, **query)
except DoesNotExist:
if create_when_not_exists:
await spider_ins.mysql_manager.create(
spider_ins.mysql_model, **data
)
await manager.create(model, **data)
else:
model_ins.__data__.update(data)
await spider_ins.mysql_manager.update(model_ins, only=only)
elif database == TargetDB.POSTGRESQL:
model_ins, created = await spider_ins.postgres_manager.get_or_create(
spider_ins.postgres_model, query, defaults=data
)
if not created:
model_ins.__data__.update(data)
await spider_ins.postgres_manager.update(model_ins, only=only)
elif database == TargetDB.BOTH:
model_ins, created = await spider_ins.mysql_manager.get_or_create(
spider_ins.mysql_model, query, defaults=data
)
if not created:
model_ins.__data__.update(data)
await spider_ins.mysql_manager.update(model_ins, only=only)
model_ins, created = await spider_ins.postgres_manager.get_or_create(
spider_ins.postgres_model, query, defaults=data
)
if not created:
model_ins.__data__.update(data)
await spider_ins.postgres_manager.update(model_ins, only=only)
await manager.update(model_ins, only=only)

@staticmethod
async def _update(spider_ins, data, query, database, create_when_not_exists, only):
if database == TargetDB.BOTH:
databases = [TargetDB.MYSQL.name, TargetDB.POSTGRES.name]
else:
raise ValueError(f"TargetDB Enum value error: {database}")
databases = [database.name]
await RuiaPeeweeUpdate._deal_update(
spider_ins, data, query, create_when_not_exists, only, databases
)

@staticmethod
async def process(spider_ins, callback_result):
Expand All @@ -123,7 +121,7 @@ async def process(spider_ins, callback_result):
)
try:
await RuiaPeeweeUpdate._update(
spider_ins, data, database, query, create_when_not_exists, only
spider_ins, data, query, database, create_when_not_exists, only
)
except OperationalError as ope:
spider_ins.logger.error(
Expand All @@ -134,8 +132,8 @@ async def process(spider_ins, callback_result):


def init_spider(*, spider_ins: Spider):
mysql_config = getattr(spider_ins, "mysql_config", None)
postgres_config = getattr(spider_ins, "postgres_config", None)
mysql_config = getattr(spider_ins, "mysql_config", {})
postgres_config = getattr(spider_ins, "postgres_config", {})
if (
(not mysql_config and not postgres_config)
or (mysql_config and not isinstance(mysql_config, dict))
Expand Down
50 changes: 44 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import typing

# -*- coding: utf-8 -*-
from peewee import CharField, Model
from ruia import AttrField, Item, Response, Spider, TextField
from ruia import AttrField, Item, Middleware, Response, TextField

from ruia_peewee_async import RuiaPeeweeInsert, RuiaPeeweeUpdate, Spider, TargetDB


class HackerNewsItem(Item):
Expand All @@ -23,7 +24,44 @@ async def parse(self, response: Response):
yield item


class ResultModel(Model):
class Insert(HackerNewsSpider):
def __init__(
self,
middleware: typing.Union[typing.Iterable, Middleware] = None,
loop=None,
is_async_start: bool = False,
cancel_tasks: bool = True,
target_db: TargetDB = TargetDB.MYSQL,
**spider_kwargs,
):
self.target_db = target_db
super().__init__(
middleware, loop, is_async_start, cancel_tasks, **spider_kwargs
)

async def parse(self, response):
async for item in super().parse(response):
yield RuiaPeeweeInsert(item.results, database=self.target_db)


class Update(HackerNewsSpider):
def __init__(
self,
middleware: typing.Union[typing.Iterable, Middleware] = None,
loop=None,
is_async_start: bool = False,
cancel_tasks: bool = True,
target_db: TargetDB = TargetDB.MYSQL,
**spider_kwargs,
):
self.target_db = target_db
super().__init__(
middleware, loop, is_async_start, cancel_tasks, **spider_kwargs
)

title = CharField()
url = CharField()
async def parse(self, response):
async for item in super().parse(response):
res = {}
res["title"] = item.results["title"]
res["url"] = "http://testing.com"
yield RuiaPeeweeUpdate(res, {"title": res["title"]}, self.target_db)
36 changes: 30 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
from logging import getLogger

import psycopg2
import pymysql
import pytest

logger = getLogger(__name__)


def check(mysql_config) -> bool:
def check_mysql(mysql_config) -> bool:
try:
connection = pymysql.connect(**mysql_config)
return connection.open
Expand All @@ -16,17 +17,40 @@ def check(mysql_config) -> bool:
return False


@pytest.fixture(scope="session")
@pytest.fixture(scope="class")
def mysql(docker_ip, docker_services):
port = docker_services.port_for("mysql", 3306)
mysql_config = {
"host": docker_ip,
"port": port,
"user": "root",
"user": "ruiamysql",
"password": "abc123",
"database": "ruiamysql",
}
mysql_info = f"docker_ip: {docker_ip}, port: {port}"
logger.info(mysql_info)
docker_services.wait_until_responsive(lambda: check(mysql_config), 300, 10)
docker_services.wait_until_responsive(lambda: check_mysql(mysql_config), 300, 10)
return mysql_config


def check_postgres(postgres_config):
try:
conn = psycopg2.connect(**postgres_config)
return conn.status == psycopg2.extensions.STATUS_READY
except psycopg2.OperationalError:
logger.info("Waitting for PostgreSQL starting completed.")
return False


@pytest.fixture(scope="class")
def postgresql(docker_ip, docker_services):
port = docker_services.port_for("postgres", 5432)
postgres_config = {
"host": docker_ip,
"port": port,
"user": "ruiapostgres",
"password": "abc123",
"database": "ruiapostgres",
}
docker_services.wait_until_responsive(
lambda: check_postgres(postgres_config), 300, 10
)
return postgres_config
2 changes: 1 addition & 1 deletion tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ services:
ports:
- '54321:5432'
environment:
POSTGRES_USER: ruia_peewee_async
POSTGRES_USER: ruiapostgres
POSTGRES_PASSWORD: abc123
image: 'postgres:latest'
72 changes: 72 additions & 0 deletions tests/test_both.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
from random import randint

import pytest
from peewee import CharField

from ruia_peewee_async import TargetDB, init_spider

from .common import Insert, Update


class BothInsert(Insert):
async def parse(self, response):
async for item in super().parse(response):
yield item


class BothUpdate(Update):
async def parse(self, response):
async for item in super().parse(response):
yield item


def basic_setup(mysql, postgresql):
async def init_after_start(spider_ins):
spider_ins.mysql_config = mysql
spider_ins.mysql_model = {
"table_name": "ruia_mysql_both",
"title": CharField(),
"url": CharField(),
}
spider_ins.postgres_config = postgresql
spider_ins.postgres_model = {
"table_name": "ruia_postgres_both",
"title": CharField(),
"url": CharField(),
}
init_spider(spider_ins=spider_ins)

return init_after_start


class TestBoth:
@pytest.mark.dependency()
async def test_both_insert(self, mysql, postgresql, event_loop):
after_start = basic_setup(mysql, postgresql)
spider_ins = await BothInsert.async_start(
loop=event_loop, after_start=after_start, target_db=TargetDB.BOTH
)
count_mysql = await spider_ins.mysql_manager.count(
spider_ins.mysql_model.select()
)
count_postgres = await spider_ins.postgres_manager.count(
spider_ins.postgres_model.select()
)
assert count_mysql >= 10, "Should insert 10 rows in MySQL."
assert count_postgres >= 10, "Should insert 10 rows in PostgreSQL."

@pytest.mark.dependency(depends=["TestBoth::test_both_insert"])
async def test_both_update(self, mysql, postgresql, event_loop):
after_start = basic_setup(mysql, postgresql)
spider_ins = await BothUpdate.async_start(
loop=event_loop, after_start=after_start, target_db=TargetDB.BOTH
)
mysql_one = await spider_ins.mysql_manager.get(
spider_ins.mysql_model, id=randint(1, 11)
)
postgres_one = await spider_ins.postgres_manager.get(
spider_ins.postgres_model, id=randint(1, 11)
)
assert mysql_one.url == "http://testing.com"
assert postgres_one.url == "http://testing.com"
Loading

0 comments on commit e0e7549

Please sign in to comment.