diff --git a/test/contrib/test_socks.py b/test/contrib/test_socks.py index bc8b6299..259baac7 100644 --- a/test/contrib/test_socks.py +++ b/test/contrib/test_socks.py @@ -232,13 +232,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_local_dns(self): def request_handler(listener): @@ -266,13 +265,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://localhost") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://localhost") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_correct_header_line(self): def request_handler(listener): @@ -304,10 +302,9 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://example.com") - assert response.status == 200 + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://example.com") + assert response.status == 200 def test_connection_timeouts(self): event = threading.Event() @@ -317,12 +314,10 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - - with pytest.raises(ConnectTimeoutError): - pm.request("GET", "http://example.com", timeout=0.001, retries=False) - event.set() + with socks.SOCKSProxyManager(proxy_url) as pm: + with pytest.raises(ConnectTimeoutError): + pm.request("GET", "http://example.com", timeout=0.001, retries=False) + event.set() def test_connection_failure(self): event = threading.Event() @@ -333,12 +328,10 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - - event.wait() - with pytest.raises(NewConnectionError): - pm.request("GET", "http://example.com", retries=False) + with socks.SOCKSProxyManager(proxy_url) as pm: + event.wait() + with pytest.raises(NewConnectionError): + pm.request("GET", "http://example.com", retries=False) def test_proxy_rejection(self): evt = threading.Event() @@ -355,12 +348,10 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - - with pytest.raises(NewConnectionError): - pm.request("GET", "http://example.com", retries=False) - evt.set() + with socks.SOCKSProxyManager(proxy_url) as pm: + with pytest.raises(NewConnectionError): + pm.request("GET", "http://example.com", retries=False) + evt.set() def test_socks_with_password(self): def request_handler(listener): @@ -390,14 +381,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="user", password="pass") - self.addCleanup(pm.clear) - - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url, username="user", password="pass") as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_socks_with_auth_in_url(self): """ @@ -432,14 +421,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://user:pass@%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://16.17.18.19") - response = pm.request("GET", "http://16.17.18.19") - - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_socks_with_invalid_password(self): def request_handler(listener): @@ -452,15 +439,15 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="user", password="badpass") - self.addCleanup(pm.clear) - - try: - pm.request("GET", "http://example.com", retries=False) - except NewConnectionError as e: - assert "SOCKS5 authentication failed" in str(e) - else: - self.fail("Did not raise") + with socks.SOCKSProxyManager( + proxy_url, username="user", password="badpass" + ) as pm: + try: + pm.request("GET", "http://example.com", retries=False) + except NewConnectionError as e: + assert "SOCKS5 authentication failed" in str(e) + else: + self.fail("Did not raise") def test_source_address_works(self): expected_port = _get_free_port(self.host) @@ -492,12 +479,11 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager( + with socks.SOCKSProxyManager( proxy_url, source_address=("127.0.0.1", expected_port) - ) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 + ) as pm: + response = pm.request("GET", "http://16.17.18.19") + assert response.status == 200 class TestSOCKS4Proxy(IPV4SocketDummyServerTestCase): @@ -534,13 +520,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.headers["Server"] == "SocksTestServer" - assert response.data == b"" + assert response.status == 200 + assert response.headers["Server"] == "SocksTestServer" + assert response.data == b"" def test_local_dns(self): def request_handler(listener): @@ -568,13 +553,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://localhost") + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://localhost") - assert response.status == 200 - assert response.headers["Server"] == "SocksTestServer" - assert response.data == b"" + assert response.status == 200 + assert response.headers["Server"] == "SocksTestServer" + assert response.data == b"" def test_correct_header_line(self): def request_handler(listener): @@ -606,10 +590,9 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4a://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - response = pm.request("GET", "http://example.com") - assert response.status == 200 + with socks.SOCKSProxyManager(proxy_url) as pm: + response = pm.request("GET", "http://example.com") + assert response.status == 200 def test_proxy_rejection(self): evt = threading.Event() @@ -626,12 +609,10 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4a://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url) - self.addCleanup(pm.clear) - - with pytest.raises(NewConnectionError): - pm.request("GET", "http://example.com", retries=False) - evt.set() + with socks.SOCKSProxyManager(proxy_url) as pm: + with pytest.raises(NewConnectionError): + pm.request("GET", "http://example.com", retries=False) + evt.set() def test_socks4_with_username(self): def request_handler(listener): @@ -659,13 +640,12 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="user") - self.addCleanup(pm.clear) - response = pm.request("GET", "http://16.17.18.19") + with socks.SOCKSProxyManager(proxy_url, username="user") as pm: + response = pm.request("GET", "http://16.17.18.19") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" def test_socks_with_invalid_username(self): def request_handler(listener): @@ -676,15 +656,13 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks4a://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, username="baduser") - self.addCleanup(pm.clear) - - try: - pm.request("GET", "http://example.com", retries=False) - except NewConnectionError as e: - assert "different user-ids" in str(e) - else: - self.fail("Did not raise") + with socks.SOCKSProxyManager(proxy_url, username="baduser") as pm: + try: + pm.request("GET", "http://example.com", retries=False) + except NewConnectionError as e: + assert "different user-ids" in str(e) + else: + self.fail("Did not raise") class TestSOCKSWithTLS(IPV4SocketDummyServerTestCase): @@ -728,10 +706,9 @@ def request_handler(listener): self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) - pm = socks.SOCKSProxyManager(proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(pm.clear) - response = pm.request("GET", "https://localhost") + with socks.SOCKSProxyManager(proxy_url, ca_certs=DEFAULT_CA) as pm: + response = pm.request("GET", "https://localhost") - assert response.status == 200 - assert response.data == b"" - assert response.headers["Server"] == "SocksTestServer" + assert response.status == 200 + assert response.data == b"" + assert response.headers["Server"] == "SocksTestServer" diff --git a/test/with_dummyserver/test_chunked_transfer.py b/test/with_dummyserver/test_chunked_transfer.py index a19e1380..63fe89de 100644 --- a/test/with_dummyserver/test_chunked_transfer.py +++ b/test/with_dummyserver/test_chunked_transfer.py @@ -29,38 +29,35 @@ def socket_handler(listener): def test_chunks(self): self.start_chunked_handler() chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] - pool = HTTPConnectionPool(self.host, self.port, retries=False) - pool.urlopen("GET", "/", chunks, headers=dict(DNT="1")) - self.addCleanup(pool.close) - - assert b"transfer-encoding" in self.buffer - body = self.buffer.split(b"\r\n\r\n", 1)[1] - lines = body.split(b"\r\n") - # Empty chunks should have been skipped, as this could not be distinguished - # from terminating the transmission - for i, chunk in enumerate([c for c in chunks if c]): - assert lines[i * 2] == hex(len(chunk))[2:].encode("utf-8") - assert lines[i * 2 + 1] == chunk + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", chunks, headers=dict(DNT="1")) + + assert b"transfer-encoding" in self.buffer + body = self.buffer.split(b"\r\n\r\n", 1)[1] + lines = body.split(b"\r\n") + # Empty chunks should have been skipped, as this could not be distinguished + # from terminating the transmission + for i, chunk in enumerate([c for c in chunks if c]): + assert lines[i * 2] == hex(len(chunk))[2:].encode("utf-8") + assert lines[i * 2 + 1] == chunk def _test_body(self, data): self.start_chunked_handler() - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - - pool.urlopen("GET", "/", data) - header, body = self.buffer.split(b"\r\n\r\n", 1) - - assert b"transfer-encoding: chunked" in header.split(b"\r\n") - if data: - bdata = data if isinstance(data, bytes) else data.encode("utf-8") - assert b"\r\n" + bdata + b"\r\n" in body - assert body.endswith(b"\r\n0\r\n\r\n") - - len_str = body.split(b"\r\n", 1)[0] - stated_len = int(len_str, 16) - assert stated_len == len(bdata) - else: - assert body == b"0\r\n\r\n" + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", data) + header, body = self.buffer.split(b"\r\n\r\n", 1) + + assert b"transfer-encoding: chunked" in header.split(b"\r\n") + if data: + bdata = data if isinstance(data, bytes) else data.encode("utf-8") + assert b"\r\n" + bdata + b"\r\n" in body + assert body.endswith(b"\r\n0\r\n\r\n") + + len_str = body.split(b"\r\n", 1)[0] + stated_len = int(len_str, 16) + assert stated_len == len(bdata) + else: + assert body == b"0\r\n\r\n" def test_bytestring_body(self): self._test_body(b"thisshouldbeonechunk\r\nasdf") @@ -80,25 +77,23 @@ def test_empty_iterable_body(self): def test_removes_duplicate_host_header(self): self.start_chunked_handler() chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.urlopen("GET", "/", chunks, headers={"Host": "test.org"}) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", chunks, headers={"Host": "test.org"}) - header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() - header_lines = header_block.split(b"\r\n")[1:] + header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() + header_lines = header_block.split(b"\r\n")[1:] - host_headers = [x for x in header_lines if x.startswith(b"host")] - assert len(host_headers) == 1 + host_headers = [x for x in header_lines if x.startswith(b"host")] + assert len(host_headers) == 1 def test_provides_default_host_header(self): self.start_chunked_handler() chunks = [b"foo", b"bar", b"", b"bazzzzzzzzzzzzzzzzzzzzzz"] - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.urlopen("GET", "/", chunks) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.urlopen("GET", "/", chunks) - header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() - header_lines = header_block.split(b"\r\n")[1:] + header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower() + header_lines = header_block.split(b"\r\n")[1:] - host_headers = [x for x in header_lines if x.startswith(b"host")] - assert len(host_headers) == 1 + host_headers = [x for x in header_lines if x.startswith(b"host")] + assert len(host_headers) == 1 diff --git a/test/with_dummyserver/test_connectionpool.py b/test/with_dummyserver/test_connectionpool.py index 117a1872..97d05e81 100644 --- a/test/with_dummyserver/test_connectionpool.py +++ b/test/with_dummyserver/test_connectionpool.py @@ -47,41 +47,39 @@ def test_timeout_float(self): ready_event = self.start_basic_handler(block_send=block_event, num=2) # Pool-global timeout - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=SHORT_TIMEOUT, retries=False - ) - self.addCleanup(pool.close) - wait_for_socket(ready_event) - with pytest.raises(ReadTimeoutError): + ) as pool: + wait_for_socket(ready_event) + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") + block_event.set() # Release block + + # Shouldn't raise this time + wait_for_socket(ready_event) + block_event.set() # Pre-release block pool.request("GET", "/") - block_event.set() # Release block - - # Shouldn't raise this time - wait_for_socket(ready_event) - block_event.set() # Pre-release block - pool.request("GET", "/") def test_conn_closed(self): block_event = Event() self.start_basic_handler(block_send=block_event, num=1) - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=SHORT_TIMEOUT, retries=False - ) - self.addCleanup(pool.close) - conn = pool._get_conn() - pool._put_conn(conn) - try: - pool.urlopen("GET", "/") - self.fail("The request should fail with a timeout error.") - except ReadTimeoutError: - if conn._sock: - with pytest.raises(socket.error): - conn.sock.recv(1024) - finally: + ) as pool: + conn = pool._get_conn() pool._put_conn(conn) - - block_event.set() + try: + pool.urlopen("GET", "/") + self.fail("The request should fail with a timeout error.") + except ReadTimeoutError: + if conn._sock: + with pytest.raises(socket.error): + conn.sock.recv(1024) + finally: + pool._put_conn(conn) + + block_event.set() def test_timeout(self): # Requests should time out when expected @@ -90,63 +88,61 @@ def test_timeout(self): # Pool-global timeout timeout = Timeout(read=SHORT_TIMEOUT) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout, retries=False) - self.addCleanup(pool.close) - - wait_for_socket(ready_event) - conn = pool._get_conn() - with pytest.raises(ReadTimeoutError): - pool._make_request(conn, "GET", "/") - pool._put_conn(conn) - block_event.set() # Release request + with HTTPConnectionPool( + self.host, self.port, timeout=timeout, retries=False + ) as pool: + wait_for_socket(ready_event) + conn = pool._get_conn() + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "GET", "/") + pool._put_conn(conn) + block_event.set() # Release request - wait_for_socket(ready_event) - block_event.clear() - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") - block_event.set() # Release request + wait_for_socket(ready_event) + block_event.clear() + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") + block_event.set() # Release request # Request-specific timeouts should raise errors - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=LONG_TIMEOUT, retries=False - ) - self.addCleanup(pool.close) - - conn = pool._get_conn() - wait_for_socket(ready_event) - now = time.time() - with pytest.raises(ReadTimeoutError): - pool._make_request(conn, "GET", "/", timeout=timeout) - delta = time.time() - now - block_event.set() # Release request - - message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" - assert delta < LONG_TIMEOUT, message - pool._put_conn(conn) - - wait_for_socket(ready_event) - now = time.time() - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/", timeout=timeout) - delta = time.time() - now - - message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" - assert delta < LONG_TIMEOUT, message - block_event.set() # Release request - - # Timeout int/float passed directly to request and _make_request should - # raise a request timeout - wait_for_socket(ready_event) - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/", timeout=SHORT_TIMEOUT) - block_event.set() # Release request + ) as pool: + conn = pool._get_conn() + wait_for_socket(ready_event) + now = time.time() + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "GET", "/", timeout=timeout) + delta = time.time() - now + block_event.set() # Release request + + message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" + assert delta < LONG_TIMEOUT, message + pool._put_conn(conn) - wait_for_socket(ready_event) - conn = pool._new_conn() - # FIXME: This assert flakes sometimes. Not sure why. - with pytest.raises(ReadTimeoutError): - pool._make_request(conn, "GET", "/", timeout=SHORT_TIMEOUT) - block_event.set() # Release request + wait_for_socket(ready_event) + now = time.time() + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/", timeout=timeout) + delta = time.time() - now + + message = "timeout was pool-level LONG_TIMEOUT rather than request-level SHORT_TIMEOUT" + assert delta < LONG_TIMEOUT, message + block_event.set() # Release request + + # Timeout int/float passed directly to request and _make_request should + # raise a request timeout + wait_for_socket(ready_event) + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/", timeout=SHORT_TIMEOUT) + block_event.set() # Release request + + wait_for_socket(ready_event) + conn = pool._new_conn() + # FIXME: This assert flakes sometimes. Not sure why. + with pytest.raises(ReadTimeoutError): + pool._make_request(conn, "GET", "/", timeout=SHORT_TIMEOUT) + block_event.set() # Release request def test_connect_timeout(self): url = "/" @@ -154,47 +150,44 @@ def test_connect_timeout(self): timeout = Timeout(connect=SHORT_TIMEOUT) # Pool-global timeout - pool = HTTPConnectionPool(host, port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", url) + with HTTPConnectionPool(host, port, timeout=timeout) as pool: + conn = pool._get_conn() + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", url) - # Retries - retries = Retry(connect=0) - with pytest.raises(MaxRetryError): - pool.request("GET", url, retries=retries) + # Retries + retries = Retry(connect=0) + with pytest.raises(MaxRetryError): + pool.request("GET", url, retries=retries) # Request-specific connection timeouts big_timeout = Timeout(read=LONG_TIMEOUT, connect=LONG_TIMEOUT) - pool = HTTPConnectionPool(host, port, timeout=big_timeout, retries=False) - self.addCleanup(pool.close) - conn = pool._get_conn() - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", url, timeout=timeout) + with HTTPConnectionPool(host, port, timeout=big_timeout, retries=False) as pool: + conn = pool._get_conn() + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", url, timeout=timeout) - pool._put_conn(conn) - with pytest.raises(ConnectTimeoutError): - pool.request("GET", url, timeout=timeout) + pool._put_conn(conn) + with pytest.raises(ConnectTimeoutError): + pool.request("GET", url, timeout=timeout) def test_total_applies_connect(self): host, port = TARPIT_HOST, 80 timeout = Timeout(total=None, connect=SHORT_TIMEOUT) - pool = HTTPConnectionPool(host, port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) + with HTTPConnectionPool(host, port, timeout=timeout) as pool: + conn = pool._get_conn() with pytest.raises(ConnectTimeoutError): pool._make_request(conn, "GET", "/") timeout = Timeout(connect=3, read=5, total=SHORT_TIMEOUT) - pool = HTTPConnectionPool(host, port, timeout=timeout) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - with pytest.raises(ConnectTimeoutError): - pool._make_request(conn, "GET", "/") + with HTTPConnectionPool(host, port, timeout=timeout) as pool: + try: + conn = pool._get_conn() + with pytest.raises(ConnectTimeoutError): + pool._make_request(conn, "GET", "/") + finally: + conn.close() def test_total_timeout(self): block_event = Event() @@ -203,39 +196,42 @@ def test_total_timeout(self): wait_for_socket(ready_event) # This will get the socket to raise an EAGAIN on the read timeout = Timeout(connect=3, read=SHORT_TIMEOUT) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") + with HTTPConnectionPool( + self.host, self.port, timeout=timeout, retries=False + ) as pool: + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") - block_event.set() - wait_for_socket(ready_event) - block_event.clear() + block_event.set() + wait_for_socket(ready_event) + block_event.clear() # The connect should succeed and this should hit the read timeout timeout = Timeout(connect=3, read=5, total=SHORT_TIMEOUT) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") + with HTTPConnectionPool( + self.host, self.port, timeout=timeout, retries=False + ) as pool: + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") def test_create_connection_timeout(self): self.start_basic_handler(block_send=Event(), num=0) # needed for self.port timeout = Timeout(connect=SHORT_TIMEOUT, total=LONG_TIMEOUT) - pool = HTTPConnectionPool( + with HTTPConnectionPool( TARPIT_HOST, self.port, timeout=timeout, retries=False - ) - self.addCleanup(pool.close) - conn = pool._new_conn() - with pytest.raises(ConnectTimeoutError): - conn.connect(connect_timeout=timeout.connect_timeout) + ) as pool: + conn = pool._new_conn() + with pytest.raises(ConnectTimeoutError): + conn.connect(connect_timeout=timeout.connect_timeout) class TestConnectionPool(HTTPDummyServerTestCase): def setUp(self): self.pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(self.pool.close) + + def tearDown(self): + self.pool.close() def test_get(self): r = self.pool.request("GET", "/specific_method", fields={"method": "GET"}) @@ -309,15 +305,16 @@ def test_nagle(self): """ Test that connections have TCP_NODELAY turned on """ # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - conn = pool._get_conn() - self.addCleanup(conn.close) - pool._make_request(conn, "GET", "/") - tcp_nodelay_setting = conn._sock.getsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY - ) - assert tcp_nodelay_setting + with HTTPConnectionPool(self.host, self.port) as pool: + try: + conn = pool._get_conn() + pool._make_request(conn, "GET", "/") + tcp_nodelay_setting = conn._sock.getsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY + ) + assert tcp_nodelay_setting + finally: + conn.close() def test_socket_options(self): """Test that connections accept socket options.""" @@ -351,20 +348,27 @@ def test_defaults_are_applied(self): """Test that modifying the default socket options works.""" # This test needs to be here in order to be run. socket.create_connection actually tries # to connect to the host provided so we need a dummyserver to be running. - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - # Get the HTTPConnection instance - conn = pool._new_conn() - self.addCleanup(conn.close) - # Update the default socket options - conn.default_socket_options += [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] - conn.connect() - s = conn._sock - self.addCleanup(s.close) - nagle_disabled = s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) > 0 - using_keepalive = s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 - assert nagle_disabled - assert using_keepalive + with HTTPConnectionPool(self.host, self.port) as pool: + # Get the HTTPConnection instance + conn = pool._new_conn() + try: + # Update the default socket options + conn.default_socket_options += [ + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + ] + conn.connect() + s = conn._sock + nagle_disabled = ( + s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) > 0 + ) + using_keepalive = ( + s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) > 0 + ) + assert nagle_disabled + assert using_keepalive + finally: + conn.close() + s.close() def test_connection_error_retries(self): """ ECONNREFUSED error should raise a connection error, with retries """ @@ -378,21 +382,18 @@ def test_connection_error_retries(self): def test_timeout_success(self): timeout = Timeout(connect=3, read=5, total=None) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - pool.request("GET", "/") - # This should not raise a "Timeout already started" error - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + pool.request("GET", "/") + # This should not raise a "Timeout already started" error + pool.request("GET", "/") - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - # This should also not raise a "Timeout already started" error - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + # This should also not raise a "Timeout already started" error + pool.request("GET", "/") timeout = Timeout(total=None) - pool = HTTPConnectionPool(self.host, self.port, timeout=timeout) - self.addCleanup(pool.close) - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, timeout=timeout) as pool: + pool.request("GET", "/") def test_bad_connect(self): pool = HTTPConnectionPool("badhost.invalid", self.port) @@ -403,63 +404,62 @@ def test_bad_connect(self): assert type(e.reason) == NewConnectionError def test_keepalive(self): - pool = HTTPConnectionPool(self.host, self.port, block=True, maxsize=1) - self.addCleanup(pool.close) - - r = pool.request("GET", "/keepalive?close=0") - r = pool.request("GET", "/keepalive?close=0") + with HTTPConnectionPool(self.host, self.port, block=True, maxsize=1) as pool: + r = pool.request("GET", "/keepalive?close=0") + r = pool.request("GET", "/keepalive?close=0") - assert r.status == 200 - assert pool.num_connections == 1 - assert pool.num_requests == 2 + assert r.status == 200 + assert pool.num_connections == 1 + assert pool.num_requests == 2 def test_keepalive_close(self): - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, block=True, maxsize=1, timeout=2 - ) - self.addCleanup(pool.close) - - r = pool.request( - "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} - ) + ) as pool: + r = pool.request( + "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} + ) - assert pool.num_connections == 1 + assert pool.num_connections == 1 - # The dummyserver will have responded with Connection:close, - # and httplib will properly cleanup the socket. + # The dummyserver will have responded with Connection:close, + # and httplib will properly cleanup the socket. - # We grab the HTTPConnection object straight from the Queue, - # because _get_conn() is where the check & reset occurs - # pylint: disable-msg=W0212 - conn = pool.pool.get() - assert conn._sock is None - pool._put_conn(conn) + # We grab the HTTPConnection object straight from the Queue, + # because _get_conn() is where the check & reset occurs + # pylint: disable-msg=W0212 + conn = pool.pool.get() + assert conn._sock is None + pool._put_conn(conn) - # Now with keep-alive - r = pool.request( - "GET", "/keepalive?close=0", retries=0, headers={"Connection": "keep-alive"} - ) + # Now with keep-alive + r = pool.request( + "GET", + "/keepalive?close=0", + retries=0, + headers={"Connection": "keep-alive"}, + ) - # The dummyserver responded with Connection:keep-alive, the connection - # persists. - conn = pool.pool.get() - assert conn._sock is not None - pool._put_conn(conn) + # The dummyserver responded with Connection:keep-alive, the connection + # persists. + conn = pool.pool.get() + assert conn._sock is not None + pool._put_conn(conn) - # Another request asking the server to close the connection. This one - # should get cleaned up for the next request. - r = pool.request( - "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} - ) + # Another request asking the server to close the connection. This one + # should get cleaned up for the next request. + r = pool.request( + "GET", "/keepalive?close=1", retries=0, headers={"Connection": "close"} + ) - assert r.status == 200 + assert r.status == 200 - conn = pool.pool.get() - assert conn._sock is None - pool._put_conn(conn) + conn = pool.pool.get() + assert conn._sock is None + pool._put_conn(conn) - # Next request - r = pool.request("GET", "/keepalive?close=0") + # Next request + r = pool.request("GET", "/keepalive?close=0") def test_post_with_urlencode(self): data = {"banana": "hammock", "lol": "cat"} @@ -537,38 +537,32 @@ def test_bad_decode(self): ) def test_connection_count(self): - pool = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(pool.close) - - pool.request("GET", "/") - pool.request("GET", "/") - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, maxsize=1) as pool: + pool.request("GET", "/") + pool.request("GET", "/") + pool.request("GET", "/") - assert pool.num_connections == 1 - assert pool.num_requests == 3 + assert pool.num_connections == 1 + assert pool.num_requests == 3 def test_connection_count_bigpool(self): - http_pool = HTTPConnectionPool(self.host, self.port, maxsize=16) - self.addCleanup(http_pool.close) + with HTTPConnectionPool(self.host, self.port, maxsize=16) as http_pool: + http_pool.request("GET", "/") + http_pool.request("GET", "/") + http_pool.request("GET", "/") - http_pool.request("GET", "/") - http_pool.request("GET", "/") - http_pool.request("GET", "/") - - assert http_pool.num_connections == 1 - assert http_pool.num_requests == 3 + assert http_pool.num_connections == 1 + assert http_pool.num_requests == 3 def test_partial_response(self): - pool = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(pool.close) - - req_data = {"lol": "cat"} - resp_data = urlencode(req_data).encode("utf-8") + with HTTPConnectionPool(self.host, self.port, maxsize=1) as pool: + req_data = {"lol": "cat"} + resp_data = urlencode(req_data).encode("utf-8") - r = pool.request("GET", "/echo", fields=req_data, preload_content=False) + r = pool.request("GET", "/echo", fields=req_data, preload_content=False) - assert r.read(5) == resp_data[:5] - assert r.read() == resp_data[5:] + assert r.read(5) == resp_data[:5] + assert r.read() == resp_data[5:] def test_lazy_load_twice(self): # This test is sad and confusing. Need to figure out what's @@ -632,32 +626,31 @@ def test_for_double_release(self): MAXSIZE = 5 # Check default state - pool = HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) - self.addCleanup(pool.close) - assert pool.num_connections == 0 - assert pool.pool.qsize() == MAXSIZE + with HTTPConnectionPool(self.host, self.port, maxsize=MAXSIZE) as pool: + assert pool.num_connections == 0 + assert pool.pool.qsize() == MAXSIZE - # Make an empty slot for testing - pool.pool.get() - assert pool.pool.qsize() == MAXSIZE - 1 + # Make an empty slot for testing + pool.pool.get() + assert pool.pool.qsize() == MAXSIZE - 1 - # Check state after simple request - pool.urlopen("GET", "/") - assert pool.pool.qsize() == MAXSIZE - 1 + # Check state after simple request + pool.urlopen("GET", "/") + assert pool.pool.qsize() == MAXSIZE - 1 - # Check state without release - pool.urlopen("GET", "/", preload_content=False) - assert pool.pool.qsize() == MAXSIZE - 2 + # Check state without release + pool.urlopen("GET", "/", preload_content=False) + assert pool.pool.qsize() == MAXSIZE - 2 - pool.urlopen("GET", "/") - assert pool.pool.qsize() == MAXSIZE - 2 + pool.urlopen("GET", "/") + assert pool.pool.qsize() == MAXSIZE - 2 - # Check state after read - pool.urlopen("GET", "/").data - assert pool.pool.qsize() == MAXSIZE - 2 + # Check state after read + pool.urlopen("GET", "/").data + assert pool.pool.qsize() == MAXSIZE - 2 - pool.urlopen("GET", "/") - assert pool.pool.qsize() == MAXSIZE - 2 + pool.urlopen("GET", "/") + assert pool.pool.qsize() == MAXSIZE - 2 def test_connections_arent_released(self): MAXSIZE = 5 @@ -679,12 +672,11 @@ def test_source_address(self): if is_ipv6 and not HAS_IPV6_AND_DNS: warnings.warn("No IPv6 support: skipping.", NoIPv6Warning) continue - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, source_address=addr, retries=False - ) - self.addCleanup(pool.close) - r = pool.request("GET", "/source_address") - assert r.data == b(addr[0]) + ) as pool: + r = pool.request("GET", "/source_address") + assert r.data == b(addr[0]) def test_source_address_error(self): for addr in INVALID_SOURCE_ADDRESSES: @@ -720,10 +712,9 @@ def test_chunked_gzip(self): assert b"123" * 4 == response.read() def test_mixed_case_hostname(self): - pool = HTTPConnectionPool("LoCaLhOsT", self.port) - self.addCleanup(pool.close) - response = pool.request("GET", "http://LoCaLhOsT:%d/" % self.port) - assert response.status == 200 + with HTTPConnectionPool("LoCaLhOsT", self.port) as pool: + response = pool.request("GET", "http://LoCaLhOsT:%d/" % self.port) + assert response.status == 200 if __name__ == "__main__": diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index cf8b25e5..f77232e9 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -85,7 +85,9 @@ class TestHTTPS(HTTPSDummyServerTestCase): def setUp(self): self._pool = HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) - self.addCleanup(self._pool.close) + + def tearDown(self): + self._pool.close() def test_simple(self): r = self._pool.request("GET", "/") @@ -182,328 +184,294 @@ def test_client_encrypted_key_requires_password(self): assert "password is required" in str(e.value) def test_verified(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] + ) as https_pool: + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + + # Modern versions of Python, or systems using PyOpenSSL, don't + # emit warnings. + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + assert not warn.called, warn.call_args_list else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning + assert warn.called + if util.HAS_SNI: + call = warn.call_args_list[0] + else: + call = warn.call_args_list[1] + error = call[0][1] + assert error == InsecurePlatformWarning def test_verified_with_context(self): ctx = util.ssl_.create_urllib3_context(cert_reqs=ssl.CERT_REQUIRED) ctx.load_verify_locations(cafile=DEFAULT_CA) - https_pool = HTTPSConnectionPool(self.host, self.port, ssl_context=ctx) - self.addCleanup(https_pool.close) - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] + with HTTPSConnectionPool(self.host, self.port, ssl_context=ctx) as https_pool: + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + + # Modern versions of Python, or systems using PyOpenSSL, don't + # emit warnings. + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + assert not warn.called, warn.call_args_list else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning + assert warn.called + if util.HAS_SNI: + call = warn.call_args_list[0] + else: + call = warn.call_args_list[1] + error = call[0][1] + assert error == InsecurePlatformWarning def test_context_combines_with_ca_certs(self): ctx = util.ssl_.create_urllib3_context(cert_reqs=ssl.CERT_REQUIRED) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, ca_certs=DEFAULT_CA, ssl_context=ctx - ) - self.addCleanup(https_pool.close) - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - - # Modern versions of Python, or systems using PyOpenSSL, don't - # emit warnings. - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - assert not warn.called, warn.call_args_list - else: - assert warn.called - if util.HAS_SNI: - call = warn.call_args_list[0] + ) as https_pool: + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + + # Modern versions of Python, or systems using PyOpenSSL, don't + # emit warnings. + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + assert not warn.called, warn.call_args_list else: - call = warn.call_args_list[1] - error = call[0][1] - assert error == InsecurePlatformWarning + assert warn.called + if util.HAS_SNI: + call = warn.call_args_list[0] + else: + call = warn.call_args_list[1] + error = call[0][1] + assert error == InsecurePlatformWarning @onlyPy279OrNewer @notSecureTransport # SecureTransport does not support cert directories @notOpenSSL098 # OpenSSL 0.9.8 does not support cert directories def test_ca_dir_verified(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_cert_dir=DEFAULT_CA_DIR - ) - self.addCleanup(https_pool.close) - - with mock.patch("warnings.warn") as warn: - r = https_pool.request("GET", "/") - assert r.status == 200 - assert not warn.called, warn.call_args_list + ) as https_pool: + with mock.patch("warnings.warn") as warn: + r = https_pool.request("GET", "/") + assert r.status == 200 + assert not warn.called, warn.call_args_list def test_invalid_common_name(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - try: - https_pool.request("GET", "/") - self.fail("Didn't raise SSL invalid common name") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "doesn't match" in str( - e.reason - ) or "certificate verify failed" in str(e.reason) + ) as https_pool: + try: + https_pool.request("GET", "/") + self.fail("Didn't raise SSL invalid common name") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "doesn't match" in str( + e.reason + ) or "certificate verify failed" in str(e.reason) def test_verified_with_bad_ca_certs(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(https_pool.close) - - try: - https_pool.request("GET", "/") - self.fail("Didn't raise SSL error with bad CA certs") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "certificate verify failed" in str(e.reason), ( - "Expected 'certificate verify failed'," "instead got: %r" % e.reason - ) + ) as https_pool: + try: + https_pool.request("GET", "/") + self.fail("Didn't raise SSL error with bad CA certs") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "certificate verify failed" in str(e.reason), ( + "Expected 'certificate verify failed'," "instead got: %r" % e.reason + ) def test_verified_without_ca_certs(self): # default is cert_reqs=None which is ssl.CERT_NONE - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED" - ) - self.addCleanup(https_pool.close) - - try: - https_pool.request("GET", "/") - self.fail( - "Didn't raise SSL error with no CA certs when" "CERT_REQUIRED is set" - ) - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - # there is a different error message depending on whether or - # not pyopenssl is injected - assert ( - "No root certificates specified" in str(e.reason) - or "certificate verify failed" in str(e.reason) - or "invalid certificate chain" in str(e.reason) - ), ( - "Expected 'No root certificates specified', " - "'certificate verify failed', or " - "'invalid certificate chain', " - "instead got: %r" % e.reason - ) + ) as https_pool: + try: + https_pool.request("GET", "/") + self.fail( + "Didn't raise SSL error with no CA certs when" + "CERT_REQUIRED is set" + ) + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + # there is a different error message depending on whether or + # not pyopenssl is injected + assert ( + "No root certificates specified" in str(e.reason) + or "certificate verify failed" in str(e.reason) + or "invalid certificate chain" in str(e.reason) + ), ( + "Expected 'No root certificates specified', " + "'certificate verify failed', or " + "'invalid certificate chain', " + "instead got: %r" % e.reason + ) def test_unverified_ssl(self): """ Test that bare HTTPSConnection can connect, make requests """ - pool = HTTPSConnectionPool(self.host, self.port, cert_reqs=ssl.CERT_NONE) - self.addCleanup(pool.close) - - with mock.patch("warnings.warn") as warn: - r = pool.request("GET", "/") - assert r.status == 200 - assert warn.called + with HTTPSConnectionPool(self.host, self.port, cert_reqs=ssl.CERT_NONE) as pool: + with mock.patch("warnings.warn") as warn: + r = pool.request("GET", "/") + assert r.status == 200 + assert warn.called - # Modern versions of Python, or systems using PyOpenSSL, only emit - # the unverified warning. Older systems may also emit other - # warnings, which we want to ignore here. - calls = warn.call_args_list - assert InsecureRequestWarning in [x[0][1] for x in calls] + # Modern versions of Python, or systems using PyOpenSSL, only emit + # the unverified warning. Older systems may also emit other + # warnings, which we want to ignore here. + calls = warn.call_args_list + assert InsecureRequestWarning in [x[0][1] for x in calls] def test_ssl_unverified_with_ca_certs(self): - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_NONE", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(pool.close) + ) as pool: + with mock.patch("warnings.warn") as warn: + r = pool.request("GET", "/") + assert r.status == 200 + assert warn.called - with mock.patch("warnings.warn") as warn: - r = pool.request("GET", "/") - assert r.status == 200 - assert warn.called - - # Modern versions of Python, or systems using PyOpenSSL, only emit - # the unverified warning. Older systems may also emit other - # warnings, which we want to ignore here. - calls = warn.call_args_list - if ( - sys.version_info >= (2, 7, 9) - or util.IS_PYOPENSSL - or util.IS_SECURETRANSPORT - ): - category = calls[0][0][1] - elif util.HAS_SNI: - category = calls[1][0][1] - else: - category = calls[2][0][1] - assert category == InsecureRequestWarning + # Modern versions of Python, or systems using PyOpenSSL, only emit + # the unverified warning. Older systems may also emit other + # warnings, which we want to ignore here. + calls = warn.call_args_list + if ( + sys.version_info >= (2, 7, 9) + or util.IS_PYOPENSSL + or util.IS_SECURETRANSPORT + ): + category = calls[0][0][1] + elif util.HAS_SNI: + category = calls[1][0][1] + else: + category = calls[2][0][1] + assert category == InsecureRequestWarning def test_assert_hostname_false(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_hostname = False - https_pool.request("GET", "/") + ) as https_pool: + https_pool.assert_hostname = False + https_pool.request("GET", "/") def test_assert_specific_hostname(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_hostname = "localhost" - https_pool.request("GET", "/") + ) as https_pool: + https_pool.assert_hostname = "localhost" + https_pool.request("GET", "/") def test_server_hostname(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA, server_hostname="localhost", - ) - self.addCleanup(https_pool.close) - conn = https_pool._new_conn() - https_pool._start_conn(conn, connect_timeout=None) - - # Assert the wrapping socket is using the passed-through SNI name. - # pyopenssl doesn't let you pull the server_hostname back off the - # socket, so only add this assertion if the attribute is there (i.e. - # the python ssl module). - # XXX This is highly-specific to SyncBackend - # See https://github.com/python-trio/urllib3/pull/54#discussion_r241683895 - # for potential solutions - sock = conn._sock._sock - if hasattr(sock, "server_hostname"): - assert sock.server_hostname == "localhost" + ) as https_pool: + conn = https_pool._new_conn() + https_pool._start_conn(conn, connect_timeout=None) + + # Assert the wrapping socket is using the passed-through SNI name. + # pyopenssl doesn't let you pull the server_hostname back off the + # socket, so only add this assertion if the attribute is there (i.e. + # the python ssl module). + # XXX This is highly-specific to SyncBackend + # See https://github.com/python-trio/urllib3/pull/54#discussion_r241683895 + # for potential solutions + sock = conn._sock._sock + if hasattr(sock, "server_hostname"): + assert sock.server_hostname == "localhost" def test_assert_fingerprint_md5(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "F2:06:5A:42:10:3F:45:1C:17:FE:E6:" "07:1E:8A:86:E5" - ) + ) as https_pool: + https_pool.assert_fingerprint = ( + "F2:06:5A:42:10:3F:45:1C:17:FE:E6:" "07:1E:8A:86:E5" + ) - https_pool.request("GET", "/") + https_pool.request("GET", "/") def test_assert_fingerprint_sha1(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) - https_pool.request("GET", "/") + ) as https_pool: + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) + https_pool.request("GET", "/") def test_assert_fingerprint_sha256(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "localhost", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "C5:4D:0B:83:84:89:2E:AE:B4:58:BB:12:" - "F7:A6:C4:76:05:03:88:D8:57:65:51:F3:" - "1E:60:B0:8B:70:18:64:E6" - ) - https_pool.request("GET", "/") + ) as https_pool: + https_pool.assert_fingerprint = ( + "C5:4D:0B:83:84:89:2E:AE:B4:58:BB:12:" + "F7:A6:C4:76:05:03:88:D8:57:65:51:F3:" + "1E:60:B0:8B:70:18:64:E6" + ) + https_pool.request("GET", "/") def test_assert_invalid_fingerprint(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" - ) + ) as https_pool: + https_pool.assert_fingerprint = ( + "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" + ) - def _test_request(pool): - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + def _test_request(pool): + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) - _test_request(https_pool) - https_pool._get_conn() + _test_request(https_pool) + https_pool._get_conn() - # Uneven length - https_pool.assert_fingerprint = "AA:A" - _test_request(https_pool) - https_pool._get_conn() + # Uneven length + https_pool.assert_fingerprint = "AA:A" + _test_request(https_pool) + https_pool._get_conn() - # Invalid length - https_pool.assert_fingerprint = "AA" - _test_request(https_pool) + # Invalid length + https_pool.assert_fingerprint = "AA" + _test_request(https_pool) def test_verify_none_and_bad_fingerprint(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_NONE", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" - ) - with pytest.raises(MaxRetryError) as cm: - https_pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + ) as https_pool: + https_pool.assert_fingerprint = ( + "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:" "AA:AA:AA:AA:AA:AA:AA:AA:AA" + ) + with pytest.raises(MaxRetryError) as cm: + https_pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) def test_verify_none_and_good_fingerprint(self): - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_NONE", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) - https_pool.request("GET", "/") + ) as https_pool: + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) + https_pool.request("GET", "/") @notSecureTransport def test_good_fingerprint_and_hostname_mismatch(self): @@ -511,104 +479,95 @@ def test_good_fingerprint_and_hostname_mismatch(self): # hostname validation without turning off all validation, which this # test doesn't do (deliberately). We should revisit this if we make # new decisions. - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) - https_pool.request("GET", "/") + ) as https_pool: + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) + https_pool.request("GET", "/") @requires_network def test_https_timeout(self): - timeout = Timeout(connect=0.001) - https_pool = HTTPSConnectionPool( - TARPIT_HOST, - self.port, - timeout=timeout, - retries=False, - cert_reqs="CERT_REQUIRED", - ) - self.addCleanup(https_pool.close) - timeout = Timeout(total=None, connect=0.001) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( TARPIT_HOST, self.port, timeout=timeout, retries=False, cert_reqs="CERT_REQUIRED", - ) - self.addCleanup(https_pool.close) - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/") + ) as https_pool: + with pytest.raises(ConnectTimeoutError): + https_pool.request("GET", "/") timeout = Timeout(read=0.01) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, timeout=timeout, retries=False, cert_reqs="CERT_REQUIRED", - ) - self.addCleanup(https_pool.close) - https_pool.assert_fingerprint = ( - "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" - ) + ) as https_pool: + https_pool.assert_fingerprint = ( + "92:81:FE:85:F7:0C:26:60:EC:D6:B3:" "BF:93:CF:F9:71:CC:07:7D:0A" + ) timeout = Timeout(total=None) - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, timeout=timeout, cert_reqs="CERT_NONE" - ) - self.addCleanup(https_pool.close) - https_pool.request("GET", "/") + ) as https_pool: + https_pool.request("GET", "/") @requires_network def test_enhanced_timeout(self): - def new_pool(timeout, cert_reqs="CERT_REQUIRED"): - https_pool = HTTPSConnectionPool( - TARPIT_HOST, - self.port, - timeout=timeout, - retries=False, - cert_reqs=cert_reqs, - ) - self.addCleanup(https_pool.close) - return https_pool - - https_pool = new_pool(Timeout(connect=0.001)) - conn = https_pool._new_conn() - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/") - with pytest.raises(ConnectTimeoutError): - https_pool._make_request(conn, "GET", "/") - - https_pool = new_pool(Timeout(connect=5)) - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/", timeout=Timeout(connect=0.001)) + with HTTPSConnectionPool( + TARPIT_HOST, + self.port, + timeout=Timeout(connect=0.001), + retries=False, + cert_reqs="CERT_REQUIRED", + ) as https_pool: + conn = https_pool._new_conn() + with pytest.raises(ConnectTimeoutError): + https_pool.request("GET", "/") + with pytest.raises(ConnectTimeoutError): + https_pool._make_request(conn, "GET", "/") + + with HTTPSConnectionPool( + TARPIT_HOST, + self.port, + timeout=Timeout(connect=5), + retries=False, + cert_reqs="CERT_REQUIRED", + ) as https_pool: + with pytest.raises(ConnectTimeoutError): + https_pool.request("GET", "/", timeout=Timeout(connect=0.001)) - t = Timeout(total=None) - https_pool = new_pool(t) - conn = https_pool._new_conn() - with pytest.raises(ConnectTimeoutError): - https_pool.request("GET", "/", timeout=Timeout(total=None, connect=0.001)) + with HTTPSConnectionPool( + TARPIT_HOST, + self.port, + timeout=Timeout(total=None), + retries=False, + cert_reqs="CERT_REQUIRED", + ) as https_pool: + conn = https_pool._new_conn() + with pytest.raises(ConnectTimeoutError): + https_pool.request( + "GET", "/", timeout=Timeout(total=None, connect=0.001) + ) def test_enhanced_ssl_connection(self): fingerprint = "92:81:FE:85:F7:0C:26:60:EC:D6:B3:BF:93:CF:F9:71:CC:07:7D:0A" - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA, assert_fingerprint=fingerprint, - ) - self.addCleanup(https_pool.close) - - r = https_pool.urlopen("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.urlopen("GET", "/") + assert r.status == 200 @onlyPy279OrNewer def test_ssl_correct_system_time(self): @@ -694,13 +653,12 @@ def test_warning_for_certs_without_a_san(self): """Ensure that a warning is raised when the cert from the server has no Subject Alternative Name.""" with mock.patch("warnings.warn") as warn: - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=NO_SAN_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 - assert warn.called + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 + assert warn.called class TestHTTPS_IPSAN(HTTPSDummyServerTestCase): @@ -713,12 +671,11 @@ def test_can_validate_ip_san(self): except ImportError: pytest.skip("Only runs on systems with an ipaddress module") - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 class TestHTTPS_IPv6Addr(IPV6HTTPSDummyServerTestCase): @@ -727,12 +684,11 @@ class TestHTTPS_IPv6Addr(IPV6HTTPSDummyServerTestCase): @pytest.mark.skipif(not HAS_IPV6, reason="Only runs on IPv6 systems") def test_strip_square_brackets_before_validating(self): """Test that the fix for #760 works.""" - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "[::1]", self.port, cert_reqs="CERT_REQUIRED", ca_certs=IPV6_ADDR_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 class TestHTTPS_IPV6SAN(IPV6HTTPSDummyServerTestCase): @@ -745,12 +701,11 @@ def test_can_validate_ipv6_san(self): except ImportError: pytest.skip("Only runs on systems with an ipaddress module") - https_pool = HTTPSConnectionPool( + with HTTPSConnectionPool( "[::1]", self.port, cert_reqs="CERT_REQUIRED", ca_certs=IPV6_SAN_CA - ) - self.addCleanup(https_pool.close) - r = https_pool.request("GET", "/") - assert r.status == 200 + ) as https_pool: + r = https_pool.request("GET", "/") + assert r.status == 200 if __name__ == "__main__": diff --git a/test/with_dummyserver/test_no_ssl.py b/test/with_dummyserver/test_no_ssl.py index d7148160..ce60adca 100644 --- a/test/with_dummyserver/test_no_ssl.py +++ b/test/with_dummyserver/test_no_ssl.py @@ -20,10 +20,9 @@ class TestHTTPWithoutSSL(HTTPDummyServerTestCase, TestWithoutSSL): ) ) def test_simple(self): - pool = urllib3.HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - r = pool.request("GET", "/") - assert r.status == 200, r.data + with urllib3.HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/") + assert r.status == 200, r.data class TestHTTPSWithoutSSL(HTTPSDummyServerTestCase, TestWithoutSSL): diff --git a/test/with_dummyserver/test_poolmanager.py b/test/with_dummyserver/test_poolmanager.py index 8d6f0a5e..0246cec4 100644 --- a/test/with_dummyserver/test_poolmanager.py +++ b/test/with_dummyserver/test_poolmanager.py @@ -19,338 +19,320 @@ def setUp(self): self.base_url_alt = "http://%s:%d" % (self.host_alt, self.port) def test_redirect(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/" % self.base_url}, - redirect=False, - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/" % self.base_url}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/" % self.base_url}, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/" % self.base_url}, + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_redirect_twice(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/redirect" % self.base_url}, - redirect=False, - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/redirect" % self.base_url}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) + }, + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_redirect_to_relative_url(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "/redirect"}, - redirect=False, - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "/redirect"}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", "%s/redirect" % self.base_url, fields={"target": "/redirect"} - ) + r = http.request( + "GET", "%s/redirect" % self.base_url, fields={"target": "/redirect"} + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_cross_host_redirect(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: + cross_host_location = "%s/echo?a=b" % self.base_url_alt + try: + http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": cross_host_location}, + timeout=1, + retries=0, + ) + self.fail( + "Request succeeded instead of raising an exception like it should." + ) - cross_host_location = "%s/echo?a=b" % self.base_url_alt - try: - http.request( + except MaxRetryError: + pass + + r = http.request( "GET", "%s/redirect" % self.base_url, - fields={"target": cross_host_location}, + fields={"target": "%s/echo?a=b" % self.base_url_alt}, timeout=1, - retries=0, - ) - self.fail( - "Request succeeded instead of raising an exception like it should." + retries=1, ) - except MaxRetryError: - pass - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/echo?a=b" % self.base_url_alt}, - timeout=1, - retries=1, - ) - - assert r._pool.host == self.host_alt + assert r._pool.host == self.host_alt def test_too_many_redirects(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: + try: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" + % (self.base_url, self.base_url) + }, + retries=1, + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass - try: - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - retries=1, - ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass + try: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" + % (self.base_url, self.base_url) + }, + retries=Retry(total=None, redirect=1), + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass - try: + def test_redirect_cross_host_remove_headers(self): + with PoolManager() as http: r = http.request( "GET", "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - retries=Retry(total=None, redirect=1), + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"Authorization": "foo"}, ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass - def test_redirect_cross_host_remove_headers(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"Authorization": "foo"}, - ) - - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "Authorization" not in data + assert "Authorization" not in data - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"authorization": "foo"}, - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"authorization": "foo"}, + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "authorization" not in data - assert "Authorization" not in data + assert "authorization" not in data + assert "Authorization" not in data def test_redirect_cross_host_no_remove_headers(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"Authorization": "foo"}, - retries=Retry(remove_headers_on_redirect=[]), - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"Authorization": "foo"}, + retries=Retry(remove_headers_on_redirect=[]), + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert data["Authorization"] == "foo" + assert data["Authorization"] == "foo" def test_redirect_cross_host_set_removed_headers(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"X-API-Secret": "foo", "Authorization": "bar"}, - retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"X-API-Secret": "foo", "Authorization": "bar"}, + retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "X-API-Secret" not in data - assert data["Authorization"] == "bar" + assert "X-API-Secret" not in data + assert data["Authorization"] == "bar" - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={"target": "%s/headers" % self.base_url_alt}, - headers={"x-api-secret": "foo", "authorization": "bar"}, - retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), - ) + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={"target": "%s/headers" % self.base_url_alt}, + headers={"x-api-secret": "foo", "authorization": "bar"}, + retries=Retry(remove_headers_on_redirect=["X-API-Secret"]), + ) - assert r.status == 200 + assert r.status == 200 - data = json.loads(r.data.decode("utf-8")) + data = json.loads(r.data.decode("utf-8")) - assert "x-api-secret" not in data - assert "X-API-Secret" not in data - assert data["Authorization"] == "bar" + assert "x-api-secret" not in data + assert "X-API-Secret" not in data + assert data["Authorization"] == "bar" def test_raise_on_redirect(self): - http = PoolManager() - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.base_url, - fields={ - "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) - }, - retries=Retry(total=None, redirect=1, raise_on_redirect=False), - ) + with PoolManager() as http: + r = http.request( + "GET", + "%s/redirect" % self.base_url, + fields={ + "target": "%s/redirect?target=%s/" % (self.base_url, self.base_url) + }, + retries=Retry(total=None, redirect=1, raise_on_redirect=False), + ) - assert r.status == 303 + assert r.status == 303 def test_raise_on_status(self): - http = PoolManager() - self.addCleanup(http.clear) + with PoolManager() as http: + try: + # the default is to raise + r = http.request( + "GET", + "%s/status" % self.base_url, + fields={"status": "500 Internal Server Error"}, + retries=Retry(total=1, status_forcelist=range(500, 600)), + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass - try: - # the default is to raise - r = http.request( - "GET", - "%s/status" % self.base_url, - fields={"status": "500 Internal Server Error"}, - retries=Retry(total=1, status_forcelist=range(500, 600)), - ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass + try: + # raise explicitly + r = http.request( + "GET", + "%s/status" % self.base_url, + fields={"status": "500 Internal Server Error"}, + retries=Retry( + total=1, status_forcelist=range(500, 600), raise_on_status=True + ), + ) + self.fail( + "Failed to raise MaxRetryError exception, returned %r" % r.status + ) + except MaxRetryError: + pass - try: - # raise explicitly + # don't raise r = http.request( "GET", "%s/status" % self.base_url, fields={"status": "500 Internal Server Error"}, retries=Retry( - total=1, status_forcelist=range(500, 600), raise_on_status=True + total=1, status_forcelist=range(500, 600), raise_on_status=False ), ) - self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) - except MaxRetryError: - pass - - # don't raise - r = http.request( - "GET", - "%s/status" % self.base_url, - fields={"status": "500 Internal Server Error"}, - retries=Retry( - total=1, status_forcelist=range(500, 600), raise_on_status=False - ), - ) - assert r.status == 500 + assert r.status == 500 def test_missing_port(self): # Can a URL that lacks an explicit port like ':80' succeed, or # will all such URLs fail with an error? - http = PoolManager() - self.addCleanup(http.clear) - - # By globally adjusting `DEFAULT_PORTS` we pretend for a moment - # that HTTP's default port is not 80, but is the port at which - # our test server happens to be listening. - DEFAULT_PORTS["http"] = self.port - try: - r = http.request("GET", "http://%s/" % self.host, retries=0) - finally: - DEFAULT_PORTS["http"] = 80 + with PoolManager() as http: + # By globally adjusting `DEFAULT_PORTS` we pretend for a moment + # that HTTP's default port is not 80, but is the port at which + # our test server happens to be listening. + DEFAULT_PORTS["http"] = self.port + try: + r = http.request("GET", "http://%s/" % self.host, retries=0) + finally: + DEFAULT_PORTS["http"] = 80 - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_headers(self): - http = PoolManager(headers={"Foo": "bar"}) - self.addCleanup(http.clear) - - r = http.request("GET", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request("POST", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request_encode_url("GET", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request_encode_body("POST", "%s/headers" % self.base_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - - r = http.request_encode_url( - "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - - r = http.request_encode_body( - "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" + with PoolManager(headers={"Foo": "bar"}) as http: + r = http.request("GET", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" - def test_http_with_ssl_keywords(self): - http = PoolManager(ca_certs="REQUIRED") - self.addCleanup(http.clear) + r = http.request("POST", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" - r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) - assert r.status == 200 + r = http.request_encode_url("GET", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" - def test_http_with_ca_cert_dir(self): - http = PoolManager(ca_certs="REQUIRED", ca_cert_dir="/nosuchdir") - self.addCleanup(http.clear) + r = http.request_encode_body("POST", "%s/headers" % self.base_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" - r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) - assert r.status == 200 + r = http.request_encode_url( + "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + + r = http.request_encode_body( + "GET", "%s/headers" % self.base_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + + def test_http_with_ssl_keywords(self): + with PoolManager(ca_certs="REQUIRED") as http: + r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) + assert r.status == 200 + + def test_http_with_ca_cert_dir(self): + with PoolManager(ca_certs="REQUIRED", ca_cert_dir="/nosuchdir") as http: + r = http.request("GET", "http://%s:%s/" % (self.host, self.port)) + assert r.status == 200 def test_cleanup_on_connection_error(self): """ @@ -733,9 +715,8 @@ def setUp(self): self.base_url = "http://[%s]:%d" % (self.host, self.port) def test_ipv6(self): - http = PoolManager() - self.addCleanup(http.clear) - http.request("GET", self.base_url) + with PoolManager() as http: + http.request("GET", self.base_url) if __name__ == "__main__": diff --git a/test/with_dummyserver/test_proxy_poolmanager.py b/test/with_dummyserver/test_proxy_poolmanager.py index ca976338..3482a8e1 100644 --- a/test/with_dummyserver/test_proxy_poolmanager.py +++ b/test/with_dummyserver/test_proxy_poolmanager.py @@ -22,343 +22,332 @@ def setUp(self): self.proxy_url = "http://%s:%d" % (self.proxy_host, self.proxy_port) def test_basic_proxy(self): - http = proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: + r = http.request("GET", "%s/" % self.http_url) + assert r.status == 200 - r = http.request("GET", "%s/" % self.http_url) - assert r.status == 200 - - r = http.request("GET", "%s/" % self.https_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url) + assert r.status == 200 def test_nagle_proxy(self): """ Test that proxy connections do not have TCP_NODELAY turned on """ - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - hc2 = http.connection_from_host(self.http_host, self.http_port) - conn = hc2._get_conn() - self.addCleanup(conn.close) - hc2._make_request(conn, "GET", "/") - tcp_nodelay_setting = conn._sock._getsockopt_tcp_nodelay() - assert tcp_nodelay_setting == 0, ( - "Expected TCP_NODELAY for proxies to be set " - "to zero, instead was %s" % tcp_nodelay_setting - ) + with proxy_from_url(self.proxy_url) as http: + hc2 = http.connection_from_host(self.http_host, self.http_port) + conn = hc2._get_conn() + try: + hc2._make_request(conn, "GET", "/") + tcp_nodelay_setting = conn._sock._getsockopt_tcp_nodelay() + assert tcp_nodelay_setting == 0, ( + "Expected TCP_NODELAY for proxies to be set " + "to zero, instead was %s" % tcp_nodelay_setting + ) + finally: + conn.close() def test_proxy_conn_fail(self): host, port = get_unreachable_address() - http = proxy_from_url("http://%s:%s/" % (host, port), retries=1, timeout=0.05) - self.addCleanup(http.clear) - with pytest.raises(MaxRetryError): - http.request("GET", "%s/" % self.https_url) - with pytest.raises(MaxRetryError): - http.request("GET", "%s/" % self.http_url) - - try: - http.request("GET", "%s/" % self.http_url) - self.fail("Failed to raise retry error.") - except MaxRetryError as e: - assert type(e.reason) == ProxyError + with proxy_from_url( + "http://%s:%s/" % (host, port), retries=1, timeout=0.05 + ) as http: + with pytest.raises(MaxRetryError): + http.request("GET", "%s/" % self.https_url) + with pytest.raises(MaxRetryError): + http.request("GET", "%s/" % self.http_url) + + try: + http.request("GET", "%s/" % self.http_url) + self.fail("Failed to raise retry error.") + except MaxRetryError as e: + assert type(e.reason) == ProxyError def test_oldapi(self): - http = ProxyManager(connection_from_url(self.proxy_url), ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) - - r = http.request("GET", "%s/" % self.http_url) - assert r.status == 200 + with ProxyManager( + connection_from_url(self.proxy_url), ca_certs=DEFAULT_CA + ) as http: + r = http.request("GET", "%s/" % self.http_url) + assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url) + assert r.status == 200 def test_proxy_verified(self): - http = proxy_from_url( + with proxy_from_url( self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA_BAD - ) - self.addCleanup(http.clear) - https_pool = http._new_pool("https", self.https_host, self.https_port) - try: - https_pool.request("GET", "/", retries=0) - self.fail("Didn't raise SSL error with wrong CA") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "certificate verify failed" in str(e.reason) - - http = proxy_from_url(self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA) - https_pool = http._new_pool("https", self.https_host, self.https_port) - - https_pool.request("GET", "/") # Should succeed without exceptions. - - http = proxy_from_url(self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA) - https_fail_pool = http._new_pool("https", "127.0.0.1", self.https_port) - - try: - https_fail_pool.request("GET", "/", retries=0) - self.fail("Didn't raise SSL invalid common name") - except MaxRetryError as e: - assert isinstance(e.reason, SSLError) - assert "doesn't match" in str(e.reason) + ) as http: + https_pool = http._new_pool("https", self.https_host, self.https_port) + try: + https_pool.request("GET", "/", retries=0) + self.fail("Didn't raise SSL error with wrong CA") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "certificate verify failed" in str(e.reason) + + with proxy_from_url( + self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA + ) as http: + https_pool = http._new_pool("https", self.https_host, self.https_port) + + https_pool.request("GET", "/") # Should succeed without exceptions. + + with proxy_from_url( + self.proxy_url, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA + ) as http: + https_fail_pool = http._new_pool("https", "127.0.0.1", self.https_port) + + try: + https_fail_pool.request("GET", "/", retries=0) + self.fail("Didn't raise SSL invalid common name") + except MaxRetryError as e: + assert isinstance(e.reason, SSLError) + assert "doesn't match" in str(e.reason) def test_redirect(self): - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/" % self.http_url}, - redirect=False, - ) + with proxy_from_url(self.proxy_url) as http: + r = http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": "%s/" % self.http_url}, + redirect=False, + ) - assert r.status == 303 + assert r.status == 303 - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/" % self.http_url}, - ) + r = http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": "%s/" % self.http_url}, + ) - assert r.status == 200 - assert r.data == b"Dummy server!" + assert r.status == 200 + assert r.data == b"Dummy server!" def test_cross_host_redirect(self): - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - - cross_host_location = "%s/echo?a=b" % self.http_url_alt - try: - http.request( + with proxy_from_url(self.proxy_url) as http: + cross_host_location = "%s/echo?a=b" % self.http_url_alt + try: + http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": cross_host_location}, + timeout=1, + retries=0, + ) + self.fail("We don't want to follow redirects here.") + + except MaxRetryError: + pass + + r = http.request( "GET", "%s/redirect" % self.http_url, - fields={"target": cross_host_location}, + fields={"target": "%s/echo?a=b" % self.http_url_alt}, timeout=1, - retries=0, + retries=1, ) - self.fail("We don't want to follow redirects here.") - - except MaxRetryError: - pass - - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/echo?a=b" % self.http_url_alt}, - timeout=1, - retries=1, - ) - assert r._pool.host != self.http_host_alt + assert r._pool.host != self.http_host_alt def test_cross_protocol_redirect(self): - http = proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) - - cross_protocol_location = "%s/echo?a=b" % self.https_url - try: - http.request( + with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: + cross_protocol_location = "%s/echo?a=b" % self.https_url + try: + http.request( + "GET", + "%s/redirect" % self.http_url, + fields={"target": cross_protocol_location}, + timeout=1, + retries=0, + ) + self.fail("We don't want to follow redirects here.") + + except MaxRetryError: + pass + + r = http.request( "GET", "%s/redirect" % self.http_url, - fields={"target": cross_protocol_location}, + fields={"target": "%s/echo?a=b" % self.https_url}, timeout=1, - retries=0, + retries=1, ) - self.fail("We don't want to follow redirects here.") - - except MaxRetryError: - pass - - r = http.request( - "GET", - "%s/redirect" % self.http_url, - fields={"target": "%s/echo?a=b" % self.https_url}, - timeout=1, - retries=1, - ) - assert r._pool.host == self.https_host + assert r._pool.host == self.https_host def test_headers(self): - http = proxy_from_url( + with proxy_from_url( self.proxy_url, headers={"Foo": "bar"}, proxy_headers={"Hickory": "dickory"}, ca_certs=DEFAULT_CA, - ) - self.addCleanup(http.clear) - - r = http.request_encode_url("GET", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_url("GET", "%s/headers" % self.http_url_alt) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host_alt, - self.http_port, - ) - - r = http.request_encode_url("GET", "%s/headers" % self.https_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, - ) - - r = http.request_encode_body("POST", "%s/headers" % self.http_url) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_url( - "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_url( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, - ) - - r = http.request_encode_body( - "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") == "dickory" - assert returned_headers.get("Host") == "%s:%s" % ( - self.http_host, - self.http_port, - ) - - r = http.request_encode_body( - "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} - ) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") is None - assert returned_headers.get("Baz") == "quux" - assert returned_headers.get("Hickory") is None - assert returned_headers.get("Host") == "%s:%s" % ( - self.https_host, - self.https_port, - ) + ) as http: + r = http.request_encode_url("GET", "%s/headers" % self.http_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_url("GET", "%s/headers" % self.http_url_alt) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host_alt, + self.http_port, + ) + + r = http.request_encode_url("GET", "%s/headers" % self.https_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") is None + assert returned_headers.get("Host") == "%s:%s" % ( + self.https_host, + self.https_port, + ) + + r = http.request_encode_body("POST", "%s/headers" % self.http_url) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_url( + "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_url( + "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") is None + assert returned_headers.get("Host") == "%s:%s" % ( + self.https_host, + self.https_port, + ) + + r = http.request_encode_body( + "GET", "%s/headers" % self.http_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") == "dickory" + assert returned_headers.get("Host") == "%s:%s" % ( + self.http_host, + self.http_port, + ) + + r = http.request_encode_body( + "GET", "%s/headers" % self.https_url, headers={"Baz": "quux"} + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") is None + assert returned_headers.get("Baz") == "quux" + assert returned_headers.get("Hickory") is None + assert returned_headers.get("Host") == "%s:%s" % ( + self.https_host, + self.https_port, + ) def test_headerdict(self): default_headers = HTTPHeaderDict(a="b") proxy_headers = HTTPHeaderDict() proxy_headers.add("foo", "bar") - http = proxy_from_url( + with proxy_from_url( self.proxy_url, headers=default_headers, proxy_headers=proxy_headers - ) - self.addCleanup(http.clear) - - request_headers = HTTPHeaderDict(baz="quux") - r = http.request("GET", "%s/headers" % self.http_url, headers=request_headers) - returned_headers = json.loads(r.data.decode()) - assert returned_headers.get("Foo") == "bar" - assert returned_headers.get("Baz") == "quux" + ) as http: + request_headers = HTTPHeaderDict(baz="quux") + r = http.request( + "GET", "%s/headers" % self.http_url, headers=request_headers + ) + returned_headers = json.loads(r.data.decode()) + assert returned_headers.get("Foo") == "bar" + assert returned_headers.get("Baz") == "quux" def test_proxy_pooling(self): - http = proxy_from_url(self.proxy_url, cert_reqs="NONE") - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url, cert_reqs="NONE") as http: + for x in range(2): + http.urlopen("GET", self.http_url) + assert len(http.pools) == 1 - for x in range(2): - http.urlopen("GET", self.http_url) - assert len(http.pools) == 1 + for x in range(2): + http.urlopen("GET", self.http_url_alt) + assert len(http.pools) == 1 - for x in range(2): - http.urlopen("GET", self.http_url_alt) - assert len(http.pools) == 1 + for x in range(2): + http.urlopen("GET", self.https_url) + assert len(http.pools) == 2 - for x in range(2): - http.urlopen("GET", self.https_url) - assert len(http.pools) == 2 - - for x in range(2): - http.urlopen("GET", self.https_url_alt) - assert len(http.pools) == 3 + for x in range(2): + http.urlopen("GET", self.https_url_alt) + assert len(http.pools) == 3 def test_proxy_pooling_ext(self): - http = proxy_from_url(self.proxy_url) - self.addCleanup(http.clear) - - hc1 = http.connection_from_url(self.http_url) - hc2 = http.connection_from_host(self.http_host, self.http_port) - hc3 = http.connection_from_url(self.http_url_alt) - hc4 = http.connection_from_host(self.http_host_alt, self.http_port) - assert hc1 == hc2 - assert hc2 == hc3 - assert hc3 == hc4 - - sc1 = http.connection_from_url(self.https_url) - sc2 = http.connection_from_host( - self.https_host, self.https_port, scheme="https" - ) - sc3 = http.connection_from_url(self.https_url_alt) - sc4 = http.connection_from_host( - self.https_host_alt, self.https_port, scheme="https" - ) - assert sc1 == sc2 - assert sc2 != sc3 - assert sc3 == sc4 + with proxy_from_url(self.proxy_url) as http: + hc1 = http.connection_from_url(self.http_url) + hc2 = http.connection_from_host(self.http_host, self.http_port) + hc3 = http.connection_from_url(self.http_url_alt) + hc4 = http.connection_from_host(self.http_host_alt, self.http_port) + assert hc1 == hc2 + assert hc2 == hc3 + assert hc3 == hc4 + + sc1 = http.connection_from_url(self.https_url) + sc2 = http.connection_from_host( + self.https_host, self.https_port, scheme="https" + ) + sc3 = http.connection_from_url(self.https_url_alt) + sc4 = http.connection_from_host( + self.https_host_alt, self.https_port, scheme="https" + ) + assert sc1 == sc2 + assert sc2 != sc3 + assert sc3 == sc4 @pytest.mark.timeout(0.5) @requires_network def test_https_proxy_timeout(self): - https = proxy_from_url("https://{host}".format(host=TARPIT_HOST)) - self.addCleanup(https.clear) - try: - https.request("GET", self.http_url, timeout=0.001) - self.fail("Failed to raise retry error.") - except MaxRetryError as e: - assert type(e.reason) == ConnectTimeoutError + with proxy_from_url("https://{host}".format(host=TARPIT_HOST)) as https: + try: + https.request("GET", self.http_url, timeout=0.001) + self.fail("Failed to raise retry error.") + except MaxRetryError as e: + assert type(e.reason) == ConnectTimeoutError @pytest.mark.timeout(0.5) @requires_network def test_https_proxy_pool_timeout(self): - https = proxy_from_url("https://{host}".format(host=TARPIT_HOST), timeout=0.001) - self.addCleanup(https.clear) - try: - https.request("GET", self.http_url) - self.fail("Failed to raise retry error.") - except MaxRetryError as e: - assert type(e.reason) == ConnectTimeoutError + with proxy_from_url( + "https://{host}".format(host=TARPIT_HOST), timeout=0.001 + ) as https: + try: + https.request("GET", self.http_url) + self.fail("Failed to raise retry error.") + except MaxRetryError as e: + assert type(e.reason) == ConnectTimeoutError def test_scheme_host_case_insensitive(self): """Assert that upper-case schemes and hosts are normalized.""" - http = proxy_from_url(self.proxy_url.upper(), ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) + with proxy_from_url(self.proxy_url.upper(), ca_certs=DEFAULT_CA) as http: + r = http.request("GET", "%s/" % self.http_url.upper()) + assert r.status == 200 - r = http.request("GET", "%s/" % self.http_url.upper()) - assert r.status == 200 - - r = http.request("GET", "%s/" % self.https_url.upper()) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url.upper()) + assert r.status == 200 class TestIPv6HTTPProxyManager(IPv6HTTPDummyProxyTestCase): @@ -370,14 +359,12 @@ def setUp(self): self.proxy_url = "http://[%s]:%d" % (self.proxy_host, self.proxy_port) def test_basic_ipv6_proxy(self): - http = proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) - self.addCleanup(http.clear) - - r = http.request("GET", "%s/" % self.http_url) - assert r.status == 200 + with proxy_from_url(self.proxy_url, ca_certs=DEFAULT_CA) as http: + r = http.request("GET", "%s/" % self.http_url) + assert r.status == 200 - r = http.request("GET", "%s/" % self.https_url) - assert r.status == 200 + r = http.request("GET", "%s/" % self.https_url) + assert r.status == 200 if __name__ == "__main__": diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index f5778465..28af8dbe 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -69,11 +69,10 @@ def multicookie_response_handler(listener): sock.close() self._start_server(multicookie_response_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - r = pool.request("GET", "/", retries=0) - assert r.headers == {"set-cookie": "foo=1, bar=1"} - assert r.headers.getlist("set-cookie") == ["foo=1", "bar=1"] + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/", retries=0) + assert r.headers == {"set-cookie": "foo=1, bar=1"} + assert r.headers.getlist("set-cookie") == ["foo=1", "bar=1"] class TestSNI(SocketDummyServerTestCase): @@ -90,16 +89,15 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - try: - pool.request("GET", "/", retries=0) - except MaxRetryError: # We are violating the protocol - pass - done_receiving.wait() - assert ( - self.host.encode("ascii") in self.buf - ), "missing hostname in SSL handshake" + with HTTPSConnectionPool(self.host, self.port) as pool: + try: + pool.request("GET", "/", retries=0) + except MaxRetryError: # We are violating the protocol + pass + done_receiving.wait() + assert ( + self.host.encode("ascii") in self.buf + ), "missing hostname in SSL handshake" class TestClientCerts(SocketDummyServerTestCase): @@ -152,19 +150,18 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_file=DEFAULT_CERTS["certfile"], key_file=DEFAULT_CERTS["keyfile"], cert_reqs="REQUIRED", ca_certs=DEFAULT_CA, - ) - self.addCleanup(pool.close) - pool.request("GET", "/", retries=0) - done_receiving.set() + ) as pool: + pool.request("GET", "/", retries=0) + done_receiving.set() - assert len(client_certs) == 1 + assert len(client_certs) == 1 def test_client_certs_one_file(self): """ @@ -197,18 +194,17 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_file=COMBINED_CERT_AND_KEY, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA, - ) - self.addCleanup(pool.close) - pool.request("GET", "/", retries=0) - done_receiving.set() + ) as pool: + pool.request("GET", "/", retries=0) + done_receiving.set() - assert len(client_certs) == 1 + assert len(client_certs) == 1 def test_missing_client_certs_raises_error(self): """ @@ -228,20 +224,19 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA - ) - self.addCleanup(pool.close) - try: - pool.request("GET", "/", retries=0) - except MaxRetryError: - done_receiving.set() - else: - done_receiving.set() - self.fail( - "Expected server to reject connection due to missing client " - "certificates" - ) + ) as pool: + try: + pool.request("GET", "/", retries=0) + except MaxRetryError: + done_receiving.set() + else: + done_receiving.set() + self.fail( + "Expected server to reject connection due to missing client " + "certificates" + ) @requires_ssl_context_keyfile_password def test_client_cert_with_string_password(self): @@ -288,18 +283,17 @@ def socket_handler(listener): password=password, ) - pool = HTTPSConnectionPool( + with HTTPSConnectionPool( self.host, self.port, ssl_context=ssl_context, cert_reqs="REQUIRED", ca_certs=DEFAULT_CA, - ) - self.addCleanup(pool.close) - pool.request("GET", "/", retries=0) - done_receiving.set() + ) as pool: + pool.request("GET", "/", retries=0) + done_receiving.set() - assert len(client_certs) == 1 + assert len(client_certs) == 1 @requires_ssl_context_keyfile_password def test_load_keyfile_with_invalid_password(self): @@ -352,27 +346,24 @@ def socket_handler(listener): done_closing.set() # let the test know it can proceed self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - response = pool.request("GET", "/", retries=0) - assert response.status == 200 - assert response.data == b"Response 0" + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request("GET", "/", retries=0) + assert response.status == 200 + assert response.data == b"Response 0" - done_closing.wait() # wait until the socket in our pool gets closed + done_closing.wait() # wait until the socket in our pool gets closed - response = pool.request("GET", "/", retries=0) - assert response.status == 200 - assert response.data == b"Response 1" + response = pool.request("GET", "/", retries=0) + assert response.status == 200 + assert response.data == b"Response 1" def test_connection_refused(self): # Does the pool retry if there is no listener on the port? host, port = get_unreachable_address() - http = HTTPConnectionPool(host, port, maxsize=3, block=True) - self.addCleanup(http.close) - with pytest.raises(MaxRetryError): - http.request("GET", "/", retries=0, release_conn=False) - assert http.pool.qsize() == http.pool.maxsize + with HTTPConnectionPool(host, port, maxsize=3, block=True) as http: + with pytest.raises(MaxRetryError): + http.request("GET", "/", retries=0, release_conn=False) + assert http.pool.qsize() == http.pool.maxsize def test_connection_read_timeout(self): timed_out = Event() @@ -386,18 +377,16 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - http = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, timeout=0.01, retries=False, maxsize=3, block=True - ) - self.addCleanup(http.close) - - try: - with pytest.raises(ReadTimeoutError): - http.request("GET", "/", release_conn=False) - finally: - timed_out.set() + ) as http: + try: + with pytest.raises(ReadTimeoutError): + http.request("GET", "/", release_conn=False) + finally: + timed_out.set() - assert http.pool.qsize() == http.pool.maxsize + assert http.pool.qsize() == http.pool.maxsize def test_read_timeout_dont_retry_method_not_in_whitelist(self): timed_out = Event() @@ -409,14 +398,14 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port, timeout=0.01, retries=True) - self.addCleanup(pool.close) - - try: - with pytest.raises(ReadTimeoutError): - pool.request("POST", "/") - finally: - timed_out.set() + with HTTPConnectionPool( + self.host, self.port, timeout=0.01, retries=True + ) as pool: + try: + with pytest.raises(ReadTimeoutError): + pool.request("POST", "/") + finally: + timed_out.set() @pytest.mark.skip def test_https_connection_read_timeout(self): @@ -432,13 +421,14 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, timeout=0.01, retries=False) - self.addCleanup(pool.close) - try: - with pytest.raises(ReadTimeoutError): - pool.request("GET", "/") - finally: - timed_out.set() + with HTTPSConnectionPool( + self.host, self.port, timeout=0.01, retries=False + ) as pool: + try: + with pytest.raises(ReadTimeoutError): + pool.request("GET", "/") + finally: + timed_out.set() def test_timeout_errors_cause_retries(self): def socket_handler(listener): @@ -477,12 +467,10 @@ def socket_handler(listener): try: self._start_server(socket_handler) t = Timeout(connect=0.001, read=0.01) - pool = HTTPConnectionPool(self.host, self.port, timeout=t) - self.addCleanup(pool.close) - - response = pool.request("GET", "/", retries=1) - assert response.status == 200 - assert response.data == b"Response 2" + with HTTPConnectionPool(self.host, self.port, timeout=t) as pool: + response = pool.request("GET", "/", retries=1) + assert response.status == 200 + assert response.data == b"Response 2" finally: socket.setdefaulttimeout(default_timeout) @@ -509,21 +497,19 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - response = pool.urlopen( - "GET", - "/", - retries=0, - preload_content=False, - timeout=Timeout(connect=1, read=0.01), - ) - try: - with pytest.raises(ReadTimeoutError): - response.read() - finally: - timed_out.set() + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.urlopen( + "GET", + "/", + retries=0, + preload_content=False, + timeout=Timeout(connect=1, read=0.01), + ) + try: + with pytest.raises(ReadTimeoutError): + response.read() + finally: + timed_out.set() def test_delayed_body_read_timeout_with_preload(self): timed_out = Event() @@ -547,16 +533,14 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - try: - with pytest.raises(ReadTimeoutError): - pool.urlopen( - "GET", "/", retries=False, timeout=Timeout(connect=1, read=0.01) - ) - finally: - timed_out.set() + with HTTPConnectionPool(self.host, self.port) as pool: + try: + with pytest.raises(ReadTimeoutError): + pool.urlopen( + "GET", "/", retries=False, timeout=Timeout(connect=1, read=0.01) + ) + finally: + timed_out.set() def test_incomplete_response(self): body = "Response" @@ -583,12 +567,10 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - response = pool.request("GET", "/", retries=0, preload_content=False) - with pytest.raises(ProtocolError): - response.read() + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request("GET", "/", retries=0, preload_content=False) + with pytest.raises(ProtocolError): + response.read() def test_retry_weird_http_version(self): """ Retry class should handle httplib.BadStatusLine errors properly """ @@ -634,12 +616,11 @@ def socket_handler(listener): sock.close() # Close the socket. self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - retry = Retry(read=1) - response = pool.request("GET", "/", retries=retry) - assert response.status == 200 - assert response.data == b"foo" + with HTTPConnectionPool(self.host, self.port) as pool: + retry = Retry(read=1) + response = pool.request("GET", "/", retries=retry) + assert response.status == 200 + assert response.data == b"foo" def test_dont_tolerate_bad_versions(self): """We don't tolerate weird versions of HTTP""" @@ -666,13 +647,11 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) + with HTTPConnectionPool(self.host, self.port) as pool: + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, BadVersionError) + assert isinstance(cm.value.reason, BadVersionError) def test_connection_cleanup_on_read_timeout(self): timed_out = Event() @@ -856,17 +835,15 @@ def socket_handler(listener): complete.set() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - response = pool.request("GET", "/", retries=0, preload_content=False) - assert response.status == 200 - response.close() + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request("GET", "/", retries=0, preload_content=False) + assert response.status == 200 + response.close() - done_closing.set() # wait until the socket in our pool gets closed - successful = complete.wait(timeout=1) - if not successful: - self.fail("Timed out waiting for connection close") + done_closing.set() # wait until the socket in our pool gets closed + successful = complete.wait(timeout=1) + if not successful: + self.fail("Timed out waiting for connection close") def test_release_conn_param_is_respected_after_timeout_retry(self): """For successful ```urlopen()```, the connection isn't released, even @@ -978,14 +955,12 @@ def body_uploader(): yield b"third data" self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - response = pool.request("POST", "/", body=body_uploader(), retries=0) + with HTTPConnectionPool(self.host, self.port) as pool: + response = pool.request("POST", "/", body=body_uploader(), retries=0) - # Only the first data should have been received by the server. - assert response.status == 400 - assert response.data == b"a\r\nfirst data\r\n" + # Only the first data should have been received by the server. + assert response.status == 400 + assert response.data == b"a\r\nfirst data\r\n" class TestProxyManager(SocketDummyServerTestCase): @@ -1010,24 +985,22 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url) - self.addCleanup(proxy.clear) - - r = proxy.request("GET", "http://google.com/") - - assert r.status == 200 - # FIXME: The order of the headers is not predictable right now. We - # should fix that someday (maybe when we migrate to - # OrderedDict/MultiDict). - assert sorted(r.data.split(b"\r\n")) == sorted( - [ - b"GET http://google.com/ HTTP/1.1", - b"host: google.com", - b"accept: */*", - b"", - b"", - ] - ) + with proxy_from_url(base_url) as proxy: + r = proxy.request("GET", "http://google.com/") + + assert r.status == 200 + # FIXME: The order of the headers is not predictable right now. We + # should fix that someday (maybe when we migrate to + # OrderedDict/MultiDict). + assert sorted(r.data.split(b"\r\n")) == sorted( + [ + b"GET http://google.com/ HTTP/1.1", + b"host: google.com", + b"accept: */*", + b"", + b"", + ] + ) def test_headers(self): def echo_socket_handler(listener): @@ -1053,18 +1026,16 @@ def echo_socket_handler(listener): # Define some proxy headers. proxy_headers = HTTPHeaderDict({"For-The-Proxy": "YEAH!"}) - proxy = proxy_from_url(base_url, proxy_headers=proxy_headers) - self.addCleanup(proxy.clear) + with proxy_from_url(base_url, proxy_headers=proxy_headers) as proxy: + conn = proxy.connection_from_url("http://www.google.com/") - conn = proxy.connection_from_url("http://www.google.com/") + r = conn.urlopen("GET", "http://www.google.com/") - r = conn.urlopen("GET", "http://www.google.com/") - - assert r.status == 200 - # FIXME: The order of the headers is not predictable right now. We - # should fix that someday (maybe when we migrate to - # OrderedDict/MultiDict). - assert b"for-the-proxy: YEAH!\r\n" in r.data + assert r.status == 200 + # FIXME: The order of the headers is not predictable right now. We + # should fix that someday (maybe when we migrate to + # OrderedDict/MultiDict). + assert b"for-the-proxy: YEAH!\r\n" in r.data def test_retries(self): close_event = Event() @@ -1096,16 +1067,15 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url) - self.addCleanup(proxy.clear) - conn = proxy.connection_from_url("http://www.google.com") + with proxy_from_url(base_url) as proxy: + conn = proxy.connection_from_url("http://www.google.com") - r = conn.urlopen("GET", "http://www.google.com", retries=1) - assert r.status == 200 + r = conn.urlopen("GET", "http://www.google.com", retries=1) + assert r.status == 200 - close_event.wait(timeout=1) - with pytest.raises(ProxyError): - conn.urlopen("GET", "http://www.google.com", retries=False) + close_event.wait(timeout=1) + with pytest.raises(ProxyError): + conn.urlopen("GET", "http://www.google.com", retries=False) def test_connect_reconn(self): def proxy_ssl_one(listener): @@ -1161,15 +1131,13 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url, ca_certs=DEFAULT_CA) - self.addCleanup(proxy.clear) - - url = "https://{0}".format(self.host) - conn = proxy.connection_from_url(url) - r = conn.urlopen("GET", url, retries=0) - assert r.status == 200 - r = conn.urlopen("GET", url, retries=0) - assert r.status == 200 + with proxy_from_url(base_url, ca_certs=DEFAULT_CA) as proxy: + url = "https://{0}".format(self.host) + conn = proxy.connection_from_url(url) + r = conn.urlopen("GET", url, retries=0) + assert r.status == 200 + r = conn.urlopen("GET", url, retries=0) + assert r.status == 200 def test_connect_failing(self): def handler(listener): @@ -1190,18 +1158,16 @@ def handler(listener): self._start_server(handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url) - self.addCleanup(proxy.clear) + with proxy_from_url(base_url) as proxy: + url = "https://{0}".format(self.host) + conn = proxy.connection_from_url(url) - url = "https://{0}".format(self.host) - conn = proxy.connection_from_url(url) + with pytest.raises(FailedTunnelError) as cm: + conn.urlopen("GET", url, retries=0) - with pytest.raises(FailedTunnelError) as cm: - conn.urlopen("GET", url, retries=0) - - exception = cm.value - assert exception.response.status_code == 401 - assert exception.response.headers["x-custom-header"] == "yougotit" + exception = cm.value + assert exception.response.status_code == 401 + assert exception.response.headers["x-custom-header"] == "yougotit" def test_connect_ipv6_addr(self): ipv6_addr = "2001:4998:c:a06::2:4008" @@ -1241,16 +1207,14 @@ def echo_socket_handler(listener): self._start_server(echo_socket_handler) base_url = "http://%s:%d" % (self.host, self.port) - proxy = proxy_from_url(base_url, cert_reqs="NONE") - self.addCleanup(proxy.clear) - - url = "https://[{0}]".format(ipv6_addr) - conn = proxy.connection_from_url(url) - try: - r = conn.urlopen("GET", url, retries=0) - assert r.status == 200 - except MaxRetryError: - self.fail("Invalid IPv6 format in HTTP CONNECT request") + with proxy_from_url(base_url, cert_reqs="NONE") as proxy: + url = "https://[{0}]".format(ipv6_addr) + conn = proxy.connection_from_url(url) + try: + r = conn.urlopen("GET", url, retries=0) + assert r.status == 200 + except MaxRetryError: + self.fail("Invalid IPv6 format in HTTP CONNECT request") class TestSSL(SocketDummyServerTestCase): @@ -1284,12 +1248,10 @@ def socket_handler(listener): ssl_sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) + with HTTPSConnectionPool(self.host, self.port) as pool: + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) def test_ssl_read_timeout(self): timed_out = Event() @@ -1323,21 +1285,19 @@ def socket_handler(listener): ssl_sock.close() self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) - self.addCleanup(pool.close) - - response = pool.urlopen( - "GET", - "/", - retries=0, - preload_content=False, - timeout=Timeout(connect=1, read=0.01), - ) - try: - with pytest.raises(ReadTimeoutError): - response.read() - finally: - timed_out.set() + with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + response = pool.urlopen( + "GET", + "/", + retries=0, + preload_content=False, + timeout=Timeout(connect=1, read=0.01), + ) + try: + with pytest.raises(ReadTimeoutError): + response.read() + finally: + timed_out.set() def test_ssl_failed_fingerprint_verification(self): def socket_handler(listener): @@ -1438,10 +1398,9 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) - self.addCleanup(pool.close) - response = pool.urlopen("GET", "/", retries=1) - assert response.data == b"Success" + with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + response = pool.urlopen("GET", "/", retries=1) + assert response.data == b"Success" def test_ssl_load_default_certs_when_empty(self): def socket_handler(listener): @@ -1475,13 +1434,11 @@ def socket_handler(listener): with mock.patch("urllib3.util.ssl_.SSLContext", lambda *_, **__: context): self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - - with pytest.raises(MaxRetryError): - pool.request("GET", "/", timeout=0.01) + with HTTPSConnectionPool(self.host, self.port) as pool: + with pytest.raises(MaxRetryError): + pool.request("GET", "/", timeout=0.01) - context.load_default_certs.assert_called_with() + context.load_default_certs.assert_called_with() def test_ssl_dont_load_default_certs_when_given(self): if platform.python_implementation() == "PyPy" and sys.version_info[0] == 2: @@ -1526,13 +1483,11 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPSConnectionPool(self.host, self.port, **kwargs) - self.addCleanup(pool.close) - - with pytest.raises(MaxRetryError): - pool.request("GET", "/", timeout=0.01) + with HTTPSConnectionPool(self.host, self.port, **kwargs) as pool: + with pytest.raises(MaxRetryError): + pool.request("GET", "/", timeout=0.01) - context.load_default_certs.assert_not_called() + context.load_default_certs.assert_not_called() class TestErrorWrapping(SocketDummyServerTestCase): @@ -1540,19 +1495,17 @@ def test_bad_statusline(self): self.start_response_handler( b"HTTP/1.1 Omg What Is This?\r\n" b"Content-Length: 0\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ProtocolError): - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + with pytest.raises(ProtocolError): + pool.request("GET", "/") def test_unknown_protocol(self): self.start_response_handler( b"HTTP/1000 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ProtocolError): - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + with pytest.raises(ProtocolError): + pool.request("GET", "/") class TestHeaders(SocketDummyServerTestCase): @@ -1563,11 +1516,10 @@ def test_headers_always_lowercase(self): b"Content-type: text/plain\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - HEADERS = {"content-length": "0", "content-type": "text/plain"} - r = pool.request("GET", "/") - assert HEADERS == dict(r.headers.items()) # to preserve case sensitivity + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + HEADERS = {"content-length": "0", "content-type": "text/plain"} + r = pool.request("GET", "/") + assert HEADERS == dict(r.headers.items()) # to preserve case sensitivity def test_headers_are_sent_with_lower_case(self): headers = {"Foo": "bar", "bAz": "quux"} @@ -1599,10 +1551,9 @@ def socket_handler(listener): for key, value in headers.items(): expected_headers[key.lower()] = value - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/", headers=HTTPHeaderDict(headers)) - assert expected_headers == parsed_headers + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.request("GET", "/", headers=HTTPHeaderDict(headers)) + assert expected_headers == parsed_headers def test_request_headers_are_sent_in_the_original_order(self): # NOTE: Probability this test gives a false negative is 1/(K!) @@ -1643,10 +1594,9 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/", headers=OrderedDict(expected_request_headers)) - assert expected_request_headers == actual_request_headers + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.request("GET", "/", headers=OrderedDict(expected_request_headers)) + assert expected_request_headers == actual_request_headers @fails_on_travis_gce def test_request_host_header_ignores_fqdn_dot(self): @@ -1674,12 +1624,11 @@ def socket_handler(listener): self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host + ".", self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/") - self.assert_header_received( - received_headers, "Host", "%s:%s" % (self.host, self.port) - ) + with HTTPConnectionPool(self.host + ".", self.port, retries=False) as pool: + pool.request("GET", "/") + self.assert_header_received( + received_headers, "Host", "%s:%s" % (self.host, self.port) + ) def test_response_headers_are_returned_in_the_original_order(self): # NOTE: Probability this test gives a false negative is 1/(K!) @@ -1711,13 +1660,12 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port) - self.addCleanup(pool.close) - r = pool.request("GET", "/", retries=0) - actual_response_headers = [ - (k, v) for (k, v) in r.headers.items() if k.startswith("x-header-") - ] - assert expected_response_headers == actual_response_headers + with HTTPConnectionPool(self.host, self.port) as pool: + r = pool.request("GET", "/", retries=0) + actual_response_headers = [ + (k, v) for (k, v) in r.headers.items() if k.startswith("x-header-") + ] + assert expected_response_headers == actual_response_headers def test_integer_values_are_sent_as_decimal_strings(self): headers = {"Foo": 88} @@ -1764,10 +1712,9 @@ def _test_broken_header_parsing(self, headers): + b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - with pytest.raises(ProtocolError): - pool.request("GET", "/") + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + with pytest.raises(ProtocolError): + pool.request("GET", "/") def test_header_without_name(self): self._test_broken_header_parsing([b": Value", b"Another: Header"]) @@ -1785,9 +1732,8 @@ def _test_okay_header_parsing(self, header): (b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n") + header + b"\r\n\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - pool.request("GET", "/") # does not raise + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + pool.request("GET", "/") # does not raise def test_header_text_plain(self): self._test_okay_header_parsing(b"Content-type: text/plain") @@ -1804,12 +1750,11 @@ def test_chunked_head_response_does_not_hang(self): b"Content-type: text/plain\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - r = pool.request("HEAD", "/", timeout=1, preload_content=False) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("HEAD", "/", timeout=1, preload_content=False) - # stream will use the read_chunked method here. - assert [] == list(r.stream()) + # stream will use the read_chunked method here. + assert [] == list(r.stream()) def test_empty_head_response_does_not_hang(self): self.start_response_handler( @@ -1818,12 +1763,11 @@ def test_empty_head_response_does_not_hang(self): b"Content-type: text/plain\r\n" b"\r\n" ) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - r = pool.request("HEAD", "/", timeout=1, preload_content=False) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("HEAD", "/", timeout=1, preload_content=False) - # stream will use the read method here. - assert [] == list(r.stream()) + # stream will use the read method here. + assert [] == list(r.stream()) class TestStream(SocketDummyServerTestCase): @@ -1848,14 +1792,13 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - pool = HTTPConnectionPool(self.host, self.port, retries=False) - self.addCleanup(pool.close) - r = pool.request("GET", "/", timeout=1, preload_content=False) + with HTTPConnectionPool(self.host, self.port, retries=False) as pool: + r = pool.request("GET", "/", timeout=1, preload_content=False) - # Stream should read to the end. - assert [b"hello, world"] == list(r.stream(None)) + # Stream should read to the end. + assert [b"hello, world"] == list(r.stream(None)) - done_event.set() + done_event.set() class TestBadContentLength(SocketDummyServerTestCase): @@ -1880,24 +1823,22 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(conn.close) + with HTTPConnectionPool(self.host, self.port, maxsize=1) as conn: + # Test stream read when content length less than headers claim + get_response = conn.request("GET", url="/", preload_content=False) + data = get_response.stream(100) - # Test stream read when content length less than headers claim - get_response = conn.request("GET", url="/", preload_content=False) - data = get_response.stream(100) - - # The first read will work fine. - next(data) - - # The second one will see the EOF condition and barf. - try: + # The first read will work fine. next(data) - assert False - except ProtocolError as e: - assert "received 12 bytes, expected 22" in str(e) - done_event.set() + # The second one will see the EOF condition and barf. + try: + next(data) + assert False + except ProtocolError as e: + assert "received 12 bytes, expected 22" in str(e) + + done_event.set() def test_enforce_content_length_no_body(self): done_event = Event() @@ -1919,15 +1860,13 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - conn = HTTPConnectionPool(self.host, self.port, maxsize=1) - self.addCleanup(conn.close) - - # Test stream on 0 length body - head_response = conn.request("HEAD", url="/", preload_content=False) - data = [chunk for chunk in head_response.stream(1)] - assert len(data) == 0 + with HTTPConnectionPool(self.host, self.port, maxsize=1) as conn: + # Test stream on 0 length body + head_response = conn.request("HEAD", url="/", preload_content=False) + data = [chunk for chunk in head_response.stream(1)] + assert len(data) == 0 - done_event.set() + done_event.set() class TestAutomaticHeaderInsertion(SocketDummyServerTestCase): @@ -1988,10 +1927,8 @@ def socket_handler(listener): self._start_server(socket_handler) retries = Retry(total=1, raise_on_status=False, status_forcelist=[404]) - pool = HTTPConnectionPool( + with HTTPConnectionPool( self.host, self.port, maxsize=10, retries=retries, block=True - ) - self.addCleanup(pool.close) - - pool.urlopen("GET", "/not_found", preload_content=False) - assert pool.num_connections == 1 + ) as pool: + pool.urlopen("GET", "/not_found", preload_content=False) + assert pool.num_connections == 1