Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[server] fsnotify based cert reloading #6118

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 33 additions & 28 deletions server/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"crypto/sha256"
"crypto/tls"
"fmt"
"io"
"os"
"time"
Expand All @@ -21,39 +22,43 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
return s.cert, nil
}

func (s *Server) reloadTLSConfig(logger logging.Logger) error {
certHash, err := hash(s.certFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}

s.certMtx.Lock()
defer s.certMtx.Unlock()

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
if err != nil {
return fmt.Errorf("failed to refresh server certificate: %w", err)
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
}

return nil
}

func (s *Server) certLoop(logger logging.Logger) Loop {
return func() error {
for range time.NewTicker(s.certRefresh).C {
certHash, err := hash(s.certFile)
err := s.reloadTLSConfig(logger)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
logger.Error("Failed to reload TLS config: %v", err)
}
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
}

s.certMtx.Lock()

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
s.certMtx.Unlock()
continue
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
}

s.certMtx.Unlock()
}

return nil
Expand Down
53 changes: 53 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
"sync"
"time"

"github.com/fsnotify/fsnotify"

"github.com/open-policy-agent/opa/internal/pathwatcher"
serverEncodingPlugin "github.com/open-policy-agent/opa/plugins/server/encoding"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -153,6 +156,17 @@ type Metrics interface {
InstrumentHandler(handler http.Handler, label string) http.Handler
}

// TLSConfig represents the TLS configuration for the server.
// This configuration is used to configure file watchers to reload each file as it
// changes on disk.
type TLSConfig struct {
// CertFile is the path to the certificate file.
CertFile string

// KeyFile is the path to the key file.
KeyFile string
}

// Loop will contain all the calls from the server that we'll be listening on.
type Loop func() error

Expand Down Expand Up @@ -182,6 +196,38 @@ func (s *Server) Init(ctx context.Context) (*Server, error) {
return nil, err
}

err = s.reloadTLSConfig(s.manager.Logger())
if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}

done := make(chan struct{})
watcher, err := pathwatcher.CreatePathWatcher([]string{
s.certFile, s.certKeyFile,
})
if err != nil {
return nil, fmt.Errorf("failed to create path watcher: %w", err)
}
go func() {
for {
s.manager.Logger().Info("watching for TLS config changes")
select {
case evt := <-watcher.Events:
removalMask := fsnotify.Remove | fsnotify.Rename
mask := fsnotify.Create | fsnotify.Write | removalMask
if (evt.Op & mask) != 0 {
err = s.reloadTLSConfig(s.manager.Logger())
if err != nil {
s.manager.Logger().Error("failed to reload TLS config: %w", err)
}
}
case <-done:
watcher.Close()
return
}
}
}()

s.partials = map[string]rego.PartialResult{}
s.preparedEvalQueries = newCache(pqMaxCacheSize)
s.defaultDecisionPath = s.generateDefaultDecisionPath()
Expand Down Expand Up @@ -274,6 +320,13 @@ func (s *Server) WithCertPool(pool *x509.CertPool) *Server {
return s
}

// WithTLSConfig sets the TLS configuration used by the server.
func (s *Server) WithTLSConfig(tlsConfig *TLSConfig) *Server {
s.certFile = tlsConfig.CertFile
s.certKeyFile = tlsConfig.KeyFile
return s
}

// WithStore sets the storage used by the server.
func (s *Server) WithStore(store storage.Store) *Server {
s.store = store
Expand Down