-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallelise Model Loading #466
base: main
Are you sure you want to change the base?
Changes from all commits
10b4a47
7b9e3e4
8331d92
33c3580
d0fcab2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ | |
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?") | ||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key") | ||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name") | ||
parser.add_argument("--preload-models", type=str, help="Comma-separated list of models to preload") | ||
vovw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
args = parser.parse_args() | ||
print(f"Selected inference engine: {args.inference_engine}") | ||
|
||
|
@@ -133,17 +134,21 @@ | |
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None | ||
) | ||
|
||
def preemptively_start_download(request_id: str, opaque_status: str): | ||
|
||
async def preemptively_start_download(request_id: str, opaque_status: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this should be made async. |
||
try: | ||
status = json.loads(opaque_status) | ||
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt": | ||
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard"))) | ||
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}") | ||
asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__)) | ||
await shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__) | ||
await node.preload_models([current_shard]) | ||
return current_shard | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we return this? |
||
except Exception as e: | ||
if DEBUG >= 2: | ||
print(f"Failed to preemptively start download: {e}") | ||
traceback.print_exc() | ||
return None | ||
|
||
|
||
node.on_opaque_status.register("start_download").on_next(preemptively_start_download) | ||
|
@@ -225,6 +230,33 @@ def handle_exit(): | |
|
||
await node.start(wait_for_peers=args.wait_for_peers) | ||
|
||
# Preload models if specified | ||
if args.preload_models: | ||
models_to_preload = [model.strip() for model in args.preload_models.split(",")] | ||
if DEBUG >= 2: | ||
print(f"Preloading models: {models_to_preload}") | ||
Comment on lines
+236
to
+237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one line |
||
|
||
inference_class = inference_engine.__class__.__name__ | ||
shards = [ | ||
shard for model in models_to_preload | ||
if (shard := build_base_shard(model, inference_class)) is not None | ||
] | ||
Comment on lines
+240
to
+243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but these on one line please |
||
|
||
if len(shards) < len(models_to_preload): | ||
unsupported = [ | ||
model for model in models_to_preload | ||
if not build_base_shard(model, inference_class) | ||
] | ||
Comment on lines
+246
to
+249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one line |
||
print(f"Warning: Unsupported model(s) for {inference_class}: {', '.join(unsupported)}") | ||
|
||
try: | ||
await node.preload_models(shards) | ||
print(f"Successfully preloaded {len(shards)} model(s)") | ||
except Exception as e: | ||
print(f"Error preloading models: {str(e)}") | ||
if DEBUG >= 1: | ||
traceback.print_exc() | ||
|
||
if args.command == "run" or args.run_model: | ||
model_name = args.model_name or args.run_model | ||
if not model_name: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?