Skip to content

Commit

Permalink
add tls.ClientHelloInfo.Conn for recursive GetConfigForClient calls
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 4, 2023
1 parent 18d3846 commit 3970c33
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 53 deletions.
37 changes: 11 additions & 26 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,20 @@ var _ = Describe("Handshake tests", func() {

It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
var local, remote net.Addr
var local2, remote2 net.Addr
done := make(chan struct{})
tlsConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
defer close(done)
local = info.Conn.LocalAddr()
remote = info.Conn.RemoteAddr()
return getTLSConfig(), nil
conf := getTLSConfig()
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
defer close(done)
local2 = info.Conn.LocalAddr()
remote2 = info.Conn.RemoteAddr()
return &(conf.Certificates[0]), nil
}
return conf, nil
},
}
runServer(tlsConf)
Expand All @@ -162,30 +169,8 @@ var _ = Describe("Handshake tests", func() {
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
})

It("has the right local and remote address on the tls.Config.GetCertificate ClientHelloInfo.Conn", func() {
var local, remote net.Addr
done := make(chan struct{})
tlsConf := getTLSConfig()
tlsConf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
defer close(done)
local = info.Conn.LocalAddr()
remote = info.Conn.RemoteAddr()
cert := tlsConf.Certificates[0]
return &cert, nil
}
runServer(tlsConf)
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
Expect(local).To(Equal(local2))
Expect(remote).To(Equal(remote2))
})

It("works with a long certificate chain", func() {
Expand Down
36 changes: 24 additions & 12 deletions internal/handshake/crypto_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,25 +127,37 @@ func NewCryptoSetupServer(

quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT)
if quicConf.TLSConfig.GetConfigForClient != nil {
gcfc := quicConf.TLSConfig.GetConfigForClient
quicConf.TLSConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)

cs.tlsConf = quicConf.TLSConfig
cs.conn = qtls.QUICServer(quicConf)

return cs
}

// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gcfc(info)
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if quicConf.TLSConfig.GetCertificate != nil {
gc := quicConf.TLSConfig.GetCertificate
quicConf.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}

cs.tlsConf = quicConf.TLSConfig
cs.conn = qtls.QUICServer(quicConf)

return cs
}

func newCryptoSetup(
Expand Down
92 changes: 77 additions & 15 deletions internal/handshake/crypto_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ const (
)

var _ = Describe("Crypto Setup TLS", func() {
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Expect(err).ToNot(HaveOccurred())
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
Expect(err).ToNot(HaveOccurred())
return tls.Certificate{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}
}

var clientConf, serverConf *tls.Config

BeforeEach(func() {
Expand Down Expand Up @@ -86,26 +105,69 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})

Context("doing the handshake", func() {
generateCert := func() tls.Certificate {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
var (
local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
)

It("wraps GetCertificate", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
cert := generateCert()
return &cert, nil
},
}
addConnToClientHelloInfo(tlsConf, local, remote)
_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
})

It("wraps GetConfigForClient", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
return &tls.Config{}, nil
},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
addConnToClientHelloInfo(tlsConf, local, remote)
_, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
return tls.Certificate{
PrivateKey: priv,
Certificate: [][]byte{certDER},
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
})

It("wraps GetConfigForClient, recursively", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{}
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
conf := tlsConf.Clone()
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
cert := generateCert()
return &cert, nil
}
return conf, nil
}
}
addConnToClientHelloInfo(tlsConf, local, remote)
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
})
})

Context("doing the handshake", func() {
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
rttStats := &utils.RTTStats{}
rttStats.UpdateRTT(rtt, 0, time.Now())
Expand Down

0 comments on commit 3970c33

Please sign in to comment.