Skip to content

Commit 4891db0

Browse files
authored
OAuth implementation (#15)
This PR: * Adds the foundation for OAuth against Databricks account on AWS with BYOIDP. * It copies one internal module that Steve Weis @sweisdb wrote for Databricks CLI (oauth.py). Once ecosystem-dev team (Serge, Pieter) build a python sdk core we will move this code to their repo as a dependency. * the PR provides authenticators with visitor pattern format for stamping auth-token which later is intended to be moved to the repo owned by Serge @nfx and and Pieter @pietern
1 parent b9645f9 commit 4891db0

18 files changed

+1218
-256
lines changed

.github/workflows/code-quality-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,4 @@ jobs:
154154
# black the code
155155
#----------------------------------------------
156156
- name: Mypy
157-
run: poetry run mypy src
157+
run: poetry run mypy --install-types --non-interactive src

poetry.lock

Lines changed: 338 additions & 158 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@ python = "^3.7.1"
1313
thrift = "^0.13.0"
1414
pandas = "^1.3.0"
1515
pyarrow = "^9.0.0"
16+
requests=">2.18.1"
17+
oauthlib=">=3.1.0"
1618

1719
[tool.poetry.dev-dependencies]
1820
pytest = "^7.1.2"
1921
mypy = "^0.950"
22+
pylint = ">=2.12.0"
2023
black = "^22.3.0"
2124

2225
[tool.poetry.urls]

src/databricks/sql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def TimestampFromTicks(ticks):
4444
return Timestamp(*time.localtime(ticks)[:6])
4545

4646

47-
def connect(server_hostname, http_path, access_token, **kwargs):
47+
def connect(server_hostname, http_path, access_token=None, **kwargs):
4848
from .client import Connection
4949

5050
return Connection(server_hostname, http_path, access_token, **kwargs)

src/databricks/sql/auth/__init__.py

Whitespace-only changes.

src/databricks/sql/auth/auth.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from enum import Enum
2+
from typing import List
3+
4+
from databricks.sql.auth.authenticators import (
5+
AuthProvider,
6+
AccessTokenAuthProvider,
7+
BasicAuthProvider,
8+
DatabricksOAuthProvider,
9+
)
10+
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
11+
12+
13+
class AuthType(Enum):
14+
DATABRICKS_OAUTH = "databricks-oauth"
15+
# other supported types (access_token, user/pass) can be inferred
16+
# we can add more types as needed later
17+
18+
19+
class ClientContext:
20+
def __init__(
21+
self,
22+
hostname: str,
23+
username: str = None,
24+
password: str = None,
25+
access_token: str = None,
26+
auth_type: str = None,
27+
oauth_scopes: List[str] = None,
28+
oauth_client_id: str = None,
29+
oauth_redirect_port_range: List[int] = None,
30+
use_cert_as_auth: str = None,
31+
tls_client_cert_file: str = None,
32+
oauth_persistence=None,
33+
):
34+
self.hostname = hostname
35+
self.username = username
36+
self.password = password
37+
self.access_token = access_token
38+
self.auth_type = auth_type
39+
self.oauth_scopes = oauth_scopes
40+
self.oauth_client_id = oauth_client_id
41+
self.oauth_redirect_port_range = oauth_redirect_port_range
42+
self.use_cert_as_auth = use_cert_as_auth
43+
self.tls_client_cert_file = tls_client_cert_file
44+
self.oauth_persistence = oauth_persistence
45+
46+
47+
def get_auth_provider(cfg: ClientContext):
48+
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
49+
assert cfg.oauth_redirect_port_range is not None
50+
assert cfg.oauth_client_id is not None
51+
assert cfg.oauth_scopes is not None
52+
53+
return DatabricksOAuthProvider(
54+
cfg.hostname,
55+
cfg.oauth_persistence,
56+
cfg.oauth_redirect_port_range,
57+
cfg.oauth_client_id,
58+
cfg.oauth_scopes,
59+
)
60+
elif cfg.access_token is not None:
61+
return AccessTokenAuthProvider(cfg.access_token)
62+
elif cfg.username is not None and cfg.password is not None:
63+
return BasicAuthProvider(cfg.username, cfg.password)
64+
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
65+
# no op authenticator. authentication is performed using ssl certificate outside of headers
66+
return AuthProvider()
67+
else:
68+
raise RuntimeError("No valid authentication settings!")
69+
70+
71+
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
72+
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
73+
PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025))
74+
75+
76+
def normalize_host_name(hostname: str):
77+
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
78+
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
79+
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
80+
81+
82+
def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
83+
cfg = ClientContext(
84+
hostname=normalize_host_name(hostname),
85+
auth_type=kwargs.get("auth_type"),
86+
access_token=kwargs.get("access_token"),
87+
username=kwargs.get("_username"),
88+
password=kwargs.get("_password"),
89+
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
90+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
91+
oauth_scopes=PYSQL_OAUTH_SCOPES,
92+
oauth_client_id=PYSQL_OAUTH_CLIENT_ID,
93+
oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE,
94+
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
95+
)
96+
return get_auth_provider(cfg)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import base64
2+
import logging
3+
from typing import Dict, List
4+
5+
from databricks.sql.auth.oauth import OAuthManager
6+
7+
# Private API: this is an evolving interface and it will change in the future.
8+
# Please must not depend on it in your applications.
9+
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence
10+
11+
12+
class AuthProvider:
13+
def add_headers(self, request_headers: Dict[str, str]):
14+
pass
15+
16+
17+
# Private API: this is an evolving interface and it will change in the future.
18+
# Please must not depend on it in your applications.
19+
class AccessTokenAuthProvider(AuthProvider):
20+
def __init__(self, access_token: str):
21+
self.__authorization_header_value = "Bearer {}".format(access_token)
22+
23+
def add_headers(self, request_headers: Dict[str, str]):
24+
request_headers["Authorization"] = self.__authorization_header_value
25+
26+
27+
# Private API: this is an evolving interface and it will change in the future.
28+
# Please must not depend on it in your applications.
29+
class BasicAuthProvider(AuthProvider):
30+
def __init__(self, username: str, password: str):
31+
auth_credentials = f"{username}:{password}".encode("UTF-8")
32+
auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode(
33+
"UTF-8"
34+
)
35+
36+
self.__authorization_header_value = f"Basic {auth_credentials_base64}"
37+
38+
def add_headers(self, request_headers: Dict[str, str]):
39+
request_headers["Authorization"] = self.__authorization_header_value
40+
41+
42+
# Private API: this is an evolving interface and it will change in the future.
43+
# Please must not depend on it in your applications.
44+
class DatabricksOAuthProvider(AuthProvider):
45+
SCOPE_DELIM = " "
46+
47+
def __init__(
48+
self,
49+
hostname: str,
50+
oauth_persistence: OAuthPersistence,
51+
redirect_port_range: List[int],
52+
client_id: str,
53+
scopes: List[str],
54+
):
55+
try:
56+
self.oauth_manager = OAuthManager(
57+
port_range=redirect_port_range, client_id=client_id
58+
)
59+
self._hostname = hostname
60+
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes)
61+
self._oauth_persistence = oauth_persistence
62+
self._client_id = client_id
63+
self._access_token = None
64+
self._refresh_token = None
65+
self._initial_get_token()
66+
except Exception as e:
67+
logging.error(f"unexpected error", e, exc_info=True)
68+
raise e
69+
70+
def add_headers(self, request_headers: Dict[str, str]):
71+
self._update_token_if_expired()
72+
request_headers["Authorization"] = f"Bearer {self._access_token}"
73+
74+
def _initial_get_token(self):
75+
try:
76+
if self._access_token is None or self._refresh_token is None:
77+
if self._oauth_persistence:
78+
token = self._oauth_persistence.read(self._hostname)
79+
if token:
80+
self._access_token = token.access_token
81+
self._refresh_token = token.refresh_token
82+
83+
if self._access_token and self._refresh_token:
84+
self._update_token_if_expired()
85+
else:
86+
(access_token, refresh_token) = self.oauth_manager.get_tokens(
87+
hostname=self._hostname, scope=self._scopes_as_str
88+
)
89+
self._access_token = access_token
90+
self._refresh_token = refresh_token
91+
self._oauth_persistence.persist(
92+
self._hostname, OAuthToken(access_token, refresh_token)
93+
)
94+
except Exception as e:
95+
logging.error(f"unexpected error in oauth initialization", e, exc_info=True)
96+
raise e
97+
98+
def _update_token_if_expired(self):
99+
try:
100+
(
101+
fresh_access_token,
102+
fresh_refresh_token,
103+
is_refreshed,
104+
) = self.oauth_manager.check_and_refresh_access_token(
105+
hostname=self._hostname,
106+
access_token=self._access_token,
107+
refresh_token=self._refresh_token,
108+
)
109+
if not is_refreshed:
110+
return
111+
else:
112+
self._access_token = fresh_access_token
113+
self._refresh_token = fresh_refresh_token
114+
115+
if self._oauth_persistence:
116+
token = OAuthToken(self._access_token, self._refresh_token)
117+
self._oauth_persistence.persist(self._hostname, token)
118+
except Exception as e:
119+
logging.error(f"unexpected error in oauth token update", e, exc_info=True)
120+
raise e

0 commit comments

Comments
 (0)