diff --git a/pumpkin-crates/core/src/engine/notifications/predicate_notification/predicate_tracker.rs b/pumpkin-crates/core/src/engine/notifications/predicate_notification/predicate_tracker.rs index 12da90c00..f69ada14e 100644 --- a/pumpkin-crates/core/src/engine/notifications/predicate_notification/predicate_tracker.rs +++ b/pumpkin-crates/core/src/engine/notifications/predicate_notification/predicate_tracker.rs @@ -1087,7 +1087,7 @@ mod tests { vec![ PredicateType::LowerBound, PredicateType::NotEqual, - PredicateType::Equal + PredicateType::Equal, ] ); assert_eq!(value.get_value(), x); @@ -1097,9 +1097,9 @@ mod tests { value.get_predicate_types().collect::>(), vec![ PredicateType::LowerBound, - PredicateType::UpperBound, PredicateType::NotEqual, - PredicateType::Equal + PredicateType::Equal, + PredicateType::UpperBound, ] ); assert_eq!(value.get_value(), x); @@ -1137,7 +1137,7 @@ mod tests { vec![ PredicateType::LowerBound, PredicateType::NotEqual, - PredicateType::Equal + PredicateType::Equal, ] ); assert_eq!(value.get_value(), x); @@ -1147,9 +1147,9 @@ mod tests { value.get_predicate_types().collect::>(), vec![ PredicateType::LowerBound, - PredicateType::UpperBound, PredicateType::NotEqual, - PredicateType::Equal + PredicateType::Equal, + PredicateType::UpperBound, ] ); assert_eq!(value.get_value(), x); diff --git a/pumpkin-crates/core/src/engine/predicates/predicate.rs b/pumpkin-crates/core/src/engine/predicates/predicate.rs index df6edac0f..b62ae72e1 100644 --- a/pumpkin-crates/core/src/engine/predicates/predicate.rs +++ b/pumpkin-crates/core/src/engine/predicates/predicate.rs @@ -10,6 +10,16 @@ use crate::propagation::DomainEvent; /// ([`DomainId`], [`PredicateType`], value). /// /// To create a [`Predicate`], use [Predicate::new] or the more concise [predicate!] macro. +/// +/// ## Order +/// Predicates have a well-defined order. They are first ordered by the domain, and then by +/// predicate type, and finally by the value. The order is chosen such that for a fixed domain `x`, +/// predicates are ordered as follows: +/// [>= 5], [>= 7], [!= 2], [!= 3], [== 5], [!= 7], [<= 6], [<= 10] +/// +/// From the order, we get the lower-bound predicates first, ordered by non-decreasing bound, then +/// the (not-)equal predicates, ordered by non-decreasing bound, then the upper-bound predicates, +/// ordered by non-increasing bounds. #[derive(Clone, PartialEq, Eq, Copy, Hash)] pub struct Predicate { /// The two most significant bits of the id stored in the [`Predicate`] contains the type of @@ -18,25 +28,72 @@ pub struct Predicate { value: i32, } -const LOWER_BOUND_CODE: u8 = 0; -const UPPER_BOUND_CODE: u8 = 1; -const NOT_EQUAL_CODE: u8 = 2; -const EQUAL_CODE: u8 = 3; +const LOWER_BOUND_CODE: u8 = PredicateType::LowerBound as u8; +const UPPER_BOUND_CODE: u8 = PredicateType::UpperBound as u8; +const NOT_EQUAL_CODE: u8 = PredicateType::NotEqual as u8; +const EQUAL_CODE: u8 = PredicateType::Equal as u8; impl Predicate { /// Creates a new [`Predicate`] (also known as atomic constraint) which represents a domain /// operation. pub fn new(id: DomainId, predicate_type: PredicateType, value: i32) -> Self { - let code = match predicate_type { - PredicateType::LowerBound => LOWER_BOUND_CODE, - PredicateType::UpperBound => UPPER_BOUND_CODE, - PredicateType::NotEqual => NOT_EQUAL_CODE, - PredicateType::Equal => EQUAL_CODE, - }; + let code = predicate_type as u8; let id = id.id() | (code as u32) << 30; Self { id, value } } + /// Returns `true` if `self` implies `other`. + /// + /// # Example + /// ``` + /// # use pumpkin_core::variables::DomainId; + /// # use pumpkin_core::predicate; + /// let x = DomainId::new(0); + /// + /// assert!(predicate![x >= 5].implies(predicate![x >= 3])); + /// assert!(predicate![x >= 5].implies(predicate![x != 1])); + /// assert!(predicate![x == 5].implies(predicate![x <= 5])); + /// ``` + pub fn implies(&self, other: Predicate) -> bool { + if self.get_domain() != other.get_domain() { + // Predicates only imply other predicates on the same domain. + return false; + } + + match self.get_predicate_type() { + PredicateType::LowerBound => match other.get_predicate_type() { + PredicateType::LowerBound => { + self.get_right_hand_side() >= other.get_right_hand_side() + } + PredicateType::NotEqual => self.get_right_hand_side() > other.get_right_hand_side(), + PredicateType::UpperBound | PredicateType::Equal => false, + }, + PredicateType::UpperBound => match other.get_predicate_type() { + PredicateType::UpperBound => { + self.get_right_hand_side() <= other.get_right_hand_side() + } + PredicateType::NotEqual => self.get_right_hand_side() < other.get_right_hand_side(), + PredicateType::LowerBound | PredicateType::Equal => false, + }, + PredicateType::NotEqual => { + other.get_predicate_type() == PredicateType::NotEqual + && self.get_right_hand_side() == other.get_right_hand_side() + } + PredicateType::Equal => match other.get_predicate_type() { + PredicateType::LowerBound => { + self.get_right_hand_side() >= other.get_right_hand_side() + } + PredicateType::UpperBound => { + self.get_right_hand_side() <= other.get_right_hand_side() + } + PredicateType::NotEqual => { + self.get_right_hand_side() != other.get_right_hand_side() + } + PredicateType::Equal => self.get_right_hand_side() == other.get_right_hand_side(), + }, + } + } + fn get_type_code(&self) -> u8 { (self.id >> 30) as u8 } @@ -44,6 +101,40 @@ impl Predicate { pub fn get_predicate_type(&self) -> PredicateType { (*self).into() } + + fn is_bound_predicate(&self) -> bool { + self.is_upper_bound_predicate() || self.is_lower_bound_predicate() + } +} + +impl PartialOrd for Predicate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Predicate { + /// See [`Predicate`] for details on the order. + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match self.get_domain().cmp(&other.get_domain()) { + std::cmp::Ordering::Equal => { + if self.is_bound_predicate() || other.is_bound_predicate() { + match self.get_type_code().cmp(&other.get_type_code()) { + std::cmp::Ordering::Equal => { + self.get_right_hand_side().cmp(&other.get_right_hand_side()) + } + ordering @ (std::cmp::Ordering::Less | std::cmp::Ordering::Greater) => { + ordering + } + } + } else { + self.get_right_hand_side().cmp(&other.get_right_hand_side()) + } + } + + ordering @ (std::cmp::Ordering::Less | std::cmp::Ordering::Greater) => ordering, + } + } } #[derive(Debug, Hash, EnumSetType)] @@ -53,9 +144,9 @@ pub enum PredicateType { // Should correspond with the codes defined previously; `EnumSetType` requires that literals // are used and not expressions LowerBound = 0, - UpperBound = 1, - NotEqual = 2, - Equal = 3, + NotEqual = 1, + Equal = 2, + UpperBound = 3, } impl From for PredicateType { @@ -321,4 +412,174 @@ mod test { let trivially_false = Predicate::trivially_false(); assert!(!trivially_false == trivially_true); } + + #[test] + fn predicates_over_same_domain_are_ordered_by_increasing_lower_bound() { + let x = DomainId::new(0); + let p1 = predicate![x >= 4]; + let p2 = predicate![x >= 6]; + assert!(p1 < p2); + } + + #[test] + fn not_equal_predicates_are_bigger_than_lower_bounds() { + let x = DomainId::new(0); + let p1 = predicate![x >= 4]; + let p2 = predicate![x != 6]; + let p3 = predicate![x != 2]; + + assert!(p1 < p2); + assert!(p1 < p3); + } + + #[test] + fn not_equal_predicates_are_ordered_by_rhs() { + let x = DomainId::new(0); + let p1 = predicate![x != 6]; + let p2 = predicate![x != 2]; + + assert!(p1 > p2); + } + + #[test] + fn equal_predicates_are_ordered_by_rhs() { + let x = DomainId::new(0); + let p1 = predicate![x == 6]; + let p2 = predicate![x == 2]; + + assert!(p1 > p2); + } + + #[test] + fn equal_predicates_bigger_than_lower_bounds() { + let x = DomainId::new(0); + let p1 = predicate![x == 6]; + let p2 = predicate![x >= 2]; + + assert!(p1 > p2); + } + + #[test] + fn equal_predicates_smaller_than_upper_bounds() { + let x = DomainId::new(0); + let p1 = predicate![x == 6]; + let p2 = predicate![x <= 2]; + + assert!(p1 < p2); + } + + #[test] + fn tighter_upper_bound_is_smaller() { + let x = DomainId::new(0); + let p1 = predicate![x <= 6]; + let p2 = predicate![x <= 2]; + + assert!(p1 > p2); + } + + #[test] + fn implies_over_different_domains_is_false() { + let x = DomainId::new(0); + let y = DomainId::new(1); + + assert!(!predicate![x >= 5].implies(predicate![y >= 4])); + } + + #[test] + fn lower_bound_implies() { + let x = DomainId::new(0); + + // Implies weaker bounds + assert!(predicate![x >= 5].implies(predicate![x >= 5])); + assert!(predicate![x >= 5].implies(predicate![x >= 4])); + + // Implies not-equals below bound + assert!(predicate![x >= 5].implies(predicate![x != 4])); + assert!(predicate![x >= 5].implies(predicate![x != 3])); + + // Does not imply stronger bounds + assert!(!predicate![x >= 5].implies(predicate![x >= 6])); + + // Does not imply not-equals at or above bound + assert!(!predicate![x >= 5].implies(predicate![x != 6])); + assert!(!predicate![x >= 5].implies(predicate![x != 5])); + + // Does not imply equals + assert!(!predicate![x >= 5].implies(predicate![x == 6])); + assert!(!predicate![x >= 5].implies(predicate![x == 5])); + assert!(!predicate![x >= 5].implies(predicate![x == 4])); + } + + #[test] + fn upper_bound_implies() { + let x = DomainId::new(0); + + // Implies weaker bounds + assert!(predicate![x <= 5].implies(predicate![x <= 5])); + assert!(predicate![x <= 5].implies(predicate![x <= 6])); + + // Implies not-equals above bound + assert!(predicate![x <= 5].implies(predicate![x != 6])); + assert!(predicate![x <= 5].implies(predicate![x != 7])); + + // Does not imply stronger bounds + assert!(!predicate![x <= 5].implies(predicate![x <= 4])); + + // Does not imply not-equals at or below bound + assert!(!predicate![x <= 5].implies(predicate![x != 4])); + assert!(!predicate![x <= 5].implies(predicate![x != 5])); + + // Does not imply equals + assert!(!predicate![x <= 5].implies(predicate![x == 6])); + assert!(!predicate![x <= 5].implies(predicate![x == 5])); + assert!(!predicate![x <= 5].implies(predicate![x == 4])); + } + + #[test] + fn equals_implies() { + let x = DomainId::new(0); + + // Implies lower bounds at or below + assert!(predicate![x == 5].implies(predicate![x >= 5])); + assert!(predicate![x == 5].implies(predicate![x >= 4])); + + // Implies upper bounds at or above + assert!(predicate![x == 5].implies(predicate![x <= 5])); + assert!(predicate![x == 5].implies(predicate![x <= 6])); + + // Implies not-equals + assert!(predicate![x == 5].implies(predicate![x != 4])); + assert!(predicate![x == 5].implies(predicate![x != 6])); + + // Does not imply not-equals at bound + assert!(!predicate![x == 5].implies(predicate![x != 5])); + + // Does not lower bounds above value + assert!(!predicate![x == 5].implies(predicate![x >= 6])); + + // Does not upper bounds below value + assert!(!predicate![x == 5].implies(predicate![x <= 4])); + } + + #[test] + fn not_equals_implies_nothing() { + let x = DomainId::new(0); + + assert!(!predicate![x != 5].implies(predicate![x <= 4])); + assert!(!predicate![x != 5].implies(predicate![x <= 5])); + assert!(!predicate![x != 5].implies(predicate![x <= 6])); + + assert!(!predicate![x != 5].implies(predicate![x >= 4])); + assert!(!predicate![x != 5].implies(predicate![x >= 5])); + assert!(!predicate![x != 5].implies(predicate![x >= 6])); + + assert!(!predicate![x != 5].implies(predicate![x == 4])); + assert!(!predicate![x != 5].implies(predicate![x == 5])); + assert!(!predicate![x != 5].implies(predicate![x == 6])); + + assert!(!predicate![x != 5].implies(predicate![x != 4])); + assert!(!predicate![x != 5].implies(predicate![x != 6])); + + assert!(predicate![x != 5].implies(predicate![x != 5])); + } }