diff --git a/README.md b/README.md index 4ec6a83..0d43aa0 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ pip install --upgrade sfconn `getconn` and `getsess` are wrapper functions over native Snowflake functions with added functionality (mainly mapping `private_key_file` value as described above). -**Note:** `getsess` is meant to work with *Snowpark* applications. As such, it will raise a `NotImplementedError` exception if `snowflake-snowpark-python` package is not available. +**Note:** `getsess()` function will be available only if `snowflake-snowpark-python` package is available at run-time. **Usage:** ```python @@ -117,14 +117,16 @@ def with_session(logger = None) -> Callable[[Session, ...], None]: `with_session()` decorator function: 1. Similar to `with_connection()` but creates a `snowflake.snowpark.Session` object instead of a connection object -1. **Note:** this decorator will raise an `NotImplementedError` exception if `snowflake-snowpark-python` package is not available. +1. **Note:** this decorator will be available only if `snowflake-snowpark-python` package is available at run-time. + +**Note**: Decorator function parenthesis cannot be omitted even if no arguments are supplied to the decorator functions **Example:** ```python from sfconn import with_connection_args, with_connection -@with_connection +@with_connection() def main(con, show_account: bool): with con.cursor() as csr: csr.execute('SELECT CURRENT_USER()') diff --git a/sfconn.nix b/sfconn.nix index 2320634..8fae8ac 100644 --- a/sfconn.nix +++ b/sfconn.nix @@ -8,7 +8,7 @@ }: buildPythonPackage rec { pname = "sfconn"; - version = "0.3.1"; + version = "0.3.2"; pyproject = true; src = ./.; diff --git a/sfconn/__init__.py b/sfconn/__init__.py index 6a7e581..0210a4d 100644 --- a/sfconn/__init__.py +++ b/sfconn/__init__.py @@ -1,13 +1,38 @@ "connection package" -__version__ = "0.3.1" +__version__ = "0.3.2" + +from functools import singledispatch from snowflake.connector import DatabaseError, DataError, InterfaceError, ProgrammingError from snowflake.connector.cursor import ResultMetadata -from .conn import Connection, Cursor, available_connections, default_connection_name, getconn, getsess +from .conn import Connection, Cursor, available_connections, default_connection_name, getconn from .jwt import get_token -from .utils import pytype, with_connection, with_connection_args, with_session +from .utils import pytype_conn, with_connection, with_connection_args + + +@singledispatch +def pytype(meta, best_match: bool = False) -> type: # type: ignore + raise TypeError(f"{meta} is not an instance of ResultMetadata or DataType") + + +@pytype.register(ResultMetadata) +def _(meta: ResultMetadata, best_match: bool = False): + return pytype_conn(meta, best_match) + + +try: + from snowflake.snowpark.types import DataType + + from .utils_snowpark import getsess, pytype_sess, with_session + + @pytype.register(DataType) + def _(meta: DataType, _: bool = False): + return pytype_sess(meta) + +except ImportError: + pass __all__ = [ "DatabaseError", diff --git a/sfconn/conn.py b/sfconn/conn.py index da02e74..7579b5f 100644 --- a/sfconn/conn.py +++ b/sfconn/conn.py @@ -116,25 +116,3 @@ def getconn(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) Connection object returned by Snowflake python connector """ return Connection(**conn_opts(keyfile_pfx_map=keyfile_pfx_map, **kwargs)) # type: ignore - - -try: - from snowflake.snowpark import Session - - def getsess(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) -> Session: - """create a Session object using named configuration - - Args: - keyfile_pfx_map: if specified must be a a pair of Path values specified as :, which will - be used to temporarily change private_key_file path value if it starts with prefix - **kwargs: Any parameter that is valid for snowflake.connector.connect() method - - Returns: - Session object returned by Snowflake python connector - """ - return Session.builder.configs(conn_opts(keyfile_pfx_map=keyfile_pfx_map, **kwargs)).create() - -except ImportError: - - def getsess(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) -> Session: - raise NotImplementedError("Unable to import snowflake.snowpark.Session; is snowflake-snowpark-python installed?") diff --git a/sfconn/utils.py b/sfconn/utils.py index 71e737b..270e560 100644 --- a/sfconn/utils.py +++ b/sfconn/utils.py @@ -5,31 +5,34 @@ from argparse import SUPPRESS, ArgumentParser, ArgumentTypeError from decimal import Decimal from functools import wraps +from logging import Logger from pathlib import Path -from typing import Any, Callable, cast +from typing import Any, Callable, Concatenate, ParamSpec, TypeAlias, TypeVar, cast from snowflake.connector.constants import FIELD_TYPES from snowflake.connector.cursor import ResultMetadata -from .conn import conn_opts, getconn, getsess +from .conn import Connection, getconn -_loglevel = logging.WARNING +P = ParamSpec("P") +R = TypeVar("R") +ConnFn: TypeAlias = Callable[Concatenate[Connection, P], R] +ArgsFn: TypeAlias = Callable[ + [Concatenate[tuple[Path, Path] | None, str | None, str | None, str | None, str | None, str | None, int, P]], R +] -def init_logging(logger: logging.Logger) -> None: + +def init_logging(logger: Logger, loglevel: int = logging.WARNING) -> None: "initialize the logging system" h = logging.StreamHandler() h.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) logger.addHandler(h) - logger.setLevel(_loglevel) - + logger.setLevel(loglevel) -def with_connection_options(fl: Callable[..., Any] | logging.Logger | None = None) -> Callable[..., Any]: - "wraps application entry function that expects a connection" - logger = fl if isinstance(fl, logging.Logger) else None - - def wrapper(fn: Callable[..., Any]) -> Callable[..., Any]: +def with_connection(logger: Logger | None = None) -> Callable[[ConnFn[P, R]], ArgsFn[P, R]]: + def wrapper(fn: ConnFn[P, R]) -> ArgsFn[P, R]: @wraps(fn) def wrapped( keyfile_pfx_map: tuple[Path, Path] | None, @@ -39,70 +42,30 @@ def wrapped( schema: str | None, warehouse: str | None, loglevel: int, - **kwargs: Any, - ) -> Any: + *args: P.args, + **kwargs: P.kwargs, + ) -> R: "script entry-point" - global _loglevel - - _loglevel = loglevel init_logging(logging.getLogger(__name__)) if logger is not None: - init_logging(logger) - _opts = conn_opts( - keyfile_pfx_map=keyfile_pfx_map, - connection_name=connection_name, - database=database, - role=role, - schema=schema, - warehouse=warehouse, - ) - return fn(_opts, **kwargs) - - return wrapped - - return wrapper if fl is None or isinstance(fl, logging.Logger) else wrapper(fl) + init_logging(logger, loglevel) - -def with_connection(fl: Callable[..., Any] | logging.Logger | None = None) -> Callable[..., Any]: - "wraps application entry function that expects a connection" - - logger = fl if isinstance(fl, logging.Logger) else None - - def wrapper(fn: Callable[..., Any]) -> Callable[..., Any]: - @wraps(fn) - @with_connection_options(logger) - def wrapped(opts: dict[str, Any], **kwargs: Any) -> Any: - "script entry-point" try: - with getconn(**opts) as cnx: - return fn(cnx, **kwargs) + with getconn( + keyfile_pfx_map=keyfile_pfx_map, + connection_name=connection_name, + database=database, + role=role, + schema=schema, + warehouse=warehouse, + ) as cnx: + return fn(cnx, *args, **kwargs) except Exception as err: raise SystemExit(str(err)) - return wrapped - - return wrapper if fl is None or isinstance(fl, logging.Logger) else wrapper(fl) + return wrapped # type: ignore - -def with_session(fl: Callable[..., Any] | logging.Logger | None = None) -> Callable[..., Any]: - "wraps application entry function that expects a connection" - - logger = fl if isinstance(fl, logging.Logger) else None - - def wrapper(fn: Callable[..., Any]) -> Callable[..., Any]: - @wraps(fn) - @with_connection_options(logger) - def wrapped(opts: dict[str, Any], **kwargs: Any) -> Any: - "script entry-point" - try: - with getsess(**opts) as session: - return fn(session, **kwargs) - except Exception as err: - raise SystemExit(str(err)) - - return wrapped - - return wrapper if fl is None or isinstance(fl, logging.Logger) else wrapper(fl) + return wrapper def add_conn_args(parser: ArgumentParser) -> None: @@ -152,7 +115,7 @@ def wrapped(args: list[str] | None = None) -> Any: return getargs -def _pytype(meta: ResultMetadata, best_match: bool = False) -> type[Any]: +def pytype_conn(meta: ResultMetadata, best_match: bool = False) -> type: """convert Python DB API data type to python type Args: @@ -185,38 +148,3 @@ def _pytype(meta: ResultMetadata, best_match: bool = False) -> type[Any]: type_ = TYPE_MAP.get(sql_type_name, str) return type_ if best_match else str if type_ in [dict, object, list] else type_ - - -try: - import snowflake.snowpark.types as T - - def pytype(meta: ResultMetadata | T.DataType, best_match: bool = False) -> type[Any]: - """convert Python DB API or Snowpark data type to python type - - Args: - meta: an individual value returned as part of cursor.description or snowflake.snowpark.types.DataType - best_match: return Python type that is best suited, rather than the actual type used by the connector - - Returns: - Python type that best matches Snowflake's type, or str in other cases - """ - if isinstance(meta, ResultMetadata): - return _pytype(meta, best_match) - - types = { - T.LongType: int, - T.DateType: dt.date, - T.TimeType: dt.time, - T.TimestampType: dt.datetime, - T.BooleanType: bool, - T.DecimalType: Decimal, - T.DoubleType: float, - T.BinaryType: bytearray, - T.ArrayType: list, - T.VariantType: object, - T.MapType: dict, - } - return next((py_t for sp_t, py_t in types.items() if isinstance(meta, sp_t)), str) - -except ImportError: - pytype = _pytype # type: ignore diff --git a/sfconn/utils_snowpark.py b/sfconn/utils_snowpark.py new file mode 100644 index 0000000..97d6e64 --- /dev/null +++ b/sfconn/utils_snowpark.py @@ -0,0 +1,101 @@ +"Utility functions" + +import datetime as dt +import logging +from decimal import Decimal +from functools import wraps +from logging import Logger +from pathlib import Path +from typing import Any, Callable, Concatenate, ParamSpec, TypeAlias, TypeVar + +import snowflake.snowpark.types as T +from snowflake.snowpark import Session + +from .conn import conn_opts +from .utils import init_logging + +P = ParamSpec("P") +R = TypeVar("R") + + +ConnFn: TypeAlias = Callable[Concatenate[Session, P], R] +ArgsFn: TypeAlias = Callable[ + [Concatenate[tuple[Path, Path] | None, str | None, str | None, str | None, str | None, str | None, int, P]], R +] + + +def getsess(*, keyfile_pfx_map: tuple[Path, Path] | None = None, **kwargs: Any) -> Session: + """create a Session object using named configuration + + Args: + keyfile_pfx_map: if specified must be a a pair of Path values specified as :, which will + be used to temporarily change private_key_file path value if it starts with prefix + **kwargs: Any parameter that is valid for snowflake.connector.connect() method + + Returns: + Session object returned by Snowflake python connector + """ + return Session.builder.configs(conn_opts(keyfile_pfx_map=keyfile_pfx_map, **kwargs)).create() + + +def with_session(logger: Logger | None = None) -> Callable[[ConnFn[P, R]], ArgsFn[P, R]]: + def _decorate_conn_fn(fn: ConnFn[P, R]) -> ArgsFn[P, R]: + @wraps(fn) + def wrapped( + keyfile_pfx_map: tuple[Path, Path] | None, + connection_name: str | None, + database: str | None, + role: str | None, + schema: str | None, + warehouse: str | None, + loglevel: int, + *args: P.args, + **kwargs: P.kwargs, + ) -> R: + "script entry-point" + init_logging(logging.getLogger(__name__)) + if logger is not None: + init_logging(logger, loglevel) + + try: + with getsess( + keyfile_pfx_map=keyfile_pfx_map, + connection_name=connection_name, + database=database, + role=role, + schema=schema, + warehouse=warehouse, + ) as session: + return fn(session, *args, **kwargs) + except Exception as err: + raise SystemExit(str(err)) + + return wrapped # type: ignore + + return _decorate_conn_fn + + +def pytype_sess(meta: T.DataType) -> type: + """convert Python DB API or Snowpark data type to python type + + Args: + meta: an instance of snowflake.snowpark.types.DataType + + Returns: + Python type that matches Snowflake's type, or str in other cases + """ + types = { + T.LongType: int, + T.DateType: dt.date, + T.TimeType: dt.time, + T.TimestampType: dt.datetime, + T.BooleanType: bool, + T.DecimalType: Decimal, + T.DoubleType: float, + T.BinaryType: bytearray, + T.ArrayType: list, + T.VariantType: object, + T.MapType: dict, + } + + return next((py_t for sp_t, py_t in types.items() if isinstance(meta, sp_t)), str) diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 33acc6d..0d4996a 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -19,7 +19,7 @@ class Result: def getargs(_: ArgumentParser) -> Any: pass - @with_connection + @with_connection() def main(cnx: Connection) -> None: with cnx.cursor() as csr: rows = list(csr.run_query(Result, "select current_user() as user, current_role() as role"))