Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to compare kms errors with errors.Is #462

Merged
merged 1 commit into from Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 30 additions & 2 deletions kms/apiv1/options.go
Expand Up @@ -72,8 +72,13 @@ func (e NotImplementedError) Error() string {
return "not implemented"
}

func (e NotImplementedError) Is(target error) bool {
_, ok := target.(NotImplementedError)
return ok
}

// AlreadyExistsError is the type of error returned if a key already exists. This
// is currently only implmented for pkcs11 and tpmkms.
// is currently only implemented for pkcs11, tpmkms, and mackms.
type AlreadyExistsError struct {
Message string
}
Expand All @@ -82,7 +87,30 @@ func (e AlreadyExistsError) Error() string {
if e.Message != "" {
return e.Message
}
return "key already exists"
return "already exists"
}

func (e AlreadyExistsError) Is(target error) bool {
_, ok := target.(AlreadyExistsError)
return ok
}

// NotFoundError is the type of error returned if a key or certificate does not
// exist. This is currently only implemented for mackms.
type NotFoundError struct {
Message string
}

func (e NotFoundError) Error() string {
if e.Message != "" {
return e.Message
}
return "not found"
}

func (e NotFoundError) Is(target error) bool {
_, ok := target.(NotFoundError)
return ok
}

// Type represents the KMS type used.
Expand Down
56 changes: 55 additions & 1 deletion kms/apiv1/options_test.go
Expand Up @@ -3,8 +3,12 @@ package apiv1
import (
"context"
"crypto"
"errors"
"fmt"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

type fakeKM struct{}
Expand Down Expand Up @@ -124,7 +128,7 @@ func TestErrAlreadyExists_Error(t *testing.T) {
fields fields
want string
}{
{"default", fields{}, "key already exists"},
{"default", fields{}, "already exists"},
{"custom", fields{"custom message: key already exists"}, "custom message: key already exists"},
}
for _, tt := range tests {
Expand All @@ -139,6 +143,30 @@ func TestErrAlreadyExists_Error(t *testing.T) {
}
}

func TestNotFoundError_Error(t *testing.T) {
type fields struct {
msg string
}
tests := []struct {
name string
fields fields
want string
}{
{"default", fields{}, "not found"},
{"custom", fields{"custom message: not found"}, "custom message: not found"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NotFoundError{
Message: tt.fields.msg,
}
if got := e.Error(); got != tt.want {
t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want)
}
})
}
}

func TestTypeOf(t *testing.T) {
type args struct {
rawuri string
Expand Down Expand Up @@ -176,3 +204,29 @@ func TestTypeOf(t *testing.T) {
})
}
}

func TestError_Is(t *testing.T) {
tests := []struct {
name string
err error
target error
want bool
}{
{"ok not implemented", NotImplementedError{}, NotImplementedError{}, true},
{"ok not implemented with message", NotImplementedError{Message: "something"}, NotImplementedError{}, true},
{"ok already exists", AlreadyExistsError{}, AlreadyExistsError{}, true},
{"ok already exists with message", AlreadyExistsError{Message: "something"}, AlreadyExistsError{}, true},
{"ok not found", NotFoundError{}, NotFoundError{}, true},
{"ok not found with message", NotFoundError{Message: "something"}, NotFoundError{}, true},
{"fail not implemented", errors.New("not implemented"), NotImplementedError{}, false},
{"fail already exists", errors.New("already exists"), AlreadyExistsError{}, false},
{"fail not found", errors.New("not found"), NotFoundError{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, errors.Is(tt.err, tt.target))
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", tt.err), tt.target))
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", fmt.Errorf("wrap 2: %w", tt.err)), tt.target))
})
}
}
33 changes: 24 additions & 9 deletions kms/mackms/mackms.go
Expand Up @@ -141,7 +141,7 @@ func (k *MacKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,

key, err := getPrivateKey(u)
if err != nil {
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", err)
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", apiv1Error(err))
}
defer key.Release()

Expand Down Expand Up @@ -263,7 +263,7 @@ func (k *MacKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons

secKeyRef, err := security.SecKeyCreateRandomKey(attrs)
if err != nil {
return nil, fmt.Errorf("mackms CreateKey failed: %w", err)
return nil, fmt.Errorf("mackms CreateKey failed: %w", apiv1Error(err))
}
defer secKeyRef.Release()

Expand Down Expand Up @@ -307,7 +307,7 @@ func (k *MacKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er

key, err := getPrivateKey(u)
if err != nil {
return nil, fmt.Errorf("mackms CreateSigner failed: %w", err)
return nil, fmt.Errorf("mackms CreateSigner failed: %w", apiv1Error(err))
}
defer key.Release()

Expand Down Expand Up @@ -343,7 +343,7 @@ func (k *MacKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi

cert, err := loadCertificate(u.label, u.serialNumber, nil)
if err != nil {
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", err)
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", apiv1Error(err))
}

return cert, nil
Expand Down Expand Up @@ -375,7 +375,7 @@ func (k *MacKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {

// Store the certificate and update the label if required
if err := storeCertificate(u.label, req.Certificate); err != nil {
return fmt.Errorf("mackms StoreCertificate failed: %w", err)
return fmt.Errorf("mackms StoreCertificate failed: %w", apiv1Error(err))
}

return nil
Expand All @@ -402,7 +402,7 @@ func (k *MacKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([

cert, err := loadCertificate(u.label, u.serialNumber, nil)
if err != nil {
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", err)
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", apiv1Error(err))
}

chain := []*x509.Certificate{cert}
Expand Down Expand Up @@ -453,7 +453,7 @@ func (k *MacKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)

// Store the certificate and update the label if required
if err := storeCertificate(u.label, req.CertificateChain[0]); err != nil {
return fmt.Errorf("mackms StoreCertificateChain failed: %w", err)
return fmt.Errorf("mackms StoreCertificateChain failed: %w", apiv1Error(err))
}

// Store the rest of the chain but do not fail if already exists
Expand Down Expand Up @@ -503,7 +503,7 @@ func (*MacKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
}
// Extract logic to deleteItem to avoid defer on loops
if err := deleteItem(dict, u.hash); err != nil {
return fmt.Errorf("mackms DeleteKey failed: %w", err)
return fmt.Errorf("mackms DeleteKey failed: %w", apiv1Error(err))
}
}

Expand Down Expand Up @@ -548,7 +548,7 @@ func (*MacKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
}

if err := deleteItem(query, nil); err != nil {
return fmt.Errorf("mackms DeleteCertificate failed: %w", err)
return fmt.Errorf("mackms DeleteCertificate failed: %w", apiv1Error(err))
}

return nil
Expand Down Expand Up @@ -1003,3 +1003,18 @@ func ecdhToECDSAPublicKey(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) {
return nil, errors.New("failed to convert *ecdh.PublicKey to *ecdsa.PublicKey")
}
}

func apiv1Error(err error) error {
switch {
case errors.Is(err, security.ErrNotFound):
return apiv1.NotFoundError{
Message: err.Error(),
}
case errors.Is(err, security.ErrAlreadyExists):
return apiv1.AlreadyExistsError{
Message: err.Error(),
}
default:
return err
}
}
39 changes: 38 additions & 1 deletion kms/mackms/mackms_test.go
Expand Up @@ -29,6 +29,8 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"fmt"
"io"
"math/big"
"net/url"
"testing"
Expand Down Expand Up @@ -1143,7 +1145,7 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
_, err := kms.LoadCertificate(&apiv1.LoadCertificateRequest{
Name: "mackms:serial=" + hex.EncodeToString(cert.SerialNumber.Bytes()),
})
assert.ErrorIs(t, err, security.ErrNotFound)
assert.ErrorIs(t, err, apiv1.NotFoundError{})
}

kms := &MacKMS{}
Expand Down Expand Up @@ -1196,3 +1198,38 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
})
}
}

func Test_apiv1Error(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
assertion assert.ErrorAssertionFunc
}{
{"ok not found", args{security.ErrNotFound}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
}},
{"ok not found wrapped", args{fmt.Errorf("something happened: %w", security.ErrNotFound)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
}},
{"ok already exists", args{security.ErrAlreadyExists}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
}},
{"ok already exists wrapped", args{fmt.Errorf("something happened: %w", security.ErrAlreadyExists)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
}},
{"ok other", args{io.ErrUnexpectedEOF}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
}},
{"ok other wrapped", args{fmt.Errorf("something happened: %w", io.ErrUnexpectedEOF)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.assertion(t, apiv1Error(tt.args.err))
})
}
}