Skip to content

Commit

Permalink
Add abstract support for database session properties, and implement f…
Browse files Browse the repository at this point in the history
…or Presto, Hive and PySpark backends.
  • Loading branch information
matthewwardrop committed Aug 20, 2018
1 parent cd9495e commit cf90061
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 52 deletions.
131 changes: 97 additions & 34 deletions omniduct/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,23 @@ class DatabaseClient(Duct, MagicsProvider):
'raw': cursor_formatters.RawCursorFormatter,
}
DEFAULT_CURSOR_FORMATTER = 'pandas'
SUPPORTS_SESSION_PROPERTIES = False

@quirk_docs('_init', mro=True)
def __init__(self, **kwargs):
def __init__(self, session_properties=None, templates=None, template_context=None, **kwargs):
"""
session_properties (dict): A mapping of default session properties
to values. Interpretation is left up to implementations.
templates (dict): A dictionary of name to template mappings. Additional
templates can be added using `.template_add`.
template_context (dict): The default template context to use when
rendering templates.
"""
Duct.__init_with_kwargs__(self, kwargs, port=self.DEFAULT_PORT)

self._templates = kwargs.pop('templates', {})
self._template_context = kwargs.pop('template_context', {})
self.session_properties = session_properties or {}
self._templates = templates or {}
self._template_context = template_context or {}
self._sqlalchemy_engine = None
self._sqlalchemy_metadata = None

Expand All @@ -93,6 +97,41 @@ def __init__(self, **kwargs):
def _init(self):
pass

# Session property management and configuration
@property
def session_properties(self):
"""The default session properties used in statement executions."""
return self._session_properties

@session_properties.setter
def session_properties(self, properties):
self._session_properties = self._get_session_properties(default=properties)

def _get_session_properties(self, overrides=None, default=None):
"""
Retrieve default session properties with optional overrides.
Properties with a value of None will be skipped, in order to allow
overrides to remove default properties.
Parameters:
overrides (dict, None): A dictionary of session property overrides.
default (dict, None): A dictionary of default session properties, if
it is necessary to override `self.session_properties`.
"""
if (default or overrides) and not self.SUPPORTS_SESSION_PROPERTIES:
raise RuntimeError("Session properties are not supported by this backend.")

props = (default or self.session_properties).copy()
props.update(overrides or {})

# Remove any properties with value set to None.
for key, value in props.items():
if value is None:
del props[key]

return props

def __call__(self, query, **kwargs):
"""
Allow use of `DatabaseClient(...)` as a short-hand for
Expand All @@ -101,8 +140,27 @@ def __call__(self, query, **kwargs):
return self.query(query, **kwargs)

# Querying
@classmethod
def statements_split(cls, statements):
def _statement_prepare(self, statement, session_properties, **kwargs):
"""
This method prepares the statement into its final form to be executed
by `self._execute`. This can be used to insert session properties or
transform the statement in any way.
Parameters:
statement (str): The statement to be executed.
session_properties (dict): A mutable dictionary of session properties
and their values (this method can mutate it depending on statement
contents).
**kwargs (dict): Any additional keyword arguments passed through to
`self.execute` (will match the extra keyword arguments added to
`self._execute`).
Returns:
statement (str): The statement to be executed (potentially transformed).
"""
return statement

def _statement_split(self, statements):
"""
This classmethod converts a single string containing one or more SQL
statements into an iterator of strings, each corresponding to one SQL
Expand All @@ -122,25 +180,6 @@ def statements_split(cls, statements):
if statement: # remove empty statements
yield statement

@classmethod
def statement_cleanup(cls, statement):
"""
This classmethod takes an SQL statement and reformats it by consistently
removing comments and replacing all whitespace. It is used by the
`query` method to avoid functionally identical queries hitting different
cache kets. If the statement's language is not to be SQL, this method
should be overloaded appropriately.
Parameters:
statement (str): The statement to be reformatted/cleaned-up.
Returns:
str: The new statement, consistently reformatted.
"""
statement = sqlparse.format(statement, strip_comments=True, reindent=True)
statement = os.linesep.join([line for line in statement.splitlines() if line])
return statement

@classmethod
def statement_hash(cls, statement, cleanup=True):
"""
Expand All @@ -158,13 +197,35 @@ def statement_hash(cls, statement, cleanup=True):
"""
if cleanup:
statement = cls.statement_cleanup(statement)
if sys.version_info.major == 3 or sys.version_info.major == 2 and isinstance(statement, unicode): # noqa: F821
if (
sys.version_info.major == 3
or sys.version_info.major == 2 and isinstance(statement, unicode) # noqa: F821
):
statement = statement.encode('utf8')
return hashlib.sha256(statement).hexdigest()

@classmethod
def statement_cleanup(cls, statement):
"""
This classmethod takes an SQL statement and reformats it by consistently
removing comments and replacing all whitespace. It is used by the
`query` method to avoid functionally identical queries hitting different
cache keys. If the statement's language is not to be SQL, this method
should be overloaded appropriately.
Parameters:
statement (str): The statement to be reformatted/cleaned-up.
Returns:
str: The new statement, consistently reformatted.
"""
statement = sqlparse.format(statement, strip_comments=True, reindent=True)
statement = os.linesep.join([line for line in statement.splitlines() if line])
return statement

@render_statement
@quirk_docs('_execute')
def execute(self, statement, cleanup=True, wait=True, cursor=None, **kwargs):
def execute(self, statement, wait=True, cursor=None, session_properties=None, **kwargs):
"""
This method executes a given statement against the relevant database,
returning the results as a standard DBAPI2 compatible cursor. Where
Expand All @@ -174,13 +235,14 @@ def execute(self, statement, cleanup=True, wait=True, cursor=None, **kwargs):
Parameters:
statement (str): The statement to be executed by the query client
(possibly templated).
cleanup (bool): Whether statement should be cleaned up before
computing the hash used to cache results.
wait (bool): Whether the cursor should be returned before the
server-side query computation is complete and the relevant
results downloaded.
cursor (DBAPI2 cursor): Rather than creating a new cursor, execute
the statement against the provided cursor.
session_properties (dict): Additional session properties and/or
overrides to use for this query. Setting a session property
value to `None` will cause it to be omitted.
**kwargs (dict): Extra keyword arguments to be passed on to
`_execute`, as implemented by subclasses.
template (bool): Whether the statement should be treated as a Jinja2
Expand All @@ -193,15 +255,16 @@ def execute(self, statement, cleanup=True, wait=True, cursor=None, **kwargs):
DBAPI2 cursor: A DBAPI2 compatible cursor instance.
"""

self.connect()
session_properties = self._get_session_properties(overrides=session_properties)

statements = self.statements_split(statement)
statements = [self.statement_cleanup(stmt) if cleanup else stmt for stmt in statements]
statements = list(self._statement_split(
self._statement_prepare(statement, session_properties=session_properties, **kwargs)
))
assert len(statements) > 0, "No non-empty statements were provided."

for statement in statements[:-1]:
cursor = self.connect()._execute(statement, cursor=cursor, wait=True, **kwargs)
cursor = self.connect()._execute(statements[-1], cursor=cursor, wait=wait, **kwargs)
cursor = self.connect()._execute(statement, cursor=cursor, wait=True, session_properties=session_properties, **kwargs)
cursor = self.connect()._execute(statements[-1], cursor=cursor, wait=wait, session_properties=session_properties, **kwargs)

return cursor

Expand Down Expand Up @@ -512,7 +575,7 @@ def push(self, df, table, if_exists='fail', **kwargs):
# Table properties

@abstractmethod
def _execute(self, statement, cursor=None, wait=True, **kwargs):
def _execute(self, statement, cursor, wait, session_properties, **kwargs):
pass

def _push(self, df, table, if_exists='fail', **kwargs):
Expand Down
11 changes: 10 additions & 1 deletion omniduct/databases/hiveserver2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class HiveServer2Client(DatabaseClient, SchemasMixin):

PROTOCOLS = ['hiveserver2']
DEFAULT_PORT = 3623
SUPPORTS_SESSION_PROPERTIES = True

def _init(self, schema=None, driver='pyhive', auth_mechanism='NOSASL',
push_using_hive_cli=False, default_table_props=None, **connection_options):
Expand Down Expand Up @@ -127,7 +128,15 @@ def _disconnect(self):
self._sqlalchemy_metadata = None
self._schemas = None

def _execute(self, statement, cursor=None, wait=True, poll_interval=1):
def _statement_prepare(self, statement, session_properties, **kwargs):
return (
"\n".join(
"SET {key} = {value};".format(key=key, value=value)
for key, value in session_properties.items()
) + statement
)

def _execute(self, statement, cursor, wait, session_properties, poll_interval=1):
"""
Additional Parameters:
poll_interval (int): Default delay in seconds between consecutive
Expand Down
2 changes: 1 addition & 1 deletion omniduct/databases/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _disconnect(self):
self.__driver = None

# Querying
def _execute(self, statement, cursor=None, wait=True):
def _execute(self, statement, cursor, wait, session_properties):
with self.__driver.session() as session:
result = session.run(statement)

Expand Down
25 changes: 13 additions & 12 deletions omniduct/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,21 @@ class PrestoClient(DatabaseClient, SchemasMixin):

PROTOCOLS = ['presto']
DEFAULT_PORT = 3506
SUPPORTS_SESSION_PROPERTIES = True

def _init(self, catalog='default', schema='default', source=None, **connection_options):
def _init(self, catalog='default', schema='default', server_protocol='http', source=None):
"""
catalog (str): The default catalog to use in database queries.
schema (str): The default schema/database to use in database queries.
server_protocol (str): The protocol over which to connect to the Presto REST
service ('http' or 'https'). (default='http')
source (str): The source of this query (by default "omniduct <version>").
If manually specified, result will be: "<source> / omniduct <version>".
connection_options (dict): Additional options to pass on to
`pyhive.presto.connect(...)`.
"""
self.catalog = catalog
self.schema = schema
self.server_protocol = server_protocol
self.source = source
self.connection_options = connection_options
self.__presto = None
self.connection_fields += ('catalog', 'schema')

Expand All @@ -66,13 +67,9 @@ def source(self, source):
# Connection

def _connect(self):
from pyhive import presto # Imported here due to slow import performance in Python 3
from sqlalchemy import create_engine, MetaData
logging.getLogger('pyhive').setLevel(1000) # Silence pyhive logging.
logger.info('Connecting to Presto coordinator...')
self.__presto = presto.connect(self.host, port=self.port, username=self.username, password=self.password,
catalog=self.catalog, schema=self.schema,
poll_interval=1, source=self.source, **self.connection_options)
self._sqlalchemy_engine = create_engine('presto://{}:{}/{}/{}'.format(self.host, self.port, self.catalog, self.schema))
self._sqlalchemy_metadata = MetaData(self._sqlalchemy_engine)

Expand All @@ -88,21 +85,25 @@ def _disconnect(self):
self.__presto.close()
except:
pass
self.__presto = None
self._sqlalchemy_engine = None
self._sqlalchemy_metadata = None
self._schemas = None

# Querying
def _execute(self, statement, cursor=None, wait=True):
def _execute(self, statement, cursor, wait, session_properties):
"""
If something goes wrong, `PrestoClient` will attempt to parse the error
log and present the user with useful debugging information. If that fails,
the full traceback will be raised instead.
"""
from pyhive import presto # Imported here due to slow import performance in Python 3
from pyhive.exc import DatabaseError # Imported here due to slow import performance in Python 3
try:
cursor = cursor or self.__presto.cursor()
cursor = cursor or presto.Cursor(
host=self.host, port=self.port, username=self.username, password=self.password,
catalog=self.catalog, schema=self.schema, session_props=session_properties,
poll_interval=1, source=self.source, protocol=self.server_protocol
)
cursor.execute(statement)
status = cursor.poll()
if wait:
Expand All @@ -124,7 +125,7 @@ def _execute(self, statement, cursor=None, wait=True):
try:
message = e.args[0]
if isinstance(message, six.string_types):
message = ast.literal_eval(re.match("[^{]*({.*})[^}]*$", e.message).group(1))
message = ast.literal_eval(re.match("[^{]*({.*})[^}]*$", message).group(1))

linenumber = message['errorLocation']['lineNumber'] - 1
splt = statement.splitlines()
Expand Down
12 changes: 10 additions & 2 deletions omniduct/databases/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class PySparkClient(DatabaseClient):

PROTOCOLS = ['pyspark']
DEFAULT_PORT = None
SUPPORTS_SESSION_PROPERTIES = True

def _init(self, app_name='omniduct', config=None, master=None, enable_hive_support=False):
"""
Expand Down Expand Up @@ -49,8 +50,15 @@ def _disconnect(self):
self._spark_session.sparkContext.stop()

# Database operations

def _execute(self, statement, cursor=None, wait=True, **kwargs):
def _statement_prepare(self, statement, session_properties):
return (
"\n".join(
"SET {key} = {value};".format(key=key, value=value)
for key, value in session_properties.items()
) + statement
)

def _execute(self, statement, cursor, wait, session_properties, **kwargs):
assert wait is True, "This Spark backend does not support asynchronous operations."
return SparkCursor(self._spark_session.sql(statement))

Expand Down
3 changes: 2 additions & 1 deletion omniduct/databases/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def _disconnect(self):
self._sqlalchemy_metadata = None
self._schemas = None

def _execute(self, statement, query=True, cursor=None, **kwargs):
def _execute(self, statement, cursor, wait, session_properties, query=True, **kwargs):
assert wait, "`SQLAlchemyClient` does not support asynchronous operations."
if cursor:
cursor.execute(statement)
else:
Expand Down
2 changes: 1 addition & 1 deletion omniduct/databases/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _disconnect(self):

# Database operations

def _execute(self, statement, cursor=None, wait=True, **kwargs):
def _execute(self, statement, cursor, wait, session_properties, **kwargs):
raise NotImplementedError

def _table_list(self, **kwargs):
Expand Down

0 comments on commit cf90061

Please sign in to comment.