diff --git a/omniduct/databases/presto.py b/omniduct/databases/presto.py index 0d0ea4f..76a2211 100644 --- a/omniduct/databases/presto.py +++ b/omniduct/databases/presto.py @@ -49,7 +49,7 @@ def NAMESPACE_DEFAULT(self): } @override - def _init(self, catalog='default', schema='default', server_protocol='http', source=None): + def _init(self, catalog='default', schema='default', server_protocol='http', source=None, requests_session=None): """ catalog (str): The default catalog to use in database queries. schema (str): The default schema/database to use in database queries. @@ -57,6 +57,9 @@ def _init(self, catalog='default', schema='default', server_protocol='http', sou service ('http' or 'https'). (default='http') source (str): The source of this query (by default "omniduct "). If manually specified, result will be: " / omniduct ". + requests_session (requests.Session): an optional requests.Session object for advanced usage. + Passed through to the pyhive Cursor which supports custom requests sessions for advanced usage + such as custom headers, cookie values, retry logic, etc. """ self.catalog = catalog self.schema = schema @@ -64,6 +67,7 @@ def _init(self, catalog='default', schema='default', server_protocol='http', sou self.source = source self.__presto = None self.connection_fields += ('catalog', 'schema') + self._requests_session = requests_session @property def source(self): @@ -115,7 +119,8 @@ def _execute(self, statement, cursor, wait, session_properties): cursor = cursor or presto.Cursor( host=self.host, port=self.port, username=self.username, password=self.password, catalog=self.catalog, schema=self.schema, session_props=session_properties, - poll_interval=1, source=self.source, protocol=self.server_protocol + poll_interval=1, source=self.source, protocol=self.server_protocol, + requests_session=self._requests_session ) cursor.execute(statement) status = cursor.poll()