diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index a15b4c6cc0..7f5ac53e89 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -222,18 +222,23 @@ "import json\n", "import cachetools.func\n", "\n", - "sm_client = boto3.client('secretsmanager')\n", - " \n", - "def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n", - " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", - " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", - " secret = sm_client.get_secret_value(secret_id, version_id)\n", - " return json.loads(secret['SecretString'])\n", - " creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n", - " return creds['username'], creds['password']\n", + "class SecretsManagerProvider(redis.CredentialProvider):\n", + " def __init__(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n", + " self.sm_client = boto3.client('secretsmanager')\n", + " self.secret_id = secret_id\n", + " self.version_id = version_id\n", + " self.version_stage = version_stage\n", "\n", - "secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", - "creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n", + " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", + " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", + " secret = self.sm_client.get_secret_value(secret_id, version_id)\n", + " return json.loads(secret['SecretString'])\n", + " creds = get_sm_user_credentials(self.secret_id, self.version_id, self.version_stage)\n", + " return creds['username'], creds['password']\n", + "\n", + "my_secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", + "creds_provider = SecretsManagerProvider(secret_id=my_secret_id)\n", "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] @@ -266,19 +271,24 @@ "import boto3\n", "import cachetools.func\n", "\n", - "ec_client = boto3.client('elasticache')\n", + "class ElastiCacheIAMProvider(redis.CredentialProvider):\n", + " def __init__(self, user, endpoint, port=6379, region=\"us-east-1\"):\n", + " self.ec_client = boto3.client('elasticache')\n", + " self.user = user\n", + " self.endpoint = endpoint\n", + " self.port = port\n", + " self.region = region\n", "\n", - "def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n", - " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", - " def get_iam_auth_token(user, endpoint, port, region):\n", - " return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", - " iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n", - " return iam_auth_token\n", + " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", + " def get_iam_auth_token(user, endpoint, port, region):\n", + " return self.ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", + " iam_auth_token = get_iam_auth_token(self.endpoint, self.port, self.user, self.region)\n", + " return iam_auth_token\n", "\n", "username = \"barshaul\"\n", "endpoint = \"test-001.use1.cache.amazonaws.com\"\n", - "creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n", - " endpoint=endpoint)\n", + "creds_provider = ElastiCacheIAMProvider(user=username, endpoint=endpoint)\n", "user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] diff --git a/docs/examples/ssl_connection_examples.ipynb b/docs/examples/ssl_connection_examples.ipynb index 386e4af452..ab3b4415ae 100644 --- a/docs/examples/ssl_connection_examples.ipynb +++ b/docs/examples/ssl_connection_examples.ipynb @@ -55,6 +55,27 @@ "url_connection.ping()" ] }, + { + "cell_type": "markdown", + "id": "04e70233", + "metadata": {}, + "source": [ + "## Connecting to a Redis instance using ConnectionPool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2903de26", + "metadata": {}, + "outputs": [], + "source": [ + "import redis\n", + "redis_pool = redis.ConnectionPool(host=\"localhost\", port=6666, connection_class=redis.SSLConnection)\n", + "ssl_connection = redis.StrictRedis(connection_pool=redis_pool) \n", + "ssl_connection.ping()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index e77fba30da..056998e9e0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -267,9 +267,6 @@ async def _read_response( response: Any byte, response = raw[:1], raw[1:] - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - # server returned an error if byte == b"-": response = response.decode("utf-8", errors="replace") @@ -289,22 +286,24 @@ async def _read_response( pass # int value elif byte == b":": - response = int(response) + return int(response) # bulk response + elif byte == b"$" and response == b"-1": + return None elif byte == b"$": - length = int(response) - if length == -1: - return None - response = await self._read(length) + response = await self._read(int(response)) # multi-bulk response + elif byte == b"*" and response == b"-1": + return None elif byte == b"*": - length = int(response) - if length == -1: - return None response = [ - (await self._read_response(disable_decoding)) for _ in range(length) + (await self._read_response(disable_decoding)) + for _ in range(int(response)) # noqa ] - if isinstance(response, bytes) and disable_decoding is False: + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: response = self.encoder.decode(response) return response diff --git a/redis/commands/core.py b/redis/commands/core.py index 28dab81f8b..e2cabb85fa 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3357,10 +3357,15 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: def smismember( self, name: str, values: List, *args: List - ) -> Union[Awaitable[List[bool]], List[bool]]: + ) -> Union[ + Awaitable[List[Union[Literal[0], Literal[1]]]], + List[Union[Literal[0], Literal[1]]], + ]: """ Return whether each value in ``values`` is a member of the set ``name`` - as a list of ``bool`` in the order of ``values`` + as a list of ``int`` in the order of ``values``: + - 1 if the value is a member of the set. + - 0 if the value is not a member of the set or if key does not exist. For more information see https://redis.io/commands/smismember """ diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 7fd4039203..c02c47ad86 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -31,8 +31,8 @@ def arrindex( name: str, path: str, scalar: int, - start: Optional[int] = 0, - stop: Optional[int] = -1, + start: Optional[int] = None, + stop: Optional[int] = None, ) -> List[Union[int, None]]: """ Return the index of ``scalar`` in the JSON array under ``path`` at key @@ -43,9 +43,13 @@ def arrindex( For more information see `JSON.ARRINDEX `_. """ # noqa - return self.execute_command( - "JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop - ) + pieces = [name, str(path), self._encode(scalar)] + if start is not None: + pieces.append(start) + if stop is not None: + pieces.append(stop) + + return self.execute_command("JSON.ARRINDEX", *pieces) def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] diff --git a/redis/connection.py b/redis/connection.py index d35980c167..c4a9685f6a 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -358,9 +358,6 @@ def _read_response(self, disable_decoding=False): byte, response = raw[:1], raw[1:] - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - # server returned an error if byte == b"-": response = response.decode("utf-8", errors="replace") @@ -379,23 +376,24 @@ def _read_response(self, disable_decoding=False): pass # int value elif byte == b":": - response = int(response) + return int(response) # bulk response + elif byte == b"$" and response == b"-1": + return None elif byte == b"$": - length = int(response) - if length == -1: - return None - response = self._buffer.read(length) + response = self._buffer.read(int(response)) # multi-bulk response + elif byte == b"*" and response == b"-1": + return None elif byte == b"*": - length = int(response) - if length == -1: - return None response = [ self._read_response(disable_decoding=disable_decoding) - for i in range(length) + for i in range(int(response)) ] - if isinstance(response, bytes) and disable_decoding is False: + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: response = self.encoder.decode(response) return response diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index b8854d20cd..fc530c63c1 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -145,9 +145,15 @@ async def test_arrappend(modclient: redis.Redis): @pytest.mark.redismod async def test_arrindex(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 1 == await modclient.json().arrindex("arr", Path.root_path(), 1) - assert -1 == await modclient.json().arrindex("arr", Path.root_path(), 1, 2) + r_path = Path.root_path() + await modclient.json().set("arr", r_path, [0, 1, 2, 3, 4]) + assert 1 == await modclient.json().arrindex("arr", r_path, 1) + assert -1 == await modclient.json().arrindex("arr", r_path, 1, 2) + assert 4 == await modclient.json().arrindex("arr", r_path, 4) + assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0) + assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=5000) + assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=-1) + assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=1, stop=3) @pytest.mark.redismod diff --git a/tests/test_json.py b/tests/test_json.py index a776e9e736..8e8da05609 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -166,6 +166,11 @@ def test_arrindex(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 1 == client.json().arrindex("arr", Path.root_path(), 1) assert -1 == client.json().arrindex("arr", Path.root_path(), 1, 2) + assert 4 == client.json().arrindex("arr", Path.root_path(), 4) + assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0) + assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=5000) + assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=-1) + assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=1, stop=3) @pytest.mark.redismod