Skip to content

Commit

Permalink
Extend Hype to support creating FastAPI apps for functions (#2)
Browse files Browse the repository at this point in the history
* Add http extras

Install http extra by default

* Add Accept header parsing for content negotiation

* Add Prefer header parsing for request preferences

* Add HTTP Problem Details object and FastAPI exception handling

* Add blocking HTTP request preference tests

* Add create_fastapi_app function

* Use hype.up consistently

Rename export to wrap

Rename from_function to validate

* Add missing type annotation for json_schema method

* Raise explicitly from None to silence warnings

* Remove unused imports

* Update uv.lock

* Replace testing group with dev dependencies

* Add missing python-multipart dependency

* Add Python 3.13 to test matrix
  • Loading branch information
mattt authored Oct 23, 2024
1 parent 3805ba3 commit 63f9103
Show file tree
Hide file tree
Showing 17 changed files with 2,099 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
- "3.10"
- "3.11"
- "3.12"
- "3.13"

steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 3 additions & 3 deletions examples/tool_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

import anthropic

from hype.function import export
import hype
from hype.tools.anthropic import create_anthropic_tools

Number = TypeVar("Number", int, float)


@export
@hype.up
def calculate(expression: str) -> Number:
"""
A simple calculator that performs basic arithmetic operations.
Expand Down Expand Up @@ -65,7 +65,7 @@ def evaluate(node: ast.AST) -> Number:
return evaluate(tree.body)


@export
@hype.up
def prime_factors(n: int) -> list[int]:
"""
Calculate the prime factors of a given number efficiently.
Expand Down
14 changes: 13 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@ requires-python = ">=3.10"
dependencies = ["pydantic>=2.0", "docstring-parser>=0.16"]

[project.optional-dependencies]
testing = ["pytest"]
http = [
"fastapi>=0.100.0",
"httpx>=0.27.2",
"opentelemetry-instrumentation-fastapi",
"python-multipart",
"uvicorn",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.uv]
dev-dependencies = ["pytest>=8.3.3", "ruff>=0.7.0"]

[tool.uv.pip]
extra = ["http"]

[tool.pylint.main]
disable = [
"C0114", # Missing module docstring
Expand Down
4 changes: 3 additions & 1 deletion src/hype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from hype.function import export as up
from hype.function import wrap as up
from hype.http import create_fastapi_app
from hype.tools.anthropic import create_anthropic_tools
from hype.tools.ollama import create_ollama_tools
from hype.tools.openai import create_openai_tools

__all__ = [
"up",
"create_fastapi_app",
"create_anthropic_tools",
"create_openai_tools",
"create_ollama_tools",
Expand Down
23 changes: 14 additions & 9 deletions src/hype/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
validate_call,
)
from pydantic.fields import FieldInfo
from pydantic.json_schema import models_json_schema
from pydantic.json_schema import JsonSchemaValue, models_json_schema

Input = ParamSpec("Input")
Output = TypeVar("Output")
Expand All @@ -38,21 +38,26 @@ class Function(BaseModel, Generic[Input, Output]):
output: type[BaseModel]

@classmethod
def from_function(cls, func: Callable[Input, Output]) -> "Function[Input, Output]":
name = func.__name__
def validate(cls, value: Callable[Input, Output]) -> "Function[Input, Output]":
if isinstance(value, Function):
return value
if not callable(value):
raise TypeError("value must be callable")

docstring = parse_docstring(func.__doc__ or "")
name = value.__name__

docstring = parse_docstring(value.__doc__ or "")
description = docstring.description

input, output = input_and_output_types(func, docstring)
input, output = input_and_output_types(value, docstring)

function = cls(
name=name,
description=description,
input=input,
output=output,
)
function._wrapped = func
function._wrapped = value
return function

def __call__(self, *args: Input.args, **kwargs: Input.kwargs) -> Output: # pylint: disable=no-member
Expand All @@ -67,7 +72,7 @@ def output_schema(self) -> dict[str, Any]:
return self.output.model_json_schema()

@property
def json_schema(self, title: str | None = None):
def json_schema(self, title: str | None = None) -> JsonSchemaValue:
_, top_level_schema = models_json_schema(
[(self.input, "validation"), (self.output, "validation")],
title=title or self.name,
Expand Down Expand Up @@ -141,5 +146,5 @@ class Output(RootModel[T]): # pylint: disable=redefined-outer-name
return input, output


def export(func: Callable[Input, Output]) -> Function[Input, Output]:
return Function.from_function(func)
def wrap(function: Callable[Input, Output]) -> Function[Input, Output]:
return Function.validate(function)
166 changes: 166 additions & 0 deletions src/hype/http/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import asyncio
import warnings
from contextlib import asynccontextmanager
from typing import Annotated

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor


from docstring_parser import parse as parse_docstring
from fastapi import APIRouter, FastAPI, File, Header, HTTPException, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import BaseModel, create_model

from hype.function import Function
from hype.http.prefer import parse_prefer_headers
from hype.http.problem import Problem, problem_exception_handler
from hype.task import Tasks


class FileUploadRequest(BaseModel):
file: UploadFile


class FileUploadResponse(BaseModel):
ok: bool


def create_file_upload_callback_router(source_operation_id: str) -> APIRouter:
router = APIRouter()

@router.put(
"{$callback_url}/files/{$request.body.id}",
response_model=FileUploadResponse,
operation_id=f"{source_operation_id}_file_upload_callback",
summary="File upload callback endpoint",
)
def upload_file(
request: FileUploadRequest = File(...), # pylint: disable=unused-argument
) -> FileUploadResponse:
return FileUploadResponse(ok=True)

return router


def add_fastapi_endpoint(
app: FastAPI,
func: Function,
) -> None:
path = f"/{func.name}"

# Create a new input model with a unique name

name = func.name
docstring = parse_docstring(func._wrapped.__doc__ or "") # pylint: disable=protected-access
summary = docstring.short_description
description = docstring.long_description
operation_id = func.name

input = create_model(
f"{operation_id}_Input",
__base__=func.input,
)

output = create_model(
f"{operation_id}_Output",
__base__=func.output,
)

@app.post(
path,
name=name,
summary=summary,
description=description,
operation_id=operation_id,
callbacks=create_file_upload_callback_router(operation_id).routes,
responses={
"default": {"model": Problem, "description": "Default error response"}
},
)
async def endpoint(
input: input, # type: ignore
prefer: Annotated[list[str] | None, Header()] = None,
) -> output: # type: ignore
preferences = parse_prefer_headers(prefer)

input_dict = input.model_dump(mode="python")
if asyncio.iscoroutinefunction(func):
task = asyncio.create_task(func(**input_dict))
else:
coroutine = asyncio.to_thread(func, **input_dict)
task = asyncio.create_task(coroutine)

id = app.state.tasks.defer(task)
done, _ = await asyncio.wait(
[task],
timeout=preferences.wait,
return_when=asyncio.FIRST_COMPLETED,
)
if done:
return done.pop().result()
else:
# If task was not completed within `wait` seconds, return the 202 response.
return JSONResponse(
status_code=202, content=None, headers={"Location": f"/tasks/{id}"}
)


def create_fastapi_app(
functions: list[Function],
title: str = "Hype API",
summary: str | None = None,
description: str = "",
version: str = "0.1.0",
) -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI): # noqa: ANN202
app.state.tasks = Tasks()

for function in functions:
add_fastapi_endpoint(app, function)

yield

await app.state.tasks.wait_until_empty()

app = FastAPI(
title=title,
summary=summary,
description=description,
version=version,
lifespan=lifespan,
)

FastAPIInstrumentor.instrument_app(app)

app.add_exception_handler(ValueError, problem_exception_handler)
app.add_exception_handler(HTTPException, problem_exception_handler)
app.add_exception_handler(RequestValidationError, problem_exception_handler)

@app.get("/tasks/{id}", include_in_schema=False)
def get_task(id: str) -> JSONResponse:
task = app.state.tasks.get(id)

if task is None:
raise HTTPException(status_code=404, detail="Task not found") from None

return JSONResponse(status_code=200, content=task.to_dict())

@app.post("/tasks/{id}/cancel", include_in_schema=False)
def cancel_task(id: str) -> JSONResponse:
task = app.state.tasks.get(id)

if task is None:
raise HTTPException(status_code=404, detail="Task not found") from None

task.cancel()
return JSONResponse(status_code=200, content=task.to_dict())

@app.get("/openapi.json", include_in_schema=False)
def get_openapi_schema() -> dict:
return app.openapi()

return app
Loading

0 comments on commit 63f9103

Please sign in to comment.