From 25add2a0bac43823a4a5ef9217578a2cf5bcfc69 Mon Sep 17 00:00:00 2001 From: Ratan Kulshreshtha Date: Wed, 19 Jun 2019 18:41:31 +0530 Subject: [PATCH] Change addCleanup() calls to context managers in tests (#1624) --- test/contrib/test_socks.py | 179 +- .../with_dummyserver/test_chunked_transfer.py | 76 +- test/with_dummyserver/test_connectionpool.py | 1457 +++++++++-------- test/with_dummyserver/test_https.py | 827 +++++----- test/with_dummyserver/test_no_ssl.py | 20 +- test/with_dummyserver/test_poolmanager.py | 498 +++--- .../test_proxy_poolmanager.py | 552 ++++--- test/with_dummyserver/test_socketlevel.py | 607 ++++--- 8 files changed, 2089 insertions(+), 2127 deletions(-) diff --git a/test/contrib/test_socks.py b/test/contrib/test_socks.py index 8a253a8f..56303b8e 100644 --- a/test/contrib/test_socks.py +++ b/test/contrib/test_socks.py @@ -230,13 +230,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_local_dns(self): def request_handler(listener): @@ -264,13 +263,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://localhost") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://localhost") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_correct_header_line(self): def request_handler(listener): @@ -302,10 +300,9 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://example.com") - assert response.status == 200 + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://example.com") + assert response.status == 200 def test_connection_timeouts(self): event = threading.Event() @@ -315,12 +312,11 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url) as pm: - with pytest.raises(ConnectTimeoutError): - pm.request("GET", "http://example.com", timeout=0.001, retries=False) - event.set() + with pytest.raises(ConnectTimeoutError): + pm.request("GET", "http://example.com", timeout=0.001, retries=False) + event.set() def test_connection_failure(self): event = threading.Event() @@ -331,12 +327,11 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url) as pm: - event.wait() - with pytest.raises(NewConnectionError): - pm.request("GET", "http://example.com", retries=False) + event.wait() + with pytest.raises(NewConnectionError): + pm.request("GET", "http://example.com", retries=False) def test_proxy_rejection(self): evt = threading.Event() @@ -353,12 +348,11 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url) as pm: - with pytest.raises(NewConnectionError): - pm.request("GET", "http://example.com", retries=False) - evt.set() + with pytest.raises(NewConnectionError): + pm.request("GET", "http://example.com", retries=False) + evt.set() def test_socks_with_password(self): def request_handler(listener): @@ -388,14 +382,13 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="user", password="pass") - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url, username="user", password="pass") as pm: - response = pm.request("GET", "http://16.17.18.19") + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_socks_with_auth_in_url(self): """ @@ -430,14 +423,13 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://user:pass@%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url) as pm: - response = pm.request("GET", "http://16.17.18.19") + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_socks_with_invalid_password(self): def request_handler(listener): @@ -450,15 +442,16 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="user", password="badpass") - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager( + proxy_url, username="user", password="badpass" + ) as pm: - try: - pm.request("GET", "http://example.com", retries=False) - except NewConnectionError as e: - assert "SOCKS5 authentication failed" in str(e) - else: - self.fail("Did not raise") + try: + pm.request("GET", "http://example.com", retries=False) + except NewConnectionError as e: + assert "SOCKS5 authentication failed" in str(e) + else: + self.fail("Did not raise") def test_source_address_works(self): expected_port = _get_free_port(self.host) @@ -490,12 +483,11 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager( + with socks.SOCKSProxyManager( proxy_url, source_address=("127.0.0.1", expected_port) - ) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 + ) as pm: + response = pm.request("GET", "http://16.17.18.19") + assert response.status == 200 class TestSOCKS4Proxy(IPV4SocketDummyServerTestCase): @@ -532,13 +524,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.headers["Server"] == "SocksTestServer" - assert response.data == b"" + assert response.status == 200 + assert response.headers["Server"] == "SocksTestServer" + assert response.data == b"" def test_local_dns(self): def request_handler(listener): @@ -566,13 +557,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://localhost") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://localhost") - assert response.status == 200 - assert response.headers["Server"] == "SocksTestServer" - assert response.data == b"" + assert response.status == 200 + assert response.headers["Server"] == "SocksTestServer" + assert response.data == b"" def test_correct_header_line(self): def request_handler(listener): @@ -604,10 +594,9 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4a://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://example.com") - assert response.status == 200 + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://example.com") + assert response.status == 200 def test_proxy_rejection(self): evt = threading.Event() @@ -624,12 +613,11 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4a://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url) as pm: - with pytest.raises(NewConnectionError): - pm.request("GET", "http://example.com", retries=False) - evt.set() + with pytest.raises(NewConnectionError): + pm.request("GET", "http://example.com", retries=False) + evt.set() def test_socks4_with_username(self): def request_handler(listener): @@ -657,13 +645,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="user") - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url, username="user") as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_socks_with_invalid_username(self): def request_handler(listener): @@ -674,15 +661,14 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4a://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="baduser") - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url, username="baduser") as pm: - try: - pm.request("GET", "http://example.com", retries=False) - except NewConnectionError as e: - assert "different user-ids" in str(e) - else: - self.fail("Did not raise") + try: + pm.request("GET", "http://example.com", retries=False) + except NewConnectionError as e: + assert "different user-ids" in str(e) + else: + self.fail("Did not raise") class TestSOCKSWithTLS(IPV4SocketDummyServerTestCase): @@ -726,10 +712,9 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(pm.clear) - response = pm.request("GET", "https://localhost") + with socks.SOCKSProxyManager(proxy_url, ca_certs=DEFAULT_CA) as pm: + response = pm.request("GET", "https://localhost") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" diff --git a/test/with_dummyserver/test_chunked_transfer.py b/test/with_dummyserver/test_chunked_transfer.py index c5746ddd..7470b07c 100644 --- a/test/with_dummyserver/test_chunked_transfer.py +++ b/test/with_dummyserver/test_chunked_transfer.py @@ -27,38 +27,36 @@ def socket_handler(listener): def test_chunks(self): self.start_chunked_handler() chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] - pool = HTTPConnectionPool(self.host, self.port, retries=False) - pool.urlopen("GET", "/", chunks, headers=dict(DNT="1"), chunked=True) - self.addCleanup(pool.close) - - assert b"Transfer-Encoding" in self.buffer - body = self.buffer.split(b"\r\n\r\n", 1)[1] - lines = body.split(b"\r\n") - # Empty chunks should have been skipped, as this could not be distinguished - # from terminating the transmission - for i, chunk in enumerate([c for c in chunks if c]): - assert lines[i * 2] == hex(len(chunk))[2:].encode("utf-8") - assert lines[i * 2 + 1] == chunk.encode("utf-8") + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", chunks, headers=dict(DNT="1"), chunked=True) + + assert b"Transfer-Encoding" in self.buffer + body = self.buffer.split(b"\r\n\r\n", 1)[1] + lines = body.split(b"\r\n") + # Empty chunks should have been skipped, as this could not be distinguished + # from terminating the transmission + for i, chunk in enumerate([c for c in chunks if c]): + assert lines[i * 2] == hex(len(chunk))[2:].encode("utf-8") + assert lines[i * 2 + 1] == chunk.encode("utf-8") def _test_body(self, data): self.start_chunked_handler() - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - pool.urlopen("GET", "/", data, chunked=True) - header, body = self.buffer.split(b"\r\n\r\n", 1) + pool.urlopen("GET", "/", data, chunked=True) + header, body = self.buffer.split(b"\r\n\r\n", 1) - assert b"Transfer-Encoding: chunked" in header.split(b"\r\n") - if data: - bdata = data if isinstance(data, bytes) else data.encode("utf-8") - assert b"\r\n" + bdata + b"\r\n" in body - assert body.endswith(b"\r\n0\r\n\r\n") + assert b"Transfer-Encoding: chunked" in header.split(b"\r\n") + if data: + bdata = data if isinstance(data, bytes) else data.encode("utf-8") + assert b"\r\n" + bdata + b"\r\n" in body + assert body.endswith(b"\r\n0\r\n\r\n") - len_str = body.split(b"\r\n", 1)[0] - stated_len = int(len_str, 16) - assert stated_len == len(bdata) - else: - assert body == b"0\r\n\r\n" + len_str = body.split(b"\r\n", 1)[0] + stated_len = int(len_str, 16) + assert stated_len == len(bdata) + else: + assert body == b"0\r\n\r\n" def test_bytestring_body(self): self._test_body(b"thisshouldbeonechunk\r\nasdf") @@ -83,25 +81,23 @@ def test_empty_iterable_body(self): def test_removes_duplicate_host_header(self): self.start_chunked_handler() chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.urlopen("GET", "/", chunks, headers={"Host": "test.org"}, chunked=True) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", chunks, headers={"Host": "test.org"}, chunked=True) - header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() - header_lines = header_block.split(b"\r\n")[1:] + header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() + header_lines = header_block.split(b"\r\n")[1:] - host_headers = [x for x in header_lines if x.startswith(b"host")] - assert len(host_headers) == 1 + host_headers = [x for x in header_lines if x.startswith(b"host")] + assert len(host_headers) == 1 def test_provides_default_host_header(self): self.start_chunked_handler() chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"] - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.urlopen("GET", "/", chunks, chunked=True) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", chunks, chunked=True) - header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() - header_lines = header_block.split(b"\r\n")[1:] + header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() + header_lines = header_block.split(b"\r\n")[1:] - host_headers = [x for x in header_lines if x.startswith(b"host")] - assert len(host_headers) == 1 + host_headers = [x for x in header_lines if x.startswith(b"host")] + assert len(host_headers) == 1 diff --git a/test/with_dummyserver/test_connectionpool.py b/test/with_dummyserver/test_connectionpool.py index 65580055..7dfdba81 100644 --- a/test/with_dummyserver/test_connectionpool.py +++ b/test/with_dummyserver/test_connectionpool.py @@ -51,41 +51,39 @@ def test_timeout_float(self): ready_event = self.start_basic_handler(block_send=block_event, num=2) # Pool-global timeout - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=SHORT_TIMEOUT, retries=False - ) - self.addCleanup(pool.close) - wait_for_socket(ready_event) - with pytest.raises(ReadTimeoutError): + ) as pool: + wait_for_socket(ready_event) + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") + block_event.set() # Release block + + # Shouldn't raise this time + wait_for_socket(ready_event) + block_event.set() # Pre-release block pool.request("GET", "/") - block_event.set() # Release block - - # Shouldn't raise this time - wait_for_socket(ready_event) - block_event.set() # Pre-release block - pool.request("GET", "/") def test_conn_closed(self): block_event = Event() self.start_basic_handler(block_send=block_event, num=1) - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=SHORT_TIMEOUT, retries=False - ) - self.addCleanup(pool.close) - conn = pool._get_conn() - pool._put_conn(conn) - try: - pool.urlopen("GET", "/") - self.fail("The request should fail with a timeout error.") - except ReadTimeoutError: - if conn.sock: - with pytest.raises(socket.error): - conn.sock.recv(1024) - finally: + ) as pool: + conn = pool._get_conn() pool._put_conn(conn) - - block_event.set() + try: + pool.urlopen("GET", "/") + self.fail("The request should fail with a timeout error.") + except ReadTimeoutError: + if conn.sock: + with pytest.raises(socket.error): + conn.sock.recv(1024) + finally: + pool._put_conn(conn) + + block_event.set() def test_timeout(self): # Requests should time out when expected @@ -94,63 +92,63 @@ def test_timeout(self): # Pool-global timeout timeout = Timeout(read=SHORT_TIMEOUT) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout, retries=False) - self.addCleanup(pool.close) + with HTTPConnectionPool( + self.host, self.port, timeout=timeout, retries=False + ) as pool: - wait_for_socket(ready_event) - conn = pool._get_conn() - with pytest.raises(ReadTimeoutError): - pool._make_request(conn, "GET", "/") - pool._put_conn(conn) - block_event.set() # Release request + wait_for_socket(ready_event) + conn = pool._get_conn() + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "GET", "/") + pool._put_conn(conn) + block_event.set() # Release request - wait_for_socket(ready_event) - block_event.clear() - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") - block_event.set() # Release request + wait_for_socket(ready_event) + block_event.clear() + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") + block_event.set() # Release request # Request-specific timeouts should raise errors - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=LONG_TIMEOUT, retries=False - ) - self.addCleanup(pool.close) - - conn = pool._get_conn() - wait_for_socket(ready_event) - now = time.time() - with pytest.raises(ReadTimeoutError): - pool._make_request(conn, "GET", "/", timeout=timeout) - delta = time.time() - now - block_event.set() # Release request - - message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" - assert delta < LONG_TIMEOUT, message - pool._put_conn(conn) - - wait_for_socket(ready_event) - now = time.time() - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/", timeout=timeout) - delta = time.time() - now - - message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" - assert delta < LONG_TIMEOUT, message - block_event.set() # Release request - - # Timeout int/float passed directly to request and _make_request should - # raise a request timeout - wait_for_socket(ready_event) - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/", timeout=SHORT_TIMEOUT) - block_event.set() # Release request + ) as pool: + + conn = pool._get_conn() + wait_for_socket(ready_event) + now = time.time() + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "GET", "/", timeout=timeout) + delta = time.time() - now + block_event.set() # Release request + + message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" + assert delta < LONG_TIMEOUT, message + pool._put_conn(conn) - wait_for_socket(ready_event) - conn = pool._new_conn() - # FIXME: This assert flakes sometimes. Not sure why. - with pytest.raises(ReadTimeoutError): - pool._make_request(conn, "GET", "/", timeout=SHORT_TIMEOUT) - block_event.set() # Release request + wait_for_socket(ready_event) + now = time.time() + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/", timeout=timeout) + delta = time.time() - now + + message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" + assert delta < LONG_TIMEOUT, message + block_event.set() # Release request + + # Timeout int/float passed directly to request and _make_request should + # raise a request timeout + wait_for_socket(ready_event) + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/", timeout=SHORT_TIMEOUT) + block_event.set() # Release request + + wait_for_socket(ready_event) + conn = pool._new_conn() + # FIXME: This assert flakes sometimes. Not sure why. + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "GET", "/", timeout=SHORT_TIMEOUT) + block_event.set() # Release request def test_connect_timeout(self): url = "/" @@ -158,47 +156,47 @@ def test_connect_timeout(self): timeout = Timeout(connect=SHORT_TIMEOUT) # Pool-global timeout - pool = HTTPConnectionPool(host, port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", url) + with HTTPConnectionPool(host, port, timeout=timeout) as pool: + conn = pool._get_conn() + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", url) - # Retries - retries = Retry(connect=0) - with pytest.raises(MaxRetryError): - pool.request("GET", url, retries=retries) + # Retries + retries = Retry(connect=0) + with pytest.raises(MaxRetryError): + pool.request("GET", url, retries=retries) # Request-specific connection timeouts big_timeout = Timeout(read=LONG_TIMEOUT, connect=LONG_TIMEOUT) - pool = HTTPConnectionPool(host, port, timeout=big_timeout, retries=False) - self.addCleanup(pool.close) - conn = pool._get_conn() - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", url, timeout=timeout) + with HTTPConnectionPool(host, port, timeout=big_timeout, retries=False) as pool: + conn = pool._get_conn() + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", url, timeout=timeout) - pool._put_conn(conn) - with pytest.raises(ConnectTimeoutError): - pool.request("GET", url, timeout=timeout) + pool._put_conn(conn) + with pytest.raises(ConnectTimeoutError): + pool.request("GET", url, timeout=timeout) def test_total_applies_connect(self): host, port = TARPIT_HOST, 80 timeout = Timeout(total=None, connect=SHORT_TIMEOUT) - pool = HTTPConnectionPool(host, port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", "/") + with HTTPConnectionPool(host, port, timeout=timeout) as pool: + conn = pool._get_conn() + try: + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", "/") + finally: + conn.close() timeout = Timeout(connect=3, read=5, total=SHORT_TIMEOUT) - pool = HTTPConnectionPool(host, port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", "/") + with HTTPConnectionPool(host, port, timeout=timeout) as pool: + conn = pool._get_conn() + try: + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", "/") + finally: + conn.close() def test_total_timeout(self): block_event = Event() @@ -207,59 +205,61 @@ def test_total_timeout(self): wait_for_socket(ready_event) # This will get the socket to raise an EAGAIN on the read timeout = Timeout(connect=3, read=SHORT_TIMEOUT) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") + with HTTPConnectionPool( + self.host, self.port, timeout=timeout, retries=False + ) as pool: + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") - block_event.set() - wait_for_socket(ready_event) - block_event.clear() + block_event.set() + wait_for_socket(ready_event) + block_event.clear() # The connect should succeed and this should hit the read timeout timeout = Timeout(connect=3, read=5, total=SHORT_TIMEOUT) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") + with HTTPConnectionPool( + self.host, self.port, timeout=timeout, retries=False + ) as pool: + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") def test_create_connection_timeout(self): self.start_basic_handler(block_send=Event(), num=0) # needed for self.port timeout = Timeout(connect=SHORT_TIMEOUT, total=LONG_TIMEOUT) - pool = HTTPConnectionPool( + with HTTPConnectionPool( TARPIT_HOST, self.port, timeout=timeout, retries=False - ) - self.addCleanup(pool.close) - conn = pool._new_conn() - with pytest.raises(ConnectTimeoutError): - conn.connect() + ) as pool: + conn = pool._new_conn() + with pytest.raises(ConnectTimeoutError): + conn.connect() class TestConnectionPool(HTTPDummyServerTestCase): - def setUp(self): - self.pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(self.pool.close) - def test_get(self): - r = self.pool.request("GET", "/specific_method", fields={"method": "GET"}) - assert r.status == 200, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/specific_method", fields={"method": "GET"}) + assert r.status == 200, r.data def test_post_url(self): - r = self.pool.request("POST", "/specific_method", fields={"method": "POST"}) - assert r.status == 200, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("POST", "/specific_method", fields={"method": "POST"}) + assert r.status == 200, r.data def test_urlopen_put(self): - r = self.pool.urlopen("PUT", "/specific_method?method=PUT") - assert r.status == 200, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.urlopen("PUT", "/specific_method?method=PUT") + assert r.status == 200, r.data def test_wrong_specific_method(self): # To make sure the dummy server is actually returning failed responses - r = self.pool.request("GET", "/specific_method", fields={"method": "POST"}) - assert r.status == 400, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/specific_method", fields={"method": "POST"}) + assert r.status == 400, r.data - r = self.pool.request("POST", "/specific_method", fields={"method": "GET"}) - assert r.status == 400, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("POST", "/specific_method", fields={"method": "GET"}) + assert r.status == 400, r.data def test_upload(self): data = "I'm in ur multipart form-data, hazing a cheezburgr" @@ -270,28 +270,31 @@ def test_upload(self): "filefield": ("lolcat.txt", data), } - r = self.pool.request("POST", "/upload", fields=fields) - assert r.status == 200, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("POST", "/upload", fields=fields) + assert r.status == 200, r.data def test_one_name_multiple_values(self): fields = [("foo", "a"), ("foo", "b")] - # urlencode - r = self.pool.request("GET", "/echo", fields=fields) - assert r.data == b"foo=a&foo=b" + with HTTPConnectionPool(self.host, self.port) as pool: + # urlencode + r = pool.request("GET", "/echo", fields=fields) + assert r.data == b"foo=a&foo=b" - # multipart - r = self.pool.request("POST", "/echo", fields=fields) - assert r.data.count(b'name="foo"') == 2 + # multipart + r = pool.request("POST", "/echo", fields=fields) + assert r.data.count(b'name="foo"') == 2 def test_request_method_body(self): - body = b"hi" - r = self.pool.request("POST", "/echo", body=body) - assert r.data == body + with HTTPConnectionPool(self.host, self.port) as pool: + body = b"hi" + r = pool.request("POST", "/echo", body=body) + assert r.data == body - fields = [("hi", "hello")] - with pytest.raises(TypeError): - self.pool.request("POST", "/echo", body=body, fields=fields) + fields = [("hi", "hello")] + with pytest.raises(TypeError): + pool.request("POST", "/echo", body=body, fields=fields) def test_unicode_upload(self): fieldname = u("myfile") @@ -305,461 +308,475 @@ def test_unicode_upload(self): u("upload_size"): size, fieldname: (filename, data), } - - r = self.pool.request("POST", "/upload", fields=fields) - assert r.status == 200, r.data + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("POST", "/upload", fields=fields) + assert r.status == 200, r.data def test_nagle(self): """ Test that connections have TCP_NODELAY turned on """ # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - pool._make_request(conn, "GET", "/") - tcp_nodelay_setting = conn.sock.getsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY - ) - assert tcp_nodelay_setting + with HTTPConnectionPool(self.host, self.port) as pool: + conn = pool._get_conn() + try: + pool._make_request(conn, "GET", "/") + tcp_nodelay_setting = conn.sock.getsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY + ) + assert tcp_nodelay_setting + finally: + conn.close() def test_socket_options(self): """Test that connections accept socket options.""" # This test needs to be here in order to be run. socket.create_connection actually tries to # connect to the host provided so we need a dummyserver to be running. - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, socket_options=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)], - ) - s = pool._new_conn()._new_conn() # Get the socket - using_keepalive = s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 - assert using_keepalive - s.close() + ) as pool: + s = pool._new_conn()._new_conn() # Get the socket + try: + using_keepalive = ( + s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 + ) + assert using_keepalive + finally: + s.close() def test_disable_default_socket_options(self): """Test that passing None disables all socket options.""" # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. - pool = HTTPConnectionPool(self.host, self.port, socket_options=None) - s = pool._new_conn()._new_conn() - using_nagle = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) == 0 - assert using_nagle - s.close() + with HTTPConnectionPool(self.host, self.port, socket_options=None) as pool: + s = pool._new_conn()._new_conn() + try: + using_nagle = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) == 0 + assert using_nagle + finally: + s.close() def test_defaults_are_applied(self): """Test that modifying the default socket options works.""" # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - # Get the HTTPConnection instance - conn = pool._new_conn() - self.addCleanup(conn.close) - # Update the default socket options - conn.default_socket_options += [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] - s = conn._new_conn() - self.addCleanup(s.close) - nagle_disabled = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) > 0 - using_keepalive = s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 - assert nagle_disabled - assert using_keepalive + with HTTPConnectionPool(self.host, self.port) as pool: + # Get the HTTPConnection instance + conn = pool._new_conn() + try: + # Update the default socket options + conn.default_socket_options += [ + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + ] + s = conn._new_conn() + nagle_disabled = ( + s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) > 0 + ) + using_keepalive = ( + s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 + ) + assert nagle_disabled + assert using_keepalive + finally: + conn.close() + s.close() def test_connection_error_retries(self): """ ECONNREFUSED error should raise a connection error, with retries """ port = find_unused_port() - pool = HTTPConnectionPool(self.host, port) - try: - pool.request("GET", "/", retries=Retry(connect=3)) - self.fail("Should have failed with a connection error.") - except MaxRetryError as e: - assert type(e.reason) == NewConnectionError + with HTTPConnectionPool(self.host, port) as pool: + try: + pool.request("GET", "/", retries=Retry(connect=3)) + self.fail("Should have failed with a connection error.") + except MaxRetryError as e: + assert type(e.reason) == NewConnectionError def test_timeout_success(self): timeout = Timeout(connect=3, read=5, total=None) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - pool.request("GET", "/") - # This should not raise a "Timeout already started" error - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + pool.request("GET", "/") + # This should not raise a "Timeout already started" error + pool.request("GET", "/") - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - # This should also not raise a "Timeout already started" error - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + # This should also not raise a "Timeout already started" error + pool.request("GET", "/") timeout = Timeout(total=None) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + pool.request("GET", "/") def test_tunnel(self): # note the actual httplib.py has no tests for this functionality timeout = Timeout(total=None) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - conn.set_tunnel(self.host, self.port) - - conn._tunnel = mock.Mock(return_value=None) - pool._make_request(conn, "GET", "/") - conn._tunnel.assert_called_once_with() + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + conn = pool._get_conn() + try: + conn.set_tunnel(self.host, self.port) + conn._tunnel = mock.Mock(return_value=None) + pool._make_request(conn, "GET", "/") + conn._tunnel.assert_called_once_with() + finally: + conn.close() # test that it's not called when tunnel is not set timeout = Timeout(total=None) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - - conn._tunnel = mock.Mock(return_value=None) - pool._make_request(conn, "GET", "/") - assert not conn._tunnel.called + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + conn = pool._get_conn() + try: + conn._tunnel = mock.Mock(return_value=None) + pool._make_request(conn, "GET", "/") + assert not conn._tunnel.called + finally: + conn.close() def test_redirect(self): - r = self.pool.request( - "GET", "/redirect", fields={"target": "/"}, redirect=False - ) - assert r.status == 303 + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/redirect", fields={"target": "/"}, redirect=False) + assert r.status == 303 - r = self.pool.request("GET", "/redirect", fields={"target": "/"}) - assert r.status == 200 - assert r.data == b"Dummy server!" + r = pool.request("GET", "/redirect", fields={"target": "/"}) + assert r.status == 200 + assert r.data == b"Dummy server!" def test_bad_connect(self): - pool = HTTPConnectionPool("badhost.invalid", self.port) - try: - pool.request("GET", "/", retries=5) - self.fail("should raise timeout exception here") - except MaxRetryError as e: - assert type(e.reason) == NewConnectionError + with HTTPConnectionPool("badhost.invalid", self.port) as pool: + try: + pool.request("GET", "/", retries=5) + self.fail("should raise timeout exception here") + except MaxRetryError as e: + assert type(e.reason) == NewConnectionError def test_keepalive(self): - pool = HTTPConnectionPool(self.host, self.port, block=True, maxsize=1) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port, block=True, maxsize=1) as pool: + r = pool.request("GET", "/keepalive?close=0") + r = pool.request("GET", "/keepalive?close=0") - r = pool.request("GET", "/keepalive?close=0") - r = pool.request("GET", "/keepalive?close=0") - - assert r.status == 200 - assert pool.num_connections == 1 - assert pool.num_requests == 2 + assert r.status == 200 + assert pool.num_connections == 1 + assert pool.num_requests == 2 def test_keepalive_close(self): - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, block=True, maxsize=1, timeout=2 - ) - self.addCleanup(pool.close) + ) as pool: - r = pool.request( - "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} - ) + r = pool.request( + "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} + ) - assert pool.num_connections == 1 + assert pool.num_connections == 1 - # The dummyserver will have responded with Connection:close, - # and httplib will properly cleanup the socket. + # The dummyserver will have responded with Connection:close, + # and httplib will properly cleanup the socket. - # We grab the HTTPConnection object straight from the Queue, - # because _get_conn() is where the check & reset occurs - # pylint: disable-msg=W0212 - conn = pool.pool.get() - assert conn.sock is None - pool._put_conn(conn) + # We grab the HTTPConnection object straight from the Queue, + # because _get_conn() is where the check & reset occurs + # pylint: disable-msg=W0212 + conn = pool.pool.get() + assert conn.sock is None + pool._put_conn(conn) - # Now with keep-alive - r = pool.request( - "GET", "/keepalive?close=0", retries=0, headers={"Connection": "keep-alive"} - ) + # Now with keep-alive + r = pool.request( + "GET", + "/keepalive?close=0", + retries=0, + headers={"Connection": "keep-alive"}, + ) - # The dummyserver responded with Connection:keep-alive, the connection - # persists. - conn = pool.pool.get() - assert conn.sock is not None - pool._put_conn(conn) + # The dummyserver responded with Connection:keep-alive, the connection + # persists. + conn = pool.pool.get() + assert conn.sock is not None + pool._put_conn(conn) - # Another request asking the server to close the connection. This one - # should get cleaned up for the next request. - r = pool.request( - "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} - ) + # Another request asking the server to close the connection. This one + # should get cleaned up for the next request. + r = pool.request( + "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} + ) - assert r.status == 200 + assert r.status == 200 - conn = pool.pool.get() - assert conn.sock is None - pool._put_conn(conn) + conn = pool.pool.get() + assert conn.sock is None + pool._put_conn(conn) - # Next request - r = pool.request("GET", "/keepalive?close=0") + # Next request + r = pool.request("GET", "/keepalive?close=0") def test_post_with_urlencode(self): - data = {"banana": "hammock", "lol": "cat"} - r = self.pool.request("POST", "/echo", fields=data, encode_multipart=False) - assert r.data.decode("utf-8") == urlencode(data) + with HTTPConnectionPool(self.host, self.port) as pool: + data = {"banana": "hammock", "lol": "cat"} + r = pool.request("POST", "/echo", fields=data, encode_multipart=False) + assert r.data.decode("utf-8") == urlencode(data) def test_post_with_multipart(self): - data = {"banana": "hammock", "lol": "cat"} - r = self.pool.request("POST", "/echo", fields=data, encode_multipart=True) - body = r.data.split(b"\r\n") - - encoded_data = encode_multipart_formdata(data)[0] - expected_body = encoded_data.split(b"\r\n") - - # TODO: Get rid of extra parsing stuff when you can specify - # a custom boundary to encode_multipart_formdata - """ - We need to loop the return lines because a timestamp is attached - from within encode_multipart_formdata. When the server echos back - the data, it has the timestamp from when the data was encoded, which - is not equivalent to when we run encode_multipart_formdata on - the data again. - """ - for i, line in enumerate(body): - if line.startswith(b"--"): - continue - - assert body[i] == expected_body[i] + with HTTPConnectionPool(self.host, self.port) as pool: + data = {"banana": "hammock", "lol": "cat"} + r = pool.request("POST", "/echo", fields=data, encode_multipart=True) + body = r.data.split(b"\r\n") + + encoded_data = encode_multipart_formdata(data)[0] + expected_body = encoded_data.split(b"\r\n") + + # TODO: Get rid of extra parsing stuff when you can specify + # a custom boundary to encode_multipart_formdata + """ + We need to loop the return lines because a timestamp is attached + from within encode_multipart_formdata. When the server echos back + the data, it has the timestamp from when the data was encoded, which + is not equivalent to when we run encode_multipart_formdata on + the data again. + """ + for i, line in enumerate(body): + if line.startswith(b"--"): + continue + + assert body[i] == expected_body[i] def test_post_with_multipart__iter__(self): - data = {"hello": "world"} - r = self.pool.request( - "POST", - "/echo", - fields=data, - preload_content=False, - multipart_boundary="boundary", - encode_multipart=True, - ) + with HTTPConnectionPool(self.host, self.port) as pool: + data = {"hello": "world"} + r = pool.request( + "POST", + "/echo", + fields=data, + preload_content=False, + multipart_boundary="boundary", + encode_multipart=True, + ) - chunks = [chunk for chunk in r] - assert chunks == [ - b"--boundary\r\n", - b'Content-Disposition: form-data; name="hello"\r\n', - b"\r\n", - b"world\r\n", - b"--boundary--\r\n", - ] + chunks = [chunk for chunk in r] + assert chunks == [ + b"--boundary\r\n", + b'Content-Disposition: form-data; name="hello"\r\n', + b"\r\n", + b"world\r\n", + b"--boundary--\r\n", + ] def test_check_gzip(self): - r = self.pool.request( - "GET", "/encodingrequest", headers={"accept-encoding": "gzip"} - ) - assert r.headers.get("content-encoding") == "gzip" - assert r.data == b"hello, world!" + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request( + "GET", "/encodingrequest", headers={"accept-encoding": "gzip"} + ) + assert r.headers.get("content-encoding") == "gzip" + assert r.data == b"hello, world!" def test_check_deflate(self): - r = self.pool.request( - "GET", "/encodingrequest", headers={"accept-encoding": "deflate"} - ) - assert r.headers.get("content-encoding") == "deflate" - assert r.data == b"hello, world!" + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request( + "GET", "/encodingrequest", headers={"accept-encoding": "deflate"} + ) + assert r.headers.get("content-encoding") == "deflate" + assert r.data == b"hello, world!" def test_bad_decode(self): - with pytest.raises(DecodeError): - self.pool.request( - "GET", - "/encodingrequest", - headers={"accept-encoding": "garbage-deflate"}, - ) + with HTTPConnectionPool(self.host, self.port) as pool: + with pytest.raises(DecodeError): + pool.request( + "GET", + "/encodingrequest", + headers={"accept-encoding": "garbage-deflate"}, + ) - with pytest.raises(DecodeError): - self.pool.request( - "GET", "/encodingrequest", headers={"accept-encoding": "garbage-gzip"} - ) + with pytest.raises(DecodeError): + pool.request( + "GET", + "/encodingrequest", + headers={"accept-encoding": "garbage-gzip"}, + ) def test_connection_count(self): - pool = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(pool.close) - - pool.request("GET", "/") - pool.request("GET", "/") - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, maxsize=1) as pool: + pool.request("GET", "/") + pool.request("GET", "/") + pool.request("GET", "/") - assert pool.num_connections == 1 - assert pool.num_requests == 3 + assert pool.num_connections == 1 + assert pool.num_requests == 3 def test_connection_count_bigpool(self): - http_pool = HTTPConnectionPool(self.host, self.port, maxsize=16) - self.addCleanup(http_pool.close) + with HTTPConnectionPool(self.host, self.port, maxsize=16) as http_pool: + http_pool.request("GET", "/") + http_pool.request("GET", "/") + http_pool.request("GET", "/") - http_pool.request("GET", "/") - http_pool.request("GET", "/") - http_pool.request("GET", "/") - - assert http_pool.num_connections == 1 - assert http_pool.num_requests == 3 + assert http_pool.num_connections == 1 + assert http_pool.num_requests == 3 def test_partial_response(self): - pool = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(pool.close) - - req_data = {"lol": "cat"} - resp_data = urlencode(req_data).encode("utf-8") + with HTTPConnectionPool(self.host, self.port, maxsize=1) as pool: + req_data = {"lol": "cat"} + resp_data = urlencode(req_data).encode("utf-8") - r = pool.request("GET", "/echo", fields=req_data, preload_content=False) + r = pool.request("GET", "/echo", fields=req_data, preload_content=False) - assert r.read(5) == resp_data[:5] - assert r.read() == resp_data[5:] + assert r.read(5) == resp_data[:5] + assert r.read() == resp_data[5:] def test_lazy_load_twice(self): # This test is sad and confusing. Need to figure out what's # going on with partial reads and socket reuse. - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, block=True, maxsize=1, timeout=2 - ) + ) as pool: - payload_size = 1024 * 2 - first_chunk = 512 + payload_size = 1024 * 2 + first_chunk = 512 - boundary = "foo" + boundary = "foo" - req_data = {"count": "a" * payload_size} - resp_data = encode_multipart_formdata(req_data, boundary=boundary)[0] + req_data = {"count": "a" * payload_size} + resp_data = encode_multipart_formdata(req_data, boundary=boundary)[0] - req2_data = {"count": "b" * payload_size} - resp2_data = encode_multipart_formdata(req2_data, boundary=boundary)[0] + req2_data = {"count": "b" * payload_size} + resp2_data = encode_multipart_formdata(req2_data, boundary=boundary)[0] - r1 = pool.request( - "POST", - "/echo", - fields=req_data, - multipart_boundary=boundary, - preload_content=False, - ) - - assert r1.read(first_chunk) == resp_data[:first_chunk] - - try: - r2 = pool.request( + r1 = pool.request( "POST", "/echo", - fields=req2_data, + fields=req_data, multipart_boundary=boundary, preload_content=False, - pool_timeout=0.001, ) - # This branch should generally bail here, but maybe someday it will - # work? Perhaps by some sort of magic. Consider it a TODO. + assert r1.read(first_chunk) == resp_data[:first_chunk] + + try: + r2 = pool.request( + "POST", + "/echo", + fields=req2_data, + multipart_boundary=boundary, + preload_content=False, + pool_timeout=0.001, + ) - assert r2.read(first_chunk) == resp2_data[:first_chunk] + # This branch should generally bail here, but maybe someday it will + # work? Perhaps by some sort of magic. Consider it a TODO. - assert r1.read() == resp_data[first_chunk:] - assert r2.read() == resp2_data[first_chunk:] - assert pool.num_requests == 2 + assert r2.read(first_chunk) == resp2_data[:first_chunk] + + assert r1.read() == resp_data[first_chunk:] + assert r2.read() == resp2_data[first_chunk:] + assert pool.num_requests == 2 - except EmptyPoolError: - assert r1.read() == resp_data[first_chunk:] - assert pool.num_requests == 1 + except EmptyPoolError: + assert r1.read() == resp_data[first_chunk:] + assert pool.num_requests == 1 - assert pool.num_connections == 1 + assert pool.num_connections == 1 def test_for_double_release(self): MAXSIZE = 5 # Check default state - pool = HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) - self.addCleanup(pool.close) - assert pool.num_connections == 0 - assert pool.pool.qsize() == MAXSIZE + with HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) as pool: + assert pool.num_connections == 0 + assert pool.pool.qsize() == MAXSIZE - # Make an empty slot for testing - pool.pool.get() - assert pool.pool.qsize() == MAXSIZE - 1 + # Make an empty slot for testing + pool.pool.get() + assert pool.pool.qsize() == MAXSIZE - 1 - # Check state after simple request - pool.urlopen("GET", "/") - assert pool.pool.qsize() == MAXSIZE - 1 + # Check state after simple request + pool.urlopen("GET", "/") + assert pool.pool.qsize() == MAXSIZE - 1 - # Check state without release - pool.urlopen("GET", "/", preload_content=False) - assert pool.pool.qsize() == MAXSIZE - 2 + # Check state without release + pool.urlopen("GET", "/", preload_content=False) + assert pool.pool.qsize() == MAXSIZE - 2 - pool.urlopen("GET", "/") - assert pool.pool.qsize() == MAXSIZE - 2 + pool.urlopen("GET", "/") + assert pool.pool.qsize() == MAXSIZE - 2 - # Check state after read - pool.urlopen("GET", "/").data - assert pool.pool.qsize() == MAXSIZE - 2 + # Check state after read + pool.urlopen("GET", "/").data + assert pool.pool.qsize() == MAXSIZE - 2 - pool.urlopen("GET", "/") - assert pool.pool.qsize() == MAXSIZE - 2 + pool.urlopen("GET", "/") + assert pool.pool.qsize() == MAXSIZE - 2 def test_release_conn_parameter(self): MAXSIZE = 5 - pool = HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) - assert pool.pool.qsize() == MAXSIZE + with HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) as pool: + assert pool.pool.qsize() == MAXSIZE - # Make request without releasing connection - pool.request("GET", "/", release_conn=False, preload_content=False) - assert pool.pool.qsize() == MAXSIZE - 1 + # Make request without releasing connection + pool.request("GET", "/", release_conn=False, preload_content=False) + assert pool.pool.qsize() == MAXSIZE - 1 def test_dns_error(self): - pool = HTTPConnectionPool( + with HTTPConnectionPool( "thishostdoesnotexist.invalid", self.port, timeout=0.001 - ) - with pytest.raises(MaxRetryError): - pool.request("GET", "/test", retries=2) + ) as pool: + with pytest.raises(MaxRetryError): + pool.request("GET", "/test", retries=2) def test_source_address(self): for addr, is_ipv6 in VALID_SOURCE_ADDRESSES: if is_ipv6 and not HAS_IPV6_AND_DNS: warnings.warn("No IPv6 support: skipping.", NoIPv6Warning) continue - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, source_address=addr, retries=False - ) - self.addCleanup(pool.close) - r = pool.request("GET", "/source_address") - assert r.data == b(addr[0]) + ) as pool: + r = pool.request("GET", "/source_address") + assert r.data == b(addr[0]) def test_source_address_error(self): for addr in INVALID_SOURCE_ADDRESSES: - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, source_address=addr, retries=False - ) - # FIXME: This assert flakes sometimes. Not sure why. - with pytest.raises(NewConnectionError): - pool.request("GET", "/source_address?{0}".format(addr)) + ) as pool: + # FIXME: This assert flakes sometimes. Not sure why. + with pytest.raises(NewConnectionError): + pool.request("GET", "/source_address?{0}".format(addr)) def test_stream_keepalive(self): x = 2 - for _ in range(x): - response = self.pool.request( - "GET", - "/chunked", - headers={"Connection": "keep-alive"}, - preload_content=False, - retries=False, - ) - for chunk in response.stream(): - assert chunk == b"123" + with HTTPConnectionPool(self.host, self.port) as pool: + for _ in range(x): + response = pool.request( + "GET", + "/chunked", + headers={"Connection": "keep-alive"}, + preload_content=False, + retries=False, + ) + for chunk in response.stream(): + assert chunk == b"123" - assert self.pool.num_connections == 1 - assert self.pool.num_requests == x + assert pool.num_connections == 1 + assert pool.num_requests == x def test_read_chunked_short_circuit(self): - response = self.pool.request("GET", "/chunked", preload_content=False) - response.read() - with pytest.raises(StopIteration): - next(response.read_chunked()) + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request("GET", "/chunked", preload_content=False) + response.read() + with pytest.raises(StopIteration): + next(response.read_chunked()) def test_read_chunked_on_closed_response(self): - response = self.pool.request("GET", "/chunked", preload_content=False) - response.close() - with pytest.raises(StopIteration): - next(response.read_chunked()) + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request("GET", "/chunked", preload_content=False) + response.close() + with pytest.raises(StopIteration): + next(response.read_chunked()) def test_chunked_gzip(self): - response = self.pool.request( - "GET", "/chunked_gzip", preload_content=False, decode_content=True - ) + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request( + "GET", "/chunked_gzip", preload_content=False, decode_content=True + ) - assert b"123" * 4 == response.read() + assert b"123" * 4 == response.read() def test_cleanup_on_connection_error(self): """ @@ -797,295 +814,310 @@ def test_cleanup_on_connection_error(self): assert http.pool.qsize() == http.pool.maxsize def test_mixed_case_hostname(self): - pool = HTTPConnectionPool("LoCaLhOsT", self.port) - self.addCleanup(pool.close) - response = pool.request("GET", "http://LoCaLhOsT:%d/" % self.port) - assert response.status == 200 + with HTTPConnectionPool("LoCaLhOsT", self.port) as pool: + response = pool.request("GET", "http://LoCaLhOsT:%d/" % self.port) + assert response.status == 200 class TestRetry(HTTPDummyServerTestCase): - def setUp(self): - self.pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(self.pool.close) - def test_max_retry(self): - try: - r = self.pool.request("GET", "/redirect", fields={"target": "/"}, retries=0) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass + with HTTPConnectionPool(self.host, self.port) as pool: + try: + r = pool.request("GET", "/redirect", fields={"target": "/"}, retries=0) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass def test_disabled_retry(self): """ Disabled retries should disable redirect handling. """ - r = self.pool.request("GET", "/redirect", fields={"target": "/"}, retries=False) - assert r.status == 303 + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/redirect", fields={"target": "/"}, retries=False) + assert r.status == 303 - r = self.pool.request( - "GET", "/redirect", fields={"target": "/"}, retries=Retry(redirect=False) - ) - assert r.status == 303 + r = pool.request( + "GET", + "/redirect", + fields={"target": "/"}, + retries=Retry(redirect=False), + ) + assert r.status == 303 - pool = HTTPConnectionPool( + with HTTPConnectionPool( "thishostdoesnotexist.invalid", self.port, timeout=0.001 - ) - with pytest.raises(NewConnectionError): - pool.request("GET", "/test", retries=False) + ) as pool: + with pytest.raises(NewConnectionError): + pool.request("GET", "/test", retries=False) def test_read_retries(self): """ Should retry for status codes in the whitelist """ - retry = Retry(read=1, status_forcelist=[418]) - resp = self.pool.request( - "GET", - "/successful_retry", - headers={"test-name": "test_read_retries"}, - retries=retry, - ) - assert resp.status == 200 + with HTTPConnectionPool(self.host, self.port) as pool: + retry = Retry(read=1, status_forcelist=[418]) + resp = pool.request( + "GET", + "/successful_retry", + headers={"test-name": "test_read_retries"}, + retries=retry, + ) + assert resp.status == 200 def test_read_total_retries(self): """ HTTP response w/ status code in the whitelist should be retried """ - headers = {"test-name": "test_read_total_retries"} - retry = Retry(total=1, status_forcelist=[418]) - resp = self.pool.request( - "GET", "/successful_retry", headers=headers, retries=retry - ) - assert resp.status == 200 + with HTTPConnectionPool(self.host, self.port) as pool: + headers = {"test-name": "test_read_total_retries"} + retry = Retry(total=1, status_forcelist=[418]) + resp = pool.request( + "GET", "/successful_retry", headers=headers, retries=retry + ) + assert resp.status == 200 def test_retries_wrong_whitelist(self): """HTTP response w/ status code not in whitelist shouldn't be retried""" - retry = Retry(total=1, status_forcelist=[202]) - resp = self.pool.request( - "GET", - "/successful_retry", - headers={"test-name": "test_wrong_whitelist"}, - retries=retry, - ) - assert resp.status == 418 + with HTTPConnectionPool(self.host, self.port) as pool: + retry = Retry(total=1, status_forcelist=[202]) + resp = pool.request( + "GET", + "/successful_retry", + headers={"test-name": "test_wrong_whitelist"}, + retries=retry, + ) + assert resp.status == 418 def test_default_method_whitelist_retried(self): """ urllib3 should retry methods in the default method whitelist """ - retry = Retry(total=1, status_forcelist=[418]) - resp = self.pool.request( - "OPTIONS", - "/successful_retry", - headers={"test-name": "test_default_whitelist"}, - retries=retry, - ) - assert resp.status == 200 + with HTTPConnectionPool(self.host, self.port) as pool: + retry = Retry(total=1, status_forcelist=[418]) + resp = pool.request( + "OPTIONS", + "/successful_retry", + headers={"test-name": "test_default_whitelist"}, + retries=retry, + ) + assert resp.status == 200 def test_retries_wrong_method_list(self): """Method not in our whitelist should not be retried, even if code matches""" - headers = {"test-name": "test_wrong_method_whitelist"} - retry = Retry(total=1, status_forcelist=[418], method_whitelist=["POST"]) - resp = self.pool.request( - "GET", "/successful_retry", headers=headers, retries=retry - ) - assert resp.status == 418 + with HTTPConnectionPool(self.host, self.port) as pool: + headers = {"test-name": "test_wrong_method_whitelist"} + retry = Retry(total=1, status_forcelist=[418], method_whitelist=["POST"]) + resp = pool.request( + "GET", "/successful_retry", headers=headers, retries=retry + ) + assert resp.status == 418 def test_read_retries_unsuccessful(self): - headers = {"test-name": "test_read_retries_unsuccessful"} - resp = self.pool.request("GET", "/successful_retry", headers=headers, retries=1) - assert resp.status == 418 + with HTTPConnectionPool(self.host, self.port) as pool: + headers = {"test-name": "test_read_retries_unsuccessful"} + resp = pool.request("GET", "/successful_retry", headers=headers, retries=1) + assert resp.status == 418 def test_retry_reuse_safe(self): """ It should be possible to reuse a Retry object across requests """ - headers = {"test-name": "test_retry_safe"} - retry = Retry(total=1, status_forcelist=[418]) - resp = self.pool.request( - "GET", "/successful_retry", headers=headers, retries=retry - ) - assert resp.status == 200 - resp = self.pool.request( - "GET", "/successful_retry", headers=headers, retries=retry - ) - assert resp.status == 200 + with HTTPConnectionPool(self.host, self.port) as pool: + headers = {"test-name": "test_retry_safe"} + retry = Retry(total=1, status_forcelist=[418]) + resp = pool.request( + "GET", "/successful_retry", headers=headers, retries=retry + ) + assert resp.status == 200 + + with HTTPConnectionPool(self.host, self.port) as pool: + resp = pool.request( + "GET", "/successful_retry", headers=headers, retries=retry + ) + assert resp.status == 200 def test_retry_return_in_response(self): - headers = {"test-name": "test_retry_return_in_response"} - retry = Retry(total=2, status_forcelist=[418]) - resp = self.pool.request( - "GET", "/successful_retry", headers=headers, retries=retry - ) - assert resp.status == 200 - assert resp.retries.total == 1 - assert resp.retries.history == ( - RequestHistory("GET", "/successful_retry", None, 418, None), - ) + with HTTPConnectionPool(self.host, self.port) as pool: + headers = {"test-name": "test_retry_return_in_response"} + retry = Retry(total=2, status_forcelist=[418]) + resp = pool.request( + "GET", "/successful_retry", headers=headers, retries=retry + ) + assert resp.status == 200 + assert resp.retries.total == 1 + assert resp.retries.history == ( + RequestHistory("GET", "/successful_retry", None, 418, None), + ) def test_retry_redirect_history(self): - resp = self.pool.request("GET", "/redirect", fields={"target": "/"}) - assert resp.status == 200 - assert resp.retries.history == ( - RequestHistory("GET", "/redirect?target=%2F", None, 303, "/"), - ) + with HTTPConnectionPool(self.host, self.port) as pool: + resp = pool.request("GET", "/redirect", fields={"target": "/"}) + assert resp.status == 200 + assert resp.retries.history == ( + RequestHistory("GET", "/redirect?target=%2F", None, 303, "/"), + ) def test_multi_redirect_history(self): - r = self.pool.request( - "GET", - "/multi_redirect", - fields={"redirect_codes": "303,302,200"}, - redirect=False, - ) - assert r.status == 303 - assert r.retries.history == tuple() - - r = self.pool.request( - "GET", - "/multi_redirect", - retries=10, - fields={"redirect_codes": "303,302,301,307,302,200"}, - ) - assert r.status == 200 - assert r.data == b"Done redirecting" - - expected = [ - (303, "/multi_redirect?redirect_codes=302,301,307,302,200"), - (302, "/multi_redirect?redirect_codes=301,307,302,200"), - (301, "/multi_redirect?redirect_codes=307,302,200"), - (307, "/multi_redirect?redirect_codes=302,200"), - (302, "/multi_redirect?redirect_codes=200"), - ] - actual = [ - (history.status, history.redirect_location) for history in r.retries.history - ] - assert actual == expected + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request( + "GET", + "/multi_redirect", + fields={"redirect_codes": "303,302,200"}, + redirect=False, + ) + assert r.status == 303 + assert r.retries.history == tuple() + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request( + "GET", + "/multi_redirect", + retries=10, + fields={"redirect_codes": "303,302,301,307,302,200"}, + ) + assert r.status == 200 + assert r.data == b"Done redirecting" + + expected = [ + (303, "/multi_redirect?redirect_codes=302,301,307,302,200"), + (302, "/multi_redirect?redirect_codes=301,307,302,200"), + (301, "/multi_redirect?redirect_codes=307,302,200"), + (307, "/multi_redirect?redirect_codes=302,200"), + (302, "/multi_redirect?redirect_codes=200"), + ] + actual = [ + (history.status, history.redirect_location) + for history in r.retries.history + ] + assert actual == expected -class TestRetryAfter(HTTPDummyServerTestCase): - def setUp(self): - self.pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(self.pool.close) +class TestRetryAfter(HTTPDummyServerTestCase): def test_retry_after(self): # Request twice in a second to get a 429 response. - r = self.pool.request( - "GET", - "/retry_after", - fields={"status": "429 Too Many Requests"}, - retries=False, - ) - r = self.pool.request( - "GET", - "/retry_after", - fields={"status": "429 Too Many Requests"}, - retries=False, - ) - assert r.status == 429 + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request( + "GET", + "/retry_after", + fields={"status": "429 Too Many Requests"}, + retries=False, + ) + r = pool.request( + "GET", + "/retry_after", + fields={"status": "429 Too Many Requests"}, + retries=False, + ) + assert r.status == 429 - r = self.pool.request( - "GET", - "/retry_after", - fields={"status": "429 Too Many Requests"}, - retries=True, - ) - assert r.status == 200 - - # Request twice in a second to get a 503 response. - r = self.pool.request( - "GET", - "/retry_after", - fields={"status": "503 Service Unavailable"}, - retries=False, - ) - r = self.pool.request( - "GET", - "/retry_after", - fields={"status": "503 Service Unavailable"}, - retries=False, - ) - assert r.status == 503 + r = pool.request( + "GET", + "/retry_after", + fields={"status": "429 Too Many Requests"}, + retries=True, + ) + assert r.status == 200 - r = self.pool.request( - "GET", - "/retry_after", - fields={"status": "503 Service Unavailable"}, - retries=True, - ) - assert r.status == 200 + # Request twice in a second to get a 503 response. + r = pool.request( + "GET", + "/retry_after", + fields={"status": "503 Service Unavailable"}, + retries=False, + ) + r = pool.request( + "GET", + "/retry_after", + fields={"status": "503 Service Unavailable"}, + retries=False, + ) + assert r.status == 503 - # Ignore Retry-After header on status which is not defined in - # Retry.RETRY_AFTER_STATUS_CODES. - r = self.pool.request( - "GET", "/retry_after", fields={"status": "418 I'm a teapot"}, retries=True - ) - assert r.status == 418 + r = pool.request( + "GET", + "/retry_after", + fields={"status": "503 Service Unavailable"}, + retries=True, + ) + assert r.status == 200 + + # Ignore Retry-After header on status which is not defined in + # Retry.RETRY_AFTER_STATUS_CODES. + r = pool.request( + "GET", + "/retry_after", + fields={"status": "418 I'm a teapot"}, + retries=True, + ) + assert r.status == 418 def test_redirect_after(self): - r = self.pool.request("GET", "/redirect_after", retries=False) - assert r.status == 303 - - t = time.time() - r = self.pool.request("GET", "/redirect_after") - assert r.status == 200 - delta = time.time() - t - assert delta >= 1 - - t = time.time() - timestamp = t + 2 - r = self.pool.request("GET", "/redirect_after?date=" + str(timestamp)) - assert r.status == 200 - delta = time.time() - t - assert delta >= 1 - - # Retry-After is past - t = time.time() - timestamp = t - 1 - r = self.pool.request("GET", "/redirect_after?date=" + str(timestamp)) - delta = time.time() - t - assert r.status == 200 - assert delta < 1 + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/redirect_after", retries=False) + assert r.status == 303 + + t = time.time() + r = pool.request("GET", "/redirect_after") + assert r.status == 200 + delta = time.time() - t + assert delta >= 1 + + t = time.time() + timestamp = t + 2 + r = pool.request("GET", "/redirect_after?date=" + str(timestamp)) + assert r.status == 200 + delta = time.time() - t + assert delta >= 1 + + # Retry-After is past + t = time.time() + timestamp = t - 1 + r = pool.request("GET", "/redirect_after?date=" + str(timestamp)) + delta = time.time() - t + assert r.status == 200 + assert delta < 1 class TestFileBodiesOnRetryOrRedirect(HTTPDummyServerTestCase): - def setUp(self): - self.pool = HTTPConnectionPool(self.host, self.port, timeout=0.1) - self.addCleanup(self.pool.close) - def test_retries_put_filehandle(self): """HTTP PUT retry with a file-like object should not timeout""" - retry = Retry(total=3, status_forcelist=[418]) - # httplib reads in 8k chunks; use a larger content length - content_length = 65535 - data = b"A" * content_length - uploaded_file = io.BytesIO(data) - headers = { - "test-name": "test_retries_put_filehandle", - "Content-Length": str(content_length), - } - resp = self.pool.urlopen( - "PUT", - "/successful_retry", - headers=headers, - retries=retry, - body=uploaded_file, - assert_same_host=False, - redirect=False, - ) - assert resp.status == 200 + with HTTPConnectionPool(self.host, self.port, timeout=0.1) as pool: + retry = Retry(total=3, status_forcelist=[418]) + # httplib reads in 8k chunks; use a larger content length + content_length = 65535 + data = b"A" * content_length + uploaded_file = io.BytesIO(data) + headers = { + "test-name": "test_retries_put_filehandle", + "Content-Length": str(content_length), + } + resp = pool.urlopen( + "PUT", + "/successful_retry", + headers=headers, + retries=retry, + body=uploaded_file, + assert_same_host=False, + redirect=False, + ) + assert resp.status == 200 def test_redirect_put_file(self): """PUT with file object should work with a redirection response""" - retry = Retry(total=3, status_forcelist=[418]) - # httplib reads in 8k chunks; use a larger content length - content_length = 65535 - data = b"A" * content_length - uploaded_file = io.BytesIO(data) - headers = { - "test-name": "test_redirect_put_file", - "Content-Length": str(content_length), - } - url = "/redirect?target=/echo&status=307" - resp = self.pool.urlopen( - "PUT", - url, - headers=headers, - retries=retry, - body=uploaded_file, - assert_same_host=False, - redirect=True, - ) - assert resp.status == 200 - assert resp.data == data + with HTTPConnectionPool(self.host, self.port, timeout=0.1) as pool: + retry = Retry(total=3, status_forcelist=[418]) + # httplib reads in 8k chunks; use a larger content length + content_length = 65535 + data = b"A" * content_length + uploaded_file = io.BytesIO(data) + headers = { + "test-name": "test_redirect_put_file", + "Content-Length": str(content_length), + } + url = "/redirect?target=/echo&status=307" + resp = pool.urlopen( + "PUT", + url, + headers=headers, + retries=retry, + body=uploaded_file, + assert_same_host=False, + redirect=True, + ) + assert resp.status == 200 + assert resp.data == data def test_redirect_with_failed_tell(self): """Abort request if failed to get a position from tell()""" @@ -1099,39 +1131,34 @@ def tell(self): # httplib uses fileno if Content-Length isn't supplied, # which is unsupported by BytesIO. headers = {"Content-Length": "8"} - try: - self.pool.urlopen("PUT", url, headers=headers, body=body) - self.fail("PUT successful despite failed rewind.") - except UnrewindableBodyError as e: - assert "Unable to record file position for" in str(e) + with HTTPConnectionPool(self.host, self.port, timeout=0.1) as pool: + try: + pool.urlopen("PUT", url, headers=headers, body=body) + self.fail("PUT successful despite failed rewind.") + except UnrewindableBodyError as e: + assert "Unable to record file position for" in str(e) class TestRetryPoolSize(HTTPDummyServerTestCase): - def setUp(self): + def test_pool_size_retry(self): retries = Retry(total=1, raise_on_status=False, status_forcelist=[404]) - self.pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, maxsize=10, retries=retries, block=True - ) - self.addCleanup(self.pool.close) - - def test_pool_size_retry(self): - self.pool.urlopen("GET", "/not_found", preload_content=False) - assert self.pool.num_connections == 1 + ) as pool: + pool.urlopen("GET", "/not_found", preload_content=False) + assert pool.num_connections == 1 class TestRedirectPoolSize(HTTPDummyServerTestCase): - def setUp(self): + def test_pool_size_redirect(self): retries = Retry( total=1, raise_on_status=False, status_forcelist=[404], redirect=True ) - self.pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, maxsize=10, retries=retries, block=True - ) - self.addCleanup(self.pool.close) - - def test_pool_size_redirect(self): - self.pool.urlopen("GET", "/redirect", preload_content=False) - assert self.pool.num_connections == 1 + ) as pool: + pool.urlopen("GET", "/redirect", preload_content=False) + assert pool.num_connections == 1 if __name__ == "__main__": diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 2f9ead51..521c03e9 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -80,68 +80,69 @@ class TestHTTPS(HTTPSDummyServerTestCase): tls_protocol_name = None - def setUp(self): - self._pool = HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) - self.addCleanup(self._pool.close) - def test_simple(self): - r = self._pool.request("GET", "/") - assert r.status == 200, r.data + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200, r.data @fails_on_travis_gce def test_dotted_fqdn(self): - pool = HTTPSConnectionPool(self.host + ".", self.port, ca_certs=DEFAULT_CA) - r = pool.request("GET", "/") - assert r.status == 200, r.data + with HTTPSConnectionPool( + self.host + ".", self.port, ca_certs=DEFAULT_CA + ) as pool: + r = pool.request("GET", "/") + assert r.status == 200, r.data def test_client_intermediate(self): client_cert, client_key = ( DEFAULT_CLIENT_CERTS["certfile"], DEFAULT_CLIENT_CERTS["keyfile"], ) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, key_file=client_key, cert_file=client_cert, ca_certs=DEFAULT_CA, - ) - r = https_pool.request("GET", "/certificate") - subject = json.loads(r.data.decode("utf-8")) - assert subject["organizationalUnitName"].startswith("Testing server cert") + ) as https_pool: + r = https_pool.request("GET", "/certificate") + subject = json.loads(r.data.decode("utf-8")) + assert subject["organizationalUnitName"].startswith("Testing server cert") def test_client_no_intermediate(self): client_cert, client_key = ( DEFAULT_CLIENT_NO_INTERMEDIATE_CERTS["certfile"], DEFAULT_CLIENT_NO_INTERMEDIATE_CERTS["keyfile"], ) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_file=client_cert, key_file=client_key, ca_certs=DEFAULT_CA, - ) - try: - https_pool.request("GET", "/certificate", retries=False) - except SSLError as e: - if not ( - "alert unknown ca" in str(e) - or "invalid certificate chain" in str(e) - or "unknown Cert Authority" in str(e) - or + ) as https_pool: + try: + https_pool.request("GET", "/certificate", retries=False) + except SSLError as e: + if not ( + "alert unknown ca" in str(e) + or "invalid certificate chain" in str(e) + or "unknown Cert Authority" in str(e) + or + # https://github.com/urllib3/urllib3/issues/1422 + "connection closed via error" in str(e) + or "WSAECONNRESET" in str(e) + ): + raise + except ProtocolError as e: # https://github.com/urllib3/urllib3/issues/1422 - "connection closed via error" in str(e) - or "WSAECONNRESET" in str(e) - ): - raise - except ProtocolError as e: - # https://github.com/urllib3/urllib3/issues/1422 - if not ( - "An existing connection was forcibly closed by the remote host" - in str(e) - ): - raise + if not ( + "An existing connection was forcibly closed by the remote host" + in str(e) + ): + raise @requires_ssl_context_keyfile_password def test_client_key_password(self): @@ -149,17 +150,17 @@ def test_client_key_password(self): DEFAULT_CLIENT_CERTS["certfile"], PASSWORD_CLIENT_KEYFILE, ) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, ca_certs=DEFAULT_CA, key_file=client_key, cert_file=client_cert, key_password="letmein", - ) - r = https_pool.request("GET", "/certificate") - subject = json.loads(r.data.decode("utf-8")) - assert subject["organizationalUnitName"].startswith("Testing server cert") + ) as https_pool: + r = https_pool.request("GET", "/certificate") + subject = json.loads(r.data.decode("utf-8")) + assert subject["organizationalUnitName"].startswith("Testing server cert") @requires_ssl_context_keyfile_password def test_client_encrypted_key_requires_password(self): @@ -167,362 +168,344 @@ def test_client_encrypted_key_requires_password(self): DEFAULT_CLIENT_CERTS["certfile"], PASSWORD_CLIENT_KEYFILE, ) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, key_file=client_key, cert_file=client_cert, key_password=None, - ) + ) as https_pool: - with pytest.raises(MaxRetryError) as e: - https_pool.request("GET", "/certificate") + with pytest.raises(MaxRetryError) as e: + https_pool.request("GET", "/certificate") - assert "password is required" in str(e.value) - assert isinstance(e.value.reason, SSLError) + assert "password is required" in str(e.value) + assert isinstance(e.value.reason, SSLError) def test_verified(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - conn = https_pool._new_conn() - assert conn.__class__ == VerifiedHTTPSConnection - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] + ) as https_pool: + + conn = https_pool._new_conn() + assert conn.__class__ == VerifiedHTTPSConnection + + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + + # Modern versions of Python, or systems using PyOpenSSL, don't + # emit warnings. + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + assert not warn.called, warn.call_args_list else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning + assert warn.called + if util.HAS_SNI: + call = warn.call_args_list[0] + else: + call = warn.call_args_list[1] + error = call[0][1] + assert error == InsecurePlatformWarning def test_verified_with_context(self): ctx = util.ssl_.create_urllib3_context(cert_reqs=ssl.CERT_REQUIRED) ctx.load_verify_locations(cafile=DEFAULT_CA) - https_pool = HTTPSConnectionPool(self.host, self.port, ssl_context=ctx) - self.addCleanup(https_pool.close) - - conn = https_pool._new_conn() - assert conn.__class__ == VerifiedHTTPSConnection - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] + with HTTPSConnectionPool(self.host, self.port, ssl_context=ctx) as https_pool: + + conn = https_pool._new_conn() + assert conn.__class__ == VerifiedHTTPSConnection + + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + + # Modern versions of Python, or systems using PyOpenSSL, don't + # emit warnings. + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + assert not warn.called, warn.call_args_list else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning + assert warn.called + if util.HAS_SNI: + call = warn.call_args_list[0] + else: + call = warn.call_args_list[1] + error = call[0][1] + assert error == InsecurePlatformWarning def test_context_combines_with_ca_certs(self): ctx = util.ssl_.create_urllib3_context(cert_reqs=ssl.CERT_REQUIRED) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, ca_certs=DEFAULT_CA, ssl_context=ctx - ) - self.addCleanup(https_pool.close) - - conn = https_pool._new_conn() - assert conn.__class__ == VerifiedHTTPSConnection - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] + ) as https_pool: + + conn = https_pool._new_conn() + assert conn.__class__ == VerifiedHTTPSConnection + + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + + # Modern versions of Python, or systems using PyOpenSSL, don't + # emit warnings. + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + assert not warn.called, warn.call_args_list else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning + assert warn.called + if util.HAS_SNI: + call = warn.call_args_list[0] + else: + call = warn.call_args_list[1] + error = call[0][1] + assert error == InsecurePlatformWarning @onlyPy279OrNewer @notSecureTransport # SecureTransport does not support cert directories @notOpenSSL098 # OpenSSL 0.9.8 does not support cert directories def test_ca_dir_verified(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_cert_dir=DEFAULT_CA_DIR - ) - self.addCleanup(https_pool.close) + ) as https_pool: - conn = https_pool._new_conn() - assert conn.__class__ == VerifiedHTTPSConnection + conn = https_pool._new_conn() + assert conn.__class__ == VerifiedHTTPSConnection - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - assert not warn.called, warn.call_args_list + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + assert not warn.called, warn.call_args_list def test_invalid_common_name(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - try: - https_pool.request("GET", "/") - self.fail("Didn't raise SSL invalid common name") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "doesn't match" in str( - e.reason - ) or "certificate verify failed" in str(e.reason) + try: + https_pool.request("GET", "/") + self.fail("Didn't raise SSL invalid common name") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "doesn't match" in str( + e.reason + ) or "certificate verify failed" in str(e.reason) def test_verified_with_bad_ca_certs(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(https_pool.close) + ) as https_pool: - try: - https_pool.request("GET", "/") - self.fail("Didn't raise SSL error with bad CA certs") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "certificate verify failed" in str(e.reason), ( - "Expected 'certificate verify failed'," "instead got: %r" % e.reason - ) + try: + https_pool.request("GET", "/") + self.fail("Didn't raise SSL error with bad CA certs") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "certificate verify failed" in str(e.reason), ( + "Expected 'certificate verify failed'," "instead got: %r" % e.reason + ) def test_verified_without_ca_certs(self): # default is cert_reqs=None which is ssl.CERT_NONE - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED" - ) - self.addCleanup(https_pool.close) - - try: - https_pool.request("GET", "/") - self.fail( - "Didn't raise SSL error with no CA certs when" "CERT_REQUIRED is set" - ) - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - # there is a different error message depending on whether or - # not pyopenssl is injected - assert ( - "No root certificates specified" in str(e.reason) - or "certificate verify failed" in str(e.reason) - or "invalid certificate chain" in str(e.reason) - ), ( - "Expected 'No root certificates specified', " - "'certificate verify failed', or " - "'invalid certificate chain', " - "instead got: %r" % e.reason - ) + ) as https_pool: + + try: + https_pool.request("GET", "/") + self.fail( + "Didn't raise SSL error with no CA certs when" + "CERT_REQUIRED is set" + ) + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + # there is a different error message depending on whether or + # not pyopenssl is injected + assert ( + "No root certificates specified" in str(e.reason) + or "certificate verify failed" in str(e.reason) + or "invalid certificate chain" in str(e.reason) + ), ( + "Expected 'No root certificates specified', " + "'certificate verify failed', or " + "'invalid certificate chain', " + "instead got: %r" % e.reason + ) def test_no_ssl(self): - pool = HTTPSConnectionPool(self.host, self.port) - pool.ConnectionCls = None - self.addCleanup(pool.close) - with pytest.raises(SSLError): - pool._new_conn() - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + with HTTPSConnectionPool(self.host, self.port) as pool: + pool.ConnectionCls = None + with pytest.raises(SSLError): + pool._new_conn() + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) def test_unverified_ssl(self): """ Test that bare HTTPSConnection can connect, make requests """ - pool = HTTPSConnectionPool(self.host, self.port, cert_reqs=ssl.CERT_NONE) - self.addCleanup(pool.close) + with HTTPSConnectionPool(self.host, self.port, cert_reqs=ssl.CERT_NONE) as pool: - with mock.patch("warnings.warn") as warn: - r = pool.request("GET", "/") - assert r.status == 200 - assert warn.called + with mock.patch("warnings.warn") as warn: + r = pool.request("GET", "/") + assert r.status == 200 + assert warn.called - # Modern versions of Python, or systems using PyOpenSSL, only emit - # the unverified warning. Older systems may also emit other - # warnings, which we want to ignore here. - calls = warn.call_args_list - assert InsecureRequestWarning in [x[0][1] for x in calls] + # Modern versions of Python, or systems using PyOpenSSL, only emit + # the unverified warning. Older systems may also emit other + # warnings, which we want to ignore here. + calls = warn.call_args_list + assert InsecureRequestWarning in [x[0][1] for x in calls] def test_ssl_unverified_with_ca_certs(self): - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_NONE", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(pool.close) + ) as pool: - with mock.patch("warnings.warn") as warn: - r = pool.request("GET", "/") - assert r.status == 200 - assert warn.called - - # Modern versions of Python, or systems using PyOpenSSL, only emit - # the unverified warning. Older systems may also emit other - # warnings, which we want to ignore here. - calls = warn.call_args_list - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - category = calls[0][0][1] - elif util.HAS_SNI: - category = calls[1][0][1] - else: - category = calls[2][0][1] - assert category == InsecureRequestWarning + with mock.patch("warnings.warn") as warn: + r = pool.request("GET", "/") + assert r.status == 200 + assert warn.called + + # Modern versions of Python, or systems using PyOpenSSL, only emit + # the unverified warning. Older systems may also emit other + # warnings, which we want to ignore here. + calls = warn.call_args_list + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + category = calls[0][0][1] + elif util.HAS_SNI: + category = calls[1][0][1] + else: + category = calls[2][0][1] + assert category == InsecureRequestWarning def test_assert_hostname_false(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_hostname = False - https_pool.request("GET", "/") + https_pool.assert_hostname = False + https_pool.request("GET", "/") def test_assert_specific_hostname(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_hostname = "localhost" - https_pool.request("GET", "/") + https_pool.assert_hostname = "localhost" + https_pool.request("GET", "/") def test_server_hostname(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA, server_hostname="localhost", - ) - self.addCleanup(https_pool.close) + ) as https_pool: - conn = https_pool._new_conn() - conn.request("GET", "/") + conn = https_pool._new_conn() + conn.request("GET", "/") - # Assert the wrapping socket is using the passed-through SNI name. - # pyopenssl doesn't let you pull the server_hostname back off the - # socket, so only add this assertion if the attribute is there (i.e. - # the python ssl module). - if hasattr(conn.sock, "server_hostname"): - assert conn.sock.server_hostname == "localhost" + # Assert the wrapping socket is using the passed-through SNI name. + # pyopenssl doesn't let you pull the server_hostname back off the + # socket, so only add this assertion if the attribute is there (i.e. + # the python ssl module). + if hasattr(conn.sock, "server_hostname"): + assert conn.sock.server_hostname == "localhost" def test_assert_fingerprint_md5(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "F2:06:5A:42:10:3F:45:1C:17:FE:E6:" "07:1E:8A:86:E5" - ) + https_pool.assert_fingerprint = ( + "F2:06:5A:42:10:3F:45:1C:17:FE:E6:" "07:1E:8A:86:E5" + ) - https_pool.request("GET", "/") + https_pool.request("GET", "/") def test_assert_fingerprint_sha1(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) - https_pool.request("GET", "/") + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) + https_pool.request("GET", "/") def test_assert_fingerprint_sha256(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "C5:4D:0B:83:84:89:2E:AE:B4:58:BB:12:" - "F7:A6:C4:76:05:03:88:D8:57:65:51:F3:" - "1E:60:B0:8B:70:18:64:E6" - ) - https_pool.request("GET", "/") + https_pool.assert_fingerprint = ( + "C5:4D:0B:83:84:89:2E:AE:B4:58:BB:12:" + "F7:A6:C4:76:05:03:88:D8:57:65:51:F3:" + "1E:60:B0:8B:70:18:64:E6" + ) + https_pool.request("GET", "/") def test_assert_invalid_fingerprint(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" - ) + https_pool.assert_fingerprint = ( + "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" + ) - def _test_request(pool): - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + def _test_request(pool): + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) - _test_request(https_pool) - https_pool._get_conn() + _test_request(https_pool) + https_pool._get_conn() - # Uneven length - https_pool.assert_fingerprint = "AA:A" - _test_request(https_pool) - https_pool._get_conn() + # Uneven length + https_pool.assert_fingerprint = "AA:A" + _test_request(https_pool) + https_pool._get_conn() - # Invalid length - https_pool.assert_fingerprint = "AA" - _test_request(https_pool) + # Invalid length + https_pool.assert_fingerprint = "AA" + _test_request(https_pool) def test_verify_none_and_bad_fingerprint(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_NONE", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" - ) - with pytest.raises(MaxRetryError) as cm: - https_pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + https_pool.assert_fingerprint = ( + "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" + ) + with pytest.raises(MaxRetryError) as cm: + https_pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) def test_verify_none_and_good_fingerprint(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_NONE", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) - https_pool.request("GET", "/") + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) + https_pool.request("GET", "/") @notSecureTransport def test_good_fingerprint_and_hostname_mismatch(self): @@ -530,147 +513,158 @@ def test_good_fingerprint_and_hostname_mismatch(self): # hostname validation without turning off all validation, which this # test doesn't do (deliberately). We should revisit this if we make # new decisions. - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) + ) as https_pool: - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) - https_pool.request("GET", "/") + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) + https_pool.request("GET", "/") @requires_network def test_https_timeout(self): - timeout = Timeout(connect=0.001) - https_pool = HTTPSConnectionPool( - TARPIT_HOST, - self.port, - timeout=timeout, - retries=False, - cert_reqs="CERT_REQUIRED", - ) - self.addCleanup(https_pool.close) timeout = Timeout(total=None, connect=0.001) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( TARPIT_HOST, self.port, timeout=timeout, retries=False, cert_reqs="CERT_REQUIRED", - ) - self.addCleanup(https_pool.close) - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/") + ) as https_pool: + with pytest.raises(ConnectTimeoutError): + https_pool.request("GET", "/") timeout = Timeout(read=0.01) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, timeout=timeout, retries=False, cert_reqs="CERT_REQUIRED", - ) - self.addCleanup(https_pool.close) - https_pool.ca_certs = DEFAULT_CA - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) + ) as https_pool: + https_pool.ca_certs = DEFAULT_CA + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) timeout = Timeout(total=None) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, timeout=timeout, cert_reqs="CERT_NONE" - ) - self.addCleanup(https_pool.close) - https_pool.request("GET", "/") + ) as https_pool: + https_pool.request("GET", "/") def test_tunnel(self): """ test the _tunnel behavior """ timeout = Timeout(total=None) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, timeout=timeout, cert_reqs="CERT_NONE" - ) - self.addCleanup(https_pool.close) - conn = https_pool._new_conn() - self.addCleanup(conn.close) - conn.set_tunnel(self.host, self.port) - conn._tunnel = mock.Mock() - https_pool._make_request(conn, "GET", "/") - conn._tunnel.assert_called_once_with() + ) as https_pool: + + conn = https_pool._new_conn() + try: + conn.set_tunnel(self.host, self.port) + conn._tunnel = mock.Mock() + https_pool._make_request(conn, "GET", "/") + conn._tunnel.assert_called_once_with() + finally: + conn.close() @requires_network def test_enhanced_timeout(self): - def new_pool(timeout, cert_reqs="CERT_REQUIRED"): - https_pool = HTTPSConnectionPool( - TARPIT_HOST, - self.port, - timeout=timeout, - retries=False, - cert_reqs=cert_reqs, - ) - self.addCleanup(https_pool.close) - return https_pool - - https_pool = new_pool(Timeout(connect=0.001)) - conn = https_pool._new_conn() - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/") - with pytest.raises(ConnectTimeoutError): - https_pool._make_request(conn, "GET", "/") - - https_pool = new_pool(Timeout(connect=5)) - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/", timeout=Timeout(connect=0.001)) + with HTTPSConnectionPool( + TARPIT_HOST, + self.port, + timeout=Timeout(connect=0.001), + retries=False, + cert_reqs="CERT_REQUIRED", + ) as https_pool: + conn = https_pool._new_conn() + try: + with pytest.raises(ConnectTimeoutError): + https_pool.request("GET", "/") + with pytest.raises(ConnectTimeoutError): + https_pool._make_request(conn, "GET", "/") + finally: + conn.close() + + with HTTPSConnectionPool( + TARPIT_HOST, + self.port, + timeout=Timeout(connect=5), + retries=False, + cert_reqs="CERT_REQUIRED", + ) as https_pool: + with pytest.raises(ConnectTimeoutError): + https_pool.request("GET", "/", timeout=Timeout(connect=0.001)) - t = Timeout(total=None) - https_pool = new_pool(t) - conn = https_pool._new_conn() - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/", timeout=Timeout(total=None, connect=0.001)) + with HTTPSConnectionPool( + TARPIT_HOST, + self.port, + timeout=Timeout(total=None), + retries=False, + cert_reqs="CERT_REQUIRED", + ) as https_pool: + conn = https_pool._new_conn() + try: + with pytest.raises(ConnectTimeoutError): + https_pool.request( + "GET", "/", timeout=Timeout(total=None, connect=0.001) + ) + finally: + conn.close() def test_enhanced_ssl_connection(self): fingerprint = "92:81:FE:85:F7:0C:26:60:EC:D6:B3:BF:93:CF:F9:71:CC:07:7D:0A" - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA, assert_fingerprint=fingerprint, - ) - self.addCleanup(https_pool.close) + ) as https_pool: - r = https_pool.request("GET", "/") - assert r.status == 200 + r = https_pool.request("GET", "/") + assert r.status == 200 @onlyPy279OrNewer def test_ssl_correct_system_time(self): - self._pool.cert_reqs = "CERT_REQUIRED" - self._pool.ca_certs = DEFAULT_CA + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA + ) as https_pool: + https_pool.cert_reqs = "CERT_REQUIRED" + https_pool.ca_certs = DEFAULT_CA - w = self._request_without_resource_warnings("GET", "/") - assert [] == w + w = self._request_without_resource_warnings("GET", "/") + assert [] == w @onlyPy279OrNewer def test_ssl_wrong_system_time(self): - self._pool.cert_reqs = "CERT_REQUIRED" - self._pool.ca_certs = DEFAULT_CA - with mock.patch("urllib3.connection.datetime") as mock_date: - mock_date.date.today.return_value = datetime.date(1970, 1, 1) + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA + ) as https_pool: + https_pool.cert_reqs = "CERT_REQUIRED" + https_pool.ca_certs = DEFAULT_CA + with mock.patch("urllib3.connection.datetime") as mock_date: + mock_date.date.today.return_value = datetime.date(1970, 1, 1) - w = self._request_without_resource_warnings("GET", "/") + w = self._request_without_resource_warnings("GET", "/") - assert len(w) == 1 - warning = w[0] + assert len(w) == 1 + warning = w[0] - assert SystemTimeWarning == warning.category - assert str(RECENT_DATE) in warning.message.args[0] + assert SystemTimeWarning == warning.category + assert str(RECENT_DATE) in warning.message.args[0] def _request_without_resource_warnings(self, method, url): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - self._pool.request(method, url) + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA + ) as https_pool: + https_pool.request(method, url) return [x for x in w if not isinstance(x.message, ResourceWarning)] @@ -678,9 +672,12 @@ def test_set_ssl_version_to_tls_version(self): if self.tls_protocol_name is None: pytest.skip("Skipping base test class") - self._pool.ssl_version = self.certs["ssl_version"] - r = self._pool.request("GET", "/") - assert r.status == 200, r.data + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA + ) as https_pool: + https_pool.ssl_version = self.certs["ssl_version"] + r = https_pool.request("GET", "/") + assert r.status == 200, r.data def test_set_cert_default_cert_required(self): conn = VerifiedHTTPSConnection(self.host, self.port) @@ -691,13 +688,17 @@ def test_tls_protocol_name_of_socket(self): if self.tls_protocol_name is None: pytest.skip("Skipping base test class") - conn = self._pool._get_conn() - conn.connect() - - if not hasattr(conn.sock, "version"): - pytest.skip("SSLSocket.version() not available") - - assert conn.sock.version() == self.tls_protocol_name + with HTTPSConnectionPool( + self.host, self.port, ca_certs=DEFAULT_CA + ) as https_pool: + conn = https_pool._get_conn() + try: + conn.connect() + if not hasattr(conn.sock, "version"): + pytest.skip("SSLSocket.version() not available") + assert conn.sock.version() == self.tls_protocol_name + finally: + conn.close() @requiresTLSv1() @@ -730,13 +731,12 @@ def test_warning_for_certs_without_a_san(self): """Ensure that a warning is raised when the cert from the server has no Subject Alternative Name.""" with mock.patch("warnings.warn") as warn: - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=NO_SAN_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 - assert warn.called + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 + assert warn.called class TestHTTPS_IPSAN(HTTPSDummyServerTestCase): @@ -749,12 +749,11 @@ def test_can_validate_ip_san(self): except ImportError: pytest.skip("Only runs on systems with an ipaddress module") - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 class TestHTTPS_IPv6Addr(IPV6HTTPSDummyServerTestCase): @@ -763,12 +762,11 @@ class TestHTTPS_IPv6Addr(IPV6HTTPSDummyServerTestCase): @pytest.mark.skipif(not HAS_IPV6, reason="Only runs on IPv6 systems") def test_strip_square_brackets_before_validating(self): """Test that the fix for #760 works.""" - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "[::1]", self.port, cert_reqs="CERT_REQUIRED", ca_certs=IPV6_ADDR_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 class TestHTTPS_IPV6SAN(IPV6HTTPSDummyServerTestCase): @@ -781,12 +779,11 @@ def test_can_validate_ipv6_san(self): except ImportError: pytest.skip("Only runs on systems with an ipaddress module") - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "[::1]", self.port, cert_reqs="CERT_REQUIRED", ca_certs=IPV6_SAN_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 if __name__ == "__main__": diff --git a/test/with_dummyserver/test_no_ssl.py b/test/with_dummyserver/test_no_ssl.py index a3cb63ca..f2d8059b 100644 --- a/test/with_dummyserver/test_no_ssl.py +++ b/test/with_dummyserver/test_no_ssl.py @@ -12,17 +12,17 @@ class TestHTTPWithoutSSL(HTTPDummyServerTestCase, TestWithoutSSL): def test_simple(self): - pool = urllib3.HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - r = pool.request("GET", "/") - assert r.status == 200, r.data + with urllib3.HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/") + assert r.status == 200, r.data class TestHTTPSWithoutSSL(HTTPSDummyServerTestCase, TestWithoutSSL): def test_simple(self): - pool = urllib3.HTTPSConnectionPool(self.host, self.port, cert_reqs="NONE") - self.addCleanup(pool.close) - try: - pool.request("GET", "/") - except urllib3.exceptions.SSLError as e: - assert "SSL module is not available" in str(e) + with urllib3.HTTPSConnectionPool( + self.host, self.port, cert_reqs="NONE" + ) as pool: + try: + pool.request("GET", "/") + except urllib3.exceptions.SSLError as e: + assert "SSL module is not available" in str(e) diff --git a/test/with_dummyserver/test_poolmanager.py b/test/with_dummyserver/test_poolmanager.py index 14143e4f..7aa7b036 100644 --- a/test/with_dummyserver/test_poolmanager.py +++ b/test/with_dummyserver/test_poolmanager.py @@ -17,338 +17,333 @@ def setUp(self): self.base_url_alt = "http://%s:%d" % (self.host_alt, self.port) def test_redirect(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/" % self.base_url}, - redirect=False, - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/" % self.base_url}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/" % self.base_url}, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/" % self.base_url}, + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_redirect_twice(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/redirect" % self.base_url}, - redirect=False, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/redirect" % self.base_url}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) + }, + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_redirect_to_relative_url(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "/redirect"}, - redirect=False, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "/redirect"}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", "%s/redirect" % self.base_url, fields={"target": "/redirect"} - ) + r = http.request( + "GET", "%s/redirect" % self.base_url, fields={"target": "/redirect"} + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_cross_host_redirect(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: + + cross_host_location = "%s/echo?a=b" % self.base_url_alt + try: + http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": cross_host_location}, + timeout=1, + retries=0, + ) + self.fail( + "Request succeeded instead of raising an exception like it should." + ) + + except MaxRetryError: + pass - cross_host_location = "%s/echo?a=b" % self.base_url_alt - try: - http.request( + r = http.request( "GET", "%s/redirect" % self.base_url, - fields={"target": cross_host_location}, + fields={"target": "%s/echo?a=b" % self.base_url_alt}, timeout=1, - retries=0, - ) - self.fail( - "Request succeeded instead of raising an exception like it should." + retries=1, ) - except MaxRetryError: - pass - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/echo?a=b" % self.base_url_alt}, - timeout=1, - retries=1, - ) - - assert r._pool.host == self.host_alt + assert r._pool.host == self.host_alt def test_too_many_redirects(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: + + try: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" + % (self.base_url, self.base_url) + }, + retries=1, + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass + + try: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" + % (self.base_url, self.base_url) + }, + retries=Retry(total=None, redirect=1), + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass - try: - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - retries=1, - ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass + def test_redirect_cross_host_remove_headers(self): + with PoolManager() as http: - try: r = http.request( "GET", "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - retries=Retry(total=None, redirect=1), + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"Authorization": "foo"}, ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass - - def test_redirect_cross_host_remove_headers(self): - http = PoolManager() - self.addCleanup(http.clear) - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"Authorization": "foo"}, - ) + assert r.status == 200 - assert r.status == 200 + data = json.loads(r.data.decode("utf-8")) - data = json.loads(r.data.decode("utf-8")) + assert "Authorization" not in data - assert "Authorization" not in data - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"authorization": "foo"}, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"authorization": "foo"}, + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "authorization" not in data - assert "Authorization" not in data + assert "authorization" not in data + assert "Authorization" not in data def test_redirect_cross_host_no_remove_headers(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"Authorization": "foo"}, - retries=Retry(remove_headers_on_redirect=[]), - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"Authorization": "foo"}, + retries=Retry(remove_headers_on_redirect=[]), + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert data["Authorization"] == "foo" + assert data["Authorization"] == "foo" def test_redirect_cross_host_set_removed_headers(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"X-API-Secret": "foo", "Authorization": "bar"}, - retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"X-API-Secret": "foo", "Authorization": "bar"}, + retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "X-API-Secret" not in data - assert data["Authorization"] == "bar" + assert "X-API-Secret" not in data + assert data["Authorization"] == "bar" - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"x-api-secret": "foo", "authorization": "bar"}, - retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"x-api-secret": "foo", "authorization": "bar"}, + retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "x-api-secret" not in data - assert "X-API-Secret" not in data - assert data["Authorization"] == "bar" + assert "x-api-secret" not in data + assert "X-API-Secret" not in data + assert data["Authorization"] == "bar" def test_raise_on_redirect(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - retries=Retry(total=None, redirect=1, raise_on_redirect=False), - ) + with PoolManager() as http: - assert r.status == 303 - - def test_raise_on_status(self): - http = PoolManager() - self.addCleanup(http.clear) - - try: - # the default is to raise r = http.request( "GET", - "%s/status" % self.base_url, - fields={"status": "500 Internal Server Error"}, - retries=Retry(total=1, status_forcelist=range(500, 600)), + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) + }, + retries=Retry(total=None, redirect=1, raise_on_redirect=False), ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass - try: - # raise explicitly + assert r.status == 303 + + def test_raise_on_status(self): + with PoolManager() as http: + + try: + # the default is to raise + r = http.request( + "GET", + "%s/status" % self.base_url, + fields={"status": "500 Internal Server Error"}, + retries=Retry(total=1, status_forcelist=range(500, 600)), + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass + + try: + # raise explicitly + r = http.request( + "GET", + "%s/status" % self.base_url, + fields={"status": "500 Internal Server Error"}, + retries=Retry( + total=1, status_forcelist=range(500, 600), raise_on_status=True + ), + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass + + # don't raise r = http.request( "GET", "%s/status" % self.base_url, fields={"status": "500 Internal Server Error"}, retries=Retry( - total=1, status_forcelist=range(500, 600), raise_on_status=True + total=1, status_forcelist=range(500, 600), raise_on_status=False ), ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass - - # don't raise - r = http.request( - "GET", - "%s/status" % self.base_url, - fields={"status": "500 Internal Server Error"}, - retries=Retry( - total=1, status_forcelist=range(500, 600), raise_on_status=False - ), - ) - - assert r.status == 500 + + assert r.status == 500 def test_missing_port(self): # Can a URL that lacks an explicit port like ':80' succeed, or # will all such URLs fail with an error? - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: - # By globally adjusting `port_by_scheme` we pretend for a moment - # that HTTP's default port is not 80, but is the port at which - # our test server happens to be listening. - port_by_scheme["http"] = self.port - try: - r = http.request("GET", "http://%s/" % self.host, retries=0) - finally: - port_by_scheme["http"] = 80 + # By globally adjusting `port_by_scheme` we pretend for a moment + # that HTTP's default port is not 80, but is the port at which + # our test server happens to be listening. + port_by_scheme["http"] = self.port + try: + r = http.request("GET", "http://%s/" % self.host, retries=0) + finally: + port_by_scheme["http"] = 80 - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_headers(self): - http = PoolManager(headers={"Foo": "bar"}) - self.addCleanup(http.clear) - - r = http.request("GET", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request("POST", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request_encode_url("GET", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request_encode_body("POST", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request_encode_url( - "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - - r = http.request_encode_body( - "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" + with PoolManager(headers={"Foo": "bar"}) as http: + + r = http.request("GET", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + + r = http.request("POST", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + + r = http.request_encode_url("GET", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + + r = http.request_encode_body("POST", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + + r = http.request_encode_url( + "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + + r = http.request_encode_body( + "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" def test_http_with_ssl_keywords(self): - http = PoolManager(ca_certs="REQUIRED") - self.addCleanup(http.clear) + with PoolManager(ca_certs="REQUIRED") as http: - r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) - assert r.status == 200 + r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) + assert r.status == 200 def test_http_with_ca_cert_dir(self): - http = PoolManager(ca_certs="REQUIRED", ca_cert_dir="/nosuchdir") - self.addCleanup(http.clear) + with PoolManager(ca_certs="REQUIRED", ca_cert_dir="/nosuchdir") as http: - r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) - assert r.status == 200 + r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) + assert r.status == 200 @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not supported on this system") @@ -357,9 +352,8 @@ def setUp(self): self.base_url = "http://[%s]:%d" % (self.host, self.port) def test_ipv6(self): - http = PoolManager() - self.addCleanup(http.clear) - http.request("GET", self.base_url) + with PoolManager() as http: + http.request("GET", self.base_url) if __name__ == "__main__": diff --git a/test/with_dummyserver/test_proxy_poolmanager.py b/test/with_dummyserver/test_proxy_poolmanager.py index 4c4e506d..eed47b1d 100644 --- a/test/with_dummyserver/test_proxy_poolmanager.py +++ b/test/with_dummyserver/test_proxy_poolmanager.py @@ -23,349 +23,348 @@ def setUp(self): self.proxy_url = "http://%s:%d" % (self.proxy_host, self.proxy_port) def test_basic_proxy(self): - http = proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: - r = http.request("GET", "%s/" % self.http_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.http_url) + assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url) + assert r.status == 200 def test_nagle_proxy(self): """ Test that proxy connections do not have TCP_NODELAY turned on """ - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - hc2 = http.connection_from_host(self.http_host, self.http_port) - conn = hc2._get_conn() - self.addCleanup(conn.close) - hc2._make_request(conn, "GET", "/") - tcp_nodelay_setting = conn.sock.getsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY - ) - assert tcp_nodelay_setting == 0, ( - "Expected TCP_NODELAY for proxies to be set " - "to zero, instead was %s" % tcp_nodelay_setting - ) + with ProxyManager(self.proxy_url) as http: + hc2 = http.connection_from_host(self.http_host, self.http_port) + conn = hc2._get_conn() + try: + hc2._make_request(conn, "GET", "/") + tcp_nodelay_setting = conn.sock.getsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY + ) + assert tcp_nodelay_setting == 0, ( + "Expected TCP_NODELAY for proxies to be set " + "to zero, instead was %s" % tcp_nodelay_setting + ) + finally: + conn.close() def test_proxy_conn_fail(self): host, port = get_unreachable_address() - http = proxy_from_url("http://%s:%s/" % (host, port), retries=1, timeout=0.05) - self.addCleanup(http.clear) - with pytest.raises(MaxRetryError): - http.request("GET", "%s/" % self.https_url) - with pytest.raises(MaxRetryError): - http.request("GET", "%s/" % self.http_url) - - try: - http.request("GET", "%s/" % self.http_url) - self.fail("Failed to raise retry error.") - except MaxRetryError as e: - assert type(e.reason) == ProxyError + with proxy_from_url( + "http://%s:%s/" % (host, port), retries=1, timeout=0.05 + ) as http: + with pytest.raises(MaxRetryError): + http.request("GET", "%s/" % self.https_url) + with pytest.raises(MaxRetryError): + http.request("GET", "%s/" % self.http_url) + + try: + http.request("GET", "%s/" % self.http_url) + self.fail("Failed to raise retry error.") + except MaxRetryError as e: + assert type(e.reason) == ProxyError def test_oldapi(self): - http = ProxyManager(connection_from_url(self.proxy_url), ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) + with ProxyManager( + connection_from_url(self.proxy_url), ca_certs=DEFAULT_CA + ) as http: - r = http.request("GET", "%s/" % self.http_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.http_url) + assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url) + assert r.status == 200 def test_proxy_verified(self): - http = proxy_from_url( + with proxy_from_url( self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(http.clear) - https_pool = http._new_pool("https", self.https_host, self.https_port) - try: - https_pool.request("GET", "/", retries=0) - self.fail("Didn't raise SSL error with wrong CA") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "certificate verify failed" in str(e.reason), ( - "Expected 'certificate verify failed'," "instead got: %r" % e.reason + ) as http: + https_pool = http._new_pool("https", self.https_host, self.https_port) + try: + https_pool.request("GET", "/", retries=0) + self.fail("Didn't raise SSL error with wrong CA") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "certificate verify failed" in str(e.reason), ( + "Expected 'certificate verify failed'," "instead got: %r" % e.reason + ) + + http = proxy_from_url( + self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA ) + https_pool = http._new_pool("https", self.https_host, self.https_port) - http = proxy_from_url(self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA) - https_pool = http._new_pool("https", self.https_host, self.https_port) + conn = https_pool._new_conn() + assert conn.__class__ == VerifiedHTTPSConnection + https_pool.request("GET", "/") # Should succeed without exceptions. - conn = https_pool._new_conn() - assert conn.__class__ == VerifiedHTTPSConnection - https_pool.request("GET", "/") # Should succeed without exceptions. - - http = proxy_from_url(self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA) - https_fail_pool = http._new_pool("https", "127.0.0.1", self.https_port) + http = proxy_from_url( + self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA + ) + https_fail_pool = http._new_pool("https", "127.0.0.1", self.https_port) - try: - https_fail_pool.request("GET", "/", retries=0) - self.fail("Didn't raise SSL invalid common name") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "doesn't match" in str(e.reason) + try: + https_fail_pool.request("GET", "/", retries=0) + self.fail("Didn't raise SSL invalid common name") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "doesn't match" in str(e.reason) def test_redirect(self): - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url) as http: - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/" % self.http_url}, - redirect=False, - ) + r = http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": "%s/" % self.http_url}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/" % self.http_url}, - ) + r = http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": "%s/" % self.http_url}, + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_cross_host_redirect(self): - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - - cross_host_location = "%s/echo?a=b" % self.http_url_alt - try: - http.request( + with proxy_from_url(self.proxy_url) as http: + + cross_host_location = "%s/echo?a=b" % self.http_url_alt + try: + http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": cross_host_location}, + timeout=1, + retries=0, + ) + self.fail("We don't want to follow redirects here.") + + except MaxRetryError: + pass + + r = http.request( "GET", "%s/redirect" % self.http_url, - fields={"target": cross_host_location}, + fields={"target": "%s/echo?a=b" % self.http_url_alt}, timeout=1, - retries=0, + retries=1, ) - self.fail("We don't want to follow redirects here.") - - except MaxRetryError: - pass - - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/echo?a=b" % self.http_url_alt}, - timeout=1, - retries=1, - ) - assert r._pool.host != self.http_host_alt + assert r._pool.host != self.http_host_alt def test_cross_protocol_redirect(self): - http = proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) - - cross_protocol_location = "%s/echo?a=b" % self.https_url - try: - http.request( + with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: + + cross_protocol_location = "%s/echo?a=b" % self.https_url + try: + http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": cross_protocol_location}, + timeout=1, + retries=0, + ) + self.fail("We don't want to follow redirects here.") + + except MaxRetryError: + pass + + r = http.request( "GET", "%s/redirect" % self.http_url, - fields={"target": cross_protocol_location}, + fields={"target": "%s/echo?a=b" % self.https_url}, timeout=1, - retries=0, + retries=1, ) - self.fail("We don't want to follow redirects here.") - - except MaxRetryError: - pass - - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/echo?a=b" % self.https_url}, - timeout=1, - retries=1, - ) - assert r._pool.host == self.https_host + assert r._pool.host == self.https_host def test_headers(self): - http = proxy_from_url( + with proxy_from_url( self.proxy_url, headers={"Foo": "bar"}, proxy_headers={"Hickory": "dickory"}, ca_certs=DEFAULT_CA, - ) - self.addCleanup(http.clear) - - r = http.request_encode_url("GET", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_url("GET", "%s/headers" % self.http_url_alt) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host_alt, - self.http_port, - ) - - r = http.request_encode_url("GET", "%s/headers" % self.https_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, - ) - - r = http.request_encode_body("POST", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_url( - "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_url( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, - ) - - r = http.request_encode_body( - "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_body( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, - ) + ) as http: + + r = http.request_encode_url("GET", "%s/headers" % self.http_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_url("GET", "%s/headers" % self.http_url_alt) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host_alt, + self.http_port, + ) + + r = http.request_encode_url("GET", "%s/headers" % self.https_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") is None + assert returned_headers.get("Host") == "%s:%s" % ( + self.https_host, + self.https_port, + ) + + r = http.request_encode_body("POST", "%s/headers" % self.http_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_url( + "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_url( + "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") is None + assert returned_headers.get("Host") == "%s:%s" % ( + self.https_host, + self.https_port, + ) + + r = http.request_encode_body( + "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_body( + "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") is None + assert returned_headers.get("Host") == "%s:%s" % ( + self.https_host, + self.https_port, + ) def test_headerdict(self): default_headers = HTTPHeaderDict(a="b") proxy_headers = HTTPHeaderDict() proxy_headers.add("foo", "bar") - http = proxy_from_url( + with proxy_from_url( self.proxy_url, headers=default_headers, proxy_headers=proxy_headers - ) - self.addCleanup(http.clear) + ) as http: - request_headers = HTTPHeaderDict(baz="quux") - r = http.request("GET", "%s/headers" % self.http_url, headers=request_headers) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Baz") == "quux" + request_headers = HTTPHeaderDict(baz="quux") + r = http.request( + "GET", "%s/headers" % self.http_url, headers=request_headers + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Baz") == "quux" def test_proxy_pooling(self): - http = proxy_from_url(self.proxy_url, cert_reqs="NONE") - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url, cert_reqs="NONE") as http: - for x in range(2): - http.urlopen("GET", self.http_url) - assert len(http.pools) == 1 + for x in range(2): + http.urlopen("GET", self.http_url) + assert len(http.pools) == 1 - for x in range(2): - http.urlopen("GET", self.http_url_alt) - assert len(http.pools) == 1 + for x in range(2): + http.urlopen("GET", self.http_url_alt) + assert len(http.pools) == 1 - for x in range(2): - http.urlopen("GET", self.https_url) - assert len(http.pools) == 2 + for x in range(2): + http.urlopen("GET", self.https_url) + assert len(http.pools) == 2 - for x in range(2): - http.urlopen("GET", self.https_url_alt) - assert len(http.pools) == 3 + for x in range(2): + http.urlopen("GET", self.https_url_alt) + assert len(http.pools) == 3 def test_proxy_pooling_ext(self): - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - - hc1 = http.connection_from_url(self.http_url) - hc2 = http.connection_from_host(self.http_host, self.http_port) - hc3 = http.connection_from_url(self.http_url_alt) - hc4 = http.connection_from_host(self.http_host_alt, self.http_port) - assert hc1 == hc2 - assert hc2 == hc3 - assert hc3 == hc4 - - sc1 = http.connection_from_url(self.https_url) - sc2 = http.connection_from_host( - self.https_host, self.https_port, scheme="https" - ) - sc3 = http.connection_from_url(self.https_url_alt) - sc4 = http.connection_from_host( - self.https_host_alt, self.https_port, scheme="https" - ) - assert sc1 == sc2 - assert sc2 != sc3 - assert sc3 == sc4 + with proxy_from_url(self.proxy_url) as http: + + hc1 = http.connection_from_url(self.http_url) + hc2 = http.connection_from_host(self.http_host, self.http_port) + hc3 = http.connection_from_url(self.http_url_alt) + hc4 = http.connection_from_host(self.http_host_alt, self.http_port) + assert hc1 == hc2 + assert hc2 == hc3 + assert hc3 == hc4 + + sc1 = http.connection_from_url(self.https_url) + sc2 = http.connection_from_host( + self.https_host, self.https_port, scheme="https" + ) + sc3 = http.connection_from_url(self.https_url_alt) + sc4 = http.connection_from_host( + self.https_host_alt, self.https_port, scheme="https" + ) + assert sc1 == sc2 + assert sc2 != sc3 + assert sc3 == sc4 @pytest.mark.timeout(0.5) @requires_network def test_https_proxy_timeout(self): - https = proxy_from_url("https://{host}".format(host=TARPIT_HOST)) - self.addCleanup(https.clear) - try: - https.request("GET", self.http_url, timeout=0.001) - self.fail("Failed to raise retry error.") - except MaxRetryError as e: - assert type(e.reason) == ConnectTimeoutError + with proxy_from_url("https://{host}".format(host=TARPIT_HOST)) as https: + try: + https.request("GET", self.http_url, timeout=0.001) + self.fail("Failed to raise retry error.") + except MaxRetryError as e: + assert type(e.reason) == ConnectTimeoutError @pytest.mark.timeout(0.5) @requires_network def test_https_proxy_pool_timeout(self): - https = proxy_from_url("https://{host}".format(host=TARPIT_HOST), timeout=0.001) - self.addCleanup(https.clear) - try: - https.request("GET", self.http_url) - self.fail("Failed to raise retry error.") - except MaxRetryError as e: - assert type(e.reason) == ConnectTimeoutError + with proxy_from_url( + "https://{host}".format(host=TARPIT_HOST), timeout=0.001 + ) as https: + try: + https.request("GET", self.http_url) + self.fail("Failed to raise retry error.") + except MaxRetryError as e: + assert type(e.reason) == ConnectTimeoutError def test_scheme_host_case_insensitive(self): """Assert that upper-case schemes and hosts are normalized.""" - http = proxy_from_url(self.proxy_url.upper(), ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url.upper(), ca_certs=DEFAULT_CA) as http: - r = http.request("GET", "%s/" % self.http_url.upper()) - assert r.status == 200 + r = http.request("GET", "%s/" % self.http_url.upper()) + assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url.upper()) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url.upper()) + assert r.status == 200 class TestIPv6HTTPProxyManager(IPv6HTTPDummyProxyTestCase): @@ -377,14 +376,13 @@ def setUp(self): self.proxy_url = "http://[%s]:%d" % (self.proxy_host, self.proxy_port) def test_basic_ipv6_proxy(self): - http = proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: - r = http.request("GET", "%s/" % self.http_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.http_url) + assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url) + assert r.status == 200 if __name__ == "__main__": diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index a62fa916..2e5e9db8 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -66,11 +66,10 @@ def multicookie_response_handler(listener): sock.close() self._start_server(multicookie_response_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - r = pool.request("GET", "/", retries=0) - assert r.headers == {"set-cookie": "foo=1, bar=1"} - assert r.headers.getlist("set-cookie") == ["foo=1", "bar=1"] + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/", retries=0) + assert r.headers == {"set-cookie": "foo=1, bar=1"} + assert r.headers.getlist("set-cookie") == ["foo=1", "bar=1"] class TestSNI(SocketDummyServerTestCase): @@ -87,16 +86,15 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - try: - pool.request("GET", "/", retries=0) - except MaxRetryError: # We are violating the protocol - pass - done_receiving.wait() - assert ( - self.host.encode("ascii") in self.buf - ), "missing hostname in SSL handshake" + with HTTPConnectionPool(self.host, self.port) as pool: + try: + pool.request("GET", "/", retries=0) + except MaxRetryError: # We are violating the protocol + pass + done_receiving.wait() + assert ( + self.host.encode("ascii") in self.buf + ), "missing hostname in SSL handshake" class TestClientCerts(SocketDummyServerTestCase): @@ -149,19 +147,18 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_file=DEFAULT_CERTS["certfile"], key_file=DEFAULT_CERTS["keyfile"], cert_reqs="REQUIRED", ca_certs=DEFAULT_CA, - ) - self.addCleanup(pool.close) - pool.request("GET", "/", retries=0) - done_receiving.set() + ) as pool: + pool.request("GET", "/", retries=0) + done_receiving.set() - assert len(client_certs) == 1 + assert len(client_certs) == 1 def test_client_certs_one_file(self): """ @@ -194,18 +191,17 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_file=COMBINED_CERT_AND_KEY, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA, - ) - self.addCleanup(pool.close) - pool.request("GET", "/", retries=0) - done_receiving.set() + ) as pool: + pool.request("GET", "/", retries=0) + done_receiving.set() - assert len(client_certs) == 1 + assert len(client_certs) == 1 def test_missing_client_certs_raises_error(self): """ @@ -225,20 +221,19 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(pool.close) - try: - pool.request("GET", "/", retries=0) - except MaxRetryError: - done_receiving.set() - else: - done_receiving.set() - self.fail( - "Expected server to reject connection due to missing client " - "certificates" - ) + ) as pool: + try: + pool.request("GET", "/", retries=0) + except MaxRetryError: + done_receiving.set() + else: + done_receiving.set() + self.fail( + "Expected server to reject connection due to missing client " + "certificates" + ) @requires_ssl_context_keyfile_password def test_client_cert_with_string_password(self): @@ -285,18 +280,17 @@ def socket_handler(listener): password=password, ) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, ssl_context=ssl_context, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA, - ) - self.addCleanup(pool.close) - pool.request("GET", "/", retries=0) - done_receiving.set() + ) as pool: + pool.request("GET", "/", retries=0) + done_receiving.set() - assert len(client_certs) == 1 + assert len(client_certs) == 1 @requires_ssl_context_keyfile_password def test_load_keyfile_with_invalid_password(self): @@ -349,27 +343,25 @@ def socket_handler(listener): done_closing.set() # let the test know it can proceed self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port) as pool: - response = pool.request("GET", "/", retries=0) - assert response.status == 200 - assert response.data == b"Response 0" + response = pool.request("GET", "/", retries=0) + assert response.status == 200 + assert response.data == b"Response 0" - done_closing.wait() # wait until the socket in our pool gets closed + done_closing.wait() # wait until the socket in our pool gets closed - response = pool.request("GET", "/", retries=0) - assert response.status == 200 - assert response.data == b"Response 1" + response = pool.request("GET", "/", retries=0) + assert response.status == 200 + assert response.data == b"Response 1" def test_connection_refused(self): # Does the pool retry if there is no listener on the port? host, port = get_unreachable_address() - http = HTTPConnectionPool(host, port, maxsize=3, block=True) - self.addCleanup(http.close) - with pytest.raises(MaxRetryError): - http.request("GET", "/", retries=0, release_conn=False) - assert http.pool.qsize() == http.pool.maxsize + with HTTPConnectionPool(host, port, maxsize=3, block=True) as http: + with pytest.raises(MaxRetryError): + http.request("GET", "/", retries=0, release_conn=False) + assert http.pool.qsize() == http.pool.maxsize def test_connection_read_timeout(self): timed_out = Event() @@ -383,18 +375,17 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - http = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=0.01, retries=False, maxsize=3, block=True - ) - self.addCleanup(http.close) + ) as http: - try: - with pytest.raises(ReadTimeoutError): - http.request("GET", "/", release_conn=False) - finally: - timed_out.set() + try: + with pytest.raises(ReadTimeoutError): + http.request("GET", "/", release_conn=False) + finally: + timed_out.set() - assert http.pool.qsize() == http.pool.maxsize + assert http.pool.qsize() == http.pool.maxsize def test_read_timeout_dont_retry_method_not_in_whitelist(self): timed_out = Event() @@ -406,14 +397,15 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port, timeout=0.01, retries=True) - self.addCleanup(pool.close) + with HTTPConnectionPool( + self.host, self.port, timeout=0.01, retries=True + ) as pool: - try: - with pytest.raises(ReadTimeoutError): - pool.request("POST", "/") - finally: - timed_out.set() + try: + with pytest.raises(ReadTimeoutError): + pool.request("POST", "/") + finally: + timed_out.set() def test_https_connection_read_timeout(self): """ Handshake timeouts should fail with a Timeout""" @@ -428,13 +420,14 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, timeout=0.01, retries=False) - self.addCleanup(pool.close) - try: - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") - finally: - timed_out.set() + with HTTPSConnectionPool( + self.host, self.port, timeout=0.01, retries=False + ) as pool: + try: + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") + finally: + timed_out.set() def test_timeout_errors_cause_retries(self): def socket_handler(listener): @@ -473,12 +466,11 @@ def socket_handler(listener): try: self._start_server(socket_handler) t = Timeout(connect=0.001, read=0.01) - pool = HTTPConnectionPool(self.host, self.port, timeout=t) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port, timeout=t) as pool: - response = pool.request("GET", "/", retries=1) - assert response.status == 200 - assert response.data == b"Response 2" + response = pool.request("GET", "/", retries=1) + assert response.status == 200 + assert response.data == b"Response 2" finally: socket.setdefaulttimeout(default_timeout) @@ -505,21 +497,20 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - response = pool.urlopen( - "GET", - "/", - retries=0, - preload_content=False, - timeout=Timeout(connect=1, read=0.01), - ) - try: - with pytest.raises(ReadTimeoutError): - response.read() - finally: - timed_out.set() + with HTTPConnectionPool(self.host, self.port) as pool: + + response = pool.urlopen( + "GET", + "/", + retries=0, + preload_content=False, + timeout=Timeout(connect=1, read=0.01), + ) + try: + with pytest.raises(ReadTimeoutError): + response.read() + finally: + timed_out.set() def test_delayed_body_read_timeout_with_preload(self): timed_out = Event() @@ -543,16 +534,15 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port) as pool: - try: - with pytest.raises(ReadTimeoutError): - pool.urlopen( - "GET", "/", retries=False, timeout=Timeout(connect=1, read=0.01) - ) - finally: - timed_out.set() + try: + with pytest.raises(ReadTimeoutError): + pool.urlopen( + "GET", "/", retries=False, timeout=Timeout(connect=1, read=0.01) + ) + finally: + timed_out.set() def test_incomplete_response(self): body = "Response" @@ -579,12 +569,11 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port) as pool: - response = pool.request("GET", "/", retries=0, preload_content=False) - with pytest.raises(ProtocolError): - response.read() + response = pool.request("GET", "/", retries=0, preload_content=False) + with pytest.raises(ProtocolError): + response.read() def test_retry_weird_http_version(self): """ Retry class should handle httplib.BadStatusLine errors properly """ @@ -630,12 +619,11 @@ def socket_handler(listener): sock.close() # Close the socket. self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - retry = Retry(read=1) - response = pool.request("GET", "/", retries=retry) - assert response.status == 200 - assert response.data == b"foo" + with HTTPConnectionPool(self.host, self.port) as pool: + retry = Retry(read=1) + response = pool.request("GET", "/", retries=retry) + assert response.status == 200 + assert response.data == b"foo" def test_connection_cleanup_on_read_timeout(self): timed_out = Event() @@ -819,17 +807,16 @@ def socket_handler(listener): complete.set() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port) as pool: - response = pool.request("GET", "/", retries=0, preload_content=False) - assert response.status == 200 - response.close() + response = pool.request("GET", "/", retries=0, preload_content=False) + assert response.status == 200 + response.close() - done_closing.set() # wait until the socket in our pool gets closed - successful = complete.wait(timeout=1) - if not successful: - self.fail("Timed out waiting for connection close") + done_closing.set() # wait until the socket in our pool gets closed + successful = complete.wait(timeout=1) + if not successful: + self.fail("Timed out waiting for connection close") def test_release_conn_param_is_respected_after_timeout_retry(self): """For successful ```urlopen(release_conn=False)```, @@ -923,25 +910,24 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url) - self.addCleanup(proxy.clear) - - r = proxy.request("GET", "http://google.com/") - - assert r.status == 200 - # FIXME: The order of the headers is not predictable right now. We - # should fix that someday (maybe when we migrate to - # OrderedDict/MultiDict). - assert sorted(r.data.split(b"\r\n")) == sorted( - [ - b"GET http://google.com/ HTTP/1.1", - b"Host: google.com", - b"Accept-Encoding: identity", - b"Accept: */*", - b"", - b"", - ] - ) + with proxy_from_url(base_url) as proxy: + + r = proxy.request("GET", "http://google.com/") + + assert r.status == 200 + # FIXME: The order of the headers is not predictable right now. We + # should fix that someday (maybe when we migrate to + # OrderedDict/MultiDict). + assert sorted(r.data.split(b"\r\n")) == sorted( + [ + b"GET http://google.com/ HTTP/1.1", + b"Host: google.com", + b"Accept-Encoding: identity", + b"Accept: */*", + b"", + b"", + ] + ) def test_headers(self): def echo_socket_handler(listener): @@ -967,18 +953,17 @@ def echo_socket_handler(listener): # Define some proxy headers. proxy_headers = HTTPHeaderDict({"For The Proxy": "YEAH!"}) - proxy = proxy_from_url(base_url, proxy_headers=proxy_headers) - self.addCleanup(proxy.clear) + with proxy_from_url(base_url, proxy_headers=proxy_headers) as proxy: - conn = proxy.connection_from_url("http://www.google.com/") + conn = proxy.connection_from_url("http://www.google.com/") - r = conn.urlopen("GET", "http://www.google.com/", assert_same_host=False) + r = conn.urlopen("GET", "http://www.google.com/", assert_same_host=False) - assert r.status == 200 - # FIXME: The order of the headers is not predictable right now. We - # should fix that someday (maybe when we migrate to - # OrderedDict/MultiDict). - assert b"For The Proxy: YEAH!\r\n" in r.data + assert r.status == 200 + # FIXME: The order of the headers is not predictable right now. We + # should fix that someday (maybe when we migrate to + # OrderedDict/MultiDict). + assert b"For The Proxy: YEAH!\r\n" in r.data def test_retries(self): close_event = Event() @@ -1010,20 +995,22 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url) - self.addCleanup(proxy.clear) - conn = proxy.connection_from_url("http://www.google.com") - - r = conn.urlopen( - "GET", "http://www.google.com", assert_same_host=False, retries=1 - ) - assert r.status == 200 + with proxy_from_url(base_url) as proxy: + conn = proxy.connection_from_url("http://www.google.com") - close_event.wait(timeout=1) - with pytest.raises(ProxyError): - conn.urlopen( - "GET", "http://www.google.com", assert_same_host=False, retries=False + r = conn.urlopen( + "GET", "http://www.google.com", assert_same_host=False, retries=1 ) + assert r.status == 200 + + close_event.wait(timeout=1) + with pytest.raises(ProxyError): + conn.urlopen( + "GET", + "http://www.google.com", + assert_same_host=False, + retries=False, + ) def test_connect_reconn(self): def proxy_ssl_one(listener): @@ -1079,15 +1066,14 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url, ca_certs=DEFAULT_CA) - self.addCleanup(proxy.clear) + with proxy_from_url(base_url, ca_certs=DEFAULT_CA) as proxy: - url = "https://{0}".format(self.host) - conn = proxy.connection_from_url(url) - r = conn.urlopen("GET", url, retries=0) - assert r.status == 200 - r = conn.urlopen("GET", url, retries=0) - assert r.status == 200 + url = "https://{0}".format(self.host) + conn = proxy.connection_from_url(url) + r = conn.urlopen("GET", url, retries=0) + assert r.status == 200 + r = conn.urlopen("GET", url, retries=0) + assert r.status == 200 def test_connect_ipv6_addr(self): ipv6_addr = "2001:4998:c:a06::2:4008" @@ -1127,16 +1113,15 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url, cert_reqs="NONE") - self.addCleanup(proxy.clear) + with proxy_from_url(base_url, cert_reqs="NONE") as proxy: - url = "https://[{0}]".format(ipv6_addr) - conn = proxy.connection_from_url(url) - try: - r = conn.urlopen("GET", url, retries=0) - assert r.status == 200 - except MaxRetryError: - self.fail("Invalid IPv6 format in HTTP CONNECT request") + url = "https://[{0}]".format(ipv6_addr) + conn = proxy.connection_from_url(url) + try: + r = conn.urlopen("GET", url, retries=0) + assert r.status == 200 + except MaxRetryError: + self.fail("Invalid IPv6 format in HTTP CONNECT request") class TestSSL(SocketDummyServerTestCase): @@ -1170,12 +1155,11 @@ def socket_handler(listener): ssl_sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port) - self.addCleanup(pool.close) + with HTTPSConnectionPool(self.host, self.port) as pool: - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) def test_ssl_read_timeout(self): timed_out = Event() @@ -1209,21 +1193,20 @@ def socket_handler(listener): ssl_sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) - self.addCleanup(pool.close) - - response = pool.urlopen( - "GET", - "/", - retries=0, - preload_content=False, - timeout=Timeout(connect=1, read=0.01), - ) - try: - with pytest.raises(ReadTimeoutError): - response.read() - finally: - timed_out.set() + with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + + response = pool.urlopen( + "GET", + "/", + retries=0, + preload_content=False, + timeout=Timeout(connect=1, read=0.01), + ) + try: + with pytest.raises(ReadTimeoutError): + response.read() + finally: + timed_out.set() def test_ssl_failed_fingerprint_verification(self): def socket_handler(listener): @@ -1323,10 +1306,9 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) - self.addCleanup(pool.close) - response = pool.urlopen("GET", "/", retries=1) - assert response.data == b"Success" + with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + response = pool.urlopen("GET", "/", retries=1) + assert response.data == b"Success" def test_ssl_load_default_certs_when_empty(self): def socket_handler(listener): @@ -1360,13 +1342,12 @@ def socket_handler(listener): with mock.patch("urllib3.util.ssl_.SSLContext", lambda *_, **__: context): self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port) - self.addCleanup(pool.close) + with HTTPSConnectionPool(self.host, self.port) as pool: - with pytest.raises(MaxRetryError): - pool.request("GET", "/", timeout=0.01) + with pytest.raises(MaxRetryError): + pool.request("GET", "/", timeout=0.01) - context.load_default_certs.assert_called_with() + context.load_default_certs.assert_called_with() def test_ssl_dont_load_default_certs_when_given(self): def socket_handler(listener): @@ -1407,13 +1388,12 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, **kwargs) - self.addCleanup(pool.close) + with HTTPSConnectionPool(self.host, self.port, **kwargs) as pool: - with pytest.raises(MaxRetryError): - pool.request("GET", "/", timeout=0.01) + with pytest.raises(MaxRetryError): + pool.request("GET", "/", timeout=0.01) - context.load_default_certs.assert_not_called() + context.load_default_certs.assert_not_called() class TestErrorWrapping(SocketDummyServerTestCase): @@ -1421,19 +1401,17 @@ def test_bad_statusline(self): self.start_response_handler( b"HTTP/1.1 Omg What Is This?\r\n" b"Content-Length: 0\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ProtocolError): - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + with pytest.raises(ProtocolError): + pool.request("GET", "/") def test_unknown_protocol(self): self.start_response_handler( b"HTTP/1000 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ProtocolError): - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + with pytest.raises(ProtocolError): + pool.request("GET", "/") class TestHeaders(SocketDummyServerTestCase): @@ -1445,11 +1423,10 @@ def test_httplib_headers_case_insensitive(self): b"Content-type: text/plain\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - HEADERS = {"Content-Length": "0", "Content-type": "text/plain"} - r = pool.request("GET", "/") - assert HEADERS == dict(r.headers.items()) # to preserve case sensitivity + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + HEADERS = {"Content-Length": "0", "Content-type": "text/plain"} + r = pool.request("GET", "/") + assert HEADERS == dict(r.headers.items()) # to preserve case sensitivity def test_headers_are_sent_with_the_original_case(self): headers = {"foo": "bar", "bAz": "quux"} @@ -1483,10 +1460,9 @@ def socket_handler(listener): } expected_headers.update(headers) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/", headers=HTTPHeaderDict(headers)) - assert expected_headers == parsed_headers + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.request("GET", "/", headers=HTTPHeaderDict(headers)) + assert expected_headers == parsed_headers def test_request_headers_are_sent_in_the_original_order(self): # NOTE: Probability this test gives a false negative is 1/(K!) @@ -1527,10 +1503,9 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/", headers=OrderedDict(expected_request_headers)) - assert expected_request_headers == actual_request_headers + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.request("GET", "/", headers=OrderedDict(expected_request_headers)) + assert expected_request_headers == actual_request_headers @fails_on_travis_gce def test_request_host_header_ignores_fqdn_dot(self): @@ -1558,12 +1533,11 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host + ".", self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/") - self.assert_header_received( - received_headers, "Host", "%s:%s" % (self.host, self.port) - ) + with HTTPConnectionPool(self.host + ".", self.port, retries=False) as pool: + pool.request("GET", "/") + self.assert_header_received( + received_headers, "Host", "%s:%s" % (self.host, self.port) + ) def test_response_headers_are_returned_in_the_original_order(self): # NOTE: Probability this test gives a false negative is 1/(K!) @@ -1595,13 +1569,12 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - r = pool.request("GET", "/", retries=0) - actual_response_headers = [ - (k, v) for (k, v) in r.headers.items() if k.startswith("X-Header-") - ] - assert expected_response_headers == actual_response_headers + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/", retries=0) + actual_response_headers = [ + (k, v) for (k, v) in r.headers.items() if k.startswith("X-Header-") + ] + assert expected_response_headers == actual_response_headers @pytest.mark.skipif( @@ -1620,23 +1593,22 @@ def _test_broken_header_parsing(self, headers, unparsed_data_check=None): + b"\r\n\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - with LogRecorder() as logs: - pool.request("GET", "/") + with LogRecorder() as logs: + pool.request("GET", "/") - for record in logs: - if ( - "Failed to parse headers" in record.msg - and pool._absolute_url("/") == record.args[0] - ): + for record in logs: if ( - unparsed_data_check is None - or unparsed_data_check in record.getMessage() + "Failed to parse headers" in record.msg + and pool._absolute_url("/") == record.args[0] ): - return - self.fail("Missing log about unparsed headers") + if ( + unparsed_data_check is None + or unparsed_data_check in record.getMessage() + ): + return + self.fail("Missing log about unparsed headers") def test_header_without_name(self): self._test_broken_header_parsing([b": Value", b"Another: Header"]) @@ -1656,14 +1628,13 @@ def _test_okay_header_parsing(self, header): (b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n") + header + b"\r\n\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: - with LogRecorder() as logs: - pool.request("GET", "/") + with LogRecorder() as logs: + pool.request("GET", "/") - for record in logs: - assert "Failed to parse headers" not in record.msg + for record in logs: + assert "Failed to parse headers" not in record.msg def test_header_text_plain(self): self._test_okay_header_parsing(b"Content-type: text/plain") @@ -1680,12 +1651,11 @@ def test_chunked_head_response_does_not_hang(self): b"Content-type: text/plain\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - r = pool.request("HEAD", "/", timeout=1, preload_content=False) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("HEAD", "/", timeout=1, preload_content=False) - # stream will use the read_chunked method here. - assert [] == list(r.stream()) + # stream will use the read_chunked method here. + assert [] == list(r.stream()) def test_empty_head_response_does_not_hang(self): self.start_response_handler( @@ -1694,12 +1664,11 @@ def test_empty_head_response_does_not_hang(self): b"Content-type: text/plain\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - r = pool.request("HEAD", "/", timeout=1, preload_content=False) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("HEAD", "/", timeout=1, preload_content=False) - # stream will use the read method here. - assert [] == list(r.stream()) + # stream will use the read method here. + assert [] == list(r.stream()) class TestStream(SocketDummyServerTestCase): @@ -1724,14 +1693,13 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - r = pool.request("GET", "/", timeout=1, preload_content=False) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("GET", "/", timeout=1, preload_content=False) - # Stream should read to the end. - assert [b"hello, world"] == list(r.stream(None)) + # Stream should read to the end. + assert [b"hello, world"] == list(r.stream(None)) - done_event.set() + done_event.set() class TestBadContentLength(SocketDummyServerTestCase): @@ -1756,24 +1724,23 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(conn.close) + with HTTPConnectionPool(self.host, self.port, maxsize=1) as conn: - # Test stream read when content length less than headers claim - get_response = conn.request( - "GET", url="/", preload_content=False, enforce_content_length=True - ) - data = get_response.stream(100) - # Read "good" data before we try to read again. - # This won't trigger till generator is exhausted. - next(data) - try: + # Test stream read when content length less than headers claim + get_response = conn.request( + "GET", url="/", preload_content=False, enforce_content_length=True + ) + data = get_response.stream(100) + # Read "good" data before we try to read again. + # This won't trigger till generator is exhausted. next(data) - assert False - except ProtocolError as e: - assert "12 bytes read, 10 more expected" in str(e) + try: + next(data) + assert False + except ProtocolError as e: + assert "12 bytes read, 10 more expected" in str(e) - done_event.set() + done_event.set() def test_enforce_content_length_no_body(self): done_event = Event() @@ -1795,17 +1762,16 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(conn.close) + with HTTPConnectionPool(self.host, self.port, maxsize=1) as conn: - # Test stream on 0 length body - head_response = conn.request( - "HEAD", url="/", preload_content=False, enforce_content_length=True - ) - data = [chunk for chunk in head_response.stream(1)] - assert len(data) == 0 + # Test stream on 0 length body + head_response = conn.request( + "HEAD", url="/", preload_content=False, enforce_content_length=True + ) + data = [chunk for chunk in head_response.stream(1)] + assert len(data) == 0 - done_event.set() + done_event.set() class TestRetryPoolSizeDrainFail(SocketDummyServerTestCase): @@ -1828,10 +1794,9 @@ def socket_handler(listener): self._start_server(socket_handler) retries = Retry(total=1, raise_on_status=False, status_forcelist=[404]) - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, maxsize=10, retries=retries, block=True - ) - self.addCleanup(pool.close) + ) as pool: - pool.urlopen("GET", "/not_found", preload_content=False) - assert pool.num_connections == 1 + pool.urlopen("GET", "/not_found", preload_content=False) + assert pool.num_connections == 1