diff --git a/lore/io/connection.py b/lore/io/connection.py index 1f9f811..8a8c7bb 100644 --- a/lore/io/connection.py +++ b/lore/io/connection.py @@ -11,6 +11,7 @@ import sys import tempfile import threading +import psycopg2 from datetime import datetime @@ -74,17 +75,29 @@ def after_replace(func): sqlalchemy_logger.setLevel(log_levels.get(lore.env.NAME, logging.WARN)) +class ResultWrapper(object): + # Used to make psycopg2 results compatible + # with the interface provided by Connection.execute + def __init__(self, results): + self.results = results + def fetchall(self): + return self.results + class Connection(object): UNLOAD_PREFIX = os.path.join(lore.env.NAME, 'unloads') IAM_ROLE = os.environ.get('IAM_ROLE', None) - def __init__(self, url, name='connection', watermark=True, **kwargs): + def __init__(self, url, name='connection', watermark=True, allow_raw_adapter_queries=True, **kwargs): if not sqlalchemy: raise lore.env.ModuleNotFoundError('No module named sqlalchemy. Please add it to requirements.txt.') parsed = lore.env.parse_url(url) self.adapter = parsed.scheme + self._use_psycopg2 = False + if self.adapter == 'postgres' and allow_raw_adapter_queries: + self._use_psycopg2 = True + if self.adapter == 'postgres': require(lore.dependencies.POSTGRES) if self.adapter == 'snowflake': @@ -425,14 +438,28 @@ def __prepare(self, sql=None, extract=None, filename=None, **kwargs): return sql + def _connection_execute(self, sql, bindings): + if self._use_psycopg2: + with self._connection.engine.raw_connection().connection as conn: + with conn.cursor() as cursor: + cursor.execute(sql, bindings) + try: + return ResultWrapper(cursor.fetchall()) + except psycopg2.ProgrammingError as e: + if 'no results to fetch' in str(e): + return None + raise e + else: + return self._connection.execute(sql, bindings) + def __execute(self, sql, bindings): try: - return self._connection.execute(sql, bindings) + return self._connection_execute(sql, bindings) except (sqlalchemy.exc.DBAPIError, Psycopg2OperationalError, SnowflakeProgrammingError) as e: if not self._transactions and (isinstance(e, Psycopg2OperationalError) or e.connection_invalidated): logger.warning('Reconnect and retry due to invalid connection') self.close() - return self._connection.execute(sql, bindings) + return self._connection_execute(sql, bindings) elif not self._transactions and (isinstance(e, SnowflakeProgrammingError) or e.connection_invalidated): if hasattr(e, 'msg') and e.msg and "authenticate" in e.msg.lower(): logger.warning('Reconnect and retry due to unauthenticated connection') diff --git a/tests/unit/io/test_connection.py b/tests/unit/io/test_connection.py index 4071bcd..acbdd87 100644 --- a/tests/unit/io/test_connection.py +++ b/tests/unit/io/test_connection.py @@ -12,7 +12,7 @@ from sqlalchemy import event from sqlalchemy.engine import Engine import pandas - +import psycopg2 import lore @@ -137,7 +137,7 @@ def insert(delay=0): lore.io.main.execute(sql='insert into tests_autocommit values (1), (2), (3)') posts.append(lore.io.main.select(sql='select count(*) from tests_autocommit')[0][0]) time.sleep(delay) - except sqlalchemy.exc.IntegrityError as ex: + except psycopg2.IntegrityError as ex: thrown.append(True) slow = Thread(target=insert, args=(1,)) @@ -163,21 +163,21 @@ def test_close(self): lore.io.main.close() reopened = lore.io.main.select(sql='select 1') self.assertEquals(reopened, [(1,)]) - with self.assertRaises(sqlalchemy.exc.ProgrammingError): + with self.assertRaises(psycopg2.ProgrammingError): lore.io.main.select(sql='select count(*) from tests_close') def test_reconnect_and_retry(self): - original_execute = lore.io.main._connection.execute + original_execute = lore.io.main._connection_execute def raise_dbapi_error_on_first_call(sql, bindings): - lore.io.main._connection.execute = original_execute + lore.io.main._connection_execute = original_execute e = lore.io.connection.Psycopg2OperationalError('server closed the connection unexpectedly. This probably means the server terminated abnormally before or while processing the request.') raise sqlalchemy.exc.DBAPIError('select 1', [], e, True) exceptions = lore.env.STDOUT_EXCEPTIONS lore.env.STDOUT_EXCEPTIONS = False connection = lore.io.main._connection - lore.io.main._connection.execute = raise_dbapi_error_on_first_call + lore.io.main._connection_execute = raise_dbapi_error_on_first_call result = lore.io.main.select(sql='select 1') lore.env.STDOUT_EXCEPTIONS = exceptions @@ -192,17 +192,17 @@ def test_tuple_interpolation(self): self.assertEqual(len(temps), 3) def test_reconnect_and_retry_on_expired_connection(self): - original_execute = lore.io.main._connection.execute + original_execute = lore.io.main._connection_execute def raise_snowflake_programming_error_on_first_call(sql, bindings): - lore.io.main._connection.execute = original_execute + lore.io.main._connection_execute = original_execute e = lore.io.connection.SnowflakeProgrammingError('Authentication token has expired. The user must authenticate again') raise sqlalchemy.exc.DBAPIError('select 1', [], e, True) exceptions = lore.env.STDOUT_EXCEPTIONS lore.env.STDOUT_EXCEPTIONS = False connection = lore.io.main._connection - lore.io.main._connection.execute = raise_snowflake_programming_error_on_first_call + lore.io.main._connection_execute = raise_snowflake_programming_error_on_first_call result = lore.io.main.select(sql='select 1') lore.env.STDOUT_EXCEPTIONS = exceptions