Skip to content

Commit

Permalink
Merge pull request #462 from smallstep/mariano/error-is
Browse files Browse the repository at this point in the history
Allow to compare kms errors with errors.Is
  • Loading branch information
maraino committed Mar 21, 2024
2 parents 8e55bd9 + ca39242 commit ac197b0
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 13 deletions.
32 changes: 30 additions & 2 deletions kms/apiv1/options.go
Expand Up @@ -72,8 +72,13 @@ func (e NotImplementedError) Error() string {
return "not implemented"
}

func (e NotImplementedError) Is(target error) bool {
_, ok := target.(NotImplementedError)
return ok
}

// AlreadyExistsError is the type of error returned if a key already exists. This
// is currently only implmented for pkcs11 and tpmkms.
// is currently only implemented for pkcs11, tpmkms, and mackms.
type AlreadyExistsError struct {
Message string
}
Expand All @@ -82,7 +87,30 @@ func (e AlreadyExistsError) Error() string {
if e.Message != "" {
return e.Message
}
return "key already exists"
return "already exists"
}

func (e AlreadyExistsError) Is(target error) bool {
_, ok := target.(AlreadyExistsError)
return ok
}

// NotFoundError is the type of error returned if a key or certificate does not
// exist. This is currently only implemented for mackms.
type NotFoundError struct {
Message string
}

func (e NotFoundError) Error() string {
if e.Message != "" {
return e.Message
}
return "not found"
}

func (e NotFoundError) Is(target error) bool {
_, ok := target.(NotFoundError)
return ok
}

// Type represents the KMS type used.
Expand Down
56 changes: 55 additions & 1 deletion kms/apiv1/options_test.go
Expand Up @@ -3,8 +3,12 @@ package apiv1
import (
"context"
"crypto"
"errors"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

type fakeKM struct{}
Expand Down Expand Up @@ -124,7 +128,7 @@ func TestErrAlreadyExists_Error(t *testing.T) {
fields fields
want string
}{
{"default", fields{}, "key already exists"},
{"default", fields{}, "already exists"},
{"custom", fields{"custom message: key already exists"}, "custom message: key already exists"},
}
for _, tt := range tests {
Expand All @@ -139,6 +143,30 @@ func TestErrAlreadyExists_Error(t *testing.T) {
}
}

func TestNotFoundError_Error(t *testing.T) {
type fields struct {
msg string
}
tests := []struct {
name string
fields fields
want string
}{
{"default", fields{}, "not found"},
{"custom", fields{"custom message: not found"}, "custom message: not found"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NotFoundError{
Message: tt.fields.msg,
}
if got := e.Error(); got != tt.want {
t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want)
}
})
}
}

func TestTypeOf(t *testing.T) {
type args struct {
rawuri string
Expand Down Expand Up @@ -176,3 +204,29 @@ func TestTypeOf(t *testing.T) {
})
}
}

func TestError_Is(t *testing.T) {
tests := []struct {
name string
err error
target error
want bool
}{
{"ok not implemented", NotImplementedError{}, NotImplementedError{}, true},
{"ok not implemented with message", NotImplementedError{Message: "something"}, NotImplementedError{}, true},
{"ok already exists", AlreadyExistsError{}, AlreadyExistsError{}, true},
{"ok already exists with message", AlreadyExistsError{Message: "something"}, AlreadyExistsError{}, true},
{"ok not found", NotFoundError{}, NotFoundError{}, true},
{"ok not found with message", NotFoundError{Message: "something"}, NotFoundError{}, true},
{"fail not implemented", errors.New("not implemented"), NotImplementedError{}, false},
{"fail already exists", errors.New("already exists"), AlreadyExistsError{}, false},
{"fail not found", errors.New("not found"), NotFoundError{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, errors.Is(tt.err, tt.target))
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", tt.err), tt.target))
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", fmt.Errorf("wrap 2: %w", tt.err)), tt.target))
})
}
}
33 changes: 24 additions & 9 deletions kms/mackms/mackms.go
Expand Up @@ -141,7 +141,7 @@ func (k *MacKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,

key, err := getPrivateKey(u)
if err != nil {
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", err)
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", apiv1Error(err))
}
defer key.Release()

Expand Down Expand Up @@ -263,7 +263,7 @@ func (k *MacKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons

secKeyRef, err := security.SecKeyCreateRandomKey(attrs)
if err != nil {
return nil, fmt.Errorf("mackms CreateKey failed: %w", err)
return nil, fmt.Errorf("mackms CreateKey failed: %w", apiv1Error(err))
}
defer secKeyRef.Release()

Expand Down Expand Up @@ -307,7 +307,7 @@ func (k *MacKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er

key, err := getPrivateKey(u)
if err != nil {
return nil, fmt.Errorf("mackms CreateSigner failed: %w", err)
return nil, fmt.Errorf("mackms CreateSigner failed: %w", apiv1Error(err))
}
defer key.Release()

Expand Down Expand Up @@ -343,7 +343,7 @@ func (k *MacKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi

cert, err := loadCertificate(u.label, u.serialNumber, nil)
if err != nil {
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", err)
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", apiv1Error(err))
}

return cert, nil
Expand Down Expand Up @@ -375,7 +375,7 @@ func (k *MacKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {

// Store the certificate and update the label if required
if err := storeCertificate(u.label, req.Certificate); err != nil {
return fmt.Errorf("mackms StoreCertificate failed: %w", err)
return fmt.Errorf("mackms StoreCertificate failed: %w", apiv1Error(err))
}

return nil
Expand All @@ -402,7 +402,7 @@ func (k *MacKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([

cert, err := loadCertificate(u.label, u.serialNumber, nil)
if err != nil {
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", err)
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", apiv1Error(err))
}

chain := []*x509.Certificate{cert}
Expand Down Expand Up @@ -453,7 +453,7 @@ func (k *MacKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)

// Store the certificate and update the label if required
if err := storeCertificate(u.label, req.CertificateChain[0]); err != nil {
return fmt.Errorf("mackms StoreCertificateChain failed: %w", err)
return fmt.Errorf("mackms StoreCertificateChain failed: %w", apiv1Error(err))
}

// Store the rest of the chain but do not fail if already exists
Expand Down Expand Up @@ -503,7 +503,7 @@ func (*MacKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
}
// Extract logic to deleteItem to avoid defer on loops
if err := deleteItem(dict, u.hash); err != nil {
return fmt.Errorf("mackms DeleteKey failed: %w", err)
return fmt.Errorf("mackms DeleteKey failed: %w", apiv1Error(err))
}
}

Expand Down Expand Up @@ -548,7 +548,7 @@ func (*MacKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
}

if err := deleteItem(query, nil); err != nil {
return fmt.Errorf("mackms DeleteCertificate failed: %w", err)
return fmt.Errorf("mackms DeleteCertificate failed: %w", apiv1Error(err))
}

return nil
Expand Down Expand Up @@ -1003,3 +1003,18 @@ func ecdhToECDSAPublicKey(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) {
return nil, errors.New("failed to convert *ecdh.PublicKey to *ecdsa.PublicKey")
}
}

func apiv1Error(err error) error {
switch {
case errors.Is(err, security.ErrNotFound):
return apiv1.NotFoundError{
Message: err.Error(),
}
case errors.Is(err, security.ErrAlreadyExists):
return apiv1.AlreadyExistsError{
Message: err.Error(),
}
default:
return err
}
}
39 changes: 38 additions & 1 deletion kms/mackms/mackms_test.go
Expand Up @@ -29,6 +29,8 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"fmt"
"io"
"math/big"
"net/url"
"testing"
Expand Down Expand Up @@ -1143,7 +1145,7 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
_, err := kms.LoadCertificate(&apiv1.LoadCertificateRequest{
Name: "mackms:serial=" + hex.EncodeToString(cert.SerialNumber.Bytes()),
})
assert.ErrorIs(t, err, security.ErrNotFound)
assert.ErrorIs(t, err, apiv1.NotFoundError{})
}

kms := &MacKMS{}
Expand Down Expand Up @@ -1196,3 +1198,38 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
})
}
}

func Test_apiv1Error(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
assertion assert.ErrorAssertionFunc
}{
{"ok not found", args{security.ErrNotFound}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
}},
{"ok not found wrapped", args{fmt.Errorf("something happened: %w", security.ErrNotFound)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
}},
{"ok already exists", args{security.ErrAlreadyExists}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
}},
{"ok already exists wrapped", args{fmt.Errorf("something happened: %w", security.ErrAlreadyExists)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
}},
{"ok other", args{io.ErrUnexpectedEOF}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
}},
{"ok other wrapped", args{fmt.Errorf("something happened: %w", io.ErrUnexpectedEOF)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.assertion(t, apiv1Error(tt.args.err))
})
}
}

0 comments on commit ac197b0

Please sign in to comment.