Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for cross-platform operability #607

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 121 additions & 154 deletions exo/api/chatgpt_api.py

Large diffs are not rendered by default.

42 changes: 22 additions & 20 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import platform
import psutil
import uuid
import netifaces
from scapy.all import get_if_addr, get_if_list
import re
import subprocess
from pathlib import Path
import tempfile
Expand Down Expand Up @@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
def get_all_ip_addresses_and_interfaces():
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append((ip, interface))
for interface in get_if_list():
ip = get_if_addr(interface)
# Include all addresses, including loopback
# Filter out link-local addresses
if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
# Remove "\\Device\\NPF_" prefix from interface name
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
ip_addresses.append((ip, simplified_interface))
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]


async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
try:
# Use the shared subprocess_pool
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
['system_profiler', 'SPNetworkDataType', '-json'],
capture_output=True,
text=True,
close_fds=True
).stdout)
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
)

data = json.loads(output)

Expand All @@ -276,15 +277,15 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:

return None


async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# On macOS, try to get interface type using networksetup
if psutil.MACOS:
macos_type = await get_macos_interface_type(ifname)
if macos_type is not None: return macos_type

# Local container/virtual interfaces
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
'bridge' in ifname):
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
return (7, "Container Virtual")

# Loopback interface
Expand All @@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# Other physical interfaces
return (2, "Other")


async def shutdown(signal, loop, server):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
Expand All @@ -329,16 +331,16 @@ def is_frozen():


def get_exo_home() -> Path:
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else: docs_folder = Path.home() / "Documents"
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
else: docs_folder = Path.home()/"Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder / "Exo"
exo_folder = docs_folder/"Exo"
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder


def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
images_dir = exo_home/"Images"
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir

13 changes: 8 additions & 5 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from exo.download.shard_download import ShardDownloader


class InferenceEngine(ABC):
Expand All @@ -13,7 +14,7 @@ class InferenceEngine(ABC):
@abstractmethod
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass

@abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray:
pass
Expand All @@ -32,13 +33,13 @@ async def load_checkpoint(self, shard: Shard, path: str):

async def save_checkpoint(self, shard: Shard, path: str):
pass

async def save_session(self, key, value):
self.session[key] = value

async def clear_session(self):
self.session.empty()

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
if shard.model_id != 'stable-diffusion-2-1-base':
Expand All @@ -49,13 +50,15 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inferen

return output_data, inference_state


inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
"dummy": "DummyInferenceEngine",
}

def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):

def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":
Expand Down
61 changes: 27 additions & 34 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import json
import mlx.core as mx
import platform

if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx


class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
Expand All @@ -37,11 +43,9 @@ def device_capabilities(self) -> DeviceCapabilities:

async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()

Expand Down Expand Up @@ -109,7 +113,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
return None

return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
Expand All @@ -131,7 +135,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr
return loss, grads
else:
return loss

async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
Expand Down Expand Up @@ -166,10 +170,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(
model=capabilities.model,
chip=capabilities.chip,
memory=capabilities.memory,
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
)
topology.update_node(node_id, device_capabilities)
for node_id, peer_connections in response.peer_graph.items():
Expand All @@ -193,28 +194,20 @@ def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.I
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state
44 changes: 21 additions & 23 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import numpy as np
from asyncio import CancelledError

import platform

from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
import json
import mlx.core as mx

if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx


class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
Expand Down Expand Up @@ -74,7 +80,7 @@ async def SendTensor(self, request, context):
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

async def SendExample(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
Expand All @@ -96,7 +102,7 @@ async def SendExample(self, request, context):
else:
loss = await self.node.process_example(shard, example, target, length, train, request_id)
return node_service_pb2.Loss(loss=loss, grads=None)

async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
Expand All @@ -112,12 +118,7 @@ async def CollectTopology(self, request, context):
for node_id, cap in topology.nodes.items()
}
peer_graph = {
node_id: node_service_pb2.PeerConnections(
connections=[
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
for conn in connections
]
)
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
for node_id, connections in topology.peer_graph.items()
}
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
Expand All @@ -131,7 +132,7 @@ async def SendResult(self, request, context):
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()

Expand All @@ -145,21 +146,18 @@ async def SendOpaqueStatus(self, request, context):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)

def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}

for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)

for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
for tensor in tensor_list.tensors
]

inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]

if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)

return inference_state
Loading