Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws-backend: add support for running instances on multiple subnets #700

Merged
merged 6 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions runner_manager/backend/aws.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Literal, Optional
from copy import deepcopy
from random import shuffle
from typing import List, Literal, Optional, Sequence

from boto3 import client
from botocore.exceptions import ClientError
Expand All @@ -14,6 +16,7 @@
AWSConfig,
AwsInstance,
AWSInstanceConfig,
AwsSubnetListConfig,
Backends,
)
from runner_manager.models.runner import Runner
Expand All @@ -31,15 +34,68 @@

def create(self, runner: Runner) -> Runner:
"""Create a runner."""
instance_resource: AwsInstance = self.instance_config.configure_instance(runner)
try:
instance = self.client.run_instances(**instance_resource)
runner.instance_id = instance["Instances"][0]["InstanceId"]
except Exception as e:
log.error(e)
raise e
if self.instance_config.subnet_id and self.instance_config.subnet_configs:
raise Exception(

Check warning on line 38 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L37-L38

Added lines #L37 - L38 were not covered by tests
"Instance config contains both subnet_id and subnet_configs, only one allowed."
)
if len(self.instance_config.subnet_configs) > 0:
runner = self._create_from_subnet_config(

Check warning on line 42 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L41-L42

Added lines #L41 - L42 were not covered by tests
runner, self.instance_config.subnet_configs
)
log.warn(f"Instance id: {runner.instance_id}")

Check warning on line 45 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L45

Added line #L45 was not covered by tests
else:
instance_resource: AwsInstance = self.instance_config.configure_instance(

Check warning on line 47 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L47

Added line #L47 was not covered by tests
runner
)
try:
runner = self._create(runner, instance_resource)
log.warn(f"Instance id: {runner.instance_id}")
except Exception as e:
log.error(e)
raise e

Check warning on line 55 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L50-L55

Added lines #L50 - L55 were not covered by tests
return super().create(runner)

def _create_from_subnet_config(
self, runner: Runner, subnet_configs: Sequence[AwsSubnetListConfig]
) -> Runner:
# Randomize the order of the Subnets - very coarse load balancing.
# TODO: Skip subnets that have failed recently. Maybe with an increasing backoff.
order = list(range(len(subnet_configs)))
shuffle(order)
for idx, i in enumerate(order):
subnet_config = subnet_configs[i]
try:

Check warning on line 67 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L63-L67

Added lines #L63 - L67 were not covered by tests
# Copy the object to avoid modifying the object we were passed.
count = self.instance_config.max_count - self.instance_config.min_count
log.info(

Check warning on line 70 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L69-L70

Added lines #L69 - L70 were not covered by tests
f"Trying to launch {count} containers on subnet {subnet_config['subnet_id']}"
)
concrete_instance_config = deepcopy(self.instance_config)
concrete_instance_config.subnet_id = subnet_config["subnet_id"]
subnet_security_groups = subnet_config.get("security_group_ids", [])
if subnet_security_groups:
security_groups = list(concrete_instance_config.security_group_ids)
security_groups += subnet_security_groups
concrete_instance_config.security_group_ids = security_groups
instance_resource: AwsInstance = (

Check warning on line 80 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L73-L80

Added lines #L73 - L80 were not covered by tests
concrete_instance_config.configure_instance(runner)
)
return self._create(runner, instance_resource)
except Exception as e:
log.warn(

Check warning on line 85 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L83-L85

Added lines #L83 - L85 were not covered by tests
f"Creating instance in subnet {subnet_config['subnet_id']} failed with '{e}'. Retrying with another subnet."
)
if idx >= len(order) - 1:
raise e
return runner

Check warning on line 90 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L88-L90

Added lines #L88 - L90 were not covered by tests

def _create(self, runner: Runner, instance_resource: AwsInstance) -> Runner:
instance = self.client.run_instances(**instance_resource)

Check warning on line 93 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L93

Added line #L93 was not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tcarmet, the linter is complaining about this line, but it's not a addition, just moved from the original code. Do you have a preferred way to fix it (I've silenced some warnings before, but I'm not 100% that that's the best plan).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know, I'm not sure how to fix it as well, but let's merge it as is, I think it's minor and it won't complain on the next lint checks.

# Allow this to raise exception as we don't want to track an instance that
# doesn't have an instance ID.
runner.instance_id = instance["Instances"][0]["InstanceId"] # type: ignore
return runner

Check warning on line 97 in runner_manager/backend/aws.py

View check run for this annotation

Codecov / codecov/patch

runner_manager/backend/aws.py#L96-L97

Added lines #L96 - L97 were not covered by tests

def delete(self, runner: Runner):
"""Delete a runner."""
if runner.instance_id:
Expand Down
13 changes: 11 additions & 2 deletions runner_manager/models/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from pathlib import Path
from string import Template
from typing import Dict, List, Literal, Optional, Sequence, TypedDict
from typing import Dict, List, Literal, NotRequired, Optional, Sequence, TypedDict

from mypy_boto3_ec2.literals import (
InstanceMetadataTagsStateType,
Expand Down Expand Up @@ -134,6 +134,14 @@ class AWSConfig(BackendConfig):
region: str = "us-west-2"


AwsSubnetListConfig = TypedDict(
"AwsSubnetListConfig",
{
"subnet_id": str,
"security_group_ids": NotRequired[Sequence[str]],
},
)

AwsInstance = TypedDict(
"AwsInstance",
{
Expand All @@ -157,7 +165,7 @@ class AWSInstanceConfig(InstanceConfig):

image: str = "ami-0735c191cf914754d" # Ubuntu 22.04 for us-west-2
instance_type: InstanceTypeType = "t3.micro"
subnet_id: str
subnet_id: str = ""
security_group_ids: Sequence[str] = []
max_count: int = 1
min_count: int = 1
Expand All @@ -167,6 +175,7 @@ class AWSInstanceConfig(InstanceConfig):
disk_size_gb: int = 20
iam_instance_profile_arn: str = ""
instance_metadata_tags: InstanceMetadataTagsStateType = "disabled"
subnet_configs: Sequence[AwsSubnetListConfig] = []

def configure_instance(self, runner: Runner) -> AwsInstance:
"""Configure instance."""
Expand Down
107 changes: 104 additions & 3 deletions tests/unit/backend/test_aws.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from unittest.mock import patch

from mypy_boto3_ec2.type_defs import TagTypeDef
from pytest import fixture, mark, raises
Expand Down Expand Up @@ -36,10 +37,68 @@ def aws_group(settings) -> RunnerGroup:
return runner_group


@fixture()
def aws_multi_subnet_group(settings) -> RunnerGroup:
config = AWSConfig()
subnet_id = os.getenv("AWS_SUBNET_ID", "")
runner_group: RunnerGroup = RunnerGroup(
id=3,
name="default",
organization="test",
manager=settings.name,
backend=AWSBackend(
name=Backends.aws,
config=config,
instance_config=AWSInstanceConfig(
subnet_configs=[
{
"subnet_id": subnet_id,
"security_group_ids": [],
}
]
),
),
labels=[
"label",
],
)
return runner_group


@fixture()
def aws_multi_subnet_group_invalid_subnets(settings) -> RunnerGroup:
config = AWSConfig()
runner_group: RunnerGroup = RunnerGroup(
id=3,
name="default",
organization="test",
manager=settings.name,
backend=AWSBackend(
name=Backends.aws,
config=config,
instance_config=AWSInstanceConfig(
subnet_configs=[
{
"subnet_id": "does-not-exist",
},
{
"subnet_id": "also-does-not-exist",
},
]
),
),
labels=[
"label",
],
)
return runner_group


@fixture()
def aws_runner(runner: Runner, aws_group: RunnerGroup) -> Runner:
# Cleanup and return a runner for testing
aws_group.backend.delete(runner)
runner.instance_id = None
return runner


Expand Down Expand Up @@ -70,7 +129,10 @@ def test_aws_instance_config(runner: Runner):
assert instance["TagSpecifications"][1]["ResourceType"] == "volume"


@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found")
@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_create_delete(aws_runner, aws_group):
runner = aws_group.backend.create(aws_runner)
assert runner.instance_id is not None
Expand All @@ -81,7 +143,10 @@ def test_create_delete(aws_runner, aws_group):
Runner.find(Runner.instance_id == runner.instance_id).first()


@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found")
@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_list(aws_runner, aws_group):
runner = aws_group.backend.create(aws_runner)
runners = aws_group.backend.list()
Expand All @@ -91,7 +156,10 @@ def test_list(aws_runner, aws_group):
aws_group.backend.get(runner.instance_id)


@mark.skipif(not os.getenv("AWS_ACCESS_KEY_ID"), reason="AWS credentials not found")
@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_update(aws_runner, aws_group):
runner = aws_group.backend.create(aws_runner)
runner.labels = [RunnerLabel(name="test", type="custom")]
Expand All @@ -100,3 +168,36 @@ def test_update(aws_runner, aws_group):
aws_group.backend.delete(runner)
with raises(NotFoundError):
aws_group.backend.get(runner.instance_id)


@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_create_delete_multi_subnet(aws_runner, aws_multi_subnet_group):
runner = aws_multi_subnet_group.backend.create(aws_runner)
print(f"{runner.instance_id}")
assert runner.instance_id is not None
assert runner.backend == "aws"
assert Runner.find(Runner.instance_id == runner.instance_id).first() == runner
aws_multi_subnet_group.backend.delete(runner)
with raises(NotFoundError):
Runner.find(Runner.instance_id == runner.instance_id).first()


@mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID") and not os.getenv("AWS_PROFILE"),
reason="AWS credentials not found",
)
def test_create_delete_multi_subnet_invalid_subnets(
aws_runner, aws_multi_subnet_group_invalid_subnets
):
with patch.object(
AWSBackend,
"_create",
wraps=aws_multi_subnet_group_invalid_subnets.backend._create,
) as mock:
with raises(Exception):
aws_multi_subnet_group_invalid_subnets.backend.create(aws_runner)
# Check that the code tries once for each subnet.
assert mock.call_count == 2
Loading