Skip to content

Commit

Permalink
feat: add configuration possibilities for CORS middleware (#705)
Browse files Browse the repository at this point in the history
fixes #701
  • Loading branch information
chriswk authored Feb 6, 2025
1 parent 43591b9 commit bf4360e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 6 deletions.
89 changes: 89 additions & 0 deletions server/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;

use actix_cors::Cors;
use actix_http::Method;
use cidr::{Ipv4Cidr, Ipv6Cidr};
use clap::{ArgGroup, Args, Parser, Subcommand, ValueEnum};

Expand Down Expand Up @@ -411,6 +413,53 @@ pub struct TlsOptions {
pub tls_server_port: u16,
}

pub fn parse_http_method(value: &str) -> Result<actix_http::Method, String> {
Method::from_bytes(value.as_bytes()).map_err(|f| format!("Failed to format method: {f:?}"))
}

#[derive(Args, Debug, Clone)]
pub struct CorsOptions {
#[clap(env, long, value_delimiter = ',')]
pub cors_origin: Option<Vec<String>>,
#[clap(env, long, value_delimiter = ',')]
pub cors_allowed_headers: Option<Vec<String>>,
#[clap(env, long, default_value_t = 172800)]
pub cors_max_age: usize,
#[clap(env, long, value_delimiter = ',')]
pub cors_exposed_headers: Option<Vec<String>>,
#[clap(env, long, value_delimiter = ',', value_parser = parse_http_method)]
pub cors_methods: Option<Vec<actix_http::Method>>,
}

impl CorsOptions {
pub fn middleware(&self) -> Cors {
let mut cors_middleware = Cors::default()
.max_age(self.cors_max_age)
.allow_any_method()
.allow_any_header();
if let Some(origins) = self.cors_origin.clone() {
for origin in origins {
cors_middleware = cors_middleware.allowed_origin(&origin);
}
cors_middleware = cors_middleware.supports_credentials();
} else {
cors_middleware = cors_middleware.allow_any_origin().send_wildcard();
}
if let Some(allowed_headers) = self.cors_allowed_headers.clone() {
for header in allowed_headers {
cors_middleware = cors_middleware.allowed_header(header);
}
}
if let Some(allowed_methods) = self.cors_methods.clone() {
cors_middleware = cors_middleware.allowed_methods(allowed_methods);
}
if let Some(exposed_headers) = self.cors_exposed_headers.clone() {
cors_middleware = cors_middleware.expose_headers(exposed_headers);
}
cors_middleware
}
}

#[derive(Args, Debug, Clone)]
pub struct HttpServerArgs {
/// Which port should this server listen for HTTP traffic on
Expand All @@ -430,6 +479,9 @@ pub struct HttpServerArgs {

#[clap(flatten)]
pub tls: TlsOptions,

#[clap(flatten)]
pub cors: CorsOptions,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -478,6 +530,7 @@ impl HttpServerArgs {

#[cfg(test)]
mod tests {
use actix_web::http;
use clap::Parser;
use tracing::info;
use tracing_test::traced_test;
Expand Down Expand Up @@ -766,6 +819,42 @@ mod tests {
}
}

#[test]
pub fn cors_origin_can_be_set_via_cli() {
let args = vec![
"unleash-edge",
"--cors-origin",
"example.com",
"--cors-origin",
"otherexample.com",
"--cors-origin",
"one.com,two.com",
"edge",
"-u http://localhost:4242",
];
let args = CliArgs::parse_from(args);
assert_eq!(args.http.cors.cors_origin.clone().unwrap().len(), 4);
let _middleware = args.http.cors.middleware();
}

#[test]
pub fn can_set_custom_cors_method() {
let args = vec![
"unleash-edge",
"--cors-methods",
"GET",
"--cors-methods",
"PATCH",
"edge",
"-u http://localhost:4242",
];
let cli = CliArgs::parse_from(args);
assert_eq!(
cli.http.cors.cors_methods,
Some(vec![http::Method::GET, http::Method::PATCH])
);
}

#[test]
pub fn proxy_trusted_servers_accept_both_ipv4_and_ipv6_cidr_addresses() {
let args = vec![
Expand Down
8 changes: 2 additions & 6 deletions server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::sync::Arc;

use actix_cors::Cors;
use actix_middleware_etag::Etag;
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
Expand Down Expand Up @@ -51,6 +50,7 @@ async fn main() -> Result<(), anyhow::Error> {
let schedule_args = args.clone();
let mode_arg = args.clone().mode;
let http_args = args.clone().http;
let cors_arg = http_args.cors.clone();
let token_header = args.clone().token_header;
let request_timeout = args.edge_request_timeout;
let keepalive_timeout = args.edge_keepalive_timeout;
Expand Down Expand Up @@ -96,11 +96,7 @@ async fn main() -> Result<(), anyhow::Error> {
let qs_config =
serde_qs::actix::QsQueryConfig::default().qs_config(serde_qs::Config::new(5, false));

let cors_middleware = Cors::default()
.allow_any_origin()
.send_wildcard()
.allow_any_header()
.allow_any_method();
let cors_middleware = cors_arg.middleware();
let mut app = App::new()
.app_data(qs_config)
.app_data(web::Data::new(token_header.clone()))
Expand Down

0 comments on commit bf4360e

Please sign in to comment.