assert: allow comparing time.Time

pull/1164/head
Torkel Rogstad 2022-01-17 14:10:48 +01:00 committed by Boyan Soubachov
parent 7bcf74e94f
commit 087b655c75
2 changed files with 30 additions and 1 deletions

View File

@ -3,6 +3,7 @@ package assert
import (
"fmt"
"reflect"
"time"
)
type CompareType int
@ -30,6 +31,8 @@ var (
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !obj1Value.CanConvert(timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
}
return compareEqual, false

View File

@ -6,6 +6,7 @@ import (
"reflect"
"runtime"
"testing"
"time"
)
func TestCompare(t *testing.T) {
@ -22,6 +23,7 @@ func TestCompare(t *testing.T) {
type customFloat32 float32
type customFloat64 float64
type customString string
type customTime time.Time
for _, currCase := range []struct {
less interface{}
greater interface{}
@ -52,6 +54,8 @@ func TestCompare(t *testing.T) {
{less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"},
{less: float64(1.23), greater: float64(2.34), cType: "float64"},
{less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"},
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
} {
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
@ -59,7 +63,8 @@ func TestCompare(t *testing.T) {
}
if resLess != compareLess {
t.Errorf("object less should be less than greater for type " + currCase.cType)
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
currCase.less, currCase.greater)
}
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())