diff --git a/netconn.go b/netconn.go index c6f8dc13..aea1a02d 100644 --- a/netconn.go +++ b/netconn.go @@ -50,16 +50,14 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { writeMu: newMu(c), } - var writeCancel context.CancelFunc - nc.writeCtx, writeCancel = context.WithCancel(ctx) - var readCancel context.CancelFunc - nc.readCtx, readCancel = context.WithCancel(ctx) + nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) + nc.readCtx, nc.readCancel = context.WithCancel(ctx) nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { if !nc.writeMu.tryLock() { // If the lock cannot be acquired, then there is an // active write goroutine and so we should cancel the context. - writeCancel() + nc.writeCancel() return } defer nc.writeMu.unlock() @@ -75,7 +73,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { if !nc.readMu.tryLock() { // If the lock cannot be acquired, then there is an // active read goroutine and so we should cancel the context. - readCancel() + nc.readCancel() return } defer nc.readMu.unlock() @@ -98,11 +96,13 @@ type netConn struct { writeMu *mu writeExpired int64 writeCtx context.Context + writeCancel context.CancelFunc readTimer *time.Timer readMu *mu readExpired int64 readCtx context.Context + readCancel context.CancelFunc readEOFed bool reader io.Reader } @@ -111,7 +111,9 @@ var _ net.Conn = &netConn{} func (nc *netConn) Close() error { nc.writeTimer.Stop() + nc.writeCancel() nc.readTimer.Stop() + nc.readCancel() return nc.c.Close(StatusNormalClosure, "") }