diff --git a/status/status.go b/status/status.go index 623be39f26b..7577625b0c4 100644 --- a/status/status.go +++ b/status/status.go @@ -77,7 +77,8 @@ func FromProto(s *spb.Status) *Status { // FromError returns a Status representation of err. // // - If err was produced by this package or implements the method `GRPCStatus() -// *Status`, the appropriate Status is returned. +// *Status`, or if err wraps a type satisfying this, the appropriate Status is +// returned. // // - If err is nil, a Status is returned with codes.OK and no message. // @@ -88,9 +89,8 @@ func FromError(err error) (s *Status, ok bool) { if err == nil { return nil, true } - if se, ok := err.(interface { - GRPCStatus() *Status - }); ok { + var se interface{ GRPCStatus() *Status } + if errors.As(err, &se) { return se.GRPCStatus(), true } return New(codes.Unknown, err.Error()), false @@ -103,19 +103,16 @@ func Convert(err error) *Status { return s } -// Code returns the Code of the error if it is a Status error, codes.OK if err -// is nil, or codes.Unknown otherwise. +// Code returns the Code of the error if it is a Status error or if it wraps a +// Status error. If that is not the case, it returns codes.OK if err is nil, or +// codes.Unknown otherwise. func Code(err error) codes.Code { // Don't use FromError to avoid allocation of OK status. if err == nil { return codes.OK } - if se, ok := err.(interface { - GRPCStatus() *Status - }); ok { - return se.GRPCStatus().Code() - } - return codes.Unknown + + return Convert(err).Code() } // FromContextError converts a context error or wrapped context error into a diff --git a/status/status_test.go b/status/status_test.go index 420fb6b8102..244cb8151fd 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -32,6 +32,7 @@ import ( cpb "google.golang.org/genproto/googleapis/rpc/code" epb "google.golang.org/genproto/googleapis/rpc/errdetails" spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/status" @@ -192,6 +193,70 @@ func (s) TestFromErrorUnknownError(t *testing.T) { } } +func (s) TestFromErrorWrapped(t *testing.T) { + const code, message = codes.Internal, "test description" + err := fmt.Errorf("wrapped error: %w", Error(code, message)) + s, ok := FromError(err) + if !ok || s.Code() != code || s.Message() != message || s.Err() == nil { + t.Fatalf("FromError(%v) = %v, %v; want , true", err, s, ok, code, message) + } +} + +func (s) TestFromErrorImplementsInterfaceWrapped(t *testing.T) { + const code, message = codes.Internal, "test description" + err := fmt.Errorf("wrapped error: %w", customError{Code: code, Message: message}) + s, ok := FromError(err) + if !ok || s.Code() != code || s.Message() != message || s.Err() == nil { + t.Fatalf("FromError(%v) = %v, %v; want , true", err, s, ok, code, message) + } +} + +func (s) TestCode(t *testing.T) { + const code = codes.Internal + err := Error(code, "test description") + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + +func (s) TestCodeOK(t *testing.T) { + if s, code := Code(nil), codes.OK; s != code { + t.Fatalf("Code(%v) = %v; want ", nil, s, code) + } +} + +func (s) TestCodeImplementsInterface(t *testing.T) { + const code = codes.Internal + err := customError{Code: code, Message: "test description"} + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + +func (s) TestCodeUnknownError(t *testing.T) { + const code = codes.Unknown + err := errors.New("unknown error") + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + +func (s) TestCodeWrapped(t *testing.T) { + const code = codes.Internal + err := fmt.Errorf("wrapped: %w", Error(code, "test description")) + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + +func (s) TestCodeImplementsInterfaceWrapped(t *testing.T) { + const code = codes.Internal + err := fmt.Errorf("wrapped: %w", customError{Code: code, Message: "test description"}) + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + func (s) TestConvertKnownError(t *testing.T) { code, message := codes.Internal, "test description" err := Error(code, message)