Skip to content

Commit

Permalink
Merge pull request #474 from pranav4501/stable-stable-diffusion-mlx
Browse files Browse the repository at this point in the history
Stable diffusion mlx
  • Loading branch information
AlexCheema authored Jan 12, 2025
2 parents bd2e8e7 + 5f3d000 commit b5cbcbc
Show file tree
Hide file tree
Showing 24 changed files with 2,272 additions and 210 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ cython_debug/

**/*.xcodeproj/*
.aider*

exo/tinychat/images/*.png
106 changes: 105 additions & 1 deletion exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
import signal
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable, Optional
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import mlx.core as mx
import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
Expand Down Expand Up @@ -185,6 +191,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
Expand All @@ -195,10 +202,12 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})


if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')

self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
Expand Down Expand Up @@ -457,6 +466,85 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")


async def handle_post_image_generations(self, request):
data = await request.json()

if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
model = data.get("model", "")
prompt = data.get("prompt", "")
image_url = data.get("image_url", "")
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)

request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
try:
if image_url != "" and image_url != None:
img = self.base64_decode(image_url)
else:
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)


response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
await response.prepare(request)

def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step) / total_steps
# Calculate the number of hashes to display
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
spaces = ' ' * (bar_length - len(arrow))

# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar

async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')

elif isinstance(result, np.ndarray):
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder / image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)

await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()


stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished

await callback.wait(on_result, timeout=self.response_timeout*10)

if stream_task:
# Wait for the stream task to complete before returning
await stream_task

return response

except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)

async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
Expand Down Expand Up @@ -598,3 +686,19 @@ async def run(self, host: str = "0.0.0.0", port: int = 52415):
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()

def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim % 64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
img = img[None]
return img

4 changes: 4 additions & 0 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ async def download_repo_files(
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")

model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
if model_index_exists:
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]

filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)
Expand Down
18 changes: 11 additions & 7 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int
print(f"No snapshot directory found for {self.current_repo_id}")
return None

if not await aios.path.exists(snapshot_dir/"model_index.json"):
# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None

# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
else:
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]

# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)

# Check download status for all relevant files
status = {}
Expand Down
21 changes: 20 additions & 1 deletion exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,4 +325,23 @@ async def shutdown(signal, loop, server):
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)


def get_exo_home() -> Path:
if os.name == "nt": # Check if the OS is Windows
docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else:
docs_folder = Path.home() / "Documents"
exo_folder = docs_folder / "Exo"
if not exo_folder.exists():
exo_folder.mkdir()
return exo_folder

def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
if not images_dir.exists():
images_dir.mkdir()
return images_dir

12 changes: 8 additions & 4 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ async def save_session(self, key, value):
async def clear_session(self):
self.session.empty()

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
tokens = await self.encode(shard, prompt)
x = tokens.reshape(1, -1)
output_data = await self.infer_tensor(request_id, shard, x)
return output_data
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
else:
x = tokens
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)

return output_data, inference_state

inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
Expand Down
Loading

0 comments on commit b5cbcbc

Please sign in to comment.