Skip to content

Commit

Permalink
server: fsnotify based cert reloading
Browse files Browse the repository at this point in the history
Reload certs, keys and optionally the CA cert pool when they change on
disk.

Signed-off-by: Charlie Egan <charlie@styra.com>
  • Loading branch information
charlieegan3 committed Nov 21, 2023
1 parent a8b57b0 commit a2fbe36
Show file tree
Hide file tree
Showing 5 changed files with 583 additions and 36 deletions.
1 change: 1 addition & 0 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ func initRuntime(ctx context.Context, params runCmdParams, args []string, addrSe
params.rt.CertificateFile = params.tlsCertFile
params.rt.CertificateKeyFile = params.tlsPrivateKeyFile
params.rt.CertificateRefresh = params.tlsCertRefresh
params.rt.CertPoolFile = params.tlsCACertFile

if params.tlsCACertFile != "" {
pool, err := loadCertPool(params.tlsCACertFile)
Expand Down
20 changes: 16 additions & 4 deletions runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import (

"github.com/fsnotify/fsnotify"
"github.com/gorilla/mux"
"github.com/open-policy-agent/opa/internal/compiler"
"github.com/open-policy-agent/opa/internal/pathwatcher"
"github.com/open-policy-agent/opa/internal/ref"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/propagation"
"go.uber.org/automaxprocs/maxprocs"

"github.com/open-policy-agent/opa/bundle"
opa_config "github.com/open-policy-agent/opa/config"
"github.com/open-policy-agent/opa/internal/compiler"
"github.com/open-policy-agent/opa/internal/config"
internal_tracing "github.com/open-policy-agent/opa/internal/distributedtracing"
internal_logging "github.com/open-policy-agent/opa/internal/logging"
"github.com/open-policy-agent/opa/internal/pathwatcher"
"github.com/open-policy-agent/opa/internal/prometheus"
"github.com/open-policy-agent/opa/internal/ref"
"github.com/open-policy-agent/opa/internal/report"
"github.com/open-policy-agent/opa/internal/runtime"
initload "github.com/open-policy-agent/opa/internal/runtime/init"
Expand Down Expand Up @@ -115,6 +115,8 @@ type Params struct {

// CertPool holds the CA certs trusted by the OPA server.
CertPool *x509.CertPool
// CertPoolFile, if set permits the reloading of the CA cert pool from disk
CertPoolFile string

// MinVersion contains the minimum TLS version that is acceptable.
// If zero, TLS 1.2 is currently taken as the minimum.
Expand Down Expand Up @@ -537,8 +539,8 @@ func (rt *Runtime) Serve(ctx context.Context) error {
WithPprofEnabled(rt.Params.PprofEnabled).
WithAddresses(*rt.Params.Addrs).
WithH2CEnabled(rt.Params.H2CEnabled).
// always use the initial values for the certificate and ca pool, reloading behavior is configured below
WithCertificate(rt.Params.Certificate).
WithCertificatePaths(rt.Params.CertificateFile, rt.Params.CertificateKeyFile, rt.Params.CertificateRefresh).
WithCertPool(rt.Params.CertPool).
WithAuthentication(rt.Params.Authentication).
WithAuthorization(rt.Params.Authorization).
Expand All @@ -562,6 +564,16 @@ func (rt *Runtime) Serve(ctx context.Context) error {
rt.server = rt.server.WithUnixSocketPermission(rt.Params.UnixSocketPerm)
}

if rt.Params.CertificateRefresh > 0 {
rt.server = rt.server.WithCertificatePaths(rt.Params.CertificateFile, rt.Params.CertificateKeyFile, rt.Params.CertificateRefresh)
} else if rt.Params.Certificate != nil {
rt.server = rt.server.WithTLSConfig(&server.TLSConfig{
CertFile: rt.Params.CertificateFile,
KeyFile: rt.Params.CertificateKeyFile,
CertPoolFile: rt.Params.CertPoolFile,
})
}

rt.server, err = rt.server.Init(ctx)
if err != nil {
rt.logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to initialize server.")
Expand Down
108 changes: 83 additions & 25 deletions server/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,110 @@ import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"os"
"time"

"github.com/fsnotify/fsnotify"

"github.com/open-policy-agent/opa/internal/pathwatcher"
"github.com/open-policy-agent/opa/logging"
)

func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
s.certMtx.RLock()
defer s.certMtx.RUnlock()
s.tlsConfigMtx.RLock()
defer s.tlsConfigMtx.RUnlock()
return s.cert, nil
}

func (s *Server) certLoop(logger logging.Logger) Loop {
func (s *Server) reloadTLSConfig(logger logging.Logger) error {
certHash, err := hash(s.certFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server key: %w", err)
}

s.tlsConfigMtx.Lock()
defer s.tlsConfigMtx.Unlock()

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
}

// do not attempt to reload the ca cert pool if it has not been configured
if s.certPoolFile == "" {
return nil
}

certPoolHash, err := hash(s.certPoolFile)
if err != nil {
return fmt.Errorf("failed to refresh CA cert pool: %w", err)
}

if !bytes.Equal(s.certPoolFileHash, certPoolHash) {
caCertPEM, err := os.ReadFile(s.certPoolFile)
if err != nil {
return fmt.Errorf("failed to read CA cert pool file: %w", err)
}

pool := x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(caCertPEM); !ok {
return fmt.Errorf("failed to parse CA cert pool file %q", s.certPoolFile)
}

s.certPool = pool
}

return nil
}

func (s *Server) certLoopPolling(logger logging.Logger) Loop {
return func() error {
for range time.NewTicker(s.certRefresh).C {
certHash, err := hash(s.certFile)
err := s.reloadTLSConfig(logger)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
}
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
logger.Error(fmt.Sprintf("Failed to reload TLS config: %s", err))
}
}

s.certMtx.Lock()
return nil
}
}

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)
func (s *Server) certLoopNotify(logger logging.Logger) Loop {
return func() error {
watcher, err := pathwatcher.CreatePathWatcher([]string{
s.certFile, s.certKeyFile, s.certPoolFile,
})
if err != nil {
return fmt.Errorf("failed to create tls path watcher: %w", err)
}

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
for evt := range watcher.Events {
removalMask := fsnotify.Remove | fsnotify.Rename
mask := fsnotify.Create | fsnotify.Write | removalMask
if (evt.Op & mask) != 0 {
err = s.reloadTLSConfig(s.manager.Logger())
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
s.certMtx.Unlock()
continue
logger.Error("failed to reload TLS config: %s", err)
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
logger.Info("TLS config reloaded")
}

s.certMtx.Unlock()
}

return nil
Expand Down
36 changes: 31 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ type Server struct {
authentication AuthenticationScheme
authorization AuthorizationScheme
cert *tls.Certificate
certMtx sync.RWMutex
tlsConfigMtx sync.RWMutex
certFile string
certFileHash []byte
certKeyFile string
certKeyFileHash []byte
certRefresh time.Duration
certPool *x509.CertPool
certPoolFile string
certPoolFileHash []byte
minTLSVersion uint16
mtx sync.RWMutex
partials map[string]rego.PartialResult
Expand Down Expand Up @@ -153,6 +155,20 @@ type Metrics interface {
InstrumentHandler(handler http.Handler, label string) http.Handler
}

// TLSConfig represents the TLS configuration for the server.
// This configuration is used to configure file watchers to reload each file as it
// changes on disk.
type TLSConfig struct {
// CertFile is the path to the certificate file.
CertFile string

// KeyFile is the path to the key file.
KeyFile string

// CertPoolFile is the path to the CA cert pool file.
CertPoolFile string
}

// Loop will contain all the calls from the server that we'll be listening on.
type Loop func() error

Expand Down Expand Up @@ -274,6 +290,14 @@ func (s *Server) WithCertPool(pool *x509.CertPool) *Server {
return s
}

// WithTLSConfig sets the TLS configuration used by the server.
func (s *Server) WithTLSConfig(tlsConfig *TLSConfig) *Server {
s.certFile = tlsConfig.CertFile
s.certKeyFile = tlsConfig.KeyFile
s.certPoolFile = tlsConfig.CertPoolFile
return s
}

// WithStore sets the storage used by the server.
func (s *Server) WithStore(store storage.Store) *Server {
s.store = store
Expand Down Expand Up @@ -566,11 +590,13 @@ func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) ([
"cert-file": s.certFile,
"cert-key-file": s.certKeyFile,
})

// if a manual cert refresh period has been set, then use the polling behavior,
// otherwise use the fsnotify default behavior
if s.certRefresh > 0 {
certLoop := s.certLoop(logger)
loops = []Loop{loop, certLoop}
} else {
loops = []Loop{loop}
loops = []Loop{loop, s.certLoopPolling(logger)}
} else if s.certFile != "" {
loops = []Loop{loop, s.certLoopNotify(logger)}
}
default:
err = fmt.Errorf("invalid url scheme %q", parsedURL.Scheme)
Expand Down

0 comments on commit a2fbe36

Please sign in to comment.