9
9
from kaggle_web_client import (KaggleWebClient ,
10
10
_KAGGLE_URL_BASE_ENV_VAR_NAME ,
11
11
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME ,
12
+ _KAGGLE_IAP_TOKEN_ENV_VAR_NAME ,
12
13
CredentialError , BackendError )
13
14
from kaggle_datasets import KaggleDatasets , _KAGGLE_TPU_NAME_ENV_VAR_NAME
14
15
15
16
_TEST_JWT = 'test-secrets-key'
17
+ _TEST_IAP = 'IAP_TOKEN'
16
18
17
19
_TPU_GCS_BUCKET = 'gs://kds-tpu-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a'
18
20
_AUTOML_GCS_BUCKET = 'gs://kds-automl-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a'
@@ -39,7 +41,7 @@ def do_POST(s):
39
41
class TestDatasets (unittest .TestCase ):
40
42
SERVER_ADDRESS = urlparse (os .getenv (_KAGGLE_URL_BASE_ENV_VAR_NAME , default = "http://127.0.0.1:8001" ))
41
43
42
- def _test_client (self , client_func , expected_path , expected_body , is_tpu = True , success = True ):
44
+ def _test_client (self , client_func , expected_path , expected_body , is_tpu = True , success = True , iap_token = False ):
43
45
_request = {}
44
46
45
47
class GetGcsPathHandler (GcsDatasetsHTTPHandler ):
@@ -63,6 +65,8 @@ def get_response(self):
63
65
env .set (_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME , _TEST_JWT )
64
66
if is_tpu :
65
67
env .set (_KAGGLE_TPU_NAME_ENV_VAR_NAME , 'FAKE_TPU' )
68
+ if iap_token :
69
+ env .set (_KAGGLE_IAP_TOKEN_ENV_VAR_NAME , _TEST_IAP )
66
70
with env :
67
71
with HTTPServer ((self .SERVER_ADDRESS .hostname , self .SERVER_ADDRESS .port ), GetGcsPathHandler ) as httpd :
68
72
threading .Thread (target = httpd .serve_forever ).start ()
@@ -87,6 +91,12 @@ def get_response(self):
87
91
msg = "Fake server did not receive an application/json content type header from the KaggleDatasets client." )
88
92
self .assertIn ('X-Kaggle-Authorization' , headers .keys (),
89
93
msg = "Fake server did not receive an X-Kaggle-Authorization header from the KaggleDatasets client." )
94
+ if iap_token :
95
+ self .assertEqual (f'Bearer { _TEST_IAP } ' , headers .get ('Authorization' ),
96
+ msg = "Fake server did not receive an Authorization header from the KaggleDatasets client." )
97
+ else :
98
+ self .assertNotIn ('Authorization' , headers .keys (),
99
+ msg = "Fake server received an Authorization header from the KaggleDatasets client. It shouldn't." )
90
100
self .assertEqual (f'Bearer { _TEST_JWT } ' , headers .get ('X-Kaggle-Authorization' ),
91
101
msg = "Fake server did not receive the right X-Kaggle-Authorization header from the KaggleDatasets client." )
92
102
@@ -127,3 +137,12 @@ def call_get_gcs_path():
127
137
{'MountSlug' : None , 'IntegrationType' : 2 },
128
138
is_tpu = True ,
129
139
success = False )
140
+
141
+ def test_iap_token (self ):
142
+ def call_get_gcs_path ():
143
+ client = KaggleDatasets ()
144
+ gcs_path = client .get_gcs_path ()
145
+ self ._test_client (call_get_gcs_path ,
146
+ '/requests/CopyDatasetVersionToKnownGcsBucketRequest' ,
147
+ {'MountSlug' : None , 'IntegrationType' : 1 },
148
+ is_tpu = False , iap_token = True )
0 commit comments