diff --git a/status/status.go b/status/status.go index 623be39f26ba..82a784962c52 100644 --- a/status/status.go +++ b/status/status.go @@ -88,9 +88,8 @@ func FromError(err error) (s *Status, ok bool) { if err == nil { return nil, true } - if se, ok := err.(interface { - GRPCStatus() *Status - }); ok { + var se interface{ GRPCStatus() *Status } + if errors.As(err, &se) { return se.GRPCStatus(), true } return New(codes.Unknown, err.Error()), false @@ -110,12 +109,8 @@ func Code(err error) codes.Code { if err == nil { return codes.OK } - if se, ok := err.(interface { - GRPCStatus() *Status - }); ok { - return se.GRPCStatus().Code() - } - return codes.Unknown + + return Convert(err).Code() } // FromContextError converts a context error or wrapped context error into a diff --git a/status/status_test.go b/status/status_test.go index 420fb6b8102c..1313e43857ff 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -32,6 +32,7 @@ import ( cpb "google.golang.org/genproto/googleapis/rpc/code" epb "google.golang.org/genproto/googleapis/rpc/errdetails" spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/status" @@ -144,6 +145,37 @@ func (s) TestFromErrorOK(t *testing.T) { } } +func (s) TestFromErrorWrapped(t *testing.T) { + code, message := codes.Internal, "test description" + err := fmt.Errorf("wrapped error: %w", Error(code, message)) + s, ok := FromError(err) + if !ok || s.Code() != code || s.Message() != message || s.Err() == nil { + t.Fatalf("FromError(%v) = %v, %v; want , true", err, s, ok, code, message) + } +} + +func (s) TestCode(t *testing.T) { + code, message := codes.Internal, "test description" + err := Error(code, message) + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + +func (s) TestCodeOK(t *testing.T) { + if s, code := Code(nil), codes.OK; s != code { + t.Fatalf("Code(%v) = %v; want ", nil, s, code) + } +} + +func (s) TestCodeWrapped(t *testing.T) { + code, message := codes.Internal, "test description" + err := fmt.Errorf("wrapped: %w", Error(code, message)) + if s := Code(err); s != code { + t.Fatalf("Code(%v) = %v; want ", err, s, code) + } +} + type customError struct { Code codes.Code Message string