Skip to content

Commit

Permalink
fixed connection_name to be removed; better type hints; decorators wi…
Browse files Browse the repository at this point in the history
…th no args require parenthesis
  • Loading branch information
padhia committed Jul 19, 2024
1 parent c0aa1ab commit b914773
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 132 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install --upgrade sfconn

`getconn` and `getsess` are wrapper functions over native Snowflake functions with added functionality (mainly mapping `private_key_file` value as described above).

**Note:** `getsess` is meant to work with *Snowpark* applications. As such, it will raise a `NotImplementedError` exception if `snowflake-snowpark-python` package is not available.
**Note:** `getsess()` function will be available only if `snowflake-snowpark-python` package is available at run-time.

**Usage:**
```python
Expand Down Expand Up @@ -117,14 +117,16 @@ def with_session(logger = None) -> Callable[[Session, ...], None]:

`with_session()` decorator function:
1. Similar to `with_connection()` but creates a `snowflake.snowpark.Session` object instead of a connection object
1. **Note:** this decorator will raise an `NotImplementedError` exception if `snowflake-snowpark-python` package is not available.
1. **Note:** this decorator will be available only if `snowflake-snowpark-python` package is available at run-time.

**Note**: Decorator function parenthesis cannot be omitted even if no arguments are supplied to the decorator functions

**Example:**

```python
from sfconn import with_connection_args, with_connection

@with_connection
@with_connection()
def main(con, show_account: bool):
with con.cursor() as csr:
csr.execute('SELECT CURRENT_USER()')
Expand Down
2 changes: 1 addition & 1 deletion sfconn.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
}:
buildPythonPackage rec {
pname = "sfconn";
version = "0.3.1";
version = "0.3.2";
pyproject = true;
src = ./.;

Expand Down
31 changes: 28 additions & 3 deletions sfconn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
"connection package"

__version__ = "0.3.1"
__version__ = "0.3.2"

from functools import singledispatch

from snowflake.connector import DatabaseError, DataError, InterfaceError, ProgrammingError
from snowflake.connector.cursor import ResultMetadata

from .conn import Connection, Cursor, available_connections, default_connection_name, getconn, getsess
from .conn import Connection, Cursor, available_connections, default_connection_name, getconn
from .jwt import get_token
from .utils import pytype, with_connection, with_connection_args, with_session
from .utils import pytype_conn, with_connection, with_connection_args


@singledispatch
def pytype(meta, best_match: bool = False) -> type: # type: ignore
raise TypeError(f"{meta} is not an instance of ResultMetadata or DataType")


@pytype.register(ResultMetadata)
def _(meta: ResultMetadata, best_match: bool = False):
return pytype_conn(meta, best_match)


try:
from snowflake.snowpark.types import DataType

from .utils_snowpark import getsess, pytype_sess, with_session

@pytype.register(DataType)
def _(meta: DataType, _: bool = False):
return pytype_sess(meta)

except ImportError:
pass

__all__ = [
"DatabaseError",
Expand Down
22 changes: 0 additions & 22 deletions sfconn/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,25 +116,3 @@ def getconn(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any)
Connection object returned by Snowflake python connector
"""
return Connection(**conn_opts(keyfile_pfx_map=keyfile_pfx_map, **kwargs)) # type: ignore


try:
from snowflake.snowpark import Session

def getsess(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) -> Session:
"""create a Session object using named configuration
Args:
keyfile_pfx_map: if specified must be a a pair of Path values specified as <from-path>:<to-path>, which will
be used to temporarily change private_key_file path value if it starts with <from-pahd> prefix
**kwargs: Any parameter that is valid for snowflake.connector.connect() method
Returns:
Session object returned by Snowflake python connector
"""
return Session.builder.configs(conn_opts(keyfile_pfx_map=keyfile_pfx_map, **kwargs)).create()

except ImportError:

def getsess(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) -> Session:
raise NotImplementedError("Unable to import snowflake.snowpark.Session; is snowflake-snowpark-python installed?")
132 changes: 30 additions & 102 deletions sfconn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,34 @@
from argparse import SUPPRESS, ArgumentParser, ArgumentTypeError
from decimal import Decimal
from functools import wraps
from logging import Logger
from pathlib import Path
from typing import Any, Callable, cast
from typing import Any, Callable, Concatenate, ParamSpec, TypeAlias, TypeVar, cast

from snowflake.connector.constants import FIELD_TYPES
from snowflake.connector.cursor import ResultMetadata

from .conn import conn_opts, getconn, getsess
from .conn import Connection, getconn

_loglevel = logging.WARNING
P = ParamSpec("P")
R = TypeVar("R")

ConnFn: TypeAlias = Callable[Concatenate[Connection, P], R]
ArgsFn: TypeAlias = Callable[
[Concatenate[tuple[Path, Path] | None, str | None, str | None, str | None, str | None, str | None, int, P]], R
]

def init_logging(logger: logging.Logger) -> None:

def init_logging(logger: Logger, loglevel: int = logging.WARNING) -> None:
"initialize the logging system"
h = logging.StreamHandler()
h.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
logger.addHandler(h)
logger.setLevel(_loglevel)

logger.setLevel(loglevel)

def with_connection_options(fl: Callable[..., Any] | logging.Logger | None = None) -> Callable[..., Any]:
"wraps application entry function that expects a connection"

logger = fl if isinstance(fl, logging.Logger) else None

def wrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
def with_connection(logger: Logger | None = None) -> Callable[[ConnFn[P, R]], ArgsFn[P, R]]:
def wrapper(fn: ConnFn[P, R]) -> ArgsFn[P, R]:
@wraps(fn)
def wrapped(
keyfile_pfx_map: tuple[Path, Path] | None,
Expand All @@ -39,70 +42,30 @@ def wrapped(
schema: str | None,
warehouse: str | None,
loglevel: int,
**kwargs: Any,
) -> Any:
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"script entry-point"
global _loglevel

_loglevel = loglevel
init_logging(logging.getLogger(__name__))
if logger is not None:
init_logging(logger)
_opts = conn_opts(
keyfile_pfx_map=keyfile_pfx_map,
connection_name=connection_name,
database=database,
role=role,
schema=schema,
warehouse=warehouse,
)
return fn(_opts, **kwargs)

return wrapped

return wrapper if fl is None or isinstance(fl, logging.Logger) else wrapper(fl)
init_logging(logger, loglevel)


def with_connection(fl: Callable[..., Any] | logging.Logger | None = None) -> Callable[..., Any]:
"wraps application entry function that expects a connection"

logger = fl if isinstance(fl, logging.Logger) else None

def wrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fn)
@with_connection_options(logger)
def wrapped(opts: dict[str, Any], **kwargs: Any) -> Any:
"script entry-point"
try:
with getconn(**opts) as cnx:
return fn(cnx, **kwargs)
with getconn(
keyfile_pfx_map=keyfile_pfx_map,
connection_name=connection_name,
database=database,
role=role,
schema=schema,
warehouse=warehouse,
) as cnx:
return fn(cnx, *args, **kwargs)
except Exception as err:
raise SystemExit(str(err))

return wrapped

return wrapper if fl is None or isinstance(fl, logging.Logger) else wrapper(fl)
return wrapped # type: ignore


def with_session(fl: Callable[..., Any] | logging.Logger | None = None) -> Callable[..., Any]:
"wraps application entry function that expects a connection"

logger = fl if isinstance(fl, logging.Logger) else None

def wrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fn)
@with_connection_options(logger)
def wrapped(opts: dict[str, Any], **kwargs: Any) -> Any:
"script entry-point"
try:
with getsess(**opts) as session:
return fn(session, **kwargs)
except Exception as err:
raise SystemExit(str(err))

return wrapped

return wrapper if fl is None or isinstance(fl, logging.Logger) else wrapper(fl)
return wrapper


def add_conn_args(parser: ArgumentParser) -> None:
Expand Down Expand Up @@ -152,7 +115,7 @@ def wrapped(args: list[str] | None = None) -> Any:
return getargs


def _pytype(meta: ResultMetadata, best_match: bool = False) -> type[Any]:
def pytype_conn(meta: ResultMetadata, best_match: bool = False) -> type:
"""convert Python DB API data type to python type
Args:
Expand Down Expand Up @@ -185,38 +148,3 @@ def _pytype(meta: ResultMetadata, best_match: bool = False) -> type[Any]:
type_ = TYPE_MAP.get(sql_type_name, str)

return type_ if best_match else str if type_ in [dict, object, list] else type_


try:
import snowflake.snowpark.types as T

def pytype(meta: ResultMetadata | T.DataType, best_match: bool = False) -> type[Any]:
"""convert Python DB API or Snowpark data type to python type
Args:
meta: an individual value returned as part of cursor.description or snowflake.snowpark.types.DataType
best_match: return Python type that is best suited, rather than the actual type used by the connector
Returns:
Python type that best matches Snowflake's type, or str in other cases
"""
if isinstance(meta, ResultMetadata):
return _pytype(meta, best_match)

types = {
T.LongType: int,
T.DateType: dt.date,
T.TimeType: dt.time,
T.TimestampType: dt.datetime,
T.BooleanType: bool,
T.DecimalType: Decimal,
T.DoubleType: float,
T.BinaryType: bytearray,
T.ArrayType: list,
T.VariantType: object,
T.MapType: dict,
}
return next((py_t for sp_t, py_t in types.items() if isinstance(meta, sp_t)), str)

except ImportError:
pytype = _pytype # type: ignore
101 changes: 101 additions & 0 deletions sfconn/utils_snowpark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"Utility functions"

import datetime as dt
import logging
from decimal import Decimal
from functools import wraps
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Concatenate, ParamSpec, TypeAlias, TypeVar

import snowflake.snowpark.types as T
from snowflake.snowpark import Session

from .conn import conn_opts
from .utils import init_logging

P = ParamSpec("P")
R = TypeVar("R")


ConnFn: TypeAlias = Callable[Concatenate[Session, P], R]
ArgsFn: TypeAlias = Callable[
[Concatenate[tuple[Path, Path] | None, str | None, str | None, str | None, str | None, str | None, int, P]], R
]


def getsess(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) -> Session:
"""create a Session object using named configuration
Args:
keyfile_pfx_map: if specified must be a a pair of Path values specified as <from-path>:<to-path>, which will
be used to temporarily change private_key_file path value if it starts with <from-pahd> prefix
**kwargs: Any parameter that is valid for snowflake.connector.connect() method
Returns:
Session object returned by Snowflake python connector
"""
return Session.builder.configs(conn_opts(keyfile_pfx_map=keyfile_pfx_map, **kwargs)).create()


def with_session(logger: Logger | None = None) -> Callable[[ConnFn[P, R]], ArgsFn[P, R]]:
def _decorate_conn_fn(fn: ConnFn[P, R]) -> ArgsFn[P, R]:
@wraps(fn)
def wrapped(
keyfile_pfx_map: tuple[Path, Path] | None,
connection_name: str | None,
database: str | None,
role: str | None,
schema: str | None,
warehouse: str | None,
loglevel: int,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"script entry-point"
init_logging(logging.getLogger(__name__))
if logger is not None:
init_logging(logger, loglevel)

try:
with getsess(
keyfile_pfx_map=keyfile_pfx_map,
connection_name=connection_name,
database=database,
role=role,
schema=schema,
warehouse=warehouse,
) as session:
return fn(session, *args, **kwargs)
except Exception as err:
raise SystemExit(str(err))

return wrapped # type: ignore

return _decorate_conn_fn


def pytype_sess(meta: T.DataType) -> type:
"""convert Python DB API or Snowpark data type to python type
Args:
meta: an instance of snowflake.snowpark.types.DataType
Returns:
Python type that matches Snowflake's type, or str in other cases
"""
types = {
T.LongType: int,
T.DateType: dt.date,
T.TimeType: dt.time,
T.TimestampType: dt.datetime,
T.BooleanType: bool,
T.DecimalType: Decimal,
T.DoubleType: float,
T.BinaryType: bytearray,
T.ArrayType: list,
T.VariantType: object,
T.MapType: dict,
}

return next((py_t for sp_t, py_t in types.items() if isinstance(meta, sp_t)), str)
2 changes: 1 addition & 1 deletion tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Result:
def getargs(_: ArgumentParser) -> Any:
pass

@with_connection
@with_connection()
def main(cnx: Connection) -> None:
with cnx.cursor() as csr:
rows = list(csr.run_query(Result, "select current_user() as user, current_role() as role"))
Expand Down

0 comments on commit b914773

Please sign in to comment.