Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions lore/io/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sys
import tempfile
import threading
import psycopg2

from datetime import datetime

Expand Down Expand Up @@ -74,17 +75,29 @@ def after_replace(func):
sqlalchemy_logger.setLevel(log_levels.get(lore.env.NAME, logging.WARN))


class ResultWrapper(object):
# Used to make psycopg2 results compatible
# with the interface provided by Connection.execute
def __init__(self, results):
self.results = results
def fetchall(self):
return self.results

class Connection(object):
UNLOAD_PREFIX = os.path.join(lore.env.NAME, 'unloads')
IAM_ROLE = os.environ.get('IAM_ROLE', None)

def __init__(self, url, name='connection', watermark=True, **kwargs):
def __init__(self, url, name='connection', watermark=True, allow_raw_adapter_queries=True, **kwargs):
if not sqlalchemy:
raise lore.env.ModuleNotFoundError('No module named sqlalchemy. Please add it to requirements.txt.')

parsed = lore.env.parse_url(url)
self.adapter = parsed.scheme

self._use_psycopg2 = False
if self.adapter == 'postgres' and allow_raw_adapter_queries:
self._use_psycopg2 = True

if self.adapter == 'postgres':
require(lore.dependencies.POSTGRES)
if self.adapter == 'snowflake':
Expand Down Expand Up @@ -425,14 +438,28 @@ def __prepare(self, sql=None, extract=None, filename=None, **kwargs):

return sql

def _connection_execute(self, sql, bindings):
if self._use_psycopg2:
with self._connection.engine.raw_connection().connection as conn:
with conn.cursor() as cursor:
cursor.execute(sql, bindings)
try:
return ResultWrapper(cursor.fetchall())
except psycopg2.ProgrammingError as e:
if 'no results to fetch' in str(e):
return None
raise e
else:
return self._connection.execute(sql, bindings)
Comment on lines +441 to +453

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _connection_execute(self, sql, bindings):
if self._use_psycopg2:
with self._connection.engine.raw_connection().connection as conn:
with conn.cursor() as cursor:
cursor.execute(sql, bindings)
try:
return ResultWrapper(cursor.fetchall())
except psycopg2.ProgrammingError as e:
if 'no results to fetch' in str(e):
return None
raise e
else:
return self._connection.execute(sql, bindings)
def _connection_execute(self, sql, bindings):
if not self._use_psycopg2:
return self._connection.execute(sql, bindings)
try:
with self._connection.engine.raw_connection().connection as conn, conn.cursor() as cursor:
cursor.execute(sql, bindings)
return ResultWrapper(cursor.fetchall())
except psycopg2.ProgrammingError as e:
if 'no results to fetch' not in str(e):
raise


def __execute(self, sql, bindings):
try:
return self._connection.execute(sql, bindings)
return self._connection_execute(sql, bindings)
except (sqlalchemy.exc.DBAPIError, Psycopg2OperationalError, SnowflakeProgrammingError) as e:
if not self._transactions and (isinstance(e, Psycopg2OperationalError) or e.connection_invalidated):
logger.warning('Reconnect and retry due to invalid connection')
self.close()
return self._connection.execute(sql, bindings)
return self._connection_execute(sql, bindings)
elif not self._transactions and (isinstance(e, SnowflakeProgrammingError) or e.connection_invalidated):
if hasattr(e, 'msg') and e.msg and "authenticate" in e.msg.lower():
logger.warning('Reconnect and retry due to unauthenticated connection')
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/io/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy import event
from sqlalchemy.engine import Engine
import pandas

import psycopg2
import lore


Expand Down Expand Up @@ -137,7 +137,7 @@ def insert(delay=0):
lore.io.main.execute(sql='insert into tests_autocommit values (1), (2), (3)')
posts.append(lore.io.main.select(sql='select count(*) from tests_autocommit')[0][0])
time.sleep(delay)
except sqlalchemy.exc.IntegrityError as ex:
except psycopg2.IntegrityError as ex:
thrown.append(True)

slow = Thread(target=insert, args=(1,))
Expand All @@ -163,21 +163,21 @@ def test_close(self):
lore.io.main.close()
reopened = lore.io.main.select(sql='select 1')
self.assertEquals(reopened, [(1,)])
with self.assertRaises(sqlalchemy.exc.ProgrammingError):
with self.assertRaises(psycopg2.ProgrammingError):
lore.io.main.select(sql='select count(*) from tests_close')

def test_reconnect_and_retry(self):
original_execute = lore.io.main._connection.execute
original_execute = lore.io.main._connection_execute

def raise_dbapi_error_on_first_call(sql, bindings):
lore.io.main._connection.execute = original_execute
lore.io.main._connection_execute = original_execute
e = lore.io.connection.Psycopg2OperationalError('server closed the connection unexpectedly. This probably means the server terminated abnormally before or while processing the request.')
raise sqlalchemy.exc.DBAPIError('select 1', [], e, True)

exceptions = lore.env.STDOUT_EXCEPTIONS
lore.env.STDOUT_EXCEPTIONS = False
connection = lore.io.main._connection
lore.io.main._connection.execute = raise_dbapi_error_on_first_call
lore.io.main._connection_execute = raise_dbapi_error_on_first_call

result = lore.io.main.select(sql='select 1')
lore.env.STDOUT_EXCEPTIONS = exceptions
Expand All @@ -192,17 +192,17 @@ def test_tuple_interpolation(self):
self.assertEqual(len(temps), 3)

def test_reconnect_and_retry_on_expired_connection(self):
original_execute = lore.io.main._connection.execute
original_execute = lore.io.main._connection_execute

def raise_snowflake_programming_error_on_first_call(sql, bindings):
lore.io.main._connection.execute = original_execute
lore.io.main._connection_execute = original_execute
e = lore.io.connection.SnowflakeProgrammingError('Authentication token has expired. The user must authenticate again')
raise sqlalchemy.exc.DBAPIError('select 1', [], e, True)

exceptions = lore.env.STDOUT_EXCEPTIONS
lore.env.STDOUT_EXCEPTIONS = False
connection = lore.io.main._connection
lore.io.main._connection.execute = raise_snowflake_programming_error_on_first_call
lore.io.main._connection_execute = raise_snowflake_programming_error_on_first_call

result = lore.io.main.select(sql='select 1')
lore.env.STDOUT_EXCEPTIONS = exceptions
Expand Down