Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support version selection for database plugins #16982

Merged
merged 8 commits into from
Sep 9, 2022
2 changes: 1 addition & 1 deletion builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
return nil, err
}

dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger)
if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err)
}
Expand Down
18 changes: 11 additions & 7 deletions builtin/logical/database/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

sys := vault.TestDynamicSystemView(cores[0].Core, nil)
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Postgres", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin-muxed", consts.PluginTypeDatabase, "TestBackend_PluginMain_PostgresMultiplexed", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodb-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Mongo", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodb-database-plugin-muxed", consts.PluginTypeDatabase, "TestBackend_PluginMain_MongoMultiplexed", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodbatlas-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_MongoAtlas", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodbatlas-database-plugin-muxed", consts.PluginTypeDatabase, "TestBackend_PluginMain_MongoAtlasMultiplexed", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_Postgres", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodb-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_Mongo", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodb-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MongoMultiplexed", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodbatlas-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MongoAtlas", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "mongodbatlas-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MongoAtlasMultiplexed", []string{}, "")

return cluster, sys
}
Expand Down Expand Up @@ -236,6 +236,7 @@ func TestBackend_config_connection(t *testing.T) {
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
Expand Down Expand Up @@ -289,6 +290,7 @@ func TestBackend_config_connection(t *testing.T) {
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
Expand Down Expand Up @@ -331,6 +333,7 @@ func TestBackend_config_connection(t *testing.T) {
"allowed_roles": []string{"flu", "barre"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
Expand Down Expand Up @@ -728,6 +731,7 @@ func TestBackend_connectionCrud(t *testing.T) {
"allowed_roles": []string{"plugin-role-test"},
"root_credentials_rotate_statements": []string(nil),
"password_policy": "",
"plugin_version": "",
}
req.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
Expand Down Expand Up @@ -1503,7 +1507,7 @@ func TestBackend_AsyncClose(t *testing.T) {
// Test that having a plugin that takes a LONG time to close will not cause the cleanup function to take
// longer than 750ms.
cluster, sys := getCluster(t)
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "hanging-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_Hanging", []string{}, "")
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "hanging-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_Hanging", []string{}, "")
t.Cleanup(cluster.Cleanup)

config := logical.TestBackendConfig()
Expand Down
10 changes: 5 additions & 5 deletions builtin/logical/database/dbplugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
cores := cluster.Cores

sys := vault.TestDynamicSystemView(cores[0].Core, nil)
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "TestPlugin_GRPC_Main", []string{}, "")
vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "", "TestPlugin_GRPC_Main", []string{}, "")

return cluster, sys
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func TestPlugin_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

dbRaw, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
dbRaw, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand All @@ -163,7 +163,7 @@ func TestPlugin_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func TestPlugin_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down
57 changes: 55 additions & 2 deletions builtin/logical/database/path_config_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
"errors"
"fmt"
"net/url"
"sort"

"github.com/fatih/structs"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-version"

v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)

Expand All @@ -22,7 +26,8 @@ var (
// DatabaseConfig is used by the Factory function to configure a Database
// object.
type DatabaseConfig struct {
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
PluginVersion string `json:"plugin_version" structs:"plugin_version" mapstructure:"plugin_version"`
// ConnectionDetails stores the database specific connection settings needed
// by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
Expand Down Expand Up @@ -110,6 +115,11 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
that plugin type.`,
},

"plugin_version": {
Type: framework.TypeString,
Description: `The version of the plugin to use.`,
},

"verify_connection": {
Type: framework.TypeBool,
Default: true,
Expand Down Expand Up @@ -281,6 +291,48 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyPluginName), nil
}

if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok {
config.PluginVersion = pluginVersionRaw.(string)
}

unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase)
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}

// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()
case err == nil && !unversionedPlugin.Builtin:
// We'll select the unversioned plugin that's been registered.
case req.Operation == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return nil, err
}

var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}

if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})

config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
}

if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok {
config.AllowedRoles = allowedRolesRaw.([]string)
} else if req.Operation == logical.CreateOperation {
Expand All @@ -301,6 +353,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
// ConnectionDetails.
delete(data.Raw, "name")
delete(data.Raw, "plugin_name")
delete(data.Raw, "plugin_version")
delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
Expand All @@ -326,7 +379,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}

// Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse("error creating database object: %s", err), nil
}
Expand Down
6 changes: 3 additions & 3 deletions builtin/logical/database/version_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type databaseVersionWrapper struct {

// newDatabaseWrapper figures out which version of the database the pluginName is referring to and returns a wrapper object
// that can be used to make operations on the underlying database plugin.
func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (dbw databaseVersionWrapper, err error) {
newDB, err := v5.PluginFactory(ctx, pluginName, sys, logger)
func newDatabaseWrapper(ctx context.Context, pluginName string, pluginVersion string, sys pluginutil.LookRunnerUtil, logger log.Logger) (dbw databaseVersionWrapper, err error) {
newDB, err := v5.PluginFactory(ctx, pluginName, pluginVersion, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
v5: newDB,
Expand All @@ -32,7 +32,7 @@ func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.L
merr := &multierror.Error{}
merr = multierror.Append(merr, err)

legacyDB, err := v4.PluginFactory(ctx, pluginName, sys, logger)
legacyDB, err := v4.PluginFactory(ctx, pluginName, pluginVersion, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
v4: legacyDB,
Expand Down