Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert.ErrorAs: log target type #1345

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 36 additions & 10 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -1997,7 +1997,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)

return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+
"expected: %q\n"+
Expand All @@ -2020,7 +2020,7 @@ func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool {
expectedText = target.Error()
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, false)

return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+
"found: %q\n"+
Expand All @@ -2038,24 +2038,50 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{
return true
}

chain := buildErrorChainString(err)
chain := buildErrorChainString(err, true)

return Fail(t, fmt.Sprintf("Should be in error chain:\n"+
"expected: %q\n"+
"expected: %T\n"+
"in chain: %s", target, chain,
), msgAndArgs...)
}

func buildErrorChainString(err error) string {
func unwrapAll(err error) (errs []error) {
errs = append(errs, err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind explaining this one? I believe this should only be the default cause since you're going to have recursive definitions of error if you explore each value in the slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically "flattens" the provided error (err) into a slice so we can access the type and Error() string of each constituent error. We start by appending err to the slice because it itself has a type and Error() string.

errors.As failing for a given error means that the target error type is not present in any of these constituent errors. So, we need to exhaustively log all of the types in the error message.

I'm not 100% sure if this answers your question.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaa, so the usecase I am thinking of is the following:

errs := unwrap(errors.Join(errors.New("a"), errors.New("b"))

This would mean the resulting values would be:

[]error(error([]{"a", "b"}, "a", "b")

Which I don't believe you want to repeat those values

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I do want to repeat those values - I want 3 lines in the error message, and 3 types.

for {
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, err)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
return
default:
return
}
}
Comment on lines +2051 to +2067
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simpler if all the switch cases recurse, but the above implementation is fine.

Suggested change
for {
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, err)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
return
default:
return
}
}
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
}
errs = append(errs, unwrapAll(err)...)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
}
}
return

}

func buildErrorChainString(err error, withType bool) string {
if err == nil {
return ""
}

e := errors.Unwrap(err)
chain := fmt.Sprintf("%q", err.Error())
for e != nil {
chain += fmt.Sprintf("\n\t%q", e.Error())
e = errors.Unwrap(e)
var chain string
errs := unwrapAll(err)
craig65535 marked this conversation as resolved.
Show resolved Hide resolved
for i := range errs {
if i != 0 {
chain += "\n\t"
}
chain += fmt.Sprintf("%q", errs[i].Error())
if withType {
chain += fmt.Sprintf(" (%T)", errs[i])
}
}
return chain
}
43 changes: 34 additions & 9 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3131,23 +3131,48 @@ func TestNotErrorIs(t *testing.T) {
}

func TestErrorAs(t *testing.T) {
mockT := new(testing.T)
tests := []struct {
err error
result bool
err error
result bool
resultErrMsg string
}{
{fmt.Errorf("wrap: %w", &customError{}), true},
{io.EOF, false},
{nil, false},
{
err: fmt.Errorf("wrap: %w", &customError{}),
result: true,
},
{
err: io.EOF,
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: **assert.customError\n" +
"in chain: \"EOF\" (*errors.errorString)\n",
},
{
err: nil,
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: **assert.customError\n" +
"in chain: \n",
},
{
err: fmt.Errorf("abc: %w", errors.New("def")),
result: false,
resultErrMsg: "" +
"Should be in error chain:\n" +
"expected: **assert.customError\n" +
"in chain: \"abc: def\" (*fmt.wrapError)\n" +
"\t\"def\" (*errors.errorString)\n",
},
}
for _, tt := range tests {
tt := tt
var target *customError
t.Run(fmt.Sprintf("ErrorAs(%#v,%#v)", tt.err, target), func(t *testing.T) {
mockT := new(captureTestingT)
res := ErrorAs(mockT, tt.err, &target)
if res != tt.result {
t.Errorf("ErrorAs(%#v,%#v) should return %t)", tt.err, target, tt.result)
}
mockT.checkResultAndErrMsg(t, tt.result, res, tt.resultErrMsg)
})
}
}