Skip to content

Commit

Permalink
feat(framework) Add utility function to parse IP address from `grpc.S…
Browse files Browse the repository at this point in the history
…ervicerContext` (#4947)
  • Loading branch information
chongshenng authored Feb 25, 2025
1 parent 72aab3f commit 4f525ad
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 84 deletions.
35 changes: 35 additions & 0 deletions src/py/flwr/common/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
"""Flower IP address utils."""


import re
import socket
from ipaddress import ip_address
from typing import Optional

import grpc

IPV6: int = 6


Expand Down Expand Up @@ -101,3 +104,35 @@ def is_port_in_use(address: str) -> bool:
return True

return False


def get_ip_address_from_servicer_context(context: grpc.ServicerContext) -> str:
"""Extract the client's IPv4 or IPv6 address from the gRPC ServicerContext.
Parameters
----------
context : grpc.ServicerContext
The gRPC ServicerContext object. The context.peer() returns a string like
"ipv4:127.0.0.1:56789" for IPv4 and "ipv6:[2001:db8::1]:54321" for IPv6.
Returns
-------
str
If one of the format matches, the function will return the client's IP address,
otherwise, it will raise a ValueError.
"""
peer: str = context.peer()
# Match IPv4: "ipv4:IP:port"
ipv4_match = re.match(r"^ipv4:(?P<ip>[^:]+):", peer)
if ipv4_match:
return ipv4_match.group("ip")

# Match IPv6: "ipv6:[IP]:port"
ipv6_match = re.match(r"^ipv6:\[(?P<ip>[^\]]+)\]:", peer)
if ipv6_match:
return ipv6_match.group("ip")

raise ValueError(
f"Unsupported peer address format: {peer} for the transport protocol. "
"The supported formats are ipv4:IP:port and ipv6:[IP]:port."
)
243 changes: 159 additions & 84 deletions src/py/flwr/common/address_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,51 @@
"""Flower IP address utils."""


from .address import parse_address
import pytest

from .address import get_ip_address_from_servicer_context, parse_address

def test_ipv4_correct() -> None:
"""Test if a correct IPv4 address is correctly parsed."""
# Prepare
addresses = [

@pytest.mark.parametrize(
"address, expected",
[
("127.0.0.1:8080", ("127.0.0.1", 8080, False)),
("0.0.0.0:12", ("0.0.0.0", 12, False)),
("0.0.0.0:65535", ("0.0.0.0", 65535, False)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


def test_ipv4_incorrect() -> None:
],
)
def test_ipv4_correct(address: str, expected: tuple[str, int, bool]) -> None:
"""Test if a correct IPv4 address is correctly parsed."""
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"address",
[
"127.0.0.1", # Missing port
"42.1.1.0:9988898", # Port number out of range
"0.0.0.0:-999999", # Negative port number
"0.0.0.0:-1", # Negative port number
"0.0.0.0:0", # Port number zero
"0.0.0.0:65536", # Port number out of range
],
)
def test_ipv4_incorrect(address: str) -> None:
"""Test if an incorrect IPv4 address returns None."""
# Prepare
addresses = [
"127.0.0.1",
"42.1.1.0:9988898",
"0.0.0.0:-999999",
"0.0.0.0:-1",
"0.0.0.0:0",
"0.0.0.0:65536",
]
# Execute
actual = parse_address(address)

for address in addresses:
# Execute
actual = parse_address(address)
# Assert
assert actual is None

# Assert
assert actual is None


def test_ipv6_correct() -> None:
"""Test if a correct IPv6 address is correctly parsed."""
# Prepare
addresses = [
@pytest.mark.parametrize(
"address, expected",
[
("[::1]:8080", ("::1", 8080, True)),
("[::]:12", ("::", 12, True)),
(
Expand All @@ -76,67 +77,141 @@ def test_ipv6_correct() -> None:
("[::]:123", ("::", 123, True)),
("[0:0:0:0:0:0:0:1]:80", ("0:0:0:0:0:0:0:1", 80, True)),
("[::1]:80", ("::1", 80, True)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


def test_ipv6_incorrect() -> None:
],
)
def test_ipv6_correct(address: str, expected: tuple[str, int, bool]) -> None:
"""Test if a correct IPv6 address is correctly parsed."""
# Execute
actual = parse_address(address)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"address",
[
"[2001:db8:3333:4444:5555:6666:7777:8888]:9988898", # Port number out of range
"[2001:db8:3333:4444:5555:6666:7777:8888]:-9988898", # Negative port number
"[2001:db8:3333:4444:5555:6666:7777:8888]:-1", # Negative port number
"[2001:db8:3333:4444:5555:6666:7777:8888]:0", # Port number zero
"[2001:db8:3333:4444:5555:6666:7777:8888]:65536", # Port number out of range
],
)
def test_ipv6_incorrect(address: str) -> None:
"""Test if an incorrect IPv6 address returns None."""
# Prepare
addresses = [
"2001:db8:3333:4444:5555:6666:7777:8888:9988898",
"2001:db8:3333:4444:5555:6666:7777:8888:-9988898",
"2001:db8:3333:4444:5555:6666:7777:8888:-1",
"2001:db8:3333:4444:5555:6666:7777:8888:0",
"2001:db8:3333:4444:5555:6666:7777:8888:65536",
]
# Execute
actual = parse_address(address)

for address in addresses:
# Execute
actual = parse_address(address)
# Assert
assert actual is None

# Assert
assert actual is None


def test_domain_correct() -> None:
"""Test if a correct domain address is correctly parsed."""
# Prepare
addresses = [
@pytest.mark.parametrize(
"address, expected",
[
("flower.ai:123", ("flower.ai", 123, None)),
("sub.flower.ai:123", ("sub.flower.ai", 123, None)),
("sub2.sub1.flower.ai:123", ("sub2.sub1.flower.ai", 123, None)),
("s5.s4.s3.s2.s1.flower.ai:123", ("s5.s4.s3.s2.s1.flower.ai", 123, None)),
("localhost:123", ("localhost", 123, None)),
("https://localhost:123", ("https://localhost", 123, None)),
("http://localhost:123", ("http://localhost", 123, None)),
]

for address, expected in addresses:
# Execute
actual = parse_address(address)
],
)
def test_domain_correct(address: str, expected: tuple[str, int, bool]) -> None:
"""Test if a correct domain address is correctly parsed."""
# Execute
actual = parse_address(address)

# Assert
assert actual == expected
# Assert
assert actual == expected


def test_domain_incorrect() -> None:
@pytest.mark.parametrize(
"address",
[
"flower.ai", # Missing port
"flower.ai:65536", # Port number out of range
],
)
def test_domain_incorrect(address: str) -> None:
"""Test if an incorrect domain address returns None."""
# Prepare
addresses = [
"flower.ai",
"flower.ai:65536",
]

for address in addresses:
# Execute
actual = parse_address(address)

# Assert
assert actual is None
# Execute
actual = parse_address(address)

# Assert
assert actual is None


class DummyContext:
"""Dummy context to mimic grpc.ServicerContext for testing purposes."""

def __init__(self, peer_str: str) -> None:
self._peer = peer_str

def peer(self) -> str:
"""."""
return self._peer


@pytest.mark.parametrize(
"peer_str, expected",
[
("ipv4:127.0.0.1:56789", "127.0.0.1"),
("ipv4:0.0.0.0:8080", "0.0.0.0"),
("ipv4:192.168.1.1:12345", "192.168.1.1"),
],
)
def test_servicer_ipv4_correct(peer_str: str, expected: str) -> None:
"""Test if a correct IPv4 address is correctly parsed from grpc.ServicerContext."""
# Prepare dummy context with the given peer string.
context = DummyContext(peer_str)

# Execute
actual = get_ip_address_from_servicer_context(context)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"peer_str, expected",
[
("ipv6:[2001:db8::1]:54321", "2001:db8::1"),
("ipv6:[::1]:8080", "::1"),
("ipv6:[fe80::1ff:fe23:4567:890a]:9999", "fe80::1ff:fe23:4567:890a"),
],
)
def test_servicer_ipv6_correct(peer_str: str, expected: str) -> None:
"""Test if a correct IPv6 address is correctly parsed from grpc.ServicerContext."""
# Prepare dummy context with the given peer string.
context = DummyContext(peer_str)

# Execute
actual = get_ip_address_from_servicer_context(context)

# Assert
assert actual == expected


@pytest.mark.parametrize(
"peer_str",
[
"invalid_string",
"ipv4:127.0.0.1", # missing port
"ipv6:2001:db8::1", # missing brackets and port
"ipv6:[2001:db8::1]56789", # missing colon after the bracket
"ipv6:2001:db8::1:54321", # missing brackets
"unix:/tmp/grpc.sock", # unix domain socket
"",
],
)
def test_servicer_incorrect_format(peer_str: str) -> None:
"""Test if an invalid grpc.ServicerContext.peer() string returns a ValueError."""
# Prepare dummy context with the given peer string.
context = DummyContext(peer_str)

# Execute and Assert
with pytest.raises(ValueError):
get_ip_address_from_servicer_context(context)

0 comments on commit 4f525ad

Please sign in to comment.