Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
dvora-h committed May 2, 2024
1 parent 1b40a31 commit f234f8b
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 38 deletions.
6 changes: 3 additions & 3 deletions redis/_parsers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,9 @@ def string_keys_to_dict(key_string, callback):
),
"COMMAND": parse_command_resp3,
"CONFIG GET": lambda r: {
str_if_bytes(key)
if key is not None
else None: (str_if_bytes(value) if value is not None else None)
str_if_bytes(key) if key is not None else None: (
str_if_bytes(value) if value is not None else None
)
for key, value in r.items()
},
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
Expand Down
12 changes: 4 additions & 8 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,11 @@


class ResponseCallbackProtocol(Protocol):
def __call__(self, response: Any, **kwargs):
...
def __call__(self, response: Any, **kwargs): ...


class AsyncResponseCallbackProtocol(Protocol):
async def __call__(self, response: Any, **kwargs):
...
async def __call__(self, response: Any, **kwargs): ...


ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
Expand Down Expand Up @@ -1220,13 +1218,11 @@ async def run(


class PubsubWorkerExceptionHandler(Protocol):
def __call__(self, e: BaseException, pubsub: PubSub):
...
def __call__(self, e: BaseException, pubsub: PubSub): ...


class AsyncPubsubWorkerExceptionHandler(Protocol):
async def __call__(self, e: BaseException, pubsub: PubSub):
...
async def __call__(self, e: BaseException, pubsub: PubSub): ...


PSWorkerThreadExcHandlerT = Union[
Expand Down
14 changes: 7 additions & 7 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,10 @@ def __init__(
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.response_callbacks = kwargs["response_callbacks"]
self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
self.result_callbacks[
"CLUSTER SLOTS"
] = lambda cmd, res, **kwargs: parse_cluster_slots(
list(res.values())[0], **kwargs
self.result_callbacks["CLUSTER SLOTS"] = (
lambda cmd, res, **kwargs: parse_cluster_slots(
list(res.values())[0], **kwargs
)
)

self._initialize = True
Expand Down Expand Up @@ -1318,9 +1318,9 @@ async def initialize(self) -> None:
)
tmp_slots[i].append(target_replica_node)
# add this node to the nodes cache
tmp_nodes_cache[
target_replica_node.name
] = target_replica_node
tmp_nodes_cache[target_replica_node.name] = (
target_replica_node
)
else:
# Validate that 2 nodes want to use the same slot cache
# setup
Expand Down
8 changes: 3 additions & 5 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,11 @@ class _Sentinel(enum.Enum):


class ConnectCallbackProtocol(Protocol):
def __call__(self, connection: "AbstractConnection"):
...
def __call__(self, connection: "AbstractConnection"): ...


class AsyncConnectCallbackProtocol(Protocol):
async def __call__(self, connection: "AbstractConnection"):
...
async def __call__(self, connection: "AbstractConnection"): ...


ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
Expand Down Expand Up @@ -698,7 +696,7 @@ def _cache_invalidation_process(
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
if data[1] is None:
self.client_cache.flush()
else:
for key in data[1]:
Expand Down
6 changes: 3 additions & 3 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,9 +1582,9 @@ def initialize(self):
)
tmp_slots[i].append(target_replica_node)
# add this node to the nodes cache
tmp_nodes_cache[
target_replica_node.name
] = target_replica_node
tmp_nodes_cache[target_replica_node.name] = (
target_replica_node
)
else:
# Validate that 2 nodes want to use the same slot cache
# setup
Expand Down
4 changes: 1 addition & 3 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3399,9 +3399,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]:
"""
return self.execute_command("SMEMBERS", name, keys=[name])

def smismember(
self, name: str, values: List, *args: List
) -> Union[
def smismember(self, name: str, values: List, *args: List) -> Union[
Awaitable[List[Union[Literal[0], Literal[1]]]],
List[Union[Literal[0], Literal[1]]],
]:
Expand Down
3 changes: 1 addition & 2 deletions redis/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,4 @@ class SlotNotCoveredError(RedisClusterException):
pass


class MaxConnectionsError(ConnectionError):
...
class MaxConnectionsError(ConnectionError): ...
6 changes: 2 additions & 4 deletions redis/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@
class CommandsProtocol(Protocol):
connection_pool: Union["AsyncConnectionPool", "ConnectionPool"]

def execute_command(self, *args, **options):
...
def execute_command(self, *args, **options): ...


class ClusterCommandsProtocol(CommandsProtocol, Protocol):
encoder: "Encoder"

def execute_command(self, *args, **options) -> Union[Any, Awaitable]:
...
def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ...
3 changes: 2 additions & 1 deletion tests/test_asyncio/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ async def test_exec_error_in_response(self, r):
assert await pipe.set("z", "zzz").execute() == [True]
assert await r.get("z") == b"zzz"

@skip_if_redis_enterprise()
async def test_exec_error_raised(self, r):
await r.set("c", "a")
async with r.pipeline() as pipe:
Expand All @@ -139,7 +140,7 @@ async def test_transaction_with_empty_error_command(self, r):
"""
for error_switch in (True, False):
async with r.pipeline() as pipe:
pipe.set("a", 1).mget([]).set("c", 3)
pipe.set("a", 1).mget([]).set("a", 3)
result = await pipe.execute(raise_on_error=error_switch)

assert result[0]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ async def test_client(decoded_r: redis.Redis):

@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
async def test_scores(decoded_r: redis.Redis):
await decoded_r.ft().create_index((TextField("txt"),))

Expand Down Expand Up @@ -1013,6 +1014,7 @@ async def test_phonetic_matcher(decoded_r: redis.Redis):

@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
async def test_scorer(decoded_r: redis.Redis):
await decoded_r.ft().create_index((TextField("description"),))

Expand Down
2 changes: 2 additions & 0 deletions tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,7 @@ def test_rpushx(self, r):

# SCAN COMMANDS
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
@skip_if_server_version_lt("2.8.0")
def test_scan(self, r):
r.set("a", 1)
Expand All @@ -2144,6 +2145,7 @@ def test_scan(self, r):
assert set(keys) == {b"a"}

@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
@skip_if_server_version_lt("6.0.0")
def test_scan_type(self, r):
r.sadd("a-set", 1)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
from redis.exceptions import ResponseError

from .conftest import assert_resp_response, skip_if_server_version_lt
from .conftest import (
assert_resp_response,
skip_if_redis_enterprise,
skip_if_server_version_lt,
)

engine = "lua"
lib = "mylib"
Expand Down Expand Up @@ -51,6 +55,7 @@ def test_function_flush(self, r):
r.function_flush("ABC")

@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
def test_function_list(self, r):
r.function_load(f"#!{engine} name={lib} \n {function}")
res = [
Expand Down
3 changes: 2 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_exec_error_in_response(self, r):
assert pipe.set("z", "zzz").execute() == [True]
assert r["z"] == b"zzz"

@skip_if_redis_enterprise()
def test_exec_error_raised(self, r):
r["c"] = "a"
with r.pipeline() as pipe:
Expand All @@ -144,7 +145,7 @@ def test_transaction_with_empty_error_command(self, r):
"""
for error_switch in (True, False):
with r.pipeline() as pipe:
pipe.set("a", 1).mget([]).set("c", 3)
pipe.set("a", 1).mget([]).set("a", 3)
result = pipe.execute(raise_on_error=error_switch)

assert result[0]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def test_client(client):

@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
def test_scores(client):
client.ft().create_index((TextField("txt"),))

Expand Down Expand Up @@ -931,6 +932,7 @@ def test_phonetic_matcher(client):

@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
def test_scorer(client):
client.ft().create_index((TextField("description"),))

Expand Down Expand Up @@ -1942,6 +1944,7 @@ def test_profile(client):

@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
def test_profile_limited(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
Expand Down

0 comments on commit f234f8b

Please sign in to comment.