diff --git a/README.md b/README.md index 5163a8a..7e37288 100644 --- a/README.md +++ b/README.md @@ -121,18 +121,26 @@ There's a `create_model` method to create the Peewee model based on database con ```python from ruia_peewee_async import create_model -model = create_model(mysql=mysql) # or postgres=postgres or both +mysql_model, mysql_manager, postgres_model, postgres_manager = create_model(mysql=mysql) # or postgres=postgres or both # create the table at the same time -model = create_mode(postgres=postgres, create_table=True) -rows = model.select().count() +mysql_model, mysql_manager, postgres_model, postgres_manager = create_model(mysql=mysql, create_table=True) # or postgres=postgres or both +rows = mysql_model.select().count() print(rows) ``` And class `Spider` from `ruia_peewee_async` has attributes below related to database you can use. ```python from peewee import Model -from typing import Dict -from peewee_async import Manager, MySQLDatabase, PostgresqlDatabase +from typing import Callable, Dict +from typing import Optional as TOptional +from peewee_async import ( + AsyncQueryWrapper, + Manager, + MySQLDatabase, + PooledMySQLDatabase, + PooledPostgresqlDatabase, + PostgresqlDatabase, +) from ruia import Spider as RuiaSpider class Spider(RuiaSpider): @@ -142,6 +150,8 @@ class Spider(RuiaSpider): postgres_manager: Manager mysql_db: MySQLDatabase postgres_db: PostgresqlDatabase + mysql_filters: TOptional[AsyncQueryWrapper] + postgres_filters: TOptional[AsyncQueryWrapper] ``` For more information, check out [peewee's documentation](http://docs.peewee-orm.com/en/latest/) and [peewee-async's documentation](https://peewee-async.readthedocs.io/en/latest/). diff --git a/ruia_peewee_async/__init__.py b/ruia_peewee_async/__init__.py index 337cf13..8942b78 100644 --- a/ruia_peewee_async/__init__.py +++ b/ruia_peewee_async/__init__.py @@ -96,7 +96,17 @@ async def filter_func(data, spider_ins, database, manager, model, filters) -> bo filtered = False filter_res = getattr(spider_ins, f"{database}_filters") for fil in filters: - if data[fil] in [getattr(x, fil) for x in filter_res]: + fil_res = [getattr(x, fil) for x in filter_res] + if not fil_res: + continue + outfil = data[fil] + if not isinstance(outfil, type(fil_res[0])): + outfil = ( + filter_res[0] # pylint: disable=protected-access + ._meta.columns[fil] + .adapt(outfil) + ) + if outfil in fil_res: filtered = True return filtered @@ -374,7 +384,7 @@ def check_config(kwargs) -> Sequence[Dict]: "database": And(str), "model": And({"table_name": And(str), str: object}), Optional("port"): And(int), - Optional("ssl"): Use(SSLContext), + Optional("ssl"): And(SSLContext), Optional("pool"): And(bool), Optional("min_connections"): And( int, lambda mic: 1 <= mic <= 10 diff --git a/tests/test_config.py b/tests/test_config.py index cd5cf03..55255da 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- +import ssl from copy import deepcopy +from contextlib import contextmanager import pytest from peewee import ModelBase @@ -24,6 +26,14 @@ def docker_cleansup(): return False +@contextmanager +def not_raises(exception): + try: + yield + except exception as exc: + raise pytest.fail(f"DID RAISE {exception}") from exc + + class TestConfig: async def test_process_errconfig( self, event_loop @@ -252,3 +262,16 @@ async def test_pool_config( assert mysql_manager.database.max_connections == 20 assert postgres_manager.database.min_connections == 5 assert postgres_manager.database.max_connections == 20 + + async def test_config( + self, docker_setup, docker_cleanup, event_loop, mysql_config, postgres_config + ): # pylint: disable=redefined-outer-name,unused-argument,unknown-option-value + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.check_hostname = False + ctx.load_verify_locations(cafile="/etc/ssl/certs/ca-certificates.crt") + postgres_config["ssl"] = ctx + mysql_config["ssl"] = ctx + with not_raises(SchemaError): + after_start(mysql=mysql_config) + with not_raises(SchemaError): + after_start(postgres=postgres_config)