diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 6972225..22deaad 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -5,7 +5,7 @@ jobs: name: base strategy: matrix: - python: [ '2.7', '3.5', '3.6', '3.6', '3.7'] + python: [ '2.7', '3.6', '3.6', '3.7', '3.8', '3.9'] # os: ['ubuntu-latest', 'windows-latest', 'macOs-latest'] os: ['ubuntu-latest', 'windows-latest'] @@ -20,6 +20,9 @@ jobs: run: pip install -U setuptools wheel - name: install run: pip install .[dev,ci] + - name: install async requirements + if: matrix.python != '2.7' + run: pip install trio curio - name: test run: python -m pytest --reruns 5 tests/ --cov oscpy/ --cov-branch - name: coveralls diff --git a/.gitignore b/.gitignore index 8316b71..188ef19 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ *.c *.pyd *.egg-info +*.swp +*.swn +.coverage +htmlcov/ .pytest_cache build dist diff --git a/README.md b/README.md index ec66c64..0cacfcb 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,29 @@ with OSCAsyncServer(port=8000) as OSC: print("unknown address {}".format(address)) ``` +Server (curio) + +```python +async def osc_app(address, port): + osc = OSCCurioServer(encoding='utf8') + osc.listen(address=address, port=port, default=True) + + @osc.address("/example") + async def example(*values): + print(f"got {values} on /example") + await curio.sleep(4) + print("done sleeping") + + @osc.address("/stop") + async def stop(*values): + print(f"time to leave!") + await osc.stop() + + await osc.process() + +curio.run(osc_app, '0.0.0.0', 8000) +``` + Client ```python diff --git a/examples/asyncio_example.py b/examples/asyncio_example.py new file mode 100644 index 0000000..5d194a5 --- /dev/null +++ b/examples/asyncio_example.py @@ -0,0 +1,33 @@ +import asyncio + +from oscpy.server.asyncio_server import OSCAsyncioServer + + +async def osc_app(address, port): + osc = OSCAsyncioServer(encoding='utf8') + osc.listen(address=address, port=port, default=True) + sock2 = osc.listen(address=address, port=port + 1) + + @osc.address("/example") + async def example(*values): + print(f"got {values} on /example") + await asyncio.sleep(4) + print("done sleeping") + + @osc.address("/test") + async def test(*values): + print(f"got {values} on /test") + await asyncio.sleep(4) + print("done sleeping") + + @osc.address("/stop", sock=sock2) + async def stop(*values): + print(f"time to leave!") + osc.terminate_server() + + print(sock2.getsockname()) + asyncio.get_event_loop().create_task(osc.process()) + await osc.join_server() + + +asyncio.run(osc_app('localhost', 8000)) diff --git a/examples/curio_example.py b/examples/curio_example.py new file mode 100644 index 0000000..44c875d --- /dev/null +++ b/examples/curio_example.py @@ -0,0 +1,29 @@ +import curio + +from oscpy.server.curio_server import OSCCurioServer + + +async def osc_app(address, port): + osc = OSCCurioServer(encoding='utf8') + osc.listen(address=address, port=port, default=True) + + @osc.address("/example") + async def example(*values): + print(f"got {values} on /example") + await curio.sleep(4) + print("done sleeping") + + @osc.address("/test") + async def test(*values): + print(f"got {values} on /test") + await curio.sleep(4) + print("done sleeping") + + @osc.address("/stop") + async def stop(*values): + print(f"time to leave!") + await osc.stop_all() + + await osc.process() + +curio.run(osc_app, '0.0.0.0', 8000) diff --git a/examples/thread_example.py b/examples/thread_example.py new file mode 100644 index 0000000..d0d23f0 --- /dev/null +++ b/examples/thread_example.py @@ -0,0 +1,22 @@ +import sys +from time import sleep + +from oscpy.server import OSCThreadServer + +osc = OSCThreadServer(encoding='utf8') +sock = osc.listen(address='0.0.0.0', port=8000, default=True) + +@osc.address('/address') +def callback(*values): + print("got values: {}".format(values)) + + +@osc.address('/stop') +def callback(*values): + print("time to leave") + osc.stop_all() + osc.terminate_server() + + +# wait until the server exits +osc.join_server() diff --git a/examples/trio_example.py b/examples/trio_example.py new file mode 100644 index 0000000..a01eda8 --- /dev/null +++ b/examples/trio_example.py @@ -0,0 +1,34 @@ +import trio + +from oscpy.server.trio_server import OSCTrioServer + + +async def osc_app(address, port): + osc = OSCTrioServer(encoding='utf8') + await osc.listen(address=address, port=port, default=True) + + @osc.address("/example") + async def example(*values): + print(f"got {values} on /example") + await trio.sleep(4) + print("done sleeping") + + @osc.address("/test") + async def test(*values): + print(f"got {values} on /test") + await trio.sleep(4) + print("done sleeping") + + @osc.address("/stop") + async def stop(*values): + print(f"time to leave!") + await osc.stop_all() + + @osc.address("/info") + async def info(): + address, port = osc.getaddress() + print(address, port) + + await osc.process() + +trio.run(osc_app, '0.0.0.0', 8000) diff --git a/oscpy/server.py b/oscpy/server/__init__.py similarity index 80% rename from oscpy/server.py rename to oscpy/server/__init__.py index 7e4bab1..6186db3 100644 --- a/oscpy/server.py +++ b/oscpy/server/__init__.py @@ -1,18 +1,15 @@ """Server API. - -This module currently only implements `OSCThreadServer`, a thread based server. """ import logging -from threading import Thread, Event - import os import re -import inspect from sys import platform -from time import sleep, time +from threading import Event +import inspect +from time import time from functools import partial -from select import select import socket +from select import select from oscpy import __version__ from oscpy.parser import read_packet, UNICODE @@ -22,6 +19,8 @@ logger = logging.getLogger(__name__) +UDP_MAX_SIZE = 65535 + def ServerClass(cls): """Decorate classes with for methods implementing OSC endpoints. @@ -48,21 +47,12 @@ def __init__(self, *args, **kwargs): __FILE__ = inspect.getfile(ServerClass) -class OSCThreadServer(object): - """A thread-based OSC server. - - Listens for osc messages in a thread, and dispatches the messages - values to callbacks from there. - - The '/_oscpy/' namespace is reserved for metadata about the OSCPy - internals, please see package documentation for further details. - """ - +class OSCBaseServer(object): def __init__( self, drop_late_bundles=False, timeout=0.01, advanced_matching=False, encoding='', encoding_errors='strict', default_handler=None, intercept_errors=True ): - """Create an OSCThreadServer. + """Create an OSC Server. - `timeout` is a number of seconds used as a time limit for select() calls in the listening thread, optiomal, defaults to @@ -87,7 +77,6 @@ def __init__( callbacks will be intercepted and logged. If False, the handler thread will terminate mostly silently on such exceptions. """ - self._must_loop = True self._termination_event = Event() self.addresses = {} @@ -104,14 +93,60 @@ def __init__( self.stats_received = Stats() self.stats_sent = Stats() - t = Thread(target=self._run_listener) - t.daemon = True - t.start() - self._thread = t - self._smart_address_cache = {} self._smart_part_cache = {} + @staticmethod + def get_socket(family, addr): + sock = socket.socket(family, socket.SOCK_DGRAM) + sock.bind(addr) + return sock + + def listen( + self, address='localhost', port=0, default=False, family='inet' + ): + """Start listening on an (address, port). + + - if `port` is 0, the system will allocate a free port + - if `default` is True, the instance will save this socket as the + default one for subsequent calls to methods with an optional socket + - `family` accepts the 'unix' and 'inet' values, a socket of the + corresponding type will be created. + If family is 'unix', then the address must be a filename, the + `port` value won't be used. 'unix' sockets are not defined on + Windows. + + The socket created to listen is returned, and can be used later + with methods accepting the `sock` parameter. + """ + if family == 'unix': + family_ = socket.AF_UNIX + elif family == 'inet': + family_ = socket.AF_INET + else: + raise ValueError( + "Unknown socket family, accepted values are 'unix' and 'inet'" + ) + + if family == 'unix': + addr = address + else: + addr = (address, port) + sock = self.get_socket(family_, addr) + self.add_socket(sock, default) + return sock + + def add_socket(self, sock, default): + self.sockets.append(sock) + if default and not self.default_socket: + self.default_socket = sock + elif default: + raise RuntimeError( + 'Only one default socket authorized! Please set ' + 'default=False to other calls to listen()' + ) + self.bind_meta_routes(sock) + def bind(self, address, callback, sock=None, get_address=False): """Bind a callback to an osc address. @@ -222,49 +257,6 @@ def unbind(self, address, callback, sock=None): self.addresses[(sock, address)] = callbacks - def listen( - self, address='localhost', port=0, default=False, family='inet' - ): - """Start listening on an (address, port). - - - if `port` is 0, the system will allocate a free port - - if `default` is True, the instance will save this socket as the - default one for subsequent calls to methods with an optional socket - - `family` accepts the 'unix' and 'inet' values, a socket of the - corresponding type will be created. - If family is 'unix', then the address must be a filename, the - `port` value won't be used. 'unix' sockets are not defined on - Windows. - - The socket created to listen is returned, and can be used later - with methods accepting the `sock` parameter. - """ - if family == 'unix': - family_ = socket.AF_UNIX - elif family == 'inet': - family_ = socket.AF_INET - else: - raise ValueError( - "Unknown socket family, accepted values are 'unix' and 'inet'" - ) - - sock = socket.socket(family_, socket.SOCK_DGRAM) - if family == 'unix': - addr = address - else: - addr = (address, port) - sock.bind(addr) - self.sockets.append(sock) - if default and not self.default_socket: - self.default_socket = sock - elif default: - raise RuntimeError( - 'Only one default socket authorized! Please set ' - 'default=False to other calls to listen()' - ) - self.bind_meta_routes(sock) - return sock - def close(self, sock=None): """Close a socket opened by the server.""" if not sock and self.default_socket: @@ -272,13 +264,22 @@ def close(self, sock=None): elif not sock: raise RuntimeError('no default socket yet and no socket provided') + if sock == self.default_socket: + self.default_socket = None + + if sock not in self.sockets: + return + + self.sockets.remove(sock) + read = select([sock], [], [], 0) if platform != 'win32' and sock.family == socket.AF_UNIX: + print(sock.getsockname()) os.unlink(sock.getsockname()) else: sock.close() - if sock == self.default_socket: - self.default_socket = None + if sock in read: + sock.recvfrom(UDP_MAX_SIZE) def getaddress(self, sock=None): """Wrap call to getsockname. @@ -295,120 +296,6 @@ def getaddress(self, sock=None): return sock.getsockname() - def stop(self, s=None): - """Close and remove a socket from the server's sockets. - - If `sock` is None, uses the default socket for the server. - - """ - if not s and self.default_socket: - s = self.default_socket - - if s in self.sockets: - read = select([s], [], [], 0) - s.close() - if s in read: - s.recvfrom(65535) - self.sockets.remove(s) - else: - raise RuntimeError('{} is not one of my sockets!'.format(s)) - - def stop_all(self): - """Call stop on all the existing sockets.""" - for s in self.sockets[:]: - self.stop(s) - sleep(10e-9) - - def terminate_server(self): - """Request the inner thread to finish its tasks and exit. - - May be called from an event, too. - """ - self._must_loop = False - - def join_server(self, timeout=None): - """Wait for the server to exit (`terminate_server()` must have been called before). - - Returns True if and only if the inner thread exited before timeout.""" - return self._termination_event.wait(timeout=timeout) - - def _run_listener(self): - """Wrapper just ensuring that the handler thread cleans up on exit.""" - try: - self._listen() - finally: - self._termination_event.set() - - def _listen(self): - """(internal) Busy loop to listen for events. - - This method is called in a thread by the `listen` method, and - will be the one actually listening for messages on the server's - sockets, and calling the callbacks when messages are received. - """ - - match = self._match_address - advanced_matching = self.advanced_matching - addresses = self.addresses - stats = self.stats_received - - def _execute_callbacks(_callbacks_list): - for cb, get_address in _callbacks_list: - try: - if get_address: - cb(address, *values) - else: - cb(*values) - except Exception as exc: - if self.intercept_errors: - logger.error("Unhandled exception caught in oscpy server", exc_info=True) - else: - raise - - while self._must_loop: - - drop_late = self.drop_late_bundles - if not self.sockets: - sleep(.01) - continue - else: - try: - read, write, error = select(self.sockets, [], [], self.timeout) - except (ValueError, socket.error): - continue - - for sender_socket in read: - try: - data, sender = sender_socket.recvfrom(65535) - except ConnectionResetError: - continue - - for address, tags, values, offset in read_packet( - data, drop_late=drop_late, encoding=self.encoding, - encoding_errors=self.encoding_errors - ): - stats.calls += 1 - stats.bytes += offset - stats.params += len(values) - stats.types.update(tags) - - matched = False - if advanced_matching: - for sock, addr in addresses: - if sock == sender_socket and match(addr, address): - callbacks_list = addresses.get((sock, addr), []) - if callbacks_list: - matched = True - _execute_callbacks(callbacks_list) - else: - callbacks_list = addresses.get((sender_socket, address), []) - if callbacks_list: - matched = True - _execute_callbacks(callbacks_list) - - if not matched and self.default_handler: - self.default_handler(address, *values) - @staticmethod def _match_address(smart_address, target_address): """(internal) Check if provided `smart_address` matches address. @@ -486,7 +373,7 @@ def get_sender(self): """ frames = inspect.getouterframes(inspect.currentframe()) for frame, filename, _, function, _, _ in frames: - if function == '_listen' and __FILE__.startswith(filename): + if function == 'handle_message' and frame.f_locals.get('self') is self and 'sender_socket' in frame.f_locals: break else: raise RuntimeError('get_sender() not called from a callback') @@ -635,3 +522,72 @@ def _get_stats_sent(self, port, *args): self.stats_sent.to_tuple(), port=port ) + + def _execute_callbacks(self, callbacks_list, address, values): + for cb, get_address in callbacks_list: + try: + if get_address: + cb(address, *values) + else: + cb(*values) + except Exception: + if self.intercept_errors: + logger.exception("Ignoring unhandled exception caught in oscpy server") + else: + logger.exception("Unhandled exception caught in oscpy server") + raise + + def handle_message(self, data, sender, sender_socket): + for callbacks, values, address in self.callbacks(data, sender, sender_socket): + self._execute_callbacks(callbacks, address, values) + + + def callbacks(self, data, sender, sender_socket): + match = self._match_address + advanced_matching = self.advanced_matching + addresses = self.addresses + stats = self.stats_received + drop_late = self.drop_late_bundles + + for address, tags, values, offset in read_packet( + data, drop_late=drop_late, encoding=self.encoding, + encoding_errors=self.encoding_errors + ): + stats.calls += 1 + stats.bytes += offset + stats.params += len(values) + stats.types.update(tags) + + matched = False + if advanced_matching: + for sock, addr in addresses: + if sock == sender_socket and match(addr, address): + callbacks_list = addresses.get((sock, addr), []) + if callbacks_list: + matched = True + yield callbacks_list, values, address + else: + callbacks_list = addresses.get((sender_socket, address), []) + if callbacks_list: + matched = True + yield callbacks_list, values, address + + if not matched and self.default_handler: + yield [(self.default_handler, True)], values, address + + def terminate_server(self): + """Request the inner thread to finish its tasks and exit. + + May be called from an event, too. + """ + self._termination_event.set() + + def join_server(self, timeout=None): + """Wait for the server to exit (`terminate_server()` must have been called before). + + Returns True if and only if the inner thread exited before timeout.""" + return self._termination_event.wait(timeout=timeout) + +# backward compatibility + +from oscpy.server.thread_server import OSCThreadServer diff --git a/oscpy/server/asyncio_server.py b/oscpy/server/asyncio_server.py new file mode 100644 index 0000000..3b7f836 --- /dev/null +++ b/oscpy/server/asyncio_server.py @@ -0,0 +1,121 @@ +import asyncio +import socket +from functools import partial +from logging import getLogger +from typing import Awaitable + +from oscpy.server import OSCBaseServer + + +logger = getLogger(__name__) + + +class OSCAsyncioServer(OSCBaseServer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.listeners = {} + self._termination_event = asyncio.Event() + + def listen(self, address='localhost', port=0, default=False, family='inet', **kwargs): + loop = asyncio.get_event_loop() + if family == 'unix': + family_ = socket.AF_UNIX + elif family == 'inet': + family_ = socket.AF_INET + else: + raise ValueError( + "Unknown socket family, accepted values are 'unix' and 'inet'" + ) + + if family == 'unix': + addr = address + else: + addr = (address, port) + + sock = self.get_socket( + family=family_, + addr=addr, + ) + self.listeners[(address, port or sock.getsockname()[1])] = loop.create_datagram_endpoint( + partial(OSCProtocol, self.handle_message, sock), + sock=sock, + ) + self.add_socket(sock, default) + return sock + + async def process(self): + return await asyncio.gather( + *self.listeners.values(), + return_exceptions=True, + ) + + async def handle_message(self, data, sender, sender_socket): + for callbacks, values, address in self.callbacks(data, sender, sender_socket): + await self._execute_callbacks(callbacks, address, values) + + async def _execute_callbacks(self, callbacks_list, address, values): + for cb, get_address in callbacks_list: + result = None + try: + if get_address: + result = cb(address, *values) + else: + result = cb(*values) + if isinstance(result, Awaitable): + await result + except asyncio.CancelledError: + ... + except Exception: + if self.intercept_errors: + logger.exception("Ignoring unhandled exception caught in oscpy server") + else: + raise + finally: + if result: + del result + + def stop(self, sock=None): + """Close and remove a socket from the server's sockets. + + If `sock` is None, uses the default socket for the server. + + """ + if not sock and self.default_socket: + sock = self.default_socket + + if sock in self.sockets: + sock.close() + self.sockets.remove(sock) + if sock is self.default_socket: + self.default_socket = None + else: + raise RuntimeError('{} is not one of my sockets!'.format(sock)) + + def stop_all(self): + for sock in self.sockets[:]: + self.stop(sock) + + async def join_server(self, timeout=None): + """Wait for the server to exit (`terminate_server()` must have been called before). + + Returns True if and only if the inner thread exited before timeout.""" + return await self._termination_event.wait() + + +class OSCProtocol(asyncio.DatagramProtocol): + def __init__(self, message_handler, sock, **kwargs): + super().__init__(**kwargs) + self.message_handler = message_handler + self.socket = sock + self.loop = asyncio.get_event_loop() + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + self.loop.call_soon( + lambda: asyncio.ensure_future(self.message_handler(data, addr, self.socket)) + ) + + def getsockname(self): + return self.socket.getsockname() diff --git a/oscpy/server/curio_server.py b/oscpy/server/curio_server.py new file mode 100644 index 0000000..1a6e4b3 --- /dev/null +++ b/oscpy/server/curio_server.py @@ -0,0 +1,95 @@ +import logging +from typing import Awaitable +from sys import platform +import os + +from curio import TaskGroup, socket +from oscpy.server import OSCBaseServer, UDP_MAX_SIZE + +logging.basicConfig() +logger = logging.getLogger(__name__) + + +class OSCCurioServer(OSCBaseServer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task_groups = {} + + @staticmethod + def get_socket(family, addr): + # identical to the parent method, except here socket is curio.socket + sock = socket.socket(family, socket.SOCK_DGRAM) + sock.bind(addr) + return sock + + async def _listen(self, sock): + async with TaskGroup(wait=all) as g: + self.task_groups[sock] = g + while not self._termination_event.is_set(): + data, addr = await sock.recvfrom(UDP_MAX_SIZE) + await g.spawn( + self.handle_message( + data, + addr, + drop_late=False, + sender_socket=sock + ) + ) + await g.join() + + async def handle_message(self, data, sender, drop_late, sender_socket): + for callbacks, values, address in self.callbacks(data, sender, sender_socket): + await self._execute_callbacks(callbacks, address, values) + + async def _execute_callbacks(self, callbacks_list, address, values): + for cb, get_address in callbacks_list: + try: + if get_address: + result = cb(address, *values) + else: + result = cb(*values) + if isinstance(result, Awaitable): + await result + except Exception: + if self.intercept_errors: + logger.exception("Ignoring unhandled exception caught in oscpy server") + else: + raise + + async def process(self): + async with TaskGroup(wait=all) as g: + self.tasks_group = g + for s in self.sockets: + await g.spawn(self._listen, s) + + async def close(self, sock=None): + """Close a socket opened by the server.""" + if not sock and self.default_socket: + sock = self.default_socket + elif not sock: + raise RuntimeError('no default socket yet and no socket provided') + + if sock not in self.sockets: + logger.warning("Ignoring requested to close an unknown socket %s" % sock) + + if sock == self.default_socket: + self.default_socket = None + + if platform != 'win32' and sock.family == socket.AF_UNIX: + os.unlink(sock.getsockname()) + else: + await sock.close() + + async def stop_all(self): + await self.tasks_group.cancel_remaining() + + async def stop(self, sock=None): + if not sock and self.default_socket: + sock = self.default_socket + + if sock in self.sockets: + g = self.task_groups.pop(sock) + await g.cancel_remaining() + else: + raise RuntimeError('{} is not one of my sockets!'.format(sock)) diff --git a/oscpy/server/thread_server.py b/oscpy/server/thread_server.py new file mode 100644 index 0000000..7e806f7 --- /dev/null +++ b/oscpy/server/thread_server.py @@ -0,0 +1,86 @@ +import logging +from threading import Thread + +from time import sleep +from select import select +import socket + +from oscpy.server import OSCBaseServer, UDP_MAX_SIZE + +logger = logging.getLogger(__name__) + + +class OSCThreadServer(OSCBaseServer): + """A thread-based OSC server. + + Listens for osc messages in a thread, and dispatches the messages + values to callbacks from there. + + The '/_oscpy/' namespace is reserved for metadata about the OSCPy + internals, please see package documentation for further details. + """ + + def __init__(self, *args, **kwargs): + super(OSCThreadServer, self).__init__(*args, **kwargs) + t = Thread(target=self._run_listener) + t.daemon = True + t.start() + self._thread = t + + def stop(self, s=None): + """Close and remove a socket from the server's sockets. + + If `sock` is None, uses the default socket for the server. + + """ + if not s and self.default_socket: + s = self.default_socket + + if s in self.sockets: + self.close(s) + else: + raise RuntimeError('{} is not one of my sockets!'.format(s)) + + def stop_all(self): + """Call stop on all the existing sockets.""" + for s in self.sockets[:]: + self.stop(s) + sleep(10e-9) + + def _run_listener(self): + """Wrapper just ensuring that the handler thread cleans up on exit.""" + try: + self._listen() + finally: + self._termination_event.set() + + def _listen(self): + """(internal) Busy loop to listen for events. + + This method is called in a thread by the `listen` method, and + will be the one actually listening for messages on the server's + sockets, and calling the callbacks when messages are received. + """ + + while not self._termination_event.is_set(): + if not self.sockets: + sleep(.01) + continue + else: + try: + read, write, error = select(self.sockets, [], [], self.timeout) + except (ValueError, socket.error): + continue + + for sender_socket in read: + try: + data, sender = sender_socket.recvfrom(UDP_MAX_SIZE) + except ConnectionResetError: + continue + + self.handle_message(data, sender, sender_socket) + + def join_server(self, timeout=None): + result = super(OSCThreadServer, self).join_server(timeout=timeout) + self._thread.join(timeout=timeout) + return result diff --git a/oscpy/server/trio_server.py b/oscpy/server/trio_server.py new file mode 100644 index 0000000..d21eeb3 --- /dev/null +++ b/oscpy/server/trio_server.py @@ -0,0 +1,149 @@ +import os +import logging +from functools import partial +from sys import platform +from typing import Awaitable + +from trio import socket, open_nursery, move_on_after +from oscpy.server import OSCBaseServer, UDP_MAX_SIZE + +logging.basicConfig() +logger = logging.getLogger(__name__) + + +class OSCTrioServer(OSCBaseServer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.nurseries = {} + + @staticmethod + async def get_socket(family, addr): + # identical to the parent method, except here socket is trio.socket + # and bind needs to be awaited + sock = socket.socket(family, socket.SOCK_DGRAM) + await sock.bind(addr) + return sock + + async def listen( + self, address='localhost', port=0, default=False, family='inet' + ): + if family == 'unix': + family_ = socket.AF_UNIX + elif family == 'inet': + family_ = socket.AF_INET + else: + raise ValueError( + "Unknown socket family, accepted values are 'unix' and 'inet'" + ) + + if family == 'unix': + addr = address + else: + addr = (address, port) + sock = await self.get_socket(family_, addr) + self.add_socket(sock, default) + return sock + + async def _listen(self, sock): + async with open_nursery() as nursery: + self.nurseries[sock] = nursery + try: + while True: + data, addr = await sock.recvfrom(UDP_MAX_SIZE) + nursery.start_soon( + partial( + self.handle_message, + data, + addr, + drop_late=False, + sender_socket=sock + ) + ) + finally: + with move_on_after(1) as cleanup_scope: + cleanup_scope.shield = True + logger.info("socket %s cancelled", sock) + await self.stop(sock) + + async def handle_message(self, data, sender, drop_late, sender_socket): + for callbacks, values, address in self.callbacks(data, sender, sender_socket): + await self._execute_callbacks(callbacks, address, values) + + async def _execute_callbacks(self, callbacks_list, address, values): + for cb, get_address in callbacks_list: + try: + if get_address: + result = cb(address, *values) + else: + result = cb(*values) + if isinstance(result, Awaitable): + await result + + except Exception: + if self.intercept_errors: + logger.error("Ignoring unhandled exception caught in oscpy server", exc_info=True) + else: + logger.exception("Unhandled exception caught in oscpy server") + raise + + async def process(self): + async with open_nursery() as nursery: + self.nursery = nursery + for s in self.sockets: + nursery.start_soon(self._listen, s) + + async def stop_all(self): + """Exit the main nursery, cancelling any in progress task + """ + self.nursery.cancel_scope.deadline = 0 + + async def stop(self, sock=None): + if sock is None: + if self.default_socket: + sock = self.default_socket + else: + raise RuntimeError('no default socket yet and no socket provided') + if sock in self.sockets: + self.sockets.remove(sock) + else: + raise RuntimeError("Socket %s is not managed by this server" % sock) + sock.close() + if sock in self.nurseries: + nursery = self.nurseries.pop(sock) + nursery.cancel_scope.deadline = 0 + + if sock is self.default_socket: + self.default_socket = None + + async def close(self, sock=None): + """Close a socket opened by the server.""" + if not sock and self.default_socket: + sock = self.default_socket + elif not sock: + raise RuntimeError('no default socket yet and no socket provided') + + if sock not in self.sockets: + logger.warning("Ignoring requested to close an unknown socket %s" % sock) + + if sock == self.default_socket: + self.default_socket = None + + if platform != 'win32' and sock.family == socket.AF_UNIX: + os.unlink(sock.getsockname()) + else: + sock.close() + + def getaddress(self, sock=None): + """Wrap call to getsockname. + + If `sock` is None, uses the default socket for the server. + + Returns (ip, port) for an inet socket, or filename for an unix + socket. + """ + if not sock and self.default_socket: + sock = self.default_socket + elif not sock: + raise RuntimeError('no default socket yet and no socket provided') + + return sock.getsockname() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_server.py b/tests/test_server.py index df43ce2..f7e8985 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,97 +1,135 @@ # coding: utf8 import pytest from time import time, sleep -from sys import platform +from sys import platform, version_info import socket from tempfile import mktemp from os.path import exists from os import unlink +from threading import Event from oscpy.server import OSCThreadServer, ServerClass from oscpy.client import send_message, send_bundle, OSCClient from oscpy import __version__ +from tests.utils import runner, _await, _callback + +if version_info > (3, 5, 0): + from oscpy.server.curio_server import OSCCurioServer + from oscpy.server.trio_server import OSCTrioServer + from oscpy.server.asyncio_server import OSCAsyncioServer + server_classes = { + OSCThreadServer, + OSCTrioServer, + OSCAsyncioServer, + OSCCurioServer, + } +else: + # so we can refer to them safely + OSCTrioServer = OSCAsyncioServer = OSCCurioServer = None + server_classes = {OSCThreadServer} + + +# # force a one second interval between each test to avoid messages hitting the +# # wrong server +# def teardown_function(function): +# sleep(1) -def test_instance(): - OSCThreadServer() +@pytest.mark.parametrize("cls", server_classes) +def test_instance(cls): + cls() -def test_listen(): - osc = OSCThreadServer() - sock = osc.listen() - osc.stop(sock) +@pytest.mark.parametrize("cls", server_classes) +def test_listen_simple(cls): + osc = cls() + sock = _await(osc.listen, osc) + runner(osc, timeout=1, socket=sock) + _await(osc.close, osc, (sock,)) -def test_getaddress(): - osc = OSCThreadServer() - sock = osc.listen() + +@pytest.mark.parametrize("cls", server_classes) +def test_getaddress(cls): + osc = cls() + sock = _await(osc.listen, osc) assert osc.getaddress(sock)[0] == '127.0.0.1' with pytest.raises(RuntimeError): osc.getaddress() - sock2 = osc.listen(default=True) + sock2 = _await(osc.listen, osc, kwargs=dict(default=True)) assert osc.getaddress(sock2)[0] == '127.0.0.1' - osc.stop(sock) + runner(osc, timeout=1, socket=sock) -def test_listen_default(): - osc = OSCThreadServer() - sock = osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_listen_default(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs=dict(default=True)) with pytest.raises(RuntimeError) as e_info: # noqa - osc.listen(default=True) + # osc.listen(default=True) + _await(osc.listen, osc, kwargs=dict(default=True)) - osc.close(sock) - osc.listen(default=True) + _await(osc.close, osc, (sock,)) + _await(osc.listen, osc, kwargs=dict(default=True)) -def test_close(): - osc = OSCThreadServer() - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_close(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs=dict(default=True)) + _await(osc.close, osc, (sock,)) - osc.close() - with pytest.raises(RuntimeError) as e_info: # noqa - osc.close() - if platform != 'win32': - filename = mktemp() - unix = osc.listen(address=filename, family='unix') - assert exists(filename) - osc.close(unix) - assert not exists(filename) +@pytest.mark.skipif(platform == 'win32', reason="unix sockets not available on windows") +@pytest.mark.parametrize("cls", server_classes) +def test_close_unix(cls): + osc = cls() + filename = mktemp() + unix = _await(osc.listen, osc, kwargs=dict(address=filename, family='unix')) + assert exists(filename) + _await(osc.close, osc, (unix,)) + assert not exists(filename) -def test_stop_unknown(): - osc = OSCThreadServer() +@pytest.mark.parametrize("cls", server_classes - {OSCCurioServer}) +def test_stop_unknown(cls): + osc = cls() with pytest.raises(RuntimeError): - osc.stop(socket.socket()) + _await(osc.stop, osc, args=[socket.socket()]) -def test_stop_default(): - osc = OSCThreadServer() - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes - {OSCCurioServer}) +def test_stop_default(cls): + osc = cls() + _await(osc.listen, osc, kwargs=dict(default=True)) assert len(osc.sockets) == 1 - osc.stop() + _await(osc.stop, osc) assert len(osc.sockets) == 0 -def test_stop_all(): - osc = OSCThreadServer() - sock = osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes - {OSCCurioServer}) +def test_stop_all(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs=dict(default=True)) host, port = sock.getsockname() - osc.listen() + sock2 = _await(osc.listen, osc) assert len(osc.sockets) == 2 - osc.stop_all() + runner(osc, timeout=.2) + _await(osc.stop_all, osc) assert len(osc.sockets) == 0 sleep(.1) - osc.listen(address=host, port=port) + sock3 = _await(osc.listen, osc, kwargs=dict(default=True)) assert len(osc.sockets) == 1 - osc.stop_all() + runner(osc, timeout=.2) + _await(osc.stop_all, osc) -def test_terminate_server(): - osc = OSCThreadServer() +@pytest.mark.parametrize("cls", {OSCThreadServer}) +def test_terminate_server(cls): + osc = cls() assert not osc.join_server(timeout=0.1) assert osc._thread.is_alive() osc.terminate_server() @@ -99,55 +137,59 @@ def test_terminate_server(): assert not osc._thread.is_alive() -def test_send_message_without_socket(): - osc = OSCThreadServer() +@pytest.mark.parametrize("cls", server_classes) +def test_send_message_without_socket(cls): + osc = cls() with pytest.raises(RuntimeError): osc.send_message(b'/test', [], 'localhost', 0) -def test_intercept_errors(caplog): +@pytest.mark.parametrize("cls", server_classes) +def test_intercept_errors(caplog, cls): - cont = [] + event = Event() def success(*values): - cont.append(True) + event.set() def broken_callback(*values): raise ValueError("some bad value") - osc = OSCThreadServer() - sock = osc.listen() + osc = cls() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] osc.bind(b'/broken_callback', broken_callback, sock) osc.bind(b'/success', success, sock) send_message(b'/broken_callback', [b'test'], 'localhost', port) - sleep(0.01) send_message(b'/success', [b'test'], 'localhost', port) - assert not osc.join_server(timeout=0.02) # Thread not stopped - assert cont == [True] + runner(osc, timeout=.2) + assert event.is_set() assert len(caplog.records) == 1, caplog.records record = caplog.records[0] - assert record.msg == "Unhandled exception caught in oscpy server" + assert record.msg == "Ignoring unhandled exception caught in oscpy server" assert not record.args assert record.exc_info - osc = OSCThreadServer(intercept_errors=False) - sock = osc.listen() + osc = cls(intercept_errors=False) + sock = _await(osc.listen, osc) port = sock.getsockname()[1] osc.bind(b'/broken_callback', broken_callback, sock) - send_message(b'/broken_callback', [b'test'], 'localhost', port) - assert osc.join_server(timeout=0.02) # Thread properly sets termination event on crash - - assert len(caplog.records) == 1, caplog.records # Unchanged + try: + send_message(b'/broken_callback', [b'test'], 'localhost', port) + runner(osc, timeout=.2) + except Exception: + pass + assert len(caplog.records) == 2, caplog.records # Unchanged -def test_send_bundle_without_socket(): - osc = OSCThreadServer() +@pytest.mark.parametrize("cls", server_classes) +def test_send_bundle_without_socket(cls): + osc = cls() with pytest.raises(RuntimeError): osc.send_bundle([], 'localhost', 0) - osc.listen(default=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) osc.send_bundle( ( (b'/test', []), @@ -156,73 +198,67 @@ def test_send_bundle_without_socket(): ) -def test_bind(): - osc = OSCThreadServer() - sock = osc.listen() +@pytest.mark.parametrize("cls", server_classes) +def test_bind1(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] - cont = [] + event = Event() def success(*values): - cont.append(True) + event.set() - osc.bind(b'/success', success, sock) + osc.bind(b'/success', success) send_message(b'/success', [b'test', 1, 1.12345], 'localhost', port) - - timeout = time() + 5 - while not cont: - if time() > timeout: - raise OSError('timeout while waiting for success message.') + runner(osc, timeout=.2) + assert event.is_set(), 'timeout while waiting for success message.' -def test_bind_get_address(): - osc = OSCThreadServer() - sock = osc.listen() +@pytest.mark.parametrize("cls", server_classes) +def test_bind_get_address(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] - cont = [] + event = Event() def success(address, *values): assert address == b'/success' - cont.append(True) + event.set() osc.bind(b'/success', success, sock, get_address=True) send_message(b'/success', [b'test', 1, 1.12345], 'localhost', port) - timeout = time() + 5 - while not cont: - if time() > timeout: - raise OSError('timeout while waiting for success message.') + runner(osc) + assert event.wait(1), 'timeout while waiting for success message.' -def test_bind_get_address_smart(): - osc = OSCThreadServer(advanced_matching=True) - sock = osc.listen() +@pytest.mark.parametrize("cls", server_classes) +def test_bind_get_address_smart(cls): + osc = cls(advanced_matching=True) + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] - cont = [] + event = Event() def success(address, *values): assert address == b'/success/a' - cont.append(True) + event.set() osc.bind(b'/success/?', success, sock, get_address=True) send_message(b'/success/a', [b'test', 1, 1.12345], 'localhost', port) + runner(osc, timeout=1, socket=sock) + assert event.wait(1), 'timeout while waiting for success message.' - timeout = time() + 5 - while not cont: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - - -def test_reuse_callback(): - osc = OSCThreadServer() - sock = osc.listen() +@pytest.mark.parametrize("cls", server_classes) +def test_reuse_callback(cls): + osc = cls() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] - cont = [] def success(*values): - cont.append(True) + pass osc.bind(b'/success', success, sock) osc.bind(b'/success', success, sock) @@ -231,14 +267,15 @@ def success(*values): assert len(osc.addresses.get((sock, b'/success2'))) == 1 -def test_unbind(): - osc = OSCThreadServer() - sock = osc.listen() +@pytest.mark.parametrize("cls", server_classes) +def test_unbind(cls): + osc = cls() + sock = _await(osc.listen, osc) port = sock.getsockname()[1] - cont = [] + event = Event() def failure(*values): - cont.append(True) + event.set() osc.bind(b'/failure', failure, sock) with pytest.raises(RuntimeError) as e_info: # noqa @@ -247,110 +284,94 @@ def failure(*values): send_message(b'/failure', [b'test', 1, 1.12345], 'localhost', port) - timeout = time() + 1 - while time() > timeout: - assert not cont - sleep(10e-9) + assert not event.wait(1), "Unexpected call to failure()" -def test_unbind_default(): - osc = OSCThreadServer() - sock = osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_unbind_default(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs={'default': True}) port = sock.getsockname()[1] - cont = [] + event = Event() def failure(*values): - cont.append(True) + event.set() osc.bind(b'/failure', failure) osc.unbind(b'/failure', failure) send_message(b'/failure', [b'test', 1, 1.12345], 'localhost', port) - timeout = time() + 1 - while time() > timeout: - assert not cont - sleep(10e-9) + assert not event.wait(1), "Unexpected call to failure()" + +@pytest.mark.parametrize("cls", server_classes) +def test_bind_multi(cls): + osc = cls() -def test_bind_multi(): - osc = OSCThreadServer() - sock1 = osc.listen() + sock1 = _await(osc.listen, osc) port1 = sock1.getsockname()[1] + event1 = Event() + osc.bind(b'/success', _callback(osc, lambda *_: event1.set()), sock1) - sock2 = osc.listen() + sock2 = _await(osc.listen, osc) port2 = sock2.getsockname()[1] - cont = [] - - def success1(*values): - cont.append(True) - - def success2(*values): - cont.append(False) - - osc.bind(b'/success', success1, sock1) - osc.bind(b'/success', success2, sock2) + event2 = Event() + osc.bind(b'/success', _callback(osc, lambda *_: event2.set()), sock2) send_message(b'/success', [b'test', 1, 1.12345], 'localhost', port1) send_message(b'/success', [b'test', 1, 1.12345], 'localhost', port2) - timeout = time() + 5 - while len(cont) < 2: - if time() > timeout: - raise OSError('timeout while waiting for success message.') + runner(osc, timeout=.1) + assert ( + event2.is_set() + and + event1.is_set() + ), 'timeout while waiting for success message.' - assert True in cont and False in cont - -def test_bind_address(): - osc = OSCThreadServer() - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_bind_address(cls): + osc = cls() + _await(osc.listen, osc, kwargs=dict(default=True)) result = [] + event = Event() @osc.address(b'/test') def success(*args): - result.append(True) + event.set() timeout = time() + 1 send_message(b'/test', [], *osc.getaddress()) - while len(result) < 1: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) - - assert True in result + runner(osc) + assert event.wait(1), 'timeout while waiting for test message.' -def test_bind_address_class(): - osc = OSCThreadServer() - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_bind_address_class(cls): + osc = cls() + _await(osc.listen, osc, kwargs=dict(default=True)) @ServerClass class Test(object): def __init__(self): - self.result = [] + self.event = Event() @osc.address_method(b'/test') def success(self, *args): - self.result.append(True) - - timeout = time() + 1 + self.event.set() test = Test() send_message(b'/test', [], *osc.getaddress()) - - while len(test.result) < 1: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) - - assert True in test.result + runner(osc) + assert test.event.wait(1), 'timeout while waiting for test message.' -def test_bind_no_default(): - osc = OSCThreadServer() +@pytest.mark.parametrize("cls", server_classes) +def test_bind_no_default(cls): + osc = cls() def success(*values): pass @@ -359,27 +380,27 @@ def success(*values): osc.bind(b'/success', success) -def test_bind_default(): - osc = OSCThreadServer() - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_bind_default(cls): + osc = cls() + _await(osc.listen, osc, kwargs=dict(default=True)) port = osc.getaddress()[1] - cont = [] + event = Event() def success(*values): - cont.append(True) + event.set() osc.bind(b'/success', success) send_message(b'/success', [b'test', 1, 1.12345], 'localhost', port) - timeout = time() + 5 - while not cont: - if time() > timeout: - raise OSError('timeout while waiting for success message.') + runner(osc) + assert event.wait(1), 'timeout while waiting for test message.' -def test_smart_address_match(): - osc = OSCThreadServer(advanced_matching=True) +@pytest.mark.parametrize("cls", server_classes) +def test_smart_address_match(cls): + osc = cls(advanced_matching=True) address = osc.create_smart_address(b'/test?') assert osc._match_address(address, b'/testa') @@ -448,16 +469,19 @@ def test_smart_address_match(): assert not osc._match_address(address, b'/testtest/stuff') -def test_smart_address_cache(): - osc = OSCThreadServer(advanced_matching=True) +@pytest.mark.parametrize("cls", server_classes) +def test_smart_address_cache(cls): + osc = cls(advanced_matching=True) assert osc.create_smart_address(b'/a') == osc.create_smart_address(b'/a') -def test_advanced_matching(): - osc = OSCThreadServer(advanced_matching=True) - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_advanced_matching(cls): + osc = cls(advanced_matching=True) + _await(osc.listen, osc, kwargs=dict(default=True)) port = osc.getaddress()[1] result = {} + event = Event() def save_result(f): name = f.__name__ @@ -549,6 +573,10 @@ def parts_somestrings2(*values): def parts_somestrings3(*values): pass + @osc.address(b'/done') + def done(*values): + event.set() + send_bundle( ( (b'/a', [1]), @@ -609,6 +637,7 @@ def parts_somestrings3(*values): (b'/part1/part2/string2', [48]), (b'/part1/part2/prefix-string1', [49]), (b'/part1/part2/sprefix-tring2', [50]), + (b'/done', []), ), 'localhost', port ) @@ -626,44 +655,40 @@ def parts_somestrings3(*values): 'parts_somestrings3': [(49,)] } - timeout = time() + 5 - while result != expected: - if time() > timeout: - print("expected: {}\n result: {}\n".format(expected, result)) - raise OSError('timeout while waiting for expected result.') - sleep(10e-9) + runner(osc, timeout=.1) + assert event.wait(1), 'timeout while waiting for test message.' + assert result == expected -def test_decorator(): - osc = OSCThreadServer() - sock = osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_decorator(cls): + osc = cls() + sock = _await(osc.listen, osc, kwargs=dict(default=True)) port = sock.getsockname()[1] - cont = [] + event1 = Event() + event2 = Event() @osc.address(b'/test1', sock) def test1(*values): - print("test1 called") - cont.append(True) + event1.set() @osc.address(b'/test2') def test2(*values): - print("test1 called") - cont.append(True) + event2.set() send_message(b'/test1', [], 'localhost', port) send_message(b'/test2', [], 'localhost', port) - timeout = time() + 1 - while len(cont) < 2: - if time() > timeout: - raise OSError('timeout while waiting for success message.') + runner(osc) + assert event1.wait(1) and event2.is_set(), "timeout waiting for test messages" -def test_answer(): - cont = [] +@pytest.mark.parametrize("cls", {OSCThreadServer}) +def test_answer(cls): + event = Event() - osc_1 = OSCThreadServer() - osc_1.listen(default=True) + osc_1 = cls(intercept_errors=False) + _await(osc_1.listen, osc_1, kwargs=dict(default=True)) @osc_1.address(b'/ping') def ping(*values): @@ -676,86 +701,91 @@ def ping(*values): ] ) - osc_2 = OSCThreadServer() - osc_2.listen(default=True) + osc_2 = OSCThreadServer(intercept_errors=False) + _await(osc_2.listen, osc_2, kwargs=dict(default=True)) @osc_2.address(b'/pong') def pong(*values): osc_2.answer(b'/ping', [True]) - osc_3 = OSCThreadServer() - osc_3.listen(default=True) + osc_3 = OSCThreadServer(intercept_errors=False) + _await(osc_3.listen, osc_3, kwargs=dict(default=True)) @osc_3.address(b'/zap') def zap(*values): if True in values: - cont.append(True) + event.set() osc_2.send_message(b'/ping', [], *osc_1.getaddress()) + runner(osc_1) + runner(osc_2) + runner(osc_3) with pytest.raises(RuntimeError) as e_info: # noqa osc_1.answer(b'/bing', []) - timeout = time() + 2 - while not cont: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) + assert event.wait(1), 'timeout while waiting for test message.' -def test_socket_family(): - osc = OSCThreadServer() - assert osc.listen().family == socket.AF_INET +@pytest.mark.parametrize("cls", server_classes) +def test_socket_family(cls): + osc = cls() + sock = _await(osc.listen, osc) + assert sock.family == socket.AF_INET filename = mktemp() if platform != 'win32': - assert osc.listen(address=filename, family='unix').family == socket.AF_UNIX # noqa + sock = _await(osc.listen, osc, kwargs=dict(address=filename, family='unix')) + assert sock.family == socket.AF_UNIX # noqa else: with pytest.raises(AttributeError) as e_info: - osc.listen(address=filename, family='unix') + _await(osc.listen, osc, kwargs=dict(family='unix')) if exists(filename): unlink(filename) with pytest.raises(ValueError) as e_info: # noqa - osc.listen(family='') + _await(osc.listen, osc, kwargs=dict(family='')) -def test_encoding_send(): - osc = OSCThreadServer() - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_encoding_send(cls): + osc = cls() + _await(osc.listen, osc, kwargs=dict(default=True)) values = [] + event = Event() @osc.address(b'/encoded') def encoded(*val): for v in val: assert isinstance(v, bytes) values.append(val) + event.set() send_message( u'/encoded', ['hello world', u'ééééé ààààà'], *osc.getaddress(), encoding='utf8') - timeout = time() + 2 - while not values: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) + runner(osc) + assert event.wait(1), 'timeout while waiting for test message.' -def test_encoding_receive(): - osc = OSCThreadServer(encoding='utf8') - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_encoding_receive(cls): + osc = cls(encoding='utf8') + _await(osc.listen, osc, kwargs=dict(default=True)) values = [] + event = Event() @osc.address(u'/encoded') def encoded(*val): for v in val: assert not isinstance(v, bytes) values.append(val) + event.set() send_message( b'/encoded', @@ -765,16 +795,15 @@ def encoded(*val): ], *osc.getaddress()) - timeout = time() + 2 - while not values: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) + runner(osc) + assert event.wait(1), 'timeout while waiting for test message.' -def test_encoding_send_receive(): - osc = OSCThreadServer(encoding='utf8') - osc.listen(default=True) +@pytest.mark.parametrize("cls", server_classes) +def test_encoding_send_receive(cls): + osc = cls(encoding='utf8') + _await(osc.listen, osc, kwargs=dict(default=True)) + event = Event() values = [] @@ -783,33 +812,34 @@ def encoded(*val): for v in val: assert not isinstance(v, bytes) values.append(val) + event.set() send_message( u'/encoded', ['hello world', u'ééééé ààààà'], *osc.getaddress(), encoding='utf8') - timeout = time() + 2 - while not values: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) + runner(osc) + assert event.wait(1), 'timeout while waiting for test message.' -def test_default_handler(): +@pytest.mark.parametrize("cls", server_classes) +def test_default_handler(cls): results = [] + event = Event() def test(address, *values): results.append((address, values)) + event.set() - osc = OSCThreadServer(default_handler=test) - osc.listen(default=True) + osc = cls(default_handler=test) + _await(osc.listen, osc, kwargs=dict(default=True)) @osc.address(b'/passthrough') def passthrough(*values): pass - osc.send_bundle( + send_bundle( ( (b'/test', []), (b'/passthrough', []), @@ -818,11 +848,8 @@ def passthrough(*values): *osc.getaddress() ) - timeout = time() + 2 - while len(results) < 2: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) + runner(osc) + assert event.wait(2), 'timeout while waiting for test message.' expected = ( (b'/test', tuple()), @@ -833,15 +860,15 @@ def passthrough(*values): assert e == r -def test_get_version(): - osc = OSCThreadServer(encoding='utf8') - osc.listen(default=True) +@pytest.mark.parametrize("cls", {OSCThreadServer}) +def test_get_version(cls): + osc = cls(encoding='utf8') + _await(osc.listen, osc, kwargs=dict(default=True)) values = [] @osc.address(u'/_oscpy/version/answer') def cb(val): - print(val) values.append(val) send_message( @@ -854,19 +881,16 @@ def cb(val): encoding_errors='strict' ) - timeout = time() + 2 - while not values: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) - + runner(osc) assert __version__ in values -def test_get_routes(): - osc = OSCThreadServer(encoding='utf8') - osc.listen(default=True) +@pytest.mark.parametrize("cls", {OSCThreadServer}) +def test_get_routes(cls): + osc = cls(encoding='utf8') + _await(osc.listen, osc, kwargs=dict(default=True)) + event = Event() values = [] @osc.address(u'/test_route') @@ -876,6 +900,7 @@ def dummy(*val): @osc.address(u'/_oscpy/routes/answer') def cb(*routes): values.extend(routes) + event.set() send_message( b'/_oscpy/routes', @@ -887,89 +912,67 @@ def cb(*routes): encoding_errors='strict' ) - timeout = time() + 2 - while not values: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) - + runner(osc) + assert event.wait(1) assert u'/test_route' in values -def test_get_sender(): - osc = OSCThreadServer(encoding='utf8') - osc.listen(default=True) - - values = [] +@pytest.mark.parametrize("cls", server_classes) +def test_get_sender(cls): + osc = cls(encoding='utf8') + _await(osc.listen, osc, kwargs=dict(default=True)) + event = Event() @osc.address(u'/test_route') def callback(*val): - values.append(osc.get_sender()) + osc.get_sender() + event.set() with pytest.raises(RuntimeError, - match='get_sender\(\) not called from a callback'): + match=r'get_sender\(\) not called from a callback'): osc.get_sender() send_message( - b'/test_route', - [ - osc.getaddress()[1] - ], + u'/test_route', + [osc.getaddress()[1]], *osc.getaddress(), encoding='utf8' ) - timeout = time() + 2 - while not values: - if time() > timeout: - raise OSError('timeout while waiting for success message.') - sleep(10e-9) + runner(osc) + assert event.wait(2), 'timeout while waiting for test message.' -def test_server_different_port(): +@pytest.mark.parametrize("cls", server_classes) +def test_server_different_port(cls): # used for storing values received by callback_3000 - checklist = [] + checklist = [Event(), Event()] - def callback_3000(*values): - checklist.append(values[0]) + def callback(index): + checklist[index].set() # server, will be tested: - server_3000 = OSCThreadServer(encoding='utf8') - sock_3000 = server_3000.listen(address='0.0.0.0', port=3000, default=True) - server_3000.bind(b'/callback_3000', callback_3000) - - # clients sending to different ports, used to test the server: - client_3000 = OSCClient(address='localhost', port=3000, encoding='utf8') - - # server sends message to himself, should work: - server_3000.send_message( - b'/callback_3000', - ["a"], - ip_address='localhost', - port=3000 - ) - sleep(0.05) + osc = cls(encoding='utf8') + sock = _await(osc.listen, osc, kwargs=dict(address='0.0.0.0', default=True)) + port = sock.getsockname()[1] + osc.bind('/callback', callback) - # client sends message to server, will be received properly: - client_3000.send_message(b'/callback_3000', ["b"]) - sleep(0.05) + # clients sending to different ports, used to test the osc: + client = OSCClient(address='localhost', port=port, encoding='utf8') + + # osc.send_message(b'/callback', [0], ip_address='localhost', port=port) + client.send_message('/callback', [0]) # sever sends message on different port, might crash the server on windows: - server_3000.send_message( - b'/callback_3000', - ["nobody is going to receive this"], - ip_address='localhost', - port=3001 - ) - sleep(0.05) + osc.send_message('/callback', ["nobody is going to receive this"], ip_address='localhost', port=port + 1) # client sends message to server again. if server is dead, message # will not be received: - client_3000.send_message(b'/callback_3000', ["c"]) - sleep(0.1) + client.send_message('/callback', [1]) # if 'c' is missing in the received checklist, the server thread # crashed and could not recieve the last message from the client: - assert checklist == ['a', 'b', 'c'] + runner(osc, timeout=0.1) + assert all(event.wait(1) for event in checklist) - server_3000.stop() # clean up + # osc.stop() # clean up diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..2bbc665 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,20 @@ +from sys import version_info +from time import sleep + +if version_info > (3, 5, 0): + from tests.utils_async import _await, runner, _callback +else: + def runner(osc, timeout=5, socket=None): + sleep(timeout) + if socket: + osc.stop(socket) + else: + osc.stop_all() + + def _await(something, osc, args=None, kwargs=None): + args = args or [] + kwargs = kwargs or {} + return something(*args, **kwargs) + + def _callback(osc, function): + return function diff --git a/tests/utils_async.py b/tests/utils_async.py new file mode 100644 index 0000000..8271e7b --- /dev/null +++ b/tests/utils_async.py @@ -0,0 +1,55 @@ +from functools import partial +from typing import Awaitable +from time import sleep +import asyncio + +import curio +import trio +from oscpy.server.curio_server import OSCCurioServer +from oscpy.server.trio_server import OSCTrioServer +from oscpy.server.asyncio_server import OSCAsyncioServer +from oscpy.server.thread_server import OSCThreadServer + +def _await(something, osc, args=None, kwargs=None, timeout=1): + args = args or [] + kwargs = kwargs or {} + if isinstance(osc, OSCTrioServer): + return trio.run(partial(something, *args, **kwargs)) + if isinstance(osc, OSCCurioServer): + async def wrapper(): + result = something(*args, **kwargs) + if isinstance(result, Awaitable): + result = await result + return result + return curio.run(wrapper) + else: + return something(*args, **kwargs) + +async def _trio_with_timout(process, timeout): + with trio.move_on_after(timeout): + await process() + +def runner(osc, timeout=1, socket=None): + if isinstance(osc, OSCThreadServer): + sleep(timeout) + if socket: + osc.stop(socket) + else: + osc.stop_all() + elif isinstance(osc, OSCCurioServer): + try: + curio.run(curio.timeout_after(timeout, osc.process)) + except curio.TaskTimeout: + ... + elif isinstance(osc, OSCTrioServer): + trio.run(lambda: _trio_with_timout(osc.process, timeout)) + elif isinstance(osc, OSCAsyncioServer): + loop = asyncio.get_event_loop() + loop.run_until_complete(osc.process()) + +def _callback(osc, function): + if isinstance(osc, OSCAsyncioServer): + async def _(*args, **kwargs): + return function(*args, **kwargs) + return _ + return function