Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow TLS settings to be specified inline #472

Merged
merged 3 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
192 changes: 141 additions & 51 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
// No need for a RoundTripper that reloads the CA file automatically.
return newRT(tlsConfig)
}

return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT)
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT)
}

type authorizationCredentialsRoundTripper struct {
Expand Down Expand Up @@ -750,7 +749,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
if len(rt.config.TLSConfig.CAFile) == 0 {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport)
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.roundTripperSettings(), tlsTransport)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -817,6 +816,10 @@ func cloneRequest(r *http.Request) *http.Request {

// NewTLSConfig creates a new tls.Config from the given TLSConfig.
func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}

tlsConfig := &tls.Config{
InsecureSkipVerify: cfg.InsecureSkipVerify,
MinVersion: uint16(cfg.MinVersion),
Expand All @@ -831,7 +834,11 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {

// If a CA cert is provided then let's read it in so we can validate the
// scrape target's certificate properly.
if len(cfg.CAFile) > 0 {
if len(cfg.CA) > 0 {
if !updateRootCA(tlsConfig, []byte(cfg.CA)) {
return nil, fmt.Errorf("unable to use inline CA cert")
}
} else if len(cfg.CAFile) > 0 {
b, err := readCAFile(cfg.CAFile)
if err != nil {
return nil, err
Expand All @@ -844,12 +851,9 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
if len(cfg.ServerName) > 0 {
tlsConfig.ServerName = cfg.ServerName
}

// If a client cert & key is provided then configure TLS config accordingly.
if len(cfg.CertFile) > 0 && len(cfg.KeyFile) == 0 {
return nil, fmt.Errorf("client cert file %q specified without client key file", cfg.CertFile)
} else if len(cfg.KeyFile) > 0 && len(cfg.CertFile) == 0 {
return nil, fmt.Errorf("client key file %q specified without client cert file", cfg.KeyFile)
} else if len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 {
if cfg.usingClientCert() && cfg.usingClientKey() {
// Verify that client cert and key are valid.
if _, err := cfg.getClientCertificate(nil); err != nil {
return nil, err
Expand All @@ -862,6 +866,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {

// TLSConfig configures the options for TLS connections.
type TLSConfig struct {
// Text of the CA cert to use for the targets.
CA string `yaml:"ca,omitempty" json:"ca,omitempty"`
// Text of the client cert file for the targets.
Cert string `yaml:"cert,omitempty" json:"cert,omitempty"`
// Text of the client key file for the targets.
Key Secret `yaml:"key,omitempty" json:"key,omitempty"`
// The CA cert to use for the targets.
CAFile string `yaml:"ca_file,omitempty" json:"ca_file,omitempty"`
// The client cert file for the targets.
Expand Down Expand Up @@ -891,29 +901,77 @@ func (c *TLSConfig) SetDirectory(dir string) {
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain TLSConfig
return unmarshal((*plain)(c))
if err := unmarshal((*plain)(c)); err != nil {
return err
}
return c.Validate()
}

// readCertAndKey reads the cert and key files from the disk.
func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) {
certData, err := os.ReadFile(certFile)
if err != nil {
return nil, nil, err
// Validate validates the TLSConfig to check that only one of the inlined or
// file-based fields for the TLS CA, client certificate, and client key are
// used.
func (c *TLSConfig) Validate() error {
if len(c.CA) > 0 && len(c.CAFile) > 0 {
return fmt.Errorf("at most one of ca and ca_file must be configured")
}
if len(c.Cert) > 0 && len(c.CertFile) > 0 {
return fmt.Errorf("at most one of cert and cert_file must be configured")
}
if len(c.Key) > 0 && len(c.KeyFile) > 0 {
return fmt.Errorf("at most one of key and key_file must be configured")
}

keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, nil, err
if c.usingClientCert() && !c.usingClientKey() {
return fmt.Errorf("exactly one of key or key_file must be configured when a client certificate is configured")
} else if c.usingClientKey() && !c.usingClientCert() {
return fmt.Errorf("exactly one of cert or cert_file must be configured when a client key is configured")
}

return certData, keyData, nil
return nil
}

func (c *TLSConfig) usingClientCert() bool {
return len(c.Cert) > 0 || len(c.CertFile) > 0
}

func (c *TLSConfig) usingClientKey() bool {
return len(c.Key) > 0 || len(c.KeyFile) > 0
}

func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings {
return TLSRoundTripperSettings{
CA: c.CA,
CAFile: c.CAFile,
Cert: c.Cert,
CertFile: c.CertFile,
Key: string(c.Key),
KeyFile: c.KeyFile,
}
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
var (
certData, keyData []byte
err error
)

if c.CertFile != "" {
certData, err = os.ReadFile(c.CertFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s): %s", c.CertFile, err)
}
} else {
certData = []byte(c.Cert)
}

if c.KeyFile != "" {
keyData, err = os.ReadFile(c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client key (%s): %s", c.KeyFile, err)
}
} else {
keyData = []byte(c.Key)
}

cert, err := tls.X509KeyPair(certData, keyData)
Expand Down Expand Up @@ -946,30 +1004,32 @@ func updateRootCA(cfg *tls.Config, b []byte) bool {
// tlsRoundTripper is a RoundTripper that updates automatically its TLS
// configuration whenever the content of the CA file changes.
type tlsRoundTripper struct {
caFile string
certFile string
keyFile string
settings TLSRoundTripperSettings

// newRT returns a new RoundTripper.
newRT func(*tls.Config) (http.RoundTripper, error)

mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
hashCertFile []byte
hashKeyFile []byte
hashCAData []byte
hashCertData []byte
hashKeyData []byte
tlsConfig *tls.Config
}

type TLSRoundTripperSettings struct {
CA, CAFile string
Cert, CertFile string
Key, KeyFile string
}

func NewTLSRoundTripper(
cfg *tls.Config,
caFile, certFile, keyFile string,
settings TLSRoundTripperSettings,
newRT func(*tls.Config) (http.RoundTripper, error),
) (http.RoundTripper, error) {
t := &tlsRoundTripper{
caFile: caFile,
certFile: certFile,
keyFile: keyFile,
settings: settings,
newRT: newRT,
tlsConfig: cfg,
}
Expand All @@ -979,44 +1039,74 @@ func NewTLSRoundTripper(
return nil, err
}
t.rt = rt
_, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash()
_, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) {
b1, err := readCAFile(t.caFile)
if err != nil {
return nil, nil, nil, nil, err
func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) {
var (
caBytes, certBytes, keyBytes []byte

err error
)

if t.settings.CAFile != "" {
caBytes, err = os.ReadFile(t.settings.CAFile)
if err != nil {
return nil, nil, nil, nil, err
}
} else if t.settings.CA != "" {
caBytes = []byte(t.settings.CA)
}

if t.settings.CertFile != "" {
certBytes, err = os.ReadFile(t.settings.CertFile)
if err != nil {
return nil, nil, nil, nil, err
}
} else if t.settings.Cert != "" {
certBytes = []byte(t.settings.Cert)
}
h1 := sha256.Sum256(b1)

var h2, h3 [32]byte
if t.certFile != "" {
b2, b3, err := readCertAndKey(t.certFile, t.keyFile)
if t.settings.KeyFile != "" {
keyBytes, err = os.ReadFile(t.settings.KeyFile)
if err != nil {
return nil, nil, nil, nil, err
}
h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3)
} else if t.settings.Key != "" {
keyBytes = []byte(t.settings.Key)
}

var caHash, certHash, keyHash [32]byte

if len(caBytes) > 0 {
caHash = sha256.Sum256(caBytes)
}
if len(certBytes) > 0 {
certHash = sha256.Sum256(certBytes)
}
if len(keyBytes) > 0 {
keyHash = sha256.Sum256(keyBytes)
}

return b1, h1[:], h2[:], h3[:], nil
return caBytes, caHash[:], certHash[:], keyHash[:], nil
}

// RoundTrip implements the http.RoundTrip interface.
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash()
caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash()
if err != nil {
return nil, err
}

t.mtx.RLock()
equal := bytes.Equal(caHash[:], t.hashCAFile) &&
bytes.Equal(certHash[:], t.hashCertFile) &&
bytes.Equal(keyHash[:], t.hashKeyFile)
equal := bytes.Equal(caHash[:], t.hashCAData) &&
bytes.Equal(certHash[:], t.hashCertData) &&
bytes.Equal(keyHash[:], t.hashKeyData)
rt := t.rt
t.mtx.RUnlock()
if equal {
Expand All @@ -1029,7 +1119,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// using GetClientCertificate.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, caData) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CAFile)
}
rt, err = t.newRT(tlsConfig)
if err != nil {
Expand All @@ -1039,9 +1129,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

t.mtx.Lock()
t.rt = rt
t.hashCAFile = caHash[:]
t.hashCertFile = certHash[:]
t.hashKeyFile = keyHash[:]
t.hashCAData = caHash[:]
t.hashCertData = certHash[:]
t.hashKeyData = keyHash[:]
t.mtx.Unlock()

return rt.RoundTrip(req)
Expand Down