diff --git a/SimpleWebSocketServer/SimpleWebSocketServer.py b/SimpleWebSocketServer/SimpleWebSocketServer.py index ccb9949..7d6acd6 100644 --- a/SimpleWebSocketServer/SimpleWebSocketServer.py +++ b/SimpleWebSocketServer/SimpleWebSocketServer.py @@ -623,20 +623,46 @@ def __init__(self, host, port, websocketclass, selectInterval = 0.1): self.serversocket.listen(5) self.selectInterval = selectInterval self.connections = {} + self.needDecorate = {} + self.needDecorateWriters = [] self.listeners = [self.serversocket] def _decorateSocket(self, sock): return sock + def _handshakeSocket(self, sock): + pass + def _constructWebSocket(self, sock, address): return self.websocketclass(self, sock, address) def close(self): - self.serversocket.close() - - for desc, conn in self.connections.items(): - conn.close() - self._handleClose(conn) + serversocket = self.serversocket + self.serversocket = None + self.listeners.remove(serversocket) + serversocket.close() + serversocket = None + + for desc, client in self.connections.items(): + # queue the CLOSE command + client.close() + + for listener in self.listeners: + if listener not in self.connections: + # SSL handshake not finished yet + listener.close() + self.listeners = [] + + # send the queued CLOSE command + self.selectInterval = 0.1 + limit = 5 + while self.connections and limit: + self.serveonce(closing=True) + limit -= 1 + + for desc, client in self.connections.items(): + # close any hung clients (most clients will be closed inside serveonce()) + self._handleClose(client) def _handleClose(self, client): client.client.close() @@ -647,58 +673,98 @@ def _handleClose(self, client): except: pass - def serveonce(self): - writers = [] - for fileno in self.listeners: - if fileno == self.serversocket: - continue - client = self.connections[fileno] - if client.sendq: - writers.append(fileno) + def serveonce(self, closing=False): + writers = self.needDecorateWriters + [fileno for fileno in self.connections if self.connections[fileno].sendq] rList, wList, xList = select(self.listeners, writers, self.listeners, self.selectInterval) + # self.close() may be called from a different thread + if (not self.serversocket) and (not closing): + return + for ready in wList: + if ready in self.needDecorateWriters: + try: + self._handshakeSocket(ready) + except ssl.SSLWantReadError: + self.needDecorateWriters.remove(ready) + continue + except ssl.SSLWantWriteError: + continue + except Exception as n: + ready.close() + self.listeners.remove(ready) + else: + self.connections[ready] = self._constructWebSocket(ready, self.needDecorate[ready]) + + # if we're here either the handshake finished or there was an unhandled exception + self.needDecorateWriters.remove(ready) + del self.needDecorate[ready] + continue + client = self.connections[ready] try: while client.sendq: opcode, payload = client.sendq.popleft() remaining = client._sendBuffer(payload) if remaining is not None: - client.sendq.appendleft((opcode, remaining)) - break + client.sendq.appendleft((opcode, remaining)) + break else: - if opcode == CLOSE: - raise Exception('received client close') + if opcode == CLOSE: + raise Exception('received client close') except Exception as n: self._handleClose(client) del self.connections[ready] - self.listeners.remove(ready) + if ready in self.listeners: + self.listeners.remove(ready) for ready in rList: if ready == self.serversocket: sock = None try: sock, address = self.serversocket.accept() + sock.settimeout(30) newsock = self._decorateSocket(sock) - newsock.setblocking(0) - fileno = newsock.fileno() - self.connections[fileno] = self._constructWebSocket(newsock, address) - self.listeners.append(fileno) + newsock.setblocking(False) + self.needDecorate[newsock] = address + self.listeners.append(newsock) except Exception as n: if sock is not None: sock.close() - else: - if ready not in self.connections: - continue - client = self.connections[ready] + continue + + if ready in self.needDecorate: try: - client._handleData() + self._handshakeSocket(ready) + except ssl.SSLWantReadError: + continue + except ssl.SSLWantWriteError: + self.needDecorateWriters.append(ready) + continue except Exception as n: - self._handleClose(client) - del self.connections[ready] + ready.close() self.listeners.remove(ready) + else: + self.connections[ready] = self._constructWebSocket(ready, self.needDecorate[ready]) + + # if we're here either the handshake finished or there was an unhandled exception + if ready in self.needDecorateWriters: + self.needDecorateWriters.remove(ready) + del self.needDecorate[ready] + continue + + if ready not in self.connections: + continue + + client = self.connections[ready] + try: + client._handleData() + except Exception as n: + self._handleClose(client) + del self.connections[ready] + self.listeners.remove(ready) for failed in xList: if failed == self.serversocket: @@ -713,7 +779,7 @@ def serveonce(self): self.listeners.remove(failed) def serveforever(self): - while True: + while self.serversocket: self.serveonce() class SimpleSSLWebSocketServer(SimpleWebSocketServer): @@ -734,9 +800,12 @@ def close(self): super(SimpleSSLWebSocketServer, self).close() def _decorateSocket(self, sock): - sslsock = self.context.wrap_socket(sock, server_side=True) + sslsock = self.context.wrap_socket(sock, server_side=True, do_handshake_on_connect=False) return sslsock + def _handshakeSocket(self, sock): + sock.do_handshake() + def _constructWebSocket(self, sock, address): ws = self.websocketclass(self, sock, address) ws.usingssl = True