From f2f8c342091c2a093b911d38327436c4adc80f66 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Mon, 3 Jul 2023 13:55:12 +0300 Subject: [PATCH] Merge master to 5.0 (#2827) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: do not use asyncio's timeout lib before 3.11.2 (#2659) There's an issue in asyncio's timeout lib before 3.11.3 that causes async calls to raise `CancelledError`. This is a cpython issue that was fixed in this commit [1] and cherry-picked to previous versions, meaning 3.11.3 will work correctly. Check [2] for more info. [1] https://github.com/python/cpython/commit/04adf2df395ded81922c71360a5d66b597471e49 [2] https://github.com/redis/redis-py/issues/2633 * UnixDomainSocketConnection missing constructor argument (#2630) * removing useless files (#2642) * Fix issue 2660: PytestUnraisableExceptionWarning from asycio client (#2669) * Fixing cancelled async futures (#2666) Co-authored-by: James R T Co-authored-by: dvora-h * Fix async (#2673) * Version 4.5.4 (#2674) * Really do not use asyncio's timeout lib before 3.11.2 (#2699) 480253037afe4c12e38a0f98cadd3019a3724254 made async-timeout required only on Python 3.11.2 and earlier. However, according to PEP-508, python_version marker is compared to first two numbers of Python version tuple - so it will evaluate to True also on 3.11.3, and install a package as a dependency. * asyncio: Fix memory leak caused by hiredis (#2693) (#2694) * Update example of Redisearch creating index (#2703) When creating index, fields should be passed inside an iterable (e.g. list or tuple) * Improving Vector Similarity Search Example (#2661) * update vss docs * add embeddings creation and storage examples * update based on feedback * fix version and link * include more realistic search examples and clean up indices * completely remove initial cap reference --------- Co-authored-by: Chayim * Fix incorrect usage of once flag in async Sentinel (#2718) In the execute_command of the async Sentinel, the once flag was being used incorrectly, with its meaning inverted. To fix we just needed to invert the if and else bodies. This isn't being caught by the tests currently because the tests of commands that use this flag do not check their results/effects (for example the "test_ckquorum" test). * Fix topk list example. (#2724) * Improve error output for master discovery (#2720) Make MasterNotFoundError exception more precise in the case of ConnectionError and TimeoutError to help the user to identify configuration errors Co-authored-by: Marc Schöchlin * return response in case of KeyError (#2628) * return response in case of KeyError * fix code linters error * fix linters 2 * fix linters 3 * Add WITHSCORES to ZREVRANK Command (#2725) * add withscores to zrevrank * change 0 -> 2 * fix errors * split test * Fix `ClusterCommandProtocol` not itself being marked as a protocol (#2729) * Fix `ClusterCommandProtocol` not itself being marked as a protocol * Update CHANGES * Fix potential race condition during disconnection (#2719) When the disconnect() function is called twice in parallel it is possible that one thread deletes the self._sock reference, while the other thread will attempt to call .close() on it, leading to an AttributeError. This situation can routinely be encountered by closing the connection in a PubSubWorkerThread error handler in a blocking thread (ie. with sleep_time==None), and then calling .close() on the PubSub object. The main thread will then run into the disconnect() function, and the listener thread is woken up by the closure and will race into the disconnect() function, too. This can be fixed easily by copying the object reference before doing the None-check, similar to what we do in the redis.client.close() function. * add "address_remap" feature to RedisCluster (#2726) * add cluster "host_port_remap" feature for asyncio.RedisCluster * Add a unittest for asyncio.RedisCluster * Add host_port_remap to _sync_ RedisCluster * add synchronous tests * rename arg to `address_remap` and take and return an address tuple. * Add class documentation * Add CHANGES * nermina changes from NRedisStack (#2736) * Updated AWS Elasticache IAM Connection Example (#2702) Co-authored-by: Nick Gerow * pinning urllib3 to fix CI (#2748) * Add RedisCluster.remap_host_port, Update tests for CWE 404 (#2706) * Use provided redis address. Bind to IPv4 * Add missing "await" and perform the correct test for pipe eimpty * Wait for a send event, rather than rely on sleep time. Excpect cancel errors. * set delay to 0 except for operation we want to cancel This speeds up the unit tests considerably by eliminating unnecessary delay. * Release resources in test * Fix cluster test to use address_remap and multiple proxies. * Use context manager to manage DelayProxy * Mark failing pipeline tests * lint * Use a common "master_host" test fixture * Update redismodules.rst (#2747) Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Add support for cluster myshardid (#2704) * feat: adding support for cluster myshardid * lint fix * fix: comment fix and async test * fix: adding version check * fix lint: * linters --------- Co-authored-by: Anuragkillswitch <70265851+Anuragkillswitch@users.noreply.github.com> Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> Co-authored-by: dvora-h * clean warnings (#2731) * fix parse_slowlog_get (#2732) * Optionally disable disconnects in read_response (#2695) * Add regression tests and fixes for issue #1128 * Fix tests for resumable read_response to use "disconnect_on_error" * undo prevision fix attempts in async client and cluster * re-enable cluster test * Suggestions from code review * Add CHANGES * Add client no-touch (#2745) * Add client no-touch * Update redis/commands/core.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Update test_commands.py Improve test_client_no_touch * Update test_commands.py Add async version test case * Chore remove whitespace Oops --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * fix create single_connection_client from url (#2752) * Fix `xadd` allow non negative maxlen (#2739) * Fix xadd allow non negative maxlen * Update change log --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Version 4.5.5 (#2753) * Kristjan/issue #2754: Add missing argument to SentinelManagedConnection.read_response() (#2756) * Increase timeout for a test which would hang completely if failing. Timeouts in virtualized CI backends can occasionally fail if too short. * add "disconnect_on_error" argument to SentinelManagedConnection * update Changes * lint * support JSON.MERGE Command (#2761) * support JSON.MERGE Command * linters * try with abc instead person * change @skip_ifmodversion_lt to latest ReJSON 2.4.7 * change version * fix test * linters * add async test * Issue #2749: Remove unnecessary __del__ handlers (#2755) * Remove unnecessary __del__ handlers There normally should be no logic attached to del. Cleanly disconnecting network resources is not needed at that time. * add CHANGES * Add WITHSCORE to ZRANK (#2758) * add withscore to zrank with tests * fix test * Fix JSON.MERGE Summary (#2786) * Fix JSON.MERGE Summary * linters * Fixed key error in parse_xinfo_stream (#2788) * insert newline to prevent sphinx from assuming code block (#2796) * Introduce OutOfMemoryError exception for Redis write command rejections due to OOM errors (#2778) * expose OutOfMemoryError as explicit exception type - handle "OOM" error code string by raising explicit exception type instance - enables callers to avoid string matching after catching ResponseError * add OutOfMemoryError exception class docstring * Provide more info in the exception docstring * Fix formatting * Again * linters --------- Co-authored-by: Chayim Co-authored-by: Igor Malinovskiy Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Add unit tests for the `connect` method of all Redis connection classes (#2631) * tests: move certificate discovery to a separate module * tests: add 'connect' tests for all Redis connection classes --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Fix dead weakref in sentinel connection causing ReferenceError (#2767) (#2771) * Fix dead weakref in sentinel conn (#2767) * Update CHANGES --------- Co-authored-by: Igor Malinovskiy Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * chore(documentation): fix redirects and some small cleanups (#2801) * Add waitaof (#2760) * Add waitaof * Update test_commands.py add test_waitaof * Update test_commands.py Add test_waitaof * Fix doc string --------- Co-authored-by: Chayim Co-authored-by: Igor Malinovskiy * Extract abstract async connection class (#2734) * make 'socket_timeout' and 'socket_connect_timeout' equivalent for TCP and UDS connections * abstract asynio connection in analogy with the synchronous connection --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Fix type hint for retry_on_error in async cluster (#2804) * fix(asyncio.cluster): fixup retry_on_error type hint This parameter accepts a list of _classes of Exceptions_, not a list of instantiated Exceptions. Fixup the type hint accordingly. * chore: update changelog --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> * Fix CI (#2809) * Support JSON.MSET Command (#2766) * support JSON.MERGE Command * linters * try with abc instead person * change @skip_ifmodversion_lt to latest ReJSON 2.4.7 * change version * fix test * linters * add async test * Support JSON.MSET command * trying to run CI * linters * add async test * reminder do delete the integration changes * delete the line from integration * fix the interface * change docstring --------- Co-authored-by: Chayim Co-authored-by: dvora-h * Version 4.6.0 (#2810) * master changes * linters * fix test_cwe_404 cluster test --------- Co-authored-by: Thiago Bellini Ribeiro Co-authored-by: woutdenolf Co-authored-by: Chayim Co-authored-by: shacharPash <93581407+shacharPash@users.noreply.github.com> Co-authored-by: James R T Co-authored-by: Mirek Długosz Co-authored-by: Oran Avraham <252748+oranav@users.noreply.github.com> Co-authored-by: mzdehbashi-github <85902780+mzdehbashi-github@users.noreply.github.com> Co-authored-by: Tyler Hutcherson Co-authored-by: Felipe Machado <462154+felipou@users.noreply.github.com> Co-authored-by: AYMEN Mohammed <53928879+AYMENJD@users.noreply.github.com> Co-authored-by: Marc Schöchlin Co-authored-by: Marc Schöchlin Co-authored-by: Avasam Co-authored-by: Markus Gerstel <2102431+Anthchirp@users.noreply.github.com> Co-authored-by: Kristján Valur Jónsson Co-authored-by: Nick Gerow Co-authored-by: Nick Gerow Co-authored-by: Cristian Matache Co-authored-by: Anurag Bandyopadhyay Co-authored-by: Anuragkillswitch <70265851+Anuragkillswitch@users.noreply.github.com> Co-authored-by: Seongchuel Ahn Co-authored-by: Alibi Co-authored-by: Smit Parmar Co-authored-by: Brad MacPhee Co-authored-by: Igor Malinovskiy Co-authored-by: Shahar Lev Co-authored-by: Vladimir Mihailenco Co-authored-by: Kevin James --- CHANGES | 12 + CONTRIBUTING.md | 20 +- docs/examples/connection_examples.ipynb | 56 +- docs/examples/opentelemetry/main.py | 3 +- .../search_vector_similarity_examples.ipynb | 610 +++++++++++++++++- docs/opentelemetry.rst | 73 +-- docs/redismodules.rst | 10 +- redis/__init__.py | 2 + redis/asyncio/__init__.py | 2 + redis/asyncio/client.py | 31 +- redis/asyncio/cluster.py | 45 +- redis/asyncio/connection.py | 338 +++++----- redis/asyncio/sentinel.py | 17 +- redis/client.py | 25 +- redis/cluster.py | 31 + redis/commands/cluster.py | 9 +- redis/commands/core.py | 56 +- redis/commands/json/__init__.py | 2 + redis/commands/json/commands.py | 42 +- redis/connection.py | 51 +- redis/exceptions.py | 13 + redis/parsers/base.py | 9 +- redis/parsers/resp2.py | 5 +- redis/parsers/resp3.py | 5 +- redis/sentinel.py | 108 +++- redis/typing.py | 2 +- setup.py | 2 +- tests/asynctests | 285 -------- tests/conftest.py | 2 +- tests/ssl_utils.py | 14 + tests/synctests | 421 ------------ tests/test_asyncio/conftest.py | 8 - tests/test_asyncio/test_cluster.py | 137 +++- tests/test_asyncio/test_commands.py | 78 +++ tests/test_asyncio/test_connect.py | 144 +++++ tests/test_asyncio/test_connection.py | 29 +- tests/test_asyncio/test_connection_pool.py | 20 +- tests/test_asyncio/test_cwe_404.py | 250 +++++++ tests/test_asyncio/test_json.py | 46 ++ tests/test_asyncio/test_pubsub.py | 7 +- tests/test_asyncio/test_sentinel.py | 2 +- tests/test_cluster.py | 123 ++++ tests/test_commands.py | 83 +++ tests/test_connect.py | 184 ++++++ tests/test_connection.py | 10 +- tests/test_connection_pool.py | 9 +- tests/test_json.py | 42 ++ tests/test_sentinel.py | 9 + tests/test_ssl.py | 15 +- 49 files changed, 2341 insertions(+), 1156 deletions(-) delete mode 100644 tests/asynctests create mode 100644 tests/ssl_utils.py delete mode 100644 tests/synctests create mode 100644 tests/test_asyncio/test_connect.py create mode 100644 tests/test_asyncio/test_cwe_404.py create mode 100644 tests/test_connect.py diff --git a/CHANGES b/CHANGES index b0744c6038..49f87cd35d 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,13 @@ + * Fix incorrect redis.asyncio.Cluster type hint for `retry_on_error` + * Fix dead weakref in sentinel connection causing ReferenceError (#2767) + * Fix #2768, Fix KeyError: 'first-entry' in parse_xinfo_stream. + * Fix #2749, remove unnecessary __del__ logic to close connections. + * Fix #2754, adding a missing argument to SentinelManagedConnection + * Fix `xadd` command to accept non-negative `maxlen` including 0 + * Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624) + * Add `address_remap` parameter to `RedisCluster` + * Fix incorrect usage of once flag in async Sentinel + * asyncio: Fix memory leak caused by hiredis (#2693) * Allow data to drain from async PythonParser when reading during a disconnect() * Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602) * Add test and fix async HiredisParser when reading during a disconnect() (#2349) @@ -40,6 +50,8 @@ * Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458) * Added a replacement for the default cluster node in the event of failure (#2463) * Fix for Unhandled exception related to self.host with unix socket (#2496) + * Improve error output for master discovery + * Make `ClusterCommandsProtocol` an actual Protocol * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2909f04f0b..90a538be46 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,15 +2,15 @@ ## Introduction -First off, thank you for considering contributing to redis-py. We value -community contributions! +We appreciate your interest in considering contributing to redis-py. +Community contributions mean a lot to us. -## Contributions We Need +## Contributions we need -You may already know what you want to contribute \-- a fix for a bug you +You may already know how you'd like to contribute, whether it's a fix for a bug you encountered, or a new feature your team wants to use. -If you don't know what to contribute, keep an open mind! Improving +If you don't know where to start, consider improving documentation, bug triaging, and writing tutorials are all examples of helpful contributions that mean less work for you. @@ -166,19 +166,19 @@ When filing an issue, make sure to answer these five questions: 4. What did you expect to see? 5. What did you see instead? -## How to Suggest a Feature or Enhancement +## Suggest a feature or enhancement If you'd like to contribute a new feature, make sure you check our issue list to see if someone has already proposed it. Work may already -be under way on the feature you want -- or we may have rejected a +be underway on the feature you want or we may have rejected a feature like it already. If you don't see anything, open a new issue that describes the feature you would like and how it should work. -## Code Review Process +## Code review process -The core team looks at Pull Requests on a regular basis. We will give -feedback as as soon as possible. After feedback, we expect a response +The core team regularly looks at pull requests. We will provide +feedback as as soon as possible. After receiving our feedback, please respond within two weeks. After that time, we may close your PR if it isn't showing any activity. diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index 7f5ac53e89..d15d964af7 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -267,28 +267,60 @@ } ], "source": [ + "from typing import Tuple, Union\n", + "from urllib.parse import ParseResult, urlencode, urlunparse\n", + "\n", + "import botocore.session\n", "import redis\n", - "import boto3\n", - "import cachetools.func\n", + "from botocore.model import ServiceId\n", + "from botocore.signers import RequestSigner\n", + "from cachetools import TTLCache, cached\n", "\n", "class ElastiCacheIAMProvider(redis.CredentialProvider):\n", - " def __init__(self, user, endpoint, port=6379, region=\"us-east-1\"):\n", - " self.ec_client = boto3.client('elasticache')\n", + " def __init__(self, user, cluster_name, region=\"us-east-1\"):\n", " self.user = user\n", - " self.endpoint = endpoint\n", - " self.port = port\n", + " self.cluster_name = cluster_name\n", " self.region = region\n", "\n", + " session = botocore.session.get_session()\n", + " self.request_signer = RequestSigner(\n", + " ServiceId(\"elasticache\"),\n", + " self.region,\n", + " \"elasticache\",\n", + " \"v4\",\n", + " session.get_credentials(),\n", + " session.get_component(\"event_emitter\"),\n", + " )\n", + "\n", + " # Generated IAM tokens are valid for 15 minutes\n", + " @cached(cache=TTLCache(maxsize=128, ttl=900))\n", " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", - " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", - " def get_iam_auth_token(user, endpoint, port, region):\n", - " return self.ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", - " iam_auth_token = get_iam_auth_token(self.endpoint, self.port, self.user, self.region)\n", - " return iam_auth_token\n", + " query_params = {\"Action\": \"connect\", \"User\": self.user}\n", + " url = urlunparse(\n", + " ParseResult(\n", + " scheme=\"https\",\n", + " netloc=self.cluster_name,\n", + " path=\"/\",\n", + " query=urlencode(query_params),\n", + " params=\"\",\n", + " fragment=\"\",\n", + " )\n", + " )\n", + " signed_url = self.request_signer.generate_presigned_url(\n", + " {\"method\": \"GET\", \"url\": url, \"body\": {}, \"headers\": {}, \"context\": {}},\n", + " operation_name=\"connect\",\n", + " expires_in=900,\n", + " region_name=self.region,\n", + " )\n", + " # RequestSigner only seems to work if the URL has a protocol, but\n", + " # Elasticache only accepts the URL without a protocol\n", + " # So strip it off the signed URL before returning\n", + " return (self.user, signed_url.removeprefix(\"https://\"))\n", "\n", "username = \"barshaul\"\n", + "cluster_name = \"test-001\"\n", "endpoint = \"test-001.use1.cache.amazonaws.com\"\n", - "creds_provider = ElastiCacheIAMProvider(user=username, endpoint=endpoint)\n", + "creds_provider = ElastiCacheIAMProvider(user=username, cluster_name=cluster_name)\n", "user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] diff --git a/docs/examples/opentelemetry/main.py b/docs/examples/opentelemetry/main.py index b140dd0148..9ef6723305 100755 --- a/docs/examples/opentelemetry/main.py +++ b/docs/examples/opentelemetry/main.py @@ -2,12 +2,11 @@ import time +import redis import uptrace from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor -import redis - tracer = trace.get_tracer("app_or_package_name", "1.0.0") diff --git a/docs/examples/search_vector_similarity_examples.ipynb b/docs/examples/search_vector_similarity_examples.ipynb index 2b0261097c..bd1df3c1ef 100644 --- a/docs/examples/search_vector_similarity_examples.ipynb +++ b/docs/examples/search_vector_similarity_examples.ipynb @@ -1,81 +1,643 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Vector Similarity\n", - "## Adding Vector Fields" + "**Vectors** (also called \"Embeddings\"), represent an AI model's impression (or understanding) of a piece of unstructured data like text, images, audio, videos, etc. Vector Similarity Search (VSS) is the process of finding vectors in the vector database that are similar to a given query vector. Popular VSS uses include recommendation systems, image and video search, document retrieval, and question answering.\n", + "\n", + "## Index Creation\n", + "Before doing vector search, first define the schema and create an index." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, + "outputs": [], + "source": [ + "import redis\n", + "from redis.commands.search.field import TagField, VectorField\n", + "from redis.commands.search.indexDefinition import IndexDefinition, IndexType\n", + "from redis.commands.search.query import Query\n", + "\n", + "r = redis.Redis(host=\"localhost\", port=6379)\n", + "\n", + "INDEX_NAME = \"index\" # Vector Index Name\n", + "DOC_PREFIX = \"doc:\" # RediSearch Key Prefix for the Index\n", + "\n", + "def create_index(vector_dimensions: int):\n", + " try:\n", + " # check to see if index exists\n", + " r.ft(INDEX_NAME).info()\n", + " print(\"Index already exists!\")\n", + " except:\n", + " # schema\n", + " schema = (\n", + " TagField(\"tag\"), # Tag Field Name\n", + " VectorField(\"vector\", # Vector Field Name\n", + " \"FLAT\", { # Vector Index Type: FLAT or HNSW\n", + " \"TYPE\": \"FLOAT32\", # FLOAT32 or FLOAT64\n", + " \"DIM\": vector_dimensions, # Number of Vector Dimensions\n", + " \"DISTANCE_METRIC\": \"COSINE\", # Vector Search Distance Metric\n", + " }\n", + " ),\n", + " )\n", + "\n", + " # index Definition\n", + " definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH)\n", + "\n", + " # create Index\n", + " r.ft(INDEX_NAME).create_index(fields=schema, definition=definition)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll start by working with vectors that have 1536 dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# define vector dimensions\n", + "VECTOR_DIMENSIONS = 1536\n", + "\n", + "# create the index\n", + "create_index(vector_dimensions=VECTOR_DIMENSIONS)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Vectors to Redis\n", + "\n", + "Next, we add vectors (dummy data) to Redis using `hset`. The search index listens to keyspace notifications and will include any written HASH objects prefixed by `DOC_PREFIX`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install numpy" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# instantiate a redis pipeline\n", + "pipe = r.pipeline()\n", + "\n", + "# define some dummy data\n", + "objects = [\n", + " {\"name\": \"a\", \"tag\": \"foo\"},\n", + " {\"name\": \"b\", \"tag\": \"foo\"},\n", + " {\"name\": \"c\", \"tag\": \"bar\"},\n", + "]\n", + "\n", + "# write data\n", + "for obj in objects:\n", + " # define key\n", + " key = f\"doc:{obj['name']}\"\n", + " # create a random \"dummy\" vector\n", + " obj[\"vector\"] = np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + " # HSET\n", + " pipe.hset(key, mapping=obj)\n", + "\n", + "res = pipe.execute()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Searching\n", + "You can use VSS queries with the `.ft(...).search(...)` query command. To use a VSS query, you must specify the option `.dialect(2)`.\n", + "\n", + "There are two supported types of vector queries in Redis: `KNN` and `Range`. `Hybrid` queries can work in both settings and combine elements of traditional search and VSS." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### KNN Queries\n", + "KNN queries are for finding the topK most similar vectors given a query vector." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { "text/plain": [ - "b'OK'" + "[Document {'id': 'doc:b', 'payload': None, 'score': '0.2376562953'},\n", + " Document {'id': 'doc:c', 'payload': None, 'score': '0.240063905716'}]" ] }, - "execution_count": 1, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import redis\n", - "from redis.commands.search.field import VectorField\n", - "from redis.commands.search.query import Query\n", + "query = (\n", + " Query(\"*=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"id\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", "\n", - "r = redis.Redis(host='localhost', port=36379)\n", + "query_params = {\n", + " \"vec\": np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + "}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Range Queries\n", + "Range queries provide a way to filter results by the distance between a vector field in Redis and a query vector based on some pre-defined threshold (radius)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document {'id': 'doc:a', 'payload': None, 'score': '0.243115246296'},\n", + " Document {'id': 'doc:c', 'payload': None, 'score': '0.24981123209'},\n", + " Document {'id': 'doc:b', 'payload': None, 'score': '0.251443207264'}]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = (\n", + " Query(\"@vector:[VECTOR_RANGE $radius $vec]=>{$YIELD_DISTANCE_AS: score}\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"id\", \"score\")\n", + " .paging(0, 3)\n", + " .dialect(2)\n", + ")\n", "\n", - "schema = (VectorField(\"v\", \"HNSW\", {\"TYPE\": \"FLOAT32\", \"DIM\": 2, \"DISTANCE_METRIC\": \"L2\"}),)\n", - "r.ft().create_index(schema)" + "# Find all vectors within 0.8 of the query vector\n", + "query_params = {\n", + " \"radius\": 0.8,\n", + " \"vec\": np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + "}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Searching" + "See additional Range Query examples in [this Jupyter notebook](https://github.com/RediSearch/RediSearch/blob/master/docs/docs/vecsim-range_queries_examples.ipynb)." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "### Querying vector fields" + "### Hybrid Queries\n", + "Hybrid queries contain both traditional filters (numeric, tags, text) and VSS in one single Redis command." ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document {'id': 'doc:b', 'payload': None, 'score': '0.24422544241', 'tag': 'foo'},\n", + " Document {'id': 'doc:a', 'payload': None, 'score': '0.259926855564', 'tag': 'foo'}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "query = (\n", + " Query(\"(@tag:{ foo })=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"id\", \"tag\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", + "\n", + "query_params = {\n", + " \"vec\": np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + "}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See additional Hybrid Query examples in [this Jupyter notebook](https://github.com/RediSearch/RediSearch/blob/master/docs/docs/vecsim-hybrid_queries_examples.ipynb)." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vector Creation and Storage Examples\n", + "The above examples use dummy data as vectors. However, in reality, most use cases leverage production-grade AI models for creating embeddings. Below we will take some sample text data, pass it to the OpenAI and Cohere API's respectively, and then write them to Redis." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "texts = [\n", + " \"Today is a really great day!\",\n", + " \"The dog next door barks really loudly.\",\n", + " \"My cat escaped and got out before I could close the door.\",\n", + " \"It's supposed to rain and thunder tomorrow.\"\n", + "]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### OpenAI Embeddings\n", + "Before working with OpenAI Embeddings, we clean up our existing search index and create a new one." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# delete index\n", + "r.ft(INDEX_NAME).dropindex(delete_documents=True)\n", + "\n", + "# make a new one\n", + "create_index(vector_dimensions=VECTOR_DIMENSIONS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install openai" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "# set your OpenAI API key - get one at https://platform.openai.com\n", + "openai.api_key = \"YOUR OPENAI API KEY\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Embeddings with OpenAI text-embedding-ada-002\n", + "# https://openai.com/blog/new-and-improved-embedding-model\n", + "response = openai.Embedding.create(input=texts, engine=\"text-embedding-ada-002\")\n", + "embeddings = np.array([r[\"embedding\"] for r in response[\"data\"]], dtype=np.float32)\n", + "\n", + "# Write to Redis\n", + "pipe = r.pipeline()\n", + "for i, embedding in enumerate(embeddings):\n", + " pipe.hset(f\"doc:{i}\", mapping = {\n", + " \"vector\": embedding.tobytes(),\n", + " \"content\": texts[i],\n", + " \"tag\": \"openai\"\n", + " })\n", + "res = pipe.execute()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0.00509819, 0.0010873 , -0.00228475, ..., -0.00457579,\n", + " 0.01329307, -0.03167175],\n", + " [-0.00357223, -0.00550784, -0.01314328, ..., -0.02915693,\n", + " 0.01470436, -0.01367203],\n", + " [-0.01284631, 0.0034875 , -0.01719686, ..., -0.01537451,\n", + " 0.01953256, -0.05048691],\n", + " [-0.01145045, -0.00785481, 0.00206323, ..., -0.02070181,\n", + " -0.01629098, -0.00300795]], dtype=float32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Search with OpenAI Embeddings\n", + "\n", + "Now that we've created embeddings with OpenAI, we can also perform a search to find relevant documents to some input text.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.00062901, -0.0070723 , -0.00148926, ..., -0.01904645,\n", + " -0.00436092, -0.01117944], dtype=float32)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"animals\"\n", + "\n", + "# create query embedding\n", + "response = openai.Embedding.create(input=[text], engine=\"text-embedding-ada-002\")\n", + "query_embedding = np.array([r[\"embedding\"] for r in response[\"data\"]], dtype=np.float32)[0]\n", + "\n", + "query_embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Result{2 total, docs: [Document {'id': 'a', 'payload': None, '__v_score': '0'}, Document {'id': 'b', 'payload': None, '__v_score': '3.09485009821e+26'}]}" + "[Document {'id': 'doc:1', 'payload': None, 'score': '0.214349985123', 'content': 'The dog next door barks really loudly.', 'tag': 'openai'},\n", + " Document {'id': 'doc:2', 'payload': None, 'score': '0.237052619457', 'content': 'My cat escaped and got out before I could close the door.', 'tag': 'openai'}]" ] }, - "execution_count": 2, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "r.hset(\"a\", \"v\", \"aaaaaaaa\")\n", - "r.hset(\"b\", \"v\", \"aaaabaaa\")\n", - "r.hset(\"c\", \"v\", \"aaaaabaa\")\n", + "# query for similar documents that have the openai tag\n", + "query = (\n", + " Query(\"(@tag:{ openai })=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"content\", \"tag\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", + "\n", + "query_params = {\"vec\": query_embedding.tobytes()}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs\n", "\n", - "q = Query(\"*=>[KNN 2 @v $vec]\").return_field(\"__v_score\").dialect(2)\n", - "r.ft().search(q, query_params={\"vec\": \"aaaaaaaa\"})" + "# the two pieces of content related to animals are returned" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cohere Embeddings\n", + "Before working with Cohere Embeddings, we clean up our existing search index and create a new one." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# delete index\n", + "r.ft(INDEX_NAME).dropindex(delete_documents=True)\n", + "\n", + "# make a new one for cohere embeddings (1024 dimensions)\n", + "VECTOR_DIMENSIONS = 1024\n", + "create_index(vector_dimensions=VECTOR_DIMENSIONS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install cohere" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import cohere\n", + "\n", + "co = cohere.Client(\"YOUR COHERE API KEY\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Embeddings with Cohere\n", + "# https://docs.cohere.ai/docs/embeddings\n", + "response = co.embed(texts=texts, model=\"small\")\n", + "embeddings = np.array(response.embeddings, dtype=np.float32)\n", + "\n", + "# Write to Redis\n", + "for i, embedding in enumerate(embeddings):\n", + " r.hset(f\"doc:{i}\", mapping = {\n", + " \"vector\": embedding.tobytes(),\n", + " \"content\": texts[i],\n", + " \"tag\": \"cohere\"\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.3034668 , -0.71533203, -0.2836914 , ..., 0.81152344,\n", + " 1.0253906 , -0.8095703 ],\n", + " [-0.02560425, -1.4912109 , 0.24267578, ..., -0.89746094,\n", + " 0.15625 , -3.203125 ],\n", + " [ 0.10125732, 0.7246094 , -0.29516602, ..., -1.9638672 ,\n", + " 1.6630859 , -0.23291016],\n", + " [-2.09375 , 0.8588867 , -0.23352051, ..., -0.01541138,\n", + " 0.17053223, -3.4042969 ]], dtype=float32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Search with Cohere Embeddings\n", + "\n", + "Now that we've created embeddings with Cohere, we can also perform a search to find relevant documents to some input text." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.49682617, 1.7070312 , 0.3466797 , ..., 0.58984375,\n", + " 0.1060791 , -2.9023438 ], dtype=float32)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"animals\"\n", + "\n", + "# create query embedding\n", + "response = co.embed(texts=[text], model=\"small\")\n", + "query_embedding = np.array(response.embeddings[0], dtype=np.float32)\n", + "\n", + "query_embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document {'id': 'doc:1', 'payload': None, 'score': '0.658673524857', 'content': 'The dog next door barks really loudly.', 'tag': 'cohere'},\n", + " Document {'id': 'doc:2', 'payload': None, 'score': '0.662699103355', 'content': 'My cat escaped and got out before I could close the door.', 'tag': 'cohere'}]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# query for similar documents that have the cohere tag\n", + "query = (\n", + " Query(\"(@tag:{ cohere })=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"content\", \"tag\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", + "\n", + "query_params = {\"vec\": query_embedding.tobytes()}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs\n", + "\n", + "# the two pieces of content related to animals are returned" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Find more example apps, tutorials, and projects using Redis Vector Similarity Search [in this GitHub organization](https://github.com/RedisVentures)." ] } ], @@ -98,7 +660,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.9.12" }, "orig_nbformat": 4 }, diff --git a/docs/opentelemetry.rst b/docs/opentelemetry.rst index 96781028e3..d006a60461 100644 --- a/docs/opentelemetry.rst +++ b/docs/opentelemetry.rst @@ -4,7 +4,7 @@ Integrating OpenTelemetry What is OpenTelemetry? ---------------------- -`OpenTelemetry `_ is an open-source observability framework for traces, metrics, and logs. +`OpenTelemetry `_ is an open-source observability framework for traces, metrics, and logs. It is a merger of OpenCensus and OpenTracing projects hosted by Cloud Native Computing Foundation. OpenTelemetry allows developers to collect and export telemetry data in a vendor agnostic way. With OpenTelemetry, you can instrument your application once and then add or change vendors without changing the instrumentation, for example, here is a list of `popular DataDog competitors `_ that support OpenTelemetry. @@ -97,7 +97,7 @@ See `OpenTelemetry Python Tracing API `_ that supports distributed tracing, metrics, and logs. You can use it to monitor applications and set up automatic alerts to receive notifications via email, Slack, Telegram, and more. +Uptrace is an `open source APM `_ that supports distributed tracing, metrics, and logs. You can use it to monitor applications and set up automatic alerts to receive notifications via email, Slack, Telegram, and more. You can use Uptrace to monitor redis-py using this `GitHub example `_ as a starting point. @@ -111,9 +111,9 @@ Monitoring Redis Server performance In addition to monitoring redis-py client, you can also monitor Redis Server performance using OpenTelemetry Collector Agent. -OpenTelemetry Collector is a proxy/middleman between your application and a `distributed tracing tool `_ such as Uptrace or Jaeger. Collector receives telemetry data, processes it, and then exports the data to APM tools that can store it permanently. +OpenTelemetry Collector is a proxy/middleman between your application and a `distributed tracing tool `_ such as Uptrace or Jaeger. Collector receives telemetry data, processes it, and then exports the data to APM tools that can store it permanently. -For example, you can use the Redis receiver provided by Otel Collector to `monitor Redis performance `_: +For example, you can use the `OpenTelemetry Redis receiver ` provided by Otel Collector to monitor Redis performance: .. image:: images/opentelemetry/redis-metrics.png :alt: Redis metrics @@ -123,55 +123,50 @@ See introduction to `OpenTelemetry Collector `_ using alerting rules. For example, the following rule uses the group by node expression to create an alert whenever an individual Redis shard is down: +Uptrace also allows you to monitor `OpenTelemetry metrics `_ using alerting rules. For example, the following monitor uses the group by node expression to create an alert whenever an individual Redis shard is down: .. code-block:: python - # /etc/uptrace/uptrace.yml - - alerting: - rules: - - name: Redis shard is down - metrics: - - redis_up as $redis_up - query: - - group by cluster # monitor each cluster, - - group by bdb # each database, - - group by node # and each shard - - $redis_up == 0 - # shard should be down for 5 minutes to trigger an alert - for: 5m + monitors: + - name: Redis shard is down + metrics: + - redis_up as $redis_up + query: + - group by cluster # monitor each cluster, + - group by bdb # each database, + - group by node # and each shard + - $redis_up + min_allowed_value: 1 + # shard should be down for 5 minutes to trigger an alert + for_duration: 5m You can also create queries with more complex expressions. For example, the following rule creates an alert when the keyspace hit rate is lower than 75%: .. code-block:: python - # /etc/uptrace/uptrace.yml - - alerting: - rules: - - name: Redis read hit rate < 75% - metrics: - - redis_keyspace_read_hits as $hits - - redis_keyspace_read_misses as $misses - query: - - group by cluster - - group by bdb - - group by node - - $hits / ($hits + $misses) < 0.75 - for: 5m + monitors: + - name: Redis read hit rate < 75% + metrics: + - redis_keyspace_read_hits as $hits + - redis_keyspace_read_misses as $misses + query: + - group by cluster + - group by bdb + - group by node + - $hits / ($hits + $misses) as hit_rate + min_allowed_value: 0.75 + for_duration: 5m See `Alerting and Notifications `_ for details. What's next? ------------ -Next, you can learn how to configure `uptrace-python `_ to export spans, metrics, and logs to Uptrace. +Next, you can learn how to configure `uptrace-python `_ to export spans, metrics, and logs to Uptrace. You may also be interested in the following guides: -- `OpenTelemetry Django `_ -- `OpenTelemetry Flask `_ -- `OpenTelemetry FastAPI `_ -- `OpenTelemetry SQLAlchemy `_ -- `OpenTelemetry instrumentations `_ +- `OpenTelemetry Django `_ +- `OpenTelemetry Flask `_ +- `OpenTelemetry FastAPI `_ +- `OpenTelemetry SQLAlchemy `_ diff --git a/docs/redismodules.rst b/docs/redismodules.rst index 2b0b3c6533..27757cb692 100644 --- a/docs/redismodules.rst +++ b/docs/redismodules.rst @@ -44,7 +44,7 @@ These are the commands for interacting with the `RedisBloom module None: - if self.connection is not None: + if hasattr(self, "connection") and (self.connection is not None): _warnings.warn( f"Unclosed client session {self!r}", ResourceWarning, source=self ) @@ -713,6 +717,11 @@ async def reset(self): self.pending_unsubscribe_patterns = set() def close(self) -> Awaitable[NoReturn]: + # In case a connection property does not yet exist + # (due to a crash earlier in the Redis() constructor), return + # immediately as there is nothing to clean-up. + if not hasattr(self, "connection"): + return return self.reset() async def on_connect(self, connection: Connection): @@ -806,7 +815,11 @@ async def parse_response(self, block: bool = True, timeout: float = 0): read_timeout = None if block else timeout response = await self._execute( - conn, conn.read_response, timeout=read_timeout, push_request=True + conn, + conn.read_response, + timeout=read_timeout, + disconnect_on_error=False, + push_request=True, ) if conn.health_check_interval and response in self.health_check_response: @@ -1404,16 +1417,10 @@ async def execute(self, raise_on_error: bool = True): conn = cast(Connection, conn) try: - return await asyncio.shield( - conn.retry.call_with_retry( - lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error), - ) + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), ) - except asyncio.CancelledError: - # not supposed to be possible, yet here we are - await conn.disconnect(nowait=True) - raise finally: await self.reset() diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 1c4222c885..5c7aecfe23 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,12 +5,14 @@ import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, @@ -141,6 +143,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. | Rest of the arguments will be passed to the :class:`~redis.asyncio.connection.Connection` instances when created @@ -235,7 +243,7 @@ def __init__( socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, socket_timeout: Optional[float] = None, retry: Optional["Retry"] = None, - retry_on_error: Optional[List[Exception]] = None, + retry_on_error: Optional[List[Type[Exception]]] = None, # SSL related kwargs ssl: bool = False, ssl_ca_certs: Optional[str] = None, @@ -245,6 +253,7 @@ def __init__( ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, protocol: Optional[int] = 2, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -337,7 +346,12 @@ def __init__( if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + address_remap=address_remap, + ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps @@ -1002,18 +1016,10 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: await connection.send_packed_command(connection.pack_command(*args), False) # Read response - return await asyncio.shield( - self._parse_and_release(connection, args[0], **kwargs) - ) - - async def _parse_and_release(self, connection, *args, **kwargs): try: - return await self.parse_response(connection, *args, **kwargs) - except asyncio.CancelledError: - # should not be possible - await connection.disconnect(nowait=True) - raise + return await self.parse_response(connection, args[0], **kwargs) finally: + # Release connection self._free.append(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: @@ -1052,6 +1058,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "address_remap", ) def __init__( @@ -1059,10 +1066,12 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1221,6 +1230,7 @@ async def initialize(self) -> None: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: @@ -1239,6 +1249,7 @@ async def initialize(self) -> None: for replica_node in replica_nodes: host = replica_node[0] port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) @@ -1312,6 +1323,16 @@ async def close(self, attr: str = "nodes_cache") -> None: ) ) + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bf6274922e..fc69b9091a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -8,6 +8,7 @@ import sys import threading import weakref +from abc import abstractmethod from itertools import chain from types import MappingProxyType from typing import ( @@ -25,7 +26,9 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse -if sys.version_info.major >= 3 and sys.version_info.minor >= 11: +# the functionality is available in 3.11.x but has a major issue before +# 3.11.3. See https://github.com/redis/redis-py/issues/2633 +if sys.version_info >= (3, 11, 3): from asyncio import timeout as async_timeout else: from async_timeout import timeout as async_timeout @@ -78,25 +81,23 @@ class _Sentinel(enum.Enum): class ConnectCallbackProtocol(Protocol): - def __call__(self, connection: "Connection"): + def __call__(self, connection: "AbstractConnection"): ... class AsyncConnectCallbackProtocol(Protocol): - async def __call__(self, connection: "Connection"): + async def __call__(self, connection: "AbstractConnection"): ... ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] -class Connection: - """Manages TCP communication to and from a Redis server""" +class AbstractConnection: + """Manages communication to and from a Redis server""" __slots__ = ( "pid", - "host", - "port", "db", "username", "client_name", @@ -104,9 +105,6 @@ class Connection: "password", "socket_timeout", "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_type", "redis_connect_func", "retry_on_timeout", "retry_on_error", @@ -129,15 +127,10 @@ class Connection: def __init__( self, *, - host: str = "localhost", - port: Union[str, int] = 6379, db: Union[str, int] = 0, password: Optional[str] = None, socket_timeout: Optional[float] = None, socket_connect_timeout: Optional[float] = None, - socket_keepalive: bool = False, - socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, - socket_type: int = 0, retry_on_timeout: bool = False, retry_on_error: Union[list, _Sentinel] = SENTINEL, encoding: str = "utf-8", @@ -162,18 +155,15 @@ def __init__( "2. 'credential_provider'" ) self.pid = os.getpid() - self.host = host - self.port = int(port) self.db = db self.client_name = client_name self.credential_provider = credential_provider self.password = password self.username = username self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None - self.socket_keepalive = socket_keepalive - self.socket_keepalive_options = socket_keepalive_options or {} - self.socket_type = socket_type + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -194,7 +184,6 @@ def __init__( self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check: float = -1 - self.ssl_context: Optional[RedisSSLContext] = None self.encoder = encoder_class(encoding, encoding_errors, decode_responses) self.redis_connect_func = redis_connect_func self._reader: Optional[asyncio.StreamReader] = None @@ -218,23 +207,9 @@ def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) return f"{self.__class__.__name__}<{repr_args}>" + @abstractmethod def repr_pieces(self): - pieces = [("host", self.host), ("port", self.port), ("db", self.db)] - if self.client_name: - pieces.append(("client_name", self.client_name)) - return pieces - - def __del__(self): - try: - if self.is_connected: - loop = asyncio.get_running_loop() - coro = self.disconnect() - if loop.is_running(): - loop.create_task(coro) - else: - loop.run_until_complete(coro) - except Exception: - pass + pass @property def is_connected(self): @@ -293,51 +268,17 @@ async def connect(self): if task and inspect.isawaitable(task): await task + @abstractmethod async def _connect(self): - """Create a TCP socket connection""" - async with async_timeout(self.socket_connect_timeout): - reader, writer = await asyncio.open_connection( - host=self.host, - port=self.port, - ssl=self.ssl_context.get() if self.ssl_context else None, - ) - self._reader = reader - self._writer = writer - sock = writer.transport.get_extra_info("socket") - if sock: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - try: - # TCP_KEEPALIVE - if self.socket_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - for k, v in self.socket_keepalive_options.items(): - sock.setsockopt(socket.SOL_TCP, k, v) + pass - except (OSError, TypeError): - # `socket_keepalive_options` might contain invalid options - # causing an error. Do not leave the connection open. - writer.close() - raise + @abstractmethod + def _host_error(self) -> str: + pass - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - if not exception.args: - # asyncio has a bug where on Connection reset by peer, the - # exception is not instanciated, so args is empty. This is the - # workaround. - # See: https://github.com/redis/redis-py/issues/2237 - # See: https://github.com/python/cpython/issues/94061 - return ( - f"Error connecting to {self.host}:{self.port}. Connection reset by peer" - ) - elif len(exception.args) == 1: - return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." - else: - return ( - f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " - f"{exception.args[0]}." - ) + @abstractmethod + def _error_message(self, exception: BaseException) -> str: + pass async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" @@ -491,7 +432,11 @@ async def send_packed_command( raise ConnectionError( f"Error {err_no} while writing to socket. {errmsg}." ) from e - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. await self.disconnect(nowait=True) raise @@ -507,18 +452,20 @@ async def can_read_destructive(self): return await self._parser.can_read_destructive() except OSError as e: await self.disconnect(nowait=True) - raise ConnectionError( - f"Error while reading from {self.host}:{self.port}: {e.args}" - ) + host_error = self._host_error() + raise ConnectionError(f"Error while reading from {host_error}: {e.args}") async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + *, + disconnect_on_error: bool = True, push_request: Optional[bool] = False, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout + host_error = self._host_error() try: if ( read_timeout is not None @@ -544,22 +491,22 @@ async def read_response( ) except asyncio.TimeoutError: if timeout is not None: - # user requested timeout, return None + # user requested timeout, return None. Operation can be retried return None # it was a self.socket_timeout error. - await self.disconnect(nowait=True) - raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + if disconnect_on_error: + await self.disconnect(nowait=True) + raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - await self.disconnect(nowait=True) - raise ConnectionError( - f"Error while reading from {self.host}:{self.port} : {e.args}" - ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) + raise ConnectionError(f"Error while reading from {host_error} : {e.args}") + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) raise if self.health_check_interval: @@ -647,7 +594,90 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] return output +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connection_arguments(self) -> Mapping: + return {"host": self.host, "port": self.port} + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + **self._connection_arguments() + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _host_error(self) -> str: + return f"{self.host}:{self.port}" + + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if not exception.args: + # asyncio has a bug where on Connection reset by peer, the + # exception is not instanciated, so args is empty. This is the + # workaround. + # See: https://github.com/redis/redis-py/issues/2237 + # See: https://github.com/python/cpython/issues/94061 + return f"Error connecting to {host_error}. Connection reset by peer" + elif len(exception.args) == 1: + return f"Error connecting to {host_error}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[0]}." + ) + + class SSLConnection(Connection): + """Manages SSL connections to and from the Redis server(s). + This class extends the Connection class, adding SSL functionality, and making + use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) + """ + def __init__( self, ssl_keyfile: Optional[str] = None, @@ -658,7 +688,6 @@ def __init__( ssl_check_hostname: bool = False, **kwargs, ): - super().__init__(**kwargs) self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, @@ -667,6 +696,12 @@ def __init__( ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, ) + super().__init__(**kwargs) + + def _connection_arguments(self) -> Mapping: + kwargs = super()._connection_arguments() + kwargs["ssl"] = self.ssl_context.get() + return kwargs @property def keyfile(self): @@ -746,77 +781,12 @@ def get(self) -> ssl.SSLContext: return self.context -class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] - def __init__( - self, - *, - path: str = "", - db: Union[str, int] = 0, - username: Optional[str] = None, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - retry_on_timeout: bool = False, - retry_on_error: Union[list, _Sentinel] = SENTINEL, - parser_class: Type[BaseParser] = DefaultParser, - socket_read_size: int = 65536, - health_check_interval: float = 0.0, - client_name: str = None, - retry: Optional[Retry] = None, - redis_connect_func=None, - credential_provider: Optional[CredentialProvider] = None, - ): - """ - Initialize a new UnixDomainSocketConnection. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object - """ - if (username or password) and credential_provider is not None: - raise DataError( - "'username' and 'password' cannot be passed along with 'credential_" - "provider'. Please provide only one of the following arguments: \n" - "1. 'password' and (optional) 'username'\n" - "2. 'credential_provider'" - ) - self.pid = os.getpid() +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, *, path: str = "", **kwargs): self.path = path - self.db = db - self.client_name = client_name - self.credential_provider = credential_provider - self.password = password - self.username = username - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None - self.retry_on_timeout = retry_on_timeout - if retry_on_error is SENTINEL: - retry_on_error = [] - if retry_on_timeout: - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) - else: - self.retry = Retry(NoBackoff(), 0) - self.health_check_interval = health_check_interval - self.next_health_check = -1 - self.redis_connect_func = redis_connect_func - self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self._sock = None - self._reader = None - self._writer = None - self._socket_read_size = socket_read_size - self.set_parser(parser_class) - self._connect_callbacks = [] - self._buffer_cutoff = 6000 + super().__init__(**kwargs) def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: pieces = [("path", self.path), ("db", self.db)] @@ -831,15 +801,21 @@ async def _connect(self): self._writer = writer await self.on_connect() - def _error_message(self, exception): + def _host_error(self) -> str: + return self.path + + def _error_message(self, exception: BaseException) -> str: # args for socket.error can either be (errno, "message") # or just "message" + host_error = self._host_error() if len(exception.args) == 1: - return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) else: return ( f"Error {exception.args[0]} connecting to unix socket: " - f"{self.path}. {exception.args[1]}." + f"{host_error}. {exception.args[1]}." ) @@ -871,7 +847,7 @@ def to_bool(value) -> Optional[bool]: class ConnectKwargs(TypedDict, total=False): username: str password: str - connection_class: Type[Connection] + connection_class: Type[AbstractConnection] host: str port: int db: int @@ -993,7 +969,7 @@ class initializer. In the case of conflicting arguments, querystring def __init__( self, - connection_class: Type[Connection] = Connection, + connection_class: Type[AbstractConnection] = Connection, max_connections: Optional[int] = None, **connection_kwargs, ): @@ -1016,8 +992,8 @@ def __init__( self._fork_lock = threading.Lock() self._lock = asyncio.Lock() self._created_connections: int - self._available_connections: List[Connection] - self._in_use_connections: Set[Connection] + self._available_connections: List[AbstractConnection] + self._in_use_connections: Set[AbstractConnection] self.reset() # lgtm [py/init-calls-subclass] self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) @@ -1140,7 +1116,7 @@ def make_connection(self): self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - async def release(self, connection: Connection): + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool""" self._checkpid() async with self._lock: @@ -1161,7 +1137,7 @@ async def release(self, connection: Connection): await connection.disconnect() return - def owns_connection(self, connection: Connection): + def owns_connection(self, connection: AbstractConnection): return connection.pid == self.pid async def disconnect(self, inuse_connections: bool = True): @@ -1175,7 +1151,7 @@ async def disconnect(self, inuse_connections: bool = True): self._checkpid() async with self._lock: if inuse_connections: - connections: Iterable[Connection] = chain( + connections: Iterable[AbstractConnection] = chain( self._available_connections, self._in_use_connections ) else: @@ -1233,14 +1209,14 @@ def __init__( self, max_connections: int = 50, timeout: Optional[int] = 20, - connection_class: Type[Connection] = Connection, + connection_class: Type[AbstractConnection] = Connection, queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout - self._connections: List[Connection] + self._connections: List[AbstractConnection] super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1330,7 +1306,7 @@ async def get_connection(self, command_name, *keys, **options): return connection - async def release(self, connection: Connection): + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool.""" # Make sure we haven't changed process. self._checkpid() diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index ec17886fc6..501e234c3c 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -67,11 +67,14 @@ async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + *, + disconnect_on_error: Optional[float] = True, ): try: return await super().read_response( disable_decoding=disable_decoding, timeout=timeout, + disconnect_on_error=disconnect_on_error, ) except ReadOnlyError: if self.connection_pool.is_master: @@ -220,13 +223,13 @@ async def execute_command(self, *args, **kwargs): kwargs.pop("once") if once: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + else: tasks = [ asyncio.Task(sentinel.execute_command(*args, **kwargs)) for sentinel in self.sentinels ] await asyncio.gather(*tasks) - else: - await random.choice(self.sentinels).execute_command(*args, **kwargs) return True def __repr__(self): @@ -254,10 +257,12 @@ async def discover_master(self, service_name: str): Returns a pair (address, port) or raises MasterNotFoundError if no master is found. """ + collected_errors = list() for sentinel_no, sentinel in enumerate(self.sentinels): try: masters = await sentinel.sentinel_masters() - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + collected_errors.append(f"{sentinel} - {e!r}") continue state = masters.get(service_name) if state and self.check_master_state(state, service_name): @@ -267,7 +272,11 @@ async def discover_master(self, service_name: str): self.sentinels[0], ) return state["ip"], state["port"] - raise MasterNotFoundError(f"No master found for {service_name!r}") + + error_info = "" + if len(collected_errors) > 0: + error_info = f" : {', '.join(collected_errors)}" + raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}") def filter_slaves( self, slaves: Iterable[Mapping] diff --git a/redis/client.py b/redis/client.py index 165af1094e..09156bace6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -323,7 +323,7 @@ def parse_xinfo_stream(response, **options): else: data = {str_if_bytes(k): v for k, v in response.items()} if not options.get("full", False): - first = data["first-entry"] + first = data.get("first-entry") if first is not None: data["first-entry"] = (first[0], pairs_to_dict(first[1])) last = data["last-entry"] @@ -435,9 +435,13 @@ def parse_item(item): # an O(N) complexity) instead of the command. if isinstance(item[3], list): result["command"] = space.join(item[3]) + result["client_address"] = item[4] + result["client_name"] = item[5] else: result["complexity"] = item[3] result["command"] = space.join(item[4]) + result["client_address"] = item[5] + result["client_name"] = item[6] return result return [parse_item(item) for item in response] @@ -533,10 +537,13 @@ def parse_geosearch_generic(response, **options): Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' commands according to 'withdist', 'withhash' and 'withcoord' labels. """ - if options["store"] or options["store_dist"]: - # `store` and `store_dist` cant be combined - # with other command arguments. - # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + try: + if options["store"] or options["store_dist"]: + # `store` and `store_dist` cant be combined + # with other command arguments. + # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + return response + except KeyError: # it means the command was sent via execute_command return response if type(response) != list: @@ -976,8 +983,12 @@ class initializer. In the case of conflicting arguments, querystring arguments always win. """ + single_connection_client = kwargs.pop("single_connection_client", False) connection_pool = ConnectionPool.from_url(url, **kwargs) - return cls(connection_pool=connection_pool) + return cls( + connection_pool=connection_pool, + single_connection_client=single_connection_client, + ) def __init__( self, @@ -1625,7 +1636,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response(push_request=True) + return conn.read_response(disconnect_on_error=False, push_request=True) response = self._execute(conn, try_read) diff --git a/redis/cluster.py b/redis/cluster.py index c09faa1042..0fc715f838 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -116,6 +116,13 @@ def parse_cluster_shards(resp, **options): return shards +def parse_cluster_myshardid(resp, **options): + """ + Parse CLUSTER MYSHARDID response. + """ + return resp.decode("utf-8") + + PRIMARY = "primary" REPLICA = "replica" SLOT_ID = "slot-id" @@ -236,6 +243,7 @@ class AbstractRedisCluster: "SLOWLOG LEN", "SLOWLOG RESET", "WAIT", + "WAITAOF", "SAVE", "MEMORY PURGE", "MEMORY MALLOC-STATS", @@ -346,6 +354,7 @@ class AbstractRedisCluster: CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { "CLUSTER SLOTS": parse_cluster_slots, "CLUSTER SHARDS": parse_cluster_shards, + "CLUSTER MYSHARDID": parse_cluster_myshardid, } RESULT_CALLBACKS = dict_merge( @@ -473,6 +482,7 @@ def __init__( read_from_replicas: bool = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -521,6 +531,12 @@ def __init__( reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. + :param address_remap: + An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. :**kwargs: Extra arguments that will be sent into Redis instance when created @@ -601,6 +617,7 @@ def __init__( from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, + address_remap=address_remap, **kwargs, ) @@ -1276,6 +1293,7 @@ def __init__( lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): self.nodes_cache = {} @@ -1287,6 +1305,7 @@ def __init__( self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class + self.address_remap = address_remap self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1509,6 +1528,7 @@ def initialize(self): if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache @@ -1525,6 +1545,7 @@ def initialize(self): for replica_node in replica_nodes: host = str_if_bytes(replica_node[0]) port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = self._get_or_create_cluster_node( host, port, REPLICA, tmp_nodes_cache @@ -1598,6 +1619,16 @@ def reset(self): # The read_load_balancer is None, do nothing pass + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPubSub(PubSub): """ diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index a23a94a3d3..cd93a85aba 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: from redis.asyncio.cluster import TargetNodesT - # Not complete, but covers the major ones # https://redis.io/commands READ_COMMANDS = frozenset( @@ -634,6 +633,14 @@ def cluster_shards(self, target_nodes=None): """ return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes) + def cluster_myshardid(self, target_nodes=None): + """ + Returns the shard ID of the node. + + For more information see https://redis.io/commands/cluster-myshardid/ + """ + return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes) + def cluster_links(self, target_node: "TargetNodesT") -> ResponseT: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each diff --git a/redis/commands/core.py b/redis/commands/core.py index 9b3c37e196..6abcd5a2ec 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -761,6 +761,17 @@ def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: """ return self.execute_command("CLIENT NO-EVICT", mode) + def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: + """ + # The command controls whether commands sent by the client will alter + # the LRU/LFU of the keys they access. + # When turned on, the current client will not change LFU/LRU stats, + # unless it sends the TOUCH command. + + For more information see https://redis.io/commands/client-no-touch + """ + return self.execute_command("CLIENT NO-TOUCH", mode) + def command(self, **kwargs): """ Returns dict reply of details about all Redis commands. @@ -1330,6 +1341,21 @@ def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: """ return self.execute_command("WAIT", num_replicas, timeout, **kwargs) + def waitaof( + self, num_local: int, num_replicas: int, timeout: int, **kwargs + ) -> ResponseT: + """ + This command blocks the current client until all previous write + commands by that client are acknowledged as having been fsynced + to the AOF of the local Redis and/or at least the specified number + of replicas. + + For more information see https://redis.io/commands/waitaof + """ + return self.execute_command( + "WAITAOF", num_local, num_replicas, timeout, **kwargs + ) + def hello(self): """ This function throws a NotImplementedError since it is intentionally @@ -3489,8 +3515,8 @@ def xadd( raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") if maxlen is not None: - if not isinstance(maxlen, int) or maxlen < 1: - raise DataError("XADD maxlen must be a positive integer") + if not isinstance(maxlen, int) or maxlen < 0: + raise DataError("XADD maxlen must be non-negative integer") pieces.append(b"MAXLEN") if approximate: pieces.append(b"~") @@ -4647,13 +4673,22 @@ def zrevrangebyscore( options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrank(self, name: KeyT, value: EncodableT) -> ResponseT: + def zrank( + self, + name: KeyT, + value: EncodableT, + withscore: bool = False, + ) -> ResponseT: """ Returns a 0-based value indicating the rank of ``value`` in sorted set - ``name`` + ``name``. + The optional WITHSCORE argument supplements the command's + reply with the score of the element returned. For more information see https://redis.io/commands/zrank """ + if withscore: + return self.execute_command("ZRANK", name, value, "WITHSCORE") return self.execute_command("ZRANK", name, value) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: @@ -4697,13 +4732,22 @@ def zremrangebyscore( """ return self.execute_command("ZREMRANGEBYSCORE", name, min, max) - def zrevrank(self, name: KeyT, value: EncodableT) -> ResponseT: + def zrevrank( + self, + name: KeyT, + value: EncodableT, + withscore: bool = False, + ) -> ResponseT: """ Returns a 0-based value indicating the descending rank of - ``value`` in sorted set ``name`` + ``value`` in sorted set ``name``. + The optional ``withscore`` argument supplements the command's + reply with the score of the element returned. For more information see https://redis.io/commands/zrevrank """ + if withscore: + return self.execute_command("ZREVRANK", name, value, "WITHSCORE") return self.execute_command("ZREVRANK", name, value) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index a9e91fe74d..1980a25c03 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -36,6 +36,8 @@ def __init__( "JSON.MGET": bulk_of_jsons(self._decode), "JSON.SET": lambda r: r and nativestr(r) == "OK", "JSON.DEBUG": self._decode, + "JSON.MSET": lambda r: r and nativestr(r) == "OK", + "JSON.MERGE": lambda r: r and nativestr(r) == "OK", "JSON.TOGGLE": self._decode, "JSON.RESP": self._decode, } diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index c02c47ad86..3abe155796 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -1,6 +1,6 @@ import os from json import JSONDecodeError, loads -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from redis.exceptions import DataError from redis.utils import deprecated_function @@ -253,6 +253,46 @@ def set( pieces.append("XX") return self.execute_command("JSON.SET", *pieces) + def mset(self, triplets: List[Tuple[str, str, JsonType]]) -> Optional[str]: + """ + Set the JSON value at key ``name`` under the ``path`` to ``obj`` + for one or more keys. + + ``triplets`` is a list of one or more triplets of key, path, value. + + For the purpose of using this within a pipeline, this command is also + aliased to JSON.MSET. + + For more information see `JSON.MSET `_. + """ + pieces = [] + for triplet in triplets: + pieces.extend([triplet[0], str(triplet[1]), self._encode(triplet[2])]) + return self.execute_command("JSON.MSET", *pieces) + + def merge( + self, + name: str, + path: str, + obj: JsonType, + decode_keys: Optional[bool] = False, + ) -> Optional[str]: + """ + Merges a given JSON value into matching paths. Consequently, JSON values + at matching paths are updated, deleted, or expanded with new children + + ``decode_keys`` If set to True, the keys of ``obj`` will be decoded + with utf-8. + + For more information see `JSON.MERGE `_. + """ + if decode_keys: + obj = decode_dict_keys(obj) + + pieces = [name, str(path), self._encode(obj)] + + return self.execute_command("JSON.MERGE", *pieces) + def set_file( self, name: str, diff --git a/redis/connection.py b/redis/connection.py index 8c5c5a6ea7..845350df17 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -129,6 +129,8 @@ def __init__( self, db=0, password=None, + socket_timeout=None, + socket_connect_timeout=None, retry_on_timeout=False, retry_on_error=SENTINEL, encoding="utf-8", @@ -165,6 +167,10 @@ def __init__( self.credential_provider = credential_provider self.password = password self.username = username + self.socket_timeout = socket_timeout + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -363,20 +369,22 @@ def on_connect(self): def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() - if self._sock is None: + + conn_sock = self._sock + self._sock = None + if conn_sock is None: return if os.getpid() == self.pid: try: - self._sock.shutdown(socket.SHUT_RDWR) + conn_sock.shutdown(socket.SHUT_RDWR) except OSError: pass try: - self._sock.close() + conn_sock.close() except OSError: pass - self._sock = None def _send_ping(self): """Send PING, expect PONG in return""" @@ -416,7 +424,11 @@ def send_packed_command(self, command, check_health=True): errno = e.args[0] errmsg = e.args[1] raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. self.disconnect() raise @@ -441,7 +453,13 @@ def can_read(self, timeout=0): self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - def read_response(self, disable_decoding=False, push_request=False): + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): """Read the response from a previously sent command""" host_error = self._host_error() @@ -454,15 +472,21 @@ def read_response(self, disable_decoding=False, push_request=False): else: response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise ConnectionError( f"Error while reading from {host_error}" f" : {e.args}" ) - except Exception: - self.disconnect() + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() raise if self.health_check_interval: @@ -514,8 +538,6 @@ def __init__( self, host="localhost", port=6379, - socket_timeout=None, - socket_connect_timeout=None, socket_keepalive=False, socket_keepalive_options=None, socket_type=0, @@ -523,8 +545,6 @@ def __init__( ): self.host = host self.port = int(port) - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type @@ -759,8 +779,9 @@ def repr_pieces(self): def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(self.socket_timeout) + sock.settimeout(self.socket_connect_timeout) sock.connect(self.path) + sock.settimeout(self.socket_timeout) return sock def _host_error(self): diff --git a/redis/exceptions.py b/redis/exceptions.py index 8a8bf423eb..7cf15a7d07 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -49,6 +49,18 @@ class NoScriptError(ResponseError): pass +class OutOfMemoryError(ResponseError): + """ + Indicates the database is full. Can only occur when either: + * Redis maxmemory-policy=noeviction + * Redis maxmemory-policy=volatile* and there are no evictable keys + + For more information see `Memory optimization in Redis `_. # noqa + """ + + pass + + class ExecAbortError(ResponseError): pass @@ -131,6 +143,7 @@ class AskError(ResponseError): pertain to this hash slot, but only if the key in question exists, otherwise the query is forwarded using a -ASK redirection to the node that is target of the migration. + src node: MIGRATING to dst node get > ASK error ask dst node > ASKING command diff --git a/redis/parsers/base.py b/redis/parsers/base.py index b98a44ef2f..f77296df6a 100644 --- a/redis/parsers/base.py +++ b/redis/parsers/base.py @@ -17,6 +17,7 @@ ModuleError, NoPermissionError, NoScriptError, + OutOfMemoryError, ReadOnlyError, RedisError, ResponseError, @@ -64,6 +65,7 @@ class BaseParser(ABC): MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, **NO_AUTH_SET_ERROR, }, + "OOM": OutOfMemoryError, "WRONGPASS": AuthenticationError, "EXECABORT": ExecAbortError, "LOADING": BusyLoadingError, @@ -73,12 +75,13 @@ class BaseParser(ABC): "NOPERM": NoPermissionError, } - def parse_error(self, response): + @classmethod + def parse_error(cls, response): "Parse an error response" error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: + if error_code in cls.EXCEPTION_CLASSES: response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] + exception_class = cls.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) return exception_class(response) diff --git a/redis/parsers/resp2.py b/redis/parsers/resp2.py index 0acd21164f..d5adc1a898 100644 --- a/redis/parsers/resp2.py +++ b/redis/parsers/resp2.py @@ -10,11 +10,12 @@ class _RESP2Parser(_RESPBase): """RESP2 protocol implementation""" def read_response(self, disable_decoding=False): - pos = self._buffer.get_pos() + pos = self._buffer.get_pos() if self._buffer else None try: result = self._read_response(disable_decoding=disable_decoding) except BaseException: - self._buffer.rewind(pos) + if self._buffer: + self._buffer.rewind(pos) raise else: self._buffer.purge() diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index b443e45ae6..1275686710 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -20,13 +20,14 @@ def handle_push_response(self, response): return response def read_response(self, disable_decoding=False, push_request=False): - pos = self._buffer.get_pos() + pos = self._buffer.get_pos() if self._buffer else None try: result = self._read_response( disable_decoding=disable_decoding, push_request=push_request ) except BaseException: - self._buffer.rewind(pos) + if self._buffer: + self._buffer.rewind(pos) raise else: self._buffer.purge() diff --git a/redis/sentinel.py b/redis/sentinel.py index d70b7142b5..0ba179b9ca 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -1,5 +1,6 @@ import random import weakref +from typing import Optional from redis.client import Redis from redis.commands import SentinelCommands @@ -53,9 +54,14 @@ def _connect_retry(self): def connect(self): return self.retry.call_with_retry(self._connect_retry, lambda error: None) - def read_response(self, disable_decoding=False): + def read_response( + self, disable_decoding=False, *, disconnect_on_error: Optional[bool] = False + ): try: - return super().read_response(disable_decoding=disable_decoding) + return super().read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + ) except ReadOnlyError: if self.connection_pool.is_master: # When talking to a master, a ReadOnlyError when likely @@ -72,6 +78,54 @@ class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): pass +class SentinelConnectionPoolProxy: + def __init__( + self, + connection_pool, + is_master, + check_connection, + service_name, + sentinel_manager, + ): + self.connection_pool_ref = weakref.ref(connection_pool) + self.is_master = is_master + self.check_connection = check_connection + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.reset() + + def reset(self): + self.master_address = None + self.slave_rr_counter = None + + def get_master_address(self): + master_address = self.sentinel_manager.discover_master(self.service_name) + if self.is_master and self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + connection_pool = self.connection_pool_ref() + if connection_pool is not None: + connection_pool.disconnect(inuse_connections=False) + return master_address + + def rotate_slaves(self): + slaves = self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + class SentinelConnectionPool(ConnectionPool): """ Sentinel backed connection pool. @@ -89,8 +143,15 @@ def __init__(self, service_name, sentinel_manager, **kwargs): ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) + self.proxy = SentinelConnectionPoolProxy( + connection_pool=self, + is_master=self.is_master, + check_connection=self.check_connection, + service_name=service_name, + sentinel_manager=sentinel_manager, + ) super().__init__(**kwargs) - self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.connection_kwargs["connection_pool"] = self.proxy self.service_name = service_name self.sentinel_manager = sentinel_manager @@ -100,8 +161,11 @@ def __repr__(self): def reset(self): super().reset() - self.master_address = None - self.slave_rr_counter = None + self.proxy.reset() + + @property + def master_address(self): + return self.proxy.master_address def owns_connection(self, connection): check = not self.is_master or ( @@ -111,31 +175,11 @@ def owns_connection(self, connection): return check and parent.owns_connection(connection) def get_master_address(self): - master_address = self.sentinel_manager.discover_master(self.service_name) - if self.is_master: - if self.master_address != master_address: - self.master_address = master_address - # disconnect any idle connections so that they reconnect - # to the new master the next time that they are used. - self.disconnect(inuse_connections=False) - return master_address + return self.proxy.get_master_address() def rotate_slaves(self): "Round-robin slave balancer" - slaves = self.sentinel_manager.discover_slaves(self.service_name) - if slaves: - if self.slave_rr_counter is None: - self.slave_rr_counter = random.randint(0, len(slaves) - 1) - for _ in range(len(slaves)): - self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) - slave = slaves[self.slave_rr_counter] - yield slave - # Fallback to the master connection - try: - yield self.get_master_address() - except MasterNotFoundError: - pass - raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + return self.proxy.rotate_slaves() class Sentinel(SentinelCommands): @@ -230,10 +274,12 @@ def discover_master(self, service_name): Returns a pair (address, port) or raises MasterNotFoundError if no master is found. """ + collected_errors = list() for sentinel_no, sentinel in enumerate(self.sentinels): try: masters = sentinel.sentinel_masters() - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + collected_errors.append(f"{sentinel} - {e!r}") continue state = masters.get(service_name) if state and self.check_master_state(state, service_name): @@ -243,7 +289,11 @@ def discover_master(self, service_name): self.sentinels[0], ) return state["ip"], state["port"] - raise MasterNotFoundError(f"No master found for {service_name!r}") + + error_info = "" + if len(collected_errors) > 0: + error_info = f" : {', '.join(collected_errors)}" + raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}") def filter_slaves(self, slaves): "Remove slaves that are in an ODOWN or SDOWN state" diff --git a/redis/typing.py b/redis/typing.py index 7c5908ff0c..e555f57f5b 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -58,7 +58,7 @@ def execute_command(self, *args, **options): ... -class ClusterCommandsProtocol(CommandsProtocol): +class ClusterCommandsProtocol(CommandsProtocol, Protocol): encoder: "Encoder" def execute_command(self, *args, **options) -> Union[Any, Awaitable]: diff --git a/setup.py b/setup.py index e6fa0bd062..b68ceaaf18 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ install_requires=[ 'importlib-metadata >= 1.0; python_version < "3.8"', 'typing-extensions; python_version<"3.8"', - 'async-timeout>=4.0.2; python_version<"3.11"', + 'async-timeout>=4.0.2; python_full_version<="3.11.2"', ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/tests/asynctests b/tests/asynctests deleted file mode 100644 index 4f0fea9223..0000000000 --- a/tests/asynctests +++ /dev/null @@ -1,285 +0,0 @@ -test_response_callbacks -test_case_insensitive_command_names -test_command_on_invalid_key_type -test_acl_cat_no_category -test_acl_cat_with_category -test_acl_deluser -test_acl_genpass -test_acl_getuser_setuser -test_acl_list -test_acl_log -test_acl_setuser_categories_without_prefix_fails -test_acl_setuser_commands_without_prefix_fails -test_acl_setuser_add_passwords_and_nopass_fails -test_acl_users -test_acl_whoami -test_client_list -test_client_list_type -test_client_id -test_client_unblock -test_client_getname -test_client_setname -test_client_kill -test_client_kill_filter_invalid_params -test_client_kill_filter_by_id -test_client_kill_filter_by_addr -test_client_list_after_client_setname -test_client_pause -test_config_get -test_config_resetstat -test_config_set -test_dbsize -test_echo -test_info -test_lastsave -test_object -test_ping -test_slowlog_get -test_slowlog_get_limit -test_slowlog_length -test_time -test_never_decode_option -test_empty_response_option -test_append -test_bitcount -test_bitop_not_empty_string -test_bitop_not -test_bitop_not_in_place -test_bitop_single_string -test_bitop_string_operands -test_bitpos -test_bitpos_wrong_arguments -test_decr -test_decrby -test_delete -test_delete_with_multiple_keys -test_delitem -test_unlink -test_unlink_with_multiple_keys -test_dump_and_restore -test_dump_and_restore_and_replace -test_dump_and_restore_absttl -test_exists -test_exists_contains -test_expire -test_expireat_datetime -test_expireat_no_key -test_expireat_unixtime -test_get_and_set -test_get_set_bit -test_getrange -test_getset -test_incr -test_incrby -test_incrbyfloat -test_keys -test_mget -test_mset -test_msetnx -test_pexpire -test_pexpireat_datetime -test_pexpireat_no_key -test_pexpireat_unixtime -test_psetex -test_psetex_timedelta -test_pttl -test_pttl_no_key -test_randomkey -test_rename -test_renamenx -test_set_nx -test_set_xx -test_set_px -test_set_px_timedelta -test_set_ex -test_set_ex_timedelta -test_set_multipleoptions -test_set_keepttl -test_setex -test_setnx -test_setrange -test_strlen -test_substr -test_ttl -test_ttl_nokey -test_type -test_blpop -test_brpop -test_brpoplpush -test_brpoplpush_empty_string -test_lindex -test_linsert -test_llen -test_lpop -test_lpush -test_lpushx -test_lrange -test_lrem -test_lset -test_ltrim -test_rpop -test_rpoplpush -test_rpush -test_lpos -test_rpushx -test_scan -test_scan_type -test_scan_iter -test_sscan -test_sscan_iter -test_hscan -test_hscan_iter -test_zscan -test_zscan_iter -test_sadd -test_scard -test_sdiff -test_sdiffstore -test_sinter -test_sinterstore -test_sismember -test_smembers -test_smove -test_spop -test_spop_multi_value -test_srandmember -test_srandmember_multi_value -test_srem -test_sunion -test_sunionstore -test_zadd -test_zadd_nx -test_zadd_xx -test_zadd_ch -test_zadd_incr -test_zadd_incr_with_xx -test_zcard -test_zcount -test_zincrby -test_zlexcount -test_zinterstore_sum -test_zinterstore_max -test_zinterstore_min -test_zinterstore_with_weight -test_zpopmax -test_zpopmin -test_bzpopmax -test_bzpopmin -test_zrange -test_zrangebylex -test_zrevrangebylex -test_zrangebyscore -test_zrank -test_zrem -test_zrem_multiple_keys -test_zremrangebylex -test_zremrangebyrank -test_zremrangebyscore -test_zrevrange -test_zrevrangebyscore -test_zrevrank -test_zscore -test_zunionstore_sum -test_zunionstore_max -test_zunionstore_min -test_zunionstore_with_weight -test_pfadd -test_pfcount -test_pfmerge -test_hget_and_hset -test_hset_with_multi_key_values -test_hset_without_data -test_hdel -test_hexists -test_hgetall -test_hincrby -test_hincrbyfloat -test_hkeys -test_hlen -test_hmget -test_hmset -test_hsetnx -test_hvals -test_hstrlen -test_sort_basic -test_sort_limited -test_sort_by -test_sort_get -test_sort_get_multi -test_sort_get_groups_two -test_sort_groups_string_get -test_sort_groups_just_one_get -test_sort_groups_no_get -test_sort_groups_three_gets -test_sort_desc -test_sort_alpha -test_sort_store -test_sort_all_options -test_sort_issue_924 -test_cluster_addslots -test_cluster_count_failure_reports -test_cluster_countkeysinslot -test_cluster_delslots -test_cluster_failover -test_cluster_forget -test_cluster_info -test_cluster_keyslot -test_cluster_meet -test_cluster_nodes -test_cluster_replicate -test_cluster_reset -test_cluster_saveconfig -test_cluster_setslot -test_cluster_slaves -test_readwrite -test_readonly_invalid_cluster_state -test_readonly -test_geoadd -test_geoadd_invalid_params -test_geodist -test_geodist_units -test_geodist_missing_one_member -test_geodist_invalid_units -test_geohash -test_geopos -test_geopos_no_value -test_old_geopos_no_value -test_georadius -test_georadius_no_values -test_georadius_units -test_georadius_with -test_georadius_count -test_georadius_sort -test_georadius_store -test_georadius_store_dist -test_georadiusmember -test_xack -test_xadd -test_xclaim -test_xclaim_trimmed -test_xdel -test_xgroup_create -test_xgroup_create_mkstream -test_xgroup_delconsumer -test_xgroup_destroy -test_xgroup_setid -test_xinfo_consumers -test_xinfo_stream -test_xlen -test_xpending -test_xpending_range -test_xrange -test_xread -test_xreadgroup -test_xrevrange -test_xtrim -test_bitfield_operations -test_bitfield_ro -test_memory_stats -test_memory_usage -test_module_list -test_binary_get_set -test_binary_lists -test_22_info -test_large_responses -test_floating_point_encoding diff --git a/tests/conftest.py b/tests/conftest.py index 1d9bc44375..50459420ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -427,7 +427,7 @@ def mock_cluster_resp_slaves(request, **kwargs): def master_host(request): url = request.config.getoption("--redis-url") parts = urlparse(url) - yield parts.hostname, parts.port + return parts.hostname, (parts.port or 6379) def wait_for_command(client, monitor, command, key=None): diff --git a/tests/ssl_utils.py b/tests/ssl_utils.py new file mode 100644 index 0000000000..ab9c2e8944 --- /dev/null +++ b/tests/ssl_utils.py @@ -0,0 +1,14 @@ +import os + + +def get_ssl_filename(name): + root = os.path.join(os.path.dirname(__file__), "..") + cert_dir = os.path.abspath(os.path.join(root, "dockers", "stunnel", "keys")) + if not os.path.isdir(cert_dir): # github actions package validation case + cert_dir = os.path.abspath( + os.path.join(root, "..", "dockers", "stunnel", "keys") + ) + if not os.path.isdir(cert_dir): + raise IOError(f"No SSL certificates found. They should be in {cert_dir}") + + return os.path.join(cert_dir, name) diff --git a/tests/synctests b/tests/synctests deleted file mode 100644 index b0de2d1ba9..0000000000 --- a/tests/synctests +++ /dev/null @@ -1,421 +0,0 @@ -test_response_callbacks -test_case_insensitive_command_names -test_auth -test_command_on_invalid_key_type -test_acl_cat_no_category -test_acl_cat_with_category -test_acl_dryrun -test_acl_deluser -test_acl_genpass -test_acl_getuser_setuser -test_acl_help -test_acl_list -test_acl_log -test_acl_setuser_categories_without_prefix_fails -test_acl_setuser_commands_without_prefix_fails -test_acl_setuser_add_passwords_and_nopass_fails -test_acl_users -test_acl_whoami -test_client_list -test_client_info -test_client_list_types_not_replica -test_client_list_replica -test_client_list_client_id -test_client_id -test_client_trackinginfo -test_client_tracking -test_client_unblock -test_client_getname -test_client_setname -test_client_kill -test_client_kill_filter_invalid_params -test_client_kill_filter_by_id -test_client_kill_filter_by_addr -test_client_list_after_client_setname -test_client_kill_filter_by_laddr -test_client_kill_filter_by_user -test_client_pause -test_client_pause_all -test_client_unpause -test_client_no_evict -test_client_reply -test_client_getredir -test_hello_notI_implemented -test_config_get -test_config_get_multi_params -test_config_resetstat -test_config_set -test_config_set_multi_params -test_failover -test_dbsize -test_echo -test_info -test_info_multi_sections -test_lastsave -test_lolwut -test_reset -test_object -test_ping -test_quit -test_role -test_select -test_slowlog_get -test_slowlog_get_limit -test_slowlog_length -test_time -test_bgsave -test_never_decode_option -test_empty_response_option -test_append -test_bitcount -test_bitcount_mode -test_bitop_not_empty_string -test_bitop_not -test_bitop_not_in_place -test_bitop_single_string -test_bitop_string_operands -test_bitpos -test_bitpos_wrong_arguments -test_bitpos_mode -test_copy -test_copy_and_replace -test_copy_to_another_database -test_decr -test_decrby -test_delete -test_delete_with_multiple_keys -test_delitem -test_unlink -test_unlink_with_multiple_keys -test_lcs -test_dump_and_restore -test_dump_and_restore_and_replace -test_dump_and_restore_absttl -test_exists -test_exists_contains -test_expire -test_expire_option_nx -test_expire_option_xx -test_expire_option_gt -test_expire_option_lt -test_expireat_datetime -test_expireat_no_key -test_expireat_unixtime -test_expiretime -test_expireat_option_nx -test_expireat_option_xx -test_expireat_option_gt -test_expireat_option_lt -test_get_and_set -test_getdel -test_getex -test_getitem_and_setitem -test_getitem_raises_keyerror_for_missing_key -test_getitem_does_not_raise_keyerror_for_empty_string -test_get_set_bit -test_getrange -test_getset -test_incr -test_incrby -test_incrbyfloat -test_keys -test_mget -test_lmove -test_blmove -test_mset -test_msetnx -test_pexpire -test_pexpire_option_nx -test_pexpire_option_xx -test_pexpire_option_gt -test_pexpire_option_lt -test_pexpireat_datetime -test_pexpireat_no_key -test_pexpireat_unixtime -test_pexpireat_option_nx -test_pexpireat_option_xx -test_pexpireat_option_gt -test_pexpireat_option_lt -test_pexpiretime -test_psetex -test_psetex_timedelta -test_pttl -test_pttl_no_key -test_hrandfield -test_randomkey -test_rename -test_renamenx -test_set_nx -test_set_xx -test_set_px -test_set_px_timedelta -test_set_ex -test_set_ex_str -test_set_ex_timedelta -test_set_exat_timedelta -test_set_pxat_timedelta -test_set_multipleoptions -test_set_keepttl -test_set_get -test_setex -test_setnx -test_setrange -test_stralgo_lcs -test_stralgo_negative -test_strlen -test_substr -test_ttl -test_ttl_nokey -test_type -test_blpop -test_brpop -test_brpoplpush -test_brpoplpush_empty_string -test_blmpop -test_lmpop -test_lindex -test_linsert -test_llen -test_lpop -test_lpop_count -test_lpush -test_lpushx -test_lpushx_with_list -test_lrange -test_lrem -test_lset -test_ltrim -test_rpop -test_rpop_count -test_rpoplpush -test_rpush -test_lpos -test_rpushx -test_scan -test_scan_type -test_scan_iter -test_sscan -test_sscan_iter -test_hscan -test_hscan_iter -test_zscan -test_zscan_iter -test_sadd -test_scard -test_sdiff -test_sdiffstore -test_sinter -test_sintercard -test_sinterstore -test_sismember -test_smembers -test_smismember -test_smove -test_spop -test_spop_multi_value -test_srandmember -test_srandmember_multi_value -test_srem -test_sunion -test_sunionstore -test_debug_segfault -test_script_debug -test_zadd -test_zadd_nx -test_zadd_xx -test_zadd_ch -test_zadd_incr -test_zadd_incr_with_xx -test_zadd_gt_lt -test_zcard -test_zcount -test_zdiff -test_zdiffstore -test_zincrby -test_zlexcount -test_zinter -test_zintercard -test_zinterstore_sum -test_zinterstore_max -test_zinterstore_min -test_zinterstore_with_weight -test_zpopmax -test_zpopmin -test_zrandemember -test_bzpopmax -test_bzpopmin -test_zmpop -test_bzmpop -test_zrange -test_zrange_errors -test_zrange_params -test_zrangestore -test_zrangebylex -test_zrevrangebylex -test_zrangebyscore -test_zrank -test_zrem -test_zrem_multiple_keys -test_zremrangebylex -test_zremrangebyrank -test_zremrangebyscore -test_zrevrange -test_zrevrangebyscore -test_zrevrank -test_zscore -test_zunion -test_zunionstore_sum -test_zunionstore_max -test_zunionstore_min -test_zunionstore_with_weight -test_zmscore -test_pfadd -test_pfcount -test_pfmerge -test_hget_and_hset -test_hset_with_multi_key_values -test_hset_with_key_values_passed_as_list -test_hset_without_data -test_hdel -test_hexists -test_hgetall -test_hincrby -test_hincrbyfloat -test_hkeys -test_hlen -test_hmget -test_hmset -test_hsetnx -test_hvals -test_hstrlen -test_sort_basic -test_sort_limited -test_sort_by -test_sort_get -test_sort_get_multi -test_sort_get_groups_two -test_sort_groups_string_get -test_sort_groups_just_one_get -test_sort_groups_no_get -test_sort_groups_three_gets -test_sort_desc -test_sort_alpha -test_sort_store -test_sort_all_options -test_sort_ro -test_sort_issue_924 -test_cluster_addslots -test_cluster_count_failure_reports -test_cluster_countkeysinslot -test_cluster_delslots -test_cluster_failover -test_cluster_forget -test_cluster_info -test_cluster_keyslot -test_cluster_meet -test_cluster_nodes -test_cluster_replicate -test_cluster_reset -test_cluster_saveconfig -test_cluster_setslot -test_cluster_slaves -test_readwrite -test_readonly_invalid_cluster_state -test_readonly -test_geoadd -test_geoadd_nx -test_geoadd_xx -test_geoadd_ch -test_geoadd_invalid_params -test_geodist -test_geodist_units -test_geodist_missing_one_member -test_geodist_invalid_units -test_geohash -test_geopos -test_geopos_no_value -test_old_geopos_no_value -test_geosearch -test_geosearch_member -test_geosearch_sort -test_geosearch_with -test_geosearch_negative -test_geosearchstore -test_geosearchstore_dist -test_georadius -test_georadius_no_values -test_georadius_units -test_georadius_with -test_georadius_count -test_georadius_sort -test_georadius_store -test_georadius_store_dist -test_georadiusmember -test_georadiusmember_count -test_xack -test_xadd -test_xadd_nomkstream -test_xadd_minlen_and_limit -test_xadd_explicit_ms -test_xautoclaim -test_xautoclaim_negative -test_xclaim -test_xclaim_trimmed -test_xdel -test_xgroup_create -test_xgroup_create_mkstream -test_xgroup_create_entriesread -test_xgroup_delconsumer -test_xgroup_createconsumer -test_xgroup_destroy -test_xgroup_setid -test_xinfo_consumers -test_xinfo_stream -test_xinfo_stream_full -test_xlen -test_xpending -test_xpending_range -test_xpending_range_idle -test_xpending_range_negative -test_xrange -test_xread -test_xreadgroup -test_xrevrange -test_xtrim -test_xtrim_minlen_and_length_args -test_bitfield_operations -test -test_bitfield_ro -test_memory_help -test_memory_doctor -test_memory_malloc_stats -test_memory_stats -test_memory_usage -test_latency_histogram_not_implemented -test_latency_graph_not_implemented -test_latency_doctor_not_implemented -test_latency_history -test_latency_latest -test_latency_reset -test_module_list -test_command_count -test_command_docs -test_command_list -test_command_getkeys -test_command -test_command_getkeysandflags -test_module -test_module_loadex -test_restore -test_restore_idletime -test_restore_frequency -test_replicaof -test_shutdown -test_shutdown_with_params -test_sync -test_psync -test_binary_get_set -test_binary_lists -test_22_info -test_large_responses -test_floating_point_encoding diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index ac18f6c12d..a7d121fa49 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,7 +1,6 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager from typing import Union -from urllib.parse import urlparse import pytest import pytest_asyncio @@ -207,13 +206,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): return _gen_cluster_mock_resp(r, response) -@pytest_asyncio.fixture(scope="session") -def master_host(request): - url = request.config.getoption("--redis-url") - parts = urlparse(url) - return parts.hostname - - async def wait_for_command( client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None ): diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 1d12877696..2c722826e1 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,7 +1,6 @@ import asyncio import binascii import datetime -import os import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union from urllib.parse import urlparse @@ -10,7 +9,7 @@ import pytest_asyncio from _pytest.fixtures import FixtureRequest from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster -from redis.asyncio.connection import Connection, SSLConnection +from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name @@ -36,6 +35,7 @@ skip_unless_arch_bits, ) +from ..ssl_utils import get_ssl_filename from .compat import mock pytestmark = pytest.mark.onlycluster @@ -49,6 +49,59 @@ ] +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.send_event = asyncio.Event() + self.server = None + self.task = None + self.n_connections = 0 + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + self.n_connections += 1 + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def aclose(self): + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + await self.server.wait_closed() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + while True: + data = await reader.read(1000) + if not data: + break + writer.write(data) + await writer.drain() + + @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -340,23 +393,6 @@ async def test_from_url(self, request: FixtureRequest) -> None: rc = RedisCluster.from_url("rediss://localhost:16379") assert rc.connection_kwargs["connection_class"] is SSLConnection - async def test_asynckills(self, r) -> None: - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(1) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection is left open with unread response") - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" - async def test_max_connections( self, create_redis: Callable[..., RedisCluster] ) -> None: @@ -826,6 +862,49 @@ async def test_default_node_is_replaced_after_exception(self, r): # Rollback to the old default node r.replace_default_node(curr_default_node) + async def test_address_remap(self, create_redis, master_host): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + hostname, master_port = master_host + ports = [master_port + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports + ] + await asyncio.gather(*[p.start() for p in proxies]) + try: + # create cluster: + r = await create_redis( + cls=RedisCluster, flushdb=False, address_remap=address_remap + ) + try: + assert await r.ping() is True + assert await r.set("byte_string", b"giraffe") + assert await r.get("byte_string") == b"giraffe" + finally: + await r.close() + finally: + await asyncio.gather(*[p.aclose() for p in proxies]) + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + class TestClusterRedisCommands: """ @@ -927,6 +1006,13 @@ async def test_cluster_myid(self, r: RedisCluster) -> None: myid = await r.cluster_myid(node) assert len(myid) == 40 + @skip_if_server_version_lt("7.2.0") + @skip_if_redis_enterprise() + async def test_cluster_myshardid(self, r: RedisCluster) -> None: + node = r.get_random_node() + myshardid = await r.cluster_myshardid(node) + assert len(myshardid) == 40 + @skip_if_redis_enterprise() async def test_cluster_slots(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, default_cluster_slots) @@ -2690,17 +2776,8 @@ class TestSSL: appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "../..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "dockers", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "dockers", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") @pytest_asyncio.fixture() def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 7e7a40adf3..bcedda80ea 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1,9 +1,11 @@ """ Tests async overrides of commands from their mixins """ +import asyncio import binascii import datetime import re +import sys from string import ascii_letters import pytest @@ -20,6 +22,11 @@ skip_unless_arch_bits, ) +if sys.version_info >= (3, 11, 3): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + REDIS_6_VERSION = "5.9.0" @@ -446,6 +453,28 @@ async def test_client_pause(self, r: redis.Redis): with pytest.raises(exceptions.RedisError): await r.client_pause(timeout="not an integer") + @skip_if_server_version_lt("7.2.0") + @pytest.mark.onlynoncluster + async def test_client_no_touch(self, r: redis.Redis): + assert await r.client_no_touch("ON") == b"OK" + assert await r.client_no_touch("OFF") == b"OK" + with pytest.raises(TypeError): + await r.client_no_touch() + + @skip_if_server_version_lt("7.2.0") + @pytest.mark.onlycluster + async def test_waitaof(self, r): + # must return a list of 2 elements + assert len(await r.waitaof(0, 0, 0)) == 2 + assert len(await r.waitaof(1, 0, 0)) == 2 + assert len(await r.waitaof(1, 0, 1000)) == 2 + + # value is out of range, value must between 0 and 1 + with pytest.raises(exceptions.ResponseError): + await r.waitaof(2, 0, 0) + with pytest.raises(exceptions.ResponseError): + await r.waitaof(-1, 0, 0) + async def test_config_get(self, r: redis.Redis): data = await r.config_get() assert "maxmemory" in data @@ -1696,6 +1725,15 @@ async def test_zrank(self, r: redis.Redis): assert await r.zrank("a", "a2") == 1 assert await r.zrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + async def test_zrank_withscore(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrank("a", "a1") == 0 + assert await r.rank("a", "a2") == 1 + assert await r.zrank("a", "a6") is None + assert await r.zrank("a", "a3", withscore=True) == [2, "3"] + assert await r.zrank("a", "a6", withscore=True) is None + async def test_zrem(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zrem("a", "a2") == 1 @@ -1784,6 +1822,15 @@ async def test_zrevrank(self, r: redis.Redis): assert await r.zrevrank("a", "a2") == 3 assert await r.zrevrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + async def test_zrevrank_withscore(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrevrank("a", "a1") == 4 + assert await r.zrevrank("a", "a2") == 3 + assert await r.zrevrank("a", "a6") is None + assert await r.zrevrank("a", "a3", withscore=True) == [2, "3"] + assert await r.zrevrank("a", "a6", withscore=True) is None + async def test_zscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zscore("a", "a1") == 1.0 @@ -3079,6 +3126,37 @@ async def test_module_list(self, r: redis.Redis): for x in await r.module_list(): assert isinstance(x, dict) + @pytest.mark.onlynoncluster + async def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + ready = asyncio.Event() + + async def helper(): + with pytest.raises(asyncio.CancelledError): + # blocking pop + ready.set() + await r.brpop(["nonexist"]) + # If the following is not done, further Timout operations will fail, + # because the timeout won't catch its Cancelled Error if the task + # has a pending cancel. Python documentation probably should reflect this. + if sys.version_info >= (3, 11): + asyncio.current_task().uncancel() + # if all is well, we can continue. The following should not hang. + await r.set("status", "down") + + task = asyncio.create_task(helper()) + await ready.wait() + await asyncio.sleep(0.01) + # the task is now sleeping, lets send it an exception + task.cancel() + # If all is well, the task should finish right away, otherwise fail with Timeout + async with async_timeout(1.0): + await task + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py new file mode 100644 index 0000000000..bead7208f5 --- /dev/null +++ b/tests/test_asyncio/test_connect.py @@ -0,0 +1,144 @@ +import asyncio +import logging +import re +import socket +import ssl + +import pytest +from redis.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, +) + +from ..ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +async def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + await _assert_connect(conn, tcp_address) + + +async def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection( + path=path, client_name=_CLIENT_NAME, socket_timeout=10 + ) + await _assert_connect(conn, path) + + +@pytest.mark.ssl +async def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +async def _assert_connect(conn, server_address, certfile=None, keyfile=None): + stop_event = asyncio.Event() + finished = asyncio.Event() + + async def _handler(reader, writer): + try: + return await _redis_request_handler(reader, writer, stop_event) + finally: + finished.set() + + if isinstance(server_address, str): + server = await asyncio.start_unix_server(_handler, path=server_address) + elif certfile: + host, port = server_address + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.load_cert_chain(certfile=certfile, keyfile=keyfile) + server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) + else: + host, port = server_address + server = await asyncio.start_server(_handler, host=host, port=port) + + async with server as aserver: + await aserver.start_serving() + try: + await conn.connect() + await conn.disconnect() + finally: + stop_event.set() + aserver.close() + await aserver.wait_closed() + await finished.wait() + + +async def _redis_request_handler(reader, writer, stop_event): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while not stop_event.is_set() or buffer: + _logger.info(str(stop_event.is_set())) + try: + buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) + except TimeoutError: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response from %s", resp) + writer.write(resp) + await writer.drain() + command = None + _logger.info("Exit handler") diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 926b432b62..ee4a107566 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -43,27 +43,6 @@ async def test_invalid_response(create_redis): await r.connection.disconnect() -async def test_asynckills(): - - for b in [True, False]: - r = Redis(single_connection_client=b) - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(1) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection left open with unread response") - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" - - @pytest.mark.onlynoncluster async def test_single_connection(): """Test that concurrent requests on a single client are synchronised.""" @@ -204,7 +183,7 @@ async def test_connection_parse_response_resume(r: redis.Redis): conn._parser._stream = MockStream(message, interrupt_every=2) for i in range(100): try: - response = await conn.read_response() + response = await conn.read_response(disconnect_on_error=False) break except MockStream.TestError: pass @@ -293,3 +272,9 @@ async def open_connection(*args, **kwargs): vals = await asyncio.gather(do_read(), do_close()) assert vals == [b"Hello, World!", None] + + +@pytest.mark.onlynoncluster +def test_create_single_connection_client_from_url(): + client = Redis.from_url("redis://localhost:6379/0?", single_connection_client=True) + assert client.single_connection_client is True diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 20c2c79c84..7672dc74b4 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -135,14 +135,14 @@ async def test_connection_creation(self): assert connection.kwargs == connection_kwargs async def test_multiple_connections(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") c2 = await pool.get_connection("_") assert c1 != c2 async def test_max_connections(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool( max_connections=2, connection_kwargs=connection_kwargs ) as pool: @@ -152,7 +152,7 @@ async def test_max_connections(self, master_host): await pool.get_connection("_") async def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") await pool.release(c1) @@ -236,7 +236,7 @@ async def test_multiple_connections(self, master_host): async def test_connection_pool_blocks_until_timeout(self, master_host): """When out of connections, block for timeout seconds, then raise""" - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) as pool: @@ -271,7 +271,7 @@ async def target(): assert (stop - start) <= 0.2 async def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") await pool.release(c1) @@ -607,10 +607,18 @@ async def test_busy_loading_from_pipeline(self, r): @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() async def test_read_only_error(self, r): - """READONLY errors get turned in ReadOnlyError exceptions""" + """READONLY errors get turned into ReadOnlyError exceptions""" with pytest.raises(redis.ReadOnlyError): await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + @skip_if_redis_enterprise() + async def test_oom_error(self, r): + """OOM errors get turned into OutOfMemoryError exceptions""" + with pytest.raises(redis.OutOfMemoryError): + # note: don't use the DEBUG OOM command since it's not the same + # as the db being full + await r.execute_command("DEBUG", "ERROR", "OOM blah blah") + def test_connect_from_url_tcp(self): connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py new file mode 100644 index 0000000000..ff588861e4 --- /dev/null +++ b/tests/test_asyncio/test_cwe_404.py @@ -0,0 +1,250 @@ +import asyncio +import contextlib + +import pytest +from redis.asyncio import Redis +from redis.asyncio.cluster import RedisCluster +from redis.asyncio.connection import async_timeout + + +class DelayProxy: + def __init__(self, addr, redis_addr, delay: float = 0.0): + self.addr = addr + self.redis_addr = redis_addr + self.delay = delay + self.send_event = asyncio.Event() + self.server = None + self.task = None + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, *args): + await self.stop() + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + @contextlib.contextmanager + def set_delay(self, delay: float = 0.0): + """ + Allow to override the delay for parts of tests which aren't time dependent, + to speed up execution. + """ + old_delay = self.delay + self.delay = delay + try: + yield + finally: + self.delay = old_delay + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + pipe1 = asyncio.create_task( + self.pipe(reader, redis_writer, "to redis:", self.send_event) + ) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:")) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def stop(self): + # clean up enough so that we can reuse the looper + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + loop = self.server.get_loop() + await loop.shutdown_asyncgens() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + name="", + event: asyncio.Event = None, + ): + while True: + data = await reader.read(1000) + if not data: + break + # print(f"{name} read {len(data)} delay {self.delay}") + if event: + event.set() + await asyncio.sleep(self.delay) + writer.write(data) + await writer.drain() + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone(delay, master_host): + + # create a tcp socket proxy that relays data to Redis and back, + # inserting 0.1 seconds of delay + async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: + + for b in [True, False]: + # note that we connect to proxy, rather than to Redis directly + async with Redis( + host="127.0.0.1", port=5380, single_connection_client=b + ) as r: + + await r.set("foo", "foo") + await r.set("bar", "bar") + + async def op(r): + with dp.set_delay(delay * 2): + return await r.get( + "foo" + ) # <-- this is the operation we want to cancel + + dp.send_event.clear() + t = asyncio.create_task(op(r)) + # Wait until the task has sent, and then some, to make sure it has + # settled on the read. + await dp.send_event.wait() + await asyncio.sleep(0.01) # a little extra time for prudence + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # make sure that our previous request, cancelled while waiting for + # a repsponse, didn't leave the connection open andin a bad state + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone_pipeline(delay, master_host): + async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: + for b in [True, False]: + async with Redis( + host="127.0.0.1", port=5380, single_connection_client=b + ) as r: + await r.set("foo", "foo") + await r.set("bar", "bar") + + pipe = r.pipeline() + + pipe2 = r.pipeline() + pipe2.get("bar") + pipe2.ping() + pipe2.get("foo") + + async def op(pipe): + with dp.set_delay(delay * 2): + return await pipe.get( + "foo" + ).execute() # <-- this is the operation we want to cancel + + dp.send_event.clear() + t = asyncio.create_task(op(pipe)) + # wait until task has settled on the read + await dp.send_event.wait() + await asyncio.sleep(0.01) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # we have now cancelled the pieline in the middle of a request, + # make sure that the connection is still usable + pipe.get("bar") + pipe.ping() + pipe.get("foo") + await pipe.reset() + + # check that the pipeline is empty after reset + assert await pipe.execute() == [] + + # validating that the pipeline can be used as it could previously + pipe.get("bar") + pipe.ping() + pipe.get("foo") + assert await pipe.execute() == [b"bar", True, b"foo"] + assert await pipe2.execute() == [b"bar", True, b"foo"] + + +@pytest.mark.onlycluster +async def test_cluster(master_host): + + delay = 0.1 + cluster_port = 16379 + remap_base = 7372 + n_nodes = 6 + hostname, _ = master_host + + def remap(address): + host, port = address + return host, remap_base + port - cluster_port + + proxies = [] + for i in range(n_nodes): + port = cluster_port + i + remapped = remap_base + i + forward_addr = hostname, port + proxy = DelayProxy(addr=("127.0.0.1", remapped), redis_addr=forward_addr) + proxies.append(proxy) + + def all_clear(): + for p in proxies: + p.send_event.clear() + + async def wait_for_send(): + asyncio.wait( + [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + ) + + @contextlib.contextmanager + def set_delay(delay: float): + with contextlib.ExitStack() as stack: + for p in proxies: + stack.enter_context(p.set_delay(delay)) + yield + + async with contextlib.AsyncExitStack() as stack: + for p in proxies: + await stack.enter_async_context(p) + + with contextlib.closing( + RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", address_remap=remap + ) + ) as r: + await r.initialize() + await r.set("foo", "foo") + await r.set("bar", "bar") + + async def op(r): + with set_delay(delay): + return await r.get("foo") + + all_clear() + t = asyncio.create_task(op(r)) + # Wait for whichever DelayProxy gets the request first + await wait_for_send() + await asyncio.sleep(0.01) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # try a number of requests to excercise all the connections + async def doit(): + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + await asyncio.gather(*[doit() for _ in range(10)]) diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index 78176f4710..6f3e8c3251 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -49,6 +49,41 @@ async def test_nonascii_setgetdelete(decoded_r: redis.Redis): assert await decoded_r.exists("notascii") == 0 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.6.0", "ReJSON") +async def test_json_merge(decoded_r: redis.Redis): + # Test with root path $ + assert await decoded_r.json().set( + "person_data", + "$", + {"person1": {"personal_data": {"name": "John"}}}, + ) + assert await decoded_r.json().merge( + "person_data", "$", {"person1": {"personal_data": {"hobbies": "reading"}}} + ) + assert await decoded_r.json().get("person_data") == { + "person1": {"personal_data": {"name": "John", "hobbies": "reading"}} + } + + # Test with root path path $.person1.personal_data + assert await decoded_r.json().merge( + "person_data", "$.person1.personal_data", {"country": "Israel"} + ) + assert await decoded_r.json().get("person_data") == { + "person1": { + "personal_data": {"name": "John", "hobbies": "reading", "country": "Israel"} + } + } + + # Test with null value to delete a value + assert await decoded_r.json().merge( + "person_data", "$.person1.personal_data", {"name": None} + ) + assert await decoded_r.json().get("person_data") == { + "person1": {"personal_data": {"country": "Israel", "hobbies": "reading"}} + } + + @pytest.mark.redismod async def test_jsonsetexistentialmodifiersshouldsucceed(decoded_r: redis.Redis): obj = {"foo": "bar"} @@ -76,6 +111,17 @@ async def test_mgetshouldsucceed(decoded_r: redis.Redis): assert await decoded_r.json().mget([1, 2], Path.root_path()) == [1, 2] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.6.0", "ReJSON") +async def test_mset(decoded_r: redis.Redis): + await decoded_r.json().mset( + [("1", Path.root_path(), 1), ("2", Path.root_path(), 2)] + ) + + assert await decoded_r.json().mget(["1"], Path.root_path()) == [1] + assert await decoded_r.json().mget(["1", "2"], Path.root_path()) == [1, 2] + + @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release async def test_clear(decoded_r: redis.Redis): diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8354abe45b..8cac17dac5 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -5,7 +5,9 @@ from typing import Optional from unittest.mock import patch -if sys.version_info.major >= 3 and sys.version_info.minor >= 11: +# the functionality is available in 3.11.x but has a major issue before +# 3.11.3. See https://github.com/redis/redis-py/issues/2633 +if sys.version_info >= (3, 11, 3): from asyncio import timeout as async_timeout else: from async_timeout import timeout as async_timeout @@ -984,6 +986,9 @@ async def get_msg_or_timeout(timeout=0.1): # the timeout on the read should not cause disconnect assert pubsub.connection.is_connected + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="requires python 3.8 or higher" + ) async def test_base_exception(self, r: redis.Redis): """ Manually trigger a BaseException inside the parser's .read_response method diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 4f32ecdc08..2091f2cb87 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -14,7 +14,7 @@ @pytest_asyncio.fixture(scope="module") def master_ip(master_host): - yield socket.gethostbyname(master_host) + yield socket.gethostbyname(master_host[0]) class SentinelTestClient: diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 834831fabd..a3a2a6beab 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,5 +1,9 @@ import binascii import datetime +import select +import socket +import socketserver +import threading import warnings from queue import LifoQueue, Queue from time import sleep @@ -53,6 +57,73 @@ ] +class ProxyRequestHandler(socketserver.BaseRequestHandler): + def recv(self, sock): + """A recv with a timeout""" + r = select.select([sock], [], [], 0.01) + if not r[0]: + return None + return sock.recv(1000) + + def handle(self): + self.server.proxy.n_connections += 1 + conn = socket.create_connection(self.server.proxy.redis_addr) + stop = False + + def from_server(): + # read from server and pass to client + while not stop: + data = self.recv(conn) + if data is None: + continue + if not data: + self.request.shutdown(socket.SHUT_WR) + return + self.request.sendall(data) + + thread = threading.Thread(target=from_server) + thread.start() + try: + while True: + # read from client and send to server + data = self.request.recv(1000) + if not data: + return + conn.sendall(data) + finally: + conn.shutdown(socket.SHUT_WR) + stop = True # for safety + thread.join() + conn.close() + + +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) + self.server.proxy = self + self.server.socket_reuse_address = True + self.thread = None + self.n_connections = 0 + + def start(self): + # test that we can connect to redis + s = socket.create_connection(self.redis_addr, timeout=2) + s.close() + # Start a thread with the server -- that thread will then start one + # more thread for each request + self.thread = threading.Thread(target=self.server.serve_forever) + # Exit the server thread when the main thread terminates + self.thread.daemon = True + self.thread.start() + + def close(self): + self.server.shutdown() + + @pytest.fixture() def slowlog(request, r): """ @@ -823,6 +894,51 @@ def raise_connection_error(): assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node + def test_address_remap(self, request, master_host): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + hostname, master_port = master_host + ports = [master_port + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports + ] + for p in proxies: + p.start() + try: + # create cluster: + r = _get_client( + RedisCluster, request, flushdb=False, address_remap=address_remap + ) + try: + assert r.ping() is True + assert r.set("byte_string", b"giraffe") + assert r.get("byte_string") == b"giraffe" + finally: + r.close() + finally: + for p in proxies: + p.close() + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + @pytest.mark.onlycluster class TestClusterRedisCommands: @@ -1046,6 +1162,13 @@ def test_cluster_shards(self, r): for attribute in node.keys(): assert attribute in attributes + @skip_if_server_version_lt("7.2.0") + @skip_if_redis_enterprise() + def test_cluster_myshardid(self, r): + myshardid = r.cluster_myshardid() + assert isinstance(myshardid, str) + assert len(myshardid) > 0 + @skip_if_redis_enterprise() def test_cluster_addslots(self, r): node = r.get_random_node() diff --git a/tests/test_commands.py b/tests/test_commands.py index a024167877..1f17552c15 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,9 +1,12 @@ import binascii import datetime import re +import threading import time +from asyncio import CancelledError from string import ascii_letters from unittest import mock +from unittest.mock import patch import pytest import redis @@ -708,6 +711,28 @@ def test_client_no_evict(self, r): with pytest.raises(TypeError): r.client_no_evict() + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.2.0") + def test_client_no_touch(self, r): + assert r.client_no_touch("ON") == b"OK" + assert r.client_no_touch("OFF") == b"OK" + with pytest.raises(TypeError): + r.client_no_touch() + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.2.0") + def test_waitaof(self, r): + # must return a list of 2 elements + assert len(r.waitaof(0, 0, 0)) == 2 + assert len(r.waitaof(1, 0, 0)) == 2 + assert len(r.waitaof(1, 0, 1000)) == 2 + + # value is out of range, value must between 0 and 1 + with pytest.raises(exceptions.ResponseError): + r.waitaof(2, 0, 0) + with pytest.raises(exceptions.ResponseError): + r.waitaof(-1, 0, 0) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("3.2.0") def test_client_reply(self, r, r_timeout): @@ -876,6 +901,8 @@ def test_slowlog_get(self, r, slowlog): # make sure other attributes are typed correctly assert isinstance(slowlog[0]["start_time"], int) assert isinstance(slowlog[0]["duration"], int) + assert isinstance(slowlog[0]["client_address"], bytes) + assert isinstance(slowlog[0]["client_name"], bytes) # Mock result if we didn't get slowlog complexity info. if "complexity" not in slowlog[0]: @@ -2710,6 +2737,15 @@ def test_zrank(self, r): assert r.zrank("a", "a2") == 1 assert r.zrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + def test_zrank_withscore(self, r: redis.Redis): + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrank("a", "a1") == 0 + assert r.rank("a", "a2") == 1 + assert r.zrank("a", "a6") is None + assert r.zrank("a", "a3", withscore=True) == [2, "3"] + assert r.zrank("a", "a6", withscore=True) is None + def test_zrem(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert r.zrem("a", "a2") == 1 @@ -2796,6 +2832,15 @@ def test_zrevrank(self, r): assert r.zrevrank("a", "a2") == 3 assert r.zrevrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + def test_zrevrank_withscore(self, r): + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrevrank("a", "a1") == 4 + assert r.zrevrank("a", "a2") == 3 + assert r.zrevrank("a", "a6") is None + assert r.zrevrank("a", "a3", withscore=True) == [2, "3"] + assert r.zrevrank("a", "a6", withscore=True) is None + def test_zscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert r.zscore("a", "a1") == 1.0 @@ -3663,6 +3708,12 @@ def test_geosearchstore_dist(self, r): # instead of save the geo score, the distance is saved. assert r.zscore("places_barcelona", "place1") == 88.05060698409301 + @skip_if_server_version_lt("3.2.0") + def test_georadius_Issue2609(self, r): + # test for issue #2609 (Geo search functions don't work with execute_command) + r.geoadd(name="my-key", values=[1, 2, "data"]) + assert r.execute_command("GEORADIUS", "my-key", 1, 2, 400, "m") == [b"data"] + @skip_if_server_version_lt("3.2.0") def test_georadius(self, r): values = (2.1909389952632, 41.433791470673, "place1") + ( @@ -4928,6 +4979,38 @@ def test_psync(self, r): res = r2.psync(r2.client_id(), 1) assert b"FULLRESYNC" in res + @pytest.mark.onlynoncluster + def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + + ok = False + + def helper(): + with pytest.raises(CancelledError): + # blocking pop + with patch.object( + r.connection._parser, "read_response", side_effect=CancelledError + ): + r.brpop(["nonexist"]) + # if all is well, we can continue. + r.set("status", "down") # should not hang + nonlocal ok + ok = True + + thread = threading.Thread(target=helper) + thread.start() + thread.join(0.1) + try: + assert not thread.is_alive() + assert ok + finally: + # disconnect here so that fixture cleanup can proceed + r.connection.disconnect() + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_connect.py b/tests/test_connect.py new file mode 100644 index 0000000000..b233c67e83 --- /dev/null +++ b/tests/test_connect.py @@ -0,0 +1,184 @@ +import logging +import re +import socket +import socketserver +import ssl +import threading + +import pytest +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection + +from .ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, tcp_address) + + +def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, path) + + +@pytest.mark.ssl +def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +def _assert_connect(conn, server_address, certfile=None, keyfile=None): + if isinstance(server_address, str): + server = _RedisUDSServer(server_address, _RedisRequestHandler) + else: + server = _RedisTCPServer( + server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile + ) + with server as aserver: + t = threading.Thread(target=aserver.serve_forever) + t.start() + try: + aserver.wait_online() + conn.connect() + conn.disconnect() + finally: + aserver.stop() + t.join(timeout=5) + + +class _RedisTCPServer(socketserver.TCPServer): + def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + self._certfile = certfile + self._keyfile = keyfile + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + def get_request(self): + if self._certfile is None: + return super().get_request() + newsocket, fromaddr = self.socket.accept() + connstream = ssl.wrap_socket( + newsocket, + server_side=True, + certfile=self._certfile, + keyfile=self._keyfile, + ssl_version=ssl.PROTOCOL_TLSv1_2, + ) + return connstream, fromaddr + + +class _RedisUDSServer(socketserver.UnixStreamServer): + def __init__(self, *args, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + +class _RedisRequestHandler(socketserver.StreamRequestHandler): + def setup(self): + _logger.info("%s connected", self.client_address) + + def finish(self): + _logger.info("%s disconnected", self.client_address) + + def handle(self): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while self.server.is_serving() or buffer: + try: + buffer += self.request.recv(1024) + except socket.timeout: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response %s", resp) + self.request.sendall(resp) + command = None + _logger.info("Exit handler") diff --git a/tests/test_connection.py b/tests/test_connection.py index 1ae3d73ede..64ae4c5d1f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -156,7 +156,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): conn._parser._sock = mock_socket for i in range(100): try: - response = conn.read_response() + response = conn.read_response(disconnect_on_error=False) break except MockSocket.TestError: pass @@ -201,3 +201,11 @@ def test_pack_command(Class): actual = Class().pack_command(*cmd)[0] assert actual == expected, f"actual = {actual}, expected = {expected}" + + +@pytest.mark.onlynoncluster +def test_create_single_connection_client_from_url(): + client = redis.Redis.from_url( + "redis://localhost:6379/0?", single_connection_client=True + ) + assert client.connection is not None diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 888e0226eb..ab0fc6be98 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -528,10 +528,17 @@ def test_busy_loading_from_pipeline(self, r): @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() def test_read_only_error(self, r): - "READONLY errors get turned in ReadOnlyError exceptions" + "READONLY errors get turned into ReadOnlyError exceptions" with pytest.raises(redis.ReadOnlyError): r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + def test_oom_error(self, r): + "OOM errors get turned into OutOfMemoryError exceptions" + with pytest.raises(redis.OutOfMemoryError): + # note: don't use the DEBUG OOM command since it's not the same + # as the db being full + r.execute_command("DEBUG", "ERROR", "OOM blah blah") + def test_connect_from_url_tcp(self): connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool diff --git a/tests/test_json.py b/tests/test_json.py index a1271386d9..fb608ff425 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -47,6 +47,39 @@ def test_json_get_jset(client): assert client.exists("foo") == 0 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.6.0", "ReJSON") # todo: update after the release +def test_json_merge(client): + # Test with root path $ + assert client.json().set( + "person_data", + "$", + {"person1": {"personal_data": {"name": "John"}}}, + ) + assert client.json().merge( + "person_data", "$", {"person1": {"personal_data": {"hobbies": "reading"}}} + ) + assert client.json().get("person_data") == { + "person1": {"personal_data": {"name": "John", "hobbies": "reading"}} + } + + # Test with root path path $.person1.personal_data + assert client.json().merge( + "person_data", "$.person1.personal_data", {"country": "Israel"} + ) + assert client.json().get("person_data") == { + "person1": { + "personal_data": {"name": "John", "hobbies": "reading", "country": "Israel"} + } + } + + # Test with null value to delete a value + assert client.json().merge("person_data", "$.person1.personal_data", {"name": None}) + assert client.json().get("person_data") == { + "person1": {"personal_data": {"country": "Israel", "hobbies": "reading"}} + } + + @pytest.mark.redismod def test_nonascii_setgetdelete(client): assert client.json().set("notascii", Path.root_path(), "hyvää-élève") @@ -85,6 +118,15 @@ def test_mgetshouldsucceed(client): assert client.json().mget([1, 2], Path.root_path()) == [1, 2] +@pytest.mark.redismod +@skip_ifmodversion_lt("2.6.0", "ReJSON") # todo: update after the release +def test_mset(client): + client.json().mset([("1", Path.root_path(), 1), ("2", Path.root_path(), 2)]) + + assert client.json().mget(["1"], Path.root_path()) == [1] + assert client.json().mget(["1", "2"], Path.root_path()) == [1, 2] + + @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release def test_clear(client): diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index d797a0467b..b7bcc27de2 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -97,6 +97,15 @@ def test_discover_master_error(sentinel): sentinel.discover_master("xxx") +@pytest.mark.onlynoncluster +def test_dead_pool(sentinel): + master = sentinel.master_for("mymaster", db=9) + conn = master.connection_pool.get_connection("_") + conn.disconnect() + del master + conn.connect() + + @pytest.mark.onlynoncluster def test_discover_master_sentinel_down(cluster, sentinel, master_ip): # Put first sentinel 'foo' down diff --git a/tests/test_ssl.py b/tests/test_ssl.py index f33e45a60b..465fdabb89 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,4 +1,3 @@ -import os import socket import ssl from urllib.parse import urlparse @@ -8,6 +7,7 @@ from redis.exceptions import ConnectionError, RedisError from .conftest import skip_if_cryptography, skip_if_nocryptography +from .ssl_utils import get_ssl_filename @pytest.mark.ssl @@ -18,17 +18,8 @@ class TestSSL: and connecting to the appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "dockers", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "dockers", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") def test_ssl_with_invalid_cert(self, request): ssl_url = request.config.option.redis_ssl_url