-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallelise Model Loading #466
base: main
Are you sure you want to change the base?
Conversation
@AlexCheema PTAL lmk if u need more changes. |
@AlexCheema PTAL |
- Wrapped debug print statements with `DEBUG >= 2` condition for better logging control. - Consolidated shard building and preloading into a single operation using list comprehensions. - Improved error handling to cover all models in a batch, reducing redundancy. - Added clearer messaging for unsupported models. - Simplified code structure for better readability and performance.
@AlexCheema , say do I run the formatter over the whole codebase ?? or just the files I edited ? |
@@ -1,4 +1,4 @@ | |||
from typing import Optional, Tuple, TYPE_CHECKING | |||
rom typing import Optional, Tuple, TYPE_CHECKING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?
@@ -133,17 +134,21 @@ | |||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None | |||
) | |||
|
|||
def preemptively_start_download(request_id: str, opaque_status: str): | |||
|
|||
async def preemptively_start_download(request_id: str, opaque_status: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be made async.
creating tasks fire-and-forget-style like it was before is better. if you need to run stuff in sequence, create a task for an asynchronous function that awaits each step.
asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__)) | ||
await shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__) | ||
await node.preload_models([current_shard]) | ||
return current_shard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we return this?
shards = [ | ||
shard for model in models_to_preload | ||
if (shard := build_base_shard(model, inference_class)) is not None | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but these on one line please
unsupported = [ | ||
model for model in models_to_preload | ||
if not build_base_shard(model, inference_class) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one line
if DEBUG >= 2: | ||
print(f"Preloading models: {models_to_preload}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one line
Please respond to my review @vovw |
flooded with college work rn will address these tomorrow |
Add model parallelise preloading capability for improved inference startup time
Adds model preloading functionality to improve initial inference latency by allowing models to be loaded into memory before they're needed.
Key changes:
--preload-models
CLI arg to specify models for preloadingpreload_model
method in inference engine interfacePrimary motivation is reducing cold-start latency by preloading models before they're needed, useful for deployments requiring predictable latency.
Tested with MLX engine and verified working with preemptive downloads. Built on existing shard infrastructure, maintains backward compatibility.
test using
prev pr #360