Skip to content

Commit

Permalink
don't close established connections on Listener.Close
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Oct 21, 2023
1 parent a263164 commit ac87a49
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 116 deletions.
36 changes: 0 additions & 36 deletions mock_packet_handler_manager_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 0 additions & 17 deletions packet_handler_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,23 +220,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (
return handler, ok
}

func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
var wg sync.WaitGroup
for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(handler packetHandler) {
// blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
handler.shutdown()
wg.Done()
}(handler)
}
}
h.mutex.Unlock()
wg.Wait()
}

func (h *packetHandlerMap) Close(e error) {
h.mutex.Lock()

Expand Down
17 changes: 0 additions & 17 deletions packet_handler_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,6 @@ var _ = Describe("Packet Handler Map", func() {
Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse())
})

It("closes the server", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
for i := 0; i < 10; i++ {
conn := NewMockPacketHandler(mockCtrl)
if i%2 == 0 {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient)
} else {
conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer)
conn.EXPECT().shutdown()
}
b := make([]byte, 12)
rand.Read(b)
m.Add(protocol.ParseConnectionID(b), conn)
}
m.CloseServer()
})

It("closes", func() {
m := newPacketHandlerMap(nil, nil, utils.DefaultLogger)
testErr := errors.New("shutdown")
Expand Down
51 changes: 23 additions & 28 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ type packetHandlerManager interface {
GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
Close(error)
CloseServer()
connRunner
}

Expand Down Expand Up @@ -104,9 +103,7 @@ type baseServer struct {
protocol.VersionNumber,
) quicConn

serverError error
errorChan chan struct{}
closed bool
errorChan chan struct{} // is closed when the server is closed
running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan rejectedPacket
Expand All @@ -132,7 +129,10 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) {
return l.baseServer.Accept(ctx)
}

// Close the server. All active connections will be closed.
// Close closes the listener.
// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted.
// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error.
// Closing the listener doesn't have any effect on already established connections.
func (l *Listener) Close() error {
return l.baseServer.Close()
}
Expand Down Expand Up @@ -321,38 +321,27 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) {
atomic.AddInt32(&s.connQueueLen, -1)
return conn, nil
case <-s.errorChan:
return nil, s.serverError
return nil, ErrServerClosed
}
}

// Close the server
func (s *baseServer) Close() error {
s.mutex.Lock()
if s.closed {
s.mutex.Unlock()
return nil
}
if s.serverError == nil {
s.serverError = ErrServerClosed
}
s.closed = true
close(s.errorChan)
s.mutex.Unlock()

<-s.running
s.onClose()
s.close(true)
return nil
}

func (s *baseServer) setCloseError(e error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
func (s *baseServer) close(notifyOnClose bool) {
select {
case <-s.errorChan: // already closed
return
default:
}
s.closed = true
s.serverError = e
close(s.errorChan)

<-s.running
if notifyOnClose {
s.onClose()
}
}

// Addr returns the server's network address
Expand Down Expand Up @@ -701,15 +690,21 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
func (s *baseServer) handleNewConn(conn quicConn) {
connCtx := conn.Context()
if s.acceptEarlyConns {
// wait until the early connection is ready (or the handshake fails)
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
case <-conn.earlyConnReady():
case <-connCtx.Done():
return
}
} else {
// wait until the handshake is complete (or fails)
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
return
case <-conn.HandshakeComplete():
case <-connCtx.Done():
return
Expand Down
70 changes: 56 additions & 14 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any())
})

It("sends a Version Negotiation Packet for unsupported versions", func() {
Expand Down Expand Up @@ -527,6 +529,8 @@ var _ = Describe("Server", func() {
// make sure we're using a server-generated connection ID
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any())
})

It("drops packets if the receive queue is full", func() {
Expand Down Expand Up @@ -565,6 +569,8 @@ var _ = Describe("Server", func() {
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
// shutdown
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
return conn
}

Expand Down Expand Up @@ -956,30 +962,67 @@ var _ = Describe("Server", func() {
})

Context("accepting connections", func() {
It("returns Accept when an error occurs", func() {
testErr := errors.New("test err")

It("returns Accept when closed", func() {
done := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(ErrServerClosed))
close(done)
}()

serv.setCloseError(testErr)
serv.Close()
Eventually(done).Should(BeClosed())
serv.onClose() // shutdown
})

It("returns immediately, if an error occurred before", func() {
testErr := errors.New("test err")
serv.setCloseError(testErr)
serv.Close()
for i := 0; i < 3; i++ {
_, err := serv.Accept(context.Background())
Expect(err).To(MatchError(testErr))
Expect(err).To(MatchError(ErrServerClosed))
}
serv.onClose() // shutdown
})

It("closes connection that are still handshaking after Close", func() {
serv.Close()

serv.newConn = func(
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
_ *protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ ConnectionIDGenerator,
_ protocol.StatelessResetToken,
conf *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ *logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run()
conn.EXPECT().Context().Return(context.Background())
return conn
}
phm.EXPECT().Get(gomock.Any())
phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
phm.EXPECT().GetStatelessResetToken(gomock.Any())
_, ok := fn()
return ok
})
serv.handleInitialImpl(
receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})},
)
})

It("returns when the context is canceled", func() {
Expand Down Expand Up @@ -1343,10 +1386,7 @@ var _ = Describe("Server", func() {
serv.connHandler = phm
})

AfterEach(func() {
phm.EXPECT().CloseServer().MaxTimes(1)
tr.Close()
})
AfterEach(func() { tr.Close() })

It("passes packets to existing connections", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
Expand Down Expand Up @@ -1425,6 +1465,8 @@ var _ = Describe("Server", func() {
conn.EXPECT().earlyConnReady()
conn.EXPECT().Context().Return(context.Background())
close(called)
// shutdown
conn.EXPECT().destroy(gomock.Any())
return conn
}

Expand Down
6 changes: 3 additions & 3 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ func (t *Transport) runSendQueue() {
}
}

// Close closes the underlying connection and waits until listen has returned.
// Close closes the underlying connection.
// If any listener was started, it will be closed as well.
// It is invalid to start new listeners or connections after that.
func (t *Transport) Close() error {
t.close(errors.New("closing"))
Expand All @@ -294,7 +295,6 @@ func (t *Transport) Close() error {
}

func (t *Transport) closeServer() {
t.handlerMap.CloseServer()
t.mutex.Lock()
t.server = nil
if t.isSingleUse {
Expand Down Expand Up @@ -322,7 +322,7 @@ func (t *Transport) close(e error) {
t.handlerMap.Close(e)
}
if t.server != nil {
t.server.setCloseError(e)
t.server.close(false)
}
t.closed = true
}
Expand Down
1 change: 0 additions & 1 deletion transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ var _ = Describe("Transport", func() {
phm := NewMockPacketHandlerManager(mockCtrl)
tr.handlerMap = phm

phm.EXPECT().CloseServer()
Expect(ln.Close()).To(Succeed())

// shutdown
Expand Down

0 comments on commit ac87a49

Please sign in to comment.