Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: Support fsnotify reloading of certs #6415

Merged
1 change: 1 addition & 0 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ func initRuntime(ctx context.Context, params runCmdParams, args []string, addrSe
params.rt.CertificateFile = params.tlsCertFile
params.rt.CertificateKeyFile = params.tlsPrivateKeyFile
params.rt.CertificateRefresh = params.tlsCertRefresh
params.rt.CertPoolFile = params.tlsCACertFile

if params.tlsCACertFile != "" {
pool, err := loadCertPool(params.tlsCACertFile)
Expand Down
26 changes: 22 additions & 4 deletions runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import (

"github.com/fsnotify/fsnotify"
"github.com/gorilla/mux"
"github.com/open-policy-agent/opa/internal/compiler"
"github.com/open-policy-agent/opa/internal/pathwatcher"
"github.com/open-policy-agent/opa/internal/ref"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/propagation"
"go.uber.org/automaxprocs/maxprocs"

"github.com/open-policy-agent/opa/bundle"
opa_config "github.com/open-policy-agent/opa/config"
"github.com/open-policy-agent/opa/internal/compiler"
"github.com/open-policy-agent/opa/internal/config"
internal_tracing "github.com/open-policy-agent/opa/internal/distributedtracing"
internal_logging "github.com/open-policy-agent/opa/internal/logging"
"github.com/open-policy-agent/opa/internal/pathwatcher"
"github.com/open-policy-agent/opa/internal/prometheus"
"github.com/open-policy-agent/opa/internal/ref"
"github.com/open-policy-agent/opa/internal/report"
"github.com/open-policy-agent/opa/internal/runtime"
initload "github.com/open-policy-agent/opa/internal/runtime/init"
Expand Down Expand Up @@ -115,6 +115,8 @@ type Params struct {

// CertPool holds the CA certs trusted by the OPA server.
CertPool *x509.CertPool
// CertPoolFile, if set permits the reloading of the CA cert pool from disk
CertPoolFile string

// MinVersion contains the minimum TLS version that is acceptable.
// If zero, TLS 1.2 is currently taken as the minimum.
Expand Down Expand Up @@ -537,8 +539,8 @@ func (rt *Runtime) Serve(ctx context.Context) error {
WithPprofEnabled(rt.Params.PprofEnabled).
WithAddresses(*rt.Params.Addrs).
WithH2CEnabled(rt.Params.H2CEnabled).
// always use the initial values for the certificate and ca pool, reloading behavior is configured below
WithCertificate(rt.Params.Certificate).
WithCertificatePaths(rt.Params.CertificateFile, rt.Params.CertificateKeyFile, rt.Params.CertificateRefresh).
WithCertPool(rt.Params.CertPool).
WithAuthentication(rt.Params.Authentication).
WithAuthorization(rt.Params.Authorization).
Expand All @@ -562,6 +564,22 @@ func (rt *Runtime) Serve(ctx context.Context) error {
rt.server = rt.server.WithUnixSocketPermission(rt.Params.UnixSocketPerm)
}

// If a refresh period is set, then we will periodically reload the certificate and ca pool. Otherwise, we will only
// reload cert, key and ca pool files when they change on disk.
if rt.Params.CertificateRefresh > 0 {
charlieegan3 marked this conversation as resolved.
Show resolved Hide resolved
rt.server = rt.server.WithCertRefresh(rt.Params.CertificateRefresh)
}

// if either the cert or the ca pool file is set then these fields will be set on the server and reloaded when they
// change on disk.
if rt.Params.CertificateFile != "" || rt.Params.CertPoolFile != "" {
rt.server = rt.server.WithTLSConfig(&server.TLSConfig{
CertFile: rt.Params.CertificateFile,
KeyFile: rt.Params.CertificateKeyFile,
CertPoolFile: rt.Params.CertPoolFile,
})
}

rt.server, err = rt.server.Init(ctx)
if err != nil {
rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to initialize server.")
Expand Down
173 changes: 148 additions & 25 deletions server/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,175 @@ import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"os"
"time"

"github.com/fsnotify/fsnotify"

"github.com/open-policy-agent/opa/internal/errors"
"github.com/open-policy-agent/opa/internal/pathwatcher"
"github.com/open-policy-agent/opa/logging"
)

func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
s.certMtx.RLock()
defer s.certMtx.RUnlock()
s.tlsConfigMtx.RLock()
defer s.tlsConfigMtx.RUnlock()
return s.cert, nil
}

func (s *Server) certLoop(logger logging.Logger) Loop {
// reloadTLSConfig reloads the TLS config if the cert, key files or cert pool contents have changed.
func (s *Server) reloadTLSConfig(logger logging.Logger) error {
s.tlsConfigMtx.Lock()
defer s.tlsConfigMtx.Unlock()

// reloading of the certificate key pair and the CA pool are independent operations,
// though errors from either operation are aggregated.
var errs error

// if the server has a cert configured, then we need to check the cert and key for changes.
if s.certFile != "" {
newCert, certFileHash, certKeyFileHash, updated, err := reloadCertificateKeyPair(
s.certFile,
s.certKeyFile,
s.certFileHash,
s.certKeyFileHash,
logger,
)
if err != nil {
errs = errors.Join(errs, err)
} else if updated {
s.cert = newCert
s.certFileHash = certFileHash
s.certKeyFileHash = certKeyFileHash

logger.Debug("Refreshed server certificate.")
}
}

// if the server has a cert pool configured, also attempt to reload this
if s.certPoolFile != "" {
johanfylling marked this conversation as resolved.
Show resolved Hide resolved
pool, certPoolFileHash, updated, err := reloadCertificatePool(s.certPoolFile, s.certPoolFileHash, logger)
if err != nil {
errs = errors.Join(errs, err)
} else if updated {
s.certPool = pool
s.certPoolFileHash = certPoolFileHash
logger.Debug("Refreshed server CA certificate pool.")
}
}

return errs
}

// reloadCertificatePool loads the CA cert pool from the given file and returns a new pool if the file has changed.
func reloadCertificatePool(certPoolFile string, certPoolFileHash []byte, logger logging.Logger) (*x509.CertPool, []byte, bool, error) {
certPoolHash, err := hash(certPoolFile)
if err != nil {
return nil, nil, false, fmt.Errorf("failed to hash CA cert pool file: %w", err)
}

if bytes.Equal(certPoolFileHash, certPoolHash) {
return nil, nil, false, nil
}
caCertPEM, err := os.ReadFile(certPoolFile)
if err != nil {
return nil, nil, false, fmt.Errorf("failed to read CA cert pool file %q: %w", certPoolFile, err)
}

pool := x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(caCertPEM); !ok {
return nil, nil, false, fmt.Errorf("failed to load CA cert pool file %q", certPoolFile)
}

return pool, certPoolHash, true, nil
}

// reloadCertificateKeyPair loads the certificate and key from the given files and returns a new certificate if either
// file has changed.
func reloadCertificateKeyPair(
certFile, certKeyFile string,
certFileHash, certKeyFileHash []byte,
logger logging.Logger,
) (*tls.Certificate, []byte, []byte, bool, error) {
certHash, err := hash(certFile)
if err != nil {
return nil, nil, nil, false, fmt.Errorf("failed to hash server certificate file: %w", err)
}

certKeyHash, err := hash(certKeyFile)
if err != nil {
return nil, nil, nil, false, fmt.Errorf("failed to hash server key file: %w", err)
}

differentCert := !bytes.Equal(certFileHash, certHash)
differentKey := !bytes.Equal(certKeyFileHash, certKeyHash)

if differentCert && !differentKey {
logger.Warn("Server certificate file changed but server key file did not change.")
}
if !differentCert && differentKey {
logger.Warn("Server key file changed but server certificate file did not change.")
}

if !differentCert && !differentKey {
return nil, nil, nil, false, nil
}

newCert, err := tls.LoadX509KeyPair(certFile, certKeyFile)
if err != nil {
return nil, nil, nil, false, fmt.Errorf("server certificate key pair was not updated, update failed: %w", err)
}

return &newCert, certHash, certKeyHash, true, nil
}

func (s *Server) certLoopPolling(logger logging.Logger) Loop {
return func() error {
for range time.NewTicker(s.certRefresh).C {
certHash, err := hash(s.certFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
}
certKeyHash, err := hash(s.certKeyFile)
err := s.reloadTLSConfig(logger)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
logger.Error(fmt.Sprintf("Failed to reload TLS config: %s", err))
}
}

s.certMtx.Lock()
return nil
}
}

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)
func (s *Server) certLoopNotify(logger logging.Logger) Loop {
return func() error {

var paths []string

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
// if a cert file is set, then we want to watch the cert and key
if s.certFile != "" {
paths = append(paths, s.certFile, s.certKeyFile)
}

// if a cert pool file is set, then we want to watch the cert pool. This might be set without the cert and key
// being set too.
if s.certPoolFile != "" {
paths = append(paths, s.certPoolFile)
}

watcher, err := pathwatcher.CreatePathWatcher(paths)
if err != nil {
return fmt.Errorf("failed to create tls path watcher: %w", err)
}

for evt := range watcher.Events {
removalMask := fsnotify.Remove | fsnotify.Rename
mask := fsnotify.Create | fsnotify.Write | removalMask
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are aware that this will trigger partial writes:

        // The pathname was written to; this does *not* mean the write has finished,
        // and a write can be followed by more writes.
        Write

One other thing that I figured out when I did https://pkg.go.dev/github.com/zalando/skipper/secrets#SecretPaths some years ago is that in Kubernetes secret mounts you have symlinks and at least inotify does not work for symlinks as expected. A symlink does not change, if only the content of the targeted file change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for this. Since a partial write would result in OPA failing to reload the file (as it'd be invalid in reloadCertificateKeyPair), the main concern here is missing a write, is that correct?

Were you able to mitigate this in skipper without polling? Perhaps we could have OPA fallback to a polling behaviour when the file in question is a symlink?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do just polling every minute or 30s (something like this). Keep it simple. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I'll need to do some digging by the sounds of things!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, no worries I just wanted to hint that it might not work as expected in all environments.

if (evt.Op & mask) != 0 {
err = s.reloadTLSConfig(s.manager.Logger())
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
s.certMtx.Unlock()
continue
logger.Error("failed to reload TLS config: %s", err)
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
logger.Info("TLS config reloaded")
}

s.certMtx.Unlock()
}

return nil
Expand Down
76 changes: 61 additions & 15 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ type Server struct {
authentication AuthenticationScheme
authorization AuthorizationScheme
cert *tls.Certificate
certMtx sync.RWMutex
tlsConfigMtx sync.RWMutex
certFile string
certFileHash []byte
certKeyFile string
certKeyFileHash []byte
certRefresh time.Duration
certPool *x509.CertPool
certPoolFile string
certPoolFileHash []byte
minTLSVersion uint16
mtx sync.RWMutex
partials map[string]rego.PartialResult
Expand Down Expand Up @@ -153,6 +155,23 @@ type Metrics interface {
InstrumentHandler(handler http.Handler, label string) http.Handler
}

// TLSConfig represents the TLS configuration for the server.
// This configuration is used to configure file watchers to reload each file as it
// changes on disk.
type TLSConfig struct {
// CertFile is the path to the server's serving certificate file.
CertFile string

// KeyFile is the path to the server's key file, completing the key pair for the
// CertFile certificate.
KeyFile string

// CertPoolFile is the path to the CA cert pool file. The contents of this file will be
// reloaded when the file changes on disk and used in as trusted client CAs in the TLS config
// for new connections to the server.
CertPoolFile string
johanfylling marked this conversation as resolved.
Show resolved Hide resolved
}

// Loop will contain all the calls from the server that we'll be listening on.
type Loop func() error

Expand Down Expand Up @@ -274,6 +293,20 @@ func (s *Server) WithCertPool(pool *x509.CertPool) *Server {
return s
}

// WithTLSConfig sets the TLS configuration used by the server.
func (s *Server) WithTLSConfig(tlsConfig *TLSConfig) *Server {
s.certFile = tlsConfig.CertFile
s.certKeyFile = tlsConfig.KeyFile
s.certPoolFile = tlsConfig.CertPoolFile
return s
}

// WithCertRefresh sets the period on which certs, keys and cert pools are reloaded from disk.
func (s *Server) WithCertRefresh(refresh time.Duration) *Server {
s.certRefresh = refresh
return s
}

// WithStore sets the storage used by the server.
func (s *Server) WithStore(store storage.Store) *Server {
s.store = store
Expand Down Expand Up @@ -566,11 +599,13 @@ func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) ([
"cert-file": s.certFile,
"cert-key-file": s.certKeyFile,
})

// if a manual cert refresh period has been set, then use the polling behavior,
// otherwise use the fsnotify default behavior
if s.certRefresh > 0 {
certLoop := s.certLoop(logger)
loops = []Loop{loop, certLoop}
} else {
loops = []Loop{loop}
loops = []Loop{loop, s.certLoopPolling(logger)}
} else if s.certFile != "" || s.certPoolFile != "" {
loops = []Loop{loop, s.certLoopNotify(logger)}
}
default:
err = fmt.Errorf("invalid url scheme %q", parsedURL.Scheme)
Expand Down Expand Up @@ -605,17 +640,28 @@ func (s *Server) getListenerForHTTPSServer(u *url.URL, h http.Handler, t httpLis
Handler: h,
TLSConfig: &tls.Config{
GetCertificate: s.getCertificate,
ClientCAs: s.certPool,
},
}
if s.authentication == AuthenticationTLS {
httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
// GetConfigForClient is used to ensure that a fresh config is provided containing the latest cert pool.
// This is not required, but appears to be how connect time updates config should be done:
// https://github.com/golang/go/issues/16066#issuecomment-250606132
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
cfg := &tls.Config{
GetCertificate: s.getCertificate,
ClientCAs: s.certPool,
}

if s.minTLSVersion != 0 {
httpsServer.TLSConfig.MinVersion = s.minTLSVersion
} else {
httpsServer.TLSConfig.MinVersion = defaultMinTLSVersion
if s.authentication == AuthenticationTLS {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

if s.minTLSVersion != 0 {
cfg.MinVersion = s.minTLSVersion
} else {
cfg.MinVersion = defaultMinTLSVersion
}

return cfg, nil
},
},
}

l := newHTTPListener(&httpsServer, t)
Expand Down