Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace panic with returning errors from key decryption providers #155

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
# Changelog
## 1.7.0

### Changed

* Changed always encrypted key provider error handling not to panic on failure

### Features

* Support DER certificates for server authentication (#152)

### Bug fixes

* Improved speed of CharsetToUTF8 (#154)

## 1.6.0

Expand Down
165 changes: 100 additions & 65 deletions aecmk/akv/keyprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,101 +63,120 @@ func init() {

// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key.
// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm.
func (p *Provider) DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte) {
func (p *Provider) DecryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) (decryptedKey []byte, err error) {
decryptedKey = nil
keyData := p.getKeyData(masterKeyPath)
if keyData == nil {
keyData, err := p.getKeyData(ctx, masterKeyPath, aecmk.Decryption)
if err != nil {
return
}
keySize := keyData.publicKey.Size()
cekv := ae.LoadCEKV(encryptedCek)
if cekv.Version != 1 {
panic(fmt.Errorf("Invalid version byte in encrypted key"))
return nil, aecmk.NewError(aecmk.Decryption, "Invalid version byte in encrypted key", nil)
}
if keySize != len(cekv.Ciphertext) {
panic(fmt.Errorf("Encrypted key has wrong ciphertext length"))
return nil, aecmk.NewError(aecmk.Decryption, "Encrypted key has wrong ciphertext length", nil)
}
if keySize != len(cekv.SignedHash) {
panic(fmt.Errorf("Encrypted key signature length mismatch"))
return nil, aecmk.NewError(aecmk.Decryption, "Encrypted key signature length mismatch", nil)
}
if !cekv.VerifySignature(keyData.publicKey) {
panic(fmt.Errorf("Invalid signature hash"))
return nil, aecmk.NewError(aecmk.Decryption, "Invalid signature hash", nil)
}

client := p.getAKVClient(keyData.endpoint)
algorithm := getAlgorithm(encryptionAlgorithm)
client, err := p.getAKVClient(aecmk.Decryption, keyData.endpoint)
if err != nil {
return
}
algorithm, err := getAlgorithm(aecmk.Decryption, encryptionAlgorithm)
if err != nil {
return
}
parameters := azkeys.KeyOperationParameters{
Algorithm: &algorithm,
Value: cekv.Ciphertext,
}
r, err := client.UnwrapKey(context.Background(), keyData.name, keyData.version, parameters, nil)
if err != nil {
panic(fmt.Errorf("Unable to decrypt key %s: %w", masterKeyPath, err))
r, e := client.UnwrapKey(ctx, keyData.name, keyData.version, parameters, nil)
if e != nil {
err = aecmk.NewError(aecmk.Decryption, fmt.Sprintf("Unable to decrypt key %s", masterKeyPath), e)
} else {
decryptedKey = r.Result
}
decryptedKey = r.Result
return
}

// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm.
func (p *Provider) EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte {
keyData := p.getKeyData(masterKeyPath)
// just validate the algorith
_ = getAlgorithm(encryptionAlgorithm)
func (p *Provider) EncryptColumnEncryptionKey(ctx context.Context, masterKeyPath string, encryptionAlgorithm string, cek []byte) (buf []byte, err error) {
keyData, err := p.getKeyData(ctx, masterKeyPath, aecmk.Encryption)
if err != nil {
return
}
_, err = getAlgorithm(aecmk.Encryption, encryptionAlgorithm)
if err != nil {
return
}
keySize := keyData.publicKey.Size()
enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder()
// Start with version byte == 1
buf := []byte{byte(1)}
tmp := []byte{byte(1)}
// EncryptedColumnEncryptionKey = version + keyPathLength + ciphertextLength + keyPath + ciphertext + signature
// version
keyPathBytes, err := enc.Bytes([]byte(strings.ToLower(masterKeyPath)))
if err != nil {
panic(fmt.Errorf("Unable to serialize key path %w", err))
err = aecmk.NewError(aecmk.Encryption, "Unable to serialize key path", err)
return
}
k := uint16(len(keyPathBytes))
// keyPathLength
buf = append(buf, byte(k), byte(k>>8))
tmp = append(tmp, byte(k), byte(k>>8))

cipherText, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, keyData.publicKey, cek, []byte{})
if err != nil {
panic(fmt.Errorf("Unable to encrypt data %w", err))
err = aecmk.NewError(aecmk.Encryption, "Unable to encrypt data", err)
return
}
l := uint16(len(cipherText))
// ciphertextLength
buf = append(buf, byte(l), byte(l>>8))
tmp = append(tmp, byte(l), byte(l>>8))
// keypath
buf = append(buf, keyPathBytes...)
tmp = append(tmp, keyPathBytes...)
// ciphertext
buf = append(buf, cipherText...)
hash := sha256.Sum256(buf)
client := p.getAKVClient(keyData.endpoint)
tmp = append(tmp, cipherText...)
hash := sha256.Sum256(tmp)
client, err := p.getAKVClient(aecmk.Encryption, keyData.endpoint)
if err != nil {
return
}
signAlgorithm := azkeys.SignatureAlgorithmRS256
parameters := azkeys.SignParameters{
Algorithm: &signAlgorithm,
Value: hash[:],
}
r, err := client.Sign(context.Background(), keyData.name, keyData.version, parameters, nil)
r, err := client.Sign(ctx, keyData.name, keyData.version, parameters, nil)
if err != nil {
panic(err)
err = aecmk.NewError(aecmk.Encryption, "AKV failed to sign data", err)
return
}
if len(r.Result) != keySize {
panic("Signature length doesn't match certificate key size")
err = aecmk.NewError(aecmk.Encryption, "Signature length doesn't match certificate key size", nil)
} else {
// signature
buf = append(tmp, r.Result...)
}
// signature
buf = append(buf, r.Result...)
return buf
return
}

// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key
// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the
// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported.
func (p *Provider) SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte {
return nil
func (p *Provider) SignColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) ([]byte, error) {
return nil, nil
}

// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key
// with the specified key path and the specified enclave behavior. Return nil if not supported.
func (p *Provider) VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool {
return nil
func (p *Provider) VerifyColumnMasterKeyMetadata(ctx context.Context, masterKeyPath string, allowEnclaveComputations bool) (*bool, error) {
return nil, nil
}

// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires.
Expand All @@ -167,51 +186,60 @@ func (p *Provider) KeyLifetime() *time.Duration {
return nil
}

func getAlgorithm(encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm) {
func getAlgorithm(op aecmk.Operation, encryptionAlgorithm string) (algorithm azkeys.EncryptionAlgorithm, err error) {
// support both RSA_OAEP and RSA-OAEP
if strings.EqualFold(encryptionAlgorithm, aecmk.KeyEncryptionAlgorithm) {
encryptionAlgorithm = string(azkeys.EncryptionAlgorithmRSAOAEP)
}
if !strings.EqualFold(encryptionAlgorithm, string(azkeys.EncryptionAlgorithmRSAOAEP)) {
panic(fmt.Errorf("Unsupported encryption algorithm %s", encryptionAlgorithm))
err = aecmk.NewError(op, fmt.Sprintf("Unsupported encryption algorithm %s", encryptionAlgorithm), nil)
} else {
algorithm = azkeys.EncryptionAlgorithmRSAOAEP
}
return azkeys.EncryptionAlgorithmRSAOAEP
return
}

// masterKeyPath is a full URL. The AKV client requires it broken down into endpoint, name, and version
// The URL has format '{endpoint}/{host}/keys/{name}/[{version}/]'
func (p *Provider) getKeyData(masterKeyPath string) *keyData {
func (p *Provider) getKeyData(ctx context.Context, masterKeyPath string, op aecmk.Operation) (k *keyData, err error) {
endpoint, keypath, allowed := p.allowedPathAndEndpoint(masterKeyPath)
if !(allowed) {
return nil
err = aecmk.KeyPathNotAllowed(masterKeyPath, op)
return
}
k := &keyData{
k = &keyData{
endpoint: endpoint,
name: keypath[0],
}
if len(keypath) > 1 {
k.version = keypath[1]
}
client := p.getAKVClient(endpoint)
r, err := client.GetKey(context.Background(), k.name, k.version, nil)
client, err := p.getAKVClient(op, endpoint)
if err != nil {
return
}
r, err := client.GetKey(ctx, k.name, k.version, nil)
if err != nil {
panic(fmt.Errorf("Unable to get key from AKV %w", err))
err = aecmk.NewError(op, "Unable to get key from AKV. Name:"+masterKeyPath, err)
}
if r.Key.Kty == nil || (*r.Key.Kty != azkeys.KeyTypeRSA && *r.Key.Kty != azkeys.KeyTypeRSAHSM) {
panic(fmt.Errorf("Key type not supported for Always Encrypted"))
err = aecmk.NewError(op, "Key type not supported for Always Encrypted", nil)
}
k.publicKey = &rsa.PublicKey{
N: new(big.Int).SetBytes(r.Key.N),
E: int(new(big.Int).SetBytes(r.Key.E).Int64()),
if err == nil {
k.publicKey = &rsa.PublicKey{
N: new(big.Int).SetBytes(r.Key.N),
E: int(new(big.Int).SetBytes(r.Key.E).Int64()),
}
}
return k
return
}

func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string, keypath []string, allowed bool) {
allowed = len(p.AllowedLocations) == 0
url, err := url.Parse(masterKeyPath)
if err != nil {
panic(fmt.Errorf("Invalid URL for master key path %s: %w", masterKeyPath, err))
allowed = false
return
}
if !allowed {

Expand All @@ -226,7 +254,8 @@ func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string
if allowed {
pathParts := strings.Split(strings.TrimLeft(url.Path, "/"), "/")
if len(pathParts) < 2 || len(pathParts) > 3 || pathParts[0] != "keys" {
panic(fmt.Errorf("Invalid URL for master key path %s", masterKeyPath))
allowed = false
return
}
keypath = pathParts[1:]
url.Path = ""
Expand All @@ -237,28 +266,34 @@ func (p *Provider) allowedPathAndEndpoint(masterKeyPath string) (endpoint string
return
}

func (p *Provider) getAKVClient(endpoint string) (client *azkeys.Client) {
client, err := azkeys.NewClient(endpoint, p.getCredential(endpoint), nil)
func (p *Provider) getAKVClient(op aecmk.Operation, endpoint string) (client *azkeys.Client, err error) {
credential, err := p.getCredential(op, endpoint)
if err == nil {
client, err = azkeys.NewClient(endpoint, credential, nil)
}
if err != nil {
panic(fmt.Errorf("Unable to create AKV client %w", err))
err = aecmk.NewError(op, "Unable to create AKV client", err)
}
return
}

func (p *Provider) getCredential(endpoint string) azcore.TokenCredential {
func (p *Provider) getCredential(op aecmk.Operation, endpoint string) (credential azcore.TokenCredential, err error) {
if len(p.credentials) == 0 {
credential, err := azidentity.NewDefaultAzureCredential(nil)
credential, err = azidentity.NewDefaultAzureCredential(nil)
if err != nil {
panic(fmt.Errorf("Unable to create a default credential: %w", err))
err = aecmk.NewError(op, "Unable to create a default credential", err)
} else {
p.credentials[wildcard] = credential
}
p.credentials[wildcard] = credential
return credential
return
}
if credential, ok := p.credentials[endpoint]; ok {
return credential
var ok bool
if credential, ok = p.credentials[endpoint]; ok {
return
}
if credential, ok := p.credentials[wildcard]; ok {
return credential
if credential, ok = p.credentials[wildcard]; ok {
return
}
panic(fmt.Errorf("No credential available for AKV path %s", endpoint))
err = aecmk.NewError(op, fmt.Sprintf("No credential available for AKV path %s", endpoint), nil)
return
}
15 changes: 10 additions & 5 deletions aecmk/akv/keyprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package akv

import (
"context"
"crypto/rand"
"net/url"
"testing"
Expand All @@ -26,9 +27,13 @@ func TestEncryptDecryptRoundTrip(t *testing.T) {
plainKey := make([]byte, 32)
_, _ = rand.Read(plainKey)
t.Log("Plainkey:", plainKey)
encryptedKey := p.EncryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, plainKey)
t.Log("Encryptedkey:", encryptedKey)
assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey")
decryptedKey := p.DecryptColumnEncryptionKey(keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey)
assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey)
encryptedKey, err := p.EncryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, plainKey)
if assert.NoError(t, err, "EncryptColumnEncryptionKey") {
t.Log("Encryptedkey:", encryptedKey)
assert.NotEqualValues(t, plainKey, encryptedKey, "encryptedKey is the same as plainKey")
decryptedKey, err := p.DecryptColumnEncryptionKey(context.Background(), keyPath, aecmk.KeyEncryptionAlgorithm, encryptedKey)
if assert.NoError(t, err, "DecryptColumnEncryptionKey") {
assert.Equalf(t, plainKey, decryptedKey, "decryptedkey doesn't match plainKey. %v : %v", decryptedKey, plainKey)
}
}
}
39 changes: 39 additions & 0 deletions aecmk/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package aecmk

import "fmt"

// Operation specifies the action that returned an error
type Operation int

const (
Decryption Operation = iota
Encryption
Validation
)

// Error is the type of all errors returned by key encryption providers
type Error struct {
Operation Operation
err error
msg string
}

func (e *Error) Error() string {
return e.msg
}

func (e *Error) Unwrap() error {
return e.err
}

func NewError(operation Operation, msg string, err error) error {
return &Error{
Operation: operation,
msg: msg,
err: err,
}
}

func KeyPathNotAllowed(path string, operation Operation) error {
return NewError(operation, fmt.Sprintf("Key path not allowed: %s", path), nil)
}