Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion crates/factor-outbound-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod wasi;
pub mod wasi_2023_10_18;
pub mod wasi_2023_11_10;

use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};

use anyhow::Context;
use http::{
Expand Down Expand Up @@ -200,3 +200,25 @@ pub struct AppState {
/// A semaphore to limit the number of concurrent outbound connections.
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
}

/// Removes IPs in the given [`BlockedNetworks`].
///
/// Returns [`ErrorCode::DestinationIpProhibited`] if all IPs are removed.
fn remove_blocked_addrs(
blocked_networks: &BlockedNetworks,
addrs: &mut Vec<SocketAddr>,
) -> Result<(), ErrorCode> {
if addrs.is_empty() {
return Ok(());
}
let blocked_addrs = blocked_networks.remove_blocked(addrs);
if addrs.is_empty() && !blocked_addrs.is_empty() {
tracing::error!(
"error.type" = "destination_ip_prohibited",
?blocked_addrs,
"all destination IP(s) prohibited by runtime config"
);
return Err(ErrorCode::DestinationIpProhibited);
}
Ok(())
}
23 changes: 22 additions & 1 deletion crates/factor-outbound-http/src/spin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::sync::Arc;

use http_body_util::BodyExt;
use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
use spin_world::v1::{
http as spin_http,
http_types::{self, HttpError, Method, Request, Response},
Expand Down Expand Up @@ -90,7 +93,8 @@ impl spin_http::Host for crate::InstanceState {
// Allow reuse of Client's internal connection pool for multiple requests
// in a single component execution
let client = self.spin_http_client.get_or_insert_with(|| {
let mut builder = reqwest::Client::builder();
let mut builder = reqwest::Client::builder()
.dns_resolver(Arc::new(SpinResolver(self.blocked_networks.clone())));
if !self.connection_pooling_enabled {
builder = builder.pool_max_idle_per_host(0);
}
Expand All @@ -113,6 +117,23 @@ impl spin_http::Host for crate::InstanceState {
}
}

struct SpinResolver(BlockedNetworks);

impl reqwest::dns::Resolve for SpinResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let blocked_networks = self.0.clone();
Box::pin(async move {
let mut addrs = tokio::net::lookup_host(name.as_str())
.await
.map_err(Box::new)?
.collect::<Vec<_>>();
// Remove blocked IPs
crate::remove_blocked_addrs(&blocked_networks, &mut addrs).map_err(Box::new)?;
Ok(Box::new(addrs.into_iter()) as reqwest::dns::Addrs)
})
}
}

impl http_types::Host for crate::InstanceState {
fn convert_http_error(&mut self, err: HttpError) -> anyhow::Result<HttpError> {
Ok(err)
Expand Down
10 changes: 1 addition & 9 deletions crates/factor-outbound-http/src/wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,15 +587,7 @@ impl ConnectOptions {
};

// Remove blocked IPs
let blocked_addrs = self.blocked_networks.remove_blocked(&mut socket_addrs);
if socket_addrs.is_empty() && !blocked_addrs.is_empty() {
tracing::error!(
"error.type" = "destination_ip_prohibited",
?blocked_addrs,
"all destination IP(s) prohibited by runtime config"
);
return Err(ErrorCode::DestinationIpProhibited);
}
crate::remove_blocked_addrs(&self.blocked_networks, &mut socket_addrs)?;

// If we're limiting concurrent outbound requests, acquire a permit
let permit = match &self.concurrent_outbound_connections_semaphore {
Expand Down
1 change: 1 addition & 0 deletions crates/factor-outbound-redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-resource-table = { path = "../table" }
spin-world = { path = "../world" }
tokio = { workspace = true }
tracing = { workspace = true }

[dev-dependencies]
Expand Down
36 changes: 35 additions & 1 deletion crates/factor-outbound-redis/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::net::SocketAddr;

use anyhow::Result;
use redis::io::AsyncDNSResolver;
use redis::AsyncConnectionConfig;
use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, Value};
use spin_core::wasmtime::component::Resource;
use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
use spin_factor_outbound_networking::config::blocked_networks::BlockedNetworks;
use spin_world::v1::{redis as v1, redis_types};
use spin_world::v2::redis::{
self as v2, Connection as RedisConnection, Error, RedisParameter, RedisResult,
Expand All @@ -11,6 +16,7 @@ use tracing::{instrument, Level};

pub struct InstanceState {
pub allowed_hosts: OutboundAllowedHosts,
pub blocked_networks: BlockedNetworks,
pub connections: spin_resource_table::Table<MultiplexedConnection>,
}

Expand All @@ -23,9 +29,11 @@ impl InstanceState {
&mut self,
address: String,
) -> Result<Resource<RedisConnection>, Error> {
let config = AsyncConnectionConfig::new()
.set_dns_resolver(SpinResolver(self.blocked_networks.clone()));
let conn = redis::Client::open(address.as_str())
.map_err(|_| Error::InvalidAddress)?
.get_multiplexed_async_connection()
.get_multiplexed_async_connection_with_config(&config)
.await
.map_err(other_error)?;
self.connections
Expand Down Expand Up @@ -365,3 +373,29 @@ impl FromRedisValue for RedisResults {
Ok(RedisResults(values))
}
}

struct SpinResolver(BlockedNetworks);

impl AsyncDNSResolver for SpinResolver {
fn resolve<'a, 'b: 'a>(
&'a self,
host: &'b str,
port: u16,
) -> redis::RedisFuture<'a, Box<dyn Iterator<Item = std::net::SocketAddr> + Send + 'a>> {
Box::pin(async move {
let mut addrs = tokio::net::lookup_host((host, port))
.await?
.collect::<Vec<_>>();
// Remove blocked IPs
let blocked_addrs = self.0.remove_blocked(&mut addrs);
if addrs.is_empty() && !blocked_addrs.is_empty() {
tracing::error!(
"error.type" = "destination_ip_prohibited",
?blocked_addrs,
"all destination IP(s) prohibited by runtime config"
);
}
Ok(Box::new(addrs.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
})
}
}
7 changes: 3 additions & 4 deletions crates/factor-outbound-redis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ impl Factor for OutboundRedisFactor {
&self,
mut ctx: PrepareContext<T, Self>,
) -> anyhow::Result<Self::InstanceBuilder> {
let allowed_hosts = ctx
.instance_builder::<OutboundNetworkingFactor>()?
.allowed_hosts();
let outbound_networking = ctx.instance_builder::<OutboundNetworkingFactor>()?;
Ok(InstanceState {
allowed_hosts,
allowed_hosts: outbound_networking.allowed_hosts(),
blocked_networks: outbound_networking.blocked_networks(),
connections: spin_resource_table::Table::new(1024),
})
}
Expand Down