diff --git a/connection.go b/connection.go index b17f3ce81a3..877f2d03311 100644 --- a/connection.go +++ b/connection.go @@ -243,7 +243,7 @@ var newConnection = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(), + oneRTTStream: newCryptoStream(true), perspective: protocol.PerspectiveServer, tracer: tracer, logger: logger, @@ -391,7 +391,7 @@ var newClientConnection = func( s.logger, ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream() + oneRTTStream := newCryptoStream(true) params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -447,8 +447,8 @@ var newClientConnection = func( } func (s *connection) preSetup() { - s.initialStream = newCryptoStream() - s.handshakeStream = newCryptoStream() + s.initialStream = newCryptoStream(false) + s.handshakeStream = newCryptoStream(false) s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) diff --git a/crypto_stream.go b/crypto_stream.go index 4be2a07ae1a..5ce2125decf 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -30,10 +30,17 @@ type cryptoStreamImpl struct { writeOffset protocol.ByteCount writeBuf []byte + + // Reassemble TLS handshake messages before returning them from GetCryptoData. + // This is only needed because crypto/tls doesn't correctly handle post-handshake messages. + onlyCompleteMsg bool } -func newCryptoStream() cryptoStream { - return &cryptoStreamImpl{queue: newFrameSorter()} +func newCryptoStream(onlyCompleteMsg bool) cryptoStream { + return &cryptoStreamImpl{ + queue: newFrameSorter(), + onlyCompleteMsg: onlyCompleteMsg, + } } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { @@ -71,6 +78,20 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { + if s.onlyCompleteMsg { + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg + } + b := s.msgBuf s.msgBuf = nil return b diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 9a4a2ee57f9..67de9149966 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/rand" "fmt" "github.com/quic-go/quic-go/internal/protocol" @@ -15,7 +16,7 @@ var _ = Describe("Crypto Stream", func() { var str cryptoStream BeforeEach(func() { - str = newCryptoStream() + str = newCryptoStream(false) }) Context("handling incoming data", func() { @@ -137,4 +138,23 @@ var _ = Describe("Crypto Stream", func() { Expect(f.Data).To(Equal([]byte("bar"))) }) }) + + It("reassembles data", func() { + str = newCryptoStream(true) + data := make([]byte, 1337) + l := len(data) - 4 + data[1] = uint8(l >> 16) + data[2] = uint8(l >> 8) + data[3] = uint8(l) + rand.Read(data[4:]) + + for i, b := range data { + Expect(str.GetCryptoData()).To(BeEmpty()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.ByteCount(i), + Data: []byte{b}, + })).To(Succeed()) + } + Expect(str.GetCryptoData()).To(Equal(data)) + }) }) diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index a3f549eec7f..be34c22da9e 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -408,6 +408,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. } client.HandleMessage(ticket, protocol.Encryption1RTT) } + if sendPostHandshakeMessageToClient { fmt.Println("sending post handshake message to the client at", messageToReplaceEncLevel) client.HandleMessage(data, messageToReplaceEncLevel)