Skip to content

Commit 48edef8

Browse files
committed
Merge branch 'issue237-oidc-device-progress-3'
2 parents 9dc4a04 + ee815a0 commit 48edef8

File tree

10 files changed

+593
-211
lines changed

10 files changed

+593
-211
lines changed

CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12-
- `Connection.authenticate_oidc()`: add argument to set maximum device code flow poll time
12+
- `Connection.authenticate_oidc()`: add argument to set maximum device code flow poll time
13+
- Show progress bar while waiting for OIDC authentication with device code flow,
14+
including special mode for in Jupyter notebooks.
15+
([#237](https://github.com/Open-EO/openeo-python-client/issues/237))
1316

1417
### Changed
1518

openeo/internal/jupyter.py

+10
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@
8282
}
8383

8484

85+
def in_jupyter_context() -> bool:
86+
"""Check if we are running in an interactive Jupyter notebook context."""
87+
try:
88+
from ipykernel.zmqshell import ZMQInteractiveShell
89+
from IPython.core.getipython import get_ipython
90+
except ImportError:
91+
return False
92+
return isinstance(get_ipython(), ZMQInteractiveShell)
93+
94+
8595
def render_component(component: str, data = None, parameters: dict = None):
8696
parameters = parameters or {}
8797
# Special handling for batch job results, show either item or collection depending on the data

openeo/rest/auth/oidc.py

+142-40
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
"""
55

66
import base64
7+
import contextlib
78
import enum
89
import functools
910
import hashlib
1011
import http.server
12+
import inspect
1113
import json
1214
import logging
15+
import math
1316
import random
1417
import string
1518
import threading
@@ -18,14 +21,15 @@
1821
import warnings
1922
import webbrowser
2023
from collections import namedtuple
21-
from queue import Queue, Empty
22-
from typing import Tuple, Callable, Union, List, Optional
24+
from queue import Empty, Queue
25+
from typing import Callable, List, Optional, Tuple, Union
2326

2427
import requests
2528

2629
import openeo
30+
from openeo.internal.jupyter import in_jupyter_context
2731
from openeo.rest import OpenEoClientException
28-
from openeo.util import dict_no_none, url_join
32+
from openeo.util import SimpleProgressBar, clip, dict_no_none, url_join
2933

3034
log = logging.getLogger(__name__)
3135

@@ -659,6 +663,93 @@ def _get_token_endpoint_post_data(self) -> dict:
659663
)
660664

661665

666+
def _like_print(display: Callable) -> Callable:
667+
"""Ensure that display function supports an `end` argument like `print`"""
668+
if display is print or "end" in inspect.signature(display).parameters:
669+
return display
670+
else:
671+
return lambda *args, end="\n", **kwargs: display(*args, **kwargs)
672+
673+
674+
class _BasicDeviceCodePollUi:
675+
"""
676+
Basic (print + carriage return) implementation of the device code
677+
polling loop UI (e.g. show progress bar and status).
678+
"""
679+
680+
def __init__(
681+
self,
682+
timeout: float,
683+
elapsed: Callable[[], float],
684+
max_width: int = 80,
685+
display: Callable = print,
686+
):
687+
self.timeout = timeout
688+
self.elapsed = elapsed
689+
self._max_width = max_width
690+
self._status = "Authorization pending"
691+
self._display = _like_print(display)
692+
self._progress_bar = SimpleProgressBar(width=(max_width - 1) // 2)
693+
694+
def _instructions(self, info: VerificationInfo) -> str:
695+
if info.verification_uri_complete:
696+
return f"Visit {info.verification_uri_complete} to authenticate."
697+
else:
698+
return f"Visit {info.verification_uri} and enter user code {info.user_code!r} to authenticate."
699+
700+
def show_instructions(self, info: VerificationInfo) -> None:
701+
self._display(self._instructions(info=info))
702+
703+
def set_status(self, status: str):
704+
self._status = status
705+
706+
def show_progress(self, status: Optional[str] = None):
707+
if status:
708+
self.set_status(status)
709+
progress_bar = self._progress_bar.get(fraction=1.0 - self.elapsed() / self.timeout)
710+
text = f"{progress_bar} {self._status}"
711+
self._display(f"{text[:self._max_width]: <{self._max_width}s}", end="\r")
712+
713+
def close(self):
714+
self._display("", end="\n")
715+
716+
717+
class _JupyterDeviceCodePollUi(_BasicDeviceCodePollUi):
718+
def __init__(
719+
self,
720+
timeout: float,
721+
elapsed: Callable[[], float],
722+
max_width: int = 80,
723+
):
724+
super().__init__(timeout=timeout, elapsed=elapsed, max_width=max_width)
725+
import IPython.display
726+
727+
self._instructions_display = IPython.display.display({"text/html": " "}, raw=True, display_id=True)
728+
self._progress_display = IPython.display.display({"text/html": " "}, raw=True, display_id=True)
729+
730+
def _instructions(self, info: VerificationInfo) -> str:
731+
url = info.verification_uri_complete if info.verification_uri_complete else info.verification_uri
732+
instructions = f'Visit <a href="{url}" title="Authenticate at {url}">{url}</a>'
733+
instructions += f' <a href="#" onclick="navigator.clipboard.writeText({url!r});return false;" title="Copy authentication URL to clipboard">&#128203;</a>'
734+
if not info.verification_uri_complete:
735+
instructions += f" and enter user code {info.user_code!r}"
736+
instructions += " to authenticate."
737+
return instructions
738+
739+
def show_instructions(self, info: VerificationInfo) -> None:
740+
self._instructions_display.update({"text/html": self._instructions(info=info)}, raw=True)
741+
742+
def show_progress(self, status: Optional[str] = None):
743+
# TODO Add emoticons to status?
744+
if status:
745+
self.set_status(status)
746+
progress_bar = self._progress_bar.get(fraction=1.0 - self.elapsed() / self.timeout)
747+
self._progress_display.update({"text/html": f"<code>{progress_bar}</code> {self._status}"}, raw=True)
748+
749+
def close(self):
750+
pass
751+
752+
662753
class OidcDeviceAuthenticator(OidcAuthenticator):
663754
"""
664755
Implementation of OAuth Device Authorization grant/flow
@@ -721,17 +812,8 @@ def _get_verification_info(self, request_refresh_token: bool = False) -> Verific
721812
def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
722813
# Get verification url and user code
723814
verification_info = self._get_verification_info(request_refresh_token=request_refresh_token)
724-
if verification_info.verification_uri_complete:
725-
self._display(
726-
f"To authenticate: visit {verification_info.verification_uri_complete} ."
727-
)
728-
else:
729-
self._display("To authenticate: visit {u} and enter the user code {c!r}.".format(
730-
u=verification_info.verification_uri, c=verification_info.user_code)
731-
)
732815

733816
# Poll token endpoint
734-
elapsed = create_timer()
735817
token_endpoint = self._provider_config['token_endpoint']
736818
post_data = {
737819
"client_id": self.client_id,
@@ -742,34 +824,54 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
742824
post_data["code_verifier"] = self._pkce.code_verifier
743825
else:
744826
post_data["client_secret"] = self.client_secret
827+
745828
poll_interval = verification_info.interval
746829
log.debug("Start polling token endpoint (interval {i}s)".format(i=poll_interval))
747-
while elapsed() <= self._max_poll_time:
748-
time.sleep(poll_interval)
749830

750-
log.debug("Doing {g!r} token request {u!r} with post data fields {p!r} (client_id {c!r})".format(
751-
g=self.grant_type, c=self.client_id, u=token_endpoint, p=list(post_data.keys()))
752-
)
753-
resp = self._requests.post(url=token_endpoint, data=post_data)
754-
if resp.status_code == 200:
755-
log.info("[{e:5.1f}s] Authorized successfully.".format(e=elapsed()))
756-
self._display("Authorized successfully.")
757-
return self._get_access_token_result(data=resp.json())
758-
else:
759-
try:
760-
error = resp.json()["error"]
761-
except Exception:
762-
error = "unknown"
763-
if error == "authorization_pending":
764-
log.info("[{e:5.1f}s] Authorization pending.".format(e=elapsed()))
765-
elif error == "slow_down":
766-
log.info("[{e:5.1f}s] Polling too fast, will slow down.".format(e=elapsed()))
767-
poll_interval += 5
768-
else:
769-
raise OidcException("Failed to retrieve access token at {u!r}: {s} {r!r} {t!r}".format(
770-
s=resp.status_code, r=resp.reason, u=token_endpoint, t=resp.text
771-
))
772-
773-
raise OidcException("Timeout exceeded {m:.1f}s while polling for access token at {u!r}".format(
774-
u=token_endpoint, m=self._max_poll_time
775-
))
831+
elapsed = create_timer()
832+
next_poll = elapsed() + poll_interval
833+
# TODO: let poll UI determine sleep interval?
834+
sleep = clip(self._max_poll_time / 100, min=1, max=5)
835+
836+
if in_jupyter_context():
837+
poll_ui = _JupyterDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed)
838+
else:
839+
poll_ui = _BasicDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed, display=self._display)
840+
poll_ui.show_instructions(info=verification_info)
841+
842+
with contextlib.closing(poll_ui):
843+
while elapsed() <= self._max_poll_time:
844+
poll_ui.show_progress()
845+
time.sleep(sleep)
846+
847+
if elapsed() >= next_poll:
848+
log.debug(
849+
f"Doing {self.grant_type!r} token request {token_endpoint!r} with post data fields {list(post_data.keys())!r} (client_id {self.client_id!r})"
850+
)
851+
poll_ui.show_progress(status="Polling")
852+
resp = self._requests.post(url=token_endpoint, data=post_data, timeout=5)
853+
if resp.status_code == 200:
854+
log.info(f"[{elapsed():5.1f}s] Authorized successfully.")
855+
poll_ui.show_progress(status="Authorized successfully")
856+
# TODO remove progress bar when authorized succesfully?
857+
return self._get_access_token_result(data=resp.json())
858+
else:
859+
try:
860+
error = resp.json()["error"]
861+
except Exception:
862+
error = "unknown"
863+
log.info(f"[{elapsed():5.1f}s] not authorized yet: {error}")
864+
if error == "authorization_pending":
865+
poll_ui.show_progress(status="Authorization pending")
866+
elif error == "slow_down":
867+
poll_ui.show_progress(status="Slowing down")
868+
poll_interval += 5
869+
else:
870+
# TODO: skip occasional glitches (e.g. see `SkipIntermittentFailures` from openeo-aggregator)
871+
raise OidcException(
872+
f"Failed to retrieve access token at {token_endpoint!r}: {resp.status_code} {resp.reason!r} {resp.text!r}"
873+
)
874+
next_poll = elapsed() + poll_interval
875+
876+
poll_ui.show_progress(status="Timed out")
877+
raise OidcException(f"Timeout ({self._max_poll_time:.1f}s) while polling for access token.")

openeo/rest/auth/testing.py

-10
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,3 @@ def get_request_history(
316316
if (method is None or method.lower() == r.method.lower())
317317
and (url is None or url == r.url)
318318
]
319-
320-
@contextlib.contextmanager
321-
def assert_device_code_poll_sleep(expect_called=True):
322-
"""Fake sleeping, but check it was called with poll interval (or not)."""
323-
with mock.patch("time.sleep") as sleep:
324-
yield
325-
if expect_called:
326-
sleep.assert_called_with(DEVICE_CODE_POLL_INTERVAL)
327-
else:
328-
sleep.assert_not_called()

openeo/util.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Various utilities and helpers.
33
"""
4+
# TODO: split this kitchen-sink in thematic submodules
45
import datetime as dt
56
import functools
67
import json
@@ -10,7 +11,7 @@
1011
import time
1112
from collections import OrderedDict
1213
from pathlib import Path
13-
from typing import Any, Union, Tuple, Callable, Optional
14+
from typing import Any, Callable, Optional, Tuple, Union
1415
from urllib.parse import urljoin
1516

1617
import requests
@@ -602,3 +603,26 @@ def to_bbox_dict(x: Any, *, crs: Optional[str] = None) -> BBoxDict:
602603
def url_join(root_url: str, path: str):
603604
"""Join a base url and sub path properly."""
604605
return urljoin(root_url.rstrip("/") + "/", path.lstrip("/"))
606+
607+
608+
def clip(x: float, min: float, max: float) -> float:
609+
"""Clip given value between minimum and maximum value"""
610+
return min if x < min else (x if x < max else max)
611+
612+
613+
class SimpleProgressBar:
614+
"""Simple ASCII-based progress bar helper."""
615+
616+
__slots__ = ["width", "bar", "fill", "left", "right"]
617+
618+
def __init__(self, width: int = 40, *, bar: str = "#", fill: str = "-", left: str = "[", right: str = "]"):
619+
self.width = int(width)
620+
self.bar = bar[0]
621+
self.fill = fill[0]
622+
self.left = left
623+
self.right = right
624+
625+
def get(self, fraction: float) -> str:
626+
width = self.width - len(self.left) - len(self.right)
627+
bar = self.bar * int(round(width * clip(fraction, min=0, max=1)))
628+
return f"{self.left}{bar:{self.fill}<{width}s}{self.right}"

0 commit comments

Comments
 (0)