From 98118babaeaf5101b0f31be1382523533067737d Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Thu, 24 Oct 2024 09:10:10 +0700 Subject: [PATCH 01/13] allow update to manual discovery file re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running --- exo/networking/manual/manual_discovery.py | 148 ++++++++++++------ .../manual/test_data/test_config.json | 2 +- .../manual/test_manual_discovery.py | 55 ++++++- 3 files changed, 148 insertions(+), 57 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 45bc14f3d..f571ce55e 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -9,63 +9,107 @@ class ManualDiscovery(Discovery): - def __init__( - self, - network_config_path: str, - node_id: str, - create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], - ): - self.topology = NetworkTopology.from_path(network_config_path) - self.create_peer_handle = create_peer_handle + def __init__( + self, + network_config_path: str, + node_id: str, + create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], + ): + self.network_config_path = network_config_path + self.node_id = node_id + self.create_peer_handle = create_peer_handle - if node_id not in self.topology.peers: - raise ValueError( - f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}" - ) + if node_id not in self.topology.peers: + raise ValueError( + f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}" + ) - self.listen_task = None + self.listen_task = None + self.known_peers: Dict[str, PeerHandle] = {} - self.known_peers: Dict[str, PeerHandle] = {} - self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers - self.peers_in_network.pop(node_id) + async def start(self) -> None: + self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) - async def start(self) -> None: - self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) + async def stop(self) -> None: + if self.listen_task: + self.listen_task.cancel() - async def stop(self) -> None: - if self.listen_task: - self.listen_task.cancel() + async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: + if wait_for_peers > 0: + while len(self.known_peers) < wait_for_peers: + if DEBUG_DISCOVERY >= 2: + print( + f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers..." + ) + await asyncio.sleep(0.1) + if DEBUG_DISCOVERY >= 2: + print( + f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}" + ) + return list(self.known_peers.values()) - async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: - if wait_for_peers > 0: - while len(self.known_peers) < wait_for_peers: - if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...") - await asyncio.sleep(0.1) - if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}") - return list(self.known_peers.values()) + async def task_find_peers_from_config(self): + if DEBUG_DISCOVERY >= 2: + print("Starting task to find peers from config...") + while True: + peers = self._get_peers().items() + for peer_id, peer_config in peers: + try: + if DEBUG_DISCOVERY >= 2: + print( + f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}" + ) + peer = self.known_peers.get(peer_id) + if not peer: + if DEBUG_DISCOVERY >= 2: + print(f"{peer_id=} not found in known peers. Adding.") + peer = self.create_peer_handle( + peer_id, + f"{peer_config.address}:{peer_config.port}", + peer_config.device_capabilities, + ) + peer = self.create_peer_handle( + peer_id, + f"{peer_config.address}:{peer_config.port}", + peer_config.device_capabilities, + ) + is_healthy = await peer.health_check() + if is_healthy: + if DEBUG_DISCOVERY >= 2: + print( + f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy." + ) + self.known_peers[peer_id] = peer + else: + if DEBUG_DISCOVERY >= 2: + print( + f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy." + ) + try: + del self.known_peers[peer_id] + except KeyError: + pass + except Exception as e: + if DEBUG_DISCOVERY >= 2: + print( + f"Exception occured when attempting to add {peer_id=}: {e}" + ) + await asyncio.sleep(1.0) - async def task_find_peers_from_config(self): - if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") - while True: - for peer_id, peer_config in self.peers_in_network.items(): - try: - if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") - peer = self.known_peers.get(peer_id) - if not peer: - if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.") - peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities) - is_healthy = await peer.health_check() - if is_healthy: - if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") - self.known_peers[peer_id] = peer - else: - if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.") - try: - del self.known_peers[peer_id] - except KeyError: - pass - except Exception as e: - if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") - await asyncio.sleep(1.0) + if DEBUG_DISCOVERY >= 2: + print( + f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}" + ) - if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}") + def _get_peers(self): + topology = NetworkTopology.from_path(self.network_config_path) + + if self.node_id not in topology.peers: + raise ValueError( + f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}" + ) + + peers_in_network: Dict[str, PeerConfig] = topology.peers + peers_in_network.pop(self.node_id) + + return peers_in_network diff --git a/exo/networking/manual/test_data/test_config.json b/exo/networking/manual/test_data/test_config.json index b50ef635a..54eced720 100644 --- a/exo/networking/manual/test_data/test_config.json +++ b/exo/networking/manual/test_data/test_config.json @@ -29,4 +29,4 @@ } } } -} +} \ No newline at end of file diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py index 69f45fa16..39af10bfa 100644 --- a/exo/networking/manual/test_manual_discovery.py +++ b/exo/networking/manual/test_manual_discovery.py @@ -1,3 +1,4 @@ +import json import asyncio import unittest from unittest import mock @@ -44,9 +45,9 @@ async def asyncTearDown(self): async def test_discovery(self): peers1 = await self.discovery1.discover_peers(wait_for_peers=1) - assert len(peers1) == 1 + self.assertEqual(len(peers1), 1) peers2 = await self.discovery2.discover_peers(wait_for_peers=1) - assert len(peers2) == 1 + self.assertEqual(len(peers2), 1) # connect has to be explicitly called after discovery self.peer1.connect.assert_not_called() @@ -76,9 +77,9 @@ async def asyncTearDown(self): async def test_grpc_discovery(self): peers1 = await self.discovery1.discover_peers(wait_for_peers=1) - assert len(peers1) == 1 + self.assertEqual(len(peers1), 1) peers2 = await self.discovery2.discover_peers(wait_for_peers=1) - assert len(peers2) == 1 + self.assertEqual(len(peers2), 1) # Connect await peers1[0].connect() @@ -98,6 +99,52 @@ async def test_grpc_discovery(self): self.assertFalse(await peers1[0].is_connected()) self.assertFalse(await peers2[0].is_connected()) + async def test_dynamic_config_update(self): + initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(initial_peers), 1) + + # Save original config for cleanup + with open(root_path, "r") as f: + original_config = json.load(f) + + try: + updated_config = { + "peers": { + **original_config["peers"], + "node3": { + "address": "localhost", + "port": 50053, + "device_capabilities": {"model": "Unknown Model", "chip": "Unknown Chip", "memory": 0, "flops": {"fp32": 0, "fp16": 0, "int8": 0}}, + }, + } + } + + with open(root_path, "w") as f: + json.dump(updated_config, f, indent=2) + + node3 = mock.AsyncMock(spec=Node) + server3 = GRPCServer(node3, "localhost", 50053) + await server3.start() + + try: + # Wait for the config to be reloaded + await asyncio.sleep(1.5) + + updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) + self.assertEqual(len(updated_peers), 2) + + for peer in updated_peers: + await peer.connect() + self.assertTrue(await peer.is_connected()) + + finally: + await server3.stop() + + finally: + # Restore the original config file + with open(root_path, "w") as f: + json.dump(original_config, f, indent=2) + if __name__ == "__main__": asyncio.run(unittest.main()) From 2e8227fccbb06e0e137e00e7d78d8d464724867c Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Thu, 24 Oct 2024 09:16:46 +0700 Subject: [PATCH 02/13] handle intermediate state for when config is being updated --- exo/networking/manual/manual_discovery.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index f571ce55e..961722a7d 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -101,15 +101,17 @@ async def task_find_peers_from_config(self): f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}" ) - def _get_peers(self): + def _get_peers(self): + try: topology = NetworkTopology.from_path(self.network_config_path) if self.node_id not in topology.peers: - raise ValueError( - f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}" - ) + raise ValueError(f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}") peers_in_network: Dict[str, PeerConfig] = topology.peers peers_in_network.pop(self.node_id) + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}") + peers_in_network = {} return peers_in_network From e5eb3259a59065a4ee947c567a307ec1ad059b26 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Thu, 24 Oct 2024 09:43:20 +0700 Subject: [PATCH 03/13] handle when a peer is removed from config, so the known_peers dict gets updated accordingly --- exo/networking/manual/manual_discovery.py | 81 +++++++++++++++---- .../manual/test_manual_discovery.py | 7 ++ 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 961722a7d..815ca0f3e 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -1,6 +1,7 @@ +import os import asyncio from exo.networking.discovery import Discovery -from typing import Dict, List, Callable +from typing import Dict, List, Callable, Optional from exo.topology.device_capabilities import DeviceCapabilities from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig @@ -27,12 +28,16 @@ def __init__( self.listen_task = None self.known_peers: Dict[str, PeerHandle] = {} - async def start(self) -> None: + self._cached_peers: Dict[str, PeerConfig] = {} + self._last_modified_time: Optional[float] = None + + async def start(self) -> None: self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) + self.cleanup_task = asyncio.create_task(self.task_clean_up_peers_from_config()) - async def stop(self) -> None: - if self.listen_task: - self.listen_task.cancel() + async def stop(self) -> None: + if self.listen_task: self.listen_task.cancel() + if self.cleanup_task: self.cleanup_task.cancel() async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if wait_for_peers > 0: @@ -48,6 +53,41 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: ) return list(self.known_peers.values()) + async def task_clean_up_peers_from_config(self): + if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...") + while True: + peers_from_config = self._get_peers() + if peers_from_config: + peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config] + + for peer in peers_to_remove: + if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...") + try: del self.known_peers[peer] + except KeyError: pass + + await asyncio.sleep(1.0) + + async def task_find_peers_from_config(self): + if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") + while True: + for peer_id, peer_config in self._get_peers().items(): + try: + if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") + peer = self.known_peers.get(peer_id) + if not peer: + if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.") + peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities) + is_healthy = await peer.health_check() + if is_healthy: + if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") + self.known_peers[peer_id] = peer + else: + if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.") + try: del self.known_peers[peer_id] + except KeyError: pass + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") + await asyncio.sleep(1.0) async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") @@ -103,15 +143,28 @@ async def task_find_peers_from_config(self): def _get_peers(self): try: - topology = NetworkTopology.from_path(self.network_config_path) + current_mtime = os.path.getmtime(self.network_config_path) - if self.node_id not in topology.peers: - raise ValueError(f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}") + if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time: + return self._cached_peers - peers_in_network: Dict[str, PeerConfig] = topology.peers - peers_in_network.pop(self.node_id) - except Exception as e: - if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}") - peers_in_network = {} + topology = NetworkTopology.from_path(self.network_config_path) - return peers_in_network + if self.node_id not in topology.peers: + raise ValueError( + f"Node ID {self.node_id} not found in network config file " + f"{self.network_config_path}. Please run with `node_id` set to " + f"one of the keys in the config file: {[k for k, _ in topology.peers]}" + ) + + peers_in_network: Dict[str, PeerConfig] = topology.peers + peers_in_network.pop(self.node_id) + + self._cached_peers = peers_in_network + self._last_modified_time = current_mtime + + return peers_in_network + + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}") + return self._cached_peers diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py index 39af10bfa..9efe7d1b8 100644 --- a/exo/networking/manual/test_manual_discovery.py +++ b/exo/networking/manual/test_manual_discovery.py @@ -145,6 +145,13 @@ async def test_dynamic_config_update(self): with open(root_path, "w") as f: json.dump(original_config, f, indent=2) + # Wait for the config to be reloaded again + await asyncio.sleep(1.5) + + updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(updated_peers), 1) + + if __name__ == "__main__": asyncio.run(unittest.main()) From 8d24df2b4b52346b3b87cba247a5b8fefac39b18 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Thu, 24 Oct 2024 17:03:50 +0700 Subject: [PATCH 04/13] fix test runtime warning --- .../manual/test_manual_discovery.py | 301 ++++++++++-------- 1 file changed, 169 insertions(+), 132 deletions(-) diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py index 9efe7d1b8..8af24ce96 100644 --- a/exo/networking/manual/test_manual_discovery.py +++ b/exo/networking/manual/test_manual_discovery.py @@ -12,146 +12,183 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - self.peer1 = mock.AsyncMock() - self.peer1.connect = mock.AsyncMock() - self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1) - _ = self.discovery1.start() - - async def asyncTearDown(self): - await self.discovery1.stop() - - async def test_discovery(self): - peers1 = await self.discovery1.discover_peers(wait_for_peers=0) - assert len(peers1) == 0 - - self.peer1.connect.assert_not_called() + async def asyncSetUp(self): + self.peer1 = mock.AsyncMock(spec=Node) + self.peer1.connect = mock.AsyncMock() + self.server1 = GRPCServer(self.peer1, "localhost", 8000) + await self.server1.start() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle( + peer_id, address, device_capabilities + ), + ) + await self.discovery1.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.server1.stop() + + async def test_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=0) + self.assertEqual(len(peers1), 0) + + self.peer1.connect.assert_not_called() class TestManualDiscovery(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - self.peer1 = mock.AsyncMock() - self.peer2 = mock.AsyncMock() - self.peer1.connect = mock.AsyncMock() - self.peer2.connect = mock.AsyncMock() - self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1) - self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2) - await self.discovery1.start() - await self.discovery2.start() - - async def asyncTearDown(self): - await self.discovery1.stop() - await self.discovery2.stop() - - async def test_discovery(self): - peers1 = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers1), 1) - peers2 = await self.discovery2.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers2), 1) - - # connect has to be explicitly called after discovery - self.peer1.connect.assert_not_called() - self.peer2.connect.assert_not_called() + async def asyncSetUp(self): + self.peer1 = mock.AsyncMock() + self.peer2 = mock.AsyncMock() + self.peer1.connect = mock.AsyncMock() + self.peer2.connect = mock.AsyncMock() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, + ) + self.discovery2 = ManualDiscovery( + root_path, + "node2", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2, + ) + await self.discovery1.start() + await self.discovery2.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.discovery2.stop() + + async def test_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(peers1), 1) + peers2 = await self.discovery2.discover_peers(wait_for_peers=1) + self.assertEqual(len(peers2), 1) + + # connect has to be explicitly called after discovery + self.peer1.connect.assert_not_called() + self.peer2.connect.assert_not_called() class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - config = NetworkTopology.from_path(root_path) - - self.node1 = mock.AsyncMock(spec=Node) - self.node2 = mock.AsyncMock(spec=Node) - self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port) - self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port) - await self.server1.start() - await self.server2.start() - self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) - self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) - await self.discovery1.start() - await self.discovery2.start() - - async def asyncTearDown(self): - await self.discovery1.stop() - await self.discovery2.stop() - await self.server1.stop() - await self.server2.stop() - - async def test_grpc_discovery(self): - peers1 = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers1), 1) - peers2 = await self.discovery2.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers2), 1) - - # Connect - await peers1[0].connect() - await peers2[0].connect() - self.assertTrue(await peers1[0].is_connected()) - self.assertTrue(await peers2[0].is_connected()) - - # Kill server1 - await self.server1.stop() - - self.assertTrue(await peers1[0].is_connected()) - self.assertFalse(await peers2[0].is_connected()) - - # Kill server2 - await self.server2.stop() - - self.assertFalse(await peers1[0].is_connected()) - self.assertFalse(await peers2[0].is_connected()) - - async def test_dynamic_config_update(self): - initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(initial_peers), 1) - - # Save original config for cleanup - with open(root_path, "r") as f: - original_config = json.load(f) - - try: - updated_config = { - "peers": { - **original_config["peers"], - "node3": { - "address": "localhost", - "port": 50053, - "device_capabilities": {"model": "Unknown Model", "chip": "Unknown Chip", "memory": 0, "flops": {"fp32": 0, "fp16": 0, "int8": 0}}, - }, - } - } - - with open(root_path, "w") as f: - json.dump(updated_config, f, indent=2) - - node3 = mock.AsyncMock(spec=Node) - server3 = GRPCServer(node3, "localhost", 50053) - await server3.start() - - try: - # Wait for the config to be reloaded + async def asyncSetUp(self): + config = NetworkTopology.from_path(root_path) + + self.node1 = mock.AsyncMock(spec=Node) + self.node2 = mock.AsyncMock(spec=Node) + self.server1 = GRPCServer( + self.node1, config.peers["node1"].address, config.peers["node1"].port + ) + self.server2 = GRPCServer( + self.node2, config.peers["node2"].address, config.peers["node2"].port + ) + await self.server1.start() + await self.server2.start() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle( + peer_id, address, description, device_capabilities + ), + ) + self.discovery2 = ManualDiscovery( + root_path, + "node2", + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle( + peer_id, address, description, device_capabilities + ), + ) + await self.discovery1.start() + await self.discovery2.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.discovery2.stop() + await self.server1.stop() + await self.server2.stop() + + async def test_grpc_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(peers1), 1) + peers2 = await self.discovery2.discover_peers(wait_for_peers=1) + self.assertEqual(len(peers2), 1) + + # Connect + await peers1[0].connect() + await peers2[0].connect() + self.assertTrue(await peers1[0].is_connected()) + self.assertTrue(await peers2[0].is_connected()) + + # Kill server1 + await self.server1.stop() + + self.assertTrue(await peers1[0].is_connected()) + self.assertFalse(await peers2[0].is_connected()) + + # Kill server2 + await self.server2.stop() + + self.assertFalse(await peers1[0].is_connected()) + self.assertFalse(await peers2[0].is_connected()) + + async def test_dynamic_config_update(self): + initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(initial_peers), 1) + + # Save original config for cleanup + with open(root_path, "r") as f: + original_config = json.load(f) + + try: + updated_config = { + "peers": { + **original_config["peers"], + "node3": { + "address": "localhost", + "port": 50053, + "device_capabilities": { + "model": "Unknown Model", + "chip": "Unknown Chip", + "memory": 0, + "flops": {"fp32": 0, "fp16": 0, "int8": 0}, + }, + }, + } + } + + with open(root_path, "w") as f: + json.dump(updated_config, f, indent=2) + + node3 = mock.AsyncMock(spec=Node) + server3 = GRPCServer(node3, "localhost", 50053) + await server3.start() + + try: + # Wait for the config to be reloaded + await asyncio.sleep(1.5) + + updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) + self.assertEqual(len(updated_peers), 2) + + for peer in updated_peers: + await peer.connect() + self.assertTrue(await peer.is_connected()) + + finally: + await server3.stop() + + finally: + # Restore the original config file + with open(root_path, "w") as f: + json.dump(original_config, f, indent=2) + + # Wait for the config to be reloaded again await asyncio.sleep(1.5) - updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) - self.assertEqual(len(updated_peers), 2) - - for peer in updated_peers: - await peer.connect() - self.assertTrue(await peer.is_connected()) - - finally: - await server3.stop() - - finally: - # Restore the original config file - with open(root_path, "w") as f: - json.dump(original_config, f, indent=2) - - # Wait for the config to be reloaded again - await asyncio.sleep(1.5) - - updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(updated_peers), 1) - + updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(updated_peers), 1) if __name__ == "__main__": - asyncio.run(unittest.main()) + asyncio.run(unittest.main()) From 90de7eada9f6e49592272b9cb5ed537d78ea2a0a Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Wed, 6 Nov 2024 10:50:58 +0700 Subject: [PATCH 05/13] changes after rebase --- exo/networking/manual/manual_discovery.py | 172 +++++----------------- 1 file changed, 38 insertions(+), 134 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 815ca0f3e..0bd2689a8 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -1,7 +1,6 @@ -import os import asyncio from exo.networking.discovery import Discovery -from typing import Dict, List, Callable, Optional +from typing import Dict, List, Callable from exo.topology.device_capabilities import DeviceCapabilities from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig @@ -10,67 +9,47 @@ class ManualDiscovery(Discovery): - def __init__( - self, - network_config_path: str, - node_id: str, - create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], - ): - self.network_config_path = network_config_path - self.node_id = node_id - self.create_peer_handle = create_peer_handle - - if node_id not in self.topology.peers: - raise ValueError( - f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}" - ) - - self.listen_task = None - self.known_peers: Dict[str, PeerHandle] = {} - - self._cached_peers: Dict[str, PeerConfig] = {} - self._last_modified_time: Optional[float] = None + def __init__( + self, + network_config_path: str, + node_id: str, + create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], + ): + self.topology = NetworkTopology.from_path(network_config_path) + self.network_config_path = network_config_path + self.node_id = node_id + self.create_peer_handle = create_peer_handle + + if node_id not in self.topology.peers: + raise ValueError( + f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}" + ) + + self.listen_task = None + + self.known_peers: Dict[str, PeerHandle] = {} + self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers + self.peers_in_network.pop(node_id) async def start(self) -> None: - self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) - self.cleanup_task = asyncio.create_task(self.task_clean_up_peers_from_config()) + self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) async def stop(self) -> None: - if self.listen_task: self.listen_task.cancel() - if self.cleanup_task: self.cleanup_task.cancel() - - async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: - if wait_for_peers > 0: - while len(self.known_peers) < wait_for_peers: - if DEBUG_DISCOVERY >= 2: - print( - f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers..." - ) - await asyncio.sleep(0.1) - if DEBUG_DISCOVERY >= 2: - print( - f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}" - ) - return list(self.known_peers.values()) + if self.listen_task: + self.listen_task.cancel() - async def task_clean_up_peers_from_config(self): - if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...") - while True: - peers_from_config = self._get_peers() - if peers_from_config: - peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config] - - for peer in peers_to_remove: - if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...") - try: del self.known_peers[peer] - except KeyError: pass - - await asyncio.sleep(1.0) + async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: + if wait_for_peers > 0: + while len(self.known_peers) < wait_for_peers: + if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...") + await asyncio.sleep(0.1) + if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}") + return list(self.known_peers.values()) async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") while True: - for peer_id, peer_config in self._get_peers().items(): + for peer_id, peer_config in self.peers_in_network.items(): try: if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") peer = self.known_peers.get(peer_id) @@ -83,88 +62,13 @@ async def task_find_peers_from_config(self): self.known_peers[peer_id] = peer else: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.") - try: del self.known_peers[peer_id] - except KeyError: pass + try: + del self.known_peers[peer_id] + except KeyError: + pass except Exception as e: if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") await asyncio.sleep(1.0) - async def task_find_peers_from_config(self): - if DEBUG_DISCOVERY >= 2: - print("Starting task to find peers from config...") - while True: - peers = self._get_peers().items() - for peer_id, peer_config in peers: - try: - if DEBUG_DISCOVERY >= 2: - print( - f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}" - ) - peer = self.known_peers.get(peer_id) - if not peer: - if DEBUG_DISCOVERY >= 2: - print(f"{peer_id=} not found in known peers. Adding.") - peer = self.create_peer_handle( - peer_id, - f"{peer_config.address}:{peer_config.port}", - peer_config.device_capabilities, - ) - peer = self.create_peer_handle( - peer_id, - f"{peer_config.address}:{peer_config.port}", - peer_config.device_capabilities, - ) - is_healthy = await peer.health_check() - if is_healthy: - if DEBUG_DISCOVERY >= 2: - print( - f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy." - ) - self.known_peers[peer_id] = peer - else: - if DEBUG_DISCOVERY >= 2: - print( - f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy." - ) - try: - del self.known_peers[peer_id] - except KeyError: - pass - except Exception as e: - if DEBUG_DISCOVERY >= 2: - print( - f"Exception occured when attempting to add {peer_id=}: {e}" - ) - await asyncio.sleep(1.0) - - if DEBUG_DISCOVERY >= 2: - print( - f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}" - ) - - def _get_peers(self): - try: - current_mtime = os.path.getmtime(self.network_config_path) - - if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time: - return self._cached_peers - - topology = NetworkTopology.from_path(self.network_config_path) - - if self.node_id not in topology.peers: - raise ValueError( - f"Node ID {self.node_id} not found in network config file " - f"{self.network_config_path}. Please run with `node_id` set to " - f"one of the keys in the config file: {[k for k, _ in topology.peers]}" - ) - - peers_in_network: Dict[str, PeerConfig] = topology.peers - peers_in_network.pop(self.node_id) - - self._cached_peers = peers_in_network - self._last_modified_time = current_mtime - return peers_in_network + if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}") - except Exception as e: - if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}") - return self._cached_peers From 0e34ce2169d332443af87f6a525bda029920a515 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Wed, 6 Nov 2024 11:40:03 +0700 Subject: [PATCH 06/13] patch after rebasing to main --- exo/networking/manual/manual_discovery.py | 64 ++++++++++++++++++----- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 0bd2689a8..e93518213 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -1,6 +1,7 @@ +import os import asyncio from exo.networking.discovery import Discovery -from typing import Dict, List, Callable +from typing import Dict, List, Callable, Optional from exo.topology.device_capabilities import DeviceCapabilities from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig @@ -15,28 +16,24 @@ def __init__( node_id: str, create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], ): - self.topology = NetworkTopology.from_path(network_config_path) self.network_config_path = network_config_path self.node_id = node_id self.create_peer_handle = create_peer_handle - if node_id not in self.topology.peers: - raise ValueError( - f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}" - ) - self.listen_task = None - + self.cleanup_task = None self.known_peers: Dict[str, PeerHandle] = {} - self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers - self.peers_in_network.pop(node_id) + + self._cached_peers: Dict[str, PeerConfig] = {} + self._last_modified_time: Optional[float] = None async def start(self) -> None: self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) + self.cleanup_task = asyncio.create_task(self.task_clean_up_peers_from_config()) async def stop(self) -> None: - if self.listen_task: - self.listen_task.cancel() + if self.listen_task: self.listen_task.cancel() + if self.cleanup_task: self.cleanup_task.cancel() async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if wait_for_peers > 0: @@ -49,7 +46,7 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") while True: - for peer_id, peer_config in self.peers_in_network.items(): + for peer_id, peer_config in self._get_peers().items(): try: if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") peer = self.known_peers.get(peer_id) @@ -72,3 +69,44 @@ async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}") + async def task_clean_up_peers_from_config(self): + if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...") + while True: + peers_from_config = self._get_peers() + if peers_from_config: + peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config] + + for peer in peers_to_remove: + if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...") + try: del self.known_peers[peer] + except KeyError: pass + + await asyncio.sleep(1.0) + + def _get_peers(self): + try: + current_mtime = os.path.getmtime(self.network_config_path) + + if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time: + return self._cached_peers + + topology = NetworkTopology.from_path(self.network_config_path) + + if self.node_id not in topology.peers: + raise ValueError( + f"Node ID {self.node_id} not found in network config file " + f"{self.network_config_path}. Please run with `node_id` set to " + f"one of the keys in the config file: {[k for k, _ in topology.peers]}" + ) + + peers_in_network: Dict[str, PeerConfig] = topology.peers + peers_in_network.pop(self.node_id) + + self._cached_peers = peers_in_network + self._last_modified_time = current_mtime + + return peers_in_network + + except Exception as e: + if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}") + return self._cached_peers From b066c944f3994f2a06ae43aba24755c1215af664 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Fri, 15 Nov 2024 10:20:42 +0700 Subject: [PATCH 07/13] make all I/O ops in manual_discovery.py run inside a ThreadPoolExecutor --- exo/networking/manual/manual_discovery.py | 60 +++++++++++++++-------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index e93518213..4b98514c6 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -1,8 +1,9 @@ import os import asyncio -from exo.networking.discovery import Discovery from typing import Dict, List, Callable, Optional +from concurrent.futures import ThreadPoolExecutor +from exo.networking.discovery import Discovery from exo.topology.device_capabilities import DeviceCapabilities from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig from exo.helpers import DEBUG_DISCOVERY @@ -26,6 +27,8 @@ def __init__( self._cached_peers: Dict[str, PeerConfig] = {} self._last_modified_time: Optional[float] = None + self._file_executor = ThreadPoolExecutor(max_workers=1) + self._lock = asyncio.Lock() async def start(self) -> None: self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) @@ -34,6 +37,7 @@ async def start(self) -> None: async def stop(self) -> None: if self.listen_task: self.listen_task.cancel() if self.cleanup_task: self.cleanup_task.cancel() + self._file_executor.shutdown(wait=True) async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: if wait_for_peers > 0: @@ -46,7 +50,8 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") while True: - for peer_id, peer_config in self._get_peers().items(): + peers_from_config = await self._get_peers() + for peer_id, peer_config in peers_from_config.items(): try: if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") peer = self.known_peers.get(peer_id) @@ -72,7 +77,7 @@ async def task_find_peers_from_config(self): async def task_clean_up_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...") while True: - peers_from_config = self._get_peers() + peers_from_config = await self._get_peers() if peers_from_config: peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config] @@ -83,30 +88,45 @@ async def task_clean_up_peers_from_config(self): await asyncio.sleep(1.0) - def _get_peers(self): + async def _get_peers(self): try: - current_mtime = os.path.getmtime(self.network_config_path) - - if self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time: - return self._cached_peers + async with self._lock: + loop = asyncio.get_running_loop() + current_mtime = await loop.run_in_executor( + self._file_executor, + os.path.getmtime, + self.network_config_path + ) - topology = NetworkTopology.from_path(self.network_config_path) + if (self._cached_peers is not None and + self._last_modified_time is not None and + current_mtime <= self._last_modified_time): + return self._cached_peers - if self.node_id not in topology.peers: - raise ValueError( - f"Node ID {self.node_id} not found in network config file " - f"{self.network_config_path}. Please run with `node_id` set to " - f"one of the keys in the config file: {[k for k, _ in topology.peers]}" + topology = await loop.run_in_executor( + self._file_executor, + NetworkTopology.from_path, + self.network_config_path ) - peers_in_network: Dict[str, PeerConfig] = topology.peers - peers_in_network.pop(self.node_id) + if self.node_id not in topology.peers: + raise ValueError( + f"Node ID {self.node_id} not found in network config file " + f"{self.network_config_path}. Please run with `node_id` set to " + f"one of the keys in the config file: {[k for k, _ in topology.peers]}" + ) + + peers_in_network: Dict[str, PeerConfig] = topology.peers + peers_in_network.pop(self.node_id) - self._cached_peers = peers_in_network - self._last_modified_time = current_mtime + self._cached_peers = peers_in_network + self._last_modified_time = current_mtime - return peers_in_network + return peers_in_network except Exception as e: - if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. Please update the config file in order to successfully discover peers. Exception: {e}") + if DEBUG_DISCOVERY >= 2: + print(f"Error when loading network config file from {self.network_config_path}. " + f"Please update the config file in order to successfully discover peers. " + f"Exception: {e}") return self._cached_peers From 18acb97b426d689ba5de2a3c76a164b8c57ffff6 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Fri, 15 Nov 2024 10:39:21 +0700 Subject: [PATCH 08/13] make popping from dict threadsafe --- exo/networking/manual/manual_discovery.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 4b98514c6..9237a8607 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -64,10 +64,7 @@ async def task_find_peers_from_config(self): self.known_peers[peer_id] = peer else: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.") - try: - del self.known_peers[peer_id] - except KeyError: - pass + self.known_peers.pop(peer_id, None) except Exception as e: if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") await asyncio.sleep(1.0) @@ -83,8 +80,7 @@ async def task_clean_up_peers_from_config(self): for peer in peers_to_remove: if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...") - try: del self.known_peers[peer] - except KeyError: pass + self.known_peers.pop(peer, None) await asyncio.sleep(1.0) From a31f9e6c20d31daa540e7a788a5744707ae84fdb Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Fri, 15 Nov 2024 10:52:00 +0700 Subject: [PATCH 09/13] fix test warnings --- exo/networking/manual/test_manual_discovery.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py index 8af24ce96..90a435c09 100644 --- a/exo/networking/manual/test_manual_discovery.py +++ b/exo/networking/manual/test_manual_discovery.py @@ -41,6 +41,10 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.peer1 = mock.AsyncMock() self.peer2 = mock.AsyncMock() + + self.peer1.id = mock.MagicMock(return_value="node2") + self.peer2.id = mock.MagicMock(return_value="node1") + self.peer1.connect = mock.AsyncMock() self.peer2.connect = mock.AsyncMock() self.discovery1 = ManualDiscovery( From 637446ffa91f44c19ad73107143aa53d099012e5 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Fri, 15 Nov 2024 11:03:35 +0700 Subject: [PATCH 10/13] rm redundant typing --- exo/networking/manual/manual_discovery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 9237a8607..ab2bb9e21 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -112,7 +112,7 @@ async def _get_peers(self): f"one of the keys in the config file: {[k for k, _ in topology.peers]}" ) - peers_in_network: Dict[str, PeerConfig] = topology.peers + peers_in_network = topology.peers peers_in_network.pop(self.node_id) self._cached_peers = peers_in_network From 2eadaa2c0d2de15bfbc1fa5b749e160d2706a23f Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Mon, 25 Nov 2024 18:32:34 +0700 Subject: [PATCH 11/13] rm redundant cleanup task --- exo/networking/manual/manual_discovery.py | 24 ++++------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index ab2bb9e21..63f1cbfe8 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -22,7 +22,6 @@ def __init__( self.create_peer_handle = create_peer_handle self.listen_task = None - self.cleanup_task = None self.known_peers: Dict[str, PeerHandle] = {} self._cached_peers: Dict[str, PeerConfig] = {} @@ -32,11 +31,9 @@ def __init__( async def start(self) -> None: self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) - self.cleanup_task = asyncio.create_task(self.task_clean_up_peers_from_config()) async def stop(self) -> None: if self.listen_task: self.listen_task.cancel() - if self.cleanup_task: self.cleanup_task.cancel() self._file_executor.shutdown(wait=True) async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: @@ -51,6 +48,7 @@ async def task_find_peers_from_config(self): if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") while True: peers_from_config = await self._get_peers() + new_known_peers = {} for peer_id, peer_config in peers_from_config.items(): try: if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") @@ -61,29 +59,15 @@ async def task_find_peers_from_config(self): is_healthy = await peer.health_check() if is_healthy: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") - self.known_peers[peer_id] = peer - else: - if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.") - self.known_peers.pop(peer_id, None) + new_known_peers[peer_id] = peer + elif DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.") except Exception as e: if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") + self.known_peers = new_known_peers await asyncio.sleep(1.0) if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}") - async def task_clean_up_peers_from_config(self): - if DEBUG_DISCOVERY >= 2: print("Starting task to clean up peers from config...") - while True: - peers_from_config = await self._get_peers() - if peers_from_config: - peers_to_remove = [peer for peer in self.known_peers.keys() if peer not in peers_from_config] - - for peer in peers_to_remove: - if DEBUG_DISCOVERY >= 2: print(f"{peer} is no longer found in the config but is currently a known peer. Removing from known peers...") - self.known_peers.pop(peer, None) - - await asyncio.sleep(1.0) - async def _get_peers(self): try: async with self._lock: From 1dfd058c23afa882aad04459bcbfc35271d98415 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Thu, 28 Nov 2024 15:22:45 +0700 Subject: [PATCH 12/13] rm unecessary lock --- exo/networking/manual/manual_discovery.py | 51 +++++++++-------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 63f1cbfe8..69c6cb214 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -27,7 +27,6 @@ def __init__( self._cached_peers: Dict[str, PeerConfig] = {} self._last_modified_time: Optional[float] = None self._file_executor = ThreadPoolExecutor(max_workers=1) - self._lock = asyncio.Lock() async def start(self) -> None: self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) @@ -60,7 +59,8 @@ async def task_find_peers_from_config(self): if is_healthy: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") new_known_peers[peer_id] = peer - elif DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.") + elif DEBUG_DISCOVERY >= 2: + print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.") except Exception as e: if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}") self.known_peers = new_known_peers @@ -70,43 +70,32 @@ async def task_find_peers_from_config(self): async def _get_peers(self): try: - async with self._lock: - loop = asyncio.get_running_loop() - current_mtime = await loop.run_in_executor( - self._file_executor, - os.path.getmtime, - self.network_config_path - ) + loop = asyncio.get_running_loop() + current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path) - if (self._cached_peers is not None and - self._last_modified_time is not None and - current_mtime <= self._last_modified_time): - return self._cached_peers + if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time): + return self._cached_peers - topology = await loop.run_in_executor( - self._file_executor, - NetworkTopology.from_path, - self.network_config_path - ) + topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path) - if self.node_id not in topology.peers: - raise ValueError( - f"Node ID {self.node_id} not found in network config file " - f"{self.network_config_path}. Please run with `node_id` set to " - f"one of the keys in the config file: {[k for k, _ in topology.peers]}" - ) + if self.node_id not in topology.peers: + raise ValueError( + f"Node ID {self.node_id} not found in network config file " + f"{self.network_config_path}. Please run with `node_id` set to " + f"one of the keys in the config file: {[k for k, _ in topology.peers]}" + ) - peers_in_network = topology.peers - peers_in_network.pop(self.node_id) + peers_in_network = topology.peers + peers_in_network.pop(self.node_id) - self._cached_peers = peers_in_network - self._last_modified_time = current_mtime + self._cached_peers = peers_in_network + self._last_modified_time = current_mtime - return peers_in_network + return peers_in_network except Exception as e: if DEBUG_DISCOVERY >= 2: print(f"Error when loading network config file from {self.network_config_path}. " - f"Please update the config file in order to successfully discover peers. " - f"Exception: {e}") + f"Please update the config file in order to successfully discover peers. " + f"Exception: {e}") return self._cached_peers From b003292b89d1294b38c0d0dee47b7bf14fe93ce8 Mon Sep 17 00:00:00 2001 From: Ian Paul Date: Sat, 28 Dec 2024 12:31:15 +0700 Subject: [PATCH 13/13] formatting and fixing tests after rebasing --- exo/networking/manual/manual_discovery.py | 4 +- .../manual/test_manual_discovery.py | 329 +++++++++--------- 2 files changed, 158 insertions(+), 175 deletions(-) diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py index 69c6cb214..1c232bb64 100644 --- a/exo/networking/manual/manual_discovery.py +++ b/exo/networking/manual/manual_discovery.py @@ -15,7 +15,7 @@ def __init__( self, network_config_path: str, node_id: str, - create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], + create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle], ): self.network_config_path = network_config_path self.node_id = node_id @@ -54,7 +54,7 @@ async def task_find_peers_from_config(self): peer = self.known_peers.get(peer_id) if not peer: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.") - peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities) + peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities) is_healthy = await peer.health_check() if is_healthy: if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py index 90a435c09..317fba9d8 100644 --- a/exo/networking/manual/test_manual_discovery.py +++ b/exo/networking/manual/test_manual_discovery.py @@ -12,187 +12,170 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - self.peer1 = mock.AsyncMock(spec=Node) - self.peer1.connect = mock.AsyncMock() - self.server1 = GRPCServer(self.peer1, "localhost", 8000) - await self.server1.start() - self.discovery1 = ManualDiscovery( - root_path, - "node1", - create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle( - peer_id, address, device_capabilities - ), - ) - await self.discovery1.start() - - async def asyncTearDown(self): - await self.discovery1.stop() - await self.server1.stop() - - async def test_discovery(self): - peers1 = await self.discovery1.discover_peers(wait_for_peers=0) - self.assertEqual(len(peers1), 0) - - self.peer1.connect.assert_not_called() + async def asyncSetUp(self): + self.peer1 = mock.AsyncMock() + self.peer1.connect = mock.AsyncMock() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, + ) + await self.discovery1.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + + async def test_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=0) + assert len(peers1) == 0 + + self.peer1.connect.assert_not_called() class TestManualDiscovery(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - self.peer1 = mock.AsyncMock() - self.peer2 = mock.AsyncMock() - - self.peer1.id = mock.MagicMock(return_value="node2") - self.peer2.id = mock.MagicMock(return_value="node1") - - self.peer1.connect = mock.AsyncMock() - self.peer2.connect = mock.AsyncMock() - self.discovery1 = ManualDiscovery( - root_path, - "node1", - create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, - ) - self.discovery2 = ManualDiscovery( - root_path, - "node2", - create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2, - ) - await self.discovery1.start() - await self.discovery2.start() - - async def asyncTearDown(self): - await self.discovery1.stop() - await self.discovery2.stop() - - async def test_discovery(self): - peers1 = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers1), 1) - peers2 = await self.discovery2.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers2), 1) - - # connect has to be explicitly called after discovery - self.peer1.connect.assert_not_called() - self.peer2.connect.assert_not_called() + async def asyncSetUp(self): + self.peer1 = mock.AsyncMock() + self.peer2 = mock.AsyncMock() + self.peer1.connect = mock.AsyncMock() + self.peer2.connect = mock.AsyncMock() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, + ) + self.discovery2 = ManualDiscovery( + root_path, + "node2", + create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2, + ) + await self.discovery1.start() + await self.discovery2.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.discovery2.stop() + + async def test_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=1) + assert len(peers1) == 1 + peers2 = await self.discovery2.discover_peers(wait_for_peers=1) + assert len(peers2) == 1 + + # connect has to be explicitly called after discovery + self.peer1.connect.assert_not_called() + self.peer2.connect.assert_not_called() class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - config = NetworkTopology.from_path(root_path) - - self.node1 = mock.AsyncMock(spec=Node) - self.node2 = mock.AsyncMock(spec=Node) - self.server1 = GRPCServer( - self.node1, config.peers["node1"].address, config.peers["node1"].port - ) - self.server2 = GRPCServer( - self.node2, config.peers["node2"].address, config.peers["node2"].port - ) - await self.server1.start() - await self.server2.start() - self.discovery1 = ManualDiscovery( - root_path, - "node1", - create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle( - peer_id, address, description, device_capabilities - ), - ) - self.discovery2 = ManualDiscovery( - root_path, - "node2", - create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle( - peer_id, address, description, device_capabilities - ), - ) - await self.discovery1.start() - await self.discovery2.start() - - async def asyncTearDown(self): - await self.discovery1.stop() - await self.discovery2.stop() - await self.server1.stop() - await self.server2.stop() - - async def test_grpc_discovery(self): - peers1 = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers1), 1) - peers2 = await self.discovery2.discover_peers(wait_for_peers=1) - self.assertEqual(len(peers2), 1) - - # Connect - await peers1[0].connect() - await peers2[0].connect() - self.assertTrue(await peers1[0].is_connected()) - self.assertTrue(await peers2[0].is_connected()) - - # Kill server1 - await self.server1.stop() - - self.assertTrue(await peers1[0].is_connected()) - self.assertFalse(await peers2[0].is_connected()) - - # Kill server2 - await self.server2.stop() - - self.assertFalse(await peers1[0].is_connected()) - self.assertFalse(await peers2[0].is_connected()) - - async def test_dynamic_config_update(self): - initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(initial_peers), 1) - - # Save original config for cleanup - with open(root_path, "r") as f: - original_config = json.load(f) - - try: - updated_config = { - "peers": { - **original_config["peers"], - "node3": { - "address": "localhost", - "port": 50053, - "device_capabilities": { - "model": "Unknown Model", - "chip": "Unknown Chip", - "memory": 0, - "flops": {"fp32": 0, "fp16": 0, "int8": 0}, - }, - }, - } - } - - with open(root_path, "w") as f: - json.dump(updated_config, f, indent=2) - - node3 = mock.AsyncMock(spec=Node) - server3 = GRPCServer(node3, "localhost", 50053) - await server3.start() - - try: - # Wait for the config to be reloaded - await asyncio.sleep(1.5) - - updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) - self.assertEqual(len(updated_peers), 2) - - for peer in updated_peers: - await peer.connect() - self.assertTrue(await peer.is_connected()) - - finally: - await server3.stop() - - finally: - # Restore the original config file - with open(root_path, "w") as f: - json.dump(original_config, f, indent=2) - - # Wait for the config to be reloaded again + async def asyncSetUp(self): + config = NetworkTopology.from_path(root_path) + + self.node1 = mock.AsyncMock(spec=Node) + self.node2 = mock.AsyncMock(spec=Node) + self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port) + self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port) + await self.server1.start() + await self.server2.start() + self.discovery1 = ManualDiscovery( + root_path, + "node1", + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), + ) + self.discovery2 = ManualDiscovery( + root_path, + "node2", + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), + ) + await self.discovery1.start() + await self.discovery2.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.discovery2.stop() + await self.server1.stop() + await self.server2.stop() + + async def test_grpc_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=1) + assert len(peers1) == 1 + peers2 = await self.discovery2.discover_peers(wait_for_peers=1) + assert len(peers2) == 1 + + # Connect + await peers1[0].connect() + await peers2[0].connect() + self.assertTrue(await peers1[0].is_connected()) + self.assertTrue(await peers2[0].is_connected()) + + # Kill server1 + await self.server1.stop() + + self.assertTrue(await peers1[0].is_connected()) + self.assertFalse(await peers2[0].is_connected()) + + # Kill server2 + await self.server2.stop() + + self.assertFalse(await peers1[0].is_connected()) + self.assertFalse(await peers2[0].is_connected()) + + async def test_dynamic_config_update(self): + initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(initial_peers), 1) + + # Save original config for cleanup + with open(root_path, "r") as f: + original_config = json.load(f) + + try: + updated_config = { + "peers": { + **original_config["peers"], + "node3": { + "address": "localhost", + "port": 50053, + "device_capabilities": { + "model": "Unknown Model", + "chip": "Unknown Chip", + "memory": 0, + "flops": {"fp32": 0, "fp16": 0, "int8": 0}, + }, + }, + } + } + + with open(root_path, "w") as f: + json.dump(updated_config, f, indent=2) + + node3 = mock.AsyncMock(spec=Node) + server3 = GRPCServer(node3, "localhost", 50053) + await server3.start() + + try: + # Wait for the config to be reloaded await asyncio.sleep(1.5) - updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) - self.assertEqual(len(updated_peers), 1) + updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) + self.assertEqual(len(updated_peers), 2) + + for peer in updated_peers: + await peer.connect() + self.assertTrue(await peer.is_connected()) + + finally: + await server3.stop() + + finally: + # Restore the original config file + with open(root_path, "w") as f: + json.dump(original_config, f, indent=2) + + # Wait for the config to be reloaded again + await asyncio.sleep(1.5) + + updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) + self.assertEqual(len(updated_peers), 1) if __name__ == "__main__": - asyncio.run(unittest.main()) + asyncio.run(unittest.main())