From f7b49d3ed7572a8f0e1adc100d323f1d4adbe418 Mon Sep 17 00:00:00 2001 From: Bogdan Drutu Date: Sat, 4 Feb 2023 16:26:48 -0800 Subject: [PATCH] Add support to compare uintptr Signed-off-by: Bogdan Drutu --- assert/assertion_compare.go | 22 ++++++++++++++++++++++ assert/assertion_compare_test.go | 7 +++++++ 2 files changed, 29 insertions(+) diff --git a/assert/assertion_compare.go b/assert/assertion_compare.go index 95d8e59da..32b29f515 100644 --- a/assert/assertion_compare.go +++ b/assert/assertion_compare.go @@ -28,6 +28,8 @@ var ( uint32Type = reflect.TypeOf(uint32(1)) uint64Type = reflect.TypeOf(uint64(1)) + uintptrType = reflect.TypeOf(uintptr(1)) + float32Type = reflect.TypeOf(float32(1)) float64Type = reflect.TypeOf(float64(1)) @@ -345,6 +347,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true } + case reflect.Uintptr: + { + uintptrObj1, ok := obj1.(uintptr) + if !ok { + uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr) + } + uintptrObj2, ok := obj2.(uintptr) + if !ok { + uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr) + } + if uintptrObj1 > uintptrObj2 { + return compareGreater, true + } + if uintptrObj1 == uintptrObj2 { + return compareEqual, true + } + if uintptrObj1 < uintptrObj2 { + return compareLess, true + } + } } return compareEqual, false diff --git a/assert/assertion_compare_test.go b/assert/assertion_compare_test.go index a38d88060..4394cb46a 100644 --- a/assert/assertion_compare_test.go +++ b/assert/assertion_compare_test.go @@ -22,6 +22,7 @@ func TestCompare(t *testing.T) { type customFloat32 float32 type customFloat64 float64 type customString string + type customUintptr uintptr for _, currCase := range []struct { less interface{} greater interface{} @@ -52,6 +53,8 @@ func TestCompare(t *testing.T) { {less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"}, {less: float64(1.23), greater: float64(2.34), cType: "float64"}, {less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"}, + {less: uintptr(1), greater: uintptr(2), cType: "uintptr"}, + {less: customUintptr(1), greater: customUintptr(2), cType: "uint64"}, } { resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) if !isComparable { @@ -148,6 +151,7 @@ func TestGreater(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"1" is not greater than "2"`}, {less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than "2.34"`}, {less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than "2.34"`}, + {less: uintptr(1), greater: uintptr(2), msg: `"1" is not greater than "2"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, Greater(out, currCase.less, currCase.greater)) @@ -189,6 +193,7 @@ func TestGreaterOrEqual(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"1" is not greater than or equal to "2"`}, {less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than or equal to "2.34"`}, {less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than or equal to "2.34"`}, + {less: uintptr(1), greater: uintptr(2), msg: `"1" is not greater than or equal to "2"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, GreaterOrEqual(out, currCase.less, currCase.greater)) @@ -230,6 +235,7 @@ func TestLess(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"2" is not less than "1"`}, {less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than "1.23"`}, {less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than "1.23"`}, + {less: uintptr(1), greater: uintptr(2), msg: `"2" is not less than "1"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, Less(out, currCase.greater, currCase.less)) @@ -271,6 +277,7 @@ func TestLessOrEqual(t *testing.T) { {less: uint64(1), greater: uint64(2), msg: `"2" is not less than or equal to "1"`}, {less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than or equal to "1.23"`}, {less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than or equal to "1.23"`}, + {less: uintptr(1), greater: uintptr(2), msg: `"2" is not less than or equal to "1"`}, } { out := &outputT{buf: bytes.NewBuffer(nil)} False(t, LessOrEqual(out, currCase.greater, currCase.less))