-
-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat before request #1125
base: main
Are you sure you want to change the base?
Feat before request #1125
Changes from all commits
0f16260
8e19354
2658b29
87a4c02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import pathlib | ||
from collections import defaultdict | ||
from typing import Optional | ||
import uuid | ||
|
||
from integration_tests.subroutes import di_subrouter, sub_router | ||
from robyn import Headers, Request, Response, Robyn, WebSocket, WebSocketConnector, jsonify, serve_file, serve_html | ||
|
@@ -126,6 +127,8 @@ def shutdown_handler(): | |
|
||
@app.before_request() | ||
def global_before_request(request: Request): | ||
if "trace_id" not in request.headers: | ||
request.headers["trace_id"] = uuid.uuid4().hex | ||
request.headers.set("global_before", "global_before_request") | ||
return request | ||
|
||
|
@@ -136,6 +139,37 @@ def global_after_request(response: Response): | |
return response | ||
|
||
|
||
@app.after_request() | ||
def global_after_request(response: Response, request: Request): | ||
response.headers.set("global_after", "global_after_request") | ||
response.headers["trace_id"] = request.headers["trace_id"] | ||
return response | ||
|
||
|
||
@app.get("/sync/global/middlewares/with_request") | ||
def sync_global_middlewares_with_request(request: Request): | ||
print(request.headers) | ||
return "sync global middlewares with request" | ||
|
||
|
||
@app.after_request("/sync/global/middlewares/with_request") | ||
def sync_global_middlewares_before_with_request(response: Response, request: Request): | ||
response.headers["trace_id"] = request.headers["trace_id"] | ||
return response | ||
|
||
|
||
@app.get("/async/global/middlewares/with_request") | ||
def sync_global_middlewares_with_request(request: Request): | ||
print(request.headers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question here. |
||
return "async global middlewares with request" | ||
|
||
|
||
@app.after_request("/async/global/middlewares/with_request") | ||
def sync_global_middlewares_before_with_request(response: Response, request: Request): | ||
response.headers["trace_id"] = request.headers["trace_id"] | ||
return response | ||
|
||
|
||
@app.get("/sync/global/middlewares") | ||
def sync_global_middlewares(request: Request): | ||
print(request.headers) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,16 +3,16 @@ | |
/// i.e. the functions that have the responsibility of parsing and executing functions. | ||
pub mod web_socket_executors; | ||
|
||
use std::sync::Arc; | ||
use crate::types::{ | ||
function_info::FunctionInfo, request::Request, response::Response, MiddlewareReturn, | ||
}; | ||
|
||
use anyhow::Result; | ||
use log::debug; | ||
use pyo3::prelude::*; | ||
use pyo3::types::PyTuple; | ||
use pyo3_asyncio::TaskLocals; | ||
|
||
use crate::types::{ | ||
function_info::FunctionInfo, request::Request, response::Response, MiddlewareReturn, | ||
}; | ||
use std::sync::Arc; | ||
|
||
#[inline] | ||
fn get_function_output<'a, T>( | ||
|
@@ -40,7 +40,13 @@ where | |
handler.call1((function_args,)) | ||
} | ||
} | ||
_ => handler.call((function_args,), Some(kwargs)), | ||
_ => { | ||
if let Ok(tuple) = function_args.downcast::<PyTuple>(py) { | ||
handler.call(tuple, Some(kwargs)) | ||
} else { | ||
handler.call((function_args,), Some(kwargs)) | ||
} | ||
Comment on lines
+44
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you please explain this line of code to me? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Rust, this syntax (function_args,) creates a nested tuple in Python, similar to ((arg1,arg2),). |
||
} | ||
} | ||
} | ||
|
||
|
@@ -79,6 +85,60 @@ where | |
}) | ||
} | ||
} | ||
// Execute middleware function after receiving a response | ||
// | ||
// This function handles post-request middleware logic that can receive both | ||
// the response and the original request as parameters. | ||
// T represents the response type, R represents the request type. | ||
// | ||
// The function determines whether to pass just the response or both response and request | ||
// to the middleware function based on the number of parameters it accepts. | ||
pub async fn execute_middleware_after_request<T, T2>( | ||
response: &T, | ||
request: &T2, | ||
function: &FunctionInfo, | ||
) -> Result<MiddlewareReturn> | ||
where | ||
T: for<'a> FromPyObject<'a> + ToPyObject, | ||
T2: for<'a> FromPyObject<'a> + ToPyObject, | ||
{ | ||
if function.is_async { | ||
let output: Py<PyAny> = Python::with_gil(|py| { | ||
let result = if function.number_of_params == 2 { | ||
pyo3_asyncio::tokio::into_future(get_function_output( | ||
function, | ||
py, | ||
&(response, request), | ||
)?) | ||
} else { | ||
pyo3_asyncio::tokio::into_future(get_function_output(function, py, response)?) | ||
}; | ||
result | ||
})? | ||
.await?; | ||
|
||
Python::with_gil(|py| -> Result<MiddlewareReturn> { | ||
let output_response = output.extract::<Response>(py); | ||
match output_response { | ||
Ok(o) => Ok(MiddlewareReturn::Response(o)), | ||
Err(_) => Ok(MiddlewareReturn::Request(output.extract::<Request>(py)?)), | ||
} | ||
}) | ||
} else { | ||
Python::with_gil(|py| -> Result<MiddlewareReturn> { | ||
let output = if function.number_of_params == 2 { | ||
get_function_output(function, py, &(response, request))? | ||
} else { | ||
get_function_output(function, py, response)? | ||
}; | ||
debug!("Middleware output: {:?}", output); | ||
match output.extract::<Response>() { | ||
Ok(o) => Ok(MiddlewareReturn::Response(o)), | ||
Err(_) => Ok(MiddlewareReturn::Request(output.extract::<Request>()?)), | ||
} | ||
}) | ||
} | ||
} | ||
|
||
#[inline] | ||
pub async fn execute_http_function( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
use crate::executors::{ | ||
execute_http_function, execute_middleware_function, execute_startup_handler, | ||
execute_http_function, execute_middleware_after_request, execute_middleware_function, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can do slightly better here. i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is your suggestion then? |
||
execute_startup_handler, | ||
}; | ||
|
||
use crate::routers::const_router::ConstRouter; | ||
|
@@ -518,26 +519,27 @@ async fn index( | |
after_middlewares.push(function); | ||
} | ||
for after_middleware in after_middlewares { | ||
response = match execute_middleware_function(&response, &after_middleware).await { | ||
Ok(MiddlewareReturn::Request(_)) => { | ||
error!("After middleware returned a request"); | ||
return Response::internal_server_error(Some(&response.headers)); | ||
} | ||
Ok(MiddlewareReturn::Response(r)) => { | ||
let response = r; | ||
response = | ||
match execute_middleware_after_request(&response, &request, &after_middleware).await { | ||
Ok(MiddlewareReturn::Request(_)) => { | ||
error!("After middleware returned a request"); | ||
return Response::internal_server_error(Some(&response.headers)); | ||
} | ||
Ok(MiddlewareReturn::Response(r)) => { | ||
let response = r; | ||
|
||
debug!("Response returned: {:?}", response); | ||
response | ||
} | ||
Err(e) => { | ||
error!( | ||
"Error while executing after middleware function for endpoint `{}`: {}", | ||
req.uri().path(), | ||
get_traceback(e.downcast_ref::<PyErr>().unwrap()) | ||
); | ||
return Response::internal_server_error(Some(&response.headers)); | ||
} | ||
}; | ||
debug!("Response returned: {:?}", response); | ||
response | ||
} | ||
Err(e) => { | ||
error!( | ||
"Error while executing after middleware function for endpoint `{}`: {}", | ||
req.uri().path(), | ||
get_traceback(e.downcast_ref::<PyErr>().unwrap()) | ||
); | ||
return Response::internal_server_error(Some(&response.headers)); | ||
} | ||
}; | ||
} | ||
|
||
debug!("Response returned: {:?}", response); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used for testing