Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat before request #1125

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
from collections import defaultdict
from typing import Optional
import uuid

from integration_tests.subroutes import di_subrouter, sub_router
from robyn import Headers, Request, Response, Robyn, WebSocket, WebSocketConnector, jsonify, serve_file, serve_html
Expand Down Expand Up @@ -126,6 +127,8 @@ def shutdown_handler():

@app.before_request()
def global_before_request(request: Request):
if "trace_id" not in request.headers:
request.headers["trace_id"] = uuid.uuid4().hex
request.headers.set("global_before", "global_before_request")
return request

Expand All @@ -136,6 +139,37 @@ def global_after_request(response: Response):
return response


@app.after_request()
def global_after_request(response: Response, request: Request):
response.headers.set("global_after", "global_after_request")
response.headers["trace_id"] = request.headers["trace_id"]
return response


@app.get("/sync/global/middlewares/with_request")
def sync_global_middlewares_with_request(request: Request):
print(request.headers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used for testing

return "sync global middlewares with request"


@app.after_request("/sync/global/middlewares/with_request")
def sync_global_middlewares_before_with_request(response: Response, request: Request):
response.headers["trace_id"] = request.headers["trace_id"]
return response


@app.get("/async/global/middlewares/with_request")
def sync_global_middlewares_with_request(request: Request):
print(request.headers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here.

return "async global middlewares with request"


@app.after_request("/async/global/middlewares/with_request")
def sync_global_middlewares_before_with_request(response: Response, request: Request):
response.headers["trace_id"] = request.headers["trace_id"]
return response


@app.get("/sync/global/middlewares")
def sync_global_middlewares(request: Request):
print(request.headers)
Expand Down
72 changes: 66 additions & 6 deletions src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
/// i.e. the functions that have the responsibility of parsing and executing functions.
pub mod web_socket_executors;

use std::sync::Arc;
use crate::types::{
function_info::FunctionInfo, request::Request, response::Response, MiddlewareReturn,
};

use anyhow::Result;
use log::debug;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use pyo3_asyncio::TaskLocals;

use crate::types::{
function_info::FunctionInfo, request::Request, response::Response, MiddlewareReturn,
};
use std::sync::Arc;

#[inline]
fn get_function_output<'a, T>(
Expand Down Expand Up @@ -40,7 +40,13 @@ where
handler.call1((function_args,))
}
}
_ => handler.call((function_args,), Some(kwargs)),
_ => {
if let Ok(tuple) = function_args.downcast::<PyTuple>(py) {
handler.call(tuple, Some(kwargs))
} else {
handler.call((function_args,), Some(kwargs))
}
Comment on lines +44 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please explain this line of code to me?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust, this syntax (function_args,) creates a nested tuple in Python, similar to ((arg1,arg2),).

}
}
}

Expand Down Expand Up @@ -79,6 +85,60 @@ where
})
}
}
// Execute middleware function after receiving a response
//
// This function handles post-request middleware logic that can receive both
// the response and the original request as parameters.
// T represents the response type, R represents the request type.
//
// The function determines whether to pass just the response or both response and request
// to the middleware function based on the number of parameters it accepts.
pub async fn execute_middleware_after_request<T, T2>(
response: &T,
request: &T2,
function: &FunctionInfo,
) -> Result<MiddlewareReturn>
where
T: for<'a> FromPyObject<'a> + ToPyObject,
T2: for<'a> FromPyObject<'a> + ToPyObject,
{
if function.is_async {
let output: Py<PyAny> = Python::with_gil(|py| {
let result = if function.number_of_params == 2 {
pyo3_asyncio::tokio::into_future(get_function_output(
function,
py,
&(response, request),
)?)
} else {
pyo3_asyncio::tokio::into_future(get_function_output(function, py, response)?)
};
result
})?
.await?;

Python::with_gil(|py| -> Result<MiddlewareReturn> {
let output_response = output.extract::<Response>(py);
match output_response {
Ok(o) => Ok(MiddlewareReturn::Response(o)),
Err(_) => Ok(MiddlewareReturn::Request(output.extract::<Request>(py)?)),
}
})
} else {
Python::with_gil(|py| -> Result<MiddlewareReturn> {
let output = if function.number_of_params == 2 {
get_function_output(function, py, &(response, request))?
} else {
get_function_output(function, py, response)?
};
debug!("Middleware output: {:?}", output);
match output.extract::<Response>() {
Ok(o) => Ok(MiddlewareReturn::Response(o)),
Err(_) => Ok(MiddlewareReturn::Request(output.extract::<Request>()?)),
}
})
}
}

#[inline]
pub async fn execute_http_function(
Expand Down
42 changes: 22 additions & 20 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::executors::{
execute_http_function, execute_middleware_function, execute_startup_handler,
execute_http_function, execute_middleware_after_request, execute_middleware_function,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do slightly better here.

i.e. execute_middleware_after_request and execute_middleware_function are not correct nomenclature imo.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your suggestion then?

execute_startup_handler,
};

use crate::routers::const_router::ConstRouter;
Expand Down Expand Up @@ -518,26 +519,27 @@ async fn index(
after_middlewares.push(function);
}
for after_middleware in after_middlewares {
response = match execute_middleware_function(&response, &after_middleware).await {
Ok(MiddlewareReturn::Request(_)) => {
error!("After middleware returned a request");
return Response::internal_server_error(Some(&response.headers));
}
Ok(MiddlewareReturn::Response(r)) => {
let response = r;
response =
match execute_middleware_after_request(&response, &request, &after_middleware).await {
Ok(MiddlewareReturn::Request(_)) => {
error!("After middleware returned a request");
return Response::internal_server_error(Some(&response.headers));
}
Ok(MiddlewareReturn::Response(r)) => {
let response = r;

debug!("Response returned: {:?}", response);
response
}
Err(e) => {
error!(
"Error while executing after middleware function for endpoint `{}`: {}",
req.uri().path(),
get_traceback(e.downcast_ref::<PyErr>().unwrap())
);
return Response::internal_server_error(Some(&response.headers));
}
};
debug!("Response returned: {:?}", response);
response
}
Err(e) => {
error!(
"Error while executing after middleware function for endpoint `{}`: {}",
req.uri().path(),
get_traceback(e.downcast_ref::<PyErr>().unwrap())
);
return Response::internal_server_error(Some(&response.headers));
}
};
}

debug!("Response returned: {:?}", response);
Expand Down