diff --git a/Cargo.toml b/Cargo.toml index 7b1c4d5..aef488d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ edition = "2018" [dependencies] serde_json = "1" serde = "1" +float-cmp = "0.9.0" [dev-dependencies] version-sync = "0.8" diff --git a/src/diff.rs b/src/diff.rs index de97a0d..81f8e79 100644 --- a/src/diff.rs +++ b/src/diff.rs @@ -1,5 +1,6 @@ use crate::core_ext::{Indent, Indexes}; -use crate::{CompareMode, Config, NumericMode}; +use crate::{CompareMode, Config, FloatCompareMode, NumericMode}; +use float_cmp::{ApproxEq, F64Margin}; use serde_json::Value; use std::{collections::HashSet, fmt}; @@ -56,8 +57,11 @@ impl<'a, 'b> DiffFolder<'a, 'b> { fn on_number(&mut self, lhs: &'a Value) { let is_equal = match self.config.numeric_mode { - NumericMode::Strict => self.rhs == lhs, - NumericMode::AssumeFloat => self.rhs.as_f64() == lhs.as_f64(), + NumericMode::Strict => self.eq_values(lhs, self.rhs), + NumericMode::AssumeFloat => match (lhs.as_f64(), self.rhs.as_f64()) { + (Some(lhs), Some(rhs)) => self.eq_floats(lhs, rhs), + (lhs, rhs) => lhs == rhs, + }, }; if !is_equal { self.acc.push(Difference { @@ -69,6 +73,27 @@ impl<'a, 'b> DiffFolder<'a, 'b> { } } + fn eq_values(&self, lhs: &Value, rhs: &Value) -> bool { + if lhs.is_f64() && rhs.is_f64() { + // `as_f64` must return a floating point value if `is_f64` returned true. The inverse + // relation is not guaranteed by serde_json. + self.eq_floats( + lhs.as_f64().expect("float value"), + rhs.as_f64().expect("float value"), + ) + } else { + lhs == rhs + } + } + + fn eq_floats(&self, lhs: f64, rhs: f64) -> bool { + if let FloatCompareMode::Epsilon(epsilon) = self.config.float_compare_mode { + lhs.approx_eq(rhs, F64Margin::default().epsilon(epsilon)) + } else { + lhs == rhs + } + } + fn on_array(&mut self, lhs: &'a Value) { if let Some(rhs) = self.rhs.as_array() { let lhs = lhs.as_array().unwrap(); @@ -401,6 +426,37 @@ mod test { Config::new(CompareMode::Inclusive).numeric_mode(NumericMode::AssumeFloat), ); assert_eq!(diffs, vec![]); + + let actual = json!(1.15); + let expected = json!(1); + let diffs = diff( + &actual, + &expected, + Config::new(CompareMode::Inclusive) + .numeric_mode(NumericMode::AssumeFloat) + .float_compare_mode(FloatCompareMode::Epsilon(0.2)), + ); + assert_eq!(diffs, vec![]); + + let actual = json!(1.25); + let expected = json!(1); + let diffs = diff( + &actual, + &expected, + Config::new(CompareMode::Inclusive) + .numeric_mode(NumericMode::AssumeFloat) + .float_compare_mode(FloatCompareMode::Epsilon(0.2)), + ); + assert_eq!(diffs.len(), 1); + + let actual = json!(2); + let expected = json!(1); + let diffs = diff( + &actual, + &expected, + Config::new(CompareMode::Inclusive).float_compare_mode(FloatCompareMode::Epsilon(2.0)), + ); + assert_eq!(diffs.len(), 1); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 961b349..f14de1c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -299,11 +299,12 @@ where } /// Configuration for how JSON values should be compared. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] #[allow(missing_copy_implementations)] pub struct Config { pub(crate) compare_mode: CompareMode, pub(crate) numeric_mode: NumericMode, + float_compare_mode: FloatCompareMode, } impl Config { @@ -314,6 +315,7 @@ impl Config { Self { compare_mode, numeric_mode: NumericMode::Strict, + float_compare_mode: FloatCompareMode::Exact, } } @@ -330,6 +332,14 @@ impl Config { self.compare_mode = compare_mode; self } + + /// Change the config's float compare mode. + /// + /// The default `float_compare_mode` is [`FloatCompareMode::Exact`]. + pub fn float_compare_mode(mut self, float_compare_mode: FloatCompareMode) -> Self { + self.float_compare_mode = float_compare_mode; + self + } } /// Mode for how JSON values should be compared. @@ -355,6 +365,17 @@ pub enum NumericMode { AssumeFloat, } +/// How should floating point numbers be compared. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum FloatCompareMode { + /// Different floats are never considered equal. + Exact, + /// Floats are considered equal if they differ by at most this epsilon value. + Epsilon(f64), +} + +impl Eq for FloatCompareMode {} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 1bbe8ee..5da3242 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -1,6 +1,6 @@ use assert_json_diff::{ assert_json_eq, assert_json_include, assert_json_matches, assert_json_matches_no_panic, - CompareMode, Config, NumericMode, + CompareMode, Config, FloatCompareMode, NumericMode, }; use serde::Serialize; use serde_json::json; @@ -182,3 +182,79 @@ fn eq_with_serializable_ref() { &user, ); } + +#[derive(Serialize)] +struct Person { + name: String, + height: f64, +} + +#[test] +fn can_pass_with_exact_float_comparison() { + let person = Person { + name: "bob".to_string(), + height: 1.79, + }; + + assert_json_matches!( + &json!({ + "name": "bob", + "height": 1.79 + }), + &person, + Config::new(CompareMode::Strict).float_compare_mode(FloatCompareMode::Exact) + ); +} + +#[test] +#[should_panic] +fn can_fail_with_exact_float_comparison() { + let person = Person { + name: "bob".to_string(), + height: 1.79, + }; + + assert_json_matches!( + &json!({ + "name": "bob", + "height": 1.7900001 + }), + &person, + Config::new(CompareMode::Strict).float_compare_mode(FloatCompareMode::Exact) + ); +} + +#[test] +fn can_pass_with_epsilon_based_float_comparison() { + let person = Person { + name: "bob".to_string(), + height: 1.79, + }; + + assert_json_matches!( + &json!({ + "name": "bob", + "height": 1.7900001 + }), + &person, + Config::new(CompareMode::Strict).float_compare_mode(FloatCompareMode::Epsilon(0.00001)) + ); +} + +#[test] +#[should_panic] +fn can_fail_with_epsilon_based_float_comparison() { + let person = Person { + name: "bob".to_string(), + height: 1.79, + }; + + assert_json_matches!( + &json!({ + "name": "bob", + "height": 1.7901 + }), + &person, + Config::new(CompareMode::Strict).float_compare_mode(FloatCompareMode::Epsilon(0.00001)) + ); +}