Skip to content

Commit

Permalink
don't close established connections on Listener.Close
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Oct 21, 2023
1 parent a263164 commit a550063
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 148 deletions.
26 changes: 14 additions & 12 deletions integrationtests/self/deadline_test.go
Expand Up @@ -14,7 +14,7 @@ import (
)

var _ = Describe("Stream deadline tests", func() {
setup := func() (*quic.Listener, quic.Stream, quic.Stream) {
setup := func() (serverStr, clientStr quic.Stream, close func()) {
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
strChan := make(chan quic.SendStream)
Expand All @@ -36,19 +36,21 @@ var _ = Describe("Stream deadline tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
clientStr, err := conn.OpenStream()
clientStr, err = conn.OpenStream()
Expect(err).ToNot(HaveOccurred())
_, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream
Expect(err).ToNot(HaveOccurred())
var serverStr quic.Stream
Eventually(strChan).Should(Receive(&serverStr))
return server, serverStr, clientStr
return serverStr, clientStr, func() {
Expect(server.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
}
}

Context("read deadlines", func() {
It("completes a transfer when the deadline is set", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
done := make(chan struct{})
Expand Down Expand Up @@ -82,8 +84,8 @@ var _ = Describe("Stream deadline tests", func() {
})

It("completes a transfer when the deadline is set concurrently", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
go func() {
Expand Down Expand Up @@ -132,8 +134,8 @@ var _ = Describe("Stream deadline tests", func() {

Context("write deadlines", func() {
It("completes a transfer when the deadline is set", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
done := make(chan struct{})
Expand Down Expand Up @@ -165,8 +167,8 @@ var _ = Describe("Stream deadline tests", func() {
})

It("completes a transfer when the deadline is set concurrently", func() {
server, serverStr, clientStr := setup()
defer server.Close()
serverStr, clientStr, closeFn := setup()
defer closeFn()

const timeout = time.Millisecond
readDone := make(chan struct{})
Expand Down
7 changes: 5 additions & 2 deletions integrationtests/self/handshake_test.go
Expand Up @@ -152,13 +152,14 @@ var _ = Describe("Handshake tests", func() {
Context("Certificate validation", func() {
It("accepts the certificate", func() {
runServer(getTLSConfig())
_, err := quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})

It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
Expand Down Expand Up @@ -187,6 +188,7 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Eventually(done).Should(BeClosed())
Expect(server.Addr()).To(Equal(local))
Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
Expand All @@ -196,13 +198,14 @@ var _ = Describe("Handshake tests", func() {

It("works with a long certificate chain", func() {
runServer(getTLSConfigWithLongCertChain())
_, err := quic.DialAddr(
conn, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})

It("errors if the server name doesn't match", func() {
Expand Down
30 changes: 18 additions & 12 deletions integrationtests/self/resumption_test.go
Expand Up @@ -52,7 +52,7 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
Expand All @@ -61,25 +61,27 @@ var _ = Describe("TLS session resumption", func() {
Expect(err).ToNot(HaveOccurred())
var sessionKey string
Eventually(puts).Should(Receive(&sessionKey))
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())

serverConn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
conn1.CloseWithError(0, "")

conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive(Equal(sessionKey)))
Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())

serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
conn2.CloseWithError(0, "")
})

It("doesn't use session resumption, if the config disables it", func() {
Expand All @@ -94,30 +96,32 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
conn1.CloseWithError(0, "")

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())

conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")

serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expand All @@ -142,30 +146,32 @@ var _ = Describe("TLS session resumption", func() {
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
conn, err := quic.DialAddr(
conn1, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Consistently(puts).ShouldNot(Receive())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
conn1.CloseWithError(0, "")

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
serverConn, err := server.Accept(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())

conn, err = quic.DialAddr(
conn2, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
tlsConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
defer conn2.CloseWithError(0, "")

serverConn, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
Expand Down
12 changes: 9 additions & 3 deletions integrationtests/self/timeout_test.go
Expand Up @@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer server.Close()

serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
Expand Down Expand Up @@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() {
Consistently(serverConnClosed).ShouldNot(BeClosed())

// make the go routine return
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})

Expand All @@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer proxy.Close()

serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
<-conn.Context().Done() // block until the connection is closed
close(serverConnClosed)
}()
Expand Down Expand Up @@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() {
Consistently(serverConnClosed).ShouldNot(BeClosed())

// make the go routine return
Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})
})
Expand All @@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() {
Expect(err).ToNot(HaveOccurred())
defer server.Close()

serverConnChan := make(chan quic.Connection, 1)
serverConnClosed := make(chan struct{})
go func() {
defer GinkgoRecover()
conn, err := server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
serverConnChan <- conn
conn.AcceptStream(context.Background()) // blocks until the connection is closed
close(serverConnClosed)
}()
Expand Down Expand Up @@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() {
_, err = str.Write([]byte("foobar"))
checkTimeoutError(err)

Expect(server.Close()).To(Succeed())
(<-serverConnChan).CloseWithError(0, "")
Eventually(serverConnClosed).Should(BeClosed())
})

Expand Down
1 change: 1 addition & 0 deletions integrationtests/self/uni_stream_test.go
Expand Up @@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() {
runReceivingPeer(client)
<-done1
<-done2
client.CloseWithError(0, "")
})
})
36 changes: 0 additions & 36 deletions mock_packet_handler_manager_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 0 additions & 17 deletions packet_handler_map.go
Expand Up @@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (
return handler, ok
}

func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(handler packetHandler) {
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
handler.shutdown()
wg.Done()
}(handler)
}
}
h.mutex.Unlock()
wg.Wait()
}

func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()

Expand Down
17 changes: 0 additions & 17 deletions packet_handler_map_test.go
Expand Up @@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})

It("closes the server", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
if i%2 == 0 {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
} else {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
conn.EXPECT().shutdown()
}
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.CloseServer()
})

It("closes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
Expand Down

0 comments on commit a550063

Please sign in to comment.