Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelise Model Loading #466

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Parallelise Model Loading #466

wants to merge 5 commits into from

Conversation

vovw
Copy link

@vovw vovw commented Nov 16, 2024

Add model parallelise preloading capability for improved inference startup time

Adds model preloading functionality to improve initial inference latency by allowing models to be loaded into memory before they're needed.

Key changes:

  • Added --preload-models CLI arg to specify models for preloading
  • Introduced preload_model method in inference engine interface
  • Implemented preloading in MLX engine using existing shard loading
  • Enhanced preemptive download to also preload models after download
  • Added concurrent model preloading support in StandardNode

Primary motivation is reducing cold-start latency by preloading models before they're needed, useful for deployments requiring predictable latency.

Tested with MLX engine and verified working with preemptive downloads. Built on existing shard infrastructure, maintains backward compatibility.

test using

exo --preload-models model1, model2
exo --preload-models llama-3.2-1b,llama-3.1-8b 

prev pr #360

@vovw
Copy link
Author

vovw commented Nov 16, 2024

@AlexCheema PTAL

lmk if u need more changes.

exo/main.py Outdated Show resolved Hide resolved
exo/orchestration/standard_node.py Outdated Show resolved Hide resolved
exo/main.py Show resolved Hide resolved
@vovw
Copy link
Author

vovw commented Nov 19, 2024

@AlexCheema PTAL

exo/inference/inference_engine.py Show resolved Hide resolved
exo/main.py Outdated Show resolved Hide resolved
exo/main.py Outdated Show resolved Hide resolved
- Wrapped debug print statements with `DEBUG >= 2` condition for better logging control.
- Consolidated shard building and preloading into a single operation using list comprehensions.
- Improved error handling to cover all models in a batch, reducing redundancy.
- Added clearer messaging for unsupported models.
- Simplified code structure for better readability and performance.
@vovw
Copy link
Author

vovw commented Nov 20, 2024

@AlexCheema , say do I run the formatter over the whole codebase ?? or just the files I edited ?

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, TYPE_CHECKING
rom typing import Optional, Tuple, TYPE_CHECKING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

@@ -133,17 +134,21 @@
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)

def preemptively_start_download(request_id: str, opaque_status: str):

async def preemptively_start_download(request_id: str, opaque_status: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be made async.
creating tasks fire-and-forget-style like it was before is better. if you need to run stuff in sequence, create a task for an asynchronous function that awaits each step.

asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
await shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__)
await node.preload_models([current_shard])
return current_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we return this?

Comment on lines +240 to +243
shards = [
shard for model in models_to_preload
if (shard := build_base_shard(model, inference_class)) is not None
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but these on one line please

Comment on lines +246 to +249
unsupported = [
model for model in models_to_preload
if not build_base_shard(model, inference_class)
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one line

Comment on lines +236 to +237
if DEBUG >= 2:
print(f"Preloading models: {models_to_preload}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one line

@AlexCheema
Copy link
Contributor

Please respond to my review @vovw

@vovw
Copy link
Author

vovw commented Nov 28, 2024

Please respond to my review @vovw

flooded with college work rn will address these tomorrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants