Skip to content

Commit

Permalink
Refactor Managers into on-demand config (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
mholt committed May 11, 2023
1 parent 53140d5 commit 8728b18
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 47 deletions.
3 changes: 3 additions & 0 deletions certificates.go
Expand Up @@ -113,6 +113,9 @@ func (cert Certificate) HasTag(tag string) bool {
// resolution of ASN.1 UTCTime/GeneralizedTime by including the extra fraction
// of a second of certificate validity beyond the NotAfter value.
func expiresAt(cert *x509.Certificate) time.Time {
if cert == nil {
return time.Time{}
}
return cert.NotAfter.Truncate(time.Second).Add(1 * time.Second)
}

Expand Down
9 changes: 9 additions & 0 deletions certmagic.go
Expand Up @@ -270,6 +270,15 @@ type OnDemandConfig struct {
// request will be denied.
DecisionFunc func(name string) error

// Sources for getting new, unmanaged certificates.
// They will be invoked only during TLS handshakes
// before on-demand certificate management occurs,
// for certificates that are not already loaded into
// the in-memory cache.
//
// TODO: EXPERIMENTAL: subject to change and/or removal.
Managers []Manager

// List of allowed hostnames (SNI values) for
// deferred (on-demand) obtaining of certificates.
// Used only by higher-level functions in this
Expand Down
12 changes: 0 additions & 12 deletions config.go
Expand Up @@ -95,15 +95,6 @@ type Config struct {
// turn until one succeeds.
Issuers []Issuer

// Sources for getting new, unmanaged certificates.
// They will be invoked only during TLS handshakes
// before on-demand certificate management occurs,
// for certificates that are not already loaded into
// the in-memory cache.
//
// TODO: EXPERIMENTAL: subject to change and/or removal.
Managers []Manager

// The source of new private keys for certificates;
// the default KeySource is StandardKeyGenerator.
KeySource KeyGenerator
Expand Down Expand Up @@ -234,9 +225,6 @@ func newWithCache(certCache *Cache, cfg Config) *Config {
cfg.Issuers = []Issuer{NewACMEIssuer(&cfg, DefaultACME)}
}
}
if cfg.Managers == nil {
cfg.Managers = Default.Managers
}
if cfg.RenewalWindowRatio == 0 {
cfg.RenewalWindowRatio = Default.RenewalWindowRatio
}
Expand Down
62 changes: 27 additions & 35 deletions handshake.go
Expand Up @@ -81,7 +81,7 @@ func (cfg *Config) GetCertificateWithContext(ctx context.Context, clientHello *t
}

// get the certificate and serve it up
cert, err := cfg.getCertDuringHandshake(ctx, clientHello, true, true)
cert, err := cfg.getCertDuringHandshake(ctx, clientHello, true)

return &cert.Certificate, err
}
Expand Down Expand Up @@ -253,19 +253,19 @@ func DefaultCertificateSelector(hello *tls.ClientHelloInfo, choices []Certificat
// An error will be returned if and only if no certificate is available.
//
// This function is safe for concurrent use.
func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.ClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
log := logWithRemote(cfg.Logger.Named("handshake"), hello)
func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.ClientHelloInfo, loadOrObtainIfNecessary bool) (Certificate, error) {
logger := logWithRemote(cfg.Logger.Named("handshake"), hello)
name := cfg.getNameFromClientHello(hello)

// First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := cfg.getCertificateFromCache(hello)
if matched {
log.Debug("matched certificate in cache",
logger.Debug("matched certificate in cache",
zap.Strings("subjects", cert.Names),
zap.Bool("managed", cert.managed),
zap.Time("expiration", expiresAt(cert.Leaf)),
zap.String("hash", cert.hash))
if cert.managed && cfg.OnDemand != nil && obtainIfNecessary {
if cert.managed && cfg.OnDemand != nil && loadOrObtainIfNecessary {
// On-demand certificates are maintained in the background, but
// maintenance is triggered by handshakes instead of by a timer
// as in maintain.go.
Expand Down Expand Up @@ -294,7 +294,7 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
timeout.Stop()
}

return cfg.getCertDuringHandshake(ctx, hello, false, false)
return cfg.getCertDuringHandshake(ctx, hello, false)
} else {
// no other goroutine is currently trying to load this cert
wait = make(chan struct{})
Expand All @@ -319,7 +319,7 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client

// If an external Manager is configured, try to get it from them.
// Only continue to use our own logic if it returns empty+nil.
externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, log)
externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, logger)
if err != nil {
return Certificate{}, err
}
Expand All @@ -345,45 +345,45 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
cacheAlmostFull := cacheCapacity > 0 && float64(cacheSize) >= cacheCapacity*.9
loadDynamically := cfg.OnDemand != nil || cacheAlmostFull

if loadDynamically && loadIfNecessary {
if loadDynamically && loadOrObtainIfNecessary {
// Check to see if we have one on disk
loadedCert, err := cfg.loadCertFromStorage(ctx, log, hello)
loadedCert, err := cfg.loadCertFromStorage(ctx, logger, hello)
if err == nil {
return loadedCert, nil
}
log.Debug("did not load cert from storage",
logger.Debug("did not load cert from storage",
zap.String("server_name", hello.ServerName),
zap.Error(err))
if cfg.OnDemand != nil {
// By this point, we need to ask the CA for a certificate
return cfg.obtainOnDemandCertificate(ctx, hello)
}
return loadedCert, nil
}

// Fall back to another certificate if there is one (either DefaultServerName or FallbackServerName)
if defaulted {
log.Debug("fell back to other certificate",
logger.Debug("fell back to default certificate",
zap.Strings("subjects", cert.Names),
zap.Bool("managed", cert.managed),
zap.Time("expiration", expiresAt(cert.Leaf)),
zap.String("hash", cert.hash))
return cert, nil
}

log.Debug("no certificate matching TLS ClientHello",
logger.Debug("no certificate matching TLS ClientHello",
zap.String("server_name", hello.ServerName),
zap.String("remote", hello.Conn.RemoteAddr().String()),
zap.String("identifier", name),
zap.Uint16s("cipher_suites", hello.CipherSuites),
zap.Float64("cert_cache_fill", float64(cacheSize)/cacheCapacity), // may be approximate! because we are not within the lock
zap.Bool("load_if_necessary", loadIfNecessary),
zap.Bool("obtain_if_necessary", obtainIfNecessary),
zap.Bool("load_or_obtain_if_necessary", loadOrObtainIfNecessary),
zap.Bool("on_demand", cfg.OnDemand != nil))

return Certificate{}, fmt.Errorf("no certificate available for '%s'", name)
}

func (cfg *Config) loadCertFromStorage(ctx context.Context, log *zap.Logger, hello *tls.ClientHelloInfo) (Certificate, error) {
func (cfg *Config) loadCertFromStorage(ctx context.Context, logger *zap.Logger, hello *tls.ClientHelloInfo) (Certificate, error) {
name := normalizedName(hello.ServerName)
loadedCert, err := cfg.CacheManagedCertificate(ctx, name)
if errors.Is(err, fs.ErrNotExist) {
Expand All @@ -395,14 +395,14 @@ func (cfg *Config) loadCertFromStorage(ctx context.Context, log *zap.Logger, hel
if err != nil {
return Certificate{}, fmt.Errorf("no matching certificate to load for %s: %w", name, err)
}
log.Debug("loaded certificate from storage",
logger.Debug("loaded certificate from storage",
zap.Strings("subjects", loadedCert.Names),
zap.Bool("managed", loadedCert.managed),
zap.Time("expiration", expiresAt(loadedCert.Leaf)),
zap.String("hash", loadedCert.hash))
loadedCert, err = cfg.handshakeMaintenance(ctx, hello, loadedCert)
if err != nil {
log.Error("maintaining newly-loaded certificate",
logger.Error("maintaining newly-loaded certificate",
zap.String("server_name", name),
zap.Error(err))
}
Expand Down Expand Up @@ -465,10 +465,6 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli

name := cfg.getNameFromClientHello(hello)

getCertWithoutReobtaining := func() (Certificate, error) {
return cfg.loadCertFromStorage(ctx, log, hello)
}

// We must protect this process from happening concurrently, so synchronize.
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
Expand All @@ -486,7 +482,7 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli
timeout.Stop()
}

return getCertWithoutReobtaining()
return cfg.loadCertFromStorage(ctx, log, hello)
}

// looks like it's up to us to do all the work and obtain the cert.
Expand Down Expand Up @@ -525,7 +521,7 @@ func (cfg *Config) obtainOnDemandCertificate(ctx context.Context, hello *tls.Cli

// success; certificate was just placed on disk, so
// we need only restart serving the certificate
return getCertWithoutReobtaining()
return cfg.loadCertFromStorage(ctx, log, hello)
}

// handshakeMaintenance performs a check on cert for expiration and OCSP validity.
Expand Down Expand Up @@ -613,10 +609,6 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
timeLeft := time.Until(expiresAt(currentCert.Leaf))
revoked := currentCert.ocsp != nil && currentCert.ocsp.Status == ocsp.Revoked

getCertWithoutReobtaining := func() (Certificate, error) {
return cfg.loadCertFromStorage(ctx, log, hello)
}

// see if another goroutine is already working on this certificate
obtainCertWaitChansMu.Lock()
wait, ok := obtainCertWaitChans[name]
Expand Down Expand Up @@ -651,7 +643,7 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
timeout.Stop()
}

return getCertWithoutReobtaining()
return cfg.loadCertFromStorage(ctx, log, hello)
}

// looks like it's up to us to do all the work and renew the cert
Expand Down Expand Up @@ -726,7 +718,7 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
return newCert, err
}

return getCertWithoutReobtaining()
return cfg.loadCertFromStorage(ctx, log, hello)
}

// if the certificate hasn't expired, we can serve what we have and renew in the background
Expand All @@ -744,20 +736,20 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
// getCertFromAnyCertManager gets a certificate from cfg's Managers. If there are no Managers defined, this is
// a no-op that returns empty values. Otherwise, it gets a certificate for hello from the first Manager that
// returns a certificate and no error.
func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.ClientHelloInfo, log *zap.Logger) (Certificate, error) {
func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.ClientHelloInfo, logger *zap.Logger) (Certificate, error) {
// fast path if nothing to do
if len(cfg.Managers) == 0 {
if cfg.OnDemand == nil || len(cfg.OnDemand.Managers) == 0 {
return Certificate{}, nil
}

var upstreamCert *tls.Certificate

// try all the GetCertificate methods on external managers; use first one that returns a certificate
for i, certManager := range cfg.Managers {
for i, certManager := range cfg.OnDemand.Managers {
var err error
upstreamCert, err = certManager.GetCertificate(ctx, hello)
if err != nil {
log.Error("getting certificate from external certificate manager",
logger.Error("getting certificate from external certificate manager",
zap.String("sni", hello.ServerName),
zap.Int("cert_manager", i),
zap.Error(err))
Expand All @@ -768,7 +760,7 @@ func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.Cli
}
}
if upstreamCert == nil {
log.Debug("all external certificate managers yielded no certificates and no errors", zap.String("sni", hello.ServerName))
logger.Debug("all external certificate managers yielded no certificates and no errors", zap.String("sni", hello.ServerName))
return Certificate{}, nil
}

Expand All @@ -778,7 +770,7 @@ func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.Cli
return Certificate{}, fmt.Errorf("external certificate manager: %s: filling cert from leaf: %v", hello.ServerName, err)
}

log.Debug("using externally-managed certificate",
logger.Debug("using externally-managed certificate",
zap.String("sni", hello.ServerName),
zap.Strings("names", cert.Names),
zap.Time("expiration", expiresAt(cert.Leaf)))
Expand Down

0 comments on commit 8728b18

Please sign in to comment.