1
+ use std:: str:: FromStr ;
2
+ use hyper:: http:: HeaderValue ;
3
+ use hyper:: service:: { make_service_fn, service_fn} ;
4
+ use hyper:: { Body , Request , Response , Server , StatusCode , Client , HeaderMap , Method , Uri } ;
5
+ use hyper:: http:: uri:: { Authority , Scheme } ;
6
+ use hyper_tls:: HttpsConnector ;
7
+
8
+
9
+ type GenericError = Box < dyn std:: error:: Error + Send + Sync > ;
10
+ type Result < T > = std:: result:: Result < T , GenericError > ;
11
+
12
+ async fn handle ( mut req : Request < Body > ) -> Result < Response < Body > > {
13
+ let Some ( destination_header) = req. headers_mut ( ) . remove ( "x-destination" ) else {
14
+ return Ok ( validation_error ( "can i hav some of that x-destination header, pwease? 🥺" , None ) ) ;
15
+ } ;
16
+
17
+ let Some ( origin) = req. headers ( ) . get ( "Origin" ) else {
18
+ return Ok ( validation_error ( "can i hav some of that origin header, pwease? 🥺" , None ) ) ;
19
+ } ;
20
+
21
+ let origin = origin. clone ( ) ;
22
+
23
+ let Ok ( authority) = Authority :: from_str ( destination_header. to_str ( ) . unwrap ( ) ) else {
24
+ return Ok ( validation_error ( "your x-destination header looks like an invalid domain 🥺" , Some ( & origin) ) )
25
+ } ;
26
+
27
+ let mut uri_parts = req. uri ( ) . clone ( ) . into_parts ( ) ;
28
+ uri_parts. authority = Some ( authority) ;
29
+ uri_parts. scheme = Some ( Scheme :: HTTPS ) ;
30
+
31
+ * req. headers_mut ( ) . get_mut ( "Host" ) . unwrap ( ) = destination_header;
32
+ * req. uri_mut ( ) = Uri :: from_parts ( uri_parts) . unwrap ( ) ;
33
+
34
+ let client = Client :: builder ( ) . build ( HttpsConnector :: new ( ) ) ;
35
+
36
+ let client_response = match client. request ( req) . await {
37
+ Ok ( result) => result,
38
+ Err ( err) => return Ok ( validation_error ( format ! ( "oops, couldn't connect to destination :(\n {}" , err) . as_str ( ) , Some ( & origin) ) )
39
+ } ;
40
+
41
+ let ( mut parts, body) = client_response. into_parts ( ) ;
42
+
43
+ parts. headers . extend ( get_default_cors ( & origin) ) ;
44
+
45
+ return Ok ( Response :: from_parts ( parts, body) ) ;
46
+ }
47
+
48
+ async fn handle_options ( req : Request < Body > ) -> Result < Response < Body > > {
49
+ let Some ( origin) = req. headers ( ) . get ( "origin" ) else {
50
+ let response = Response :: builder ( )
51
+ . status ( StatusCode :: BAD_REQUEST )
52
+ . body ( Body :: empty ( ) ) . unwrap ( ) ;
53
+
54
+ return Ok ( response) ;
55
+ } ;
56
+
57
+ let mut response = Response :: builder ( ) . status ( StatusCode :: OK ) ;
58
+
59
+ let mut headers = get_default_cors ( origin) ;
60
+ headers. insert ( "Access-Control-Allow-Headers" , HeaderValue :: from_static ( "*" ) ) ;
61
+
62
+ response. headers_mut ( ) . unwrap ( ) . extend ( headers) ;
63
+ return Ok ( response. body ( Body :: empty ( ) ) . unwrap ( ) )
64
+ }
65
+
66
+ async fn route ( req : Request < Body > ) -> Result < Response < Body > > {
67
+ println ! ( "req: {:?}" , req) ;
68
+ match ( req. method ( ) , req. uri ( ) . path ( ) ) {
69
+ ( & Method :: OPTIONS , _) => handle_options ( req) . await ,
70
+ _ => handle ( req) . await
71
+ }
72
+ }
73
+
74
+ fn validation_error ( message : & str , origin : Option < & HeaderValue > ) -> Response < Body > {
75
+ let mut response = Response :: builder ( ) . status ( StatusCode :: BAD_REQUEST ) ;
76
+ if let Some ( origin) = origin {
77
+ let headers = get_default_cors ( origin) ;
78
+ response. headers_mut ( ) . unwrap ( ) . extend ( headers) ;
79
+ }
80
+
81
+ return response. body ( Body :: from ( message. to_string ( ) ) ) . unwrap ( )
82
+ }
83
+
84
+ fn get_default_cors ( origin : & HeaderValue ) -> HeaderMap {
85
+ let mut cors_headers = HeaderMap :: new ( ) ;
86
+ cors_headers. insert ( "Access-Control-Allow-Credentials" , HeaderValue :: from_static ( "true" ) ) ;
87
+ cors_headers. insert ( "Access-Control-Allow-Origin" , origin. clone ( ) ) ;
88
+ cors_headers. insert ( "Access-Control-Allow-Methods" , HeaderValue :: from_static ( "GET, PUT, POST, DELETE, HEAD, PATCH, OPTIONS" ) ) ;
89
+ return cors_headers;
90
+ }
91
+
92
+
93
+ #[ tokio:: main]
94
+ pub async fn main ( ) -> Result < ( ) > {
95
+ // For every connection, we must make a `Service` to handle all
96
+ // incoming HTTP requests on said connection.
97
+ let make_svc = make_service_fn ( |_conn| {
98
+ // This is the `Service` that will handle the connection.
99
+ // `service_fn` is a helper to convert a function that
100
+ // returns a Response into a `Service`.
101
+ async { Ok :: < _ , GenericError > ( service_fn ( route) ) }
102
+ } ) ;
103
+
104
+ let addr = ( [ 127 , 0 , 0 , 1 ] , 3001 ) . into ( ) ;
105
+
106
+ let server = Server :: bind ( & addr) . serve ( make_svc) ;
107
+
108
+ println ! ( "Listening on http://{}" , addr) ;
109
+
110
+ server. await ?;
111
+
112
+ Ok ( ( ) )
113
+ }
0 commit comments