Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't close established connections on Listener.Close, when using a Transport #4072

Merged
merged 2 commits into from Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 14 additions & 12 deletions integrationtests/self/deadline_test.go
Expand Up @@ -14,7 +14,7 @@ import (
)

var _ = Describe("Stream deadline tests", func() {
setup := func() (*quic.Listener, quic.Stream, quic.Stream) {
setup := func() (serverStr, clientStr quic.Stream, close func()) {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
strChan := make(chan quic.SendStream)
Expand All @@ -36,19 +36,21 @@ var _ = Describe("Stream deadline tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
clientStr, err := conn.OpenStream()
clientStr, err = conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
Expect(err).ToNot(HaveOccurred())
var serverStr quic.Stream
Eventually(strChan).Should(Receive(&serverStr))
return server, serverStr, clientStr
return serverStr, clientStr, func() {
Expect(server.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
}
}

Context("read deadlines", func() {
It("completes a transfer when the deadline is set", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
done := make(chan struct{})
Expand Down Expand Up @@ -82,8 +84,8 @@ var _ = Describe("Stream deadline tests", func() {
})

It("completes a transfer when the deadline is set concurrently", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
go func() {
Expand Down Expand Up @@ -132,8 +134,8 @@ var _ = Describe("Stream deadline tests", func() {

Context("write deadlines", func() {
It("completes a transfer when the deadline is set", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
done := make(chan struct{})
Expand Down Expand Up @@ -165,8 +167,8 @@ var _ = Describe("Stream deadline tests", func() {
})

It("completes a transfer when the deadline is set concurrently", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
readDone := make(chan struct{})
Expand Down
7 changes: 5 additions & 2 deletions integrationtests/self/handshake_test.go
Expand Up @@ -152,13 +152,14 @@ var _ = Describe("Handshake tests", func() {
Context("Certificate validation", func() {
It("accepts the certificate", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})

It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
Expand Down Expand Up @@ -187,6 +188,7 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
Expand All @@ -196,13 +198,14 @@ var _ = Describe("Handshake tests", func() {

It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
_, err := quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})

It("errors if the server name doesn't match", func() {
Expand Down
30 changes: 18 additions & 12 deletions integrationtests/self/resumption_test.go
Expand Up @@ -52,34 +52,36 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())

serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())

conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())

serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
conn2.CloseWithError(0, "")
})

It("doesn't use session resumption, if the config disables it", func() {
Expand All @@ -94,30 +96,32 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn1.CloseWithError(0, "")
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())

conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")

serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expand All @@ -142,30 +146,32 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn1.CloseWithError(0, "")

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())

conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")

serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expand Down
12 changes: 9 additions & 3 deletions integrationtests/self/timeout_test.go
Expand Up @@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer server.Close()

serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
Expand Down Expand Up @@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() {
Consistently(serverConnClosed).ShouldNot(BeClosed())

// make the go routine return
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})

Expand All @@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()

serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
<-conn.Context().Done() // block until the connection is closed
close(serverConnClosed)
}()
Expand Down Expand Up @@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() {
Consistently(serverConnClosed).ShouldNot(BeClosed())

// make the go routine return
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
})
Expand All @@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer server.Close()

serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
Expand Down Expand Up @@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() {
_, err = str.Write([]byte("foobar"))
checkTimeoutError(err)

Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})

Expand Down
1 change: 1 addition & 0 deletions integrationtests/self/uni_stream_test.go
Expand Up @@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() {
runReceivingPeer(client)
<-done1
<-done2
client.CloseWithError(0, "")
})
})
36 changes: 0 additions & 36 deletions mock_packet_handler_manager_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 0 additions & 17 deletions packet_handler_map.go
Expand Up @@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (
return handler, ok
}

func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(handler packetHandler) {
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
handler.shutdown()
wg.Done()
}(handler)
}
}
h.mutex.Unlock()
wg.Wait()
}

func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()

Expand Down
17 changes: 0 additions & 17 deletions packet_handler_map_test.go
Expand Up @@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})

It("closes the server", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
if i%2 == 0 {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
} else {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
conn.EXPECT().shutdown()
}
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.CloseServer()
})

It("closes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
Expand Down