diff --git a/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py b/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py index 3cc845b21..f9f1ace4d 100644 --- a/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py +++ b/lyrebird/mock/extra_mock_server/lyrebird_proxy_protocol.py @@ -1,6 +1,7 @@ from aiohttp import web, client from typing import List, Set, Optional from urllib import parse as urlparse +import re class UnknownLyrebirdProxyProtocol(Exception): @@ -120,14 +121,19 @@ def protocol_read_from_query_2(self, request: web.Request, lb_config): if not proxy_host: return - # remove lyrebrid proxy protocol keys from query string origin_query_str = '' - for query_key, query_value in request.query.items(): - if query_key in ['proxyscheme', 'proxyhost', 'proxypath']: - continue - origin_query_str += f'&{query_key}={query_value}' - if len(origin_query_str) >= 1: - origin_query_str = '?'+origin_query_str[1:] + qs_index = request.path_qs.find('?') + if qs_index >= 0: + # query string to 2D array + # like a=1&b=2 ==> [(a, 1), (b, 2)] + raw_query_string = request.path_qs[qs_index+1:] + raw_query_array = re.split('\\&|\\=', raw_query_string) + raw_query_items = list(zip(raw_query_array[::2], raw_query_array[1::2])) + # remove lyrebrid proxy protocol keys from query string + raw_query_items = list(filter(lambda x: x[0] not in ['proxyscheme', + 'proxyhost', 'proxypath'], raw_query_items)) + # 2D array to query string + origin_query_str = '?'+'&'.join([f'{item[0]}={item[1]}' for item in raw_query_items]) origin_url = f'{proxy_scheme}://{urlparse.unquote(proxy_host)}{urlparse.unquote(proxy_path)}{origin_query_str}' diff --git a/lyrebird/mock/extra_mock_server/server.py b/lyrebird/mock/extra_mock_server/server.py index b1e1831a6..fcfaa6e68 100644 --- a/lyrebird/mock/extra_mock_server/server.py +++ b/lyrebird/mock/extra_mock_server/server.py @@ -23,7 +23,7 @@ def is_filtered(context: LyrebirdProxyContext): allow list like ''' global lb_config - filters = lb_config.get('proxy.filters') + filters = lb_config.get('proxy.filters', []) for _filter in filters: if re.search(_filter, context.origin_url): return True @@ -107,20 +107,26 @@ async def req_handler(request: web.Request): return web.Response(status=500, text=f'{e.__class__.__name__}') -async def _run_app(config): +def init_app(config): + global lb_config + lb_config = config + global logger log.init(config) logger = log.get_logger() - global lb_config - lb_config = config + app = web.Application() + app.router.add_route('*', r'/{path:(.*)}', req_handler) + + return app + + +async def _run_app(config): + app = init_app(config) port = config.get('extra.mock.port') port = port if port else 9999 - app = web.Application() - app.router.add_route('*', r'/{path:(.*)}', req_handler) - try: app_runner = web.AppRunner(app, auto_decompress=False) await app_runner.setup() diff --git a/lyrebird/version.py b/lyrebird/version.py index 191a0391e..0410cc02c 100644 --- a/lyrebird/version.py +++ b/lyrebird/version.py @@ -1,3 +1,3 @@ -IVERSION = (2, 10, 1) +IVERSION = (2, 10, 2) VERSION = ".".join(str(i) for i in IVERSION) LYREBIRD = "Lyrebird " + VERSION diff --git a/requirements.dev.txt b/requirements.dev.txt index 9868b929e..e1c0cf8fb 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -2,3 +2,4 @@ autopep8==1.7.0 pip pytest pytest-cov +pytest-aiohttp diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 000000000..2f4c80e30 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/tests/test_extra_mock_server.py b/tests/test_extra_mock_server.py new file mode 100644 index 000000000..e95247984 --- /dev/null +++ b/tests/test_extra_mock_server.py @@ -0,0 +1,51 @@ +from aiohttp import web +from lyrebird.mock.extra_mock_server.server import init_app + +config = { + "version": "1.6.8", + "proxy.filters": [ + ], + "mock.proxy_headers": { + "scheme": "MKScheme", + "host": "MKOriginHost", + "port": "MKOriginPort" + }} + + +async def test_proxy_args_in_path(aiohttp_client, loop): + app = init_app(config) + client = await aiohttp_client(app) + resp = await client.get('/http://www.bing.com') + assert resp.status == 200 + text = await resp.text() + assert 'bing' in text + + +async def test_proxy_args_in_headers(aiohttp_client, loop): + app = init_app(config) + client = await aiohttp_client(app) + resp = await client.get('/http://www.bing.com', headers={ + 'MKScheme': 'http', + 'MKOriginHost': 'www.bing.com' + }) + assert resp.status == 200 + text = await resp.text() + assert 'bing' in text + + +async def test_proxy_args_in_query_v1(aiohttp_client, loop): + app = init_app(config) + client = await aiohttp_client(app) + resp = await client.get('/?proxy=http%3A//www.bing.com') + assert resp.status == 200 + text = await resp.text() + assert 'bing' in text + + +async def test_proxy_args_in_query_v2(aiohttp_client, loop): + app = init_app(config) + client = await aiohttp_client(app) + resp = await client.get('/?proxyscheme=http&proxyhost=www.bing.com') + assert resp.status == 200 + text = await resp.text() + assert 'bing' in text