Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelise Model Loading #466

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion exo/inference/dummy_inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, TYPE_CHECKING
rom typing import Optional, Tuple, TYPE_CHECKING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

import numpy as np
import random
import string
Expand Down Expand Up @@ -41,3 +41,6 @@ async def ensure_shard(self, shard: Shard):
await asyncio.sleep(0.1) # Simulate a short delay
self.shard = shard
print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

async def preload_model(self, shard: Shard) -> None:
await self.ensure_shard(shard)
4 changes: 4 additions & 0 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.n
output_data = await self.infer_tensor(request_id, shard, tokens)
return output_data

@abstractmethod
async def preload_model(self, shard: Shard) -> None:
pass
vovw marked this conversation as resolved.
Show resolved Hide resolved

inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
Expand Down
3 changes: 3 additions & 0 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,6 @@ def load_shard_wrapper():
model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
self.shard = shard
self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)

async def preload_model(self, shard: Shard) -> None:
await self.ensure_shard(shard)
3 changes: 3 additions & 0 deletions exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ async def ensure_shard(self, shard: Shard):
self.tokenizer = await resolve_tokenizer(tokenizer_path)
self.shard = shard
self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)

async def preload_model(self, shard: Shard) -> None:
await self.ensure_shard(shard)
36 changes: 34 additions & 2 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--preload-models", type=str, help="Comma-separated list of models to preload")
vovw marked this conversation as resolved.
Show resolved Hide resolved
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")

Expand Down Expand Up @@ -133,17 +134,21 @@
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)

def preemptively_start_download(request_id: str, opaque_status: str):

async def preemptively_start_download(request_id: str, opaque_status: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be made async.
creating tasks fire-and-forget-style like it was before is better. if you need to run stuff in sequence, create a task for an asynchronous function that awaits each step.

try:
status = json.loads(opaque_status)
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
await shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__)
await node.preload_models([current_shard])
return current_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we return this?

except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()
return None


node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
Expand Down Expand Up @@ -225,6 +230,33 @@ def handle_exit():

await node.start(wait_for_peers=args.wait_for_peers)

# Preload models if specified
if args.preload_models:
models_to_preload = [model.strip() for model in args.preload_models.split(",")]
if DEBUG >= 2:
print(f"Preloading models: {models_to_preload}")
Comment on lines +236 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one line


inference_class = inference_engine.__class__.__name__
shards = [
shard for model in models_to_preload
if (shard := build_base_shard(model, inference_class)) is not None
]
Comment on lines +240 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but these on one line please


if len(shards) < len(models_to_preload):
unsupported = [
model for model in models_to_preload
if not build_base_shard(model, inference_class)
]
Comment on lines +246 to +249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one line

print(f"Warning: Unsupported model(s) for {inference_class}: {', '.join(unsupported)}")

try:
await node.preload_models(shards)
print(f"Successfully preloaded {len(shards)} model(s)")
except Exception as e:
print(f"Error preloading models: {str(e)}")
if DEBUG >= 1:
traceback.print_exc()

if args.command == "run" or args.run_model:
model_name = args.model_name or args.run_model
if not model_name:
Expand Down
4 changes: 4 additions & 0 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
@abstractmethod
def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
pass

@abstractmethod
async def preload_models(self, shards: List[Shard]) -> None:
pass
24 changes: 14 additions & 10 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader


class StandardNode(Node):
def __init__(
self,
Expand Down Expand Up @@ -101,11 +102,11 @@ async def broadcast_supported_engines(self, supported_engines_names: List[str]):

def get_topology_inference_engines(self) -> List[List[str]]:
return self.topology_inference_engines_pool

async def encode_prompt(self, shard: Shard, prompt):
toks = await self.inference_engine.encode(shard, prompt)
return toks

async def process_result(
self,
shard,
Expand All @@ -114,15 +115,15 @@ async def process_result(
):
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)

if request_id not in self.buffered_logits:
self.buffered_logits[request_id] = []

self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]

if shard.is_last_layer():
result = await self.inference_engine.sample(result)

await self.inference_engine.ensure_shard(shard)
is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens

Expand All @@ -131,7 +132,7 @@ async def process_result(
if result.size == 1: # we got a new token out
self.buffered_token_output[request_id][0].append(result.item())
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)

if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")

if is_finished:
Expand Down Expand Up @@ -196,7 +197,7 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Opti
return None
else:
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
ret = await self.process_result(shard, result, request_id)
ret = await self.process_result(shard, result, request_id)
return result

async def process_tensor(
Expand Down Expand Up @@ -255,7 +256,7 @@ async def _process_tensor(
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
try:
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
ret = await self.process_result(shard, result, request_id)
ret = await self.process_result(shard, result, request_id)
return ret
except Exception as e:
print(f"Error processing tensor for shard {shard}: {e}")
Expand All @@ -272,7 +273,7 @@ async def forward_to_next_shard(
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
return

next_partition_index = self.get_partition_index(offset = 1)
next_partition_index = self.get_partition_index(offset=1)
if DEBUG >= 1: print(f"Next partition index: {next_partition_index}")
if next_partition_index is not None:
target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
Expand All @@ -299,7 +300,7 @@ def get_partition_index(self, offset: int = 0):
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
if current_partition_index is None:
raise ValueError(f"No current partition found for node: {self.id}")
return (current_partition_index + offset) % len(partitions)
return (current_partition_index+offset) % len(partitions)

def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
if index is None:
Expand Down Expand Up @@ -429,7 +430,7 @@ def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
self.on_token.trigger_all(request_id, tokens, is_finished)

async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
async def send_result_to_peer(peer):
try:
Expand Down Expand Up @@ -461,3 +462,6 @@ async def send_status_to_peer(peer):
@property
def current_topology(self) -> Topology:
return self.topology

async def preload_models(self, shards: List[Shard]) -> None:
await asyncio.gather(*(asyncio.create_task(self.inference_engine.preload_model(shard)) for shard in shards))