diff --git a/runner_manager/backend/aws.py b/runner_manager/backend/aws.py index 60cafc08..2a53aa0d 100644 --- a/runner_manager/backend/aws.py +++ b/runner_manager/backend/aws.py @@ -1,6 +1,6 @@ from copy import deepcopy -from typing import List, Literal, Optional, Sequence from random import shuffle +from typing import List, Literal, Optional, Sequence from boto3 import client from botocore.exceptions import ClientError @@ -15,8 +15,8 @@ from runner_manager.models.backend import ( AWSConfig, AwsInstance, - AwsSubnetListConfig, AWSInstanceConfig, + AwsSubnetListConfig, Backends, ) from runner_manager.models.runner import Runner @@ -35,12 +35,18 @@ def client(self) -> EC2Client: def create(self, runner: Runner) -> Runner: """Create a runner.""" if self.instance_config.subnet_id and self.instance_config.subnet_configs: - raise Exception("Instance config contains both subnet_id and subnet_configs, only one allowed.") + raise Exception( + "Instance config contains both subnet_id and subnet_configs, only one allowed." + ) if len(self.instance_config.subnet_configs) > 0: - runner = self._create_from_subnet_config(runner, self.instance_config.subnet_configs) + runner = self._create_from_subnet_config( + runner, self.instance_config.subnet_configs + ) log.warn(f"Instance id: {runner.instance_id}") else: - instance_resource: AwsInstance = self.instance_config.configure_instance(runner) + instance_resource: AwsInstance = self.instance_config.configure_instance( + runner + ) try: runner = self._create(runner, instance_resource) log.warn(f"Instance id: {runner.instance_id}") @@ -49,7 +55,9 @@ def create(self, runner: Runner) -> Runner: raise e return super().create(runner) - def _create_from_subnet_config(self, runner: Runner, subnet_configs: Sequence[AwsSubnetListConfig]) -> Runner: + def _create_from_subnet_config( + self, runner: Runner, subnet_configs: Sequence[AwsSubnetListConfig] + ) -> Runner: # Randomize the order of the Subnets - very coarse load balancing. # TODO: Skip subnets that have failed recently. Maybe with an increasing backoff. order = list(range(len(subnet_configs))) @@ -61,16 +69,22 @@ def _create_from_subnet_config(self, runner: Runner, subnet_configs: Sequence[Aw try: # Copy the object to avoid modifying the object we were passed. count = self.instance_config.max_count - self.instance_config.min_count - log.info(f"Trying to launch {count} containers on subnet {subnet_config['subnet_id']}") + log.info( + f"Trying to launch {count} containers on subnet {subnet_config['subnet_id']}" + ) concrete_instance_config = deepcopy(self.instance_config) concrete_instance_config.subnet_id = subnet_config["subnet_id"] concrete_instance_config.security_group_ids.extend( subnet_config.get("security_group_ids", []) ) - instance_resource: AwsInstance = concrete_instance_config.configure_instance(runner) + instance_resource: AwsInstance = ( + concrete_instance_config.configure_instance(runner) + ) return self._create(runner, instance_resource) except Exception as e: - log.warn(f"Creating instance in subnet {subnet_config['subnet_id']} failed with '{e}'. Retrying with another subnet.") + log.warn( + f"Creating instance in subnet {subnet_config['subnet_id']} failed with '{e}'. Retrying with another subnet." + ) if idx >= len(order) - 1: raise e return runner diff --git a/runner_manager/models/backend.py b/runner_manager/models/backend.py index b8deb84c..98bb60f1 100644 --- a/runner_manager/models/backend.py +++ b/runner_manager/models/backend.py @@ -1,7 +1,7 @@ from enum import Enum from pathlib import Path from string import Template -from typing import Dict, List, Literal, Optional, Sequence, TypedDict, NotRequired +from typing import Dict, List, Literal, NotRequired, Optional, Sequence, TypedDict from mypy_boto3_ec2.literals import ( InstanceMetadataTagsStateType, @@ -139,7 +139,7 @@ class AWSConfig(BackendConfig): { "subnet_id": str, "security_group_ids": NotRequired[Sequence[str]], - } + }, ) AwsInstance = TypedDict( diff --git a/tests/unit/backend/test_aws.py b/tests/unit/backend/test_aws.py index ffebebf1..2458b4e7 100644 --- a/tests/unit/backend/test_aws.py +++ b/tests/unit/backend/test_aws.py @@ -1,8 +1,8 @@ import os +from unittest.mock import patch from mypy_boto3_ec2.type_defs import TagTypeDef from pytest import fixture, mark, raises -from unittest.mock import patch from redis_om import NotFoundError from runner_manager.backend.aws import AWSBackend @@ -36,6 +36,7 @@ def aws_group(settings) -> RunnerGroup: ) return runner_group + @fixture() def aws_multi_subnet_group(settings) -> RunnerGroup: config = AWSConfig() @@ -63,6 +64,7 @@ def aws_multi_subnet_group(settings) -> RunnerGroup: ) return runner_group + @fixture() def aws_multi_subnet_group_invalid_subnets(settings) -> RunnerGroup: config = AWSConfig() @@ -81,7 +83,7 @@ def aws_multi_subnet_group_invalid_subnets(settings) -> RunnerGroup: }, { "subnet_id": "also-does-not-exist", - } + }, ] ), ), @@ -127,7 +129,10 @@ def test_aws_instance_config(runner: Runner): assert instance["TagSpecifications"][1]["ResourceType"] == "volume" -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_create_delete(aws_runner, aws_group): runner = aws_group.backend.create(aws_runner) assert runner.instance_id is not None @@ -138,7 +143,10 @@ def test_create_delete(aws_runner, aws_group): Runner.find(Runner.instance_id == runner.instance_id).first() -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_list(aws_runner, aws_group): runner = aws_group.backend.create(aws_runner) runners = aws_group.backend.list() @@ -148,7 +156,10 @@ def test_list(aws_runner, aws_group): aws_group.backend.get(runner.instance_id) -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_update(aws_runner, aws_group): runner = aws_group.backend.create(aws_runner) runner.labels = [RunnerLabel(name="test", type="custom")] @@ -159,7 +170,10 @@ def test_update(aws_runner, aws_group): aws_group.backend.get(runner.instance_id) -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), reason="AWS credentials not found") +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) def test_create_delete_multi_subnet(aws_runner, aws_multi_subnet_group): runner = aws_multi_subnet_group.backend.create(aws_runner) print(f"{runner.instance_id}") @@ -171,10 +185,19 @@ def test_create_delete_multi_subnet(aws_runner, aws_multi_subnet_group): Runner.find(Runner.instance_id == runner.instance_id).first() -@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), reason="AWS credentials not found") -def test_create_delete_multi_subnet_invalid_subnets(aws_runner, aws_multi_subnet_group_invalid_subnets): - with patch.object(AWSBackend, '_create', wraps=aws_multi_subnet_group_invalid_subnets.backend._create) as mock: +@mark.skipif( + not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"), + reason="AWS credentials not found", +) +def test_create_delete_multi_subnet_invalid_subnets( + aws_runner, aws_multi_subnet_group_invalid_subnets +): + with patch.object( + AWSBackend, + "_create", + wraps=aws_multi_subnet_group_invalid_subnets.backend._create, + ) as mock: with raises(Exception): aws_multi_subnet_group_invalid_subnets.backend.create(aws_runner) # Check that the code tries once for each subnet. - assert mock.call_count == 2 \ No newline at end of file + assert mock.call_count == 2