Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Create FastAPI #15

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/.venv
.DS_Store
/output
/__pycache__
__pycache__
.env
1 change: 1 addition & 0 deletions genmoji/.env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HF_TOKEN=
Empty file added genmoji/domain/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions metaprompt/open-genmoji.md → genmoji/domain/prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
You are helping create a prompt for a Emoji generation image model. An emoji must be easily interpreted when small so details must be exaggerated to be clear. Your goal is to use descriptions to achieve this.
SYSTEM_PROMPT = """You are helping create a prompt for a Emoji generation image model. An emoji must be easily interpreted when small so details must be exaggerated to be clear. Your goal is to use descriptions to achieve this.

You will receive a user description, and you must rephrase it to consist of short phrases separated by periods, adding detail to everything the user provides.

Expand All @@ -16,4 +16,4 @@
- "head is turned towards viewer.": ONLY humans or animals
- "detailed texture.": ONLY objects

Further addon phrases may be added to ensure the clarity of the emoji.
Further addon phrases may be added to ensure the clarity of the emoji."""
17 changes: 17 additions & 0 deletions genmoji/domain/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel
from typing import Optional


class DownloadModelRequest(BaseModel):
huggingface_repo: str
model_name: str


class GenerationRequest(BaseModel):
prompt: str
lora: Optional[str] = "flux-dev"
llm_model: Optional[str] = "llama3.1:latest"
direct: Optional[bool] = False
height: Optional[int] = 160
width: Optional[int] = 160
upscale_factor: Optional[int] = 5
File renamed without changes.
15 changes: 8 additions & 7 deletions genmoji.py → genmoji/genmoji.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import sys
from promptAssistant import get_prompt_response
from generateImage import generate_image
from PIL import Image
import os
import argparse
import json
import os
import sys
from PIL import Image
from utils.generate_image import generate_image
from utils.llm_utils import model_inference


def get_unique_path(base_path):
def get_unique_path(base_path: str) -> str:
directory = os.path.dirname(base_path)
filename = os.path.basename(base_path)
name, ext = os.path.splitext(filename)
Expand All @@ -31,6 +31,7 @@ def main(
upscale_factor: int,
output_path: str = "output/genmoji.png",
lora: str = "flux-dev",
llm_model: str = "llama3.1:latest"
):
with open("./lora/info.json", "r") as f:
models = json.load(f)
Expand All @@ -57,7 +58,7 @@ def main(
sys.exit(1)
if not direct:
# Get the response from the prompt assistant
prompt_response = get_prompt_response(user_prompt, metaprompt)
prompt_response = model_inference(user_prompt, llm_model).get("message")
print("Prompt Created: " + prompt_response)
elif direct:
prompt_response = user_prompt
Expand Down
105 changes: 105 additions & 0 deletions genmoji/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import io
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from huggingface_hub import hf_hub_download, login
from PIL import Image
from domain.schemas import DownloadModelRequest, GenerationRequest
from utils.generate_image import generate_image
from utils.llm_utils import list_installed_llms, model_inference
from utils.logger import logger

load_dotenv()

app = FastAPI()
login(token=os.getenv("HF_TOKEN"))


@app.post("/download_model")
async def download_model(input_data: DownloadModelRequest) -> JSONResponse:
"""
Downloads model from huggingface
"""
try:
filename = f"{input_data.model_name}.safetensors"
hf_hub_download(
repo_id=input_data.huggingface_repo, filename=filename, local_dir="./lora"
)
return JSONResponse(content={"response": f"Download {input_data.model_name} complete"})
except Exception as e:
logger.error(f"Error downloading model: {e}")
logger.exception(e)
return JSONResponse(content={"response": f"Error downloading model: {e}"}, status_code=500)


@app.get("/installed_img_genmodels")
async def get_installed_img_gen_models() -> JSONResponse:
"""
Get the list of installed image gen models
"""
try:
models = os.listdir("lora/")
models = [model.replace(".safetensors", "") for model in models if not model.startswith(".") and "safetensors" in model]
return JSONResponse(content={"models": models})
except Exception as e:
logger.error(f"Error listing models: {e}")
logger.exception(e)
return JSONResponse(content=f"Error listing models: {e}", status_code=500)


@app.get("/installed_llms")
async def get_installed_llms() -> JSONResponse:
"""
Get the list of installed llms from ollama
"""
try:
return JSONResponse(content={"models": list_installed_llms()})
except Exception as e:
logger.error(f"Error listing models: {e}")
logger.exception(e)
return JSONResponse(content=f"Error listing models: {e}", status_code=500)


@app.post("/inference")
async def inference(input_data: GenerationRequest) -> StreamingResponse:
"""
Perform model inference to generate an emoji. Uses Ollama to handle LLM inference for lora prompt generation
"""
# Check if the lora file exists
lora_path = f"lora/{input_data.lora}.safetensors"
if not os.path.exists(lora_path):
raise HTTPException(
status_code=404,
detail=f"Error: LoRA {input_data.lora} is not downloaded. Please run use the /download_model endpoint to download it.")

if input_data.direct:
user_prompt = input_data.prompt
else:
# Get the response from the prompt assistant
user_prompt = model_inference(
user_prompt=input_data.prompt,
model_name=input_data.llm_model
).get("message")
if "i cannot" in user_prompt.lower() or "i can't" in user_prompt.lower():
logger.warning("Refusal detected from LLM prompt enhancement. Using raw prompt input, generation may be lacking.")
user_prompt = input_data.prompt
else:
logger.info("Prompt Created: " + user_prompt)

# Generate the image using the response from the prompt assistant
image = generate_image(user_prompt, input_data.lora, input_data.width, input_data.height)

output_width, output_height = image.size
resized_image = image.resize(
(
output_width * input_data.upscale_factor,
output_height * input_data.upscale_factor
),
Image.LANCZOS)

img_io = io.BytesIO()
resized_image.save(img_io, 'PNG')
img_io.seek(0)

return StreamingResponse(img_io, media_type="image/png")
6 changes: 6 additions & 0 deletions genmoji/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi==0.115.6
python-dotenv==1.0.1
huggingface-hub==0.27.1
mflux==0.5.1
ollama==0.4.5
uvicorn==0.34.0
Empty file added genmoji/utils/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions generateImage.py → genmoji/utils/generate_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mflux import Flux1, Config, ModelConfig
import os
from mflux import Config, Flux1, ModelConfig


def generate_image(prompt: str, lora: str, width: int, height: int):
Expand All @@ -8,7 +8,7 @@ def generate_image(prompt: str, lora: str, width: int, height: int):
model_config=ModelConfig.FLUX1_DEV,
quantize=8,
lora_paths=[
f"{os.path.abspath(os.path.dirname(__file__))}/lora/{lora}.safetensors"
f"{os.path.abspath(os.path.dirname(__file__))}/../lora/{lora}.safetensors"
],
lora_scales=[1.0],
)
Expand Down
52 changes: 52 additions & 0 deletions genmoji/utils/llm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from ollama import Client, Options, Message, ResponseError
from typing import Dict, List
from domain.prompts import SYSTEM_PROMPT


OLLAMA_HOST_DEFAULT="http://localhost:11434"
OLLAMA_HOST = os.getenv("OLLAMA_HOST")
client = Client(host=OLLAMA_HOST)


def model_inference(
user_prompt: str,
model_name: str,
max_output: int = 128,
temperature: float = 0.7
) -> Dict[str, str]:
try:
# prime conversation with system prompt and few shot
messages = [
Message(role="system", content=SYSTEM_PROMPT),
Message(role="user", content="a horse wearing a suit"),
Message(role="assistant", content="emoji of horse in black suit and tie with flowing mane. a strong, confident stallion wearing formal attire for a special occasion. cute. 3D lighting. no cast shadows. enlarged head in cartoon style. head is turned towards viewer."),
Message(role="user", content="flying pig"),
Message(role="assistant", content="emoji of flying pink pig. enlarged head in cartoon style. cute. white wings. head is turned towards viewer. 3D lighting. no cast shadows."),
Message(role="user", content=user_prompt)
]
response = client.chat(
model=model_name,
messages=messages,
options=Options(temperature=temperature, num_predict=max_output))

return {"message": response.get("message").content}
except ResponseError as e:
if e.status_code == 404:
return {"message": f"Model {model_name} not found"}
return {"message": f"Error performing inference: {e.error}"}
except Exception as e:
return {"message": f"Error performing inference: {e}"}


def list_installed_llms() -> List[str]:
models = client.list()

return [
{
"model_name": model.get("model"),
"family": model.get("details", {}).get("family"),
"param_size": model.get("details", {}).get("parameter_size"),
}
for model in models.get("models")
]
25 changes: 25 additions & 0 deletions genmoji/utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging


def setup_logger(name: str) -> logging.Logger:
"""
Sets up a logger instance with a set log format
"""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)

# Create a formatter for the log messages
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')

# Create a console handler for the log messages
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

# Add the handlers to the logger
# logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger


logger = setup_logger(__name__)
8 changes: 0 additions & 8 deletions metaprompt/open-genmoji.json

This file was deleted.

60 changes: 0 additions & 60 deletions promptAssistant.py

This file was deleted.

36 changes: 0 additions & 36 deletions resize.py

This file was deleted.