Skip to content

Commit

Permalink
Allow to compare kms errors with errors.Is
Browse files Browse the repository at this point in the history
This commit implements the "Is(target error) bool" interface to apiv1
errors so we can compare them with errors.Is even if the message is not
empty.
  • Loading branch information
maraino committed Mar 21, 2024
1 parent 8e55bd9 commit 820dcb6
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 12 deletions.
32 changes: 30 additions & 2 deletions kms/apiv1/options.go
Expand Up @@ -72,8 +72,13 @@ func (e NotImplementedError) Error() string {
return "not implemented"
}

func (e NotImplementedError) Is(target error) bool {
_, ok := target.(NotImplementedError)
return ok
}

// AlreadyExistsError is the type of error returned if a key already exists. This
// is currently only implmented for pkcs11 and tpmkms.
// is currently only implemented for pkcs11, tpmkms, and mackms.
type AlreadyExistsError struct {
Message string
}
Expand All @@ -82,7 +87,30 @@ func (e AlreadyExistsError) Error() string {
if e.Message != "" {
return e.Message
}
return "key already exists"
return "already exists"
}

func (e AlreadyExistsError) Is(target error) bool {
_, ok := target.(AlreadyExistsError)
return ok
}

// NotFoundError is the type of error returned if a key or certificate does not
// exist. This is currently only implemented for mackms.
type NotFoundError struct {
Message string
}

func (e NotFoundError) Error() string {
if e.Message != "" {
return e.Message
}
return "not found"
}

func (e NotFoundError) Is(target error) bool {
_, ok := target.(NotFoundError)
return ok
}

// Type represents the KMS type used.
Expand Down
30 changes: 30 additions & 0 deletions kms/apiv1/options_test.go
Expand Up @@ -3,8 +3,12 @@ package apiv1
import (
"context"
"crypto"
"errors"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

type fakeKM struct{}
Expand Down Expand Up @@ -176,3 +180,29 @@ func TestTypeOf(t *testing.T) {
})
}
}

func TestError_Is(t *testing.T) {
tests := []struct {
name string
err error
target error
want bool
}{
{"ok not implemented", NotImplementedError{}, NotImplementedError{}, true},
{"ok not implemented with message", NotImplementedError{Message: "something"}, NotImplementedError{}, true},
{"ok already exists", AlreadyExistsError{}, AlreadyExistsError{}, true},
{"ok already exists with message", AlreadyExistsError{Message: "something"}, AlreadyExistsError{}, true},
{"ok not found", NotFoundError{}, NotFoundError{}, true},
{"ok not found with message", NotFoundError{Message: "something"}, NotFoundError{}, true},
{"fail not implemented", errors.New("not implemented"), NotImplementedError{}, false},
{"fail already exists", errors.New("already exists"), AlreadyExistsError{}, false},
{"fail not found", errors.New("not found"), NotFoundError{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, errors.Is(tt.err, tt.target))
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", tt.err), tt.target))
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", fmt.Errorf("wrap 2: %w", tt.err)), tt.target))
})
}
}
33 changes: 24 additions & 9 deletions kms/mackms/mackms.go
Expand Up @@ -141,7 +141,7 @@ func (k *MacKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,

key, err := getPrivateKey(u)
if err != nil {
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", err)
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", apiv1Error(err))
}
defer key.Release()

Expand Down Expand Up @@ -263,7 +263,7 @@ func (k *MacKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons

secKeyRef, err := security.SecKeyCreateRandomKey(attrs)
if err != nil {
return nil, fmt.Errorf("mackms CreateKey failed: %w", err)
return nil, fmt.Errorf("mackms CreateKey failed: %w", apiv1Error(err))
}
defer secKeyRef.Release()

Expand Down Expand Up @@ -307,7 +307,7 @@ func (k *MacKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er

key, err := getPrivateKey(u)
if err != nil {
return nil, fmt.Errorf("mackms CreateSigner failed: %w", err)
return nil, fmt.Errorf("mackms CreateSigner failed: %w", apiv1Error(err))
}
defer key.Release()

Expand Down Expand Up @@ -343,7 +343,7 @@ func (k *MacKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi

cert, err := loadCertificate(u.label, u.serialNumber, nil)
if err != nil {
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", err)
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", apiv1Error(err))
}

return cert, nil
Expand Down Expand Up @@ -375,7 +375,7 @@ func (k *MacKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {

// Store the certificate and update the label if required
if err := storeCertificate(u.label, req.Certificate); err != nil {
return fmt.Errorf("mackms StoreCertificate failed: %w", err)
return fmt.Errorf("mackms StoreCertificate failed: %w", apiv1Error(err))
}

return nil
Expand All @@ -402,7 +402,7 @@ func (k *MacKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([

cert, err := loadCertificate(u.label, u.serialNumber, nil)
if err != nil {
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", err)
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", apiv1Error(err))
}

chain := []*x509.Certificate{cert}
Expand Down Expand Up @@ -453,7 +453,7 @@ func (k *MacKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)

// Store the certificate and update the label if required
if err := storeCertificate(u.label, req.CertificateChain[0]); err != nil {
return fmt.Errorf("mackms StoreCertificateChain failed: %w", err)
return fmt.Errorf("mackms StoreCertificateChain failed: %w", apiv1Error(err))
}

// Store the rest of the chain but do not fail if already exists
Expand Down Expand Up @@ -503,7 +503,7 @@ func (*MacKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
}
// Extract logic to deleteItem to avoid defer on loops
if err := deleteItem(dict, u.hash); err != nil {
return fmt.Errorf("mackms DeleteKey failed: %w", err)
return fmt.Errorf("mackms DeleteKey failed: %w", apiv1Error(err))
}
}

Expand Down Expand Up @@ -548,7 +548,7 @@ func (*MacKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
}

if err := deleteItem(query, nil); err != nil {
return fmt.Errorf("mackms DeleteCertificate failed: %w", err)
return fmt.Errorf("mackms DeleteCertificate failed: %w", apiv1Error(err))
}

return nil
Expand Down Expand Up @@ -1003,3 +1003,18 @@ func ecdhToECDSAPublicKey(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) {
return nil, errors.New("failed to convert *ecdh.PublicKey to *ecdsa.PublicKey")
}
}

func apiv1Error(err error) error {
switch {
case errors.Is(err, security.ErrNotFound):
return apiv1.NotFoundError{
Message: err.Error(),
}
case errors.Is(err, security.ErrAlreadyExists):
return apiv1.AlreadyExistsError{
Message: err.Error(),
}
default:
return err
}
}
39 changes: 38 additions & 1 deletion kms/mackms/mackms_test.go
Expand Up @@ -29,6 +29,8 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"fmt"
"io"
"math/big"
"net/url"
"testing"
Expand Down Expand Up @@ -1143,7 +1145,7 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
_, err := kms.LoadCertificate(&apiv1.LoadCertificateRequest{
Name: "mackms:serial=" + hex.EncodeToString(cert.SerialNumber.Bytes()),
})
assert.ErrorIs(t, err, security.ErrNotFound)
assert.ErrorIs(t, err, apiv1.NotFoundError{})
}

kms := &MacKMS{}
Expand Down Expand Up @@ -1196,3 +1198,38 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
})
}
}

func Test_apiv1Error(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
assertion assert.ErrorAssertionFunc
}{
{"ok not found", args{security.ErrNotFound}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
}},
{"ok not found wrapped", args{fmt.Errorf("something happened: %w", security.ErrNotFound)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
}},
{"ok already exists", args{security.ErrAlreadyExists}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
}},
{"ok already exists wrapped", args{fmt.Errorf("something happened: %w", security.ErrAlreadyExists)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
}},
{"ok other", args{io.ErrUnexpectedEOF}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
}},
{"ok other wrapped", args{fmt.Errorf("something happened: %w", io.ErrUnexpectedEOF)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.assertion(t, apiv1Error(tt.args.err))
})
}
}

0 comments on commit 820dcb6

Please sign in to comment.