4
4
"""
5
5
6
6
import base64
7
+ import contextlib
7
8
import enum
8
9
import functools
9
10
import hashlib
10
11
import http .server
12
+ import inspect
11
13
import json
12
14
import logging
15
+ import math
13
16
import random
14
17
import string
15
18
import threading
18
21
import warnings
19
22
import webbrowser
20
23
from collections import namedtuple
21
- from queue import Queue , Empty
22
- from typing import Tuple , Callable , Union , List , Optional
24
+ from queue import Empty , Queue
25
+ from typing import Callable , List , Optional , Tuple , Union
23
26
24
27
import requests
25
28
26
29
import openeo
30
+ from openeo .internal .jupyter import in_jupyter_context
27
31
from openeo .rest import OpenEoClientException
28
- from openeo .util import dict_no_none , url_join
32
+ from openeo .util import SimpleProgressBar , clip , dict_no_none , url_join
29
33
30
34
log = logging .getLogger (__name__ )
31
35
@@ -659,6 +663,93 @@ def _get_token_endpoint_post_data(self) -> dict:
659
663
)
660
664
661
665
666
+ def _like_print (display : Callable ) -> Callable :
667
+ """Ensure that display function supports an `end` argument like `print`"""
668
+ if display is print or "end" in inspect .signature (display ).parameters :
669
+ return display
670
+ else :
671
+ return lambda * args , end = "\n " , ** kwargs : display (* args , ** kwargs )
672
+
673
+
674
+ class _BasicDeviceCodePollUi :
675
+ """
676
+ Basic (print + carriage return) implementation of the device code
677
+ polling loop UI (e.g. show progress bar and status).
678
+ """
679
+
680
+ def __init__ (
681
+ self ,
682
+ timeout : float ,
683
+ elapsed : Callable [[], float ],
684
+ max_width : int = 80 ,
685
+ display : Callable = print ,
686
+ ):
687
+ self .timeout = timeout
688
+ self .elapsed = elapsed
689
+ self ._max_width = max_width
690
+ self ._status = "Authorization pending"
691
+ self ._display = _like_print (display )
692
+ self ._progress_bar = SimpleProgressBar (width = (max_width - 1 ) // 2 )
693
+
694
+ def _instructions (self , info : VerificationInfo ) -> str :
695
+ if info .verification_uri_complete :
696
+ return f"Visit { info .verification_uri_complete } to authenticate."
697
+ else :
698
+ return f"Visit { info .verification_uri } and enter user code { info .user_code !r} to authenticate."
699
+
700
+ def show_instructions (self , info : VerificationInfo ) -> None :
701
+ self ._display (self ._instructions (info = info ))
702
+
703
+ def set_status (self , status : str ):
704
+ self ._status = status
705
+
706
+ def show_progress (self , status : Optional [str ] = None ):
707
+ if status :
708
+ self .set_status (status )
709
+ progress_bar = self ._progress_bar .get (fraction = 1.0 - self .elapsed () / self .timeout )
710
+ text = f"{ progress_bar } { self ._status } "
711
+ self ._display (f"{ text [:self ._max_width ]: <{self ._max_width }s} " , end = "\r " )
712
+
713
+ def close (self ):
714
+ self ._display ("" , end = "\n " )
715
+
716
+
717
+ class _JupyterDeviceCodePollUi (_BasicDeviceCodePollUi ):
718
+ def __init__ (
719
+ self ,
720
+ timeout : float ,
721
+ elapsed : Callable [[], float ],
722
+ max_width : int = 80 ,
723
+ ):
724
+ super ().__init__ (timeout = timeout , elapsed = elapsed , max_width = max_width )
725
+ import IPython .display
726
+
727
+ self ._instructions_display = IPython .display .display ({"text/html" : " " }, raw = True , display_id = True )
728
+ self ._progress_display = IPython .display .display ({"text/html" : " " }, raw = True , display_id = True )
729
+
730
+ def _instructions (self , info : VerificationInfo ) -> str :
731
+ url = info .verification_uri_complete if info .verification_uri_complete else info .verification_uri
732
+ instructions = f'Visit <a href="{ url } " title="Authenticate at { url } ">{ url } </a>'
733
+ instructions += f' <a href="#" onclick="navigator.clipboard.writeText({ url !r} );return false;" title="Copy authentication URL to clipboard">📋</a>'
734
+ if not info .verification_uri_complete :
735
+ instructions += f" and enter user code { info .user_code !r} "
736
+ instructions += " to authenticate."
737
+ return instructions
738
+
739
+ def show_instructions (self , info : VerificationInfo ) -> None :
740
+ self ._instructions_display .update ({"text/html" : self ._instructions (info = info )}, raw = True )
741
+
742
+ def show_progress (self , status : Optional [str ] = None ):
743
+ # TODO Add emoticons to status?
744
+ if status :
745
+ self .set_status (status )
746
+ progress_bar = self ._progress_bar .get (fraction = 1.0 - self .elapsed () / self .timeout )
747
+ self ._progress_display .update ({"text/html" : f"<code>{ progress_bar } </code> { self ._status } " }, raw = True )
748
+
749
+ def close (self ):
750
+ pass
751
+
752
+
662
753
class OidcDeviceAuthenticator (OidcAuthenticator ):
663
754
"""
664
755
Implementation of OAuth Device Authorization grant/flow
@@ -721,17 +812,8 @@ def _get_verification_info(self, request_refresh_token: bool = False) -> Verific
721
812
def get_tokens (self , request_refresh_token : bool = False ) -> AccessTokenResult :
722
813
# Get verification url and user code
723
814
verification_info = self ._get_verification_info (request_refresh_token = request_refresh_token )
724
- if verification_info .verification_uri_complete :
725
- self ._display (
726
- f"To authenticate: visit { verification_info .verification_uri_complete } ."
727
- )
728
- else :
729
- self ._display ("To authenticate: visit {u} and enter the user code {c!r}." .format (
730
- u = verification_info .verification_uri , c = verification_info .user_code )
731
- )
732
815
733
816
# Poll token endpoint
734
- elapsed = create_timer ()
735
817
token_endpoint = self ._provider_config ['token_endpoint' ]
736
818
post_data = {
737
819
"client_id" : self .client_id ,
@@ -742,34 +824,54 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
742
824
post_data ["code_verifier" ] = self ._pkce .code_verifier
743
825
else :
744
826
post_data ["client_secret" ] = self .client_secret
827
+
745
828
poll_interval = verification_info .interval
746
829
log .debug ("Start polling token endpoint (interval {i}s)" .format (i = poll_interval ))
747
- while elapsed () <= self ._max_poll_time :
748
- time .sleep (poll_interval )
749
830
750
- log .debug ("Doing {g!r} token request {u!r} with post data fields {p!r} (client_id {c!r})" .format (
751
- g = self .grant_type , c = self .client_id , u = token_endpoint , p = list (post_data .keys ()))
752
- )
753
- resp = self ._requests .post (url = token_endpoint , data = post_data )
754
- if resp .status_code == 200 :
755
- log .info ("[{e:5.1f}s] Authorized successfully." .format (e = elapsed ()))
756
- self ._display ("Authorized successfully." )
757
- return self ._get_access_token_result (data = resp .json ())
758
- else :
759
- try :
760
- error = resp .json ()["error" ]
761
- except Exception :
762
- error = "unknown"
763
- if error == "authorization_pending" :
764
- log .info ("[{e:5.1f}s] Authorization pending." .format (e = elapsed ()))
765
- elif error == "slow_down" :
766
- log .info ("[{e:5.1f}s] Polling too fast, will slow down." .format (e = elapsed ()))
767
- poll_interval += 5
768
- else :
769
- raise OidcException ("Failed to retrieve access token at {u!r}: {s} {r!r} {t!r}" .format (
770
- s = resp .status_code , r = resp .reason , u = token_endpoint , t = resp .text
771
- ))
772
-
773
- raise OidcException ("Timeout exceeded {m:.1f}s while polling for access token at {u!r}" .format (
774
- u = token_endpoint , m = self ._max_poll_time
775
- ))
831
+ elapsed = create_timer ()
832
+ next_poll = elapsed () + poll_interval
833
+ # TODO: let poll UI determine sleep interval?
834
+ sleep = clip (self ._max_poll_time / 100 , min = 1 , max = 5 )
835
+
836
+ if in_jupyter_context ():
837
+ poll_ui = _JupyterDeviceCodePollUi (timeout = self ._max_poll_time , elapsed = elapsed )
838
+ else :
839
+ poll_ui = _BasicDeviceCodePollUi (timeout = self ._max_poll_time , elapsed = elapsed , display = self ._display )
840
+ poll_ui .show_instructions (info = verification_info )
841
+
842
+ with contextlib .closing (poll_ui ):
843
+ while elapsed () <= self ._max_poll_time :
844
+ poll_ui .show_progress ()
845
+ time .sleep (sleep )
846
+
847
+ if elapsed () >= next_poll :
848
+ log .debug (
849
+ f"Doing { self .grant_type !r} token request { token_endpoint !r} with post data fields { list (post_data .keys ())!r} (client_id { self .client_id !r} )"
850
+ )
851
+ poll_ui .show_progress (status = "Polling" )
852
+ resp = self ._requests .post (url = token_endpoint , data = post_data , timeout = 5 )
853
+ if resp .status_code == 200 :
854
+ log .info (f"[{ elapsed ():5.1f} s] Authorized successfully." )
855
+ poll_ui .show_progress (status = "Authorized successfully" )
856
+ # TODO remove progress bar when authorized succesfully?
857
+ return self ._get_access_token_result (data = resp .json ())
858
+ else :
859
+ try :
860
+ error = resp .json ()["error" ]
861
+ except Exception :
862
+ error = "unknown"
863
+ log .info (f"[{ elapsed ():5.1f} s] not authorized yet: { error } " )
864
+ if error == "authorization_pending" :
865
+ poll_ui .show_progress (status = "Authorization pending" )
866
+ elif error == "slow_down" :
867
+ poll_ui .show_progress (status = "Slowing down" )
868
+ poll_interval += 5
869
+ else :
870
+ # TODO: skip occasional glitches (e.g. see `SkipIntermittentFailures` from openeo-aggregator)
871
+ raise OidcException (
872
+ f"Failed to retrieve access token at { token_endpoint !r} : { resp .status_code } { resp .reason !r} { resp .text !r} "
873
+ )
874
+ next_poll = elapsed () + poll_interval
875
+
876
+ poll_ui .show_progress (status = "Timed out" )
877
+ raise OidcException (f"Timeout ({ self ._max_poll_time :.1f} s) while polling for access token." )
0 commit comments