Skip to content

Commit 73f0708

Browse files
motusbpkroth
andauthored
Use token-based authentication instead of the access key in the AzureFileShareService (#779)
Make `AzureFileShareService` class use token-based credential instead of the access key. This PR is part of #777 **NOTE:** * We have to use late initialization of the `_share_client` to avoid issues with JSON schema validation tests. * More PRs will follow: * Use to `azcopy` on the remote VMs instead of mounting the file share using the access key * Remove references to `"storageAccountKey"` from configurations and tests and document the new mechanism --------- Co-authored-by: Brian Kroth <[email protected]>
1 parent 7dce3d1 commit 73f0708

File tree

3 files changed

+45
-28
lines changed

3 files changed

+45
-28
lines changed

mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mlos_bench.services.base_fileshare import FileShareService
1515
from mlos_bench.services.base_service import Service
16+
from mlos_bench.services.types.authenticator_type import SupportsAuth
1617
from mlos_bench.util import check_required_params
1718

1819
_LOG = logging.getLogger(__name__)
@@ -52,23 +53,30 @@ def __init__(
5253
parent,
5354
self.merge_methods(methods, [self.upload, self.download]),
5455
)
55-
5656
check_required_params(
5757
self.config,
5858
{
5959
"storageAccountName",
6060
"storageFileShareName",
61-
"storageAccountKey",
6261
},
6362
)
64-
65-
self._share_client = ShareClient.from_share_url(
66-
AzureFileShareService._SHARE_URL.format(
67-
account_name=self.config["storageAccountName"],
68-
fs_name=self.config["storageFileShareName"],
69-
),
70-
credential=self.config["storageAccountKey"],
71-
)
63+
self._share_client: Optional[ShareClient] = None
64+
65+
def _get_share_client(self) -> ShareClient:
66+
"""Get the Azure file share client object."""
67+
if self._share_client is None:
68+
assert self._parent is not None and isinstance(
69+
self._parent, SupportsAuth
70+
), "Authorization service not provided. Include service-auth.jsonc?"
71+
self._share_client = ShareClient.from_share_url(
72+
self._SHARE_URL.format(
73+
account_name=self.config["storageAccountName"],
74+
fs_name=self.config["storageFileShareName"],
75+
),
76+
credential=self._parent.get_access_token(),
77+
token_intent="backup",
78+
)
79+
return self._share_client
7280

7381
def download(
7482
self,
@@ -78,7 +86,7 @@ def download(
7886
recursive: bool = True,
7987
) -> None:
8088
super().download(params, remote_path, local_path, recursive)
81-
dir_client = self._share_client.get_directory_client(remote_path)
89+
dir_client = self._get_share_client().get_directory_client(remote_path)
8290
if dir_client.exists():
8391
os.makedirs(local_path, exist_ok=True)
8492
for content in dir_client.list_directories_and_files():
@@ -91,7 +99,7 @@ def download(
9199
# Ensure parent folders exist
92100
folder, _ = os.path.split(local_path)
93101
os.makedirs(folder, exist_ok=True)
94-
file_client = self._share_client.get_file_client(remote_path)
102+
file_client = self._get_share_client().get_file_client(remote_path)
95103
try:
96104
data = file_client.download_file()
97105
with open(local_path, "wb") as output_file:
@@ -147,7 +155,7 @@ def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[
147155
# Ensure parent folders exist
148156
folder, _ = os.path.split(remote_path)
149157
self._remote_makedirs(folder)
150-
file_client = self._share_client.get_file_client(remote_path)
158+
file_client = self._get_share_client().get_file_client(remote_path)
151159
with open(local_path, "rb") as file_data:
152160
_LOG.debug("Upload file: %s -> %s", local_path, remote_path)
153161
file_client.upload_file(file_data)
@@ -167,6 +175,6 @@ def _remote_makedirs(self, remote_path: str) -> None:
167175
if not folder:
168176
continue
169177
path += folder + "/"
170-
dir_client = self._share_client.get_directory_client(path)
178+
dir_client = self._get_share_client().get_directory_client(path)
171179
if not dir_client.exists():
172180
dir_client.create_directory()

mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ def test_download_file(
2626
local_folder = "some/local/folder"
2727
remote_path = f"{remote_folder}/{filename}"
2828
local_path = f"{local_folder}/{filename}"
29-
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
29+
3030
config: dict = {}
31-
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object(
31+
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
32+
mock_share_client, "get_file_client"
33+
) as mock_get_file_client, patch.object(
3234
mock_share_client, "get_directory_client"
3335
) as mock_get_directory_client:
36+
3437
mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False))
3538

3639
azure_fileshare.download(config, remote_path, local_path)
@@ -81,8 +84,9 @@ def test_download_folder_non_recursive(
8184
local_folder = "some/local/folder"
8285
dir_client_returns = make_dir_client_returns(remote_folder)
8386
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
87+
8488
config: dict = {}
85-
with patch.object(
89+
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
8690
mock_share_client, "get_directory_client"
8791
) as mock_get_directory_client, patch.object(
8892
mock_share_client, "get_file_client"
@@ -114,15 +118,14 @@ def test_download_folder_recursive(
114118
remote_folder = "a/remote/folder"
115119
local_folder = "some/local/folder"
116120
dir_client_returns = make_dir_client_returns(remote_folder)
117-
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
121+
118122
config: dict = {}
119-
with patch.object(
123+
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
120124
mock_share_client, "get_directory_client"
121125
) as mock_get_directory_client, patch.object(
122126
mock_share_client, "get_file_client"
123127
) as mock_get_file_client:
124128
mock_get_directory_client.side_effect = lambda x: dir_client_returns[x]
125-
126129
azure_fileshare.download(config, remote_folder, local_folder, recursive=True)
127130

128131
mock_get_file_client.assert_has_calls(
@@ -157,9 +160,11 @@ def test_upload_file(
157160
local_path = f"{local_folder}/{filename}"
158161
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
159162
mock_isdir.return_value = False
160-
config: dict = {}
161163

162-
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
164+
config: dict = {}
165+
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
166+
mock_share_client, "get_file_client"
167+
) as mock_get_file_client:
163168
azure_fileshare.upload(config, local_path, remote_path)
164169

165170
mock_get_file_client.assert_called_with(remote_path)
@@ -228,9 +233,11 @@ def test_upload_directory_non_recursive(
228233
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
229234
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
230235
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
231-
config: dict = {}
232236

233-
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
237+
config: dict = {}
238+
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
239+
mock_share_client, "get_file_client"
240+
) as mock_get_file_client:
234241
azure_fileshare.upload(config, local_folder, remote_folder, recursive=False)
235242

236243
mock_get_file_client.assert_called_with(f"{remote_folder}/a_file_1.csv")
@@ -252,9 +259,11 @@ def test_upload_directory_recursive(
252259
mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)]
253260
mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)]
254261
mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access
255-
config: dict = {}
256262

257-
with patch.object(mock_share_client, "get_file_client") as mock_get_file_client:
263+
config: dict = {}
264+
with patch.object(azure_fileshare, "_share_client") as mock_share_client, patch.object(
265+
mock_share_client, "get_file_client"
266+
) as mock_get_file_client:
258267
azure_fileshare.upload(config, local_folder, remote_folder, recursive=True)
259268

260269
mock_get_file_client.assert_has_calls(

mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A
102102

103103

104104
@pytest.fixture
105-
def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService:
105+
def azure_fileshare(azure_auth_service: AzureAuthService) -> AzureFileShareService:
106106
"""Creates a dummy AzureFileShareService for tests that require it."""
107107
with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"):
108108
return AzureFileShareService(
@@ -112,5 +112,5 @@ def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> Azu
112112
"storageAccountKey": "TEST_ACCOUNT_KEY",
113113
},
114114
global_config={},
115-
parent=config_persistence_service,
115+
parent=azure_auth_service,
116116
)

0 commit comments

Comments
 (0)