Skip to content

Commit

Permalink
restructured flake.nix
Browse files Browse the repository at this point in the history
  • Loading branch information
padhia committed Oct 6, 2024
1 parent 38a427b commit bb4968d
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 46 deletions.
48 changes: 38 additions & 10 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,47 @@
description = "Snowflake connection helper functions";

inputs = {
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
nix-utils.url = "github:padhia/nix-utils";
snowflake.url = "github:padhia/snowflake/next";
nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";

nix-utils.url = "github:padhia/nix-utils/next";
nix-utils.inputs.nixpkgs.follows = "nixpkgs";
snowflake.inputs.nixpkgs.follows = "nixpkgs";

snowflake.url = "github:padhia/snowflake/next";
snowflake.inputs = {
nixpkgs.follows = "nixpkgs";
flake-utils.follows = "flake-utils";
};
};

outputs = { self, nixpkgs, nix-utils, snowflake }:
nix-utils.lib.mkPyFlake {
pkgs = { sfconn = import ./sfconn.nix; sfconn02x = import ./sfconn02x.nix; };
defaultPkg = "sfconn";
deps = [ "snowflake-connector-python" "snowflake-snowpark-python" "pyjwt" "pytest" ];
pyFlakes = [ snowflake ];
outputs = { self, nixpkgs, nix-utils, flake-utils, snowflake }:
let
inherit (nix-utils.lib) pyDevShell extendPyPkgsWith;

overlays.default = final: prev:
extendPyPkgsWith prev {
sfconn = ./sfconn.nix;
sfconn02x = ./sfconn02x.nix;
};

buildSystem = system:
let
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
overlays = [ snowflake.overlays.default self.overlays.default ];
};
in {
devShells.default = pyDevShell {
inherit pkgs;
name = "sfconn";
extra = [ "snowflake-snowpark-python" ];
pyVer = "311";
};
};

in {
inherit overlays;
inherit (flake-utils.lib.eachDefaultSystem buildSystem) devShells;
};
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"
requires-python = ">=3.11"
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"snowflake-connector-python>=3.7.0",
"snowflake-connector-python[secure-local-storage]>=3.7.0",
"pyjwt",
]
dynamic = ["version"]
Expand Down
19 changes: 12 additions & 7 deletions sfconn.nix
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
buildPythonPackage,
setuptools,
snowflake-connector-python,
keyring,
pyjwt,
pytest
}:
buildPythonPackage rec {
pname = "sfconn";
version = "0.3.3";
pname = "sfconn";
version = "0.3.3";
pyproject = true;
src = ./.;
src = ./.;

propagatedBuildInputs = [ snowflake-connector-python pyjwt ];
nativeBuildInputs = [ setuptools pytest ];
doCheck = false;
dependencies = [
snowflake-connector-python
keyring
pyjwt
];

build-system = [ setuptools ];
doCheck = false;

meta = with lib; {
homepage = "https://github.com/padhia/sfconn";
Expand Down
11 changes: 7 additions & 4 deletions sfconn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from snowflake.connector.cursor import ResultMetadata

from .conn import Connection, Cursor, available_connections, default_connection_name, getconn
from .jwt import get_token
from .utils import pytype_conn, with_connection, with_connection_args, with_jwt, with_jwt_args
from .jwt import RestInfo, get_rest_info, get_token
from .utils import pytype_conn, set_loglevel, with_connection, with_connection_args, with_rest, with_rest_args


@singledispatch
Expand Down Expand Up @@ -47,10 +47,13 @@ def _(meta: DataType, _: bool = False):
"getconn",
"getsess",
"get_token",
"get_rest_info",
"RestInfo",
"pytype",
"set_loglevel",
"with_connection",
"with_connection_args",
"with_session",
"with_jwt",
"with_jwt_args",
"with_rest",
"with_rest_args",
]
7 changes: 7 additions & 0 deletions sfconn/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def _parse_keyfile_pfx_map() -> tuple[Path, Path] | None:
_default_keyfile_pfx_map = _parse_keyfile_pfx_map()


def _mask_opts(opts: dict[str, Any]) -> dict[str, Any]:
return opts | {k: "*****" for k in ["password", "passcode", "token"] if k in opts}


class Connection(SnowflakeConnection):
"A Connection class that overrides the cursor() method to return a custom Cursor class"

Expand Down Expand Up @@ -101,6 +105,9 @@ def fix_keyfile_path(path: str) -> str:
if "private_key_file" in opts:
opts["private_key_file"] = fix_keyfile_path(cast(str, opts["private_key_file"]))

if logger.getEffectiveLevel() >= logging.DEBUG:
logger.debug(_mask_opts(opts))

return opts


Expand Down
2 changes: 1 addition & 1 deletion sfconn/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def execute_debug(self, sql: str, params: Params = None) -> Self:
Returns:
Self
"""
if logger.getEffectiveLevel() <= DEBUG:
if logger.getEffectiveLevel() >= DEBUG:
fi = getframeinfo(currentframe().f_back.f_back) # type: ignore
logger.debug(
"Running SQL, file: %s, line: %d, function: %s\n%s;",
Expand Down
76 changes: 67 additions & 9 deletions sfconn/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime as dt
import hashlib
from pathlib import Path
from typing import Any, cast
from typing import Any, NamedTuple, cast

import jwt

Expand Down Expand Up @@ -34,26 +34,53 @@ def _clean_account_name(account: str) -> str:
return account


def get_token(
class RestInfo(NamedTuple):
token: str
url: str
database: str | None = None
schema: str | None = None
role: str | None = None
warehouse: str | None = None

def headers(self, user_agent: str = "sfconn-app") -> dict[str, str]:
return {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": f"{user_agent}",
"X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
}

@property
def conn_opts(self) -> dict[str, str]:
opts = dict(database=self.database, schema=self.schema, role=self.role, warehouse=self.warehouse)
return {k: v for k, v in opts.items() if v is not None}


def get_rest_info(
connection_name: str | None = None,
keyfile_pfx_map: tuple[Path, Path] | None = None,
lifetime: dt.timedelta = LIFETIME,
) -> str:
"""get a JWT when using key-pair authentication
*,
keyfile_pfx_map: tuple[Path, Path] | None = None,
**kwargs: str | None
) -> RestInfo:
"""get Jwt object using key-pair authentication
Args
conn: A connection name to be looked up from the config_file, optional, default to None for the default connection
lifetime: issued token's lifetime
lifetime: issued token's lifetime (default 59 minutes)
keyfile_pfx_map: if specified must be a a pair of Path values specified as <from-path>:<to-path>, which will
be used to temporarily change private_key_file path value if it starts with <from-pahd> prefix
Returns:
a JWT
Jwt object
Exceptions:
ValueError: if `conn` doesn't support key-pair authentication
*: Any exceptions raised by either conn_opts() or class PrivateKey
"""

opts = conn_opts(connection_name=connection_name, keyfile_pfx_map=keyfile_pfx_map)
opts = conn_opts(connection_name=connection_name, keyfile_pfx_map=keyfile_pfx_map, **kwargs)
keyf = cast(str | None, opts.get("private_key_file"))
if keyf is None:
raise ValueError(f"'{connection_name}' does not use key-pair authentication to support creating a JWT")
Expand All @@ -70,4 +97,35 @@ def get_token(
"exp": int((now + lifetime).timestamp()),
}

return jwt.encode(payload, key=key.key, algorithm=ALGORITHM) # type: ignore
return RestInfo(
token=jwt.encode(payload, key=key.key, algorithm=ALGORITHM),
url=f"https://{opts['account']}.snowflakecomputing.com/api/v2/statements",
database=opts.get("database"),
schema=opts.get("schema"),
role=opts.get("role"),
warehouse=opts.get("warehouse"),
)


def get_token(
connection_name: str | None = None,
lifetime: dt.timedelta = LIFETIME,
*,
keyfile_pfx_map: tuple[Path, Path] | None = None,
) -> str:
"""get a JWT when using key-pair authentication
Args
conn: A connection name to be looked up from the config_file, optional, default to None for the default connection
lifetime: issued token's lifetime (default 59 minutes)
keyfile_pfx_map: if specified must be a a pair of Path values specified as <from-path>:<to-path>, which will
be used to temporarily change private_key_file path value if it starts with <from-pahd> prefix
Returns:
a JWT
Exceptions:
ValueError: if `conn` doesn't support key-pair authentication
*: Any exceptions raised by either conn_opts() or class PrivateKey
"""
return get_rest_info(connection_name=connection_name, lifetime=lifetime, keyfile_pfx_map=keyfile_pfx_map).token
33 changes: 25 additions & 8 deletions sfconn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from snowflake.connector.cursor import ResultMetadata

from .conn import Connection, getconn
from .jwt import get_token
from .jwt import RestInfo, get_rest_info

P = ParamSpec("P")
R = TypeVar("R")
Expand All @@ -32,6 +32,11 @@ def init_logging(logger: Logger, loglevel: int = logging.WARNING) -> None:
logger.setLevel(loglevel)


def set_loglevel(loglevel: int = logging.WARNING) -> None:
"set logging level for the module, default WARNING"
init_logging(logging.getLogger(".".join(__name__.split(".")[:-1])), loglevel)


def with_connection(logger: Logger | None = None) -> Callable[[ConnFn[P, R]], ArgsFn[P, R]]:
def wrapper(fn: ConnFn[P, R]) -> ArgsFn[P, R]:
@wraps(fn)
Expand All @@ -47,7 +52,7 @@ def wrapped(
**kwargs: P.kwargs,
) -> R:
"script entry-point"
init_logging(logging.getLogger(__name__))
set_loglevel(loglevel)
if logger is not None:
init_logging(logger, loglevel)

Expand Down Expand Up @@ -116,7 +121,7 @@ def wrapped(args: list[str] | None = None) -> Any:
return getargs


def with_jwt_args(doc: str | None, **kwargs: Any) -> Callable[..., Callable[..., Any]]:
def with_rest_args(doc: str | None, **kwargs: Any) -> Callable[..., Callable[..., Any]]:
"""Function decorator that instantiates and adds snowflake JWT as first argument"""

def getargs(fn: Callable[[ArgumentParser], None]) -> Callable[..., Any]:
Expand All @@ -140,25 +145,37 @@ def wrapped(args: list[str] | None = None) -> Any:
return getargs


def with_jwt(logger: Logger | None = None) -> Callable[[Callable[Concatenate[str, P], R]], ArgsFn[P, R]]:
def wrapper(fn: Callable[Concatenate[str, P], R]) -> ArgsFn[P, R]:
def with_rest(logger: Logger | None = None) -> Callable[[Callable[Concatenate[RestInfo, P], R]], ArgsFn[P, R]]:
def wrapper(fn: Callable[Concatenate[RestInfo, P], R]) -> ArgsFn[P, R]:
@wraps(fn)
def wrapped(
keyfile_pfx_map: tuple[Path, Path] | None,
connection_name: str | None,
database: str | None,
role: str | None,
schema: str | None,
warehouse: str | None,
lifetime: dt.timedelta,
loglevel: int,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"script entry-point"
init_logging(logging.getLogger(__name__))
set_loglevel(loglevel)
if logger is not None:
init_logging(logger, loglevel)

try:
token = get_token(keyfile_pfx_map=keyfile_pfx_map, connection_name=connection_name, lifetime=lifetime)
return fn(token, *args, **kwargs)
rest_info = get_rest_info(
keyfile_pfx_map=keyfile_pfx_map,
connection_name=connection_name,
lifetime=lifetime,
database=database,
schema=schema,
role=role,
warehouse=warehouse,
)
return fn(rest_info, *args, **kwargs)
except Exception as err:
raise SystemExit(str(err))

Expand Down
17 changes: 11 additions & 6 deletions sfconn02x.nix
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@
fetchPypi,
setuptools,
snowflake-connector-python,
keyring,
pyjwt,
pytest
}:
buildPythonPackage rec {
pname = "sfconn";
version = "0.2.5";
pname = "sfconn";
version = "0.2.5";
pyproject = true;

src = fetchPypi {
inherit pname version;
hash = "sha256-jdhR9UgHH2klrTtI0bSWN4/FSYXxJdlDhKMRW7c+AdQ=";
};

propagatedBuildInputs = [ snowflake-connector-python pyjwt ];
nativeBuildInputs = [ setuptools pytest ];
doCheck = false;
dependencies = [
snowflake-connector-python
keyring
pyjwt
];

build-system = [ setuptools ];
doCheck = false;

meta = with lib; {
homepage = "https://github.com/padhia/sfconn";
Expand Down

0 comments on commit bb4968d

Please sign in to comment.