Skip to content

Commit

Permalink
Merge branch 'master' into usr/aksinha334/redis-py#issue2598
Browse files Browse the repository at this point in the history
  • Loading branch information
aksinha334 committed Mar 15, 2023
2 parents 39bfa18 + a372ba4 commit 948a142
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 56 deletions.
50 changes: 30 additions & 20 deletions docs/examples/connection_examples.ipynb
Expand Up @@ -222,18 +222,23 @@
"import json\n",
"import cachetools.func\n",
"\n",
"sm_client = boto3.client('secretsmanager')\n",
" \n",
"def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
" @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
" def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
" secret = sm_client.get_secret_value(secret_id, version_id)\n",
" return json.loads(secret['SecretString'])\n",
" creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n",
" return creds['username'], creds['password']\n",
"class SecretsManagerProvider(redis.CredentialProvider):\n",
" def __init__(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
" self.sm_client = boto3.client('secretsmanager')\n",
" self.secret_id = secret_id\n",
" self.version_id = version_id\n",
" self.version_stage = version_stage\n",
"\n",
"secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
"creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n",
" def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
" @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
" def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
" secret = self.sm_client.get_secret_value(secret_id, version_id)\n",
" return json.loads(secret['SecretString'])\n",
" creds = get_sm_user_credentials(self.secret_id, self.version_id, self.version_stage)\n",
" return creds['username'], creds['password']\n",
"\n",
"my_secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
"creds_provider = SecretsManagerProvider(secret_id=my_secret_id)\n",
"user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
"user_connection.ping()"
]
Expand Down Expand Up @@ -266,19 +271,24 @@
"import boto3\n",
"import cachetools.func\n",
"\n",
"ec_client = boto3.client('elasticache')\n",
"class ElastiCacheIAMProvider(redis.CredentialProvider):\n",
" def __init__(self, user, endpoint, port=6379, region=\"us-east-1\"):\n",
" self.ec_client = boto3.client('elasticache')\n",
" self.user = user\n",
" self.endpoint = endpoint\n",
" self.port = port\n",
" self.region = region\n",
"\n",
"def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n",
" @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
" def get_iam_auth_token(user, endpoint, port, region):\n",
" return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
" iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n",
" return iam_auth_token\n",
" def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
" @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
" def get_iam_auth_token(user, endpoint, port, region):\n",
" return self.ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
" iam_auth_token = get_iam_auth_token(self.endpoint, self.port, self.user, self.region)\n",
" return iam_auth_token\n",
"\n",
"username = \"barshaul\"\n",
"endpoint = \"test-001.use1.cache.amazonaws.com\"\n",
"creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n",
" endpoint=endpoint)\n",
"creds_provider = ElastiCacheIAMProvider(user=username, endpoint=endpoint)\n",
"user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n",
"user_connection.ping()"
]
Expand Down
21 changes: 21 additions & 0 deletions docs/examples/ssl_connection_examples.ipynb
Expand Up @@ -55,6 +55,27 @@
"url_connection.ping()"
]
},
{
"cell_type": "markdown",
"id": "04e70233",
"metadata": {},
"source": [
"## Connecting to a Redis instance using ConnectionPool"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2903de26",
"metadata": {},
"outputs": [],
"source": [
"import redis\n",
"redis_pool = redis.ConnectionPool(host=\"localhost\", port=6666, connection_class=redis.SSLConnection)\n",
"ssl_connection = redis.StrictRedis(connection_pool=redis_pool) \n",
"ssl_connection.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
25 changes: 12 additions & 13 deletions redis/asyncio/connection.py
Expand Up @@ -267,9 +267,6 @@ async def _read_response(
response: Any
byte, response = raw[:1], raw[1:]

if byte not in (b"-", b"+", b":", b"$", b"*"):
raise InvalidResponse(f"Protocol Error: {raw!r}")

# server returned an error
if byte == b"-":
response = response.decode("utf-8", errors="replace")
Expand All @@ -289,22 +286,24 @@ async def _read_response(
pass
# int value
elif byte == b":":
response = int(response)
return int(response)
# bulk response
elif byte == b"$" and response == b"-1":
return None
elif byte == b"$":
length = int(response)
if length == -1:
return None
response = await self._read(length)
response = await self._read(int(response))
# multi-bulk response
elif byte == b"*" and response == b"-1":
return None
elif byte == b"*":
length = int(response)
if length == -1:
return None
response = [
(await self._read_response(disable_decoding)) for _ in range(length)
(await self._read_response(disable_decoding))
for _ in range(int(response)) # noqa
]
if isinstance(response, bytes) and disable_decoding is False:
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if disable_decoding is False:
response = self.encoder.decode(response)
return response

Expand Down
9 changes: 7 additions & 2 deletions redis/commands/core.py
Expand Up @@ -3357,10 +3357,15 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]:

def smismember(
self, name: str, values: List, *args: List
) -> Union[Awaitable[List[bool]], List[bool]]:
) -> Union[
Awaitable[List[Union[Literal[0], Literal[1]]]],
List[Union[Literal[0], Literal[1]]],
]:
"""
Return whether each value in ``values`` is a member of the set ``name``
as a list of ``bool`` in the order of ``values``
as a list of ``int`` in the order of ``values``:
- 1 if the value is a member of the set.
- 0 if the value is not a member of the set or if key does not exist.
For more information see https://redis.io/commands/smismember
"""
Expand Down
14 changes: 9 additions & 5 deletions redis/commands/json/commands.py
Expand Up @@ -31,8 +31,8 @@ def arrindex(
name: str,
path: str,
scalar: int,
start: Optional[int] = 0,
stop: Optional[int] = -1,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[Union[int, None]]:
"""
Return the index of ``scalar`` in the JSON array under ``path`` at key
Expand All @@ -43,9 +43,13 @@ def arrindex(
For more information see `JSON.ARRINDEX <https://redis.io/commands/json.arrindex>`_.
""" # noqa
return self.execute_command(
"JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop
)
pieces = [name, str(path), self._encode(scalar)]
if start is not None:
pieces.append(start)
if stop is not None:
pieces.append(stop)

return self.execute_command("JSON.ARRINDEX", *pieces)

def arrinsert(
self, name: str, path: str, index: int, *args: List[JsonType]
Expand Down
24 changes: 11 additions & 13 deletions redis/connection.py
Expand Up @@ -358,9 +358,6 @@ def _read_response(self, disable_decoding=False):

byte, response = raw[:1], raw[1:]

if byte not in (b"-", b"+", b":", b"$", b"*"):
raise InvalidResponse(f"Protocol Error: {raw!r}")

# server returned an error
if byte == b"-":
response = response.decode("utf-8", errors="replace")
Expand All @@ -379,23 +376,24 @@ def _read_response(self, disable_decoding=False):
pass
# int value
elif byte == b":":
response = int(response)
return int(response)
# bulk response
elif byte == b"$" and response == b"-1":
return None
elif byte == b"$":
length = int(response)
if length == -1:
return None
response = self._buffer.read(length)
response = self._buffer.read(int(response))
# multi-bulk response
elif byte == b"*" and response == b"-1":
return None
elif byte == b"*":
length = int(response)
if length == -1:
return None
response = [
self._read_response(disable_decoding=disable_decoding)
for i in range(length)
for i in range(int(response))
]
if isinstance(response, bytes) and disable_decoding is False:
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if disable_decoding is False:
response = self.encoder.decode(response)
return response

Expand Down
12 changes: 9 additions & 3 deletions tests/test_asyncio/test_json.py
Expand Up @@ -145,9 +145,15 @@ async def test_arrappend(modclient: redis.Redis):

@pytest.mark.redismod
async def test_arrindex(modclient: redis.Redis):
await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4])
assert 1 == await modclient.json().arrindex("arr", Path.root_path(), 1)
assert -1 == await modclient.json().arrindex("arr", Path.root_path(), 1, 2)
r_path = Path.root_path()
await modclient.json().set("arr", r_path, [0, 1, 2, 3, 4])
assert 1 == await modclient.json().arrindex("arr", r_path, 1)
assert -1 == await modclient.json().arrindex("arr", r_path, 1, 2)
assert 4 == await modclient.json().arrindex("arr", r_path, 4)
assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0)
assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=5000)
assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=-1)
assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=1, stop=3)


@pytest.mark.redismod
Expand Down
5 changes: 5 additions & 0 deletions tests/test_json.py
Expand Up @@ -166,6 +166,11 @@ def test_arrindex(client):
client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4])
assert 1 == client.json().arrindex("arr", Path.root_path(), 1)
assert -1 == client.json().arrindex("arr", Path.root_path(), 1, 2)
assert 4 == client.json().arrindex("arr", Path.root_path(), 4)
assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0)
assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=5000)
assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=-1)
assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=1, stop=3)


@pytest.mark.redismod
Expand Down

0 comments on commit 948a142

Please sign in to comment.