1212import warnings
1313import traceback
1414import urllib3
15+ from urllib .parse import urljoin
1516from urllib3 .exceptions import InsecureRequestWarning
1617from json import JSONDecodeError
1718from typing import Union
@@ -128,7 +129,7 @@ def __init__(self, value, field='error_code', default=None, warning=None):
128129class StaticCredentials (Credentials ):
129130 """A credential class that simply takes a set of static credentials."""
130131
131- def __init__ (self , access_key_id , private_key , access_token = '' , method = 'static' ):
132+ def __init__ (self , access_key_id = '' , private_key = '' , access_token = '' , method = 'static' ):
132133 super (StaticCredentials , self ).__init__ (
133134 access_key_id = access_key_id , private_key = private_key ,
134135 access_token = access_token , method = method
@@ -158,6 +159,7 @@ def __init__(self, debug=False, tls_verify=False, strict_errors=False, tls_warni
158159
159160 _loader = Loader ()
160161 _user_agent = self ._make_user_agent_header ()
162+
161163 self ._client_creator = ClientCreator (
162164 _loader ,
163165 Context (),
@@ -274,6 +276,19 @@ def _warning_format(message, category, filename, lineno, line=None):
274276 self .CREDENTIAL_NAME_PATTERN = re .compile (r'[^a-z0-9-]' )
275277 self .OPERATION_REGEX = re .compile (r'operation ([0-9a-zA-Z-]{36}) running' )
276278
279+ # Workload services with special credential and endpoint handling
280+ self .WORKLOAD_SERVICES = ['dfworkload' ]
281+
282+ # substrings to check for in different CRNs
283+ self .CRN_STRINGS = {
284+ 'generic' : ['crn:' ],
285+ 'env' : [':environments:' , ':environment:' ],
286+ 'df' : [':df:' , ':service:' ],
287+ 'flow' : [':df:' , ':flow:' ],
288+ 'readyflow' : [':df:' , 'readyFlow' ],
289+ 'deployment' : [':df:' , ':deployment:' ]
290+ }
291+
277292 def _make_user_agent_header (self ):
278293 cdpy_version = pkg_resources .get_distribution ('cdpy' ).version
279294 return '%s CDPY/%s CDPCLI/%s Python/%s %s/%s' % (
@@ -305,24 +320,44 @@ def _setup_logger(self, log_level, log_format):
305320
306321 self .logger .addHandler (handler )
307322
308- def _build_client (self , service ):
309- if not self .cdp_credentials :
310- self .cdp_credentials = self ._client_creator .context .get_credentials ()
323+ def _build_client (self , service , parameters = None ):
324+ if service in self .WORKLOAD_SERVICES :
325+ if service == 'dfworkload' :
326+ workload_name = 'DF'
327+ else :
328+ workload_name = None
329+ self .throw_error (CdpError ("Workload %s not recognised for client generation" % service ))
330+ if 'environmentCrn' not in parameters :
331+ self .throw_error (CdpError ("environmentCrn must be supplied when connecting to %s" % service ))
332+ df_access_token = self .call (
333+ svc = 'iam' , func = 'generate_workload_auth_token' ,
334+ workloadName = workload_name , environmentCrn = parameters ['environmentCrn' ]
335+ )
336+ token = df_access_token ['token' ]
337+ if not token .startswith ('Bearer ' ):
338+ token = 'Bearer ' + token
339+ credentials = StaticCredentials (access_token = token )
340+ endpoint_url = urljoin (df_access_token ['endpointUrl' ], '/' )
341+ else :
342+ if not self .cdp_credentials :
343+ self .cdp_credentials = self ._client_creator .context .get_credentials ()
344+ credentials = self .cdp_credentials
345+ endpoint_url = self .client_endpoint
311346 try :
312347 # region introduced in client version 0.9.42
313348 client = self ._client_creator .create_client (
314349 service_name = service ,
315350 region = self .cp_region ,
316- explicit_endpoint_url = self . client_endpoint ,
351+ explicit_endpoint_url = endpoint_url ,
317352 tls_verification = self .tls_verify ,
318- credentials = self . cdp_credentials
353+ credentials = credentials
319354 )
320355 except TypeError :
321356 client = self ._client_creator .create_client (
322357 service_name = service ,
323- explicit_endpoint_url = self . client_endpoint ,
358+ explicit_endpoint_url = endpoint_url ,
324359 tls_verification = self .tls_verify ,
325- credentials = self . cdp_credentials
360+ credentials = credentials
326361 )
327362 return client
328363
@@ -358,11 +393,11 @@ def _default_throw_warning(warning: 'CdpWarning'):
358393 def regex_search (pattern , obj ):
359394 return re .search (pattern , obj )
360395
361- def validate_crn (self , obj : str ):
362- if obj is not None and obj . startswith ( 'crn:' ) :
363- pass
364- else :
365- self . throw_error ( CdpError ( "Supplied env_crn %s is not a valid CDP crn" % str (obj )))
396+ def validate_crn (self , obj : str , crn_type = 'generic' ):
397+ for substring in self . CRN_STRINGS [ crn_type ] :
398+ if substring not in obj :
399+ self . throw_error ( CdpError ( "Supplied crn %s of proposed type %s is missing substring %s"
400+ % ( str (obj ), crn_type , substring )))
366401
367402 @staticmethod
368403 def sleep (seconds ):
@@ -388,10 +423,10 @@ def _convert(o):
388423
389424 return json .dumps (data , indent = 2 , default = _convert )
390425
391- def _client (self , service ):
426+ def _client (self , service , parameters = None ):
392427 """Builds a CDP Endpoint client of a given type, and caches it against later reuse"""
393428 if service not in self ._clients :
394- self ._clients [service ] = self ._build_client (service )
429+ self ._clients [service ] = self ._build_client (service , parameters )
395430 return self ._clients [service ]
396431
397432 def read_file (self , file_path ):
@@ -520,7 +555,7 @@ def call(self, svc: str, func: str, ret_field: str = None, squelch: ['Squelch']
520555 Returns (dict, list, None): Output of CDP CLI Call
521556 """
522557 try :
523- call_function = getattr (self ._client (service = svc ), func )
558+ call_function = getattr (self ._client (service = svc , parameters = kwargs ), func )
524559 if self .scrub_inputs :
525560 # Remove unused submission values as the API rejects them
526561 payload = {x : y for x , y in kwargs .items () if y is not None }
0 commit comments