-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add e2b code artifact tool support for the FastAPI template (#339)
- Loading branch information
Showing
11 changed files
with
444 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
"create-llama": patch | ||
--- | ||
|
||
Add e2b code artifact tool for the FastAPI template |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
templates/components/engines/python/agent/tools/artifact.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import logging | ||
from typing import Dict, List, Optional | ||
|
||
from llama_index.core.base.llms.types import ChatMessage | ||
from llama_index.core.settings import Settings | ||
from llama_index.core.tools import FunctionTool | ||
from pydantic import BaseModel, Field | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Prompt based on https://github.com/e2b-dev/ai-artifacts | ||
CODE_GENERATION_PROMPT = """You are a skilled software engineer. You do not make mistakes. Generate an artifact. You can install additional dependencies. You can use one of the following templates: | ||
1. code-interpreter-multilang: "Runs code as a Jupyter notebook cell. Strong data analysis angle. Can use complex visualisation to explain results.". File: script.py. Dependencies installed: python, jupyter, numpy, pandas, matplotlib, seaborn, plotly. Port: none. | ||
2. nextjs-developer: "A Next.js 13+ app that reloads automatically. Using the pages router.". File: pages/index.tsx. Dependencies installed: [email protected], typescript, @types/node, @types/react, @types/react-dom, postcss, tailwindcss, shadcn. Port: 3000. | ||
3. vue-developer: "A Vue.js 3+ app that reloads automatically. Only when asked specifically for a Vue app.". File: app.vue. Dependencies installed: vue@latest, [email protected], tailwindcss. Port: 3000. | ||
4. streamlit-developer: "A streamlit app that reloads automatically.". File: app.py. Dependencies installed: streamlit, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 8501. | ||
5. gradio-developer: "A gradio app. Gradio Blocks/Interface should be called demo.". File: app.py. Dependencies installed: gradio, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 7860. | ||
Make sure to use the correct syntax for the programming language you're using. | ||
""" | ||
|
||
|
||
class CodeArtifact(BaseModel): | ||
commentary: str = Field( | ||
..., | ||
description="Describe what you're about to do and the steps you want to take for generating the artifact in great detail.", | ||
) | ||
template: str = Field( | ||
..., description="Name of the template used to generate the artifact." | ||
) | ||
title: str = Field(..., description="Short title of the artifact. Max 3 words.") | ||
description: str = Field( | ||
..., description="Short description of the artifact. Max 1 sentence." | ||
) | ||
additional_dependencies: List[str] = Field( | ||
..., | ||
description="Additional dependencies required by the artifact. Do not include dependencies that are already included in the template.", | ||
) | ||
has_additional_dependencies: bool = Field( | ||
..., | ||
description="Detect if additional dependencies that are not included in the template are required by the artifact.", | ||
) | ||
install_dependencies_command: str = Field( | ||
..., | ||
description="Command to install additional dependencies required by the artifact.", | ||
) | ||
port: Optional[int] = Field( | ||
..., | ||
description="Port number used by the resulted artifact. Null when no ports are exposed.", | ||
) | ||
file_path: str = Field( | ||
..., description="Relative path to the file, including the file name." | ||
) | ||
code: str = Field( | ||
..., | ||
description="Code generated by the artifact. Only runnable code is allowed.", | ||
) | ||
|
||
|
||
class CodeGeneratorTool: | ||
def __init__(self): | ||
pass | ||
|
||
def artifact(self, query: str, old_code: Optional[str] = None) -> Dict: | ||
"""Generate a code artifact based on the input. | ||
Args: | ||
query (str): The description of the application you want to build. | ||
old_code (Optional[str], optional): The existing code to be modified. Defaults to None. | ||
Returns: | ||
Dict: A dictionary containing the generated artifact information. | ||
""" | ||
|
||
if old_code: | ||
user_message = f"{query}\n\nThe existing code is: \n```\n{old_code}\n```" | ||
else: | ||
user_message = query | ||
|
||
messages: List[ChatMessage] = [ | ||
ChatMessage(role="system", content=CODE_GENERATION_PROMPT), | ||
ChatMessage(role="user", content=user_message), | ||
] | ||
try: | ||
sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) | ||
response = sllm.chat(messages) | ||
data: CodeArtifact = response.raw | ||
return data.model_dump() | ||
except Exception as e: | ||
logger.error(f"Failed to generate artifact: {str(e)}") | ||
raise e | ||
|
||
|
||
def get_tools(**kwargs): | ||
return [FunctionTool.from_defaults(fn=CodeGeneratorTool().artifact)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright 2024 FoundryLabs, Inc. and LlamaIndex, Inc. | ||
# Portions of this file are copied from the e2b project (https://github.com/e2b-dev/ai-artifacts) and then converted to Python | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import base64 | ||
import logging | ||
import os | ||
import uuid | ||
from typing import Dict, List, Optional, Union | ||
|
||
from app.engine.tools.artifact import CodeArtifact | ||
from app.engine.utils.file_helper import save_file | ||
from e2b_code_interpreter import CodeInterpreter, Sandbox | ||
from fastapi import APIRouter, HTTPException, Request | ||
from pydantic import BaseModel | ||
|
||
logger = logging.getLogger("uvicorn") | ||
|
||
sandbox_router = APIRouter() | ||
|
||
SANDBOX_TIMEOUT = 10 * 60 # timeout in seconds | ||
MAX_DURATION = 60 # max duration in seconds | ||
|
||
|
||
class ExecutionResult(BaseModel): | ||
template: str | ||
stdout: List[str] | ||
stderr: List[str] | ||
runtime_error: Optional[Dict[str, Union[str, List[str]]]] = None | ||
output_urls: List[Dict[str, str]] | ||
url: Optional[str] | ||
|
||
def to_response(self): | ||
""" | ||
Convert the execution result to a response object (camelCase) | ||
""" | ||
return { | ||
"template": self.template, | ||
"stdout": self.stdout, | ||
"stderr": self.stderr, | ||
"runtimeError": self.runtime_error, | ||
"outputUrls": self.output_urls, | ||
"url": self.url, | ||
} | ||
|
||
|
||
@sandbox_router.post("") | ||
async def create_sandbox(request: Request): | ||
request_data = await request.json() | ||
|
||
try: | ||
artifact = CodeArtifact(**request_data["artifact"]) | ||
except Exception: | ||
logger.error(f"Could not create artifact from request data: {request_data}") | ||
return HTTPException( | ||
status_code=400, detail="Could not create artifact from the request data" | ||
) | ||
|
||
sbx = None | ||
|
||
# Create an interpreter or a sandbox | ||
if artifact.template == "code-interpreter-multilang": | ||
sbx = CodeInterpreter(api_key=os.getenv("E2B_API_KEY"), timeout=SANDBOX_TIMEOUT) | ||
logger.debug(f"Created code interpreter {sbx}") | ||
else: | ||
sbx = Sandbox( | ||
api_key=os.getenv("E2B_API_KEY"), | ||
template=artifact.template, | ||
metadata={"template": artifact.template, "user_id": "default"}, | ||
timeout=SANDBOX_TIMEOUT, | ||
) | ||
logger.debug(f"Created sandbox {sbx}") | ||
|
||
# Install packages | ||
if artifact.has_additional_dependencies: | ||
if isinstance(sbx, CodeInterpreter): | ||
sbx.notebook.exec_cell(artifact.install_dependencies_command) | ||
logger.debug( | ||
f"Installed dependencies: {', '.join(artifact.additional_dependencies)} in code interpreter {sbx}" | ||
) | ||
elif isinstance(sbx, Sandbox): | ||
sbx.commands.run(artifact.install_dependencies_command) | ||
logger.debug( | ||
f"Installed dependencies: {', '.join(artifact.additional_dependencies)} in sandbox {sbx}" | ||
) | ||
|
||
# Copy code to disk | ||
if isinstance(artifact.code, list): | ||
for file in artifact.code: | ||
sbx.files.write(file.file_path, file.file_content) | ||
logger.debug(f"Copied file to {file.file_path}") | ||
else: | ||
sbx.files.write(artifact.file_path, artifact.code) | ||
logger.debug(f"Copied file to {artifact.file_path}") | ||
|
||
# Execute code or return a URL to the running sandbox | ||
if artifact.template == "code-interpreter-multilang": | ||
result = sbx.notebook.exec_cell(artifact.code or "") | ||
output_urls = _download_cell_results(result.results) | ||
return ExecutionResult( | ||
template=artifact.template, | ||
stdout=result.logs.stdout, | ||
stderr=result.logs.stderr, | ||
runtime_error=result.error, | ||
output_urls=output_urls, | ||
url=None, | ||
).to_response() | ||
else: | ||
return ExecutionResult( | ||
template=artifact.template, | ||
stdout=[], | ||
stderr=[], | ||
runtime_error=None, | ||
output_urls=[], | ||
url=f"https://{sbx.get_host(artifact.port or 80)}", | ||
).to_response() | ||
|
||
|
||
def _download_cell_results(cell_results: Optional[List]) -> List[Dict[str, str]]: | ||
""" | ||
To pull results from code interpreter cell and save them to disk for serving | ||
""" | ||
if not cell_results: | ||
return [] | ||
|
||
output = [] | ||
for result in cell_results: | ||
try: | ||
formats = result.formats() | ||
for ext in formats: | ||
data = result[ext] | ||
|
||
if ext in ["png", "svg", "jpeg", "pdf"]: | ||
file_path = f"output/tools/{uuid.uuid4()}.{ext}" | ||
base64_data = data | ||
buffer = base64.b64decode(base64_data) | ||
file_meta = save_file(content=buffer, file_path=file_path) | ||
output.append( | ||
{ | ||
"type": ext, | ||
"filename": file_meta.filename, | ||
"url": file_meta.url, | ||
} | ||
) | ||
except Exception as e: | ||
logger.error(f"Error processing result: {str(e)}") | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from fastapi import APIRouter | ||
|
||
from .chat import chat_router # noqa: F401 | ||
from .chat_config import config_router # noqa: F401 | ||
from .upload import file_upload_router # noqa: F401 | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(chat_router, prefix="/chat") | ||
api_router.include_router(config_router, prefix="/chat/config") | ||
api_router.include_router(file_upload_router, prefix="/chat/upload") | ||
|
||
# Dynamically adding additional routers if they exist | ||
try: | ||
from .sandbox import sandbox_router # noqa: F401 | ||
|
||
api_router.include_router(sandbox_router, prefix="/sandbox") | ||
except ImportError: | ||
pass |
Oops, something went wrong.