@@ -9,13 +9,17 @@ import (
99 "log"
1010 "reflect"
1111 "strings"
12+ "time"
1213)
1314
1415var (
1516 // FloatPrecision is the number of decimal places to round float values
1617 // to when comparing.
1718 FloatPrecision = 10
1819
20+ // TimePrecision is a precision used for time.Time.Truncate(), if it is non-zero.
21+ TimePrecision time.Duration
22+
1923 // MaxDiff specifies the maximum number of differences to return.
2024 MaxDiff = 10
2125
@@ -79,7 +83,11 @@ type cmp struct {
7983 flag map [byte ]bool
8084}
8185
82- var errorType = reflect .TypeOf ((* error )(nil )).Elem ()
86+ var (
87+ errorType = reflect .TypeOf ((* error )(nil )).Elem ()
88+ timeType = reflect .TypeOf (time.Time {})
89+ durationType = reflect .TypeOf (time .Nanosecond )
90+ )
8391
8492// Equal compares variables a and b, recursing into their structure up to
8593// MaxDepth levels deep (if greater than zero), and returns a list of differences,
@@ -203,6 +211,23 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
203211 return
204212 }
205213
214+ fixTimePrecision := func () {
215+ if TimePrecision > 0 {
216+ switch aType {
217+ case timeType , durationType :
218+ aFunc := a .MethodByName ("Truncate" )
219+ bFunc := a .MethodByName ("Truncate" )
220+
221+ if aFunc .CanInterface () && bFunc .CanInterface () {
222+ precision := reflect .ValueOf (TimePrecision )
223+
224+ a = aFunc .Call ([]reflect.Value {precision })[0 ]
225+ b = bFunc .Call ([]reflect.Value {precision })[0 ]
226+ }
227+ }
228+ }
229+ }
230+
206231 switch aKind {
207232
208233 /////////////////////////////////////////////////////////////////////
@@ -221,6 +246,8 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
221246 Iterate through the fields (FirstName, LastName), recurse into their values.
222247 */
223248
249+ fixTimePrecision ()
250+
224251 // Types with an Equal() method, like time.Time, only if struct field
225252 // is exported (CanInterface)
226253 if eqFunc := a .MethodByName ("Equal" ); eqFunc .IsValid () && eqFunc .CanInterface () {
@@ -439,6 +466,7 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
439466 c .saveDiff (a .Bool (), b .Bool ())
440467 }
441468 case reflect .Int , reflect .Int8 , reflect .Int16 , reflect .Int32 , reflect .Int64 :
469+ fixTimePrecision ()
442470 if a .Int () != b .Int () {
443471 c .saveDiff (a .Int (), b .Int ())
444472 }
0 commit comments