Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support version selection for database plugins #16982

Merged
merged 8 commits into from
Sep 9, 2022
2 changes: 1 addition & 1 deletion builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
return nil, err
}

dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger)
if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err)
}
Expand Down
4 changes: 4 additions & 0 deletions builtin/logical/database/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ func TestBackend_config_connection(t *testing.T) {
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
Expand Down Expand Up @@ -289,6 +290,7 @@ func TestBackend_config_connection(t *testing.T) {
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
Expand Down Expand Up @@ -331,6 +333,7 @@ func TestBackend_config_connection(t *testing.T) {
"allowed_roles": []string{"flu", "barre"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
Expand Down Expand Up @@ -728,6 +731,7 @@ func TestBackend_connectionCrud(t *testing.T) {
"allowed_roles": []string{"plugin-role-test"},
"root_credentials_rotate_statements": []string(nil),
"password_policy": "",
"plugin_version": "",
}
req.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
Expand Down
8 changes: 4 additions & 4 deletions builtin/logical/database/dbplugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func TestPlugin_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

dbRaw, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
dbRaw, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand All @@ -163,7 +163,7 @@ func TestPlugin_CreateUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func TestPlugin_RenewUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()

db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", sys, log.NewNullLogger())
db, err := dbplugin.PluginFactory(namespace.RootContext(nil), "test-plugin", "", sys, log.NewNullLogger())
if err != nil {
t.Fatalf("err: %s", err)
}
Expand Down
52 changes: 50 additions & 2 deletions builtin/logical/database/path_config_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
"errors"
"fmt"
"net/url"
"sort"

"github.com/fatih/structs"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-version"

v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)

Expand All @@ -22,7 +26,8 @@ var (
// DatabaseConfig is used by the Factory function to configure a Database
// object.
type DatabaseConfig struct {
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
PluginVersion string `json:"plugin_version" structs:"plugin_version" mapstructure:"plugin_version"`
// ConnectionDetails stores the database specific connection settings needed
// by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
Expand Down Expand Up @@ -110,6 +115,11 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
that plugin type.`,
},

"plugin_version": {
Type: framework.TypeString,
Description: `The version of the plugin to use.`,
},

"verify_connection": {
Type: framework.TypeBool,
Default: true,
Expand Down Expand Up @@ -281,6 +291,43 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyPluginName), nil
}

if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok {
config.PluginVersion = pluginVersionRaw.(string)
}
if config.PluginVersion != "" {
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}

// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()
} else {
// No version provided. Pin to the current latest version if any versioned
// plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return nil, err
}

var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}

if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})

config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
}

if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok {
config.AllowedRoles = allowedRolesRaw.([]string)
} else if req.Operation == logical.CreateOperation {
Expand All @@ -301,6 +348,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
// ConnectionDetails.
delete(data.Raw, "name")
delete(data.Raw, "plugin_name")
delete(data.Raw, "plugin_version")
delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
Expand All @@ -326,7 +374,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}

// Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
dbw, err := newDatabaseWrapper(ctx, config.PluginName, config.PluginVersion, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse("error creating database object: %s", err), nil
}
Expand Down
6 changes: 3 additions & 3 deletions builtin/logical/database/version_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type databaseVersionWrapper struct {

// newDatabaseWrapper figures out which version of the database the pluginName is referring to and returns a wrapper object
// that can be used to make operations on the underlying database plugin.
func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (dbw databaseVersionWrapper, err error) {
newDB, err := v5.PluginFactory(ctx, pluginName, sys, logger)
func newDatabaseWrapper(ctx context.Context, pluginName string, pluginVersion string, sys pluginutil.LookRunnerUtil, logger log.Logger) (dbw databaseVersionWrapper, err error) {
newDB, err := v5.PluginFactory(ctx, pluginName, pluginVersion, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
v5: newDB,
Expand All @@ -32,7 +32,7 @@ func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.L
merr := &multierror.Error{}
merr = multierror.Append(merr, err)

legacyDB, err := v4.PluginFactory(ctx, pluginName, sys, logger)
legacyDB, err := v4.PluginFactory(ctx, pluginName, pluginVersion, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
v4: legacyDB,
Expand Down
5 changes: 5 additions & 0 deletions builtin/plugin/backend_lazyLoad_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package plugin

import (
"context"
"errors"
"testing"

"github.com/hashicorp/vault/sdk/helper/logging"
Expand Down Expand Up @@ -193,3 +194,7 @@ func (v testSystemView) LookupPluginVersion(context.Context, string, consts.Plug
},
}, nil
}

func (v testSystemView) ListVersionedPlugins(_ context.Context, _ consts.PluginType) ([]pluginutil.VersionedPlugin, error) {
return nil, errors.New("ListVersionedPlugins not implemented for testSystemView")
}
2 changes: 1 addition & 1 deletion sdk/database/dbplugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type Database interface {

// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
func PluginFactory(ctx context.Context, pluginName string, pluginVersion string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(ctx, pluginName, consts.PluginTypeDatabase)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions sdk/database/dbplugin/v5/plugin_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (

// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
func PluginFactory(ctx context.Context, pluginName string, pluginVersion string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(ctx, pluginName, consts.PluginTypeDatabase)
pluginRunner, err := sys.LookupPluginVersion(ctx, pluginName, consts.PluginTypeDatabase, pluginVersion)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -43,6 +43,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
config := pluginutil.PluginClientConfig{
Name: pluginName,
PluginType: consts.PluginTypeDatabase,
Version: pluginVersion,
PluginSets: PluginSets,
HandshakeConfig: HandshakeConfig,
Logger: namedLogger,
Expand Down
8 changes: 8 additions & 0 deletions sdk/logical/system_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ type SystemView interface {
// name and version. Returns a PluginRunner or an error if a plugin can not be found.
LookupPluginVersion(ctx context.Context, pluginName string, pluginType consts.PluginType, version string) (*pluginutil.PluginRunner, error)

// ListVersionedPlugins returns information about all plugins of a certain
// type in the catalog, including any versioning information stored for them.
ListVersionedPlugins(ctx context.Context, pluginType consts.PluginType) ([]pluginutil.VersionedPlugin, error)

// NewPluginClient returns a client for managing the lifecycle of plugin
// processes
NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error)
Expand Down Expand Up @@ -176,6 +180,10 @@ func (d StaticSystemView) LookupPluginVersion(_ context.Context, _ string, _ con
return nil, errors.New("LookupPluginVersion is not implemented in StaticSystemView")
}

func (d StaticSystemView) ListVersionedPlugins(_ context.Context, _ consts.PluginType) ([]pluginutil.VersionedPlugin, error) {
return nil, errors.New("ListVersionedPlugins is not implemented in StaticSystemView")
tomhjp marked this conversation as resolved.
Show resolved Hide resolved
}

func (d StaticSystemView) MlockEnabled() bool {
return d.EnableMlock
}
Expand Down
4 changes: 4 additions & 0 deletions sdk/plugin/grpc_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ func (s *gRPCSystemViewClient) LookupPluginVersion(_ context.Context, _ string,
return nil, fmt.Errorf("cannot call LookupPluginVersion from a plugin backend")
}

func (s *gRPCSystemViewClient) ListVersionedPlugins(_ context.Context, _ consts.PluginType) ([]pluginutil.VersionedPlugin, error) {
return nil, fmt.Errorf("cannot call ListVersionedPlugins from a plugin backend")
}

func (s *gRPCSystemViewClient) MlockEnabled() bool {
reply, err := s.client.MlockEnabled(context.Background(), &pb.Empty{})
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions vault/dynamic_system_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,18 @@ func (d dynamicSystemView) LookupPluginVersion(ctx context.Context, name string,
return r, nil
}

// ListVersionedPlugins returns information about all plugins of a certain
// typein the catalog, including any versioning information stored for them.
func (d dynamicSystemView) ListVersionedPlugins(ctx context.Context, pluginType consts.PluginType) ([]pluginutil.VersionedPlugin, error) {
if d.core == nil {
return nil, fmt.Errorf("system view core is nil")
}
if d.core.pluginCatalog == nil {
return nil, fmt.Errorf("system view core plugin catalog is nil")
}
return d.core.pluginCatalog.ListVersionedPlugins(ctx, pluginType)
}

// MlockEnabled returns the configuration setting for enabling mlock on plugins.
func (d dynamicSystemView) MlockEnabled() bool {
return d.core.enableMlock
Expand Down
9 changes: 5 additions & 4 deletions vault/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ const mountStateUnmounting = "unmounting"
type MountEntry struct {
Table string `json:"table"` // The table it belongs to
Path string `json:"path"` // Mount Path
Type string `json:"type"` // Logical backend Type
Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 this is a confusing overloading of terms. A bit of clarity is good.

Description string `json:"description"` // User-provided description
UUID string `json:"uuid"` // Barrier view UUID
BackendAwareUUID string `json:"backend_aware_uuid"` // UUID that can be used by the backend as a helper when a consistent value is needed outside of storage.
Expand All @@ -330,9 +330,9 @@ type MountEntry struct {
synthesizedConfigCache sync.Map

// version info
Version string `json:"version,omitempty"`
Sha string `json:"sha,omitempty"`
RunningVersion string `json:"running_version,omitempty"`
Version string `json:"version,omitempty"` // The semantic version of the mounted plugin, e.g. v1.2.3.
Sha string `json:"sha,omitempty"` // The SHA256 sum of the plugin binary.
RunningVersion string `json:"running_version,omitempty"` // The semantic version of the mounted plugin as reported by the plugin.
RunningSha string `json:"running_sha,omitempty"`
}

Expand Down Expand Up @@ -1489,6 +1489,7 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView
}

conf["plugin_type"] = consts.PluginTypeSecrets.String()
conf["plugin_version"] = entry.Version

backendLogger := c.baseLogger.Named(fmt.Sprintf("secrets.%s.%s", t, entry.Accessor))
c.AddLogger(backendLogger)
Expand Down