Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support local model with inference-engine mlx #475

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
89665df
add inference mlx run local model in single node
OKHand-Zy Nov 8, 2024
b2bcc12
Merge exo f1eec9f commit version
OKHand-Zy Nov 15, 2024
0b87eb9
filter merge erro
OKHand-Zy Nov 15, 2024
e2f0723
filter f1eec9f model change
OKHand-Zy Nov 15, 2024
2a2e3b2
futuer:(i-e:mlx)suppoert local model and HF model terminal complet
OKHand-Zy Nov 18, 2024
d8bbb2b
futur:support cli and chatapi local model complet
OKHand-Zy Nov 19, 2024
9ab3513
filter run_model_cli
OKHand-Zy Nov 20, 2024
a9c345a
add init_exo_function (helpers.py)
OKHand-Zy Nov 20, 2024
1ae2648
Merge branch 'main' 1fa42f3 into support-local-model
OKHand-Zy Nov 20, 2024
cdba915
filter read me
OKHand-Zy Nov 20, 2024
0b06fe1
filter cli local model error
OKHand-Zy Nov 20, 2024
438eae4
filter some mark
OKHand-Zy Nov 20, 2024
32574b9
Merge exo 93d38e2 commits
OKHand-Zy Nov 22, 2024
535cb44
filter name miss
OKHand-Zy Nov 22, 2024
1065242
Merge exo commit 7013041 fix 'CLI reply error'
OKHand-Zy Nov 22, 2024
1e5429d
remove:add lcoal model arg
OKHand-Zy Nov 28, 2024
880682a
filter: other dir no config error
OKHand-Zy Dec 6, 2024
c4e18ae
merge: new image
OKHand-Zy Dec 6, 2024
daeb1eb
Merge exo 17411df branch
OKHand-Zy Dec 6, 2024
dd49f08
Merge exo d4cc2cf commit
OKHand-Zy Dec 12, 2024
cb270ed
filter: chat-api modelpool for Local model
OKHand-Zy Dec 12, 2024
316e58b
add: Local Model Api
OKHand-Zy Dec 12, 2024
961c0dc
filter: file&code to other palce
OKHand-Zy Dec 12, 2024
85d41f0
future: enhance lh_helpers but download_file function not complete
OKHand-Zy Dec 16, 2024
aca7eba
filter: add args setting sotre model ip and port
OKHand-Zy Dec 17, 2024
ef60ff3
future:add other node update local model to model_crads
OKHand-Zy Dec 18, 2024
e96f5ae
Merge exo cfedcec commit
OKHand-Zy Dec 19, 2024
22b4484
filter: _resolve_tokenizer find error
OKHand-Zy Dec 19, 2024
8c3a600
future: local mode auto download prototype
OKHand-Zy Dec 19, 2024
f0e1877
filter:download file path
OKHand-Zy Dec 20, 2024
5eef335
filter: stored_model dir path and http class name
OKHand-Zy Dec 20, 2024
c557fe8
feat:chatAPI can run Local model on multi node
OKHand-Zy Dec 24, 2024
8e05ad7
Merge branch 'auto dowload-local-model' into support-local-model
OKHand-Zy Dec 24, 2024
322b36c
docs: Update the README.md to add instructions for using local models…
OKHand-Zy Dec 24, 2024
677260e
Merge exo main branch fdc3b5a commit
OKHand-Zy Dec 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,64 @@ With a custom prompt:
exo run llama-3.2-3b --prompt "What is the meaning of exo?"
```

### Example Usage on a single device with Local Model
- Step1. Download the model from huggingface
- Step2. Put it in the exo folder(~/.cache/exo)
- Step3.
```shell
exo --inference-engine mlx --run-model <~/.cache/exo/'model_name'>
or
exo --inference-engine mlx --run-model <Local/'model_name'>
```

### Example Usage on Multiple Devices with Local Model
#### Manual Put Local Model:
- Step1. Download the Hugging Face model and place it in the exo folder (~/.cache/exo) of each node.
- Step2.
```shell
# Every node runs the following command
exo --inference-engine mlx
```
- Step3.
```shell
curl http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Local/<model_name>",
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
"temperature": 0.7
}'
```
#### Auto Download Local Model:
- Step1. Download the Hugging Face model and place it in the exo folder (~/.cache/exo) of one node.

- Step2.
```shell
# Every node runs the following command
exo --stored-model-ip <stored_model_node_ip> --inference-engine mlx
```
- Step3.
```shell
curl http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Local/<model_name>",
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
"temperature": 0.7
}'
```
### Model Storage

Models by default are stored in `~/.cache/huggingface/hub`.

You can set a different model storage location by setting the `HF_HOME` env var.

### Local Model Storage

Please stored models in `~/.cache/exo`.

You can set a different model storage location by setting the `HOME` env var

## Debugging

Enable debug logs with the DEBUG environment variable (0-9).
Expand Down
117 changes: 89 additions & 28 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4

from exo.download.storedhost.sh_shard_download import SHShardDownloader
from exo.download.storedhost.sh_helpers import get_repo_root, downlaod_tokenizer_config

class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
Expand Down Expand Up @@ -251,10 +254,16 @@ async def handle_model_support(self, request):
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
if 'Local/' in model_name:
downloader = SHShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
else:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()

download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
Expand Down Expand Up @@ -284,7 +293,24 @@ async def handle_model_support(self, request):
)

async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
models_list = []
for model_name, _ in model_cards.items():
if "Local" in model_name:
model_data = {
"id": model_name,
"object": "model",
"owned_by": "Local",
"ready": True
}
else:
model_data = {
"id": model_name,
"object": "model",
"owned_by": "exo",
"ready": True
}
models_list.append(model_data)
return web.json_response(models_list)

async def handle_post_chat_token_encode(self, request):
data = await request.json()
Expand All @@ -296,7 +322,11 @@ async def handle_post_chat_token_encode(self, request):
model = self.default_model
shard = build_base_shard(model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if "Local" in shard.model_id:
model_path = model_cards.get(shard.model_id, {}).get("repo", {}).get(self.inference_engine_classname,{})
tokenizer = await resolve_tokenizer(model_path)
else:
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
tokens = tokenizer.encode(prompt)
return web.json_response({
Expand Down Expand Up @@ -325,6 +355,7 @@ async def handle_post_chat_completions(self, request):
if not chat_request.model or chat_request.model not in model_cards:
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
chat_request.model = self.default_model

shard = build_base_shard(chat_request.model, self.inference_engine_classname)
if not shard:
supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
Expand All @@ -333,7 +364,13 @@ async def handle_post_chat_completions(self, request):
status=400,
)

tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if os.path.isfile(chat_request.model):
model_path = model_cards.get(shard.model_id, {}).get("repo", {}).get(self.inference_engine_classname,{})
tokenizer = await resolve_tokenizer(model_path)
else:
if shard.model_id.startswith("Local"):
await downlaod_tokenizer_config(shard.model_id)
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
Expand Down Expand Up @@ -465,30 +502,54 @@ async def handle_delete_model(self, request):

repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")

if repo_id.startswith("local"):
lm_dir = get_repo_root(repo_id)
print(lm_dir)
if DEBUG >= 2: print(f"Looking for model files in: {lm_dir}")

if os.path.exists(lm_dir):
if DEBUG >= 2: print(f"Found model files at {lm_dir}, deleting...")
try:
shutil.rmtree(lm_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(lm_dir)
})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
else:
return web.json_response({
"detail": f"Model files not found at {lm_dir}"
}, status=404)

else:
# Get the HF cache directory using the helper function
hf_home = get_hf_home()
cache_dir = get_repo_root(repo_id)

# Get the HF cache directory using the helper function
hf_home = get_hf_home()
cache_dir = get_repo_root(repo_id)

if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")

if os.path.exists(cache_dir):
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
except Exception as e:
if os.path.exists(cache_dir):
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
else:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
else:
return web.json_response({
"detail": f"Model files not found at {cache_dir}"
}, status=404)
"detail": f"Model files not found at {cache_dir}"
}, status=404)

except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
Expand Down
31 changes: 24 additions & 7 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import aiohttp
from aiofiles import os as aios

from exo.download.storedhost.sh_helpers import (
get_exo_home, get_lh_weight_map, fetch_lh_file_list, download_model_dir
)

class HFShardDownloader(ShardDownloader):
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
Expand All @@ -36,7 +39,12 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
return self.completed_downloads[shard]
if self.quick_check:
repo_root = get_repo_root(repo_name)
snapshots_dir = repo_root/"snapshots"
if 'Local' in shard.model_id:
model_name = shard.model_id.split('/')[1]
snapshots_dir = Path(get_exo_home()/model_name)
else: # HF model
snapshots_dir = repo_root/"snapshots"

if snapshots_dir.exists():
visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
if visible_dirs:
Expand Down Expand Up @@ -80,11 +88,16 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
async def wrapped_progress_callback(event: RepoProgressEvent):
self._on_progress.trigger_all(shard, event)

weight_map = await get_weight_map(repo_name)
allow_patterns = get_allow_patterns(weight_map, shard)

return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)

if 'Local' in repo_name:
weight_map = await get_lh_weight_map(repo_name)
allow_patterns = get_allow_patterns(weight_map, shard)
return await download_model_dir(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
else: # HF
weight_map = await get_weight_map(repo_name)
allow_patterns = get_allow_patterns(weight_map, shard)
return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)


@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
Expand Down Expand Up @@ -120,7 +133,11 @@ async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int
downloaded_bytes = 0

async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
if self.current_repo_id.startswith("Local"): # 看要不要換 check_agent
file_list = await fetch_lh_file_list(session, self.current_repo_id)
else:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)

relevant_files = list(
filter_repo_objects(
file_list, allow_patterns=patterns, key=lambda x: x["path"]))
Expand Down
42 changes: 42 additions & 0 deletions exo/download/storedhost/http/http_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from aiohttp import web
import aiohttp_cors
import os
import json
from pathlib import Path
from datetime import datetime
from typing import Dict
from exo import DEBUG

async def download_model(model_name, target_dir):
url = f"http://localhost:52525/models/{model_name}/download"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
current_file = None
current_writer = None

async for line in response.content:
try:
# Try parsing as metadata
metadata = json.loads(line)
if current_writer:
await current_writer.close()

# Setup new file
filepath = os.path.join(target_dir, metadata['filename'])
os.makedirs(os.path.dirname(filepath), exist_ok=True)
current_file = open(filepath, 'wb')
current_writer = current_file
continue
except json.JSONDecodeError:
pass

# Check for EOF marker
if line.strip() == b'EOF':
if current_writer:
await current_writer.close()
current_writer = None
continue

# Write chunk to current file
if current_writer:
current_writer.write(line)
Loading