Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add utility function to parse IP address from grpc.ServicerContext #4947

Merged
merged 9 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/py/flwr/common/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
"""Flower IP address utils."""


import re
import socket
from ipaddress import ip_address
from typing import Optional

import grpc

IPV6: int = 6


Expand Down Expand Up @@ -101,3 +104,35 @@ def is_port_in_use(address: str) -> bool:
return True

return False


def get_ip_address_from_servicer_context(context: grpc.ServicerContext) -> str:
"""Extract the client's IPv4 or IPv6 address from the gRPC ServicerContext.

Parameters
----------
context : grpc.ServicerContext
The gRPC ServicerContext object. The context.peer() returns a string like
"ipv4:127.0.0.1:56789" for IPv4 and "ipv6:[2001:db8::1]:54321" for IPv6.

Returns
-------
str
If one of the format matches, the function will return the client's IP address,
otherwise, it will raise a ValueError.
"""
peer: str = context.peer()
# Match IPv4: "ipv4:IP:port"
ipv4_match = re.match(r"^ipv4:(?P<ip>[^:]+):", peer)
if ipv4_match:
return ipv4_match.group("ip")

# Match IPv6: "ipv6:[IP]:port"
ipv6_match = re.match(r"^ipv6:\[(?P<ip>[^\]]+)\]:", peer)
if ipv6_match:
return ipv6_match.group("ip")

raise ValueError(
f"Unsupported peer address format: {peer} for the transport protocol. "
"The supported formats are ipv4:IP:port and ipv6:[IP]:port."
)
243 changes: 159 additions & 84 deletions src/py/flwr/common/address_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,51 @@
"""Flower IP address utils."""


from .address import parse_address
import pytest

from .address import get_ip_address_from_servicer_context, parse_address

def test_ipv4_correct() -> None:
"""Test if a correct IPv4 address is correctly parsed."""
# Prepare
addresses = [

@pytest.mark.parametrize(
"address, expected",
[
("127.0.0.1:8080", ("127.0.0.1", 8080, False)),
("0.0.0.0:12", ("0.0.0.0", 12, False)),
("0.0.0.0:65535", ("0.0.0.0", 65535, False)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


def test_ipv4_incorrect() -> None:
],
)
def test_ipv4_correct(address: str, expected: tuple[str, int, bool]) -> None:
"""Test if a correct IPv4 address is correctly parsed."""
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"address",
[
"127.0.0.1", # Missing port
"42.1.1.0:9988898", # Port number out of range
"0.0.0.0:-999999", # Negative port number
"0.0.0.0:-1", # Negative port number
"0.0.0.0:0", # Port number zero
"0.0.0.0:65536", # Port number out of range
],
)
def test_ipv4_incorrect(address: str) -> None:
"""Test if an incorrect IPv4 address returns None."""
# Prepare
addresses = [
"127.0.0.1",
"42.1.1.0:9988898",
"0.0.0.0:-999999",
"0.0.0.0:-1",
"0.0.0.0:0",
"0.0.0.0:65536",
]
# Execute
actual = parse_address(address)

for address in addresses:
# Execute
actual = parse_address(address)
# Assert
assert actual is None

# Assert
assert actual is None


def test_ipv6_correct() -> None:
"""Test if a correct IPv6 address is correctly parsed."""
# Prepare
addresses = [
@pytest.mark.parametrize(
"address, expected",
[
("[::1]:8080", ("::1", 8080, True)),
("[::]:12", ("::", 12, True)),
(
Expand All @@ -76,67 +77,141 @@ def test_ipv6_correct() -> None:
("[::]:123", ("::", 123, True)),
("[0:0:0:0:0:0:0:1]:80", ("0:0:0:0:0:0:0:1", 80, True)),
("[::1]:80", ("::1", 80, True)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


def test_ipv6_incorrect() -> None:
],
)
def test_ipv6_correct(address: str, expected: tuple[str, int, bool]) -> None:
"""Test if a correct IPv6 address is correctly parsed."""
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"address",
[
"[2001:db8:3333:4444:5555:6666:7777:8888]:9988898", # Port number out of range
"[2001:db8:3333:4444:5555:6666:7777:8888]:-9988898", # Negative port number
"[2001:db8:3333:4444:5555:6666:7777:8888]:-1", # Negative port number
"[2001:db8:3333:4444:5555:6666:7777:8888]:0", # Port number zero
"[2001:db8:3333:4444:5555:6666:7777:8888]:65536", # Port number out of range
],
)
def test_ipv6_incorrect(address: str) -> None:
"""Test if an incorrect IPv6 address returns None."""
# Prepare
addresses = [
"2001:db8:3333:4444:5555:6666:7777:8888:9988898",
"2001:db8:3333:4444:5555:6666:7777:8888:-9988898",
"2001:db8:3333:4444:5555:6666:7777:8888:-1",
"2001:db8:3333:4444:5555:6666:7777:8888:0",
"2001:db8:3333:4444:5555:6666:7777:8888:65536",
]
# Execute
actual = parse_address(address)

for address in addresses:
# Execute
actual = parse_address(address)
# Assert
assert actual is None

# Assert
assert actual is None


def test_domain_correct() -> None:
"""Test if a correct domain address is correctly parsed."""
# Prepare
addresses = [
@pytest.mark.parametrize(
"address, expected",
[
("flower.ai:123", ("flower.ai", 123, None)),
("sub.flower.ai:123", ("sub.flower.ai", 123, None)),
("sub2.sub1.flower.ai:123", ("sub2.sub1.flower.ai", 123, None)),
("s5.s4.s3.s2.s1.flower.ai:123", ("s5.s4.s3.s2.s1.flower.ai", 123, None)),
("localhost:123", ("localhost", 123, None)),
("https://localhost:123", ("https://localhost", 123, None)),
("http://localhost:123", ("http://localhost", 123, None)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)
],
)
def test_domain_correct(address: str, expected: tuple[str, int, bool]) -> None:
"""Test if a correct domain address is correctly parsed."""
# Execute
actual = parse_address(address)

# Assert
assert actual == expected
# Assert
assert actual == expected


def test_domain_incorrect() -> None:
@pytest.mark.parametrize(
"address",
[
"flower.ai", # Missing port
"flower.ai:65536", # Port number out of range
],
)
def test_domain_incorrect(address: str) -> None:
"""Test if an incorrect domain address returns None."""
# Prepare
addresses = [
"flower.ai",
"flower.ai:65536",
]

for address in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual is None
# Execute
actual = parse_address(address)

# Assert
assert actual is None


class DummyContext:
"""Dummy context to mimic grpc.ServicerContext for testing purposes."""

def __init__(self, peer_str: str) -> None:
self._peer = peer_str

def peer(self) -> str:
"""."""
return self._peer


@pytest.mark.parametrize(
"peer_str, expected",
[
("ipv4:127.0.0.1:56789", "127.0.0.1"),
("ipv4:0.0.0.0:8080", "0.0.0.0"),
("ipv4:192.168.1.1:12345", "192.168.1.1"),
],
)
def test_servicer_ipv4_correct(peer_str: str, expected: str) -> None:
"""Test if a correct IPv4 address is correctly parsed from grpc.ServicerContext."""
# Prepare dummy context with the given peer string.
context = DummyContext(peer_str)

# Execute
actual = get_ip_address_from_servicer_context(context)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"peer_str, expected",
[
("ipv6:[2001:db8::1]:54321", "2001:db8::1"),
("ipv6:[::1]:8080", "::1"),
("ipv6:[fe80::1ff:fe23:4567:890a]:9999", "fe80::1ff:fe23:4567:890a"),
],
)
def test_servicer_ipv6_correct(peer_str: str, expected: str) -> None:
"""Test if a correct IPv6 address is correctly parsed from grpc.ServicerContext."""
# Prepare dummy context with the given peer string.
context = DummyContext(peer_str)

# Execute
actual = get_ip_address_from_servicer_context(context)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"peer_str",
[
"invalid_string",
"ipv4:127.0.0.1", # missing port
"ipv6:2001:db8::1", # missing brackets and port
"ipv6:[2001:db8::1]56789", # missing colon after the bracket
"ipv6:2001:db8::1:54321", # missing brackets
"unix:/tmp/grpc.sock", # unix domain socket
"",
],
)
def test_servicer_incorrect_format(peer_str: str) -> None:
"""Test if an invalid grpc.ServicerContext.peer() string returns a ValueError."""
# Prepare dummy context with the given peer string.
context = DummyContext(peer_str)

# Execute and Assert
with pytest.raises(ValueError):
get_ip_address_from_servicer_context(context)