Skip to content

Commit 190e8b1

Browse files
committed
init
0 parents  commit 190e8b1

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

.gitignore

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Generated by Cargo
2+
# will have compiled files and executables
3+
debug/
4+
target/
5+
6+
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
7+
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
8+
Cargo.lock
9+
10+
# These are backup files generated by rustfmt
11+
**/*.rs.bk
12+
13+
# MSVC Windows builds of rustc generate these, which store debugging information
14+
*.pdb
15+
16+
.idea

Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "silly-cors"
3+
version = "0.1.0"
4+
authors = ["dubzer"]
5+
repository = "https://github.com/dubzer/silly-cors"
6+
edition = "2021"
7+
8+
[dependencies]
9+
hyper = { version = "0.14.21", features = ["full"] }
10+
hyper-tls = "0.5.0"
11+
tokio = { version = "1", features = ["full"] }

src/main.rs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use std::str::FromStr;
2+
use hyper::http::HeaderValue;
3+
use hyper::service::{make_service_fn, service_fn};
4+
use hyper::{Body, Request, Response, Server, StatusCode, Client, HeaderMap, Method, Uri};
5+
use hyper::http::uri::{Authority, Scheme};
6+
use hyper_tls::HttpsConnector;
7+
8+
9+
type GenericError = Box<dyn std::error::Error + Send + Sync>;
10+
type Result<T> = std::result::Result<T, GenericError>;
11+
12+
async fn handle(mut req: Request<Body>) -> Result<Response<Body>> {
13+
let Some(destination_header) = req.headers_mut().remove("x-destination") else {
14+
return Ok(validation_error("can i hav some of that x-destination header, pwease? 🥺", None));
15+
};
16+
17+
let Some(origin) = req.headers().get("Origin") else {
18+
return Ok(validation_error("can i hav some of that origin header, pwease? 🥺", None));
19+
};
20+
21+
let origin = origin.clone();
22+
23+
let Ok(authority) = Authority::from_str(destination_header.to_str().unwrap()) else {
24+
return Ok(validation_error("your x-destination header looks like an invalid domain 🥺", Some(&origin)))
25+
};
26+
27+
let mut uri_parts = req.uri().clone().into_parts();
28+
uri_parts.authority = Some(authority);
29+
uri_parts.scheme = Some(Scheme::HTTPS);
30+
31+
*req.headers_mut().get_mut("Host").unwrap() = destination_header;
32+
*req.uri_mut() = Uri::from_parts(uri_parts).unwrap();
33+
34+
let client = Client::builder().build(HttpsConnector::new());
35+
36+
let client_response = match client.request(req).await {
37+
Ok(result) => result,
38+
Err(err) => return Ok(validation_error(format!("oops, couldn't connect to destination :(\n{}", err).as_str(), Some(&origin)))
39+
};
40+
41+
let (mut parts, body) = client_response.into_parts();
42+
43+
parts.headers.extend(get_default_cors(&origin));
44+
45+
return Ok(Response::from_parts(parts, body));
46+
}
47+
48+
async fn handle_options(req: Request<Body>) -> Result<Response<Body>> {
49+
let Some(origin) = req.headers().get("origin") else {
50+
let response = Response::builder()
51+
.status(StatusCode::BAD_REQUEST)
52+
.body(Body::empty()).unwrap();
53+
54+
return Ok(response);
55+
};
56+
57+
let mut response = Response::builder().status(StatusCode::OK);
58+
59+
let mut headers = get_default_cors(origin);
60+
headers.insert("Access-Control-Allow-Headers", HeaderValue::from_static("*"));
61+
62+
response.headers_mut().unwrap().extend(headers);
63+
return Ok(response.body(Body::empty()).unwrap())
64+
}
65+
66+
async fn route(req: Request<Body>) -> Result<Response<Body>> {
67+
println!("req: {:?}", req);
68+
match (req.method(), req.uri().path()) {
69+
(&Method::OPTIONS, _) => handle_options(req).await,
70+
_ => handle(req).await
71+
}
72+
}
73+
74+
fn validation_error(message: &str, origin: Option<&HeaderValue>) -> Response<Body> {
75+
let mut response = Response::builder().status(StatusCode::BAD_REQUEST);
76+
if let Some(origin) = origin {
77+
let headers = get_default_cors(origin);
78+
response.headers_mut().unwrap().extend(headers);
79+
}
80+
81+
return response.body(Body::from(message.to_string())).unwrap()
82+
}
83+
84+
fn get_default_cors(origin: &HeaderValue) -> HeaderMap {
85+
let mut cors_headers = HeaderMap::new();
86+
cors_headers.insert("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
87+
cors_headers.insert("Access-Control-Allow-Origin", origin.clone());
88+
cors_headers.insert("Access-Control-Allow-Methods", HeaderValue::from_static("GET, PUT, POST, DELETE, HEAD, PATCH, OPTIONS"));
89+
return cors_headers;
90+
}
91+
92+
93+
#[tokio::main]
94+
pub async fn main() -> Result<()> {
95+
// For every connection, we must make a `Service` to handle all
96+
// incoming HTTP requests on said connection.
97+
let make_svc = make_service_fn(|_conn| {
98+
// This is the `Service` that will handle the connection.
99+
// `service_fn` is a helper to convert a function that
100+
// returns a Response into a `Service`.
101+
async { Ok::<_, GenericError>(service_fn(route)) }
102+
});
103+
104+
let addr = ([127, 0, 0, 1], 3001).into();
105+
106+
let server = Server::bind(&addr).serve(make_svc);
107+
108+
println!("Listening on http://{}", addr);
109+
110+
server.await?;
111+
112+
Ok(())
113+
}

0 commit comments

Comments
 (0)