Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for streaming responses #1087

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
update
sansyrox committed Dec 16, 2024
commit 67b7c5dae1aefb490fd75ed16804288df59ad127
207 changes: 207 additions & 0 deletions docs_src/src/pages/documentation/streaming.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Streaming Responses in Robyn

Robyn supports streaming responses for various use cases including real-time data, large file downloads, and server-sent events (SSE). This document explains how to use streaming responses effectively.

## Basic Usage

### Simple Streaming Response

```python
@app.get("/stream")
async def stream():
async def generator():
for i in range(5):
yield f"Chunk {i}\n".encode()

return Response(
status_code=200,
headers={"Content-Type": "text/plain"},
description=generator()
)
```

## Supported Types

Robyn's streaming response system supports multiple data types:

1. **Binary Data** (`bytes`)
```python
yield b"Binary data"
```

2. **Text Data** (`str`)
```python
yield "String data".encode()
```

3. **Numbers** (`int`, `float`)
```python
yield str(42).encode()
```

4. **JSON Data**
```python
import json
data = {"key": "value"}
yield json.dumps(data).encode()
```

## Use Cases

### 1. Server-Sent Events (SSE)

SSE allows real-time updates from server to client:

```python
@app.get("/events")
async def sse():
async def event_generator():
yield f"event: message\ndata: {json.dumps(data)}\n\n".encode()

return Response(
status_code=200,
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
},
description=event_generator()
)
```

Client usage:
```javascript
const evtSource = new EventSource("/events");
evtSource.onmessage = (event) => {
console.log(JSON.parse(event.data));
};
```

### 2. Large File Downloads

Stream large files in chunks to manage memory usage:

```python
@app.get("/download")
async def download():
async def file_generator():
chunk_size = 8192 # 8KB chunks
with open("large_file.bin", "rb") as f:
while chunk := f.read(chunk_size):
yield chunk

return Response(
status_code=200,
headers={
"Content-Type": "application/octet-stream",
"Content-Disposition": "attachment; filename=file.bin"
},
description=file_generator()
)
```

### 3. CSV Generation

Stream CSV data as it's generated:

```python
@app.get("/csv")
async def csv():
async def csv_generator():
yield "header1,header2\n".encode()
for item in data:
yield f"{item.field1},{item.field2}\n".encode()

return Response(
status_code=200,
headers={
"Content-Type": "text/csv",
"Content-Disposition": "attachment; filename=data.csv"
},
description=csv_generator()
)
```

## Best Practices

1. **Always encode your data**
- Convert strings to bytes using `.encode()`
- Use `json.dumps().encode()` for JSON data

2. **Set appropriate headers**
- Use correct Content-Type
- Add Content-Disposition for downloads
- Set Cache-Control for SSE

3. **Handle errors gracefully**
```python
async def generator():
try:
for item in items:
yield process(item)
except Exception as e:
yield f"Error: {str(e)}".encode()
```

4. **Memory management**
- Use appropriate chunk sizes
- Don't hold entire dataset in memory
- Clean up resources after streaming

## Testing

Test streaming responses using the test client:

```python
@pytest.mark.asyncio
async def test_stream():
async with app.test_client() as client:
response = await client.get("/stream")
chunks = []
async for chunk in response.content:
chunks.append(chunk)
# Assert on chunks
```

## Common Issues

1. **Forgetting to encode data**
```python
# Wrong
yield "data" # Will fail
# Correct
yield "data".encode()
```

2. **Not setting correct headers**
```python
# SSE needs specific headers
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
```

3. **Memory leaks**
```python
# Wrong
data = []
async def generator():
for i in range(1000000):
data.append(i) # Memory leak
yield str(i).encode()

# Correct
async def generator():
for i in range(1000000):
yield str(i).encode()
```

## Performance Considerations

1. Use appropriate chunk sizes (typically 8KB-64KB)
2. Implement backpressure handling
3. Consider using async file I/O for large files
4. Monitor memory usage during streaming
5. Implement timeouts for long-running streams
58 changes: 25 additions & 33 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@

from integration_tests.subroutes import di_subrouter, sub_router
from integration_tests.views import AsyncView, SyncView
from robyn import Headers, Request, Response, Robyn, WebSocket, WebSocketConnector, jsonify, serve_file, serve_html
from robyn import Headers, Request, Response, Robyn, WebSocket, WebSocketConnector, jsonify, serve_file, serve_html, StreamingResponse
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.robyn import QueryParams, Url
from robyn.templating import JinjaTemplate
@@ -159,7 +159,10 @@ def sync_before_request(request: Request):
@app.after_request("/sync/middlewares")
def sync_after_request(response: Response):
response.headers.set("after", "sync_after_request")
response.description = response.description + " after"
if isinstance(response.description, bytes):
response.description = response.description + b" after"
else:
response.description = response.description + " after"
return response


@@ -180,7 +183,10 @@ async def async_before_request(request: Request):
@app.after_request("/async/middlewares")
async def async_after_request(response: Response):
response.headers.set("after", "async_after_request")
response.description = response.description + " after"
if isinstance(response.description, bytes):
response.description = response.description + b" after"
else:
response.description = response.description + " after"
return response


@@ -1085,56 +1091,42 @@ def create_item(request, body: CreateItemBody, query: CreateItemQueryParamsParam
# --- Streaming responses ---

@app.get("/stream/sync")
def sync_stream():
def number_generator():
async def sync_stream():
def generator():
for i in range(5):
yield f"Chunk {i}\n".encode()

return Response(
headers = Headers({"Content-Type": "text/plain"})
return StreamingResponse(
status_code=200,
headers={"Content-Type": "text/plain"},
description=number_generator()
description=generator(),
headers=headers
)

@app.get("/stream/async")
async def async_stream():
async def async_generator():
import asyncio
async def generator():
for i in range(5):
await asyncio.sleep(1) # Simulate async work
yield f"Async Chunk {i}\n".encode()

return Response(
return StreamingResponse(
status_code=200,
headers={"Content-Type": "text/plain"},
description=async_generator()
description=generator()
)

@app.get("/stream/mixed")
async def mixed_stream():
async def mixed_generator():
import asyncio
# Binary data
async def generator():
yield b"Binary chunk\n"
await asyncio.sleep(0.5)

# String data
yield "String chunk\n".encode()
await asyncio.sleep(0.5)

# Integer data
yield str(42).encode() + b"\n"
await asyncio.sleep(0.5)

# JSON data
import json
data = {"message": "JSON chunk", "number": 123}
yield json.dumps(data).encode() + b"\n"
yield json.dumps({"message": "JSON chunk", "number": 123}).encode() + b"\n"

return Response(
return StreamingResponse(
status_code=200,
headers={"Content-Type": "text/plain"},
description=mixed_generator()
description=generator()
)

@app.get("/stream/events")
@@ -1156,7 +1148,7 @@ async def event_generator():
data = json.dumps({'status': 'complete', 'results': [1, 2, 3]}, indent=2)
yield f"event: complete\ndata: {data}\n\n".encode()

return Response(
return StreamingResponse(
status_code=200,
headers={
"Content-Type": "text/event-stream",
@@ -1178,7 +1170,7 @@ async def file_generator():
chunk = b"X" * min(chunk_size, total_size - offset)
yield chunk

return Response(
return StreamingResponse(
status_code=200,
headers={
"Content-Type": "application/octet-stream",
@@ -1202,7 +1194,7 @@ async def csv_generator():
row = f"{i},item-{i},{random.randint(1, 100)}\n"
yield row.encode()

return Response(
return StreamingResponse(
status_code=200,
headers={
"Content-Type": "text/csv",
5 changes: 4 additions & 1 deletion integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,12 @@
from typing import List

import pytest
import pytest_asyncio
from robyn import Robyn
from integration_tests.base_routes import app

from integration_tests.helpers.network_helpers import get_network_host


def spawn_process(command: List[str]) -> subprocess.Popen:
if platform.system() == "Windows":
command[0] = "python"
@@ -127,3 +129,4 @@ def env_file():
env_path.unlink()
del os.environ["ROBYN_PORT"]
del os.environ["ROBYN_HOST"]

288 changes: 118 additions & 170 deletions integration_tests/test_streaming_responses.py
Original file line number Diff line number Diff line change
@@ -14,181 +14,129 @@

import json
import pytest
from robyn import Robyn
from robyn.robyn import Request
from integration_tests.base_routes import app
import aiohttp

# Mark all tests in this module as async
pytestmark = pytest.mark.asyncio

@pytest.mark.asyncio
async def test_sync_stream():
"""Test basic synchronous streaming response.
Verifies that:
1. Response has correct content type
2. Chunks are received in correct order
3. Each chunk has expected format
"""
async with app.test_client() as client:
response = await client.get("/stream/sync")
assert response.status_code == 200
assert response.headers["Content-Type"] == "text/plain"

chunks = []
async for chunk in response.content:
chunks.append(chunk.decode())

assert len(chunks) == 5
for i, chunk in enumerate(chunks):
assert chunk == f"Chunk {i}\n"


@pytest.mark.asyncio
"""Test basic synchronous streaming response."""
async with aiohttp.ClientSession() as client:
async with client.get("http://127.0.0.1:8080/stream/sync") as response:
assert response.status == 200
assert response.headers["Content-Type"] == "text/plain"

chunks = []
async for chunk in response.content:
chunks.append(chunk.decode())

assert len(chunks) == 5
for i, chunk in enumerate(chunks):
assert chunk == f"Chunk {i}\n"

async def test_async_stream():
"""Test asynchronous streaming response.
Verifies that:
1. Response has correct content type
2. Chunks are received in correct order with delays
3. Each chunk has expected format
"""
async with app.test_client() as client:
response = await client.get("/stream/async")
assert response.status_code == 200
assert response.headers["Content-Type"] == "text/plain"

chunks = []
async for chunk in response.content:
chunks.append(chunk.decode())

assert len(chunks) == 5
for i, chunk in enumerate(chunks):
assert chunk == f"Async Chunk {i}\n"


@pytest.mark.asyncio
"""Test asynchronous streaming response."""
async with aiohttp.ClientSession() as client:
async with client.get("http://127.0.0.1:8080/stream/async") as response:
assert response.status == 200
assert response.headers["Content-Type"] == "text/plain"

chunks = []
async for chunk in response.content:
chunks.append(chunk.decode())

assert len(chunks) == 5
for i, chunk in enumerate(chunks):
assert chunk == f"Async Chunk {i}\n"

async def test_mixed_stream():
"""Test streaming of mixed content types.
Verifies that:
1. Response handles different content types:
- Binary data
- String data
- Integer data
- JSON data
2. Each chunk is correctly encoded
"""
async with app.test_client() as client:
response = await client.get("/stream/mixed")
assert response.status_code == 200
assert response.headers["Content-Type"] == "text/plain"

expected = [
b"Binary chunk\n",
b"String chunk\n",
b"42\n",
json.dumps({"message": "JSON chunk", "number": 123}).encode() + b"\n"
]

chunks = []
async for chunk in response.content:
chunks.append(chunk)

assert len(chunks) == len(expected)
for chunk, expected_chunk in zip(chunks, expected):
assert chunk == expected_chunk


@pytest.mark.asyncio
"""Test streaming of mixed content types."""
async with aiohttp.ClientSession() as client:
async with client.get("http://127.0.0.1:8080/stream/mixed") as response:
assert response.status == 200
assert response.headers["Content-Type"] == "text/plain"

expected = [
b"Binary chunk\n",
b"String chunk\n",
b"42\n",
json.dumps({"message": "JSON chunk", "number": 123}).encode() + b"\n"
]

chunks = []
async for chunk in response.content:
chunks.append(chunk)

assert len(chunks) == len(expected)
for chunk, expected_chunk in zip(chunks, expected):
assert chunk == expected_chunk

async def test_server_sent_events():
"""Test Server-Sent Events (SSE) streaming.
Verifies that:
1. Response has correct SSE headers
2. Events are properly formatted with:
- Event type
- Event ID (when provided)
- Event data
"""
async with app.test_client() as client:
response = await client.get("/stream/events")
assert response.status_code == 200
assert response.headers["Content-Type"] == "text/event-stream"
assert response.headers["Cache-Control"] == "no-cache"
assert response.headers["Connection"] == "keep-alive"

events = []
async for chunk in response.content:
events.append(chunk.decode())

# Test first event (message)
assert "event: message\n" in events[0]
assert "data: {" in events[0]
event_data = json.loads(events[0].split("data: ")[1].strip())
assert "time" in event_data
assert event_data["type"] == "start"

# Test second event (with ID)
assert "id: 1\n" in events[1]
assert "event: update\n" in events[1]
event_data = json.loads(events[1].split("data: ")[1].strip())
assert event_data["progress"] == 50

# Test third event (complete)
assert "event: complete\n" in events[2]
event_data = json.loads(events[2].split("data: ")[1].strip())
assert event_data["status"] == "complete"
assert event_data["results"] == [1, 2, 3]


@pytest.mark.asyncio
"""Test Server-Sent Events (SSE) streaming."""
async with aiohttp.ClientSession() as client:
async with client.get("http://127.0.0.1:8080/stream/events") as response:
assert response.status == 200
assert response.headers["Content-Type"] == "text/event-stream"
assert response.headers["Cache-Control"] == "no-cache"
assert response.headers["Connection"] == "keep-alive"

events = []
async for chunk in response.content:
events.append(chunk.decode())

# Test first event (message)
assert "event: message\n" in events[0]
assert "data: {" in events[0]
event_data = json.loads(events[0].split("data: ")[1].strip())
assert "time" in event_data
assert event_data["type"] == "start"

# Test second event (with ID)
assert "id: 1\n" in events[1]
assert "event: update\n" in events[1]
event_data = json.loads(events[1].split("data: ")[1].strip())
assert event_data["progress"] == 50

# Test third event (complete)
assert "event: complete\n" in events[2]
event_data = json.loads(events[2].split("data: ")[1].strip())
assert event_data["status"] == "complete"
assert event_data["results"] == [1, 2, 3]

async def test_large_file_stream():
"""Test streaming of large files in chunks.
Verifies that:
1. Response has correct headers for file download
2. Content is streamed in correct chunk sizes
3. Total content length matches expected size
"""
async with app.test_client() as client:
response = await client.get("/stream/large-file")
assert response.status_code == 200
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["Content-Disposition"] == "attachment; filename=large-file.bin"

total_size = 0
async for chunk in response.content:
assert len(chunk) <= 1024 # Max chunk size
total_size += len(chunk)

assert total_size == 10 * 1024 # 10KB total


@pytest.mark.asyncio
"""Test streaming of large files in chunks."""
async with aiohttp.ClientSession() as client:
async with client.get("http://127.0.0.1:8080/stream/large-file") as response:
assert response.status == 200
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["Content-Disposition"] == "attachment; filename=large-file.bin"

total_size = 0
async for chunk in response.content:
assert len(chunk) <= 1024 # Max chunk size
total_size += len(chunk)

assert total_size == 10 * 1024 # 10KB total

async def test_csv_stream():
"""Test streaming of CSV data.
Verifies that:
1. Response has correct CSV headers
2. CSV content is properly formatted
3. All rows are received in correct order
"""
async with app.test_client() as client:
response = await client.get("/stream/csv")
assert response.status_code == 200
assert response.headers["Content-Type"] == "text/csv"
assert response.headers["Content-Disposition"] == "attachment; filename=data.csv"

lines = []
async for chunk in response.content:
lines.extend(chunk.decode().splitlines())

# Verify header
assert lines[0] == "id,name,value"

# Verify data rows
assert len(lines) == 6 # Header + 5 data rows
for i, line in enumerate(lines[1:], 0):
id_, name, value = line.split(',')
assert int(id_) == i
assert name == f"item-{i}"
assert 1 <= int(value) <= 100
"""Test streaming of CSV data."""
async with aiohttp.ClientSession() as client:
async with client.get("http://127.0.0.1:8080/stream/csv") as response:
assert response.status == 200
assert response.headers["Content-Type"] == "text/csv"
assert response.headers["Content-Disposition"] == "attachment; filename=data.csv"

lines = []
async for chunk in response.content:
lines.extend(chunk.decode().splitlines())

# Verify header
assert lines[0] == "id,name,value"

# Verify data rows
assert len(lines) == 6 # Header + 5 data rows
for i, line in enumerate(lines[1:], 0):
id_, name, value = line.split(',')
assert int(id_) == i
assert name == f"item-{i}"
assert 1 <= int(value) <= 100
163 changes: 162 additions & 1 deletion poetry.lock
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -82,6 +82,10 @@ optional = true

[tool.poetry.group.test.dependencies]
pytest = "7.2.1"
pytest-asyncio = "0.21.0"
pytest-cov = "4.0.0"
pytest-xdist = "3.6.1"
pytest-timeout = "2.1.0"
pytest-codspeed = "1.2.2"
requests = "2.28.2"
nox = "2023.4.22"
3 changes: 2 additions & 1 deletion robyn/__init__.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from robyn.processpool import run_processes
from robyn.reloader import compile_rust_files
from robyn.responses import html, serve_file, serve_html
from robyn.robyn import FunctionInfo, Headers, HttpMethod, Request, Response, WebSocketConnector, get_version
from robyn.robyn import FunctionInfo, Headers, HttpMethod, Request, Response, WebSocketConnector, get_version, StreamingResponse
from robyn.router import MiddlewareRouter, MiddlewareType, Router, WebSocketRouter
from robyn.types import Directory
from robyn.ws import WebSocket
@@ -673,6 +673,7 @@ def cors_middleware(request):
"Robyn",
"Request",
"Response",
"StreamingResponse",
"status_codes",
"jsonify",
"serve_file",
16 changes: 15 additions & 1 deletion robyn/responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mimetypes
import os
from typing import Optional
from typing import Optional, Any

from robyn.robyn import Headers, Response

@@ -18,6 +18,20 @@ def __init__(
self.headers = headers or Headers({"Content-Disposition": "attachment"})


class StreamingResponse:
def __init__(
self,
status_code: int = 200,
description: Optional[Any] = None,
headers: Optional[Headers] = None,
):
self.status_code = status_code
self.description = description or []
self.headers = headers or Headers({})
self.response_type = "stream"
self.file_path = None


def html(html: str) -> Response:
"""
This function will help in serving a simple html string
24 changes: 18 additions & 6 deletions robyn/router.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from robyn.dependency_injection import DependencyMap
from robyn.jsonify import jsonify
from robyn.responses import FileResponse
from robyn.robyn import FunctionInfo, Headers, HttpMethod, Identity, MiddlewareType, QueryParams, Request, Response, Url
from robyn.robyn import FunctionInfo, Headers, HttpMethod, Identity, MiddlewareType, QueryParams, Request, Response, StreamingResponse, Url
from robyn.types import Body, Files, FormData, IPAddress, Method, PathParams
from robyn.ws import WebSocket

@@ -47,29 +47,41 @@ def __init__(self) -> None:
super().__init__()
self.routes: List[Route] = []

def _format_tuple_response(self, res: tuple) -> Response:
def _format_tuple_response(self, res: tuple) -> Union[Response, StreamingResponse]:
if len(res) != 3:
raise ValueError("Tuple should have 3 elements")

description, headers, status_code = res
description = self._format_response(description).description
formatted_response = self._format_response(description)

# Handle StreamingResponse case
if isinstance(formatted_response, StreamingResponse):
formatted_response.headers.update(headers)
formatted_response.status_code = status_code
return formatted_response

# Regular Response case
new_headers: Headers = Headers(headers)
if new_headers.contains("Content-Type"):
headers.set("Content-Type", new_headers.get("Content-Type"))

return Response(
status_code=status_code,
headers=headers,
description=description,
description=formatted_response.description,
)

def _format_response(
self,
res: Union[Dict, Response, bytes, tuple, str],
) -> Response:
res: Union[Dict, Response, StreamingResponse, bytes, tuple, str],
) -> Union[Response, StreamingResponse]:
if isinstance(res, Response):
return res

# Special handling for StreamingResponse
if isinstance(res, StreamingResponse):
return res

if isinstance(res, dict):
return Response(
status_code=status_codes.HTTP_200_OK,
119 changes: 119 additions & 0 deletions src/base_routes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use actix_web::{web, HttpRequest, HttpResponse};
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::types::{Headers, Response, StreamingResponse};

pub async fn handle_request(
req: HttpRequest,
path: web::Path<String>,
query: web::Query<std::collections::HashMap<String, String>>,
payload: web::Payload,
app_state: web::Data<PyObject>,
) -> HttpResponse {
let path = path.into_inner();
let query = query.into_inner();

Python::with_gil(|py| {
let app = app_state.as_ref();
let args = PyDict::new(py);

// Convert query params to Python dict
let query_dict = PyDict::new(py);
for (key, value) in query {
query_dict.set_item(key, value).unwrap();
}

// Create headers dict
let headers = Headers::new(None);

// Call the route handler
let result = app.call_method1(
py,
"handle_request",
(path, req.method().as_str(), query_dict, headers),
);

match result {
Ok(response) => {
// Try to extract as StreamingResponse first
match response.extract::<StreamingResponse>(py) {
Ok(streaming_response) => streaming_response.respond_to(&req),
Err(_) => {
// If not a StreamingResponse, try as regular Response
match response.extract::<Response>(py) {
Ok(response) => response.respond_to(&req),
Err(e) => {
// If extraction fails, return 500 error
let headers = Headers::new(None);
Response::internal_server_error(Some(&headers)).respond_to(&req)
}
}
}
}
}
Err(e) => {
// Handle Python error by returning 500
let headers = Headers::new(None);
Response::internal_server_error(Some(&headers)).respond_to(&req)
}
}
})
}

pub async fn handle_request_with_body(
req: HttpRequest,
path: web::Path<String>,
query: web::Query<std::collections::HashMap<String, String>>,
payload: web::Payload,
app_state: web::Data<PyObject>,
) -> HttpResponse {
let path = path.into_inner();
let query = query.into_inner();

Python::with_gil(|py| {
let app = app_state.as_ref();
let args = PyDict::new(py);

// Convert query params to Python dict
let query_dict = PyDict::new(py);
for (key, value) in query {
query_dict.set_item(key, value).unwrap();
}

// Create headers dict
let headers = Headers::new(None);

// Call the route handler
let result = app.call_method1(
py,
"handle_request_with_body",
(path, req.method().as_str(), query_dict, headers, payload),
);

match result {
Ok(response) => {
// Try to extract as StreamingResponse first
match response.extract::<StreamingResponse>(py) {
Ok(streaming_response) => streaming_response.respond_to(&req),
Err(_) => {
// If not a StreamingResponse, try as regular Response
match response.extract::<Response>(py) {
Ok(response) => response.respond_to(&req),
Err(e) => {
// If extraction fails, return 500 error
let headers = Headers::new(None);
Response::internal_server_error(Some(&headers)).respond_to(&req)
}
}
}
}
}
Err(e) => {
// Handle Python error by return 500
let headers = Headers::new(None);
Response::internal_server_error(Some(&headers)).respond_to(&req)
}
}
})
}
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ use types::{
identity::Identity,
multimap::QueryParams,
request::PyRequest,
response::PyResponse,
response::{PyResponse, PyStreamingResponse},
HttpMethod, Url,
};

@@ -42,6 +42,7 @@ pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<Identity>()?;
m.add_class::<PyRequest>()?;
m.add_class::<PyResponse>()?;
m.add_class::<PyStreamingResponse>()?;
m.add_class::<Url>()?;
m.add_class::<QueryParams>()?;
m.add_class::<MiddlewareType>()?;
27 changes: 15 additions & 12 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use log::debug;
use pyo3::{
exceptions::PyValueError,
prelude::*,
@@ -79,8 +78,7 @@ pub fn get_body_from_pyobject(body: &PyAny) -> PyResult<Vec<u8>> {
} else if let Ok(b) = body.downcast::<PyBytes>() {
Ok(b.as_bytes().to_vec())
} else {
debug!("Could not convert specified body to bytes");
Ok(vec![])
Err(PyValueError::new_err("Body must be either string or bytes"))
}
}

@@ -89,26 +87,31 @@ pub fn get_description_from_pyobject(description: &PyAny) -> PyResult<Vec<u8>> {
Ok(s.to_string().into_bytes())
} else if let Ok(b) = description.downcast::<PyBytes>() {
Ok(b.as_bytes().to_vec())
} else if let Ok(i) = description.extract::<i64>() {
Ok(i.to_string().into_bytes())
} else {
debug!("Could not convert specified response description to bytes");
Ok(vec![])
Err(PyValueError::new_err("Description must be string, bytes, or integer"))
}
}

pub fn check_body_type(py: Python, body: &Py<PyAny>) -> PyResult<()> {
if body.downcast::<PyString>(py).is_err() && body.downcast::<PyBytes>(py).is_err() {
let body_ref = body.as_ref(py);
if !body_ref.is_instance_of::<PyString>() && !body_ref.is_instance_of::<PyBytes>() {
return Err(PyValueError::new_err(
"Could not convert specified body to bytes",
"Body must be either string or bytes"
));
};
}
Ok(())
}

pub fn check_description_type(py: Python, body: &Py<PyAny>) -> PyResult<()> {
if body.downcast::<PyString>(py).is_err() && body.downcast::<PyBytes>(py).is_err() {
pub fn check_description_type(py: Python, description: &Py<PyAny>) -> PyResult<()> {
let desc_ref = description.as_ref(py);
if !desc_ref.is_instance_of::<PyString>() &&
!desc_ref.is_instance_of::<PyBytes>() &&
!desc_ref.is_instance_of::<pyo3::types::PyInt>() {
return Err(PyValueError::new_err(
"Could not convert specified response description to bytes",
"Description must be string, bytes, or integer"
));
};
}
Ok(())
}
2 changes: 1 addition & 1 deletion src/types/request.rs
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ impl Request {
Self {
query_params,
headers,
method: req.method().as_str().to_owned(),
method: req.method().as_str().to_string(),
path_params: HashMap::new(),
body,
url,
379 changes: 306 additions & 73 deletions src/types/response.rs

Large diffs are not rendered by default.