From 8c191050a2dd20dea129d441c1e7d98e22228609 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 5 Jan 2025 02:31:59 +0000 Subject: [PATCH 1/2] download status in parallel, support async ensure shard with using shard_downloader instead --- exo/api/chatgpt_api.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 73762de71..4daa199db 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -245,7 +245,7 @@ async def handle_model_support(self, request): ) await response.prepare(request) - for model_name, pretty in pretty_name.items(): + async def process_model(model_name, pretty): if model_name in model_cards: model_info = model_cards[model_name] @@ -273,6 +273,12 @@ async def handle_model_support(self, request): await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) + # Process all models in parallel + await asyncio.gather(*[ + process_model(model_name, pretty) + for model_name, pretty in pretty_name.items() + ]) + await response.write(b"data: [DONE]\n\n") return response @@ -562,7 +568,7 @@ async def handle_post_download(self, request): if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400) shard = build_base_shard(model_name, self.inference_engine_classname) if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400) - asyncio.create_task(self.node.inference_engine.ensure_shard(shard)) + asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard)) return web.json_response({ "status": "success", From 29244c6369b4a81c13dd41fab6be20121642613b Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sun, 5 Jan 2025 02:33:25 +0000 Subject: [PATCH 2/2] fix args for ensure_shard --- exo/api/chatgpt_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 4daa199db..ef9634003 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -568,7 +568,7 @@ async def handle_post_download(self, request): if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400) shard = build_base_shard(model_name, self.inference_engine_classname) if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400) - asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard)) + asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname)) return web.json_response({ "status": "success",