Skip to content

Commit

Permalink
Merge pull request #602 from exo-explore/fixexodir
Browse files Browse the repository at this point in the history
fix exo folder
  • Loading branch information
AlexCheema authored Jan 12, 2025
2 parents b5cbcbc + fcc699a commit c260689
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 46 deletions.
13 changes: 5 additions & 8 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,19 +329,16 @@ def is_frozen():


def get_exo_home() -> Path:
if os.name == "nt": # Check if the OS is Windows
docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else:
docs_folder = Path.home() / "Documents"
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else: docs_folder = Path.home() / "Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder / "Exo"
if not exo_folder.exists():
exo_folder.mkdir()
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder

def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
if not images_dir.exists():
images_dir.mkdir()
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir

10 changes: 5 additions & 5 deletions exo/inference/debug_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)

next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=token_full,
)

resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp1,
)
token2 = await inference_engine_2.sample(resp2)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
input_data=token2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp3,
Expand Down
4 changes: 2 additions & 2 deletions exo/inference/dummy_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) ->
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
return self.tokenizer.decode(tokens)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
return input_data + 1 if self.shard.is_last_layer() else input_data
return input_data + 1 if self.shard.is_last_layer() else input_data, None

async def ensure_shard(self, shard: Shard):
if self.shard == shard: return
Expand Down
4 changes: 2 additions & 2 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
pass

@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
pass

@abstractmethod
Expand All @@ -39,7 +39,7 @@ async def save_session(self, key, value):
async def clear_session(self):
self.session.empty()

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
Expand Down
6 changes: 3 additions & 3 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
if self.model.model_type != 'StableDiffusionPipeline':
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
else:
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
output_data = np.array(output_data)
return output_data, inference_state

Expand Down
16 changes: 5 additions & 11 deletions exo/inference/test_dummy_inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import pytest
import json
import numpy as np
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.shard import Shard


class MockShardDownloader:
async def ensure_shard(self, shard):
pass


@pytest.mark.asyncio
async def test_dummy_inference_specific():
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
test_prompt = "This is a test prompt"

result = await engine.infer_prompt("test_request", test_shard, test_prompt)
result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)

print(f"Inference result shape: {result.shape}")

Expand All @@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
@pytest.mark.asyncio
async def test_dummy_inference_engine():
# Initialize the DummyInferenceEngine
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()

# Create a test shard
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)

# Test infer_prompt
output = await engine.infer_prompt("test_id", shard, "Test prompt")
output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")

assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"

# Test infer_tensor
input_tensor = np.array([[1, 2, 3]])
output = await engine.infer_tensor("test_id", shard, input_tensor)
output, _ = await engine.infer_tensor("test_id", shard, input_tensor)

assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
Expand Down
12 changes: 6 additions & 6 deletions exo/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
prompt = "In a single word only, what is the last name of the current president of the USA?"
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)
token_full = token_full.reshape(1, -1)
next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
input_data=token_full,
)

pp = n_layers // 2
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp1,
)
tokens2 = await inference_engine_1.sample(resp2)
tokens2 = tokens2.reshape(1, -1)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
input_data=tokens2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp3,
Expand Down
2 changes: 1 addition & 1 deletion exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def save_checkpoint(self, shard: Shard, path: str):
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
safe_save(state_dict, path)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
def wrap_infer():
x = Tensor(input_data)
Expand Down
4 changes: 2 additions & 2 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)

Expand All @@ -101,7 +101,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendTensor(request)

Expand Down
4 changes: 2 additions & 2 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def SendPrompt(self, request, context):
)
prompt = request.prompt
request_id = request.request_id
inference_state = self.deserialize_inference_state(request.inference_state)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
Expand All @@ -68,7 +68,7 @@ async def SendTensor(self, request, context):
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id

inference_state = self.deserialize_inference_state(request.inference_state)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)

result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
Expand Down
6 changes: 3 additions & 3 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Opti
return None
else:
self.outstanding_requests[request_id] = "processing"
result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return result

Expand Down Expand Up @@ -320,7 +320,7 @@ async def _process_example(
loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
else:
self.outstanding_requests[request_id] = "preprocessing"
step = await self.inference_engine.infer_tensor(request_id, shard, example)
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
self.outstanding_requests[request_id] = "training"
Expand All @@ -336,7 +336,7 @@ async def _process_example(
loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
else:
self.outstanding_requests[request_id] = "preprocessing"
step = await self.inference_engine.infer_tensor(request_id, shard, example)
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
self.outstanding_requests.pop(request_id)
Expand Down
2 changes: 1 addition & 1 deletion test/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)

ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
models = []
for model_id in model_cards:
Expand Down

0 comments on commit c260689

Please sign in to comment.