diff --git a/redis/client.py b/redis/client.py index 0af7e050d6..e34dcdafb5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -739,6 +739,7 @@ def __init__( self.health_check_response = [b"pong", self.health_check_response_b] if self.push_handler_func is None: _set_info_logger() + self._connection_lock = threading.Lock() self.reset() def __enter__(self) -> "PubSub": @@ -812,14 +813,16 @@ def execute_command(self, *args): # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) - if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + with self._connection_lock: + if self.connection is None: + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index fb46772af3..d8d33e2b7a 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1153,3 +1153,22 @@ def get_msg(): # the timeout on the read should not cause disconnect assert is_connected() + + +@pytest.mark.onlynoncluster +class TestConnectionLeak: + def test_connection_leak(self, r: redis.Redis): + pubsub = r.pubsub() + + def test(): + tid = threading.get_ident() + pubsub.subscribe(f"foo{tid}") + + threads = [threading.Thread(target=test) for _ in range(10)] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert r.connection_pool._created_connections == 2