Skip to content

Commit

Permalink
Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up
Browse files Browse the repository at this point in the history
Support changing manual configuration while running
  • Loading branch information
AlexCheema authored Dec 28, 2024
2 parents 496a3b4 + b003292 commit a174c78
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 29 deletions.
74 changes: 52 additions & 22 deletions exo/networking/manual/manual_discovery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor

from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
Expand All @@ -13,28 +15,25 @@ def __init__(
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle

if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)

self.listen_task = None

self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)

self._cached_peers: Dict[str, PeerConfig] = {}
self._last_modified_time: Optional[float] = None
self._file_executor = ThreadPoolExecutor(max_workers=1)

async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())

async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
if self.listen_task: self.listen_task.cancel()
self._file_executor.shutdown(wait=True)

async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
Expand All @@ -47,7 +46,9 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
peers_from_config = await self._get_peers()
new_known_peers = {}
for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
Expand All @@ -57,15 +58,44 @@ async def task_find_peers_from_config(self):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
new_known_peers[peer_id] = peer
elif DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
self.known_peers = new_known_peers
await asyncio.sleep(1.0)

if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

async def _get_peers(self):
try:
loop = asyncio.get_running_loop()
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)

if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
return self._cached_peers

topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)

if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file "
f"{self.network_config_path}. Please run with `node_id` set to "
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
)

peers_in_network = topology.peers
peers_in_network.pop(self.node_id)

self._cached_peers = peers_in_network
self._last_modified_time = current_mtime

return peers_in_network

except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(f"Error when loading network config file from {self.network_config_path}. "
f"Please update the config file in order to successfully discover peers. "
f"Exception: {e}")
return self._cached_peers
2 changes: 1 addition & 1 deletion exo/networking/manual/test_data/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
}
}
}
}
}
90 changes: 84 additions & 6 deletions exo/networking/manual/test_manual_discovery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
Expand All @@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
_ = self.discovery1.start()
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
await self.discovery1.start()

async def asyncTearDown(self):
await self.discovery1.stop()
Expand All @@ -33,8 +38,16 @@ async def asyncSetUp(self):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
)
await self.discovery1.start()
await self.discovery2.start()

Expand Down Expand Up @@ -63,8 +76,16 @@ async def asyncSetUp(self):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
await self.discovery1.start()
await self.discovery2.start()

Expand Down Expand Up @@ -98,6 +119,63 @@ async def test_grpc_discovery(self):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())

async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)

# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)

try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
},
},
}
}

with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)

node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()

try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)

updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)

for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())

finally:
await server3.stop()

finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)

# Wait for the config to be reloaded again
await asyncio.sleep(1.5)

updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(updated_peers), 1)


if __name__ == "__main__":
asyncio.run(unittest.main())

0 comments on commit a174c78

Please sign in to comment.