Skip to content

Commit

Permalink
Merge pull request #472 from grafana/inline-ca-strings
Browse files Browse the repository at this point in the history
Allow TLS settings to be specified inline
  • Loading branch information
roidelapluie committed May 11, 2023
2 parents f505d86 + bcb00f1 commit 085fa47
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 53 deletions.
192 changes: 141 additions & 51 deletions config/http_config.go
Expand Up @@ -579,8 +579,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
// No need for a RoundTripper that reloads the CA file automatically.
return newRT(tlsConfig)
}

return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT)
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT)
}

type authorizationCredentialsRoundTripper struct {
Expand Down Expand Up @@ -750,7 +749,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
if len(rt.config.TLSConfig.CAFile) == 0 {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport)
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.roundTripperSettings(), tlsTransport)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -817,6 +816,10 @@ func cloneRequest(r *http.Request) *http.Request {

// NewTLSConfig creates a new tls.Config from the given TLSConfig.
func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}

tlsConfig := &tls.Config{
InsecureSkipVerify: cfg.InsecureSkipVerify,
MinVersion: uint16(cfg.MinVersion),
Expand All @@ -831,7 +834,11 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {

// If a CA cert is provided then let's read it in so we can validate the
// scrape target's certificate properly.
if len(cfg.CAFile) > 0 {
if len(cfg.CA) > 0 {
if !updateRootCA(tlsConfig, []byte(cfg.CA)) {
return nil, fmt.Errorf("unable to use inline CA cert")
}
} else if len(cfg.CAFile) > 0 {
b, err := readCAFile(cfg.CAFile)
if err != nil {
return nil, err
Expand All @@ -844,12 +851,9 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
if len(cfg.ServerName) > 0 {
tlsConfig.ServerName = cfg.ServerName
}

// If a client cert & key is provided then configure TLS config accordingly.
if len(cfg.CertFile) > 0 && len(cfg.KeyFile) == 0 {
return nil, fmt.Errorf("client cert file %q specified without client key file", cfg.CertFile)
} else if len(cfg.KeyFile) > 0 && len(cfg.CertFile) == 0 {
return nil, fmt.Errorf("client key file %q specified without client cert file", cfg.KeyFile)
} else if len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 {
if cfg.usingClientCert() && cfg.usingClientKey() {
// Verify that client cert and key are valid.
if _, err := cfg.getClientCertificate(nil); err != nil {
return nil, err
Expand All @@ -862,6 +866,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {

// TLSConfig configures the options for TLS connections.
type TLSConfig struct {
// Text of the CA cert to use for the targets.
CA string `yaml:"ca,omitempty" json:"ca,omitempty"`
// Text of the client cert file for the targets.
Cert string `yaml:"cert,omitempty" json:"cert,omitempty"`
// Text of the client key file for the targets.
Key Secret `yaml:"key,omitempty" json:"key,omitempty"`
// The CA cert to use for the targets.
CAFile string `yaml:"ca_file,omitempty" json:"ca_file,omitempty"`
// The client cert file for the targets.
Expand Down Expand Up @@ -891,29 +901,77 @@ func (c *TLSConfig) SetDirectory(dir string) {
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain TLSConfig
return unmarshal((*plain)(c))
if err := unmarshal((*plain)(c)); err != nil {
return err
}
return c.Validate()
}

// readCertAndKey reads the cert and key files from the disk.
func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) {
certData, err := os.ReadFile(certFile)
if err != nil {
return nil, nil, err
// Validate validates the TLSConfig to check that only one of the inlined or
// file-based fields for the TLS CA, client certificate, and client key are
// used.
func (c *TLSConfig) Validate() error {
if len(c.CA) > 0 && len(c.CAFile) > 0 {
return fmt.Errorf("at most one of ca and ca_file must be configured")
}
if len(c.Cert) > 0 && len(c.CertFile) > 0 {
return fmt.Errorf("at most one of cert and cert_file must be configured")
}
if len(c.Key) > 0 && len(c.KeyFile) > 0 {
return fmt.Errorf("at most one of key and key_file must be configured")
}

keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, nil, err
if c.usingClientCert() && !c.usingClientKey() {
return fmt.Errorf("exactly one of key or key_file must be configured when a client certificate is configured")
} else if c.usingClientKey() && !c.usingClientCert() {
return fmt.Errorf("exactly one of cert or cert_file must be configured when a client key is configured")
}

return certData, keyData, nil
return nil
}

func (c *TLSConfig) usingClientCert() bool {
return len(c.Cert) > 0 || len(c.CertFile) > 0
}

func (c *TLSConfig) usingClientKey() bool {
return len(c.Key) > 0 || len(c.KeyFile) > 0
}

func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings {
return TLSRoundTripperSettings{
CA: c.CA,
CAFile: c.CAFile,
Cert: c.Cert,
CertFile: c.CertFile,
Key: string(c.Key),
KeyFile: c.KeyFile,
}
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
var (
certData, keyData []byte
err error
)

if c.CertFile != "" {
certData, err = os.ReadFile(c.CertFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s): %s", c.CertFile, err)
}
} else {
certData = []byte(c.Cert)
}

if c.KeyFile != "" {
keyData, err = os.ReadFile(c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client key (%s): %s", c.KeyFile, err)
}
} else {
keyData = []byte(c.Key)
}

cert, err := tls.X509KeyPair(certData, keyData)
Expand Down Expand Up @@ -946,30 +1004,32 @@ func updateRootCA(cfg *tls.Config, b []byte) bool {
// tlsRoundTripper is a RoundTripper that updates automatically its TLS
// configuration whenever the content of the CA file changes.
type tlsRoundTripper struct {
caFile string
certFile string
keyFile string
settings TLSRoundTripperSettings

// newRT returns a new RoundTripper.
newRT func(*tls.Config) (http.RoundTripper, error)

mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
hashCertFile []byte
hashKeyFile []byte
hashCAData []byte
hashCertData []byte
hashKeyData []byte
tlsConfig *tls.Config
}

type TLSRoundTripperSettings struct {
CA, CAFile string
Cert, CertFile string
Key, KeyFile string
}

func NewTLSRoundTripper(
cfg *tls.Config,
caFile, certFile, keyFile string,
settings TLSRoundTripperSettings,
newRT func(*tls.Config) (http.RoundTripper, error),
) (http.RoundTripper, error) {
t := &tlsRoundTripper{
caFile: caFile,
certFile: certFile,
keyFile: keyFile,
settings: settings,
newRT: newRT,
tlsConfig: cfg,
}
Expand All @@ -979,44 +1039,74 @@ func NewTLSRoundTripper(
return nil, err
}
t.rt = rt
_, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash()
_, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) {
b1, err := readCAFile(t.caFile)
if err != nil {
return nil, nil, nil, nil, err
func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) {
var (
caBytes, certBytes, keyBytes []byte

err error
)

if t.settings.CAFile != "" {
caBytes, err = os.ReadFile(t.settings.CAFile)
if err != nil {
return nil, nil, nil, nil, err
}
} else if t.settings.CA != "" {
caBytes = []byte(t.settings.CA)
}

if t.settings.CertFile != "" {
certBytes, err = os.ReadFile(t.settings.CertFile)
if err != nil {
return nil, nil, nil, nil, err
}
} else if t.settings.Cert != "" {
certBytes = []byte(t.settings.Cert)
}
h1 := sha256.Sum256(b1)

var h2, h3 [32]byte
if t.certFile != "" {
b2, b3, err := readCertAndKey(t.certFile, t.keyFile)
if t.settings.KeyFile != "" {
keyBytes, err = os.ReadFile(t.settings.KeyFile)
if err != nil {
return nil, nil, nil, nil, err
}
h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3)
} else if t.settings.Key != "" {
keyBytes = []byte(t.settings.Key)
}

var caHash, certHash, keyHash [32]byte

if len(caBytes) > 0 {
caHash = sha256.Sum256(caBytes)
}
if len(certBytes) > 0 {
certHash = sha256.Sum256(certBytes)
}
if len(keyBytes) > 0 {
keyHash = sha256.Sum256(keyBytes)
}

return b1, h1[:], h2[:], h3[:], nil
return caBytes, caHash[:], certHash[:], keyHash[:], nil
}

// RoundTrip implements the http.RoundTrip interface.
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash()
caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash()
if err != nil {
return nil, err
}

t.mtx.RLock()
equal := bytes.Equal(caHash[:], t.hashCAFile) &&
bytes.Equal(certHash[:], t.hashCertFile) &&
bytes.Equal(keyHash[:], t.hashKeyFile)
equal := bytes.Equal(caHash[:], t.hashCAData) &&
bytes.Equal(certHash[:], t.hashCertData) &&
bytes.Equal(keyHash[:], t.hashKeyData)
rt := t.rt
t.mtx.RUnlock()
if equal {
Expand All @@ -1029,7 +1119,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// using GetClientCertificate.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, caData) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CAFile)
}
rt, err = t.newRT(tlsConfig)
if err != nil {
Expand All @@ -1039,9 +1129,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

t.mtx.Lock()
t.rt = rt
t.hashCAFile = caHash[:]
t.hashCertFile = certHash[:]
t.hashKeyFile = keyHash[:]
t.hashCAData = caHash[:]
t.hashCertData = certHash[:]
t.hashKeyData = keyHash[:]
t.mtx.Unlock()

return rt.RoundTrip(req)
Expand Down

0 comments on commit 085fa47

Please sign in to comment.