From 82d68461962a175c5d24024b4764d4f1d9aa61b5 Mon Sep 17 00:00:00 2001 From: Askaholic Date: Fri, 20 Mar 2020 10:47:29 -0800 Subject: [PATCH] Performance: Don't await broadcasts (#548) * Don't await broadcasts * Pre-encode the PONG message --- server/__init__.py | 83 +++++++++---------- server/config.py | 8 -- server/lobbyconnection.py | 4 +- server/protocol/qdatastreamprotocol.py | 7 -- server/servercontext.py | 35 ++------ tests/integration_tests/test_load.py | 17 ---- tests/integration_tests/test_servercontext.py | 30 ------- 7 files changed, 50 insertions(+), 134 deletions(-) diff --git a/server/__init__.py b/server/__init__.py index c8c19b82e..6d66a29e7 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -10,7 +10,6 @@ from typing import Optional import aiomeasures -import asyncio from server.db import FAFDatabase from . import config as config @@ -50,6 +49,7 @@ DIRTY_REPORT_INTERVAL = 1 # Seconds stats = None +logger = logging.getLogger("server") if not config.ENABLE_STATSD: from . import fake_statsd @@ -59,9 +59,7 @@ stats = aiomeasures.StatsD(config.STATSD_SERVER) -def encode_message(message: str): - # Crazy evil encoding scheme - return QDataStreamProtocol.pack_message(message) +PING_MSG = QDataStreamProtocol.pack_message('PING') def run_lobby_server( @@ -86,63 +84,60 @@ async def do_report_dirties(): games.clear_dirty() player_service.clear_dirty() - tasks = [] - if dirty_queues: - tasks.append( - ctx.broadcast({ + try: + if dirty_queues: + ctx.write_broadcast({ 'command': 'matchmaker_info', 'queues': [queue.to_dict() for queue in dirty_queues] }, lambda lobby_conn: lobby_conn.authenticated ) - ) + except Exception: + logger.exception("Error writing matchmaker_info") - if dirty_players: - tasks.append( - ctx.broadcast({ + try: + if dirty_players: + ctx.write_broadcast({ 'command': 'player_info', 'players': [player.to_dict() for player in dirty_players] }, lambda lobby_conn: lobby_conn.authenticated ) - ) + except Exception: + logger.exception("Error writing player_info") # TODO: This spams squillions of messages: we should implement per- # connection message aggregation at the next abstraction layer down :P for game in dirty_games: - if game.state == GameState.ENDED: - games.remove_game(game) - - # So we're going to be broadcasting this to _somebody_... - message = game.to_dict() - - # These games shouldn't be broadcast, but instead privately sent - # to those who are allowed to see them. - if game.visibility == VisibilityState.FRIENDS: - # To see this game, you must have an authenticated - # connection and be a friend of the host, or the host. - def validation_func(lobby_conn): - return lobby_conn.player.id in game.host.friends or \ - lobby_conn.player == game.host - else: - def validation_func(lobby_conn): - return lobby_conn.player.id not in game.host.foes - - tasks.append(ctx.broadcast( - message, - lambda lobby_conn: lobby_conn.authenticated and validation_func(lobby_conn) - )) - - try: - await asyncio.gather(*tasks) - except Exception as e: - logging.getLogger().exception(e) - - ping_msg = encode_message('PING') + try: + if game.state == GameState.ENDED: + games.remove_game(game) + + # So we're going to be broadcasting this to _somebody_... + message = game.to_dict() + + # These games shouldn't be broadcast, but instead privately sent + # to those who are allowed to see them. + if game.visibility == VisibilityState.FRIENDS: + # To see this game, you must have an authenticated + # connection and be a friend of the host, or the host. + def validation_func(lobby_conn): + return lobby_conn.player.id in game.host.friends or \ + lobby_conn.player == game.host + else: + def validation_func(lobby_conn): + return lobby_conn.player.id not in game.host.foes + + ctx.write_broadcast( + message, + lambda lobby_conn: lobby_conn.authenticated and validation_func(lobby_conn) + ) + except Exception: + logger.exception("Error writing game_info %s", game.id) @at_interval(45) - async def ping_broadcast(): - await ctx.broadcast_raw(ping_msg) + def ping_broadcast(): + ctx.write_broadcast_raw(PING_MSG) def make_connection() -> LobbyConnection: return LobbyConnection( diff --git a/server/config.py b/server/config.py index 3ef5335fa..477d56952 100644 --- a/server/config.py +++ b/server/config.py @@ -39,14 +39,6 @@ FORCE_STEAM_LINK_AFTER_DATE = int(os.getenv('FORCE_STEAM_LINK_AFTER_DATE', 1536105599)) # 5 september 2018 by default FORCE_STEAM_LINK = os.getenv('FORCE_STEAM_LINK', 'false').lower() == 'true' -# How long we wait for a connection to read our messages before we consider -# it to be stalled. Stalled connections will be terminated if the max buffer -# size is reached. -CLIENT_STALL_TIME = int(os.getenv('CLIENT_STALL_TIME', 10)) -# Maximum number of bytes we will allow a stalled connection to get behind -# before we terminate their connection. -CLIENT_MAX_WRITE_BUFFER_SIZE = int(os.getenv('CLIENT_MAX_WRITE_BUFFER_SIZE', 2**17)) - NEWBIE_BASE_MEAN = int(os.getenv('NEWBIE_BASE_MEAN', 500)) NEWBIE_MIN_GAMES = int(os.getenv('NEWBIE_MIN_GAMES', 10)) TOP_PLAYER_MIN_RATING = int(os.getenv('TOP_PLAYER_MIN_RATING', 1600)) diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index 41d9df7f2..adb035d15 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -38,6 +38,8 @@ from .rating import RatingType from .types import Address +PONG_MSG = QDataStreamProtocol.pack_message("PONG") + class ClientError(Exception): """ @@ -181,7 +183,7 @@ async def on_message_received(self, message): await self.abort("Error processing command") async def command_ping(self, msg): - await self.protocol.send_raw(self.protocol.pack_message('PONG')) + await self.protocol.send_raw(PONG_MSG) async def command_pong(self, msg): pass diff --git a/server/protocol/qdatastreamprotocol.py b/server/protocol/qdatastreamprotocol.py index f1962b320..591445ca5 100644 --- a/server/protocol/qdatastreamprotocol.py +++ b/server/protocol/qdatastreamprotocol.py @@ -138,13 +138,6 @@ def close(self): """ self.writer.close() - def abort(self): - """ - Close writer stream immediately discarding the buffer contents - :return: - """ - self.writer.transport.abort() - async def drain(self): """ Await the write buffer to empty. diff --git a/server/servercontext.py b/server/servercontext.py index 9f57d9eee..ead3259ac 100644 --- a/server/servercontext.py +++ b/server/servercontext.py @@ -2,10 +2,9 @@ import server -from .async_functions import gather_without_exceptions -from .config import CLIENT_MAX_WRITE_BUFFER_SIZE, CLIENT_STALL_TIME, TRACE +from .config import TRACE from .decorators import with_logger -from .protocol import DisconnectedError, QDataStreamProtocol +from .protocol import QDataStreamProtocol from .types import Address @@ -54,41 +53,23 @@ def close(self): def __contains__(self, connection): return connection in self.connections.keys() - async def broadcast(self, message, validate_fn=lambda a: True): - await self.broadcast_raw( + def write_broadcast(self, message, validate_fn=lambda a: True): + self.write_broadcast_raw( QDataStreamProtocol.encode_message(message), validate_fn ) self._logger.log(TRACE, "]]: %s", message) - async def broadcast_raw(self, message, validate_fn=lambda a: True): + def write_broadcast_raw(self, data, validate_fn=lambda a: True): server.stats.incr('server.broadcasts') - tasks = [] for conn, proto in self.connections.items(): try: if proto.is_connected() and validate_fn(conn): - tasks.append( - self._send_raw_with_stall_handling(proto, message) - ) + proto.writer.write(data) except Exception: - self._logger.exception("Encountered error in broadcast") - - await gather_without_exceptions(tasks, DisconnectedError) - - async def _send_raw_with_stall_handling(self, proto, message): - try: - await asyncio.wait_for( - proto.send_raw(message), - timeout=CLIENT_STALL_TIME - ) - except asyncio.TimeoutError: - buffer_size = proto.writer.transport.get_write_buffer_size() - if buffer_size > CLIENT_MAX_WRITE_BUFFER_SIZE: - self._logger.warning( - "Terminating stalled connection with buffer size: %i", - buffer_size + self._logger.exception( + "Encountered error in broadcast: %s", conn ) - proto.abort() async def client_connected(self, stream_reader, stream_writer): self._logger.debug("%s: Client connected", self) diff --git a/tests/integration_tests/test_load.py b/tests/integration_tests/test_load.py index 4aacc0e74..d3c0295f6 100644 --- a/tests/integration_tests/test_load.py +++ b/tests/integration_tests/test_load.py @@ -143,20 +143,3 @@ async def test_backpressure_handling(lobby_server, caplog): with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(write_without_reading(proto), 10) - - -@fast_forward(1000) -async def test_backpressure_handling_stalls(lobby_server, caplog): - # TRACE will be spammed with thousands of messages - caplog.set_level(logging.DEBUG) - - _, _, proto = await connect_and_sign_in( - ("test", "test_password"), lobby_server - ) - # Set our local buffer size to 0 to help the server apply backpressure as - # early as possible. - proto.writer.transport.set_write_buffer_limits(high=0) - proto.reader._limit = 0 - - with pytest.raises(DisconnectedError): - await write_without_reading(proto) diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py index d25531639..bf7bcf8a9 100644 --- a/tests/integration_tests/test_servercontext.py +++ b/tests/integration_tests/test_servercontext.py @@ -70,36 +70,6 @@ async def test_serverside_abort(event_loop, mock_context, mock_server): mock_server.on_connection_lost.assert_any_call() -async def test_broadcast_raw(context, mock_server): - srv, ctx = context - (reader, writer) = await asyncio.open_connection( - *srv.sockets[0].getsockname() - ) - writer.close() - - # If connection errors aren't handled, this should fail due to a - # ConnectionError - for _ in range(20): - await ctx.broadcast_raw(b"Some bytes") - - assert len(ctx.connections) == 0 - - -async def test_broadcast(context, mock_server): - srv, ctx = context - (reader, writer) = await asyncio.open_connection( - *srv.sockets[0].getsockname() - ) - writer.close() - - # If connection errors aren't handled, this should fail due to a - # ConnectionError - for _ in range(20): - await ctx.broadcast(["Some message"]) - - assert len(ctx.connections) == 0 - - async def test_connection_broken_external(context, mock_server): """ When the connection breaks while the server is calling protocol.send from