1
- from dataclasses import dataclass
2
- from typing import Union , Tuple , Callable , Any , Awaitable
1
+ from typing import Union , Tuple , Callable , Any , Awaitable , Optional , List
3
2
4
3
from redis .credentials import StreamingCredentialProvider
5
4
from redis .auth .token_manager import TokenManagerConfig , RetryPolicy , TokenManager , CredentialsListener
6
5
7
- from redis_entraid .identity_provider import EntraIDIdentityProvider
8
-
9
-
10
- @dataclass
11
- class TokenAuthConfig :
12
- """
13
- Configuration for token authentication.
14
-
15
- Requires :class:`EntraIDIdentityProvider`. It's recommended to use an additional factory methods.
16
- See :class:`EntraIDIdentityProvider` for more information.
17
- """
18
- DEFAULT_EXPIRATION_REFRESH_RATIO = 0.8
19
- DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0
20
- DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 100
21
- DEFAULT_MAX_ATTEMPTS = 3
22
- DEFAULT_DELAY_IN_MS = 3
23
-
24
- idp : EntraIDIdentityProvider
25
- expiration_refresh_ratio : float = DEFAULT_EXPIRATION_REFRESH_RATIO
26
- lower_refresh_bound_millis : int = DEFAULT_LOWER_REFRESH_BOUND_MILLIS
27
- token_request_execution_timeout_in_ms : int = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS
28
- max_attempts : int = DEFAULT_MAX_ATTEMPTS
29
- delay_in_ms : int = DEFAULT_DELAY_IN_MS
30
-
31
- def get_token_manager_config (self ) -> TokenManagerConfig :
32
- return TokenManagerConfig (
33
- self .expiration_refresh_ratio ,
34
- self .lower_refresh_bound_millis ,
35
- self .token_request_execution_timeout_in_ms ,
36
- RetryPolicy (
37
- self .max_attempts ,
38
- self .delay_in_ms
39
- )
40
- )
41
-
42
- def get_identity_provider (self ) -> EntraIDIdentityProvider :
43
- return self .idp
6
+ from redis_entraid .identity_provider import ManagedIdentityType , ManagedIdentityIdType , \
7
+ _create_provider_from_managed_identity , ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig , \
8
+ _create_provider_from_service_principal
44
9
10
+ DEFAULT_EXPIRATION_REFRESH_RATIO = 0.7
11
+ DEFAULT_LOWER_REFRESH_BOUND_MILLIS = 0
12
+ DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS = 100
13
+ DEFAULT_MAX_ATTEMPTS = 3
14
+ DEFAULT_DELAY_IN_MS = 3
45
15
46
16
class EntraIdCredentialsProvider (StreamingCredentialProvider ):
47
17
def __init__ (
48
18
self ,
49
- config : TokenAuthConfig ,
19
+ idp_config : Union [ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig ],
20
+ token_manager_config : TokenManagerConfig ,
50
21
initial_delay_in_ms : float = 0 ,
51
22
block_for_initial : bool = False ,
52
23
):
53
24
"""
54
- :param config:
25
+ :param idp_config: Identity provider specific configuration.
26
+ :param token_manager_config: Token manager specific configuration.
55
27
:param initial_delay_in_ms: Initial delay before run background refresh (valid for async only)
56
28
:param block_for_initial: Block execution until initial token will be acquired (valid for async only)
57
29
"""
30
+ if isinstance (idp_config , ManagedIdentityProviderConfig ):
31
+ idp = _create_provider_from_managed_identity (idp_config )
32
+ else :
33
+ idp = _create_provider_from_service_principal (idp_config )
34
+
58
35
self ._token_mgr = TokenManager (
59
- config . get_identity_provider () ,
60
- config . get_token_manager_config ()
36
+ idp ,
37
+ token_manager_config
61
38
)
62
39
self ._listener = CredentialsListener ()
63
40
self ._is_streaming = False
64
41
self ._initial_delay_in_ms = initial_delay_in_ms
65
42
self ._block_for_initial = block_for_initial
66
43
67
44
def get_credentials (self ) -> Union [Tuple [str ], Tuple [str , str ]]:
45
+ """
46
+ Acquire token from the identity provider.
47
+ """
68
48
init_token = self ._token_mgr .acquire_token ()
69
49
70
50
if self ._is_streaming is False :
@@ -77,6 +57,9 @@ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
77
57
return init_token .get_token ().try_get ('oid' ), init_token .get_token ().get_value ()
78
58
79
59
async def get_credentials_async (self ) -> Union [Tuple [str ], Tuple [str , str ]]:
60
+ """
61
+ Acquire token from the identity provider in async mode.
62
+ """
80
63
init_token = await self ._token_mgr .acquire_token_async ()
81
64
82
65
if self ._is_streaming is False :
@@ -98,3 +81,84 @@ def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]):
98
81
99
82
def is_streaming (self ) -> bool :
100
83
return self ._is_streaming
84
+
85
+
86
+ def create_from_managed_identity (
87
+ identity_type : ManagedIdentityType ,
88
+ resource : str ,
89
+ id_type : Optional [ManagedIdentityIdType ] = None ,
90
+ id_value : Optional [str ] = '' ,
91
+ kwargs : Optional [dict ] = {},
92
+ token_manager_config : Optional [TokenManagerConfig ] = TokenManagerConfig (
93
+ DEFAULT_EXPIRATION_REFRESH_RATIO ,
94
+ DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
95
+ DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,
96
+ RetryPolicy (
97
+ DEFAULT_MAX_ATTEMPTS ,
98
+ DEFAULT_DELAY_IN_MS
99
+ )
100
+ )
101
+ ) -> EntraIdCredentialsProvider :
102
+ """
103
+ Create a credential provider from a managed identity type.
104
+
105
+ :param identity_type: Managed identity type.
106
+ :param resource: Identity provider resource.
107
+ :param id_type: Identity provider type.
108
+ :param id_value: Identity provider value.
109
+ :param kwargs: Optional keyword arguments to pass to identity provider. See: :class:`ManagedIdentityClient`
110
+ :param token_manager_config: Token manager specific configuration.
111
+ :return: EntraIdCredentialsProvider instance.
112
+ """
113
+ managed_identity_config = ManagedIdentityProviderConfig (
114
+ identity_type = identity_type ,
115
+ resource = resource ,
116
+ id_type = id_type ,
117
+ id_value = id_value ,
118
+ kwargs = kwargs
119
+ )
120
+
121
+ return EntraIdCredentialsProvider (managed_identity_config , token_manager_config )
122
+
123
+
124
+ def create_from_service_principal (
125
+ client_id : str ,
126
+ client_credential : Any ,
127
+ tenant_id : Optional [str ] = None ,
128
+ scopes : Optional [List [str ]] = None ,
129
+ timeout : Optional [float ] = None ,
130
+ token_kwargs : Optional [dict ] = {},
131
+ app_kwargs : Optional [dict ] = {},
132
+ token_manager_config : Optional [TokenManagerConfig ] = TokenManagerConfig (
133
+ DEFAULT_EXPIRATION_REFRESH_RATIO ,
134
+ DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
135
+ DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,
136
+ RetryPolicy (
137
+ DEFAULT_MAX_ATTEMPTS ,
138
+ DEFAULT_DELAY_IN_MS
139
+ )
140
+ )) -> EntraIdCredentialsProvider :
141
+ """
142
+ Create a credential provider from a service principal.
143
+
144
+ :param client_credential: Service principal credentials.
145
+ :param client_id: Service principal client ID.
146
+ :param scopes: Service principal scopes. Fallback to default scopes if None.
147
+ :param timeout: Service principal timeout.
148
+ :param tenant_id: Service principal tenant ID.
149
+ :param token_kwargs: Optional token arguments to pass to service identity provider.
150
+ :param app_kwargs: Optional keyword arguments to pass to service principal application.
151
+ :param token_manager_config: Token manager specific configuration.
152
+ :return: EntraIdCredentialsProvider instance.
153
+ """
154
+ service_principal_config = ServicePrincipalIdentityProviderConfig (
155
+ client_credential = client_credential ,
156
+ client_id = client_id ,
157
+ scopes = scopes ,
158
+ timeout = timeout ,
159
+ tenant_id = tenant_id ,
160
+ app_kwargs = app_kwargs ,
161
+ token_kwargs = token_kwargs ,
162
+ )
163
+
164
+ return EntraIdCredentialsProvider (service_principal_config , token_manager_config )
0 commit comments