diff --git a/.github/workflows/rust-test.yml b/.github/workflows/rust-test.yml index 36658829..e7bb3593 100644 --- a/.github/workflows/rust-test.yml +++ b/.github/workflows/rust-test.yml @@ -35,6 +35,10 @@ jobs: ports: # Maps tcp port 5432 on service container to the host - 5432:5432 + redis: + image: redis + ports: + - 6379:6379 steps: - uses: actions/checkout@v4 @@ -57,4 +61,5 @@ jobs: export POSTGRES_USER="postgres" export POSTGRES_PASSWORD="postgres" export POSTGRES_DBNAME="test" + export REDIS_URL="redis://redis:6379" cargo test --verbose diff --git a/Cargo.lock b/Cargo.lock index 2236c6fe..4b7d976f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -457,8 +457,10 @@ dependencies = [ "async-trait", "chrono", "deadpool-postgres", + "deadpool-redis", "deadpool-sqlite", "encoding_rs", + "futures-util", "log", "openssl", "postgres-openssl", @@ -568,6 +570,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "deadpool-redis" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfae6799b68a735270e4344ee3e834365f707c72da09c9a8bb89b45cc3351395" +dependencies = [ + "deadpool 0.12.1", + "redis 0.27.5", +] + [[package]] name = "deadpool-runtime" version = "0.1.4" @@ -2077,6 +2089,27 @@ dependencies = [ "url", ] +[[package]] +name = "redis" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cccf17a692ce51b86564334614d72dcae1def0fd5ecebc9f02956da74352b5" +dependencies = [ + "arc-swap", + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "num-bigint", + "percent-encoding", + "pin-project-lite", + "ryu", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.5.8" @@ -2436,7 +2469,7 @@ dependencies = [ "ppp", "quick-xml", "rdkafka", - "redis", + "redis 0.25.4", "regex", "roxmltree", "rustls-pemfile", diff --git a/common/Cargo.toml b/common/Cargo.toml index 570702c7..71c3cd19 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -23,9 +23,11 @@ chrono = { version = "0.4.26", default-features = false, features = ["clock"] } encoding_rs = "0.8.32" deadpool-postgres = "0.14.0" deadpool-sqlite = "0.5.0" +deadpool-redis = "0.18.0" openssl = "0.10.70" postgres-openssl = "0.5.0" strum = { version = "0.26.1", features = ["derive"] } +futures-util = "0.3.28" [dev-dependencies] tempfile = "3.16.0" @@ -68,4 +70,4 @@ assets = [ { source = "../openwec.conf.sample.toml", dest = "/usr/share/doc/openwec/", mode = "0644", doc = true }, { source = "../README.md", dest = "/usr/share/doc/openwec/", mode = "0644", doc = true }, { source = "../doc/*", dest = "/usr/share/doc/openwec/doc/", mode = "0644", doc = true }, -] \ No newline at end of file +] diff --git a/common/src/database/mod.rs b/common/src/database/mod.rs index 26a340c3..0d9f640f 100644 --- a/common/src/database/mod.rs +++ b/common/src/database/mod.rs @@ -7,6 +7,7 @@ use crate::{ bookmark::BookmarkData, database::postgres::PostgresDatabase, database::sqlite::SQLiteDatabase, + database::redis::RedisDatabase, heartbeat::{HeartbeatData, HeartbeatsCache}, settings::Settings, subscription::{ @@ -21,6 +22,7 @@ use self::schema::{Migration, Version}; pub mod postgres; pub mod schema; pub mod sqlite; +pub mod redis; pub type Db = Arc; @@ -40,6 +42,13 @@ pub async fn db_from_settings(settings: &Settings) -> Result { schema::postgres::register_migrations(&mut db); Ok(Arc::new(db)) } + crate::settings::Database::Redis(redis) => { + let mut db = RedisDatabase::new(redis.connection_url()) + .await + .context("Failed to initialize Redis client")?; + schema::redis::register_migrations(&mut db); + Ok(Arc::new(db)) + } } } diff --git a/common/src/database/redis.rs b/common/src/database/redis.rs new file mode 100644 index 00000000..e0ac3a1d --- /dev/null +++ b/common/src/database/redis.rs @@ -0,0 +1,865 @@ +// Some of the following code is inspired from +// https://github.com/SkylerLipthay/schemamama_postgres. As stated by its +// license (MIT), we include below its copyright notice and permission notice: +// +// The MIT License (MIT) +// +// Copyright (c) 2024 Axoflow +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// +use anyhow::{anyhow, Context, Result}; +use async_trait::async_trait; +use deadpool_redis::redis::AsyncCommands; +use deadpool_redis::{Config, Connection, Pool, Runtime}; +use log::warn; +use std::collections::btree_map::Entry::Vacant; +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::sync::Arc; +use std::time::SystemTime; + +use crate::bookmark::BookmarkData; +use crate::database::Database; +use crate::heartbeat::{HeartbeatData, HeartbeatKey, HeartbeatValue, HeartbeatsCache}; +use crate::subscription::{ + SubscriptionData, SubscriptionMachine, SubscriptionMachineState, SubscriptionStatsCounters +}; +use crate::database::schema::redis::v1::subscription::*; +use futures_util::stream::StreamExt; + +use super::schema::{Migration, MigrationBase, Version}; + + +use deadpool_redis::redis::{self, ToRedisArgs}; +use strum::{Display, EnumString}; + +#[derive(Debug, Eq, Hash, PartialEq, EnumString, Display)] +pub enum RedisDomain { + Users, + Subscription, + Machine, + Heartbeat, + BookMark, + Ip, + FirstSeen, + LastSeen, + LastEventSeen, + #[strum(serialize = "*")] + Any, +} + +impl ToRedisArgs for RedisDomain { + fn write_redis_args(&self, out: &mut W) { + out.write_arg(self.as_str().as_bytes()); + } +} + +impl RedisDomain { + pub fn as_str(&self) -> &str { + match self { + RedisDomain::Users => "users", + RedisDomain::Subscription => "subscription", + RedisDomain::Machine => "machine", + RedisDomain::Heartbeat => "heartbeat", + RedisDomain::BookMark => "bookmark", + RedisDomain::Ip => "ip", + RedisDomain::FirstSeen => "first_seen", + RedisDomain::LastSeen => "last_seen", + RedisDomain::LastEventSeen => "last_event_seen", + RedisDomain::Any => "*", + } + } +} + +const MIGRATION_TABLE_NAME: &str = "__schema_migrations"; + +#[async_trait] +pub trait RedisMigration: Migration { + /// Called when this migration is to be executed. + async fn up(&self, conn: &mut Connection) -> Result<()>; + + /// Called when this migration is to be reversed. + async fn down(&self, conn: &mut Connection) -> Result<()>; + + fn to_base(&self) -> Arc { + Arc::new(MigrationBase::new(self.version(), self.description())) + } +} + +enum MachineStatusFilter { + Alive, + Active, + Dead, +} + +impl MachineStatusFilter { + fn is_match(&self, last_seen: &i64, last_event_seen: &Option, start_time: i64) -> bool { + match self { + MachineStatusFilter::Alive => { + *last_seen > start_time && last_event_seen.map_or(true, |event_time| event_time <= start_time) + }, + MachineStatusFilter::Active => { + last_event_seen.map_or(false, |event_time| event_time > start_time) + }, + MachineStatusFilter::Dead => { + *last_seen <= start_time && last_event_seen.map_or(true, |event_time| event_time <= start_time) + } + } + } +} + +pub struct RedisDatabase { + pool: Pool, + migrations: BTreeMap>, +} + +#[derive(Default)] +struct HeartbeatFilter { + subscription: Option, + machine: Option, + ip: Option, +} + +impl RedisDatabase { + pub async fn new(connection_url: &str) -> Result { + let config = Config::from_url(connection_url); + let pool = config.create_pool(Some(Runtime::Tokio1))?; + let db = RedisDatabase { + pool, + migrations: BTreeMap::new(), + }; + + Ok(db) + } + + /// Register a migration. If a migration with the same version is already registered, a warning + /// is logged and the registration fails. + pub fn register_migration(&mut self, migration: Arc) { + let version = migration.version(); + if let Vacant(e) = self.migrations.entry(version) { + e.insert(migration); + } else { + warn!("Migration with version {:?} is already registered", version); + } + } + + async fn get_heartbeats_by_field( + &self, + fields: HeartbeatFilter + ) -> Result> { + + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + + let key = format!("{}:{}:{}", RedisDomain::Heartbeat, + fields.subscription.unwrap_or_else(|| RedisDomain::Any.to_string()), + fields.machine.unwrap_or_else(|| RedisDomain::Any.to_string())); + + let keys = list_keys(&mut conn, &key).await?; + let mut heartbeats = Vec::::new(); + + for key in keys { + let heartbeat_data : HashMap = conn.hgetall(&key).await.context("Failed to get heartbeat data")?; + if !heartbeat_data.is_empty() { + + // cache subs + let subscription_uuid = option_to_result( + heartbeat_data.get(RedisDomain::Subscription.as_str()), + anyhow!("RedisError: No Heartbea/{} present!", RedisDomain::Subscription.as_str()))?; + + let subscription_data_opt = self.get_subscription_by_identifier(&subscription_uuid).await?; + + if subscription_data_opt.is_none() { + continue; + } + + let subscription_data = subscription_data_opt.ok_or_else(|| { + anyhow::anyhow!("Subscription data not found for UUID: {}", subscription_uuid) + })?; + + if fields.ip.is_some() && heartbeat_data.get(RedisDomain::Ip.as_str()) != fields.ip.as_ref() { + continue; + } + + let hb = HeartbeatData::new( + option_to_result( + heartbeat_data.get(RedisDomain::Machine.as_str()), + anyhow!("RedisError: No Heartbea/{} present!", RedisDomain::Machine.as_str()))?, + option_to_result( + heartbeat_data.get(RedisDomain::Ip.as_str()), + anyhow!("RedisError: No Heartbea/{} present!", RedisDomain::Ip.as_str()))?, + subscription_data, + heartbeat_data.get(RedisDomain::FirstSeen.as_str()) + .and_then(|value| value.parse::().ok()) + .with_context(|| format!("Failed to parse integer for field '{}'", RedisDomain::FirstSeen))?, + heartbeat_data.get(RedisDomain::LastSeen.as_str()) + .and_then(|value| value.parse::().ok()) + .with_context(|| format!("Failed to parse integer for field '{}'", RedisDomain::LastSeen))?, + heartbeat_data.get(RedisDomain::LastEventSeen.as_str()) + .and_then(|value| value.parse::().ok()), + ); + heartbeats.push(hb); + } else { + log::warn!("No bookmard found for key: {}", key); + } + } + + Ok(heartbeats) + } + +} + +async fn list_keys(con: &mut Connection, key: &str) -> Result> { + let mut res = Vec::new(); + let mut iter = con.scan_match::<&str, String>(key).await.context("Unable to list keys")?; + + while let Some(key) = iter.next().await { + res.push(key); + } + + Ok(res) +} + +async fn set_heartbeat_inner(conn: &mut Connection, subscription: &str, machine: &str, value: HeartbeatValue) -> Result<()> { + let redis_key = format!("{}:{}:{}", RedisDomain::Heartbeat, subscription.to_uppercase(), machine); + let key_exists = conn.exists(&redis_key).await.unwrap_or(true); + + let mut items:Vec<(RedisDomain, String)> = vec![ + (RedisDomain::Subscription, subscription.to_uppercase()), + (RedisDomain::Machine, machine.to_string()), + (RedisDomain::Ip, value.ip), + (RedisDomain::LastSeen, value.last_seen.to_string()), + ]; + + if !key_exists { + items.push((RedisDomain::FirstSeen, value.last_seen.to_string())); + } + if let Some(last_event_seen) = value.last_event_seen { + items.push((RedisDomain::LastEventSeen, last_event_seen.to_string())); + } + + let _: () = conn.hset_multiple(&redis_key, &items).await.context("Failed to store bookmark data")?; + + Ok(()) +} + +async fn set_heartbeat(conn: &mut Connection, key: &HeartbeatKey, value: &HeartbeatValue) -> Result<()> { + set_heartbeat_inner(conn, &key.subscription, &key.machine, value.to_owned()).await +} + +fn option_to_result(option: Option<&T>, err: E) -> Result +where + T: Clone, +{ + option.cloned().ok_or(err) +} + +#[allow(unused)] +#[async_trait] +impl Database for RedisDatabase { + async fn get_bookmark(&self, machine: &str, subscription: &str) -> Result> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = format!("{}:{}:{}", RedisDomain::BookMark, subscription.to_uppercase(), machine); + Ok(conn.hget(&key, RedisDomain::BookMark.as_str()).await.context("Failed to get bookmark data")?) + } + + async fn get_bookmarks(&self, subscription: &str) -> Result> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = format!("{}:{}:{}", RedisDomain::BookMark, subscription.to_uppercase(), RedisDomain::Any); + let keys = list_keys(&mut conn, &key).await?; + let mut bookmarks = Vec::::new(); + + for key in keys { + let bookmark_data : HashMap = conn.hgetall(&key).await.context("Failed to get bookmark data")?; + if !bookmark_data.is_empty() { + bookmarks.push(BookmarkData { + subscription: option_to_result( + bookmark_data.get(RedisDomain::Subscription.as_str()), + anyhow!("RedisError: No Bookmark/{} present!", RedisDomain::Subscription.as_str()))?.clone(), + machine: option_to_result( + bookmark_data.get(RedisDomain::Machine.as_str()), + anyhow!("RedisError: No Bookmark/{} present!", RedisDomain::Machine.as_str()))?.clone(), + bookmark: option_to_result( + bookmark_data.get(RedisDomain::BookMark.as_str()), + anyhow!("RedisError: No Bookmark/{} present!", RedisDomain::BookMark.as_str()))?.clone(), + }); + } else { + log::warn!("No bookmard found for key: {}", key); + } + } + + Ok(bookmarks) + } + + async fn store_bookmark( + &self, + machine: &str, + subscription: &str, + bookmark: &str, + ) -> Result<()> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = format!("{}:{}:{}", RedisDomain::BookMark, subscription.to_uppercase(), machine); + + let items:Vec<(RedisDomain, String)> = vec![ + (RedisDomain::Subscription, subscription.to_uppercase()), + (RedisDomain::Machine, machine.to_string()), + (RedisDomain::BookMark, bookmark.to_string()), + ]; + + let _: () = conn.hset_multiple(&key, &items).await.context("Failed to store bookmark data")?; + + Ok(()) + } + + async fn delete_bookmarks( + &self, + machine: Option<&str>, + subscription: Option<&str>, + ) -> Result<()> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let compose_key = |subscription: &str, machine: &str| -> String { + format!("{}:{}:{}", RedisDomain::BookMark, subscription.to_uppercase(), machine) + }; + let key : String = match (subscription, machine) { + (Some(subscription), Some(machine)) => { + compose_key(subscription, machine) + }, + (Some(subscription), None) => { + compose_key(subscription, RedisDomain::Any.as_str()) + }, + (None, Some(machine)) => { + compose_key(RedisDomain::Any.as_str(), machine) + }, + (None, None) => { + compose_key(RedisDomain::Any.as_str(), RedisDomain::Any.as_str()) + } + }; + + let keys = list_keys(&mut conn, &key).await?; + if (!keys.is_empty()) + { + let _ : usize = conn.del(keys.as_slice()).await.context("Failed to delete bookmark data")?; + } + + Ok(()) + } + + async fn get_heartbeats_by_machine( + &self, + machine: &str, + subscription: Option<&str>, + ) -> Result> { + let mut fields = HeartbeatFilter{ + subscription: subscription.map(String::from), + machine: Some(machine.to_string()), + ..Default::default() + }; + self.get_heartbeats_by_field(fields).await + } + + async fn get_heartbeats_by_ip( + &self, + ip: &str, + subscription: Option<&str>, + ) -> Result> { + let mut fields = HeartbeatFilter{ + ip: Some(ip.to_string()), + subscription: subscription.map(String::from), + ..Default::default()}; + self.get_heartbeats_by_field(fields).await + } + + async fn get_heartbeats(&self) -> Result> { + let fields = HeartbeatFilter::default(); + self.get_heartbeats_by_field(fields).await + } + + async fn get_heartbeats_by_subscription( + &self, + subscription: &str, + ) -> Result> { + let mut fields = HeartbeatFilter{ + subscription: Some(subscription.to_string()), + ..Default::default() + }; + + self.get_heartbeats_by_field(fields).await + } + + async fn store_heartbeat( + &self, + machine: &str, + ip: String, + subscription: &str, + is_event: bool, + ) -> Result<()> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH)? + .as_secs(); + + let hbv = HeartbeatValue{ + ip, + last_seen: now, + last_event_seen: if is_event { Some(now) } else { None }, + }; + + set_heartbeat_inner(&mut conn, subscription, machine, hbv).await + } + + async fn store_heartbeats(&self, heartbeats: &HeartbeatsCache) -> Result<()> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + for (key, value) in heartbeats.iter() { + set_heartbeat(&mut conn, key, value).await?; + } + Ok(()) + } + + async fn get_subscriptions(&self) -> Result> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + + let key = format!("{}:{}:{}", RedisDomain::Subscription, RedisDomain::Any, RedisDomain::Any); + + let keys = list_keys(&mut conn, &key).await?; + + let mut subscriptions = Vec::new(); + + for key in keys { + let subscription_redis_data: Option = conn.get(&key).await.context("Failed to get subscription data")?; + + if let Some(subscription_redis_data) = subscription_redis_data { + match serde_json::from_str::(&subscription_redis_data) { + Ok(subscription) => subscriptions.push(subscription.into()), + Err(err) => { + log::warn!("Failed to deserialize subscription data for key {}: {}", key, err); + } + } + } else { + log::warn!("No subscription found for key: {}", key); + } + } + + Ok(subscriptions) + } + + async fn get_subscription_by_identifier( + &self, + identifier: &str, + ) -> Result> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = format!("{}:{}:{}", RedisDomain::Subscription, RedisDomain::Any, RedisDomain::Any); + + let keys = list_keys(&mut conn, &key).await?; + + let filtered: Vec = keys.into_iter() + .filter(|key| key.split(':').skip(1).any(|elt| elt == identifier)) + .collect(); + + if !filtered.is_empty() { + let result: Option = conn.get(&filtered[0]).await.context("Failed to get subscription data")?; + if result.is_some() { + let subscription_redis_data: SubscriptionRedisData = serde_json::from_str(&result.unwrap()).context("Failed to deserialize subscription data")?; + return Ok(Some(subscription_redis_data.into())); + } + } + Ok(None) + } + + async fn store_subscription(&self, subscription: &SubscriptionData) -> Result<()> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + + let key_filter = format!("{}:{}:{}",RedisDomain::Subscription, subscription.uuid().to_string().to_uppercase(), RedisDomain::Any); + let keys = list_keys(&mut conn, &key_filter).await?; + if (!keys.is_empty()) { + let _:() = conn.del(keys).await?; + } + + + let subscription_redis_data = SubscriptionRedisData::from(subscription); + + let key = format!("{}:{}:{}", RedisDomain::Subscription, subscription.uuid().to_string().to_uppercase(), subscription.name()); + let value = serde_json::to_string(&subscription_redis_data).context("Failed to serialize subscription data")?; + let _ : String = conn.set(key, value).await.context("Failed to store subscription data")?; + Ok(()) + } + + async fn delete_subscription(&self, uuid: &str) -> Result<()> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = format!("{}:{}:{}", RedisDomain::Subscription, uuid.to_uppercase(), RedisDomain::Any); + + let keys = list_keys(&mut conn, &key).await?; + + self.delete_bookmarks(None, Some(uuid)).await.context("Failed to delete subscription releated bookmark data")?; + if !keys.is_empty() { + let _: () = conn.del(keys).await.context("Failed to delete subscription data")?; + } + Ok(()) + } + + /// Fails if `setup_schema` hasn't previously been called or if the query otherwise fails. + async fn current_version(&self) -> Result> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = MIGRATION_TABLE_NAME; + let versions:Vec = conn.zrange(key, -1, -1).await.context("There is no version info stored in DB.")?; + let last_version = versions.last().and_then(|v| v.parse::().ok()); + Ok(last_version) + } + + /// Fails if `setup_schema` hasn't previously been called or if the query otherwise fails. + async fn migrated_versions(&self) -> Result> { + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + let key = MIGRATION_TABLE_NAME; + let versions:Vec = conn.zrange(key, 0, -1).await.context("There is no version info stored in DB.")?; + let result : BTreeSet = versions.into_iter().map(|v| v.parse::().with_context(|| format!("Failed to parse version: {}", v))).collect::>()?; + Ok(result) + } + + /// Fails if `setup_schema` hasn't previously been called or if the migration otherwise fails. + async fn apply_migration(&self, version: Version) -> Result<()> { + let migration = self + .migrations + .get(&version) + .ok_or_else(|| anyhow!("Could not retrieve migration with version {}", version))? + .clone(); + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + migration.up(&mut conn).await?; + let key = MIGRATION_TABLE_NAME; + let version = migration.version(); + let added_count: i64 = conn.zadd(key, version, version).await.with_context(|| format!("Unable to add version: {}", version))?; + if added_count > 0 { + println!("Successfully added version {} to sorted set", version); + } else { + println!("Version {} was not added (it may already exist)", version); + } + Ok(()) + } + + /// Fails if `setup_schema` hasn't previously been called or if the migration otherwise fails. + async fn revert_migration(&self, version: Version) -> Result<()> { + let migration = self + .migrations + .get(&version) + .ok_or_else(|| anyhow!("Could not retrieve migration with version {}", version))? + .clone(); + let mut conn = self.pool.get().await.context("Failed to get Redis connection")?; + migration.down(&mut conn).await?; + let key = MIGRATION_TABLE_NAME; + let version = migration.version(); + let removed_count: i64 = conn.zrem(key, version).await.context("Failed to remove version")?; + if removed_count > 0 { + println!("Successfully removed version: {}", version); + } else { + println!("Version {} not found in the sorted set.", version); + } + Ok(()) + } + + /// Create the tables required to keep track of schema state. If the tables already + /// exist, this function has no operation. + async fn setup_schema(&self) -> Result<()> { + Ok(()) + } + + async fn migrations(&self) -> BTreeMap> { + let mut base_migrations = BTreeMap::new(); + for (version, migration) in self.migrations.iter() { + base_migrations.insert(*version, migration.to_base()); + } + base_migrations + } + + async fn get_stats( + &self, + subscription: &str, + start_time: i64, + ) -> Result { + let mut fields = HeartbeatFilter{ + subscription: Some(subscription.to_string()), + ..Default::default() + }; + let heartbeats = self.get_heartbeats_by_field(fields).await?; + + let total_machines_count = i64::try_from(heartbeats.len())?; + let mut alive_machines_count = 0; + let mut active_machines_count = 0; + let mut dead_machines_count = 0; + + for hb in heartbeats.iter() { + match hb { + HeartbeatData{last_seen, last_event_seen, ..} if MachineStatusFilter::Alive.is_match(last_seen, last_event_seen, start_time) => { + alive_machines_count += 1; + }, + HeartbeatData{last_seen, last_event_seen, ..} if MachineStatusFilter::Active.is_match(last_seen, last_event_seen, start_time) => { + active_machines_count += 1; + }, + HeartbeatData{last_seen, last_event_seen, ..} if MachineStatusFilter::Dead.is_match(last_seen, last_event_seen, start_time) => { + dead_machines_count += 1; + }, + _ => {}, + }; + } + + Ok(SubscriptionStatsCounters::new( + total_machines_count, + alive_machines_count, + active_machines_count, + dead_machines_count, + )) + } + + async fn get_machines( + &self, + subscription: &str, + start_time: i64, + stat_type: Option, + ) -> Result> { + let mut fields = HeartbeatFilter{ + subscription: Some(subscription.to_string()), + ..Default::default() + }; + + let heartbeats = self.get_heartbeats_by_field(fields).await?; + let mut result = Vec::::new(); + + for hb in heartbeats.iter() { + + match stat_type { + None => {}, + Some(SubscriptionMachineState::Active) => { + if !MachineStatusFilter::Active.is_match(&hb.last_seen, &hb.last_event_seen, start_time) { + continue; + } + }, + Some(SubscriptionMachineState::Alive) => { + if !MachineStatusFilter::Alive.is_match(&hb.last_seen, &hb.last_event_seen, start_time) { + continue; + } + }, + Some(SubscriptionMachineState::Dead) => { + if !MachineStatusFilter::Dead.is_match(&hb.last_seen, &hb.last_event_seen, start_time) { + continue; + } + }, + } + result.push(SubscriptionMachine::new(hb.machine().to_string(), hb.ip().to_string())); + } + + Ok(result) + } +} + + +#[cfg(test)] +mod tests { + + use std::{env, str::FromStr}; + use uuid::Uuid; + + use crate::{ + database::schema::{self, Migrator}, migration, subscription::SubscriptionUuid + }; + + use super::*; + use anyhow::Ok; + use serial_test::serial; + + #[allow(unused)] + async fn cleanup_db(db: &RedisDatabase) -> Result<()> { + let mut con = db.pool.get().await?; + let _ : () = deadpool_redis::redis::cmd("FLUSHALL").query_async(&mut con).await?; + Ok(()) + } + + async fn drop_migrations_table(db: &RedisDatabase) -> Result<()> { + let mut conn = db.pool.get().await.context("Failed to get Redis connection")?; + let key = MIGRATION_TABLE_NAME; + let _:() = conn.del(key).await?; + Ok(()) + } + + async fn redis_db() -> Result { + let connection_string = env::var("REDIS_URL").unwrap_or("redis://127.0.0.1:6379".to_string()); + RedisDatabase::new(connection_string.as_str()).await + } + + async fn db_with_migrations() -> Result> { + let mut db = redis_db().await?; + schema::redis::register_migrations(&mut db); + cleanup_db(&db).await?; + drop_migrations_table(&db).await?; + Ok(Arc::new(db)) + } + + #[tokio::test] + #[serial] + async fn test_open_and_close() -> Result<()> { + redis_db() + .await + .expect("Could not connect to database"); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_bookmarks() -> Result<()> { + crate::database::tests::test_bookmarks(db_with_migrations().await?).await?; + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_heartbeats() -> Result<()> { + crate::database::tests::test_heartbeats(db_with_migrations().await?).await?; + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_heartbeats_cache() -> Result<()> { + crate::database::tests::test_heartbeats_cache(db_with_migrations().await?).await?; + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_subscriptions() -> Result<()> { + crate::database::tests::test_subscriptions(db_with_migrations().await?).await?; + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_stats() -> Result<()> { + crate::database::tests::test_stats_and_machines(db_with_migrations().await?).await?; + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_current_version_empty() -> Result<()> { + let db = db_with_migrations().await?; + let res = db.current_version().await?; + assert_eq!(res, None); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_current_version() -> Result<()> { + let db = redis_db().await?; + let mut con = db.pool.get().await?; + let members = vec![(1.0, 1),(2.0, 2),(3.0, 3)]; + let _:() = con.zadd_multiple(MIGRATION_TABLE_NAME, &members).await?; + let res = db.current_version().await?; + assert_eq!(res, Some(3)); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_migrated_versions() -> Result<()> { + let db = redis_db().await?; + let mut con = db.pool.get().await?; + let members = vec![(1.0, 1),(2.0, 2),(3.0, 3)]; + let _:() = con.zadd_multiple(MIGRATION_TABLE_NAME, &members).await?; + let res = db.migrated_versions().await?; + assert_eq!(res, BTreeSet::::from_iter(vec![1,2,3])); + Ok(()) + } + + struct CreateUsers; + migration!(CreateUsers, 1, "create users table"); + + #[async_trait] + impl RedisMigration for CreateUsers { + async fn up(&self, conn: &mut Connection) -> Result<()> { + let key = format!("{}", RedisDomain::Users); + let _:() = conn.set(key, "").await?; + Ok(()) + } + + async fn down(&self, conn: &mut Connection) -> Result<()> { + let key = format!("{}", RedisDomain::Users); + let _:() = conn.del(key).await?; + Ok(()) + } + } + + #[tokio::test] + #[serial] + async fn test_register() -> Result<()> { + let mut db = redis_db() + .await + .expect("Could not connect to database"); + + drop_migrations_table(&db).await?; + db.register_migration(Arc::new(CreateUsers)); + + db.setup_schema().await.expect("Could not setup schema"); + + let db_arc = Arc::new(db); + + let migrator = Migrator::new(db_arc.clone()); + + migrator.up(None, false).await.unwrap(); + + assert_eq!(db_arc.current_version().await.unwrap(), Some(1)); + + migrator.down(None, false).await.unwrap(); + + assert_eq!(db_arc.current_version().await.unwrap(), None); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_list_keys() -> Result<()> { + let db = redis_db().await?; + cleanup_db(&db).await?; + let mut con = db.pool.get().await?; + db.store_bookmark("machine1", "subscription", "bookmark1").await?; + db.store_bookmark("machine2", "subscription", "bookmark2").await?; + let key = "BookMark:SUBSCRIPTION:*"; + let keys = list_keys(&mut con, key).await?; + assert!(keys.len() == 2); + Ok(()) + } + + #[tokio::test] + #[serial] + async fn test_heartbeat_without_subscription() -> Result<()> { + let db = redis_db().await?; + cleanup_db(&db).await?; + db.store_heartbeat("machine1", "127.0.0.1".to_string(), "subscription1", true).await?; + db.store_heartbeat("machine2", "192.168.0.1".to_string(), "subscription1", true).await?; + db.store_heartbeat("machine2", "0.0.0.1".to_string(), "b00bf259-3ba9-4faf-b58e-d0e9a3275778", true).await?; + let mut subs = SubscriptionData::new("subscription2", "query"); + subs.set_uuid(SubscriptionUuid(Uuid::from_str("b00bf259-3ba9-4faf-b58e-d0e9a3275778")?)); + db.store_subscription(&subs).await?; + let mut fields = HeartbeatFilter::default(); + fields.machine = Some("machine2".to_string()); + let heartbeat_data = db.get_heartbeats_by_field(fields).await?; + + assert!(heartbeat_data.len() == 1, "expected: {} received:{}", 1, heartbeat_data.len()); + assert_eq!(heartbeat_data[0].machine(), "machine2"); + + Ok(()) + } + +} diff --git a/common/src/database/schema/mod.rs b/common/src/database/schema/mod.rs index dcee7757..c1ba4566 100644 --- a/common/src/database/schema/mod.rs +++ b/common/src/database/schema/mod.rs @@ -34,6 +34,7 @@ use super::Database; pub mod postgres; pub mod sqlite; +pub mod redis; /// The version type alias used to uniquely reference migrations. pub type Version = i64; diff --git a/common/src/database/schema/redis/mod.rs b/common/src/database/schema/redis/mod.rs new file mode 100644 index 00000000..8bd0bc99 --- /dev/null +++ b/common/src/database/schema/redis/mod.rs @@ -0,0 +1,7 @@ +use crate::database::redis::RedisDatabase; + +pub mod v1; + +pub fn register_migrations(_redis_db: &mut RedisDatabase) { + // for future changes +} diff --git a/common/src/database/schema/redis/v1/mod.rs b/common/src/database/schema/redis/v1/mod.rs new file mode 100644 index 00000000..8e061c6e --- /dev/null +++ b/common/src/database/schema/redis/v1/mod.rs @@ -0,0 +1,3 @@ +pub mod subscription; + +pub const VERSION: &str = module_path!(); diff --git a/common/src/database/schema/redis/v1/subscription.rs b/common/src/database/schema/redis/v1/subscription.rs new file mode 100644 index 00000000..84b1450b --- /dev/null +++ b/common/src/database/schema/redis/v1/subscription.rs @@ -0,0 +1,57 @@ +use serde::{Deserialize, Serialize}; +use crate::subscription::*; + +use super::VERSION; + +#[derive(Debug, PartialEq, Clone, Eq, Serialize, Deserialize)] +pub struct SubscriptionRedisData { + version: String, + uuid: SubscriptionUuid, + internal_version: InternalVersion, + revision: Option, + uri: Option, + enabled: bool, + princs_filter: PrincsFilter, + parameters: SubscriptionParameters, + outputs: Vec, +} + +impl SubscriptionRedisData { + pub fn from_subscription_data(from: &SubscriptionData) -> Self { + Self { + version: VERSION.to_string(), + uuid: *from.uuid(), + internal_version: from.internal_version(), + revision: from.revision().cloned(), + uri: from.uri().cloned(), + enabled: from.enabled(), + princs_filter: from.princs_filter().clone(), + parameters: from.parameters().clone(), + outputs: from.outputs().to_vec(), + } + } + pub fn into_subscription_data(self) -> SubscriptionData { + let mut sd = SubscriptionData::new(&self.parameters.name, &self.parameters.query); + sd.set_revision(self.revision). + set_uuid(self.uuid). + set_uri(self.uri). + set_enabled(self.enabled). + set_princs_filter(self.princs_filter). + set_parameters(self.parameters). + set_outputs(self.outputs); + sd.set_internal_version(self.internal_version); + sd + } +} + +impl From<&SubscriptionData> for SubscriptionRedisData { + fn from(value: &SubscriptionData) -> Self { + SubscriptionRedisData::from_subscription_data(value) + } +} + +impl From for SubscriptionData { + fn from(value: SubscriptionRedisData) -> Self { + value.into_subscription_data() + } +} diff --git a/common/src/heartbeat.rs b/common/src/heartbeat.rs index e4a79a9e..e31fb948 100644 --- a/common/src/heartbeat.rs +++ b/common/src/heartbeat.rs @@ -13,9 +13,9 @@ pub struct HeartbeatData { #[serde(serialize_with = "utils::serialize_timestamp")] first_seen: Timestamp, #[serde(serialize_with = "utils::serialize_timestamp")] - last_seen: Timestamp, + pub last_seen: Timestamp, #[serde(serialize_with = "utils::serialize_option_timestamp")] - last_event_seen: Option, + pub last_event_seen: Option, } fn serialize_subscription_data( diff --git a/common/src/settings.rs b/common/src/settings.rs index 6d7e42d8..3b5de367 100644 --- a/common/src/settings.rs +++ b/common/src/settings.rs @@ -18,6 +18,7 @@ pub enum Authentication { pub enum Database { SQLite(SQLite), Postgres(Postgres), + Redis(Redis), } #[derive(Debug, Deserialize, Clone)] @@ -102,6 +103,18 @@ impl Kerberos { } } +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Redis { + connection_url: String, +} + +impl Redis { + pub fn connection_url(&self) -> &str { + &self.connection_url + } +} + #[derive(Debug, Deserialize, Clone)] #[serde(deny_unknown_fields)] pub struct SQLite { diff --git a/common/src/subscription.rs b/common/src/subscription.rs index bfcdfb3b..e003f906 100644 --- a/common/src/subscription.rs +++ b/common/src/subscription.rs @@ -265,7 +265,7 @@ impl SubscriptionOutputFormat { } } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub enum PrincsFilterOperation { Only, Except, @@ -294,7 +294,7 @@ impl PrincsFilterOperation { } } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub struct PrincsFilter { operation: Option, princs: HashSet, @@ -383,7 +383,7 @@ impl PrincsFilter { } } -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] pub enum ContentFormat { Raw, RenderedText, @@ -412,7 +412,7 @@ impl FromStr for ContentFormat { } } -#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy, Serialize)] +#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy, Serialize, Deserialize)] pub struct SubscriptionUuid(pub Uuid); impl Display for SubscriptionUuid { @@ -424,7 +424,7 @@ impl Display for SubscriptionUuid { // We use the newtype pattern so that the compiler can check that // we don't use one instead of the other -#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)] +#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy, Serialize, Deserialize)] pub struct InternalVersion(pub Uuid); impl Display for InternalVersion { @@ -447,7 +447,7 @@ impl Display for PublicVersion { /// of the subscription is updated and clients are expected to update /// their configuration. /// Every elements must implement the Hash trait -#[derive(Debug, PartialEq, Clone, Eq, Hash)] +#[derive(Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)] pub struct SubscriptionParameters { pub name: String, pub query: String, @@ -464,7 +464,7 @@ pub struct SubscriptionParameters { pub data_locale: Option, } -#[derive(Debug, PartialEq, Clone, Eq)] +#[derive(Debug, PartialEq, Clone, Eq, Serialize, Deserialize)] pub struct SubscriptionData { // Unique identifier of the subscription uuid: SubscriptionUuid, @@ -661,6 +661,15 @@ impl SubscriptionData { Ok(PublicVersion(Uuid::from_u64_pair(result, result))) } + pub fn parameters(&self) -> &SubscriptionParameters { + &self.parameters + } + + pub fn set_parameters(&mut self, parameters: SubscriptionParameters) -> &mut SubscriptionData { + self.parameters = parameters; + self + } + /// Get a reference to the subscription's name. pub fn name(&self) -> &str { self.parameters.name.as_ref()