diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..0e0ace6a4a --- /dev/null +++ b/.flake8 @@ -0,0 +1,21 @@ +[flake8] +max-line-length = 88 +exclude = + *.egg-info, + *.pyc, + .git, + .tox, + .venv*, + build, + docs/*, + dist, + docker, + venv*, + .venv*, + whitelist.py, + tasks.py +ignore = + F405 + W503 + E203 + E126 \ No newline at end of file diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index f49a4fcd46..1bab506c32 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -16,6 +16,10 @@ on: schedule: - cron: '0 1 * * *' # nightly build +concurrency: + group: ${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + permissions: contents: read # to fetch code (actions/checkout) @@ -48,7 +52,7 @@ jobs: run-tests: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: max-parallel: 15 fail-fast: false @@ -68,32 +72,77 @@ jobs: - name: run tests run: | pip install -U setuptools wheel + pip install -r requirements.txt pip install -r dev_requirements.txt - tox -e ${{matrix.test-type}}-${{matrix.connection-type}} + if [ "${{matrix.connection-type}}" == "hiredis" ]; then + pip install hiredis + fi + invoke devenv + sleep 5 # time to settle + invoke ${{matrix.test-type}}-tests + - uses: actions/upload-artifact@v2 if: success() || failure() with: - name: pytest-results-${{matrix.test-type}} + name: pytest-results-${{matrix.test-type}}-${{matrix.connection-type}}-${{matrix.python-version}} path: '${{matrix.test-type}}*results.xml' + - name: Upload codecov coverage uses: codecov/codecov-action@v3 + if: ${{matrix.python-version == '3.11'}} with: fail_ci_if_error: false - # - name: View Test Results - # uses: dorny/test-reporter@v1 - # if: success() || failure() - # with: - # name: Test Results ${{matrix.python-version}} ${{matrix.test-type}}-${{matrix.connection-type}} - # path: '${{matrix.test-type}}*results.xml' - # reporter: java-junit - # list-suites: failed - # list-tests: failed - # max-annotations: 10 + + - name: View Test Results + uses: dorny/test-reporter@v1 + if: success() || failure() + continue-on-error: true + with: + name: Test Results ${{matrix.python-version}} ${{matrix.test-type}}-${{matrix.connection-type}} + path: '*.xml' + reporter: java-junit + list-suites: all + list-tests: all + max-annotations: 10 + fail-on-error: 'false' + + resp3_tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.7', '3.11'] + test-type: ['standalone', 'cluster'] + connection-type: ['hiredis', 'plain'] + protocol: ['3'] + env: + ACTIONS_ALLOW_UNSECURE_COMMANDS: true + name: RESP3 [${{ matrix.python-version }} ${{matrix.test-type}}-${{matrix.connection-type}}] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: run tests + run: | + pip install -U setuptools wheel + pip install -r requirements.txt + pip install -r dev_requirements.txt + if [ "${{matrix.connection-type}}" == "hiredis" ]; then + pip install hiredis + fi + invoke devenv + sleep 5 # time to settle + invoke ${{matrix.test-type}}-tests + invoke ${{matrix.test-type}}-tests --uvloop build_and_test_package: name: Validate building and installing the package runs-on: ubuntu-latest + needs: [run-tests] strategy: + fail-fast: false matrix: extension: ['tar.gz', 'whl'] steps: diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000000..039f0337a2 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +profile=black +multi_line_output=3 +src_paths = ["redis", "tests"] +skip_glob=benchmarks/* \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8518547518..90a538be46 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,8 +38,9 @@ Here's how to get started with your code contribution: a. python -m venv .venv b. source .venv/bin/activate c. pip install -r dev_requirements.txt + c. pip install -r requirements.txt -4. If you need a development environment, run `invoke devenv` +4. If you need a development environment, run `invoke devenv`. Note: this relies on docker-compose to build environments, and assumes that you have a version supporting [docker profiles](https://docs.docker.com/compose/profiles/). 5. While developing, make sure the tests pass by running `invoke tests` 6. If you like the change and think the project could use it, send a pull request @@ -59,7 +60,6 @@ can execute docker and its various commands. - Three sentinel Redis nodes - A redis cluster - An stunnel docker, fronting the master Redis node -- A Redis node, running unstable - the latest redis The replica node, is a replica of the master node, using the [leader-follower replication](https://redis.io/topics/replication) diff --git a/benchmarks/socket_read_size.py b/benchmarks/socket_read_size.py index 3427956ced..544c733178 100644 --- a/benchmarks/socket_read_size.py +++ b/benchmarks/socket_read_size.py @@ -1,12 +1,12 @@ from base import Benchmark -from redis.connection import HiredisParser, PythonParser +from redis.connection import PythonParser, _HiredisParser class SocketReadBenchmark(Benchmark): ARGUMENTS = ( - {"name": "parser", "values": [PythonParser, HiredisParser]}, + {"name": "parser", "values": [PythonParser, _HiredisParser]}, { "name": "value_size", "values": [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000], diff --git a/dev_requirements.txt b/dev_requirements.txt index 8ffb1e944f..cdb3774ab6 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,15 +1,14 @@ click==8.0.4 black==22.3.0 flake8==5.0.4 +flake8-isort==6.0.0 flynt~=0.69.0 -isort==5.10.1 mock==4.0.3 packaging>=20.4 pytest==7.2.0 -pytest-timeout==2.0.1 +pytest-timeout==2.1.0 pytest-asyncio>=0.20.2 tox==3.27.1 -tox-docker==3.1.0 invoke==1.7.3 pytest-cov>=4.0.0 vulture>=2.3.0 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000..17d4b23977 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,109 @@ +--- + +version: "3.8" + +services: + + redis: + image: redis/redis-stack-server:edge + container_name: redis-standalone + ports: + - 6379:6379 + environment: + - "REDIS_ARGS=--enable-debug-command yes --enable-module-command yes" + profiles: + - standalone + - sentinel + - replica + - all + + replica: + image: redis/redis-stack-server:edge + container_name: redis-replica + depends_on: + - redis + environment: + - "REDIS_ARGS=--replicaof redis 6379" + ports: + - 6380:6379 + profiles: + - replica + - all + + cluster: + container_name: redis-cluster + build: + context: . + dockerfile: dockers/Dockerfile.cluster + ports: + - 16379:16379 + - 16380:16380 + - 16381:16381 + - 16382:16382 + - 16383:16383 + - 16384:16384 + volumes: + - "./dockers/cluster.redis.conf:/redis.conf:ro" + profiles: + - cluster + - all + + stunnel: + image: redisfab/stunnel:latest + depends_on: + - redis + ports: + - 6666:6666 + profiles: + - all + - standalone + - ssl + volumes: + - "./dockers/stunnel/conf:/etc/stunnel/conf.d:ro" + - "./dockers/stunnel/keys:/etc/stunnel/keys:ro" + + sentinel: + image: redis/redis-stack-server:edge + container_name: redis-sentinel + depends_on: + - redis + environment: + - "REDIS_ARGS=--port 26379" + entrypoint: "/opt/redis-stack/bin/redis-sentinel /redis.conf --port 26379" + ports: + - 26379:26379 + volumes: + - "./dockers/sentinel.conf:/redis.conf" + profiles: + - sentinel + - all + + sentinel2: + image: redis/redis-stack-server:edge + container_name: redis-sentinel2 + depends_on: + - redis + environment: + - "REDIS_ARGS=--port 26380" + entrypoint: "/opt/redis-stack/bin/redis-sentinel /redis.conf --port 26380" + ports: + - 26380:26380 + volumes: + - "./dockers/sentinel.conf:/redis.conf" + profiles: + - sentinel + - all + + sentinel3: + image: redis/redis-stack-server:edge + container_name: redis-sentinel3 + depends_on: + - redis + entrypoint: "/opt/redis-stack/bin/redis-sentinel /redis.conf --port 26381" + ports: + - 26381:26381 + volumes: + - "./dockers/sentinel.conf:/redis.conf" + profiles: + - sentinel + - all diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile deleted file mode 100644 index c76d15db36..0000000000 --- a/docker/base/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py:6.2.6 -FROM redis:6.2.6-buster - -CMD ["redis-server", "/redis.conf"] diff --git a/docker/base/Dockerfile.cluster b/docker/base/Dockerfile.cluster deleted file mode 100644 index 5c246dcf28..0000000000 --- a/docker/base/Dockerfile.cluster +++ /dev/null @@ -1,11 +0,0 @@ -# produces redisfab/redis-py-cluster:6.2.6 -FROM redis:6.2.6-buster - -COPY create_cluster.sh /create_cluster.sh -RUN chmod +x /create_cluster.sh - -EXPOSE 16379 16380 16381 16382 16383 16384 - -ENV START_PORT=16379 -ENV END_PORT=16384 -CMD /create_cluster.sh diff --git a/docker/base/Dockerfile.cluster4 b/docker/base/Dockerfile.cluster4 deleted file mode 100644 index 3158d6edd4..0000000000 --- a/docker/base/Dockerfile.cluster4 +++ /dev/null @@ -1,9 +0,0 @@ -# produces redisfab/redis-py-cluster:4.0 -FROM redis:4.0-buster - -COPY create_cluster4.sh /create_cluster4.sh -RUN chmod +x /create_cluster4.sh - -EXPOSE 16391 16392 16393 16394 16395 16396 - -CMD [ "/create_cluster4.sh"] \ No newline at end of file diff --git a/docker/base/Dockerfile.cluster5 b/docker/base/Dockerfile.cluster5 deleted file mode 100644 index 3becfc853a..0000000000 --- a/docker/base/Dockerfile.cluster5 +++ /dev/null @@ -1,9 +0,0 @@ -# produces redisfab/redis-py-cluster:5.0 -FROM redis:5.0-buster - -COPY create_cluster5.sh /create_cluster5.sh -RUN chmod +x /create_cluster5.sh - -EXPOSE 16385 16386 16387 16388 16389 16390 - -CMD [ "/create_cluster5.sh"] \ No newline at end of file diff --git a/docker/base/Dockerfile.redis4 b/docker/base/Dockerfile.redis4 deleted file mode 100644 index 7528ac1631..0000000000 --- a/docker/base/Dockerfile.redis4 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py:4.0 -FROM redis:4.0-buster - -CMD ["redis-server", "/redis.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.redis5 b/docker/base/Dockerfile.redis5 deleted file mode 100644 index 6bcbe20bfc..0000000000 --- a/docker/base/Dockerfile.redis5 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py:5.0 -FROM redis:5.0-buster - -CMD ["redis-server", "/redis.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.redismod_cluster b/docker/base/Dockerfile.redismod_cluster deleted file mode 100644 index 5b80e495fb..0000000000 --- a/docker/base/Dockerfile.redismod_cluster +++ /dev/null @@ -1,12 +0,0 @@ -# produces redisfab/redis-py-modcluster:6.2.6 -FROM redislabs/redismod:edge - -COPY create_redismod_cluster.sh /create_redismod_cluster.sh -RUN chmod +x /create_redismod_cluster.sh - -EXPOSE 46379 46380 46381 46382 46383 46384 - -ENV START_PORT=46379 -ENV END_PORT=46384 -ENTRYPOINT [] -CMD /create_redismod_cluster.sh diff --git a/docker/base/Dockerfile.sentinel b/docker/base/Dockerfile.sentinel deleted file mode 100644 index ef659e3004..0000000000 --- a/docker/base/Dockerfile.sentinel +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py-sentinel:6.2.6 -FROM redis:6.2.6-buster - -CMD ["redis-sentinel", "/sentinel.conf"] diff --git a/docker/base/Dockerfile.sentinel4 b/docker/base/Dockerfile.sentinel4 deleted file mode 100644 index 45bb03e88e..0000000000 --- a/docker/base/Dockerfile.sentinel4 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py-sentinel:4.0 -FROM redis:4.0-buster - -CMD ["redis-sentinel", "/sentinel.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.sentinel5 b/docker/base/Dockerfile.sentinel5 deleted file mode 100644 index 6958154e46..0000000000 --- a/docker/base/Dockerfile.sentinel5 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py-sentinel:5.0 -FROM redis:5.0-buster - -CMD ["redis-sentinel", "/sentinel.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.stunnel b/docker/base/Dockerfile.stunnel deleted file mode 100644 index bf4510907c..0000000000 --- a/docker/base/Dockerfile.stunnel +++ /dev/null @@ -1,11 +0,0 @@ -# produces redisfab/stunnel:latest -FROM ubuntu:18.04 - -RUN apt-get update -qq --fix-missing -RUN apt-get upgrade -qqy -RUN apt install -qqy stunnel -RUN mkdir -p /etc/stunnel/conf.d -RUN echo "foreground = yes\ninclude = /etc/stunnel/conf.d" > /etc/stunnel/stunnel.conf -RUN chown -R root:root /etc/stunnel/ - -CMD ["/usr/bin/stunnel"] diff --git a/docker/base/Dockerfile.unstable b/docker/base/Dockerfile.unstable deleted file mode 100644 index ab5b7fc6fb..0000000000 --- a/docker/base/Dockerfile.unstable +++ /dev/null @@ -1,18 +0,0 @@ -# produces redisfab/redis-py:unstable -FROM ubuntu:bionic as builder -RUN apt-get update -RUN apt-get upgrade -y -RUN apt-get install -y build-essential git -RUN mkdir /build -WORKDIR /build -RUN git clone https://github.com/redis/redis -WORKDIR /build/redis -RUN make - -FROM ubuntu:bionic as runner -COPY --from=builder /build/redis/src/redis-server /usr/bin/redis-server -COPY --from=builder /build/redis/src/redis-cli /usr/bin/redis-cli -COPY --from=builder /build/redis/src/redis-sentinel /usr/bin/redis-sentinel - -EXPOSE 6379 -CMD ["redis-server", "/redis.conf"] diff --git a/docker/base/Dockerfile.unstable_cluster b/docker/base/Dockerfile.unstable_cluster deleted file mode 100644 index 2e3ed55371..0000000000 --- a/docker/base/Dockerfile.unstable_cluster +++ /dev/null @@ -1,11 +0,0 @@ -# produces redisfab/redis-py-cluster:6.2.6 -FROM redisfab/redis-py:unstable-bionic - -COPY create_cluster.sh /create_cluster.sh -RUN chmod +x /create_cluster.sh - -EXPOSE 6372 6373 6374 6375 6376 6377 - -ENV START_PORT=6372 -ENV END_PORT=6377 -CMD ["/create_cluster.sh"] diff --git a/docker/base/Dockerfile.unstable_sentinel b/docker/base/Dockerfile.unstable_sentinel deleted file mode 100644 index fe6d062de8..0000000000 --- a/docker/base/Dockerfile.unstable_sentinel +++ /dev/null @@ -1,17 +0,0 @@ -# produces redisfab/redis-py-sentinel:unstable -FROM ubuntu:bionic as builder -RUN apt-get update -RUN apt-get upgrade -y -RUN apt-get install -y build-essential git -RUN mkdir /build -WORKDIR /build -RUN git clone https://github.com/redis/redis -WORKDIR /build/redis -RUN make - -FROM ubuntu:bionic as runner -COPY --from=builder /build/redis/src/redis-server /usr/bin/redis-server -COPY --from=builder /build/redis/src/redis-cli /usr/bin/redis-cli -COPY --from=builder /build/redis/src/redis-sentinel /usr/bin/redis-sentinel - -CMD ["redis-sentinel", "/sentinel.conf"] diff --git a/docker/base/README.md b/docker/base/README.md deleted file mode 100644 index a2f26a8106..0000000000 --- a/docker/base/README.md +++ /dev/null @@ -1 +0,0 @@ -Dockers in this folder are built, and uploaded to the redisfab dockerhub store. diff --git a/docker/base/create_cluster4.sh b/docker/base/create_cluster4.sh deleted file mode 100755 index a39da58784..0000000000 --- a/docker/base/create_cluster4.sh +++ /dev/null @@ -1,26 +0,0 @@ -#! /bin/bash -mkdir -p /nodes -touch /nodes/nodemap -for PORT in $(seq 16391 16396); do - mkdir -p /nodes/$PORT - if [[ -e /redis.conf ]]; then - cp /redis.conf /nodes/$PORT/redis.conf - else - touch /nodes/$PORT/redis.conf - fi - cat << EOF >> /nodes/$PORT/redis.conf -port ${PORT} -cluster-enabled yes -daemonize yes -logfile /redis.log -dir /nodes/$PORT -EOF - redis-server /nodes/$PORT/redis.conf - if [ $? -ne 0 ]; then - echo "Redis failed to start, exiting." - exit 3 - fi - echo 127.0.0.1:$PORT >> /nodes/nodemap -done -echo yes | redis-cli --cluster create $(seq -f 127.0.0.1:%g 16391 16396) --cluster-replicas 1 -tail -f /redis.log \ No newline at end of file diff --git a/docker/base/create_cluster5.sh b/docker/base/create_cluster5.sh deleted file mode 100755 index 0c63d8e910..0000000000 --- a/docker/base/create_cluster5.sh +++ /dev/null @@ -1,26 +0,0 @@ -#! /bin/bash -mkdir -p /nodes -touch /nodes/nodemap -for PORT in $(seq 16385 16390); do - mkdir -p /nodes/$PORT - if [[ -e /redis.conf ]]; then - cp /redis.conf /nodes/$PORT/redis.conf - else - touch /nodes/$PORT/redis.conf - fi - cat << EOF >> /nodes/$PORT/redis.conf -port ${PORT} -cluster-enabled yes -daemonize yes -logfile /redis.log -dir /nodes/$PORT -EOF - redis-server /nodes/$PORT/redis.conf - if [ $? -ne 0 ]; then - echo "Redis failed to start, exiting." - exit 3 - fi - echo 127.0.0.1:$PORT >> /nodes/nodemap -done -echo yes | redis-cli --cluster create $(seq -f 127.0.0.1:%g 16385 16390) --cluster-replicas 1 -tail -f /redis.log \ No newline at end of file diff --git a/docker/base/create_redismod_cluster.sh b/docker/base/create_redismod_cluster.sh deleted file mode 100755 index 20443a4c42..0000000000 --- a/docker/base/create_redismod_cluster.sh +++ /dev/null @@ -1,46 +0,0 @@ -#! /bin/bash - -mkdir -p /nodes -touch /nodes/nodemap -if [ -z ${START_PORT} ]; then - START_PORT=46379 -fi -if [ -z ${END_PORT} ]; then - END_PORT=46384 -fi -if [ ! -z "$3" ]; then - START_PORT=$2 - START_PORT=$3 -fi -echo "STARTING: ${START_PORT}" -echo "ENDING: ${END_PORT}" - -for PORT in `seq ${START_PORT} ${END_PORT}`; do - mkdir -p /nodes/$PORT - if [[ -e /redis.conf ]]; then - cp /redis.conf /nodes/$PORT/redis.conf - else - touch /nodes/$PORT/redis.conf - fi - cat << EOF >> /nodes/$PORT/redis.conf -port ${PORT} -cluster-enabled yes -daemonize yes -logfile /redis.log -dir /nodes/$PORT -EOF - - set -x - redis-server /nodes/$PORT/redis.conf - if [ $? -ne 0 ]; then - echo "Redis failed to start, exiting." - continue - fi - echo 127.0.0.1:$PORT >> /nodes/nodemap -done -if [ -z "${REDIS_PASSWORD}" ]; then - echo yes | redis-cli --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 -else - echo yes | redis-cli -a ${REDIS_PASSWORD} --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 -fi -tail -f /redis.log diff --git a/docker/cluster/redis.conf b/docker/cluster/redis.conf deleted file mode 100644 index dff658c79b..0000000000 --- a/docker/cluster/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -# Redis Cluster config file will be shared across all nodes. -# Do not change the following configurations that are already set: -# port, cluster-enabled, daemonize, logfile, dir diff --git a/docker/redis4/master/redis.conf b/docker/redis4/master/redis.conf deleted file mode 100644 index b7ed0ebf00..0000000000 --- a/docker/redis4/master/redis.conf +++ /dev/null @@ -1,2 +0,0 @@ -port 6381 -save "" diff --git a/docker/redis4/sentinel/sentinel_1.conf b/docker/redis4/sentinel/sentinel_1.conf deleted file mode 100644 index cfee17c051..0000000000 --- a/docker/redis4/sentinel/sentinel_1.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26385 - -sentinel monitor redis-py-test 127.0.0.1 6381 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis4/sentinel/sentinel_2.conf b/docker/redis4/sentinel/sentinel_2.conf deleted file mode 100644 index 68d930aea8..0000000000 --- a/docker/redis4/sentinel/sentinel_2.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26386 - -sentinel monitor redis-py-test 127.0.0.1 6381 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis4/sentinel/sentinel_3.conf b/docker/redis4/sentinel/sentinel_3.conf deleted file mode 100644 index 60abf65c9b..0000000000 --- a/docker/redis4/sentinel/sentinel_3.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26387 - -sentinel monitor redis-py-test 127.0.0.1 6381 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis5/master/redis.conf b/docker/redis5/master/redis.conf deleted file mode 100644 index e479c48b28..0000000000 --- a/docker/redis5/master/redis.conf +++ /dev/null @@ -1,2 +0,0 @@ -port 6382 -save "" diff --git a/docker/redis5/replica/redis.conf b/docker/redis5/replica/redis.conf deleted file mode 100644 index a2dc9e0945..0000000000 --- a/docker/redis5/replica/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -port 6383 -save "" -replicaof master 6382 diff --git a/docker/redis5/sentinel/sentinel_1.conf b/docker/redis5/sentinel/sentinel_1.conf deleted file mode 100644 index c748a0ba72..0000000000 --- a/docker/redis5/sentinel/sentinel_1.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26382 - -sentinel monitor redis-py-test 127.0.0.1 6382 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis5/sentinel/sentinel_2.conf b/docker/redis5/sentinel/sentinel_2.conf deleted file mode 100644 index 0a50c9a623..0000000000 --- a/docker/redis5/sentinel/sentinel_2.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26383 - -sentinel monitor redis-py-test 127.0.0.1 6382 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis5/sentinel/sentinel_3.conf b/docker/redis5/sentinel/sentinel_3.conf deleted file mode 100644 index a0e350ba0f..0000000000 --- a/docker/redis5/sentinel/sentinel_3.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26384 - -sentinel monitor redis-py-test 127.0.0.1 6383 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis6.2/master/redis.conf b/docker/redis6.2/master/redis.conf deleted file mode 100644 index 15a31b5a38..0000000000 --- a/docker/redis6.2/master/redis.conf +++ /dev/null @@ -1,2 +0,0 @@ -port 6379 -save "" diff --git a/docker/redis6.2/replica/redis.conf b/docker/redis6.2/replica/redis.conf deleted file mode 100644 index a76d402c5e..0000000000 --- a/docker/redis6.2/replica/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -port 6380 -save "" -replicaof master 6379 diff --git a/docker/redis6.2/sentinel/sentinel_2.conf b/docker/redis6.2/sentinel/sentinel_2.conf deleted file mode 100644 index 955621b872..0000000000 --- a/docker/redis6.2/sentinel/sentinel_2.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26380 - -sentinel monitor redis-py-test 127.0.0.1 6379 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis6.2/sentinel/sentinel_3.conf b/docker/redis6.2/sentinel/sentinel_3.conf deleted file mode 100644 index 62c40512f1..0000000000 --- a/docker/redis6.2/sentinel/sentinel_3.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26381 - -sentinel monitor redis-py-test 127.0.0.1 6379 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis7/master/redis.conf b/docker/redis7/master/redis.conf deleted file mode 100644 index ef57c1fe99..0000000000 --- a/docker/redis7/master/redis.conf +++ /dev/null @@ -1,4 +0,0 @@ -port 6379 -save "" -enable-debug-command yes -enable-module-command yes \ No newline at end of file diff --git a/docker/redismod_cluster/redis.conf b/docker/redismod_cluster/redis.conf deleted file mode 100644 index 48f06668a8..0000000000 --- a/docker/redismod_cluster/redis.conf +++ /dev/null @@ -1,8 +0,0 @@ -loadmodule /usr/lib/redis/modules/redisai.so -loadmodule /usr/lib/redis/modules/redisearch.so -loadmodule /usr/lib/redis/modules/redisgraph.so -loadmodule /usr/lib/redis/modules/redistimeseries.so -loadmodule /usr/lib/redis/modules/rejson.so -loadmodule /usr/lib/redis/modules/redisbloom.so -loadmodule /var/opt/redislabs/lib/modules/redisgears.so Plugin /var/opt/redislabs/modules/rg/plugin/gears_python.so Plugin /var/opt/redislabs/modules/rg/plugin/gears_jvm.so JvmOptions -Djava.class.path=/var/opt/redislabs/modules/rg/gear_runtime-jar-with-dependencies.jar JvmPath /var/opt/redislabs/modules/rg/OpenJDK/jdk-11.0.9.1+1/ - diff --git a/docker/unstable/redis.conf b/docker/unstable/redis.conf deleted file mode 100644 index 93a55cf3b3..0000000000 --- a/docker/unstable/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -port 6378 -protected-mode no -save "" diff --git a/docker/unstable_cluster/redis.conf b/docker/unstable_cluster/redis.conf deleted file mode 100644 index f307a63757..0000000000 --- a/docker/unstable_cluster/redis.conf +++ /dev/null @@ -1,4 +0,0 @@ -# Redis Cluster config file will be shared across all nodes. -# Do not change the following configurations that are already set: -# port, cluster-enabled, daemonize, logfile, dir -protected-mode no diff --git a/dockers/Dockerfile.cluster b/dockers/Dockerfile.cluster new file mode 100644 index 0000000000..204232a665 --- /dev/null +++ b/dockers/Dockerfile.cluster @@ -0,0 +1,7 @@ +FROM redis/redis-stack-server:latest as rss + +COPY dockers/create_cluster.sh /create_cluster.sh +RUN ls -R /opt/redis-stack +RUN chmod a+x /create_cluster.sh + +ENTRYPOINT [ "/create_cluster.sh"] diff --git a/dockers/cluster.redis.conf b/dockers/cluster.redis.conf new file mode 100644 index 0000000000..26da33567a --- /dev/null +++ b/dockers/cluster.redis.conf @@ -0,0 +1,6 @@ +protected-mode no +loadmodule /opt/redis-stack/lib/redisearch.so +loadmodule /opt/redis-stack/lib/redisgraph.so +loadmodule /opt/redis-stack/lib/redistimeseries.so +loadmodule /opt/redis-stack/lib/rejson.so +loadmodule /opt/redis-stack/lib/redisbloom.so diff --git a/docker/base/create_cluster.sh b/dockers/create_cluster.sh old mode 100755 new mode 100644 similarity index 75% rename from docker/base/create_cluster.sh rename to dockers/create_cluster.sh index fcb1b1cd8d..da9a0cb606 --- a/docker/base/create_cluster.sh +++ b/dockers/create_cluster.sh @@ -31,7 +31,8 @@ dir /nodes/$PORT EOF set -x - redis-server /nodes/$PORT/redis.conf + /opt/redis-stack/bin/redis-server /nodes/$PORT/redis.conf + sleep 1 if [ $? -ne 0 ]; then echo "Redis failed to start, exiting." continue @@ -39,8 +40,8 @@ EOF echo 127.0.0.1:$PORT >> /nodes/nodemap done if [ -z "${REDIS_PASSWORD}" ]; then - echo yes | redis-cli --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 + echo yes | /opt/redis-stack/bin/redis-cli --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 else - echo yes | redis-cli -a ${REDIS_PASSWORD} --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 + echo yes | opt/redis-stack/bin/redis-cli -a ${REDIS_PASSWORD} --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 fi tail -f /redis.log diff --git a/docker/redis6.2/sentinel/sentinel_1.conf b/dockers/sentinel.conf similarity index 73% rename from docker/redis6.2/sentinel/sentinel_1.conf rename to dockers/sentinel.conf index bd2d830af3..1a33f53344 100644 --- a/docker/redis6.2/sentinel/sentinel_1.conf +++ b/dockers/sentinel.conf @@ -1,6 +1,4 @@ -port 26379 - sentinel monitor redis-py-test 127.0.0.1 6379 2 sentinel down-after-milliseconds redis-py-test 5000 sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 +sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/stunnel/README b/dockers/stunnel/README similarity index 100% rename from docker/stunnel/README rename to dockers/stunnel/README diff --git a/docker/stunnel/conf/redis.conf b/dockers/stunnel/conf/redis.conf similarity index 83% rename from docker/stunnel/conf/redis.conf rename to dockers/stunnel/conf/redis.conf index 84f6d40133..a150d8b011 100644 --- a/docker/stunnel/conf/redis.conf +++ b/dockers/stunnel/conf/redis.conf @@ -1,6 +1,6 @@ [redis] accept = 6666 -connect = master:6379 +connect = redis:6379 cert = /etc/stunnel/keys/server-cert.pem key = /etc/stunnel/keys/server-key.pem verify = 0 diff --git a/docker/stunnel/create_certs.sh b/dockers/stunnel/create_certs.sh similarity index 100% rename from docker/stunnel/create_certs.sh rename to dockers/stunnel/create_certs.sh diff --git a/docker/stunnel/keys/ca-cert.pem b/dockers/stunnel/keys/ca-cert.pem similarity index 100% rename from docker/stunnel/keys/ca-cert.pem rename to dockers/stunnel/keys/ca-cert.pem diff --git a/docker/stunnel/keys/ca-key.pem b/dockers/stunnel/keys/ca-key.pem similarity index 100% rename from docker/stunnel/keys/ca-key.pem rename to dockers/stunnel/keys/ca-key.pem diff --git a/docker/stunnel/keys/client-cert.pem b/dockers/stunnel/keys/client-cert.pem similarity index 100% rename from docker/stunnel/keys/client-cert.pem rename to dockers/stunnel/keys/client-cert.pem diff --git a/docker/stunnel/keys/client-key.pem b/dockers/stunnel/keys/client-key.pem similarity index 100% rename from docker/stunnel/keys/client-key.pem rename to dockers/stunnel/keys/client-key.pem diff --git a/docker/stunnel/keys/client-req.pem b/dockers/stunnel/keys/client-req.pem similarity index 100% rename from docker/stunnel/keys/client-req.pem rename to dockers/stunnel/keys/client-req.pem diff --git a/docker/stunnel/keys/server-cert.pem b/dockers/stunnel/keys/server-cert.pem similarity index 100% rename from docker/stunnel/keys/server-cert.pem rename to dockers/stunnel/keys/server-cert.pem diff --git a/docker/stunnel/keys/server-key.pem b/dockers/stunnel/keys/server-key.pem similarity index 100% rename from docker/stunnel/keys/server-key.pem rename to dockers/stunnel/keys/server-key.pem diff --git a/docker/stunnel/keys/server-req.pem b/dockers/stunnel/keys/server-req.pem similarity index 100% rename from docker/stunnel/keys/server-req.pem rename to dockers/stunnel/keys/server-req.pem diff --git a/docs/examples/opentelemetry/main.py b/docs/examples/opentelemetry/main.py index b140dd0148..9ef6723305 100755 --- a/docs/examples/opentelemetry/main.py +++ b/docs/examples/opentelemetry/main.py @@ -2,12 +2,11 @@ import time +import redis import uptrace from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor -import redis - tracer = trace.get_tracer("app_or_package_name", "1.0.0") diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000..f1b716ae96 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,13 @@ +[pytest] +addopts = -s +markers = + redismod: run only the redis module tests + pipeline: pipeline tests + onlycluster: marks tests to be run only with cluster mode redis + onlynoncluster: marks tests to be run only with standalone redis + ssl: marker for only the ssl tests + asyncio: marker for async tests + replica: replica tests + experimental: run only experimental tests +asyncio_mode = auto +timeout = 30 diff --git a/redis/__init__.py b/redis/__init__.py index d7b74edf41..495d2d99bb 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,5 +1,6 @@ import sys +from redis import asyncio # noqa from redis.backoff import default_backoff from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster diff --git a/redis/_parsers/__init__.py b/redis/_parsers/__init__.py new file mode 100644 index 0000000000..6cc32e3cae --- /dev/null +++ b/redis/_parsers/__init__.py @@ -0,0 +1,20 @@ +from .base import BaseParser, _AsyncRESPBase +from .commands import AsyncCommandsParser, CommandsParser +from .encoders import Encoder +from .hiredis import _AsyncHiredisParser, _HiredisParser +from .resp2 import _AsyncRESP2Parser, _RESP2Parser +from .resp3 import _AsyncRESP3Parser, _RESP3Parser + +__all__ = [ + "AsyncCommandsParser", + "_AsyncHiredisParser", + "_AsyncRESPBase", + "_AsyncRESP2Parser", + "_AsyncRESP3Parser", + "CommandsParser", + "Encoder", + "BaseParser", + "_HiredisParser", + "_RESP2Parser", + "_RESP3Parser", +] diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py new file mode 100644 index 0000000000..f77296df6a --- /dev/null +++ b/redis/_parsers/base.py @@ -0,0 +1,232 @@ +import sys +from abc import ABC +from asyncio import IncompleteReadError, StreamReader, TimeoutError +from typing import List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from ..exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ConnectionError, + ExecAbortError, + ModuleError, + NoPermissionError, + NoScriptError, + OutOfMemoryError, + ReadOnlyError, + RedisError, + ResponseError, +) +from ..typing import EncodableT +from .encoders import Encoder +from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer + +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) +# user send an AUTH cmd to a server without authorization configured +NO_AUTH_SET_ERROR = { + # Redis >= 6.0 + "AUTH called without any password " + "configured for the default user. Are you sure " + "your configuration is correct?": AuthenticationError, + # Redis < 6.0 + "Client sent AUTH, but no password is set": AuthenticationError, +} + + +class BaseParser(ABC): + + EXCEPTION_CLASSES = { + "ERR": { + "max number of clients reached": ConnectionError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + **NO_AUTH_SET_ERROR, + }, + "OOM": OutOfMemoryError, + "WRONGPASS": AuthenticationError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + @classmethod + def parse_error(cls, response): + "Parse an error response" + error_code = response.split(" ")[0] + if error_code in cls.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = cls.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection): + raise NotImplementedError() + + +class _RESPBase(BaseParser): + """Base class for sync-based resp parsing""" + + def __init__(self, socket_read_size): + self.socket_read_size = socket_read_size + self.encoder = None + self._sock = None + self._buffer = None + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + "Called when the socket connects" + self._sock = connection._sock + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + "Called when the socket disconnects" + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + def can_read(self, timeout): + return self._buffer and self._buffer.can_read(timeout) + + +class AsyncBaseParser(BaseParser): + """Base parsing class for the python-backed async parser""" + + __slots__ = "_stream", "_read_size" + + def __init__(self, socket_read_size: int): + self._stream: Optional[StreamReader] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + async def can_read_destructive(self) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class _AsyncRESPBase(AsyncBaseParser): + """Base class for async resp parsing""" + + __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() + + def on_connect(self, connection): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + self.encoder = connection.encoder + self._clear() + self._connected = True + + def on_disconnect(self): + """Called when the stream disconnects""" + self._connected = False + + async def can_read_destructive(self) -> bool: + if not self._connected: + raise RedisError("Buffer is closed.") + if self._buffer: + return True + try: + async with async_timeout(0): + return await self._stream.read(1) + except TimeoutError: + return False + + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result diff --git a/redis/commands/parser.py b/redis/_parsers/commands.py similarity index 62% rename from redis/commands/parser.py rename to redis/_parsers/commands.py index 115230a9d2..d3b4a99ed3 100644 --- a/redis/commands/parser.py +++ b/redis/_parsers/commands.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + from redis.exceptions import RedisError, ResponseError from redis.utils import str_if_bytes +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + class CommandsParser: """ @@ -16,7 +21,7 @@ def __init__(self, redis_connection): self.initialize(redis_connection) def initialize(self, r): - commands = r.execute_command("COMMAND") + commands = r.command() uppercase_commands = [] for cmd in commands: if any(x.isupper() for x in cmd): @@ -117,12 +122,9 @@ def _get_moveable_keys(self, redis_conn, *args): So, don't use this function with EVAL or EVALSHA. """ - pieces = [] - cmd_name = args[0] # The command name should be splitted into separate arguments, # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] - pieces = pieces + cmd_name.split() - pieces = pieces + list(args[1:]) + pieces = args[0].split() + list(args[1:]) try: keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) except ResponseError as e: @@ -153,14 +155,102 @@ def _get_pubsub_keys(self, *args): # the second argument is a part of the command name, e.g. # ['PUBSUB', 'NUMSUB', 'foo']. pubsub_type = args[1].upper() - if pubsub_type in ["CHANNELS", "NUMSUB"]: + if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]: keys = args[2:] elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: # format example: # SUBSCRIBE channel [channel ...] keys = list(args[1:]) - elif command == "PUBLISH": + elif command in ["PUBLISH", "SPUBLISH"]: # format example: # PUBLISH channel message keys = [args[1]] return keys + + +class AsyncCommandsParser: + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands", "node") + + def __init__(self) -> None: + self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} + + async def initialize(self, node: Optional["ClusterNode"] = None) -> None: + if node: + self.node = node + + commands = await self.node.execute_command("COMMAND") + for cmd, command in commands.items(): + if "movablekeys" in command["flags"]: + commands[cmd] = -1 + elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: + commands[cmd] = 0 + elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: + commands[cmd] = 1 + self.commands = {cmd.upper(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + if len(args) < 2: + # The command has no keys in it + return None + + try: + command = self.commands[args[0]] + except KeyError: + # try to split the command name and to take only the main command + # e.g. 'memory' for 'memory usage' + args = args[0].split() + list(args[1:]) + cmd_name = args[0].upper() + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize() + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name} command doesn't exist in Redis commands" + ) + + command = self.commands[cmd_name] + + if command == 1: + return (args[1],) + if command == 0: + return None + if command == -1: + return await self._get_moveable_keys(*args) + + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) + last_key_pos + return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] + + async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + try: + keys = await self.node.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys diff --git a/redis/_parsers/encoders.py b/redis/_parsers/encoders.py new file mode 100644 index 0000000000..6fdf0ad882 --- /dev/null +++ b/redis/_parsers/encoders.py @@ -0,0 +1,44 @@ +from ..exceptions import DataError + + +class Encoder: + "Encode strings to bytes-like and decode bytes-like to strings" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): + return value + elif isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) + elif isinstance(value, (int, float)): + value = repr(value).encode() + elif not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = type(value).__name__ + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) + if isinstance(value, str): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py new file mode 100644 index 0000000000..f27e3b12c0 --- /dev/null +++ b/redis/_parsers/helpers.py @@ -0,0 +1,851 @@ +import datetime + +from redis.utils import str_if_bytes + + +def timestamp_to_datetime(response): + "Converts a unix timestamp to a Python datetime object" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def parse_debug_object(response): + "Parse the results of Redis's DEBUG OBJECT command into a Python dict" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = str_if_bytes(response) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_info(response): + """Parse the result of Redis's INFO command into a Python dict""" + info = {} + response = str_if_bytes(response) + + def get_value(value): + if "," not in value or "=" not in value: + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(","): + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith("#"): + if line.find(":") != -1: + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) + + if key == "module": + # Hardcode a list for key 'modules' since there could be + # multiple lines that started with 'module' + info.setdefault("modules", []).append(get_value(value)) + else: + info[key] = get_value(value) + else: + # if the line isn't splittable, append it to the "__raw__" key + info.setdefault("__raw__", []).append(line) + + return info + + +def parse_memory_stats(response, **kwargs): + """Parse the results of MEMORY STATS""" + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) + for key, value in stats.items(): + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) + return stats + + +SENTINEL_STATE_TYPES = { + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, +} + + +def parse_sentinel_state(item): + result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): + result[name] = flag in flags + return result + + +def parse_sentinel_master(response): + return parse_sentinel_state(map(str_if_bytes, response)) + + +def parse_sentinel_state_resp3(response): + result = {} + for key in response: + try: + value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key])) + result[str_if_bytes(key)] = value + except Exception: + result[str_if_bytes(key)] = response[str_if_bytes(key)] + flags = set(result["flags"].split(",")) + result["flags"] = flags + return result + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(map(str_if_bytes, item)) + result[state["name"]] = state + return result + + +def parse_sentinel_masters_resp3(response): + return [parse_sentinel_state(master) for master in response] + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] + + +def parse_sentinel_slaves_and_sentinels_resp3(response): + return [parse_sentinel_state_resp3(item) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None + + +def pairs_to_dict(response, decode_keys=False, decode_string_values=False): + """Create a dict given a list of key/value pairs""" + if response is None: + return {} + if decode_keys or decode_string_values: + # the iter form is faster, but I don't know how to make that work + # with a str_if_bytes() map + keys = response[::2] + if decode_keys: + keys = map(str_if_bytes, keys) + values = response[1::2] + if decode_string_values: + values = map(str_if_bytes, values) + return dict(zip(keys, values)) + else: + it = iter(response) + return dict(zip(it, it)) + + +def pairs_to_dict_typed(response, type_info): + it = iter(response) + result = {} + for key, value in zip(it, it): + if key in type_info: + try: + value = type_info[key](value) + except Exception: + # if for some reason the value can't be coerced, just use + # the string value + pass + result[key] = value + return result + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + it = iter(response) + return list(zip(it, map(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options.get("groups"): + return response + n = options["groups"] + return list(zip(*[response[i::n] for i in range(n)])) + + +def parse_stream_list(response): + if response is None: + return None + data = [] + for r in response: + if r is not None: + data.append((r[0], pairs_to_dict(r[1]))) + else: + data.append((None, None)) + return data + + +def pairs_to_dict_with_str_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(map(pairs_to_dict_with_str_keys, response)) + + +def parse_xclaim(response, **options): + if options.get("parse_justid", False): + return response + return parse_stream_list(response) + + +def parse_xautoclaim(response, **options): + if options.get("parse_justid", False): + return response[1] + response[1] = parse_stream_list(response[1]) + return response + + +def parse_xinfo_stream(response, **options): + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(k): v for k, v in response.items()} + if not options.get("full", False): + first = data.get("first-entry") + if first is not None: + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] + if last is not None: + data["last-entry"] = (last[0], pairs_to_dict(last[1])) + else: + data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} + if isinstance(data["groups"][0], list): + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + else: + data["groups"] = [ + {str_if_bytes(k): v for k, v in group.items()} + for group in data["groups"] + ] + return data + + +def parse_xread(response): + if response is None: + return [] + return [[r[0], parse_stream_list(r[1])] for r in response] + + +def parse_xread_resp3(response): + if response is None: + return {} + return {key: [parse_stream_list(value)] for key, value in response.items()} + + +def parse_xpending(response, **options): + if options.get("parse_detail", False): + return parse_xpending_range(response) + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] + return { + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, + } + + +def parse_xpending_range(response): + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") + return [dict(zip(k, r)) for r in response] + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def bool_ok(response): + return str_if_bytes(response) == "OK" + + +def parse_zadd(response, **options): + if response is None: + return None + if options.get("as_score"): + return float(response) + return int(response) + + +def parse_client_list(response, **options): + clients = [] + for c in str_if_bytes(response).splitlines(): + # Values might contain '=' + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) + return clients + + +def parse_config_get(response, **options): + response = [str_if_bytes(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + + +def parse_scan(response, **options): + cursor, r = response + return int(cursor), r + + +def parse_hscan(response, **options): + cursor, r = response + return int(cursor), r and pairs_to_dict(r) or {} + + +def parse_zscan(response, **options): + score_cast_func = options.get("score_cast_func", float) + cursor, r = response + it = iter(r) + return int(cursor), list(zip(it, map(score_cast_func, it))) + + +def parse_zmscore(response, **options): + # zmscore: list of scores (double precision floating point number) or nil + return [float(score) if score is not None else None for score in response] + + +def parse_slowlog_get(response, **options): + space = " " if options.get("decode_responses", False) else b" " + + def parse_item(item): + result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} + # Redis Enterprise injects another entry at index [3], which has + # the complexity info (i.e. the value N in case the command has + # an O(N) complexity) instead of the command. + if isinstance(item[3], list): + result["command"] = space.join(item[3]) + result["client_address"] = item[4] + result["client_name"] = item[5] + else: + result["complexity"] = item[3] + result["command"] = space.join(item[4]) + result["client_address"] = item[5] + result["client_name"] = item[6] + return result + + return [parse_item(item) for item in response] + + +def parse_stralgo(response, **options): + """ + Parse the response from `STRALGO` command. + Without modifiers the returned value is string. + When LEN is given the command returns the length of the result + (i.e integer). + When IDX is given the command returns a dictionary with the LCS + length and all the ranges in both the strings, start and end + offset for each string, where there are matches. + When WITHMATCHLEN is given, each array representing a match will + also have the length of the match at the beginning of the array. + """ + if options.get("len", False): + return int(response) + if options.get("idx", False): + if options.get("withmatchlen", False): + matches = [ + [(int(match[-1]))] + list(map(tuple, match[:-1])) + for match in response[1] + ] + else: + matches = [list(map(tuple, match)) for match in response[1]] + return { + str_if_bytes(response[0]): matches, + str_if_bytes(response[2]): int(response[3]), + } + return str_if_bytes(response) + + +def parse_cluster_info(response, **options): + response = str_if_bytes(response) + return dict(line.split(":") for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + addr = addr.split("@")[0] + node_dict = { + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": [], + "migrations": [], + "connected": True if connected == "connected" else False, + } + if len(line_items) >= 9: + slots, migrations = _parse_slots(line_items[8:]) + node_dict["slots"], node_dict["migrations"] = slots, migrations + return addr, node_dict + + +def _parse_slots(slot_ranges): + slots, migrations = [], [] + for s_range in slot_ranges: + if "->-" in s_range: + slot_id, dst_node_id = s_range[1:-1].split("->-", 1) + migrations.append( + {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} + ) + elif "-<-" in s_range: + slot_id, src_node_id = s_range[1:-1].split("-<-", 1) + migrations.append( + {"slot": slot_id, "node_id": src_node_id, "state": "importing"} + ) + else: + s_range = [sl for sl in s_range.split("-")] + slots.append(s_range) + + return slots, migrations + + +def parse_cluster_nodes(response, **options): + """ + @see: https://redis.io/commands/cluster-nodes # string / bytes + @see: https://redis.io/commands/cluster-replicas # list of string / bytes + """ + if isinstance(response, (str, bytes)): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) + + +def parse_geosearch_generic(response, **options): + """ + Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' + commands according to 'withdist', 'withhash' and 'withcoord' labels. + """ + try: + if options["store"] or options["store_dist"]: + # `store` and `store_dist` cant be combined + # with other command arguments. + # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + return response + except KeyError: # it means the command was sent via execute_command + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: + # just a bunch of places + return response_list + + cast = { + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, + } + + # zip all output results with each casting function to get + # the properly native Python value. + f = [lambda x: x] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] + + +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + commands[cmd_name] = cmd_dict + return commands + + +def parse_command_resp3(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = command[1] + cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + cmd_dict["acl_categories"] = command[6] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + + commands[cmd_name] = cmd_dict + return commands + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +def parse_client_kill(response, **options): + if isinstance(response, int): + return response + return str_if_bytes(response) == "OK" + + +def parse_acl_getuser(response, **options): + if response is None: + return None + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(key): value for key, value in response.items()} + + # convert everything but user-defined data in 'keys' to native strings + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) + if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): + data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) + if data["keys"] == [""]: + data["keys"] = [] + if "channels" in data: + if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): + data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) + if data["channels"] == [""]: + data["channels"] = [] + if "selectors" in data: + if data["selectors"] != [] and isinstance(data["selectors"][0], list): + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] + elif data["selectors"] != []: + data["selectors"] = [ + {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} + for selector in data["selectors"] + ] + + # split 'commands' into separate 'categories' and 'commands' lists + commands, categories = [], [] + for command in data["commands"].split(" "): + categories.append(command) if "@" in command else commands.append(command) + + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] + return data + + +def parse_acl_log(response, **options): + if response is None: + return None + if isinstance(response, list): + data = [] + for log in response: + log_data = pairs_to_dict(log, True, True) + client_info = log_data.get("client-info", "") + log_data["client-info"] = parse_client_info(client_info) + + # float() is lossy comparing to the "double" in C + log_data["age-seconds"] = float(log_data["age-seconds"]) + data.append(log_data) + else: + data = bool_ok(response) + return data + + +def parse_client_info(value): + """ + Parsing client-info in ACL Log in following format. + "key1=value1 key2=value2 key3=value3" + """ + client_info = {} + infos = str_if_bytes(value).split(" ") + for info in infos: + key, value = info.split("=") + client_info[key] = value + + # Those fields are defined as int in networking.c + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "argv-mem", + "oll", + "omem", + "tot-mem", + }: + client_info[int_key] = int(client_info[int_key]) + return client_info + + +def parse_set_result(response, **options): + """ + Handle SET result since GET argument is available since Redis 6.2. + Parsing SET result into: + - BOOL + - String when GET argument is used + """ + if options.get("get"): + # Redis will return a getCommand result. + # See `setGenericCommand` in t_string.c + return response + return response and str_if_bytes(response) == "OK" + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +_RedisCallbacks = { + **string_keys_to_dict( + "AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX " + "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE", + bool, + ), + **string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float), + **string_keys_to_dict( + "ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH", + bool_ok, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict( + "GEORADIUS GEORADIUSBYMEMBER GEOSEARCH", + parse_geosearch_generic, + ), + **string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list), + "ACL GETUSER": parse_acl_getuser, + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SETUSER": bool_ok, + "ACL SAVE": bool_ok, + "CLIENT INFO": parse_client_info, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT PAUSE": bool_ok, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": bool, + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER ADDSLOTSRANGE": bool_ok, + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER DELSLOTSRANGE": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "COMMAND": parse_command, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "FUNCTION DELETE": bool_ok, + "FUNCTION FLUSH": bool_ok, + "FUNCTION RESTORE": bool_ok, + "GEODIST": float_or_none, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MODULE LOAD": bool, + "MODULE UNLOAD": bool, + "PING": lambda r: str_if_bytes(r) == "PONG", + "PUBSUB NUMSUB": parse_pubsub_numsub, + "QUIT": bool_ok, + "SET": parse_set_result, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL CKQUORUM": bool_ok, + "SENTINEL FAILOVER": bool_ok, + "SENTINEL FLUSHCONFIG": bool_ok, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MONITOR": bool_ok, + "SENTINEL RESET": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SET": bool_ok, + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG RESET": bool_ok, + "SORT": sort_return_tuples, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XAUTOCLAIM": parse_xautoclaim, + "XCLAIM": parse_xclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZSCAN": parse_zscan, +} + + +_RedisCallbacksRESP2 = { + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE " + "ZREVRANGEBYSCORE ZREVRANK ZUNION", + zset_score_pairs, + ), + **string_keys_to_dict("ZINCRBY ZSCORE", float_or_none), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + **string_keys_to_dict( + "BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL GENPASS": str_if_bytes, + "ACL HELP": lambda r: list(map(str_if_bytes, r)), + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + "CONFIG GET": parse_config_get, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "MEMORY STATS": parse_memory_stats, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "RESET": str_if_bytes, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "STRALGO": parse_stralgo, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "ZADD": parse_zadd, + "ZMSCORE": parse_zmscore, +} + + +_RedisCallbacksRESP3 = { + **string_keys_to_dict( + "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " + "ZUNION HGETALL XREADGROUP", + lambda r, **kwargs: r, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r + ] + if isinstance(r, list) + else bool_ok(r), + "COMMAND": parse_command_resp3, + "CONFIG GET": lambda r: { + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None + for key, value in r.items() + }, + "MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()}, + "SENTINEL MASTER": parse_sentinel_state_resp3, + "SENTINEL MASTERS": parse_sentinel_masters_resp3, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), + "XINFO CONSUMERS": lambda r: [ + {str_if_bytes(key): value for key, value in x.items()} for x in r + ], + "XINFO GROUPS": lambda r: [ + {str_if_bytes(key): value for key, value in d.items()} for d in r + ], +} diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py new file mode 100644 index 0000000000..b3247b71ec --- /dev/null +++ b/redis/_parsers/hiredis.py @@ -0,0 +1,217 @@ +import asyncio +import socket +import sys +from typing import Callable, List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from redis.compat import TypedDict + +from ..exceptions import ConnectionError, InvalidResponse, RedisError +from ..typing import EncodableT +from ..utils import HIREDIS_AVAILABLE +from .base import AsyncBaseParser, BaseParser +from .socket import ( + NONBLOCKING_EXCEPTION_ERROR_NUMBERS, + NONBLOCKING_EXCEPTIONS, + SENTINEL, + SERVER_CLOSED_CONNECTION_ERROR, +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class _HiredisParser(BaseParser): + "Parser class for connections using Hiredis" + + def __init__(self, socket_read_size): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not installed") + self.socket_read_size = socket_read_size + self._buffer = bytearray(socket_read_size) + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection, **kwargs): + import hiredis + + self._sock = connection._sock + self._socket_timeout = connection.socket_timeout + kwargs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + "errors": connection.encoder.encoding_errors, + } + + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + + def on_disconnect(self): + self._sock = None + self._reader = None + self._next_response = False + + def can_read(self, timeout): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): + sock = self._sock + custom_timeout = timeout is not SENTINEL + try: + if custom_timeout: + sock.settimeout(timeout) + bufflen = self._sock.recv_into(self._buffer) + if bufflen == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + self._reader.feed(self._buffer, 0, bufflen) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + if custom_timeout: + sock.settimeout(self._socket_timeout) + + def read_response(self, disable_decoding=False): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + + while response is False: + self.read_from_socket() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response + + +class _AsyncHiredisParser(AsyncBaseParser): + """Async implementation of parser class for connections using Hiredis""" + + __slots__ = ("_reader",) + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader = None + + def on_connect(self, connection): + import hiredis + + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._connected = True + + def on_disconnect(self): + self._connected = False + + async def can_read_destructive(self): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._reader.gets(): + return True + try: + async with async_timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: + return False + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + # If `on_disconnect()` has been called, prohibit any more reads + # even if they could happen because data might be present. + # We still allow reads in progress to finish + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response diff --git a/redis/_parsers/resp2.py b/redis/_parsers/resp2.py new file mode 100644 index 0000000000..d5adc1a898 --- /dev/null +++ b/redis/_parsers/resp2.py @@ -0,0 +1,132 @@ +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP2Parser(_RESPBase): + """RESP2 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() if self._buffer else None + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + if self._buffer: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = self._buffer.read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP2Parser(_AsyncRESPBase): + """Async class for the RESP2 protocol""" + + async def read_response(self, disable_decoding: bool = False): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = await self._read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding)) + for _ in range(int(response)) # noqa + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py new file mode 100644 index 0000000000..1275686710 --- /dev/null +++ b/redis/_parsers/resp3.py @@ -0,0 +1,261 @@ +from logging import getLogger +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP3Parser(_RESPBase): + """RESP3 protocol implementation""" + + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + def read_response(self, disable_decoding=False, push_request=False): + pos = self._buffer.get_pos() if self._buffer else None + try: + result = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + except BaseException: + if self._buffer: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False, push_request=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = self._buffer.read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response + elif byte == b"$": + response = self._buffer.read(int(response)) + # verbatim string response + elif byte == b"=": + response = self._buffer.read(int(response))[4:] + # array response + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + try: + response = set(response) + except TypeError: + pass + # map response + elif byte == b"%": + # we use this approach and not dict comprehension here + # because this dict comprehension fails in python 3.7 + resp_dict = {} + for _ in range(int(response)): + key = self._read_response(disable_decoding=disable_decoding) + resp_dict[key] = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + response = resp_dict + # push response + elif byte == b">": + response = [ + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func + + +class _AsyncRESP3Parser(_AsyncRESPBase): + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + async def read_response( + self, disable_decoding: bool = False, push_request: bool = False + ): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False, push_request: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # if byte not in (b"-", b"+", b":", b"$", b"*"): + # raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = await self._read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response + elif byte == b"$": + response = await self._read(int(response)) + # verbatim string response + elif byte == b"=": + response = (await self._read(int(response)))[4:] + # array response + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + try: + response = set(response) + except TypeError: + pass + # map response + elif byte == b"%": + response = { + (await self._read_response(disable_decoding=disable_decoding)): ( + await self._read_response(disable_decoding=disable_decoding) + ) + for _ in range(int(response)) + } + # push response + elif byte == b">": + response = [ + ( + await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return await ( + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + else: + return res + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func diff --git a/redis/_parsers/socket.py b/redis/_parsers/socket.py new file mode 100644 index 0000000000..8147243bba --- /dev/null +++ b/redis/_parsers/socket.py @@ -0,0 +1,162 @@ +import errno +import io +import socket +from io import SEEK_END +from typing import Optional, Union + +from ..exceptions import ConnectionError, TimeoutError +from ..utils import SSL_AVAILABLE + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} + +if SSL_AVAILABLE: + import ssl + + if hasattr(ssl, "SSLWantReadError"): + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 + else: + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +SENTINEL = object() + +SYM_CRLF = b"\r\n" + + +class SocketBuffer: + def __init__( + self, socket: socket.socket, socket_read_size: int, socket_timeout: float + ): + self._sock = socket + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer = io.BytesIO() + + def unread_bytes(self) -> int: + """ + Remaining unread length of buffer + """ + pos = self._buffer.tell() + end = self._buffer.seek(0, SEEK_END) + self._buffer.seek(pos) + return end - pos + + def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, object] = SENTINEL, + raise_on_timeout: Optional[bool] = True, + ) -> bool: + sock = self._sock + socket_read_size = self.socket_read_size + marker = 0 + custom_timeout = timeout is not SENTINEL + + buf = self._buffer + current_pos = buf.tell() + buf.seek(0, SEEK_END) + if custom_timeout: + sock.settimeout(timeout) + try: + while True: + data = self._sock.recv(socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + marker += data_length + + if length is not None and length > marker: + continue + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + buf.seek(current_pos) + if custom_timeout: + sock.settimeout(self.socket_timeout) + + def can_read(self, timeout: float) -> bool: + return bool(self.unread_bytes()) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # BufferIO will return less than requested if buffer is short + data = self._buffer.read(length) + missing = length - len(data) + if missing: + # fill up the buffer and read the remainder + self._read_from_socket(missing) + data += self._buffer.read(missing) + return data[:-2] + + def readline(self) -> bytes: + buf = self._buffer + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + data += buf.readline() + + return data[:-2] + + def get_pos(self) -> int: + """ + Get current read position + """ + return self._buffer.tell() + + def rewind(self, pos: int) -> None: + """ + Rewind the buffer to a specific position, to re-start reading + """ + self._buffer.seek(pos) + + def purge(self) -> None: + """ + After a successful read, purge the read part of buffer + """ + unread = self.unread_bytes() + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self._buffer.seek(0) + + def close(self) -> None: + try: + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._sock = None diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index 2a82df251e..3545ab44c2 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -7,7 +7,6 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -39,7 +38,6 @@ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", - "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 7479b742b6..111df24185 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -24,6 +24,12 @@ cast, ) +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) from redis.asyncio.connection import ( Connection, ConnectionPool, @@ -37,7 +43,6 @@ NEVER_DECODE, AbstractRedis, CaseInsensitiveDict, - bool_ok, ) from redis.commands import ( AsyncCoreCommands, @@ -57,7 +62,7 @@ WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT -from redis.utils import safe_str, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -180,6 +185,7 @@ def __init__( auto_close_connection_pool: bool = True, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Redis client. @@ -217,6 +223,7 @@ def __init__( "health_check_interval": health_check_interval, "client_name": client_name, "redis_connect_func": redis_connect_func, + "protocol": protocol, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -255,7 +262,12 @@ def __init__( self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) + + if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.response_callbacks.update(_RedisCallbacksRESP3) + else: + self.response_callbacks.update(_RedisCallbacksRESP2) # If using a single connection client, we need to lock creation-of and use-of # the client in order to avoid race conditions such as using asyncio.gather @@ -657,6 +669,7 @@ def __init__( shard_hint: Optional[str] = None, ignore_subscribe_messages: bool = False, encoder=None, + push_handler_func: Optional[Callable] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -665,18 +678,21 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: - self.health_check_response: Iterable[Union[str, bytes]] = [ - "pong", + self.health_check_response = [ + ["pong", self.HEALTH_CHECK_MESSAGE], self.HEALTH_CHECK_MESSAGE, ] else: self.health_check_response = [ - b"pong", + [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] + if self.push_handler_func is None: + _set_info_logger() self.channels = {} self.pending_unsubscribe_channels = set() self.patterns = {} @@ -761,6 +777,8 @@ async def connect(self): self.connection.register_connect_callback(self.on_connect) else: await self.connection.connect() + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ @@ -802,10 +820,14 @@ async def parse_response(self, block: bool = True, timeout: float = 0): read_timeout = None if block else timeout response = await self._execute( - conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False + conn, + conn.read_response, + timeout=read_timeout, + disconnect_on_error=False, + push_request=True, ) - if conn.health_check_interval and response == self.health_check_response: + if conn.health_check_interval and response in self.health_check_response: # ignore the health check message as user might not expect it return None return response @@ -933,8 +955,8 @@ def ping(self, message=None) -> Awaitable: """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) async def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -942,6 +964,10 @@ async def handle_message(self, response, ignore_subscribe_messages=False): with a message handler, the handler is invoked instead of a parsed message being returned. """ + if response is None: + return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 989c6ccda8..9e2a40ce1b 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -18,16 +18,15 @@ Union, ) -from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import ( - Connection, - DefaultParser, - Encoder, - SSLConnection, - parse_url, +from redis._parsers import AsyncCommandsParser, Encoder +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, ) +from redis.asyncio.client import ResponseCallbackT +from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import default_backoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis @@ -258,6 +257,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + protocol: Optional[int] = 2, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: @@ -299,6 +299,7 @@ def __init__( "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, "retry": retry, + "protocol": protocol, } if ssl: @@ -331,7 +332,11 @@ def __init__( self.retry.update_supported_errors(retry_on_error) kwargs.update({"retry": self.retry}) - kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + kwargs["response_callbacks"] = _RedisCallbacks.copy() + if kwargs.get("protocol") in ["3", 3]: + kwargs["response_callbacks"].update(_RedisCallbacksRESP3) + else: + kwargs["response_callbacks"].update(_RedisCallbacksRESP2) self.connection_kwargs = kwargs if startup_nodes: @@ -358,7 +363,7 @@ def __init__( self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 - self.commands_parser = CommandsParser() + self.commands_parser = AsyncCommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index d6195e1801..22c5030e6c 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -33,35 +33,31 @@ else: from async_timeout import timeout as async_timeout - from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.compat import Protocol, TypedDict +from redis.connection import DEFAULT_RESP_VERSION from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - OutOfMemoryError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.typing import EncodableT, EncodedT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -hiredis = None -if HIREDIS_AVAILABLE: - import hiredis +from .._parsers import ( + BaseParser, + Encoder, + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, +) SYM_STAR = b"*" SYM_DOLLAR = b"$" @@ -69,367 +65,19 @@ SYM_LF = b"\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - class _Sentinel(enum.Enum): sentinel = object() SENTINEL = _Sentinel.sentinel -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class _HiredisReaderArgs(TypedDict, total=False): - protocolError: Callable[[str], Exception] - replyError: Callable[[str], Exception] - encoding: Optional[str] - errors: Optional[str] - - -class Encoder: - """Encode strings to bytes-like and decode bytes-like to strings""" - - __slots__ = "encoding", "encoding_errors", "decode_responses" - - def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value: EncodableT) -> EncodedT: - """Return a bytestring or bytes-like representation of the value""" - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - if isinstance(value, (bytes, memoryview)): - return value - if isinstance(value, (int, float)): - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) - return repr(value).encode() - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - - def decode(self, value: EncodableT, force=False) -> EncodableT: - """Return a unicode string from the bytes-like representation""" - if self.decode_responses or force: - if isinstance(value, bytes): - return value.decode(self.encoding, self.encoding_errors) - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) - return value - - -ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] - - -class BaseParser: - """Plain Python parsing class""" - - __slots__ = "_stream", "_read_size", "_connected" - - EXCEPTION_CLASSES: ExceptionMappingT = { - "ERR": { - "max number of clients reached": ConnectionError, - "Client sent AUTH, but no password is set": AuthenticationError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - "OOM": OutOfMemoryError, - } - - def __init__(self, socket_read_size: int): - self._stream: Optional[asyncio.StreamReader] = None - self._read_size = socket_read_size - self._connected = False - - @classmethod - def parse_error(cls, response: str) -> ResponseError: - """Parse an error response""" - error_code = response.split(" ")[0] - if error_code in cls.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = cls.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - def on_disconnect(self): - raise NotImplementedError() - - def on_connect(self, connection: "AbstractConnection"): - raise NotImplementedError() - - async def can_read_destructive(self) -> bool: - raise NotImplementedError() - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: - raise NotImplementedError() - - -class PythonParser(BaseParser): - """Plain Python parsing class""" - - __slots__ = ("encoder", "_buffer", "_pos", "_chunks") - - def __init__(self, socket_read_size: int): - super().__init__(socket_read_size) - self.encoder: Optional[Encoder] = None - self._buffer = b"" - self._chunks = [] - self._pos = 0 - - def _clear(self): - self._buffer = b"" - self._chunks.clear() - - def on_connect(self, connection: "AbstractConnection"): - """Called when the stream connects""" - self._stream = connection._reader - if self._stream is None: - raise RedisError("Buffer is closed.") - self.encoder = connection.encoder - self._clear() - self._connected = True - - def on_disconnect(self): - """Called when the stream disconnects""" - self._connected = False - - async def can_read_destructive(self) -> bool: - if not self._connected: - raise RedisError("Buffer is closed.") - if self._buffer: - return True - try: - async with async_timeout(0): - return await self._stream.read(1) - except asyncio.TimeoutError: - return False - - async def read_response(self, disable_decoding: bool = False): - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._chunks: - # augment parsing buffer with previously read data - self._buffer += b"".join(self._chunks) - self._chunks.clear() - self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - # Successfully parsing a response allows us to clear our parsing buffer - self._clear() - return response - - async def _read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None]: - raw = await self._readline() - response: Any - byte, response = raw[:1], raw[1:] - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - self._clear() # Successful parse - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - return int(response) - # bulk response - elif byte == b"$" and response == b"-1": - return None - elif byte == b"$": - response = await self._read(int(response)) - # multi-bulk response - elif byte == b"*" and response == b"-1": - return None - elif byte == b"*": - response = [ - (await self._read_response(disable_decoding)) - for _ in range(int(response)) # noqa - ] - else: - raise InvalidResponse(f"Protocol Error: {raw!r}") - - if disable_decoding is False: - response = self.encoder.decode(response) - return response - - async def _read(self, length: int) -> bytes: - """ - Read `length` bytes of data. These are assumed to be followed - by a '\r\n' terminator which is subsequently discarded. - """ - want = length + 2 - end = self._pos + want - if len(self._buffer) >= end: - result = self._buffer[self._pos : end - 2] - else: - tail = self._buffer[self._pos :] - try: - data = await self._stream.readexactly(want - len(tail)) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += want - return result - - async def _readline(self) -> bytes: - """ - read an unknown number of bytes up to the next '\r\n' - line separator, which is discarded. - """ - found = self._buffer.find(b"\r\n", self._pos) - if found >= 0: - result = self._buffer[self._pos : found] - else: - tail = self._buffer[self._pos :] - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += len(result) + 2 - return result - - -class HiredisParser(BaseParser): - """Parser class for connections using Hiredis""" - - __slots__ = ("_reader",) - - def __init__(self, socket_read_size: int): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not available.") - super().__init__(socket_read_size=socket_read_size) - self._reader: Optional[hiredis.Reader] = None - - def on_connect(self, connection: "AbstractConnection"): - self._stream = connection._reader - kwargs: _HiredisReaderArgs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - } - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - kwargs["errors"] = connection.encoder.encoding_errors - - self._reader = hiredis.Reader(**kwargs) - self._connected = True - - def on_disconnect(self): - self._connected = False - - async def can_read_destructive(self): - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._reader.gets(): - return True - try: - async with async_timeout(0): - return await self.read_from_socket() - except asyncio.TimeoutError: - return False - - async def read_from_socket(self): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, List[EncodableT]]: - # If `on_disconnect()` has been called, prohibit any more reads - # even if they could happen because data might be present. - # We still allow reads in progress to finish - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - - response = self._reader.gets() - while response is False: - await self.read_from_socket() - response = self._reader.gets() - - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response -DefaultParser: Type[Union[PythonParser, HiredisParser]] +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _AsyncHiredisParser else: - DefaultParser = PythonParser + DefaultParser = _AsyncRESP2Parser class ConnectCallbackProtocol(Protocol): @@ -465,6 +113,7 @@ class AbstractConnection: "last_active_at", "encoder", "ssl_context", + "protocol", "_reader", "_writer", "_parser", @@ -496,6 +145,7 @@ def __init__( redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): if (username or password) and credential_provider is not None: raise DataError( @@ -542,6 +192,16 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + self.protocol = protocol def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) @@ -623,7 +283,9 @@ def _error_message(self, exception: BaseException) -> str: async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) + parser = self._parser + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -631,8 +293,25 @@ async def on_connect(self) -> None: or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = await self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + elif auth_args: await self.send_command("AUTH", *auth_args, check_health=False) try: @@ -648,6 +327,20 @@ async def on_connect(self) -> None: if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol) + response = await self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: await self.send_command("CLIENT", "SETNAME", self.client_name) @@ -768,16 +461,30 @@ async def read_response( timeout: Optional[float] = None, *, disconnect_on_error: bool = True, + push_request: Optional[bool] = False, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout host_error = self._host_error() try: - if read_timeout is not None: + if ( + read_timeout is not None + and self.protocol in ["3", 3] + and not HIREDIS_AVAILABLE + ): + async with async_timeout(read_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + elif read_timeout is not None: async with async_timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) + elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) else: response = await self._parser.read_response( disable_decoding=disable_decoding diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py deleted file mode 100644 index 5faf8f8c57..0000000000 --- a/redis/asyncio/parser.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union - -from redis.exceptions import RedisError, ResponseError - -if TYPE_CHECKING: - from redis.asyncio.cluster import ClusterNode - - -class CommandsParser: - """ - Parses Redis commands to get command keys. - - COMMAND output is used to determine key locations. - Commands that do not have a predefined key location are flagged with 'movablekeys', - and these commands' keys are determined by the command 'COMMAND GETKEYS'. - - NOTE: Due to a bug in redis<7.0, this does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this with EVAL or EVALSHA. - """ - - __slots__ = ("commands", "node") - - def __init__(self) -> None: - self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} - - async def initialize(self, node: Optional["ClusterNode"] = None) -> None: - if node: - self.node = node - - commands = await self.node.execute_command("COMMAND") - for cmd, command in commands.items(): - if "movablekeys" in command["flags"]: - commands[cmd] = -1 - elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: - commands[cmd] = 0 - elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: - commands[cmd] = 1 - self.commands = {cmd.upper(): command for cmd, command in commands.items()} - - # As soon as this PR is merged into Redis, we should reimplement - # our logic to use COMMAND INFO changes to determine the key positions - # https://github.com/redis/redis/pull/8324 - async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - if len(args) < 2: - # The command has no keys in it - return None - - try: - command = self.commands[args[0]] - except KeyError: - # try to split the command name and to take only the main command - # e.g. 'memory' for 'memory usage' - args = args[0].split() + list(args[1:]) - cmd_name = args[0].upper() - if cmd_name not in self.commands: - # We'll try to reinitialize the commands cache, if the engine - # version has changed, the commands may not be current - await self.initialize() - if cmd_name not in self.commands: - raise RedisError( - f"{cmd_name} command doesn't exist in Redis commands" - ) - - command = self.commands[cmd_name] - - if command == 1: - return (args[1],) - if command == 0: - return None - if command == -1: - return await self._get_moveable_keys(*args) - - last_key_pos = command["last_key_pos"] - if last_key_pos < 0: - last_key_pos = len(args) + last_key_pos - return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] - - async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - try: - keys = await self.node.execute_command("COMMAND GETKEYS", *args) - except ResponseError as e: - message = e.__str__() - if ( - "Invalid arguments" in message - or "The command has no key arguments" in message - ): - return None - else: - raise e - return keys diff --git a/redis/client.py b/redis/client.py index ab626ccdf4..66e2c7b84f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,5 +1,4 @@ import copy -import datetime import re import threading import time @@ -7,6 +6,12 @@ from itertools import chain from typing import Optional +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -18,7 +23,6 @@ from redis.exceptions import ( ConnectionError, ExecAbortError, - ModuleError, PubSubError, RedisError, ResponseError, @@ -27,7 +31,7 @@ ) from redis.lock import Lock from redis.retry import Retry -from redis.utils import safe_str, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -36,21 +40,6 @@ NEVER_DECODE = "NEVER_DECODE" -def timestamp_to_datetime(response): - "Converts a unix timestamp to a Python datetime object" - if not response: - return None - try: - response = int(response) - except ValueError: - return None - return datetime.datetime.fromtimestamp(response) - - -def string_keys_to_dict(key_string, callback): - return dict.fromkeys(key_string.split(), callback) - - class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." @@ -78,778 +67,11 @@ def update(self, data): super().update(data) -def parse_debug_object(response): - "Parse the results of Redis's DEBUG OBJECT command into a Python dict" - # The 'type' of the object is the first item in the response, but isn't - # prefixed with a name - response = str_if_bytes(response) - response = "type:" + response - response = dict(kv.split(":") for kv in response.split()) - - # parse some expected int values from the string response - # note: this cmd isn't spec'd so these may not appear in all redis versions - int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") - for field in int_fields: - if field in response: - response[field] = int(response[field]) - - return response - - -def parse_object(response, infotype): - """Parse the results of an OBJECT command""" - if infotype in ("idletime", "refcount"): - return int_or_none(response) - return response - - -def parse_info(response): - """Parse the result of Redis's INFO command into a Python dict""" - info = {} - response = str_if_bytes(response) - - def get_value(value): - if "," not in value or "=" not in value: - try: - if "." in value: - return float(value) - else: - return int(value) - except ValueError: - return value - else: - sub_dict = {} - for item in value.split(","): - k, v = item.rsplit("=", 1) - sub_dict[k] = get_value(v) - return sub_dict - - for line in response.splitlines(): - if line and not line.startswith("#"): - if line.find(":") != -1: - # Split, the info fields keys and values. - # Note that the value may contain ':'. but the 'host:' - # pseudo-command is the only case where the key contains ':' - key, value = line.split(":", 1) - if key == "cmdstat_host": - key, value = line.rsplit(":", 1) - - if key == "module": - # Hardcode a list for key 'modules' since there could be - # multiple lines that started with 'module' - info.setdefault("modules", []).append(get_value(value)) - else: - info[key] = get_value(value) - else: - # if the line isn't splittable, append it to the "__raw__" key - info.setdefault("__raw__", []).append(line) - - return info - - -def parse_memory_stats(response, **kwargs): - """Parse the results of MEMORY STATS""" - stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) - for key, value in stats.items(): - if key.startswith("db."): - stats[key] = pairs_to_dict( - value, decode_keys=True, decode_string_values=True - ) - return stats - - -SENTINEL_STATE_TYPES = { - "can-failover-its-master": int, - "config-epoch": int, - "down-after-milliseconds": int, - "failover-timeout": int, - "info-refresh": int, - "last-hello-message": int, - "last-ok-ping-reply": int, - "last-ping-reply": int, - "last-ping-sent": int, - "master-link-down-time": int, - "master-port": int, - "num-other-sentinels": int, - "num-slaves": int, - "o-down-time": int, - "pending-commands": int, - "parallel-syncs": int, - "port": int, - "quorum": int, - "role-reported-time": int, - "s-down-time": int, - "slave-priority": int, - "slave-repl-offset": int, - "voted-leader-epoch": int, -} - - -def parse_sentinel_state(item): - result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) - flags = set(result["flags"].split(",")) - for name, flag in ( - ("is_master", "master"), - ("is_slave", "slave"), - ("is_sdown", "s_down"), - ("is_odown", "o_down"), - ("is_sentinel", "sentinel"), - ("is_disconnected", "disconnected"), - ("is_master_down", "master_down"), - ): - result[name] = flag in flags - return result - - -def parse_sentinel_master(response): - return parse_sentinel_state(map(str_if_bytes, response)) - - -def parse_sentinel_masters(response): - result = {} - for item in response: - state = parse_sentinel_state(map(str_if_bytes, item)) - result[state["name"]] = state - return result - - -def parse_sentinel_slaves_and_sentinels(response): - return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] - - -def parse_sentinel_get_master(response): - return response and (response[0], int(response[1])) or None - - -def pairs_to_dict(response, decode_keys=False, decode_string_values=False): - """Create a dict given a list of key/value pairs""" - if response is None: - return {} - if decode_keys or decode_string_values: - # the iter form is faster, but I don't know how to make that work - # with a str_if_bytes() map - keys = response[::2] - if decode_keys: - keys = map(str_if_bytes, keys) - values = response[1::2] - if decode_string_values: - values = map(str_if_bytes, values) - return dict(zip(keys, values)) - else: - it = iter(response) - return dict(zip(it, it)) - - -def pairs_to_dict_typed(response, type_info): - it = iter(response) - result = {} - for key, value in zip(it, it): - if key in type_info: - try: - value = type_info[key](value) - except Exception: - # if for some reason the value can't be coerced, just use - # the string value - pass - result[key] = value - return result - - -def zset_score_pairs(response, **options): - """ - If ``withscores`` is specified in the options, return the response as - a list of (value, score) pairs - """ - if not response or not options.get("withscores"): - return response - score_cast_func = options.get("score_cast_func", float) - it = iter(response) - return list(zip(it, map(score_cast_func, it))) - - -def sort_return_tuples(response, **options): - """ - If ``groups`` is specified, return the response as a list of - n-element tuples with n being the value found in options['groups'] - """ - if not response or not options.get("groups"): - return response - n = options["groups"] - return list(zip(*[response[i::n] for i in range(n)])) - - -def int_or_none(response): - if response is None: - return None - return int(response) - - -def parse_stream_list(response): - if response is None: - return None - data = [] - for r in response: - if r is not None: - data.append((r[0], pairs_to_dict(r[1]))) - else: - data.append((None, None)) - return data - - -def pairs_to_dict_with_str_keys(response): - return pairs_to_dict(response, decode_keys=True) - - -def parse_list_of_dicts(response): - return list(map(pairs_to_dict_with_str_keys, response)) - - -def parse_xclaim(response, **options): - if options.get("parse_justid", False): - return response - return parse_stream_list(response) - - -def parse_xautoclaim(response, **options): - if options.get("parse_justid", False): - return response[1] - response[1] = parse_stream_list(response[1]) - return response - - -def parse_xinfo_stream(response, **options): - data = pairs_to_dict(response, decode_keys=True) - if not options.get("full", False): - first = data.get("first-entry") - if first is not None: - data["first-entry"] = (first[0], pairs_to_dict(first[1])) - last = data["last-entry"] - if last is not None: - data["last-entry"] = (last[0], pairs_to_dict(last[1])) - else: - data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - data["groups"] = [ - pairs_to_dict(group, decode_keys=True) for group in data["groups"] - ] - return data - - -def parse_xread(response): - if response is None: - return [] - return [[r[0], parse_stream_list(r[1])] for r in response] - - -def parse_xpending(response, **options): - if options.get("parse_detail", False): - return parse_xpending_range(response) - consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] - return { - "pending": response[0], - "min": response[1], - "max": response[2], - "consumers": consumers, - } - - -def parse_xpending_range(response): - k = ("message_id", "consumer", "time_since_delivered", "times_delivered") - return [dict(zip(k, r)) for r in response] - - -def float_or_none(response): - if response is None: - return None - return float(response) - - -def bool_ok(response): - return str_if_bytes(response) == "OK" - - -def parse_zadd(response, **options): - if response is None: - return None - if options.get("as_score"): - return float(response) - return int(response) - - -def parse_client_list(response, **options): - clients = [] - for c in str_if_bytes(response).splitlines(): - # Values might contain '=' - clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) - return clients - - -def parse_config_get(response, **options): - response = [str_if_bytes(i) if i is not None else None for i in response] - return response and pairs_to_dict(response) or {} - - -def parse_scan(response, **options): - cursor, r = response - return int(cursor), r - - -def parse_hscan(response, **options): - cursor, r = response - return int(cursor), r and pairs_to_dict(r) or {} - - -def parse_zscan(response, **options): - score_cast_func = options.get("score_cast_func", float) - cursor, r = response - it = iter(r) - return int(cursor), list(zip(it, map(score_cast_func, it))) - - -def parse_zmscore(response, **options): - # zmscore: list of scores (double precision floating point number) or nil - return [float(score) if score is not None else None for score in response] - - -def parse_slowlog_get(response, **options): - space = " " if options.get("decode_responses", False) else b" " - - def parse_item(item): - result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} - # Redis Enterprise injects another entry at index [3], which has - # the complexity info (i.e. the value N in case the command has - # an O(N) complexity) instead of the command. - if isinstance(item[3], list): - result["command"] = space.join(item[3]) - result["client_address"] = item[4] - result["client_name"] = item[5] - else: - result["complexity"] = item[3] - result["command"] = space.join(item[4]) - result["client_address"] = item[5] - result["client_name"] = item[6] - return result - - return [parse_item(item) for item in response] - - -def parse_stralgo(response, **options): - """ - Parse the response from `STRALGO` command. - Without modifiers the returned value is string. - When LEN is given the command returns the length of the result - (i.e integer). - When IDX is given the command returns a dictionary with the LCS - length and all the ranges in both the strings, start and end - offset for each string, where there are matches. - When WITHMATCHLEN is given, each array representing a match will - also have the length of the match at the beginning of the array. - """ - if options.get("len", False): - return int(response) - if options.get("idx", False): - if options.get("withmatchlen", False): - matches = [ - [(int(match[-1]))] + list(map(tuple, match[:-1])) - for match in response[1] - ] - else: - matches = [list(map(tuple, match)) for match in response[1]] - return { - str_if_bytes(response[0]): matches, - str_if_bytes(response[2]): int(response[3]), - } - return str_if_bytes(response) - - -def parse_cluster_info(response, **options): - response = str_if_bytes(response) - return dict(line.split(":") for line in response.splitlines() if line) - - -def _parse_node_line(line): - line_items = line.split(" ") - node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] - addr = addr.split("@")[0] - node_dict = { - "node_id": node_id, - "flags": flags, - "master_id": master_id, - "last_ping_sent": ping, - "last_pong_rcvd": pong, - "epoch": epoch, - "slots": [], - "migrations": [], - "connected": True if connected == "connected" else False, - } - if len(line_items) >= 9: - slots, migrations = _parse_slots(line_items[8:]) - node_dict["slots"], node_dict["migrations"] = slots, migrations - return addr, node_dict - - -def _parse_slots(slot_ranges): - slots, migrations = [], [] - for s_range in slot_ranges: - if "->-" in s_range: - slot_id, dst_node_id = s_range[1:-1].split("->-", 1) - migrations.append( - {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} - ) - elif "-<-" in s_range: - slot_id, src_node_id = s_range[1:-1].split("-<-", 1) - migrations.append( - {"slot": slot_id, "node_id": src_node_id, "state": "importing"} - ) - else: - s_range = [sl for sl in s_range.split("-")] - slots.append(s_range) - - return slots, migrations - - -def parse_cluster_nodes(response, **options): - """ - @see: https://redis.io/commands/cluster-nodes # string / bytes - @see: https://redis.io/commands/cluster-replicas # list of string / bytes - """ - if isinstance(response, (str, bytes)): - response = response.splitlines() - return dict(_parse_node_line(str_if_bytes(node)) for node in response) - - -def parse_geosearch_generic(response, **options): - """ - Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' - commands according to 'withdist', 'withhash' and 'withcoord' labels. - """ - try: - if options["store"] or options["store_dist"]: - # `store` and `store_dist` cant be combined - # with other command arguments. - # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' - return response - except KeyError: # it means the command was sent via execute_command - return response - - if type(response) != list: - response_list = [response] - else: - response_list = response - - if not options["withdist"] and not options["withcoord"] and not options["withhash"]: - # just a bunch of places - return response_list - - cast = { - "withdist": float, - "withcoord": lambda ll: (float(ll[0]), float(ll[1])), - "withhash": int, - } - - # zip all output results with each casting function to get - # the properly native Python value. - f = [lambda x: x] - f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] - return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] - - -def parse_command(response, **options): - commands = {} - for command in response: - cmd_dict = {} - cmd_name = str_if_bytes(command[0]) - cmd_dict["name"] = cmd_name - cmd_dict["arity"] = int(command[1]) - cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] - cmd_dict["first_key_pos"] = command[3] - cmd_dict["last_key_pos"] = command[4] - cmd_dict["step_count"] = command[5] - if len(command) > 7: - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] - commands[cmd_name] = cmd_dict - return commands - - -def parse_pubsub_numsub(response, **options): - return list(zip(response[0::2], response[1::2])) - - -def parse_client_kill(response, **options): - if isinstance(response, int): - return response - return str_if_bytes(response) == "OK" - - -def parse_acl_getuser(response, **options): - if response is None: - return None - data = pairs_to_dict(response, decode_keys=True) - - # convert everything but user-defined data in 'keys' to native strings - data["flags"] = list(map(str_if_bytes, data["flags"])) - data["passwords"] = list(map(str_if_bytes, data["passwords"])) - data["commands"] = str_if_bytes(data["commands"]) - if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): - data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) - if data["keys"] == [""]: - data["keys"] = [] - if "channels" in data: - if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): - data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) - if data["channels"] == [""]: - data["channels"] = [] - if "selectors" in data: - data["selectors"] = [ - list(map(str_if_bytes, selector)) for selector in data["selectors"] - ] - - # split 'commands' into separate 'categories' and 'commands' lists - commands, categories = [], [] - for command in data["commands"].split(" "): - if "@" in command: - categories.append(command) - else: - commands.append(command) - - data["commands"] = commands - data["categories"] = categories - data["enabled"] = "on" in data["flags"] - return data - - -def parse_acl_log(response, **options): - if response is None: - return None - if isinstance(response, list): - data = [] - for log in response: - log_data = pairs_to_dict(log, True, True) - client_info = log_data.get("client-info", "") - log_data["client-info"] = parse_client_info(client_info) - - # float() is lossy comparing to the "double" in C - log_data["age-seconds"] = float(log_data["age-seconds"]) - data.append(log_data) - else: - data = bool_ok(response) - return data - - -def parse_client_info(value): - """ - Parsing client-info in ACL Log in following format. - "key1=value1 key2=value2 key3=value3" - """ - client_info = {} - infos = str_if_bytes(value).split(" ") - for info in infos: - key, value = info.split("=") - client_info[key] = value - - # Those fields are defined as int in networking.c - for int_key in { - "id", - "age", - "idle", - "db", - "sub", - "psub", - "multi", - "qbuf", - "qbuf-free", - "obl", - "argv-mem", - "oll", - "omem", - "tot-mem", - }: - client_info[int_key] = int(client_info[int_key]) - return client_info - - -def parse_module_result(response): - if isinstance(response, ModuleError): - raise response - return True - - -def parse_set_result(response, **options): - """ - Handle SET result since GET argument is available since Redis 6.2. - Parsing SET result into: - - BOOL - - String when GET argument is used - """ - if options.get("get"): - # Redis will return a getCommand result. - # See `setGenericCommand` in t_string.c - return response - return response and str_if_bytes(response) == "OK" +class AbstractRedis: + pass -class AbstractRedis: - RESPONSE_CALLBACKS = { - **string_keys_to_dict( - "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " - "HEXISTS HMSET MOVE MSETNX PERSIST " - "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", - bool, - ), - **string_keys_to_dict( - "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " - "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " - "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " - "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " - "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", - int, - ), - **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), - **string_keys_to_dict( - # these return OK, or int if redis-server is >=1.3.4 - "LPUSH RPUSH", - lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", - ), - **string_keys_to_dict("SORT", sort_return_tuples), - **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), - **string_keys_to_dict( - "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " - "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", - bool_ok, - ), - **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), - **string_keys_to_dict( - "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - ), - **string_keys_to_dict( - "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " - "ZREVRANGE ZREVRANGEBYSCORE", - zset_score_pairs, - ), - **string_keys_to_dict( - "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None - ), - **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), - **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), - **string_keys_to_dict("XREAD XREADGROUP", parse_xread), - **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - "ACL CAT": lambda r: list(map(str_if_bytes, r)), - "ACL DELUSER": int, - "ACL GENPASS": str_if_bytes, - "ACL GETUSER": parse_acl_getuser, - "ACL HELP": lambda r: list(map(str_if_bytes, r)), - "ACL LIST": lambda r: list(map(str_if_bytes, r)), - "ACL LOAD": bool_ok, - "ACL LOG": parse_acl_log, - "ACL SAVE": bool_ok, - "ACL SETUSER": bool_ok, - "ACL USERS": lambda r: list(map(str_if_bytes, r)), - "ACL WHOAMI": str_if_bytes, - "CLIENT GETNAME": str_if_bytes, - "CLIENT ID": int, - "CLIENT KILL": parse_client_kill, - "CLIENT LIST": parse_client_list, - "CLIENT INFO": parse_client_info, - "CLIENT SETNAME": bool_ok, - "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, - "CLIENT PAUSE": bool_ok, - "CLIENT GETREDIR": int, - "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), - "CLUSTER ADDSLOTS": bool_ok, - "CLUSTER ADDSLOTSRANGE": bool_ok, - "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), - "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), - "CLUSTER DELSLOTS": bool_ok, - "CLUSTER DELSLOTSRANGE": bool_ok, - "CLUSTER FAILOVER": bool_ok, - "CLUSTER FORGET": bool_ok, - "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), - "CLUSTER INFO": parse_cluster_info, - "CLUSTER KEYSLOT": lambda x: int(x), - "CLUSTER MEET": bool_ok, - "CLUSTER NODES": parse_cluster_nodes, - "CLUSTER REPLICAS": parse_cluster_nodes, - "CLUSTER REPLICATE": bool_ok, - "CLUSTER RESET": bool_ok, - "CLUSTER SAVECONFIG": bool_ok, - "CLUSTER SET-CONFIG-EPOCH": bool_ok, - "CLUSTER SETSLOT": bool_ok, - "CLUSTER SLAVES": parse_cluster_nodes, - "COMMAND": parse_command, - "COMMAND COUNT": int, - "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), - "CONFIG GET": parse_config_get, - "CONFIG RESETSTAT": bool_ok, - "CONFIG SET": bool_ok, - "DEBUG OBJECT": parse_debug_object, - "FUNCTION DELETE": bool_ok, - "FUNCTION FLUSH": bool_ok, - "FUNCTION RESTORE": bool_ok, - "GEOHASH": lambda r: list(map(str_if_bytes, r)), - "GEOPOS": lambda r: list( - map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) - ), - "GEOSEARCH": parse_geosearch_generic, - "GEORADIUS": parse_geosearch_generic, - "GEORADIUSBYMEMBER": parse_geosearch_generic, - "HGETALL": lambda r: r and pairs_to_dict(r) or {}, - "HSCAN": parse_hscan, - "INFO": parse_info, - "LASTSAVE": timestamp_to_datetime, - "MEMORY PURGE": bool_ok, - "MEMORY STATS": parse_memory_stats, - "MEMORY USAGE": int_or_none, - "MODULE LOAD": parse_module_result, - "MODULE UNLOAD": parse_module_result, - "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], - "OBJECT": parse_object, - "PING": lambda r: str_if_bytes(r) == "PONG", - "QUIT": bool_ok, - "STRALGO": parse_stralgo, - "PUBSUB NUMSUB": parse_pubsub_numsub, - "RANDOMKEY": lambda r: r and r or None, - "RESET": str_if_bytes, - "SCAN": parse_scan, - "SCRIPT EXISTS": lambda r: list(map(bool, r)), - "SCRIPT FLUSH": bool_ok, - "SCRIPT KILL": bool_ok, - "SCRIPT LOAD": str_if_bytes, - "SENTINEL CKQUORUM": bool_ok, - "SENTINEL FAILOVER": bool_ok, - "SENTINEL FLUSHCONFIG": bool_ok, - "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, - "SENTINEL MASTER": parse_sentinel_master, - "SENTINEL MASTERS": parse_sentinel_masters, - "SENTINEL MONITOR": bool_ok, - "SENTINEL RESET": bool_ok, - "SENTINEL REMOVE": bool_ok, - "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, - "SENTINEL SET": bool_ok, - "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, - "SET": parse_set_result, - "SLOWLOG GET": parse_slowlog_get, - "SLOWLOG LEN": int, - "SLOWLOG RESET": bool_ok, - "SSCAN": parse_scan, - "TIME": lambda x: (int(x[0]), int(x[1])), - "XCLAIM": parse_xclaim, - "XAUTOCLAIM": parse_xautoclaim, - "XGROUP CREATE": bool_ok, - "XGROUP DELCONSUMER": int, - "XGROUP DESTROY": bool, - "XGROUP SETID": bool_ok, - "XINFO CONSUMERS": parse_list_of_dicts, - "XINFO GROUPS": parse_list_of_dicts, - "XINFO STREAM": parse_xinfo_stream, - "XPENDING": parse_xpending, - "ZADD": parse_zadd, - "ZSCAN": parse_zscan, - "ZMSCORE": parse_zmscore, - } - - -class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ Implementation of the Redis protocol. @@ -953,6 +175,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Redis client. @@ -1001,6 +224,7 @@ def __init__( "client_name": client_name, "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, + "protocol": protocol, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -1046,7 +270,12 @@ def __init__( if single_connection_client: self.connection = self.connection_pool.get_connection("_") - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) + + if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.response_callbacks.update(_RedisCallbacksRESP3) + else: + self.response_callbacks.update(_RedisCallbacksRESP2) def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -1376,8 +605,8 @@ class PubSub: will be returned and it's safe to start listening again. """ - PUBLISH_MESSAGE_TYPES = ("message", "pmessage") - UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe") HEALTH_CHECK_MESSAGE = "redis-py-health-check" def __init__( @@ -1386,6 +615,7 @@ def __init__( shard_hint=None, ignore_subscribe_messages=False, encoder=None, + push_handler_func=None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -1395,6 +625,7 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) @@ -1402,6 +633,8 @@ def __init__( self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: self.health_check_response = [b"pong", self.health_check_response_b] + if self.push_handler_func is None: + _set_info_logger() self.reset() def __enter__(self): @@ -1425,9 +658,11 @@ def reset(self): self.connection.clear_connect_callbacks() self.connection_pool.release(self.connection) self.connection = None - self.channels = {} self.health_check_response_counter = 0 + self.channels = {} self.pending_unsubscribe_channels = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() @@ -1442,16 +677,23 @@ def on_connect(self, connection): # before passing them to [p]subscribe. self.pending_unsubscribe_channels.clear() self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() if self.channels: - channels = {} - for k, v in self.channels.items(): - channels[self.encoder.decode(k, force=True)] = v + channels = { + self.encoder.decode(k, force=True): v for k, v in self.channels.items() + } self.subscribe(**channels) if self.patterns: - patterns = {} - for k, v in self.patterns.items(): - patterns[self.encoder.decode(k, force=True)] = v + patterns = { + self.encoder.decode(k, force=True): v for k, v in self.patterns.items() + } self.psubscribe(**patterns) + if self.shard_channels: + shard_channels = { + self.encoder.decode(k, force=True): v + for k, v in self.shard_channels.items() + } + self.ssubscribe(**shard_channels) @property def subscribed(self): @@ -1472,6 +714,8 @@ def execute_command(self, *args): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: @@ -1537,7 +781,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response(disconnect_on_error=False) + return conn.read_response(disconnect_on_error=False, push_request=True) response = self._execute(conn, try_read) @@ -1658,6 +902,45 @@ def unsubscribe(self, *args): self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *args) + def ssubscribe(self, *args, target_node=None, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + return self.execute_command("SUNSUBSCRIBE", *args) + def listen(self): "Listen for messages on channels this client has been subscribed to" while self.subscribed: @@ -1692,12 +975,14 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): return self.handle_message(response, ignore_subscribe_messages) return None + get_sharded_message = get_message + def ping(self, message=None): """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -1707,6 +992,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): """ if response is None: return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1737,12 +1024,17 @@ def handle_message(self, response, ignore_subscribe_messages=False): if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) else: channel = response[1] if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) - if not self.channels and not self.patterns: + if not self.channels and not self.patterns and not self.shard_channels: # There are no subscriptions anymore, set subscribed_event flag # to false self.subscribed_event.clear() @@ -1751,6 +1043,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): # if there's a message handler, invoke it if message_type == "pmessage": handler = self.patterns.get(message["pattern"], None) + elif message_type == "smessage": + handler = self.shard_channels.get(message["channel"], None) else: handler = self.channels.get(message["channel"], None) if handler: @@ -1771,6 +1065,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): for pattern, handler in self.patterns.items(): if handler is None: raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + for s_channel, handler in self.shard_channels.items(): + if handler is None: + raise PubSubError( + f"Shard Channel: '{s_channel}' has no handler registered" + ) thread = PubSubWorkerThread( self, sleep_time, daemon=daemon, exception_handler=exception_handler diff --git a/redis/cluster.py b/redis/cluster.py index be8e4623a7..c179511b0c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,10 +6,13 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from redis._parsers import CommandsParser, Encoder +from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan -from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands -from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.commands.helpers import list_or_args +from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -31,6 +34,7 @@ from redis.lock import Lock from redis.retry import Retry from redis.utils import ( + HIREDIS_AVAILABLE, dict_merge, list_keys_to_dict, merge_result, @@ -145,6 +149,7 @@ def parse_cluster_myshardid(resp, **options): "queue_class", "retry", "retry_on_timeout", + "protocol", "socket_connect_timeout", "socket_keepalive", "socket_keepalive_options", @@ -227,6 +232,8 @@ class AbstractRedisCluster: "PUBSUB CHANNELS", "PUBSUB NUMPAT", "PUBSUB NUMSUB", + "PUBSUB SHARDCHANNELS", + "PUBSUB SHARDNUMSUB", "PING", "INFO", "SHUTDOWN", @@ -253,7 +260,6 @@ class AbstractRedisCluster: "CLIENT INFO", "CLIENT KILL", "READONLY", - "READWRITE", "CLUSTER INFO", "CLUSTER MEET", "CLUSTER NODES", @@ -353,11 +359,13 @@ class AbstractRedisCluster: } RESULT_CALLBACKS = dict_merge( - list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), + list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub), list_keys_to_dict( ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) ), - list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), + list_keys_to_dict( + ["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result + ), list_keys_to_dict( [ "PING", @@ -1632,7 +1640,15 @@ class ClusterPubSub(PubSub): https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html """ - def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): + def __init__( + self, + redis_cluster, + node=None, + host=None, + port=None, + push_handler_func=None, + **kwargs, + ): """ When a pubsub instance is created without specifying a node, a single node will be transparently chosen for the pubsub connection on the @@ -1654,8 +1670,13 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): else redis_cluster.get_redis_connection(self.node).connection_pool ) self.cluster = redis_cluster + self.node_pubsub_mapping = {} + self._pubsubs_generator = self._pubsubs_generator() super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, + **kwargs, ) def set_pubsub_node(self, cluster, node=None, host=None, port=None): @@ -1707,9 +1728,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port): f"Node {host}:{port} doesn't exist in the cluster" ) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args): """ - Execute a publish/subscribe command. + Execute a subscribe/unsubscribe command. Taken code from redis-py and tweak to make it work within a cluster. """ @@ -1739,9 +1760,94 @@ def execute_command(self, *args, **kwargs): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) + def _get_node_pubsub(self, node): + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + pubsub = node.redis_connection.pubsub( + push_handler_func=self.push_handler_func + ) + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + def _sharded_message_generator(self): + for _ in range(len(self.node_pubsub_mapping)): + pubsub = next(self._pubsubs_generator) + message = pubsub.get_message() + if message is not None: + return message + return None + + def _pubsubs_generator(self): + while True: + for pubsub in self.node_pubsub_mapping.values(): + yield pubsub + + def get_sharded_message( + self, ignore_subscribe_messages=False, timeout=0.0, target_node=None + ): + if target_node: + message = self.node_pubsub_mapping[target_node.name].get_message( + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + else: + message = self._sharded_message_generator() + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if self.node_pubsub_mapping[node.name].subscribed is False: + self.node_pubsub_mapping.pop(node.name) + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None + return message + + def ssubscribe(self, *args, **kwargs): + if args: + args = list_or_args(args[0], args[1:]) + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + for s_channel, handler in s_channels.items(): + node = self.cluster.get_node_from_key(s_channel) + pubsub = self._get_node_pubsub(node) + if handler: + pubsub.ssubscribe(**{s_channel: handler}) + else: + pubsub.ssubscribe(s_channel) + self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) + if pubsub.subscribed and not self.subscribed: + self.subscribed_event.set() + self.health_check_response_counter = 0 + + def sunsubscribe(self, *args): + if args: + args = list_or_args(args[0], args[1:]) + else: + args = self.shard_channels + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + p = self._get_node_pubsub(node) + p.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + p.pending_unsubscribe_shard_channels + ) + def get_redis_connection(self): """ Get the Redis connection of the pubsub connected node. @@ -1749,6 +1855,15 @@ def get_redis_connection(self): if self.node is not None: return self.node.redis_connection + def disconnect(self): + """ + Disconnect the pubsub connection. + """ + if self.connection: + self.connection.disconnect() + for pubsub in self.node_pubsub_mapping.values(): + pubsub.connection.disconnect() + class ClusterPipeline(RedisCluster): """ diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index f3f08286c8..a94d9764a6 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,7 +1,6 @@ from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args -from .parser import CommandsParser from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands from .sentinel import AsyncSentinelCommands, SentinelCommands @@ -10,7 +9,6 @@ "AsyncRedisClusterCommands", "AsyncRedisModuleCommands", "AsyncSentinelCommands", - "CommandsParser", "CoreCommands", "READ_COMMANDS", "RedisClusterCommands", diff --git a/redis/commands/bf/__init__.py b/redis/commands/bf/__init__.py index 4da060e995..bfa9456879 100644 --- a/redis/commands/bf/__init__.py +++ b/redis/commands/bf/__init__.py @@ -1,4 +1,4 @@ -from redis.client import bool_ok +from redis._parsers.helpers import bool_ok from ..helpers import parse_to_list from .commands import * # noqa @@ -91,20 +91,29 @@ class CMSBloom(CMSCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { CMS_INITBYDIM: bool_ok, CMS_INITBYPROB: bool_ok, # CMS_INCRBY: spaceHolder, # CMS_QUERY: spaceHolder, CMS_MERGE: bool_ok, + } + + _RESP2_MODULE_CALLBACKS = { CMS_INFO: CMSInfo, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CMSCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -112,21 +121,30 @@ class TOPKBloom(TOPKCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { TOPK_RESERVE: bool_ok, - TOPK_ADD: parse_to_list, - TOPK_INCRBY: parse_to_list, # TOPK_QUERY: spaceHolder, # TOPK_COUNT: spaceHolder, - TOPK_LIST: parse_to_list, + } + + _RESP2_MODULE_CALLBACKS = { + TOPK_ADD: parse_to_list, + TOPK_INCRBY: parse_to_list, TOPK_INFO: TopKInfo, + TOPK_LIST: parse_to_list, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TOPKCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -134,7 +152,7 @@ class CFBloom(CFCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { CF_RESERVE: bool_ok, # CF_ADD: spaceHolder, # CF_ADDNX: spaceHolder, @@ -145,14 +163,23 @@ def __init__(self, client, **kwargs): # CF_COUNT: spaceHolder, # CF_SCANDUMP: spaceHolder, # CF_LOADCHUNK: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { CF_INFO: CFInfo, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CFCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -160,28 +187,35 @@ class TDigestBloom(TDigestCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { TDIGEST_CREATE: bool_ok, # TDIGEST_RESET: bool_ok, # TDIGEST_ADD: spaceHolder, # TDIGEST_MERGE: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { + TDIGEST_BYRANK: parse_to_list, + TDIGEST_BYREVRANK: parse_to_list, TDIGEST_CDF: parse_to_list, - TDIGEST_QUANTILE: parse_to_list, + TDIGEST_INFO: TDigestInfo, TDIGEST_MIN: float, TDIGEST_MAX: float, TDIGEST_TRIMMED_MEAN: float, - TDIGEST_INFO: TDigestInfo, - TDIGEST_RANK: parse_to_list, - TDIGEST_REVRANK: parse_to_list, - TDIGEST_BYRANK: parse_to_list, - TDIGEST_BYREVRANK: parse_to_list, + TDIGEST_QUANTILE: parse_to_list, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TDigestCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -189,7 +223,7 @@ class BFBloom(BFCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { BF_RESERVE: bool_ok, # BF_ADD: spaceHolder, # BF_MADD: spaceHolder, @@ -199,12 +233,21 @@ def __init__(self, client, **kwargs): # BF_SCANDUMP: spaceHolder, # BF_LOADCHUNK: spaceHolder, # BF_CARD: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { BF_INFO: BFInfo, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = BFCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) diff --git a/redis/commands/bf/commands.py b/redis/commands/bf/commands.py index c45523c99b..447f844508 100644 --- a/redis/commands/bf/commands.py +++ b/redis/commands/bf/commands.py @@ -60,7 +60,6 @@ class BFCommands: """Bloom Filter commands.""" - # region Bloom Filter Functions def create(self, key, errorRate, capacity, expansion=None, noScale=None): """ Create a new Bloom Filter `key` with desired probability of false positives @@ -178,7 +177,6 @@ def card(self, key): class CFCommands: """Cuckoo Filter commands.""" - # region Cuckoo Filter Functions def create( self, key, capacity, expansion=None, bucket_size=None, max_iterations=None ): @@ -488,7 +486,6 @@ def byrevrank(self, key, rank, *ranks): class CMSCommands: """Count-Min Sketch Commands""" - # region Count-Min Sketch Functions def initbydim(self, key, width, depth): """ Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user. diff --git a/redis/commands/bf/info.py b/redis/commands/bf/info.py index c526e6ca4c..e1f0208609 100644 --- a/redis/commands/bf/info.py +++ b/redis/commands/bf/info.py @@ -16,6 +16,15 @@ def __init__(self, args): self.insertedNum = response["Number of items inserted"] self.expansionRate = response["Expansion rate"] + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + class CFInfo(object): size = None @@ -38,6 +47,15 @@ def __init__(self, args): self.expansionRate = response["Expansion rate"] self.maxIteration = response["Max iterations"] + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + class CMSInfo(object): width = None @@ -50,6 +68,9 @@ def __init__(self, args): self.depth = response["depth"] self.count = response["count"] + def __getitem__(self, item): + return getattr(self, item) + class TopKInfo(object): k = None @@ -64,6 +85,9 @@ def __init__(self, args): self.depth = response["depth"] self.decay = response["decay"] + def __getitem__(self, item): + return getattr(self, item) + class TDigestInfo(object): compression = None @@ -85,3 +109,12 @@ def __init__(self, args): self.unmerged_weight = response["Unmerged weight"] self.total_compressions = response["Total compressions"] self.memory_usage = response["Memory usage"] + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/redis/commands/core.py b/redis/commands/core.py index 392ddb542c..6abcd5a2ec 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3365,9 +3365,13 @@ def sinterstore( args = list_or_args(keys, args) return self.execute_command("SINTERSTORE", dest, *args) - def sismember(self, name: str, value: str) -> Union[Awaitable[bool], bool]: + def sismember( + self, name: str, value: str + ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: """ - Return a boolean indicating if ``value`` is a member of set ``name`` + Return whether ``value`` is a member of set ``name``: + - 1 if the value is a member of the set. + - 0 if the value is not a member of the set or if key does not exist. For more information see https://redis.io/commands/sismember """ @@ -5147,6 +5151,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT """ return self.execute_command("PUBLISH", channel, message, **kwargs) + def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + """ + Posts a message to the given shard channel. + Returns the number of clients that received the message + + For more information see https://redis.io/commands/spublish + """ + return self.execute_command("SPUBLISH", shard_channel, message) + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber @@ -5155,6 +5168,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) + def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Return a list of shard_channels that have at least one subscriber + + For more information see https://redis.io/commands/pubsub-shardchannels + """ + return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -5172,6 +5193,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) + def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + """ + Return a list of (shard_channel, number of subscribers) tuples + for each channel given in ``*args`` + + For more information see https://redis.io/commands/pubsub-shardnumsub + """ + return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs) + AsyncPubSubCommands = PubSubCommands diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index 4f2a0c5ffc..e895e6a2ba 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -31,37 +31,54 @@ def __init__( :type json.JSONEncoder: An instance of json.JSONEncoder """ # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - "JSON.CLEAR": int, - "JSON.DEL": int, - "JSON.FORGET": int, - "JSON.GET": self._decode, + self._MODULE_CALLBACKS = { + "JSON.ARRPOP": self._decode, + "JSON.DEBUG": self._decode, + "JSON.MERGE": lambda r: r and nativestr(r) == "OK", "JSON.MGET": bulk_of_jsons(self._decode), - "JSON.SET": lambda r: r and nativestr(r) == "OK", "JSON.MSET": lambda r: r and nativestr(r) == "OK", - "JSON.MERGE": lambda r: r and nativestr(r) == "OK", - "JSON.NUMINCRBY": self._decode, - "JSON.NUMMULTBY": self._decode, + "JSON.RESP": self._decode, + "JSON.SET": lambda r: r and nativestr(r) == "OK", "JSON.TOGGLE": self._decode, - "JSON.STRAPPEND": self._decode, - "JSON.STRLEN": self._decode, + } + + _RESP2_MODULE_CALLBACKS = { "JSON.ARRAPPEND": self._decode, "JSON.ARRINDEX": self._decode, "JSON.ARRINSERT": self._decode, "JSON.ARRLEN": self._decode, - "JSON.ARRPOP": self._decode, "JSON.ARRTRIM": self._decode, - "JSON.OBJLEN": self._decode, + "JSON.CLEAR": int, + "JSON.DEL": int, + "JSON.FORGET": int, + "JSON.GET": self._decode, + "JSON.NUMINCRBY": self._decode, + "JSON.NUMMULTBY": self._decode, "JSON.OBJKEYS": self._decode, - "JSON.RESP": self._decode, - "JSON.DEBUG": self._decode, + "JSON.STRAPPEND": self._decode, + "JSON.OBJLEN": self._decode, + "JSON.STRLEN": self._decode, + "JSON.TOGGLE": self._decode, + } + + _RESP3_MODULE_CALLBACKS = { + "JSON.GET": lambda response: [ + [self._decode(r) for r in res] for res in response + ] + if response + else response } self.client = client self.execute_command = client.execute_command self.MODULE_VERSION = version - for key, value in self.MODULE_CALLBACKS.items(): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for key, value in self._MODULE_CALLBACKS.items(): self.client.set_response_callback(key, value) self.__encoder__ = encoder @@ -117,7 +134,7 @@ def pipeline(self, transaction=True, shard_hint=None): else: p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 70e9c279e5..e635f91e99 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -1,7 +1,17 @@ import redis from ...asyncio.client import Pipeline as AsyncioPipeline -from .commands import AsyncSearchCommands, SearchCommands +from .commands import ( + AGGREGATE_CMD, + CONFIG_CMD, + INFO_CMD, + PROFILE_CMD, + SEARCH_CMD, + SPELLCHECK_CMD, + SYNDUMP_CMD, + AsyncSearchCommands, + SearchCommands, +) class Search(SearchCommands): @@ -85,11 +95,20 @@ def __init__(self, client, index_name="idx"): If conn is not None, we employ an already existing redis connection """ - self.MODULE_CALLBACKS = {} + self._MODULE_CALLBACKS = {} self.client = client self.index_name = index_name self.execute_command = client.execute_command self._pipeline = client.pipeline + self._RESP2_MODULE_CALLBACKS = { + INFO_CMD: self._parse_info, + SEARCH_CMD: self._parse_search, + AGGREGATE_CMD: self._parse_aggregate, + PROFILE_CMD: self._parse_profile, + SPELLCHECK_CMD: self._parse_spellcheck, + CONFIG_CMD: self._parse_config_get, + SYNDUMP_CMD: self._parse_syndump, + } def pipeline(self, transaction=True, shard_hint=None): """Creates a pipeline for the SEARCH module, that can be used for executing @@ -97,7 +116,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) @@ -155,7 +174,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ p = AsyncPipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 3bd7d47aa8..742474523f 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -63,6 +63,86 @@ class SearchCommands: """Search commands.""" + def _parse_results(self, cmd, res, **kwargs): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + return res + else: + return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) + + def _parse_info(self, res, **kwargs): + it = map(to_string, res) + return dict(zip(it, it)) + + def _parse_search(self, res, **kwargs): + return Result( + res, + not kwargs["query"]._no_content, + duration=kwargs["duration"], + has_payload=kwargs["query"]._with_payloads, + with_scores=kwargs["query"]._with_scores, + ) + + def _parse_aggregate(self, res, **kwargs): + return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) + + def _parse_profile(self, res, **kwargs): + query = kwargs["query"] + if isinstance(query, AggregateRequest): + result = self._get_aggregate_result(res[0], query, query._cursor) + else: + result = Result( + res[0], + not query._no_content, + duration=kwargs["duration"], + has_payload=query._with_payloads, + with_scores=query._with_scores, + ) + + return result, parse_to_dict(res[1]) + + def _parse_spellcheck(self, res, **kwargs): + corrections = {} + if res == 0: + return corrections + + for _correction in res: + if isinstance(_correction, int) and _correction == 0: + continue + + if len(_correction) != 3: + continue + if not _correction[2]: + continue + if not _correction[2][0]: + continue + + # For spellcheck output + # 1) 1) "TERM" + # 2) "{term1}" + # 3) 1) 1) "{score1}" + # 2) "{suggestion1}" + # 2) 1) "{score2}" + # 2) "{suggestion2}" + # + # Following dictionary will be made + # corrections = { + # '{term1}': [ + # {'score': '{score1}', 'suggestion': '{suggestion1}'}, + # {'score': '{score2}', 'suggestion': '{suggestion2}'} + # ] + # } + corrections[_correction[1]] = [ + {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] + ] + + return corrections + + def _parse_config_get(self, res, **kwargs): + return {kvs[0]: kvs[1] for kvs in res} if res else {} + + def _parse_syndump(self, res, **kwargs): + return {res[i]: res[i + 1] for i in range(0, len(res), 2)} + def batch_indexer(self, chunk_size=100): """ Create a new batch indexer from the client with a given chunk size @@ -368,8 +448,7 @@ def info(self): """ res = self.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) + return self._parse_results(INFO_CMD, res) def get_params_args( self, query_params: Union[Dict[str, Union[str, int, float]], None] @@ -422,12 +501,8 @@ def search( if isinstance(res, Pipeline): return res - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) def explain( @@ -473,7 +548,9 @@ def aggregate( cmd += self.get_params_args(query_params) raw = self.execute_command(*cmd) - return self._get_aggregate_result(raw, query, has_cursor) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) def _get_aggregate_result(self, raw, query, has_cursor): if has_cursor: @@ -531,18 +608,9 @@ def profile( res = self.execute_command(*cmd) - if isinstance(query, AggregateRequest): - result = self._get_aggregate_result(res[0], query, query._cursor) - else: - result = Result( - res[0], - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, - ) - - return result, parse_to_dict(res[1]) + return self._parse_results( + PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) def spellcheck(self, query, distance=None, include=None, exclude=None): """ @@ -568,43 +636,9 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): if exclude: cmd.extend(["TERMS", "EXCLUDE", exclude]) - raw = self.execute_command(*cmd) - - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - # For spellcheck output - # 1) 1) "TERM" - # 2) "{term1}" - # 3) 1) 1) "{score1}" - # 2) "{suggestion1}" - # 2) 1) "{score2}" - # 2) "{suggestion2}" - # - # Following dictionary will be made - # corrections = { - # '{term1}': [ - # {'score': '{score1}', 'suggestion': '{suggestion1}'}, - # {'score': '{score2}', 'suggestion': '{suggestion2}'} - # ] - # } - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] + res = self.execute_command(*cmd) - return corrections + return self._parse_results(SPELLCHECK_CMD, res) def dict_add(self, name, *terms): """Adds terms to a dictionary. @@ -670,12 +704,8 @@ def config_get(self, option): For more information see `FT.CONFIG GET `_. """ # noqa cmd = [CONFIG_CMD, "GET", option] - res = {} - raw = self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res + res = self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) def tagvals(self, tagfield): """ @@ -810,12 +840,12 @@ def sugget( if with_payloads: args.append(WITHPAYLOADS) - ret = self.execute_command(*args) + res = self.execute_command(*args) results = [] - if not ret: + if not res: return results - parser = SuggestionParser(with_scores, with_payloads, ret) + parser = SuggestionParser(with_scores, with_payloads, res) return [s for s in parser] def synupdate(self, groupid, skipinitial=False, *terms): @@ -851,8 +881,8 @@ def syndump(self): For more information see `FT.SYNDUMP `_. """ # noqa - raw = self.execute_command(SYNDUMP_CMD, self.index_name) - return {raw[i]: raw[i + 1] for i in range(0, len(raw), 2)} + res = self.execute_command(SYNDUMP_CMD, self.index_name) + return self._parse_results(SYNDUMP_CMD, res) class AsyncSearchCommands(SearchCommands): @@ -865,8 +895,7 @@ async def info(self): """ res = await self.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) + return self._parse_results(INFO_CMD, res) async def search( self, @@ -891,12 +920,8 @@ async def search( if isinstance(res, Pipeline): return res - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) async def aggregate( @@ -927,7 +952,9 @@ async def aggregate( cmd += self.get_params_args(query_params) raw = await self.execute_command(*cmd) - return self._get_aggregate_result(raw, query, has_cursor) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) async def spellcheck(self, query, distance=None, include=None, exclude=None): """ @@ -953,28 +980,9 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): if exclude: cmd.extend(["TERMS", "EXCLUDE", exclude]) - raw = await self.execute_command(*cmd) + res = await self.execute_command(*cmd) - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] - - return corrections + return self._parse_results(SPELLCHECK_CMD, res) async def config_set(self, option, value): """Set runtime configuration option. @@ -1001,11 +1009,8 @@ async def config_get(self, option): """ # noqa cmd = [CONFIG_CMD, "GET", option] res = {} - raw = await self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res + res = await self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) async def load_document(self, id): """ diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 4a6886f237..498f5118f1 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -1,4 +1,5 @@ import redis +from redis._parsers.helpers import bool_ok from ..helpers import parse_to_list from .commands import ( @@ -32,27 +33,36 @@ class TimeSeries(TimeSeriesCommands): def __init__(self, client=None, **kwargs): """Create a new RedisTimeSeries client.""" # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - CREATE_CMD: redis.client.bool_ok, - ALTER_CMD: redis.client.bool_ok, - CREATERULE_CMD: redis.client.bool_ok, + self._MODULE_CALLBACKS = { + ALTER_CMD: bool_ok, + CREATE_CMD: bool_ok, + CREATERULE_CMD: bool_ok, + DELETERULE_CMD: bool_ok, + } + + _RESP2_MODULE_CALLBACKS = { DEL_CMD: int, - DELETERULE_CMD: redis.client.bool_ok, - RANGE_CMD: parse_range, - REVRANGE_CMD: parse_range, - MRANGE_CMD: parse_m_range, - MREVRANGE_CMD: parse_m_range, GET_CMD: parse_get, - MGET_CMD: parse_m_get, INFO_CMD: TSInfo, + MGET_CMD: parse_m_get, + MRANGE_CMD: parse_m_range, + MREVRANGE_CMD: parse_m_range, + RANGE_CMD: parse_range, + REVRANGE_CMD: parse_range, QUERYINDEX_CMD: parse_to_list, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.execute_command = client.execute_command - for key, value in self.MODULE_CALLBACKS.items(): - self.client.set_response_callback(key, value) + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in self._MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) def pipeline(self, transaction=True, shard_hint=None): """Creates a pipeline for the TimeSeries module, that can be used @@ -83,7 +93,7 @@ def pipeline(self, transaction=True, shard_hint=None): else: p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/timeseries/info.py b/redis/commands/timeseries/info.py index 65f3baacd0..3a384dc049 100644 --- a/redis/commands/timeseries/info.py +++ b/redis/commands/timeseries/info.py @@ -80,3 +80,12 @@ def __init__(self, args): self.duplicate_policy = response["duplicatePolicy"] if type(self.duplicate_policy) == bytes: self.duplicate_policy = self.duplicate_policy.decode() + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/redis/compat.py b/redis/compat.py index 738687f645..e478493467 100644 --- a/redis/compat.py +++ b/redis/compat.py @@ -2,8 +2,5 @@ try: from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] except ImportError: - from typing_extensions import ( # lgtm [py/unused-import] - Literal, - Protocol, - TypedDict, - ) + from typing_extensions import Literal # lgtm [py/unused-import] + from typing_extensions import Protocol, TypedDict diff --git a/redis/connection.py b/redis/connection.py index bec456c9ce..66debed2ea 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,65 +1,39 @@ import copy -import errno -import io import os import socket +import ssl import sys import threading import weakref from abc import abstractmethod -from io import SEEK_END from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional, Union +from typing import Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from redis.backoff import NoBackoff -from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider -from redis.exceptions import ( +from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser +from .backoff import NoBackoff +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider +from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - OutOfMemoryError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.retry import Retry -from redis.utils import ( +from .retry import Retry +from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, + SSL_AVAILABLE, str_if_bytes, ) -try: - import ssl - - ssl_available = True -except ImportError: - ssl_available = False - -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} - -if ssl_available: - if hasattr(ssl, "SSLWantReadError"): - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 - else: - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) - if HIREDIS_AVAILABLE: import hiredis @@ -68,454 +42,15 @@ SYM_CRLF = b"\r\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +DEFAULT_RESP_VERSION = 2 SENTINEL = object() -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class Encoder: - "Encode strings to bytes-like and decode bytes-like to strings" - - def __init__(self, encoding, encoding_errors, decode_responses): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value): - "Return a bytestring or bytes-like representation of the value" - if isinstance(value, (bytes, memoryview)): - return value - elif isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. Convert to a " - "bytes, string, int or float first." - ) - elif isinstance(value, (int, float)): - value = repr(value).encode() - elif not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = type(value).__name__ - raise DataError( - f"Invalid input of type: '{typename}'. " - f"Convert to a bytes, string, int or float first." - ) - if isinstance(value, str): - value = value.encode(self.encoding, self.encoding_errors) - return value - - def decode(self, value, force=False): - "Return a unicode string from the bytes-like representation" - if self.decode_responses or force: - if isinstance(value, memoryview): - value = value.tobytes() - if isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - return value - - -class BaseParser: - EXCEPTION_CLASSES = { - "ERR": { - "max number of clients reached": ConnectionError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments " - "for 'auth' command": AuthenticationWrongNumberOfArgsError, - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments " - "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "OOM": OutOfMemoryError, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - @classmethod - def parse_error(cls, response): - "Parse an error response" - error_code = response.split(" ")[0] - if error_code in cls.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = cls.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - -class SocketBuffer: - def __init__( - self, socket: socket.socket, socket_read_size: int, socket_timeout: float - ): - self._sock = socket - self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout - self._buffer = io.BytesIO() - - def unread_bytes(self) -> int: - """ - Remaining unread length of buffer - """ - pos = self._buffer.tell() - end = self._buffer.seek(0, SEEK_END) - self._buffer.seek(pos) - return end - pos - - def _read_from_socket( - self, - length: Optional[int] = None, - timeout: Union[float, object] = SENTINEL, - raise_on_timeout: Optional[bool] = True, - ) -> bool: - sock = self._sock - socket_read_size = self.socket_read_size - marker = 0 - custom_timeout = timeout is not SENTINEL - - buf = self._buffer - current_pos = buf.tell() - buf.seek(0, SEEK_END) - if custom_timeout: - sock.settimeout(timeout) - try: - while True: - data = self._sock.recv(socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - marker += data_length - - if length is not None and length > marker: - continue - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - buf.seek(current_pos) - if custom_timeout: - sock.settimeout(self.socket_timeout) - - def can_read(self, timeout: float) -> bool: - return bool(self.unread_bytes()) or self._read_from_socket( - timeout=timeout, raise_on_timeout=False - ) - - def read(self, length: int) -> bytes: - length = length + 2 # make sure to read the \r\n terminator - # BufferIO will return less than requested if buffer is short - data = self._buffer.read(length) - missing = length - len(data) - if missing: - # fill up the buffer and read the remainder - self._read_from_socket(missing) - data += self._buffer.read(missing) - return data[:-2] - - def readline(self) -> bytes: - buf = self._buffer - data = buf.readline() - while not data.endswith(SYM_CRLF): - # there's more data in the socket that we need - self._read_from_socket() - data += buf.readline() - - return data[:-2] - - def get_pos(self) -> int: - """ - Get current read position - """ - return self._buffer.tell() - - def rewind(self, pos: int) -> None: - """ - Rewind the buffer to a specific position, to re-start reading - """ - self._buffer.seek(pos) - - def purge(self) -> None: - """ - After a successful read, purge the read part of buffer - """ - unread = self.unread_bytes() - - # Only if we have read all of the buffer do we truncate, to - # reduce the amount of memory thrashing. This heuristic - # can be changed or removed later. - if unread > 0: - return - - if unread > 0: - # move unread data to the front - view = self._buffer.getbuffer() - view[:unread] = view[-unread:] - self._buffer.truncate(unread) - self._buffer.seek(0) - - def close(self) -> None: - try: - self._buffer.close() - except Exception: - # issue #633 suggests the purge/close somehow raised a - # BadFileDescriptor error. Perhaps the client ran out of - # memory or something else? It's probably OK to ignore - # any error being raised from purge/close since we're - # removing the reference to the instance below. - pass - self._buffer = None - self._sock = None - - -class PythonParser(BaseParser): - "Plain Python parsing class" - - def __init__(self, socket_read_size): - self.socket_read_size = socket_read_size - self.encoder = None - self._sock = None - self._buffer = None - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection): - "Called when the socket connects" - self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) - self.encoder = connection.encoder - - def on_disconnect(self): - "Called when the socket disconnects" - self._sock = None - if self._buffer is not None: - self._buffer.close() - self._buffer = None - self.encoder = None - - def can_read(self, timeout): - return self._buffer and self._buffer.can_read(timeout) - - def read_response(self, disable_decoding=False): - pos = self._buffer.get_pos() if self._buffer else None - try: - result = self._read_response(disable_decoding=disable_decoding) - except BaseException: - if self._buffer: - self._buffer.rewind(pos) - raise - else: - self._buffer.purge() - return result - - def _read_response(self, disable_decoding=False): - raw = self._buffer.readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - byte, response = raw[:1], raw[1:] - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - return int(response) - # bulk response - elif byte == b"$" and response == b"-1": - return None - elif byte == b"$": - response = self._buffer.read(int(response)) - # multi-bulk response - elif byte == b"*" and response == b"-1": - return None - elif byte == b"*": - response = [ - self._read_response(disable_decoding=disable_decoding) - for i in range(int(response)) - ] - else: - raise InvalidResponse(f"Protocol Error: {raw!r}") - - if disable_decoding is False: - response = self.encoder.decode(response) - return response - - -class HiredisParser(BaseParser): - "Parser class for connections using Hiredis" - - def __init__(self, socket_read_size): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not installed") - self.socket_read_size = socket_read_size - self._buffer = bytearray(socket_read_size) - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection, **kwargs): - self._sock = connection._sock - self._socket_timeout = connection.socket_timeout - kwargs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - "errors": connection.encoder.encoding_errors, - } - - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - self._reader = hiredis.Reader(**kwargs) - self._next_response = False - - def on_disconnect(self): - self._sock = None - self._reader = None - self._next_response = False - - def can_read(self, timeout): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._next_response is False: - self._next_response = self._reader.gets() - if self._next_response is False: - return self.read_from_socket(timeout=timeout, raise_on_timeout=False) - return True - - def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): - sock = self._sock - custom_timeout = timeout is not SENTINEL - try: - if custom_timeout: - sock.settimeout(timeout) - bufflen = self._sock.recv_into(self._buffer) - if bufflen == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - self._reader.feed(self._buffer, 0, bufflen) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - if custom_timeout: - sock.settimeout(self._socket_timeout) - - def read_response(self, disable_decoding=False): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - # _next_response might be cached from a can_read() call - if self._next_response is not False: - response = self._next_response - self._next_response = False - return response - - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - - while response is False: - self.read_from_socket() - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - - -DefaultParser: BaseParser +DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _HiredisParser else: - DefaultParser = PythonParser + DefaultParser = _RESP2Parser class HiredisRespSerializer: @@ -609,6 +144,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, command_packer=None, ): """ @@ -661,6 +197,17 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + # p = DEFAULT_RESP_VERSION + self.protocol = p self._command_packer = self._construct_command_packer(command_packer) def __repr__(self): @@ -747,7 +294,9 @@ def _error_message(self, exception): def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) + parser = self._parser + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -755,6 +304,24 @@ def on_connect(self): or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() + + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + elif auth_args: # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -772,6 +339,21 @@ def on_connect(self): if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + self.send_command("HELLO", self.protocol) + response = self.read_response() + if ( + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol + ): + raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: self.send_command("CLIENT", "SETNAME", self.client_name) @@ -872,14 +454,23 @@ def can_read(self, timeout=0): raise ConnectionError(f"Error while reading from {host_error}: {e.args}") def read_response( - self, disable_decoding=False, *, disconnect_on_error: bool = True + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): """Read the response from a previously sent command""" host_error = self._host_error() try: - response = self._parser.read_response(disable_decoding=disable_decoding) + if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + response = self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: if disconnect_on_error: self.disconnect() @@ -1073,7 +664,7 @@ def __init__( Raises: RedisError """ # noqa - if not ssl_available: + if not SSL_AVAILABLE: raise RedisError("Python wasn't built with SSL support") self.keyfile = ssl_keyfile @@ -1174,8 +765,9 @@ def _connect(self): class UnixDomainSocketConnection(AbstractConnection): "Manages UDS communication to and from a Redis server" - def __init__(self, path="", **kwargs): + def __init__(self, path="", socket_timeout=None, **kwargs): self.path = path + self.socket_timeout = socket_timeout super().__init__(**kwargs) def repr_pieces(self): diff --git a/redis/ocsp.py b/redis/ocsp.py index ab8a35a33d..b0420b4711 100644 --- a/redis/ocsp.py +++ b/redis/ocsp.py @@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.hashes import SHA1, Hash from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from cryptography.x509 import ocsp - from redis.exceptions import AuthorizationError, ConnectionError diff --git a/redis/typing.py b/redis/typing.py index 47a255652a..56a1e99ba7 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,14 +1,23 @@ # from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Iterable, + Mapping, + Type, + TypeVar, + Union, +) from redis.compat import Protocol if TYPE_CHECKING: + from redis._parsers import Encoder from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool - from redis.asyncio.connection import Encoder as AsyncEncoder - from redis.connection import ConnectionPool, Encoder + from redis.connection import ConnectionPool Number = Union[int, float] @@ -39,6 +48,8 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] @@ -48,7 +59,7 @@ def execute_command(self, *args, **options): class ClusterCommandsProtocol(CommandsProtocol, Protocol): - encoder: Union["AsyncEncoder", "Encoder"] + encoder: "Encoder" def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... diff --git a/redis/utils.py b/redis/utils.py index d95e62c042..148d15246b 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,3 +1,4 @@ +import logging from contextlib import contextmanager from functools import wraps from typing import Any, Dict, Mapping, Union @@ -12,6 +13,13 @@ HIREDIS_AVAILABLE = False HIREDIS_PACK_AVAILABLE = False +try: + import ssl # noqa + + SSL_AVAILABLE = True +except ImportError: + SSL_AVAILABLE = False + try: import cryptography # noqa @@ -110,3 +118,16 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def _set_info_logger(): + """ + Set up a logger that log info logs to stdout. + (This is used by the default push response handler) + """ + if "push_response" not in logging.root.manager.loggerDict.keys(): + logger = logging.getLogger("push_response") + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) diff --git a/setup.py b/setup.py index 65122b647b..3a752d44d3 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,11 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="4.6.0", + version="5.0.0rc2", packages=find_packages( include=[ "redis", + "redis._parsers", "redis.asyncio", "redis.commands", "redis.commands.bf", @@ -19,6 +20,7 @@ "redis.commands.search", "redis.commands.timeseries", "redis.commands.graph", + "redis.parsers", ] ), url="https://github.com/redis/redis-py", diff --git a/tasks.py b/tasks.py index 64b3aef80f..5162566183 100644 --- a/tasks.py +++ b/tasks.py @@ -1,69 +1,81 @@ +# https://github.com/pyinvoke/invoke/issues/833 +import inspect import os import shutil from invoke import run, task -with open("tox.ini") as fp: - lines = fp.read().split("\n") - dockers = [line.split("=")[1].strip() for line in lines if line.find("name") != -1] +if not hasattr(inspect, "getargspec"): + inspect.getargspec = inspect.getfullargspec @task def devenv(c): - """Builds a development environment: downloads, and starts all dockers - specified in the tox.ini file. - """ + """Brings up the test environment, by wrapping docker compose.""" clean(c) - cmd = "tox -e devenv" - for d in dockers: - cmd += f" --docker-dont-stop={d}" + cmd = "docker-compose --profile all up -d" run(cmd) @task def build_docs(c): """Generates the sphinx documentation.""" - run("tox -e docs") + run("pip install -r docs/requirements.txt") + run("make html") @task def linters(c): """Run code linters""" - run("tox -e linters") + run("flake8 tests redis") + run("black --target-version py37 --check --diff tests redis") + run("isort --check-only --diff tests redis") + run("vulture redis whitelist.py --min-confidence 80") + run("flynt --fail-on-change --dry-run tests redis") @task def all_tests(c): - """Run all linters, and tests in redis-py. This assumes you have all - the python versions specified in the tox.ini file. - """ + """Run all linters, and tests in redis-py.""" linters(c) tests(c) @task -def tests(c): +def tests(c, uvloop=False, protocol=2): """Run the redis-py test suite against the current python, with and without hiredis. """ print("Starting Redis tests") - run("tox -e '{standalone,cluster}'-'{plain,hiredis}'") + standalone_tests(c, uvloop=uvloop, protocol=protocol) + cluster_tests(c, uvloop=uvloop, protocol=protocol) @task -def standalone_tests(c): - """Run all Redis tests against the current python, - with and without hiredis.""" - print("Starting Redis tests") - run("tox -e standalone-'{plain,hiredis,ocsp}'") +def standalone_tests(c, uvloop=False, protocol=2): + """Run tests against a standalone redis instance""" + if uvloop: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop --junit-xml=standalone-uvloop-results.xml" + ) + else: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --junit-xml=standalone-results.xml" + ) @task -def cluster_tests(c): - """Run all Redis Cluster tests against the current python, - with and without hiredis.""" - print("Starting RedisCluster tests") - run("tox -e cluster-'{plain,hiredis}'") +def cluster_tests(c, uvloop=False, protocol=2): + """Run tests against a redis cluster""" + cluster_url = "redis://localhost:16379/0" + if uvloop: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --junit-xml=cluster-uvloop-results.xml --uvloop" + ) + else: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_clusteclient.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --junit-xml=cluster-results.xml" + ) @task @@ -73,7 +85,7 @@ def clean(c): shutil.rmtree("build") if os.path.isdir("dist"): shutil.rmtree("dist") - run(f"docker rm -f {' '.join(dockers)}") + run("docker-compose --profile all rm -s -f") @task diff --git a/tests/conftest.py b/tests/conftest.py index 4cd4c3c160..b3c410e51b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,18 +6,17 @@ from urllib.parse import urlparse import pytest -from packaging.version import Version - import redis +from packaging.version import Version from redis.backoff import NoBackoff from redis.connection import parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/9" -default_redismod_url = "redis://localhost:36379" -default_redis_unstable_url = "redis://localhost:6378" +default_redis_url = "redis://localhost:6379/0" +default_protocol = "2" +default_redismod_url = "redis://localhost:6379" # default ssl client ignores verification for the purpose of testing default_redis_ssl_url = "rediss://localhost:6666" @@ -73,6 +72,7 @@ def format_usage(self): def pytest_addoption(parser): + parser.addoption( "--redis-url", default=default_redis_url, @@ -81,14 +81,11 @@ def pytest_addoption(parser): ) parser.addoption( - "--redismod-url", - default=default_redismod_url, + "--protocol", + default=default_protocol, action="store", - help="Connection string to redis server" - " with loaded modules," - " defaults to `%(default)s`", + help="Protocol version, defaults to `%(default)s`", ) - parser.addoption( "--redis-ssl-url", default=default_redis_ssl_url, @@ -105,13 +102,6 @@ def pytest_addoption(parser): " defaults to `%(default)s`", ) - parser.addoption( - "--redis-unstable-url", - default=default_redis_unstable_url, - action="store", - help="Redis unstable (latest version) connection string " - "defaults to %(default)s`", - ) parser.addoption( "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" ) @@ -152,10 +142,8 @@ def pytest_sessionstart(session): # store REDIS_INFO in config so that it is available from "condition strings" session.config.REDIS_INFO = REDIS_INFO - # module info, if the second redis is running + # module info try: - redismod_url = session.config.getoption("--redismod-url") - info = _get_info(redismod_url) REDIS_INFO["modules"] = info["modules"] except redis.exceptions.ConnectionError: pass @@ -289,6 +277,9 @@ def _get_client( redis_url = request.config.getoption("--redis-url") else: redis_url = from_url + if "protocol" not in redis_url: + kwargs["protocol"] = request.config.getoption("--protocol") + cluster_mode = REDIS_INFO["cluster_enabled"] if not cluster_mode: url_options = parse_url(redis_url) @@ -332,20 +323,15 @@ def cluster_teardown(client, flushdb): client.disconnect_connection_pools() -# specifically set to the zero database, because creating -# an index on db != 0 raises a ResponseError in redis @pytest.fixture() -def modclient(request, **kwargs): - rmurl = request.config.getoption("--redismod-url") - with _get_client( - redis.Redis, request, from_url=rmurl, decode_responses=True, **kwargs - ) as client: +def r(request): + with _get_client(redis.Redis, request) as client: yield client @pytest.fixture() -def r(request): - with _get_client(redis.Redis, request) as client: +def decoded_r(request): + with _get_client(redis.Redis, request, decode_responses=True) as client: yield client @@ -385,7 +371,7 @@ def mock_cluster_resp_ok(request, **kwargs): @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, "2") + return _gen_cluster_mock_resp(r, 2) @pytest.fixture() @@ -444,15 +430,6 @@ def master_host(request): return parts.hostname, (parts.port or 6379) -@pytest.fixture() -def unstable_r(request): - url = request.config.getoption("--redis-unstable-url") - with _get_client( - redis.Redis, request, from_url=url, decode_responses=True - ) as client: - yield client - - def wait_for_command(client, monitor, command, key=None): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting @@ -472,3 +449,34 @@ def wait_for_command(client, monitor, command, key=None): return monitor_response if key in monitor_response["command"]: return None + + +def is_resp2_connection(r): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): + protocol = r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.cluster.AbstractRedisCluster): + protocol = r.nodes_manager.connection_kwargs.get("protocol") + return protocol in ["2", 2, None] + + +def get_protocol_version(r): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): + return r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.cluster.AbstractRedisCluster): + return r.nodes_manager.connection_kwargs.get("protocol") + + +def assert_resp_response(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response == resp2_expected + else: + assert response == resp3_expected + + +def assert_resp_response_in(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response in resp2_expected + else: + assert response in resp3_expected diff --git a/tests/ssl_utils.py b/tests/ssl_utils.py index 50937638a7..ab9c2e8944 100644 --- a/tests/ssl_utils.py +++ b/tests/ssl_utils.py @@ -3,10 +3,10 @@ def get_ssl_filename(name): root = os.path.join(os.path.dirname(__file__), "..") - cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys")) + cert_dir = os.path.abspath(os.path.join(root, "dockers", "stunnel", "keys")) if not os.path.isdir(cert_dir): # github actions package validation case cert_dir = os.path.abspath( - os.path.join(root, "..", "docker", "stunnel", "keys") + os.path.join(root, "..", "dockers", "stunnel", "keys") ) if not os.path.isdir(cert_dir): raise IOError(f"No SSL certificates found. They should be in {cert_dir}") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 121a13b41b..e5da3f8f46 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -4,18 +4,14 @@ import pytest import pytest_asyncio -from packaging.version import Version - import redis.asyncio as redis +from packaging.version import Version +from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.asyncio.client import Monitor -from redis.asyncio.connection import ( - HIREDIS_AVAILABLE, - HiredisParser, - PythonParser, - parse_url, -) +from redis.asyncio.connection import parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import REDIS_INFO from .compat import mock @@ -31,14 +27,14 @@ async def _get_info(redis_url): @pytest_asyncio.fixture( params=[ pytest.param( - (True, PythonParser), + (True, _AsyncRESP2Parser), marks=pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', reason="cluster mode enabled" ), ), - (False, PythonParser), + (False, _AsyncRESP2Parser), pytest.param( - (True, HiredisParser), + (True, _AsyncHiredisParser), marks=[ pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', @@ -50,7 +46,7 @@ async def _get_info(redis_url): ], ), pytest.param( - (False, HiredisParser), + (False, _AsyncHiredisParser), marks=pytest.mark.skipif( not HIREDIS_AVAILABLE, reason="hiredis is not installed" ), @@ -73,8 +69,12 @@ async def client_factory( url: str = request.config.getoption("--redis-url"), cls=redis.Redis, flushdb=True, + protocol=request.config.getoption("--protocol"), **kwargs, ): + if "protocol" not in url: + kwargs["protocol"] = request.config.getoption("--protocol") + cluster_mode = REDIS_INFO["cluster_enabled"] if not cluster_mode: single = kwargs.pop("single_connection_client", False) or single_connection @@ -133,10 +133,8 @@ async def r2(create_redis): @pytest_asyncio.fixture() -async def modclient(request, create_redis): - return await create_redis( - url=request.config.getoption("--redismod-url"), decode_responses=True - ) +async def decoded_r(create_redis): + return await create_redis(decode_responses=True) def _gen_cluster_mock_resp(r, response): @@ -156,7 +154,7 @@ async def mock_cluster_resp_ok(create_redis, **kwargs): @pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, "2") + return _gen_cluster_mock_resp(r, 2) @pytest_asyncio.fixture() diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index 9f4a805c4c..0535ddfe02 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -1,94 +1,99 @@ from math import inf import pytest - import redis.asyncio as redis from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_ifmodversion_lt, +) def intlist(obj): return [int(v) for v in obj] -# @pytest.fixture -# async def client(modclient): -# assert isinstance(modawait modclient.bf(), redis.commands.bf.BFBloom) -# assert isinstance(modawait modclient.cf(), redis.commands.bf.CFBloom) -# assert isinstance(modawait modclient.cms(), redis.commands.bf.CMSBloom) -# assert isinstance(modawait modclient.tdigest(), redis.commands.bf.TDigestBloom) -# assert isinstance(modawait modclient.topk(), redis.commands.bf.TOPKBloom) - -# modawait modclient.flushdb() -# return modclient - - @pytest.mark.redismod -async def test_create(modclient: redis.Redis): +async def test_create(decoded_r: redis.Redis): """Test CREATE/RESERVE calls""" - assert await modclient.bf().create("bloom", 0.01, 1000) - assert await modclient.bf().create("bloom_e", 0.01, 1000, expansion=1) - assert await modclient.bf().create("bloom_ns", 0.01, 1000, noScale=True) - assert await modclient.cf().create("cuckoo", 1000) - assert await modclient.cf().create("cuckoo_e", 1000, expansion=1) - assert await modclient.cf().create("cuckoo_bs", 1000, bucket_size=4) - assert await modclient.cf().create("cuckoo_mi", 1000, max_iterations=10) - assert await modclient.cms().initbydim("cmsDim", 100, 5) - assert await modclient.cms().initbyprob("cmsProb", 0.01, 0.01) - assert await modclient.topk().reserve("topk", 5, 100, 5, 0.9) + assert await decoded_r.bf().create("bloom", 0.01, 1000) + assert await decoded_r.bf().create("bloom_e", 0.01, 1000, expansion=1) + assert await decoded_r.bf().create("bloom_ns", 0.01, 1000, noScale=True) + assert await decoded_r.cf().create("cuckoo", 1000) + assert await decoded_r.cf().create("cuckoo_e", 1000, expansion=1) + assert await decoded_r.cf().create("cuckoo_bs", 1000, bucket_size=4) + assert await decoded_r.cf().create("cuckoo_mi", 1000, max_iterations=10) + assert await decoded_r.cms().initbydim("cmsDim", 100, 5) + assert await decoded_r.cms().initbyprob("cmsProb", 0.01, 0.01) + assert await decoded_r.topk().reserve("topk", 5, 100, 5, 0.9) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_create(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_create(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) -# region Test Bloom Filter @pytest.mark.redismod -async def test_bf_add(modclient: redis.Redis): - assert await modclient.bf().create("bloom", 0.01, 1000) - assert 1 == await modclient.bf().add("bloom", "foo") - assert 0 == await modclient.bf().add("bloom", "foo") - assert [0] == intlist(await modclient.bf().madd("bloom", "foo")) - assert [0, 1] == await modclient.bf().madd("bloom", "foo", "bar") - assert [0, 0, 1] == await modclient.bf().madd("bloom", "foo", "bar", "baz") - assert 1 == await modclient.bf().exists("bloom", "foo") - assert 0 == await modclient.bf().exists("bloom", "noexist") - assert [1, 0] == intlist(await modclient.bf().mexists("bloom", "foo", "noexist")) +async def test_bf_add(decoded_r: redis.Redis): + assert await decoded_r.bf().create("bloom", 0.01, 1000) + assert 1 == await decoded_r.bf().add("bloom", "foo") + assert 0 == await decoded_r.bf().add("bloom", "foo") + assert [0] == intlist(await decoded_r.bf().madd("bloom", "foo")) + assert [0, 1] == await decoded_r.bf().madd("bloom", "foo", "bar") + assert [0, 0, 1] == await decoded_r.bf().madd("bloom", "foo", "bar", "baz") + assert 1 == await decoded_r.bf().exists("bloom", "foo") + assert 0 == await decoded_r.bf().exists("bloom", "noexist") + assert [1, 0] == intlist(await decoded_r.bf().mexists("bloom", "foo", "noexist")) @pytest.mark.redismod -async def test_bf_insert(modclient: redis.Redis): - assert await modclient.bf().create("bloom", 0.01, 1000) - assert [1] == intlist(await modclient.bf().insert("bloom", ["foo"])) - assert [0, 1] == intlist(await modclient.bf().insert("bloom", ["foo", "bar"])) - assert [1] == intlist(await modclient.bf().insert("captest", ["foo"], capacity=10)) - assert [1] == intlist(await modclient.bf().insert("errtest", ["foo"], error=0.01)) - assert 1 == await modclient.bf().exists("bloom", "foo") - assert 0 == await modclient.bf().exists("bloom", "noexist") - assert [1, 0] == intlist(await modclient.bf().mexists("bloom", "foo", "noexist")) - info = await modclient.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum +async def test_bf_insert(decoded_r: redis.Redis): + assert await decoded_r.bf().create("bloom", 0.01, 1000) + assert [1] == intlist(await decoded_r.bf().insert("bloom", ["foo"])) + assert [0, 1] == intlist(await decoded_r.bf().insert("bloom", ["foo", "bar"])) + assert [1] == intlist(await decoded_r.bf().insert("captest", ["foo"], capacity=10)) + assert [1] == intlist(await decoded_r.bf().insert("errtest", ["foo"], error=0.01)) + assert 1 == await decoded_r.bf().exists("bloom", "foo") + assert 0 == await decoded_r.bf().exists("bloom", "noexist") + assert [1, 0] == intlist(await decoded_r.bf().mexists("bloom", "foo", "noexist")) + info = await decoded_r.bf().info("bloom") + assert_resp_response( + decoded_r, + 2, + info.get("insertedNum"), + info.get("Number of items inserted"), + ) + assert_resp_response( + decoded_r, + 1000, + info.get("capacity"), + info.get("Capacity"), + ) + assert_resp_response( + decoded_r, + 1, + info.get("filterNum"), + info.get("Number of filters"), + ) @pytest.mark.redismod -async def test_bf_scandump_and_loadchunk(modclient: redis.Redis): +async def test_bf_scandump_and_loadchunk(decoded_r: redis.Redis): # Store a filter - await modclient.bf().create("myBloom", "0.0001", "1000") + await decoded_r.bf().create("myBloom", "0.0001", "1000") # test is probabilistic and might fail. It is OK to change variables if # certain to not break anything async def do_verify(): res = 0 for x in range(1000): - await modclient.bf().add("myBloom", x) - rv = await modclient.bf().exists("myBloom", x) + await decoded_r.bf().add("myBloom", x) + rv = await decoded_r.bf().exists("myBloom", x) assert rv - rv = await modclient.bf().exists("myBloom", f"nonexist_{x}") + rv = await decoded_r.bf().exists("myBloom", f"nonexist_{x}") res += rv == x assert res < 5 @@ -96,52 +101,62 @@ async def do_verify(): cmds = [] if HIREDIS_AVAILABLE: with pytest.raises(ModuleError): - cur = await modclient.bf().scandump("myBloom", 0) + cur = await decoded_r.bf().scandump("myBloom", 0) return - cur = await modclient.bf().scandump("myBloom", 0) + cur = await decoded_r.bf().scandump("myBloom", 0) first = cur[0] cmds.append(cur) while True: - cur = await modclient.bf().scandump("myBloom", first) + cur = await decoded_r.bf().scandump("myBloom", first) first = cur[0] if first == 0: break else: cmds.append(cur) - prev_info = await modclient.bf().execute_command("bf.debug", "myBloom") + prev_info = await decoded_r.bf().execute_command("bf.debug", "myBloom") # Remove the filter - await modclient.bf().client.delete("myBloom") + await decoded_r.bf().client.delete("myBloom") # Now, load all the commands: for cmd in cmds: - await modclient.bf().loadchunk("myBloom", *cmd) + await decoded_r.bf().loadchunk("myBloom", *cmd) - cur_info = await modclient.bf().execute_command("bf.debug", "myBloom") + cur_info = await decoded_r.bf().execute_command("bf.debug", "myBloom") assert prev_info == cur_info await do_verify() - await modclient.bf().client.delete("myBloom") - await modclient.bf().create("myBloom", "0.0001", "10000000") + await decoded_r.bf().client.delete("myBloom") + await decoded_r.bf().create("myBloom", "0.0001", "10000000") @pytest.mark.redismod -async def test_bf_info(modclient: redis.Redis): +async def test_bf_info(decoded_r: redis.Redis): expansion = 4 # Store a filter - await modclient.bf().create("nonscaling", "0.0001", "1000", noScale=True) - info = await modclient.bf().info("nonscaling") - assert info.expansionRate is None + await decoded_r.bf().create("nonscaling", "0.0001", "1000", noScale=True) + info = await decoded_r.bf().info("nonscaling") + assert_resp_response( + decoded_r, + None, + info.get("expansionRate"), + info.get("Expansion rate"), + ) - await modclient.bf().create("expanding", "0.0001", "1000", expansion=expansion) - info = await modclient.bf().info("expanding") - assert info.expansionRate == 4 + await decoded_r.bf().create("expanding", "0.0001", "1000", expansion=expansion) + info = await decoded_r.bf().info("expanding") + assert_resp_response( + decoded_r, + 4, + info.get("expansionRate"), + info.get("Expansion rate"), + ) try: # noScale mean no expansion - await modclient.bf().create( + await decoded_r.bf().create( "myBloom", "0.0001", "1000", expansion=expansion, noScale=True ) assert False @@ -150,95 +165,96 @@ async def test_bf_info(modclient: redis.Redis): @pytest.mark.redismod -async def test_bf_card(modclient: redis.Redis): +async def test_bf_card(decoded_r: redis.Redis): # return 0 if the key does not exist - assert await modclient.bf().card("not_exist") == 0 + assert await decoded_r.bf().card("not_exist") == 0 # Store a filter - assert await modclient.bf().add("bf1", "item_foo") == 1 - assert await modclient.bf().card("bf1") == 1 + assert await decoded_r.bf().add("bf1", "item_foo") == 1 + assert await decoded_r.bf().card("bf1") == 1 - # Error when key is of a type other than Bloom filter. + # Error when key is of a type other than Bloom filtedecoded_r. with pytest.raises(redis.ResponseError): - await modclient.set("setKey", "value") - await modclient.bf().card("setKey") + await decoded_r.set("setKey", "value") + await decoded_r.bf().card("setKey") -# region Test Cuckoo Filter @pytest.mark.redismod -async def test_cf_add_and_insert(modclient: redis.Redis): - assert await modclient.cf().create("cuckoo", 1000) - assert await modclient.cf().add("cuckoo", "filter") - assert not await modclient.cf().addnx("cuckoo", "filter") - assert 1 == await modclient.cf().addnx("cuckoo", "newItem") - assert [1] == await modclient.cf().insert("captest", ["foo"]) - assert [1] == await modclient.cf().insert("captest", ["foo"], capacity=1000) - assert [1] == await modclient.cf().insertnx("captest", ["bar"]) - assert [1] == await modclient.cf().insertnx("captest", ["food"], nocreate="1") - assert [0, 0, 1] == await modclient.cf().insertnx("captest", ["foo", "bar", "baz"]) - assert [0] == await modclient.cf().insertnx("captest", ["bar"], capacity=1000) - assert [1] == await modclient.cf().insert("empty1", ["foo"], capacity=1000) - assert [1] == await modclient.cf().insertnx("empty2", ["bar"], capacity=1000) - info = await modclient.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum +async def test_cf_add_and_insert(decoded_r: redis.Redis): + assert await decoded_r.cf().create("cuckoo", 1000) + assert await decoded_r.cf().add("cuckoo", "filter") + assert not await decoded_r.cf().addnx("cuckoo", "filter") + assert 1 == await decoded_r.cf().addnx("cuckoo", "newItem") + assert [1] == await decoded_r.cf().insert("captest", ["foo"]) + assert [1] == await decoded_r.cf().insert("captest", ["foo"], capacity=1000) + assert [1] == await decoded_r.cf().insertnx("captest", ["bar"]) + assert [1] == await decoded_r.cf().insertnx("captest", ["food"], nocreate="1") + assert [0, 0, 1] == await decoded_r.cf().insertnx("captest", ["foo", "bar", "baz"]) + assert [0] == await decoded_r.cf().insertnx("captest", ["bar"], capacity=1000) + assert [1] == await decoded_r.cf().insert("empty1", ["foo"], capacity=1000) + assert [1] == await decoded_r.cf().insertnx("empty2", ["bar"], capacity=1000) + info = await decoded_r.cf().info("captest") + assert_resp_response( + decoded_r, 5, info.get("insertedNum"), info.get("Number of items inserted") + ) + assert_resp_response( + decoded_r, 0, info.get("deletedNum"), info.get("Number of items deleted") + ) + assert_resp_response( + decoded_r, 1, info.get("filterNum"), info.get("Number of filters") + ) @pytest.mark.redismod -async def test_cf_exists_and_del(modclient: redis.Redis): - assert await modclient.cf().create("cuckoo", 1000) - assert await modclient.cf().add("cuckoo", "filter") - assert await modclient.cf().exists("cuckoo", "filter") - assert not await modclient.cf().exists("cuckoo", "notexist") - assert 1 == await modclient.cf().count("cuckoo", "filter") - assert 0 == await modclient.cf().count("cuckoo", "notexist") - assert await modclient.cf().delete("cuckoo", "filter") - assert 0 == await modclient.cf().count("cuckoo", "filter") - - -# region Test Count-Min Sketch +async def test_cf_exists_and_del(decoded_r: redis.Redis): + assert await decoded_r.cf().create("cuckoo", 1000) + assert await decoded_r.cf().add("cuckoo", "filter") + assert await decoded_r.cf().exists("cuckoo", "filter") + assert not await decoded_r.cf().exists("cuckoo", "notexist") + assert 1 == await decoded_r.cf().count("cuckoo", "filter") + assert 0 == await decoded_r.cf().count("cuckoo", "notexist") + assert await decoded_r.cf().delete("cuckoo", "filter") + assert 0 == await decoded_r.cf().count("cuckoo", "filter") + + @pytest.mark.redismod -async def test_cms(modclient: redis.Redis): - assert await modclient.cms().initbydim("dim", 1000, 5) - assert await modclient.cms().initbyprob("prob", 0.01, 0.01) - assert await modclient.cms().incrby("dim", ["foo"], [5]) - assert [0] == await modclient.cms().query("dim", "notexist") - assert [5] == await modclient.cms().query("dim", "foo") - assert [10, 15] == await modclient.cms().incrby("dim", ["foo", "bar"], [5, 15]) - assert [10, 15] == await modclient.cms().query("dim", "foo", "bar") - info = await modclient.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count +async def test_cms(decoded_r: redis.Redis): + assert await decoded_r.cms().initbydim("dim", 1000, 5) + assert await decoded_r.cms().initbyprob("prob", 0.01, 0.01) + assert await decoded_r.cms().incrby("dim", ["foo"], [5]) + assert [0] == await decoded_r.cms().query("dim", "notexist") + assert [5] == await decoded_r.cms().query("dim", "foo") + assert [10, 15] == await decoded_r.cms().incrby("dim", ["foo", "bar"], [5, 15]) + assert [10, 15] == await decoded_r.cms().query("dim", "foo", "bar") + info = await decoded_r.cms().info("dim") + assert info["width"] + assert 1000 == info["width"] + assert 5 == info["depth"] + assert 25 == info["count"] @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_cms_merge(modclient: redis.Redis): - assert await modclient.cms().initbydim("A", 1000, 5) - assert await modclient.cms().initbydim("B", 1000, 5) - assert await modclient.cms().initbydim("C", 1000, 5) - assert await modclient.cms().incrby("A", ["foo", "bar", "baz"], [5, 3, 9]) - assert await modclient.cms().incrby("B", ["foo", "bar", "baz"], [2, 3, 1]) - assert [5, 3, 9] == await modclient.cms().query("A", "foo", "bar", "baz") - assert [2, 3, 1] == await modclient.cms().query("B", "foo", "bar", "baz") - assert await modclient.cms().merge("C", 2, ["A", "B"]) - assert [7, 6, 10] == await modclient.cms().query("C", "foo", "bar", "baz") - assert await modclient.cms().merge("C", 2, ["A", "B"], ["1", "2"]) - assert [9, 9, 11] == await modclient.cms().query("C", "foo", "bar", "baz") - assert await modclient.cms().merge("C", 2, ["A", "B"], ["2", "3"]) - assert [16, 15, 21] == await modclient.cms().query("C", "foo", "bar", "baz") - - -# endregion - - -# region Test Top-K +async def test_cms_merge(decoded_r: redis.Redis): + assert await decoded_r.cms().initbydim("A", 1000, 5) + assert await decoded_r.cms().initbydim("B", 1000, 5) + assert await decoded_r.cms().initbydim("C", 1000, 5) + assert await decoded_r.cms().incrby("A", ["foo", "bar", "baz"], [5, 3, 9]) + assert await decoded_r.cms().incrby("B", ["foo", "bar", "baz"], [2, 3, 1]) + assert [5, 3, 9] == await decoded_r.cms().query("A", "foo", "bar", "baz") + assert [2, 3, 1] == await decoded_r.cms().query("B", "foo", "bar", "baz") + assert await decoded_r.cms().merge("C", 2, ["A", "B"]) + assert [7, 6, 10] == await decoded_r.cms().query("C", "foo", "bar", "baz") + assert await decoded_r.cms().merge("C", 2, ["A", "B"], ["1", "2"]) + assert [9, 9, 11] == await decoded_r.cms().query("C", "foo", "bar", "baz") + assert await decoded_r.cms().merge("C", 2, ["A", "B"], ["2", "3"]) + assert [16, 15, 21] == await decoded_r.cms().query("C", "foo", "bar", "baz") + + @pytest.mark.redismod -async def test_topk(modclient: redis.Redis): +async def test_topk(decoded_r: redis.Redis): # test list with empty buckets - assert await modclient.topk().reserve("topk", 3, 50, 4, 0.9) + assert await decoded_r.topk().reserve("topk", 3, 50, 4, 0.9) assert [ None, None, @@ -257,7 +273,7 @@ async def test_topk(modclient: redis.Redis): None, "D", None, - ] == await modclient.topk().add( + ] == await decoded_r.topk().add( "topk", "A", "B", @@ -277,17 +293,17 @@ async def test_topk(modclient: redis.Redis): "E", 1, ) - assert [1, 1, 0, 0, 1, 0, 0] == await modclient.topk().query( + assert [1, 1, 0, 0, 1, 0, 0] == await decoded_r.topk().query( "topk", "A", "B", "C", "D", "E", "F", "G" ) with pytest.deprecated_call(): - assert [4, 3, 2, 3, 3, 0, 1] == await modclient.topk().count( + assert [4, 3, 2, 3, 3, 0, 1] == await decoded_r.topk().count( "topk", "A", "B", "C", "D", "E", "F", "G" ) # test full list - assert await modclient.topk().reserve("topklist", 3, 50, 3, 0.9) - assert await modclient.topk().add( + assert await decoded_r.topk().reserve("topklist", 3, 50, 3, 0.9) + assert await decoded_r.topk().add( "topklist", "A", "B", @@ -306,192 +322,196 @@ async def test_topk(modclient: redis.Redis): "E", "E", ) - assert ["A", "B", "E"] == await modclient.topk().list("topklist") - res = await modclient.topk().list("topklist", withcount=True) + assert ["A", "B", "E"] == await decoded_r.topk().list("topklist") + res = await decoded_r.topk().list("topklist", withcount=True) assert ["A", 4, "B", 3, "E", 3] == res - info = await modclient.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) + info = await decoded_r.topk().info("topklist") + assert 3 == info["k"] + assert 50 == info["width"] + assert 3 == info["depth"] + assert 0.9 == round(float(info["decay"]), 1) @pytest.mark.redismod -async def test_topk_incrby(modclient: redis.Redis): - await modclient.flushdb() - assert await modclient.topk().reserve("topk", 3, 10, 3, 1) - assert [None, None, None] == await modclient.topk().incrby( +async def test_topk_incrby(decoded_r: redis.Redis): + await decoded_r.flushdb() + assert await decoded_r.topk().reserve("topk", 3, 10, 3, 1) + assert [None, None, None] == await decoded_r.topk().incrby( "topk", ["bar", "baz", "42"], [3, 6, 2] ) - res = await modclient.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) + res = await decoded_r.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) assert [None, "bar"] == res with pytest.deprecated_call(): - assert [3, 6, 10, 4, 0] == await modclient.topk().count( + assert [3, 6, 10, 4, 0] == await decoded_r.topk().count( "topk", "bar", "baz", "42", "xyzzy", 4 ) -# region Test T-Digest @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_reset(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 10) +async def test_tdigest_reset(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 10) # reset on empty histogram - assert await modclient.tdigest().reset("tDigest") + assert await decoded_r.tdigest().reset("tDigest") # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(10))) + assert await decoded_r.tdigest().add("tDigest", list(range(10))) - assert await modclient.tdigest().reset("tDigest") + assert await decoded_r.tdigest().reset("tDigest") # assert we have 0 unmerged nodes - assert 0 == (await modclient.tdigest().info("tDigest")).unmerged_nodes + info = await decoded_r.tdigest().info("tDigest") + assert_resp_response( + decoded_r, 0, info.get("unmerged_nodes"), info.get("Unmerged nodes") + ) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_merge(modclient: redis.Redis): - assert await modclient.tdigest().create("to-tDigest", 10) - assert await modclient.tdigest().create("from-tDigest", 10) +async def test_tdigest_merge(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("to-tDigest", 10) + assert await decoded_r.tdigest().create("from-tDigest", 10) # insert data-points into sketch - assert await modclient.tdigest().add("from-tDigest", [1.0] * 10) - assert await modclient.tdigest().add("to-tDigest", [2.0] * 10) + assert await decoded_r.tdigest().add("from-tDigest", [1.0] * 10) + assert await decoded_r.tdigest().add("to-tDigest", [2.0] * 10) # merge from-tdigest into to-tdigest - assert await modclient.tdigest().merge("to-tDigest", 1, "from-tDigest") + assert await decoded_r.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram - info = await modclient.tdigest().info("to-tDigest") - total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) - assert 20.0 == total_weight_to + info = await decoded_r.tdigest().info("to-tDigest") + if is_resp2_connection(decoded_r): + assert 20 == float(info["merged_weight"]) + float(info["unmerged_weight"]) + else: + assert 20 == float(info["Merged weight"]) + float(info["Unmerged weight"]) # test override - assert await modclient.tdigest().create("from-override", 10) - assert await modclient.tdigest().create("from-override-2", 10) - assert await modclient.tdigest().add("from-override", [3.0] * 10) - assert await modclient.tdigest().add("from-override-2", [4.0] * 10) - assert await modclient.tdigest().merge( + assert await decoded_r.tdigest().create("from-override", 10) + assert await decoded_r.tdigest().create("from-override-2", 10) + assert await decoded_r.tdigest().add("from-override", [3.0] * 10) + assert await decoded_r.tdigest().add("from-override-2", [4.0] * 10) + assert await decoded_r.tdigest().merge( "to-tDigest", 2, "from-override", "from-override-2", override=True ) - assert 3.0 == await modclient.tdigest().min("to-tDigest") - assert 4.0 == await modclient.tdigest().max("to-tDigest") + assert 3.0 == await decoded_r.tdigest().min("to-tDigest") + assert 4.0 == await decoded_r.tdigest().max("to-tDigest") @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_min_and_max(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_min_and_max(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", [1, 2, 3]) + assert await decoded_r.tdigest().add("tDigest", [1, 2, 3]) # min/max - assert 3 == await modclient.tdigest().max("tDigest") - assert 1 == await modclient.tdigest().min("tDigest") + assert 3 == await decoded_r.tdigest().max("tDigest") + assert 1 == await decoded_r.tdigest().min("tDigest") @pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") -async def test_tdigest_quantile(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 500) +async def test_tdigest_quantile(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 500) # insert data-points into sketch - assert await modclient.tdigest().add( + assert await decoded_r.tdigest().add( "tDigest", list([x * 0.01 for x in range(1, 10000)]) ) # assert min min/max have same result as quantile 0 and 1 assert ( - await modclient.tdigest().max("tDigest") - == (await modclient.tdigest().quantile("tDigest", 1))[0] + await decoded_r.tdigest().max("tDigest") + == (await decoded_r.tdigest().quantile("tDigest", 1))[0] ) assert ( - await modclient.tdigest().min("tDigest") - == (await modclient.tdigest().quantile("tDigest", 0.0))[0] + await decoded_r.tdigest().min("tDigest") + == (await decoded_r.tdigest().quantile("tDigest", 0.0))[0] ) - assert 1.0 == round((await modclient.tdigest().quantile("tDigest", 0.01))[0], 2) - assert 99.0 == round((await modclient.tdigest().quantile("tDigest", 0.99))[0], 2) + assert 1.0 == round((await decoded_r.tdigest().quantile("tDigest", 0.01))[0], 2) + assert 99.0 == round((await decoded_r.tdigest().quantile("tDigest", 0.99))[0], 2) # test multiple quantiles - assert await modclient.tdigest().create("t-digest", 100) - assert await modclient.tdigest().add("t-digest", [1, 2, 3, 4, 5]) - res = await modclient.tdigest().quantile("t-digest", 0.5, 0.8) + assert await decoded_r.tdigest().create("t-digest", 100) + assert await decoded_r.tdigest().add("t-digest", [1, 2, 3, 4, 5]) + res = await decoded_r.tdigest().quantile("t-digest", 0.5, 0.8) assert [3.0, 5.0] == res @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_cdf(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_cdf(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(1, 10))) - assert 0.1 == round((await modclient.tdigest().cdf("tDigest", 1.0))[0], 1) - assert 0.9 == round((await modclient.tdigest().cdf("tDigest", 9.0))[0], 1) - res = await modclient.tdigest().cdf("tDigest", 1.0, 9.0) + assert await decoded_r.tdigest().add("tDigest", list(range(1, 10))) + assert 0.1 == round((await decoded_r.tdigest().cdf("tDigest", 1.0))[0], 1) + assert 0.9 == round((await decoded_r.tdigest().cdf("tDigest", 9.0))[0], 1) + res = await decoded_r.tdigest().cdf("tDigest", 1.0, 9.0) assert [0.1, 0.9] == [round(x, 1) for x in res] @pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") -async def test_tdigest_trimmed_mean(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_trimmed_mean(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(1, 10))) - assert 5 == await modclient.tdigest().trimmed_mean("tDigest", 0.1, 0.9) - assert 4.5 == await modclient.tdigest().trimmed_mean("tDigest", 0.4, 0.5) + assert await decoded_r.tdigest().add("tDigest", list(range(1, 10))) + assert 5 == await decoded_r.tdigest().trimmed_mean("tDigest", 0.1, 0.9) + assert 4.5 == await decoded_r.tdigest().trimmed_mean("tDigest", 0.4, 0.5) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_rank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(0, 20))) - assert -1 == (await modclient.tdigest().rank("t-digest", -1))[0] - assert 0 == (await modclient.tdigest().rank("t-digest", 0))[0] - assert 10 == (await modclient.tdigest().rank("t-digest", 10))[0] - assert [-1, 20, 9] == await modclient.tdigest().rank("t-digest", -20, 20, 9) +async def test_tdigest_rank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == (await decoded_r.tdigest().rank("t-digest", -1))[0] + assert 0 == (await decoded_r.tdigest().rank("t-digest", 0))[0] + assert 10 == (await decoded_r.tdigest().rank("t-digest", 10))[0] + assert [-1, 20, 9] == await decoded_r.tdigest().rank("t-digest", -20, 20, 9) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_revrank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(0, 20))) - assert -1 == (await modclient.tdigest().revrank("t-digest", 20))[0] - assert 19 == (await modclient.tdigest().revrank("t-digest", 0))[0] - assert [-1, 19, 9] == await modclient.tdigest().revrank("t-digest", 21, 0, 10) +async def test_tdigest_revrank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == (await decoded_r.tdigest().revrank("t-digest", 20))[0] + assert 19 == (await decoded_r.tdigest().revrank("t-digest", 0))[0] + assert [-1, 19, 9] == await decoded_r.tdigest().revrank("t-digest", 21, 0, 10) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_byrank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(1, 11))) - assert 1 == (await modclient.tdigest().byrank("t-digest", 0))[0] - assert 10 == (await modclient.tdigest().byrank("t-digest", 9))[0] - assert (await modclient.tdigest().byrank("t-digest", 100))[0] == inf +async def test_tdigest_byrank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(1, 11))) + assert 1 == (await decoded_r.tdigest().byrank("t-digest", 0))[0] + assert 10 == (await decoded_r.tdigest().byrank("t-digest", 9))[0] + assert (await decoded_r.tdigest().byrank("t-digest", 100))[0] == inf with pytest.raises(redis.ResponseError): - (await modclient.tdigest().byrank("t-digest", -1))[0] + (await decoded_r.tdigest().byrank("t-digest", -1))[0] @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_byrevrank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(1, 11))) - assert 10 == (await modclient.tdigest().byrevrank("t-digest", 0))[0] - assert 1 == (await modclient.tdigest().byrevrank("t-digest", 9))[0] - assert (await modclient.tdigest().byrevrank("t-digest", 100))[0] == -inf +async def test_tdigest_byrevrank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(1, 11))) + assert 10 == (await decoded_r.tdigest().byrevrank("t-digest", 0))[0] + assert 1 == (await decoded_r.tdigest().byrevrank("t-digest", 9))[0] + assert (await decoded_r.tdigest().byrevrank("t-digest", 100))[0] == -inf with pytest.raises(redis.ResponseError): - (await modclient.tdigest().byrevrank("t-digest", -1))[0] + (await decoded_r.tdigest().byrevrank("t-digest", -1))[0] # @pytest.mark.redismod -# async def test_pipeline(modclient: redis.Redis): -# pipeline = await modclient.bf().pipeline() -# assert not await modclient.bf().execute_command("get pipeline") +# async def test_pipeline(decoded_r: redis.Redis): +# pipeline = await decoded_r.bf().pipeline() +# assert not await decoded_r.bf().execute_command("get pipeline") # -# assert await modclient.bf().create("pipeline", 0.01, 1000) +# assert await decoded_r.bf().create("pipeline", 0.01, 1000) # for i in range(100): # pipeline.add("pipeline", i) # for i in range(100): -# assert not (await modclient.bf().exists("pipeline", i)) +# assert not (await decoded_r.bf().exists("pipeline", i)) # # pipeline.execute() # # for i in range(100): -# assert await modclient.bf().exists("pipeline", i) +# assert await decoded_r.bf().exists("pipeline", i) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index c41d4a2168..ee498e71f7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -8,10 +8,9 @@ import pytest import pytest_asyncio from _pytest.fixtures import FixtureRequest - +from redis._parsers import AsyncCommandsParser from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster from redis.asyncio.connection import Connection, SSLConnection, async_timeout -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name @@ -30,6 +29,7 @@ ) from redis.utils import str_if_bytes from tests.conftest import ( + assert_resp_response, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -152,7 +152,7 @@ async def execute_command(*_args, **_kwargs): execute_command_mock.side_effect = execute_command with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -602,7 +602,7 @@ def map_7007(self): mocks["send_packed_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -964,7 +964,7 @@ async def test_client_setname(self, r: RedisCluster) -> None: node = r.get_random_node() await r.client_setname("redis_py_test", target_nodes=node) client_name = await r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" + assert_resp_response(r, client_name, "redis_py_test", b"redis_py_test") async def test_exists(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} @@ -1443,7 +1443,7 @@ async def test_client_trackinginfo(self, r: RedisCluster) -> None: node = r.get_primaries()[0] res = await r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @skip_if_server_version_lt("2.9.50") async def test_client_pause(self, r: RedisCluster) -> None: @@ -1609,24 +1609,68 @@ async def test_cluster_renamenx(self, r: RedisCluster) -> None: async def test_cluster_blpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) is None await r.rpush("{foo}c", "1") - assert await r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, await r.blpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) async def test_cluster_brpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) is None await r.rpush("{foo}c", "1") - assert await r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, await r.brpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) async def test_cluster_brpoplpush(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") @@ -1699,7 +1743,8 @@ async def test_cluster_zdiff(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + response = await r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: @@ -1707,7 +1752,8 @@ async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert await r.zrange("{foo}out", 0, -1) == [b"a3"] - assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + response = await r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zinter(self, r: RedisCluster) -> None: @@ -1721,32 +1767,41 @@ async def test_cluster_zinter(self, r: RedisCluster) -> None: ["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True ) # aggregate with SUM - assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8], [b"a1", 9]] + ) # aggregate with MAX - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] + ) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) # aggregate with MIN - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] + ) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 1)], [[b"a1", 1], [b"a3", 1]] + ) # with weights - assert await r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a3", 20), (b"a1", 23)] + res = await r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) + assert_resp_response( + r, res, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1758,10 +1813,12 @@ async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1773,10 +1830,12 @@ async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1785,66 +1844,86 @@ async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b2", - 20, - ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b1", - 10, - ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a2", - 2, + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a1", - 1, + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], ) assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None await r.zadd("{foo}c", {"c1": 100}) - assert await r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + await r.bzpopmax("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmin(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b1", - 10, - ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b2", - 20, - ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a1", - 1, + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"b", b"b1", 10], ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a2", - 2, + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"a", b"a2", 2], ) assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None await r.zadd("{foo}c", {"c1": 100}) - assert await r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + await r.bzpopmin("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("6.2.0") async def test_cluster_zrangestore(self, r: RedisCluster) -> None: @@ -1853,10 +1932,12 @@ async def test_cluster_zrangestore(self, r: RedisCluster) -> None: assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert await r.zrangestore("{foo}b", "{foo}a", 1, 2) assert await r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert await r.zrange("{foo}b", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert await r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1883,36 +1964,49 @@ async def test_cluster_zunion(self, r: RedisCluster) -> None: b"a3", b"a1", ] - assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert await r.zunion( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + await r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) async def test_cluster_zunionstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1924,12 +2018,12 @@ async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1941,12 +2035,12 @@ async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1955,12 +2049,12 @@ async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") async def test_cluster_pfcount(self, r: RedisCluster) -> None: @@ -2444,7 +2538,7 @@ async def mocked_execute_command(self, *args, **kwargs): assert "Redis Cluster cannot be connected" in str(e.value) with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index c0259680c0..08e66b050f 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -10,11 +10,19 @@ import pytest import pytest_asyncio - import redis from redis import exceptions -from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + parse_info, +) +from redis.client import EMPTY_RESPONSE, NEVER_DECODE from tests.conftest import ( + assert_resp_response, + assert_resp_response_in, + is_resp2_connection, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, @@ -78,14 +86,19 @@ class TestResponseCallbacks: """Tests for the response callback system""" async def test_response_callbacks(self, r: redis.Redis): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS - assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + callbacks = _RedisCallbacks + if is_resp2_connection(r): + callbacks.update(_RedisCallbacksRESP2) + else: + callbacks.update(_RedisCallbacksRESP3) + assert r.response_callbacks == callbacks + assert id(r.response_callbacks) != id(_RedisCallbacks) r.set_response_callback("GET", lambda x: "static") await r.set("a", "foo") assert await r.get("a") == "static" async def test_case_insensitive_command_names(self, r: redis.Redis): - assert r.response_callbacks["del"] == r.response_callbacks["DEL"] + assert r.response_callbacks["ping"] == r.response_callbacks["PING"] class TestRedisCommands: @@ -99,13 +112,13 @@ async def test_command_on_invalid_key_type(self, r: redis.Redis): async def test_acl_cat_no_category(self, r: redis.Redis): categories = await r.acl_cat() assert isinstance(categories, list) - assert "read" in categories + assert "read" in categories or b"read" in categories @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_cat_with_category(self, r: redis.Redis): commands = await r.acl_cat("read") assert isinstance(commands, list) - assert "get" in commands + assert "get" in commands or b"get" in commands @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_deluser(self, r_teardown): @@ -119,36 +132,32 @@ async def test_acl_deluser(self, r_teardown): @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_genpass(self, r: redis.Redis): password = await r.acl_genpass() - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) - @skip_if_server_version_lt(REDIS_6_VERSION) - @skip_if_server_version_gte("7.0.0") + @skip_if_server_version_lt("7.0.0") async def test_acl_getuser_setuser(self, r_teardown): username = "redis-py-user" r = r_teardown(username) # test enabled=False assert await r.acl_setuser(username, enabled=False, reset=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": False, - "flags": ["off", "allchannels", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "off" in acl["flags"] + assert acl["enabled"] is False # test nopass=True assert await r.acl_setuser(username, enabled=True, reset=True, nopass=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": True, - "flags": ["on", "allchannels", "nopass", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "on" in acl["flags"] + assert "nopass" in acl["flags"] + assert acl["enabled"] is True # test all args assert await r.acl_setuser( @@ -161,12 +170,11 @@ async def test_acl_getuser_setuser(self, r_teardown): keys=["cache:*", "objects:*"], ) acl = await r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True - assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert "on" in acl["flags"] + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 # test reset=False keeps existing ACL and applies new ACL on top @@ -188,12 +196,10 @@ async def test_acl_getuser_setuser(self, r_teardown): keys=["objects:*"], ) acl = await r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True - assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert "on" in acl["flags"] + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 # test removal of passwords @@ -229,14 +235,13 @@ async def test_acl_getuser_setuser(self, r_teardown): assert len((await r.acl_getuser(username))["passwords"]) == 1 @skip_if_server_version_lt(REDIS_6_VERSION) - @skip_if_server_version_gte("7.0.0") async def test_acl_list(self, r_teardown): username = "redis-py-user" r = r_teardown(username) - + start = await r.acl_list() assert await r.acl_setuser(username, enabled=False, reset=True) users = await r.acl_list() - assert f"user {username} off sanitize-payload &* -@all" in users + assert len(users) == len(start) + 1 @skip_if_server_version_lt(REDIS_6_VERSION) @pytest.mark.onlynoncluster @@ -271,7 +276,8 @@ async def test_acl_log(self, r_teardown, create_redis): assert len(await r.acl_log()) == 2 assert len(await r.acl_log(count=1)) == 1 assert isinstance((await r.acl_log())[0], dict) - assert "client-info" in (await r.acl_log(count=1))[0] + expected = (await r.acl_log(count=1))[0] + assert_resp_response_in(r, "client-info", expected, expected.keys()) assert await r.acl_log_reset() @skip_if_server_version_lt(REDIS_6_VERSION) @@ -307,7 +313,7 @@ async def test_acl_users(self, r: redis.Redis): @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_whoami(self, r: redis.Redis): username = await r.acl_whoami() - assert isinstance(username, str) + assert isinstance(username, (str, bytes)) @pytest.mark.onlynoncluster async def test_client_list(self, r: redis.Redis): @@ -345,7 +351,9 @@ async def test_client_getname(self, r: redis.Redis): @pytest.mark.onlynoncluster async def test_client_setname(self, r: redis.Redis): assert await r.client_setname("redis_py_test") - assert await r.client_getname() == "redis_py_test" + assert_resp_response( + r, await r.client_getname(), "redis_py_test", b"redis_py_test" + ) @skip_if_server_version_lt("2.6.9") @pytest.mark.onlynoncluster @@ -944,6 +952,19 @@ async def test_pttl_no_key(self, r: redis.Redis): """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" assert await r.pttl("a") == -2 + @skip_if_server_version_lt("6.2.0") + async def test_hrandfield(self, r): + assert await r.hrandfield("key") is None + await r.hset("key", mapping={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + assert await r.hrandfield("key") is not None + assert len(await r.hrandfield("key", 2)) == 2 + # with values + assert_resp_response(r, len(await r.hrandfield("key", 2, True)), 4, 2) + # without duplications + assert len(await r.hrandfield("key", 10)) == 5 + # with duplications + assert len(await r.hrandfield("key", -10)) == 10 + @pytest.mark.onlynoncluster async def test_randomkey(self, r: redis.Redis): assert await r.randomkey() is None @@ -1080,25 +1101,45 @@ async def test_type(self, r: redis.Redis): async def test_blpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") - assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"3") - assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"4") - assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"1") - assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) assert await r.blpop(["b", "a"], timeout=1) is None await r.rpush("c", "1") - assert await r.blpop("c", timeout=1) == (b"c", b"1") + assert_resp_response( + r, await r.blpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"] + ) @pytest.mark.onlynoncluster async def test_brpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") - assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"4") - assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"3") - assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"2") - assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) assert await r.brpop(["b", "a"], timeout=1) is None await r.rpush("c", "1") - assert await r.brpop("c", timeout=1) == (b"c", b"1") + assert_resp_response( + r, await r.brpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"] + ) @pytest.mark.onlynoncluster async def test_brpoplpush(self, r: redis.Redis): @@ -1403,7 +1444,10 @@ async def test_spop_multi_value(self, r: redis.Redis): for value in values: assert value in s - assert await r.spop("a", 1) == list(set(s) - set(values)) + response = await r.spop("a", 1) + assert_resp_response( + r, response, list(set(s) - set(values)), set(s) - set(values) + ) async def test_srandmember(self, r: redis.Redis): s = [b"1", b"2", b"3"] @@ -1441,11 +1485,13 @@ async def test_sunionstore(self, r: redis.Redis): async def test_zadd(self, r: redis.Redis): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} await r.zadd("a", mapping) - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -1462,23 +1508,24 @@ async def test_zadd(self, r: redis.Redis): async def test_zadd_nx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) async def test_zadd_xx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert await r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a1", 99.0)], [[b"a1", 99.0]]) async def test_zadd_ch(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 99.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 99.0)], [[b"a2", 2.0], [b"a1", 99.0]] + ) async def test_zadd_incr(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 @@ -1502,6 +1549,25 @@ async def test_zcount(self, r: redis.Redis): assert await r.zcount("a", 1, "(" + str(2)) == 1 assert await r.zcount("a", 10, 20) == 0 + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiff(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiff(["a", "b"]) == [b"a3"] + response = await r.zdiff(["a", "b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiffstore(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiffstore("out", ["a", "b"]) + assert await r.zrange("out", 0, -1) == [b"a3"] + response = await r.zrange("out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) + async def test_zincrby(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zincrby("a", 1, "a2") == 3.0 @@ -1521,7 +1587,10 @@ async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"]) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8.0], [b"a1", 9.0]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): @@ -1529,7 +1598,10 @@ async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): @@ -1537,7 +1609,10 @@ async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 3)], [[b"a1", 1], [b"a3", 3]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): @@ -1545,49 +1620,104 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmax("a") == [(b"a3", 3)] + response = await r.zpopmax("a") + assert_resp_response(r, response, [(b"a3", 3)], [b"a3", 3.0]) # with count - assert await r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + response = await r.zpopmax("a", count=2) + assert_resp_response( + r, response, [(b"a2", 2), (b"a1", 1)], [[b"a2", 2], [b"a1", 1]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmin("a") == [(b"a1", 1)] + response = await r.zpopmin("a") + assert_resp_response(r, response, [(b"a1", 1)], [b"a1", 1.0]) # with count - assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + response = await r.zpopmin("a", count=2) + assert_resp_response( + r, response, [(b"a2", 2), (b"a3", 3)], [[b"a2", 2], [b"a3", 3]] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster async def test_bzpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"a", b"a2", 2), + [b"a", b"a2", 2], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"a", b"a1", 1), + [b"a", b"a1", 1], + ) assert await r.bzpopmax(["b", "a"], timeout=1) is None await r.zadd("c", {"c1": 100}) - assert await r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, await r.bzpopmax("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster async def test_bzpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"a", b"a2", 2), + [b"a", b"a2", 2], + ) assert await r.bzpopmin(["b", "a"], timeout=1) is None await r.zadd("c", {"c1": 100}) - assert await r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, await r.bzpopmin("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) async def test_zrange(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1595,20 +1725,20 @@ async def test_zrange(self, r: redis.Redis): assert await r.zrange("a", 1, 2) == [b"a2", b"a3"] # withscores - assert await r.zrange("a", 0, 1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] - assert await r.zrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, 1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) + response = await r.zrange("a", 1, 2, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]] + ) # custom score function - assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + # assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] @skip_if_server_version_lt("2.8.9") async def test_zrangebylex(self, r: redis.Redis): @@ -1642,16 +1772,24 @@ async def test_zrangebyscore(self, r: redis.Redis): assert await r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert await r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] + response = await r.zrangebyscore("a", 2, 4, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) # custom score function - assert await r.zrangebyscore( + response = await r.zrangebyscore( "a", 2, 4, withscores=True, score_cast_func=int - ) == [(b"a2", 2), (b"a3", 3), (b"a4", 4)] + ) + assert_resp_response( + r, + response, + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) async def test_zrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1708,20 +1846,20 @@ async def test_zrevrange(self, r: redis.Redis): assert await r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert await r.zrevrange("a", 0, 1, withscores=True) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - assert await r.zrevrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 1.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 3.0), (b"a2", 2.0)], [[b"a3", 3.0], [b"a2", 2.0]] + ) + response = await r.zrevrange("a", 1, 2, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 1.0)], [[b"a2", 2.0], [b"a1", 1.0]] + ) # custom score function - assert await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) + assert_resp_response( + r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]] + ) async def test_zrevrangebyscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1731,16 +1869,24 @@ async def test_zrevrangebyscore(self, r: redis.Redis): assert await r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] # withscores - assert await r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrangebyscore("a", 4, 2, withscores=True) + assert_resp_response( + r, + response, + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) # custom score function - assert await r.zrevrangebyscore( + response = await r.zrevrangebyscore( "a", 4, 2, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + ) + assert_resp_response( + r, + response, + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) async def test_zrevrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1769,12 +1915,13 @@ async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"]) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 3.0), (b"a4", 4.0), (b"a3", 8.0), (b"a1", 9.0)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): @@ -1782,12 +1929,13 @@ async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + respponse = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + respponse, + [(b"a2", 2.0), (b"a4", 4.0), (b"a3", 5.0), (b"a1", 6.0)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): @@ -1795,12 +1943,13 @@ async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): @@ -1808,12 +1957,13 @@ async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 5.0), (b"a4", 12.0), (b"a3", 20.0), (b"a1", 23.0)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) # HYPERLOGLOG TESTS @skip_if_server_version_lt("2.8.9") @@ -2254,11 +2404,12 @@ async def test_geohash(self, r: redis.Redis): ) await r.geoadd("barcelona", values) - assert await r.geohash("barcelona", "place1", "place2", "place3") == [ - "sp3e9yg3kd0", - "sp3e9cbc3t0", - None, - ] + assert_resp_response( + r, + await r.geohash("barcelona", "place1", "place2", "place3"), + ["sp3e9yg3kd0", "sp3e9cbc3t0", None], + [b"sp3e9yg3kd0", b"sp3e9cbc3t0", None], + ) @skip_if_server_version_lt("3.2.0") async def test_geopos(self, r: redis.Redis): @@ -2270,10 +2421,18 @@ async def test_geopos(self, r: redis.Redis): await r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert await r.geopos("barcelona", "place1", "place2") == [ - (2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099), - ] + assert_resp_response( + r, + await r.geopos("barcelona", "place1", "place2"), + [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ], + [ + [2.19093829393386841, 41.43379028184083523], + [2.18737632036209106, 41.40634178640635099], + ], + ) @skip_if_server_version_lt("4.0.0") async def test_geopos_no_value(self, r: redis.Redis): @@ -2682,7 +2841,7 @@ async def test_xgroup_setid(self, r: redis.Redis): ] assert await r.xinfo_groups(stream) == expected - @skip_if_server_version_lt("5.0.0") + @skip_if_server_version_lt("7.2.0") async def test_xinfo_consumers(self, r: redis.Redis): stream = "stream" group = "group" @@ -2698,8 +2857,8 @@ async def test_xinfo_consumers(self, r: redis.Redis): info = await r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ - {"name": consumer1.encode(), "pending": 1}, - {"name": consumer2.encode(), "pending": 2}, + {"name": consumer1.encode(), "pending": 1, "inactive": 2}, + {"name": consumer2.encode(), "pending": 2, "inactive": 2}, ] # we can't determine the idle time, so just make sure it's an int @@ -2808,28 +2967,30 @@ async def test_xread(self, r: redis.Redis): m1 = await r.xadd(stream, {"foo": "bar"}) m2 = await r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert await r.xread(streams={stream: 0}) == expected + res = await r.xread(streams={stream: 0}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert await r.xread(streams={stream: 0}, count=1) == expected + res = await r.xread(streams={stream: 0}, count=1) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) - expected = [[stream.encode(), [await get_stream_message(r, stream, m2)]]] + expected_entries = [await get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert await r.xread(streams={stream: m1}) == expected - - # xread starting at the last message returns an empty list - assert await r.xread(streams={stream: m2}) == [] + res = await r.xread(streams={stream: m1}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xreadgroup(self, r: redis.Redis): @@ -2840,26 +3001,27 @@ async def test_xreadgroup(self, r: redis.Redis): m2 = await r.xadd(stream, {"bing": "baz"}) await r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert ( - await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} ) await r.xgroup_destroy(stream, group) @@ -2868,35 +3030,34 @@ async def test_xreadgroup(self, r: redis.Redis): # will only find messages added after this await r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response(r, res, [], {}) # xreadgroup with noack does not have any items in the PEL await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") - assert ( - len( - ( - await r.xreadgroup( - group, consumer, streams={stream: ">"}, noack=True - ) - )[0][1] - ) - == 2 - ) - # now there should be nothing pending - assert ( - len((await r.xreadgroup(group, consumer, streams={stream: "0"}))[0][1]) == 0 - ) + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[strem_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[strem_name][0]) == 0 await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) - assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xrevrange(self, r: redis.Redis): diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 8e3209fdc6..bead7208f5 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -5,7 +5,6 @@ import ssl import pytest - from redis.asyncio.connection import ( Connection, SSLConnection, diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 158b8545e2..09960fd7e2 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -4,16 +4,15 @@ from unittest.mock import patch import pytest - import redis -from redis.asyncio import Redis -from redis.asyncio.connection import ( - BaseParser, - Connection, - HiredisParser, - PythonParser, - UnixDomainSocketConnection, +from redis._parsers import ( + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, + _AsyncRESPBase, ) +from redis.asyncio import Redis +from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError @@ -31,11 +30,11 @@ async def test_invalid_response(create_redis): raw = b"x" fake_stream = MockStream(raw + b"\r\n") - parser: BaseParser = r.connection._parser + parser: _AsyncRESPBase = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - if isinstance(parser, PythonParser): + if isinstance(parser, _AsyncRESPBase): assert str(cm.value) == f"Protocol Error: {raw!r}" else: assert ( @@ -91,22 +90,22 @@ async def get_conn(_): @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_loading_external_modules(modclient): +async def test_loading_external_modules(r): def inner(): pass - modclient.load_external_module("myfuncname", inner) - assert getattr(modclient, "myfuncname") == inner - assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + r.load_external_module("myfuncname", inner) + assert getattr(r, "myfuncname") == inner + assert isinstance(getattr(r, "myfuncname"), types.FunctionType) # and call it from redis.commands import RedisModuleCommands j = RedisModuleCommands.json - modclient.load_external_module("sometestfuncname", j) + r.load_external_module("sometestfuncname", j) # d = {'hello': 'world!'} - # mod = j(modclient) + # mod = j(r) # mod.set("fookey", ".", d) # assert mod.get('fookey') == d @@ -197,7 +196,9 @@ async def test_connection_parse_response_resume(r: redis.Redis): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser], + ids=["AsyncRESP2Parser", "AsyncRESP3Parser", "AsyncHiredisParser"], ) async def test_connection_disconect_race(parser_class): """ @@ -211,7 +212,7 @@ async def test_connection_disconect_race(parser_class): This test verifies that a read in progress can finish even if the `disconnect()` method is called. """ - if parser_class == HiredisParser and not HIREDIS_AVAILABLE: + if parser_class == _AsyncHiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available") args = {} diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 24d9902142..7672dc74b4 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -4,7 +4,6 @@ import pytest import pytest_asyncio - import redis.asyncio as redis from redis.asyncio.connection import Connection, to_bool from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt @@ -246,8 +245,9 @@ async def test_connection_pool_blocks_until_timeout(self, master_host): start = asyncio.get_running_loop().time() with pytest.raises(redis.ConnectionError): await pool.get_connection("_") - # we should have waited at least 0.1 seconds - assert asyncio.get_running_loop().time() - start >= 0.1 + + # we should have waited at least some period of time + assert asyncio.get_running_loop().time() - start >= 0.05 await c1.disconnect() async def test_connection_pool_blocks_until_conn_available(self, master_host): @@ -267,7 +267,8 @@ async def target(): start = asyncio.get_running_loop().time() await asyncio.gather(target(), pool.get_connection("_")) - assert asyncio.get_running_loop().time() - start >= 0.1 + stop = asyncio.get_running_loop().time() + assert (stop - start) <= 0.2 async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0]} @@ -666,6 +667,7 @@ async def r(self, create_redis, server): @pytest.mark.onlynoncluster +@pytest.mark.xfail(strict=False) class TestHealthCheck: interval = 60 diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 8e213cdb26..4429f7453b 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -5,7 +5,6 @@ import pytest import pytest_asyncio - import redis from redis import AuthenticationError, DataError, ResponseError from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 21f2ddde2a..ff588861e4 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -2,7 +2,6 @@ import contextlib import pytest - from redis.asyncio import Redis from redis.asyncio.cluster import RedisCluster from redis.asyncio.connection import async_timeout @@ -183,7 +182,7 @@ async def op(pipe): async def test_cluster(master_host): delay = 0.1 - cluster_port = 6372 + cluster_port = 16379 remap_base = 7372 n_nodes = 6 hostname, _ = master_host diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index 3efcf69e5b..162ccb367d 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -1,6 +1,5 @@ import pytest import pytest_asyncio - import redis.asyncio as redis from redis.exceptions import DataError @@ -90,6 +89,7 @@ async def r(self, create_redis): yield redis await redis.flushall() + @pytest.mark.xfail async def test_basic_command(self, r: redis.Redis): await r.set("hello", "world") diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py index e7a772fc0f..22195901e6 100644 --- a/tests/test_asyncio/test_graph.py +++ b/tests/test_asyncio/test_graph.py @@ -1,5 +1,4 @@ import pytest - import redis.asyncio as redis from redis.commands.graph import Edge, Node, Path from redis.commands.graph.execution_plan import Operation @@ -8,15 +7,15 @@ @pytest.mark.redismod -async def test_bulk(modclient): +async def test_bulk(decoded_r): with pytest.raises(NotImplementedError): - await modclient.graph().bulk() - await modclient.graph().bulk(foo="bar!") + await decoded_r.graph().bulk() + await decoded_r.graph().bulk(foo="bar!") @pytest.mark.redismod -async def test_graph_creation(modclient: redis.Redis): - graph = modclient.graph() +async def test_graph_creation(decoded_r: redis.Redis): + graph = decoded_r.graph() john = Node( label="person", @@ -60,8 +59,8 @@ async def test_graph_creation(modclient: redis.Redis): @pytest.mark.redismod -async def test_array_functions(modclient: redis.Redis): - graph = modclient.graph() +async def test_array_functions(decoded_r: redis.Redis): + graph = decoded_r.graph() query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" await graph.query(query) @@ -83,12 +82,12 @@ async def test_array_functions(modclient: redis.Redis): @pytest.mark.redismod -async def test_path(modclient: redis.Redis): +async def test_path(decoded_r: redis.Redis): node0 = Node(node_id=0, label="L1") node1 = Node(node_id=1, label="L1") edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) - graph = modclient.graph() + graph = decoded_r.graph() graph.add_node(node0) graph.add_node(node1) graph.add_edge(edge01) @@ -103,20 +102,20 @@ async def test_path(modclient: redis.Redis): @pytest.mark.redismod -async def test_param(modclient: redis.Redis): +async def test_param(decoded_r: redis.Redis): params = [1, 2.3, "str", True, False, None, [0, 1, 2]] query = "RETURN $param" for param in params: - result = await modclient.graph().query(query, {"param": param}) + result = await decoded_r.graph().query(query, {"param": param}) expected_results = [[param]] assert expected_results == result.result_set @pytest.mark.redismod -async def test_map(modclient: redis.Redis): +async def test_map(decoded_r: redis.Redis): query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" - actual = (await modclient.graph().query(query)).result_set[0][0] + actual = (await decoded_r.graph().query(query)).result_set[0][0] expected = { "a": 1, "b": "str", @@ -130,40 +129,40 @@ async def test_map(modclient: redis.Redis): @pytest.mark.redismod -async def test_point(modclient: redis.Redis): +async def test_point(decoded_r: redis.Redis): query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" expected_lat = 32.070794860 expected_lon = 34.820751118 - actual = (await modclient.graph().query(query)).result_set[0][0] + actual = (await decoded_r.graph().query(query)).result_set[0][0] assert abs(actual["latitude"] - expected_lat) < 0.001 assert abs(actual["longitude"] - expected_lon) < 0.001 query = "RETURN point({latitude: 32, longitude: 34.0})" expected_lat = 32 expected_lon = 34 - actual = (await modclient.graph().query(query)).result_set[0][0] + actual = (await decoded_r.graph().query(query)).result_set[0][0] assert abs(actual["latitude"] - expected_lat) < 0.001 assert abs(actual["longitude"] - expected_lon) < 0.001 @pytest.mark.redismod -async def test_index_response(modclient: redis.Redis): - result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") +async def test_index_response(decoded_r: redis.Redis): + result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") assert 1 == result_set.indices_created - result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") + result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") assert 0 == result_set.indices_created - result_set = await modclient.graph().query("DROP INDEX ON :person(age)") + result_set = await decoded_r.graph().query("DROP INDEX ON :person(age)") assert 1 == result_set.indices_deleted with pytest.raises(ResponseError): - await modclient.graph().query("DROP INDEX ON :person(age)") + await decoded_r.graph().query("DROP INDEX ON :person(age)") @pytest.mark.redismod -async def test_stringify_query_result(modclient: redis.Redis): - graph = modclient.graph() +async def test_stringify_query_result(decoded_r: redis.Redis): + graph = decoded_r.graph() john = Node( alias="a", @@ -216,14 +215,14 @@ async def test_stringify_query_result(modclient: redis.Redis): @pytest.mark.redismod -async def test_optional_match(modclient: redis.Redis): +async def test_optional_match(decoded_r: redis.Redis): # Build a graph of form (a)-[R]->(b) node0 = Node(node_id=0, label="L1", properties={"value": "a"}) node1 = Node(node_id=1, label="L1", properties={"value": "b"}) edge01 = Edge(node0, "R", node1, edge_id=0) - graph = modclient.graph() + graph = decoded_r.graph() graph.add_node(node0) graph.add_node(node1) graph.add_edge(edge01) @@ -241,17 +240,17 @@ async def test_optional_match(modclient: redis.Redis): @pytest.mark.redismod -async def test_cached_execution(modclient: redis.Redis): - await modclient.graph().query("CREATE ()") +async def test_cached_execution(decoded_r: redis.Redis): + await decoded_r.graph().query("CREATE ()") - uncached_result = await modclient.graph().query( + uncached_result = await decoded_r.graph().query( "MATCH (n) RETURN n, $param", {"param": [0]} ) assert uncached_result.cached_execution is False # loop to make sure the query is cached on each thread on server for x in range(0, 64): - cached_result = await modclient.graph().query( + cached_result = await decoded_r.graph().query( "MATCH (n) RETURN n, $param", {"param": [0]} ) assert uncached_result.result_set == cached_result.result_set @@ -261,51 +260,51 @@ async def test_cached_execution(modclient: redis.Redis): @pytest.mark.redismod -async def test_slowlog(modclient: redis.Redis): +async def test_slowlog(decoded_r: redis.Redis): create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - await modclient.graph().query(create_query) + await decoded_r.graph().query(create_query) - results = await modclient.graph().slowlog() + results = await decoded_r.graph().slowlog() assert results[0][1] == "GRAPH.QUERY" assert results[0][2] == create_query @pytest.mark.redismod @pytest.mark.xfail(strict=False) -async def test_query_timeout(modclient: redis.Redis): +async def test_query_timeout(decoded_r: redis.Redis): # Build a sample graph with 1000 nodes. - await modclient.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") + await decoded_r.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") # Issue a long-running query with a 1-millisecond timeout. with pytest.raises(ResponseError): - await modclient.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) + await decoded_r.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) assert False is False with pytest.raises(Exception): - await modclient.graph().query("RETURN 1", timeout="str") + await decoded_r.graph().query("RETURN 1", timeout="str") assert False is False @pytest.mark.redismod -async def test_read_only_query(modclient: redis.Redis): +async def test_read_only_query(decoded_r: redis.Redis): with pytest.raises(Exception): # Issue a write query, specifying read-only true, # this call should fail. - await modclient.graph().query("CREATE (p:person {name:'a'})", read_only=True) + await decoded_r.graph().query("CREATE (p:person {name:'a'})", read_only=True) assert False is False @pytest.mark.redismod -async def test_profile(modclient: redis.Redis): +async def test_profile(decoded_r: redis.Redis): q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" - profile = (await modclient.graph().profile(q)).result_set + profile = (await decoded_r.graph().profile(q)).result_set assert "Create | Records produced: 3" in profile assert "Unwind | Records produced: 3" in profile q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" - profile = (await modclient.graph().profile(q)).result_set + profile = (await decoded_r.graph().profile(q)).result_set assert "Results | Records produced: 2" in profile assert "Project | Records produced: 2" in profile assert "Filter | Records produced: 2" in profile @@ -314,16 +313,16 @@ async def test_profile(modclient: redis.Redis): @pytest.mark.redismod @skip_if_redis_enterprise() -async def test_config(modclient: redis.Redis): +async def test_config(decoded_r: redis.Redis): config_name = "RESULTSET_SIZE" config_value = 3 # Set configuration - response = await modclient.graph().config(config_name, config_value, set=True) + response = await decoded_r.graph().config(config_name, config_value, set=True) assert response == "OK" # Make sure config been updated. - response = await modclient.graph().config(config_name, set=False) + response = await decoded_r.graph().config(config_name, set=False) expected_response = [config_name, config_value] assert response == expected_response @@ -331,46 +330,46 @@ async def test_config(modclient: redis.Redis): config_value = 1 << 20 # 1MB # Set configuration - response = await modclient.graph().config(config_name, config_value, set=True) + response = await decoded_r.graph().config(config_name, config_value, set=True) assert response == "OK" # Make sure config been updated. - response = await modclient.graph().config(config_name, set=False) + response = await decoded_r.graph().config(config_name, set=False) expected_response = [config_name, config_value] assert response == expected_response # reset to default - await modclient.graph().config("QUERY_MEM_CAPACITY", 0, set=True) - await modclient.graph().config("RESULTSET_SIZE", -100, set=True) + await decoded_r.graph().config("QUERY_MEM_CAPACITY", 0, set=True) + await decoded_r.graph().config("RESULTSET_SIZE", -100, set=True) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_list_keys(modclient: redis.Redis): - result = await modclient.graph().list_keys() +async def test_list_keys(decoded_r: redis.Redis): + result = await decoded_r.graph().list_keys() assert result == [] - await modclient.graph("G").query("CREATE (n)") - result = await modclient.graph().list_keys() + await decoded_r.graph("G").query("CREATE (n)") + result = await decoded_r.graph().list_keys() assert result == ["G"] - await modclient.graph("X").query("CREATE (m)") - result = await modclient.graph().list_keys() + await decoded_r.graph("X").query("CREATE (m)") + result = await decoded_r.graph().list_keys() assert result == ["G", "X"] - await modclient.delete("G") - await modclient.rename("X", "Z") - result = await modclient.graph().list_keys() + await decoded_r.delete("G") + await decoded_r.rename("X", "Z") + result = await decoded_r.graph().list_keys() assert result == ["Z"] - await modclient.delete("Z") - result = await modclient.graph().list_keys() + await decoded_r.delete("Z") + result = await decoded_r.graph().list_keys() assert result == [] @pytest.mark.redismod -async def test_multi_label(modclient: redis.Redis): - redis_graph = modclient.graph("g") +async def test_multi_label(decoded_r: redis.Redis): + redis_graph = decoded_r.graph("g") node = Node(label=["l", "ll"]) redis_graph.add_node(node) @@ -395,8 +394,8 @@ async def test_multi_label(modclient: redis.Redis): @pytest.mark.redismod -async def test_execution_plan(modclient: redis.Redis): - redis_graph = modclient.graph("execution_plan") +async def test_execution_plan(decoded_r: redis.Redis): + redis_graph = decoded_r.graph("execution_plan") create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), @@ -414,8 +413,8 @@ async def test_execution_plan(modclient: redis.Redis): @pytest.mark.redismod -async def test_explain(modclient: redis.Redis): - redis_graph = modclient.graph("execution_plan") +async def test_explain(decoded_r: redis.Redis): + redis_graph = decoded_r.graph("execution_plan") # graph creation / population create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index 12bdacdb9a..6f3e8c3251 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -1,318 +1,338 @@ import pytest - import redis.asyncio as redis from redis import exceptions from redis.commands.json.path import Path -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import assert_resp_response, skip_ifmodversion_lt @pytest.mark.redismod -async def test_json_setbinarykey(modclient: redis.Redis): +async def test_json_setbinarykey(decoded_r: redis.Redis): d = {"hello": "world", b"some": "value"} with pytest.raises(TypeError): - modclient.json().set("somekey", Path.root_path(), d) - assert await modclient.json().set("somekey", Path.root_path(), d, decode_keys=True) + decoded_r.json().set("somekey", Path.root_path(), d) + assert await decoded_r.json().set("somekey", Path.root_path(), d, decode_keys=True) @pytest.mark.redismod -async def test_json_setgetdeleteforget(modclient: redis.Redis): - assert await modclient.json().set("foo", Path.root_path(), "bar") - assert await modclient.json().get("foo") == "bar" - assert await modclient.json().get("baz") is None - assert await modclient.json().delete("foo") == 1 - assert await modclient.json().forget("foo") == 0 # second delete - assert await modclient.exists("foo") == 0 +async def test_json_setgetdeleteforget(decoded_r: redis.Redis): + assert await decoded_r.json().set("foo", Path.root_path(), "bar") + assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) + assert await decoded_r.json().get("baz") is None + assert await decoded_r.json().delete("foo") == 1 + assert await decoded_r.json().forget("foo") == 0 # second delete + assert await decoded_r.exists("foo") == 0 @pytest.mark.redismod -async def test_jsonget(modclient: redis.Redis): - await modclient.json().set("foo", Path.root_path(), "bar") - assert await modclient.json().get("foo") == "bar" +async def test_jsonget(decoded_r: redis.Redis): + await decoded_r.json().set("foo", Path.root_path(), "bar") + assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) @pytest.mark.redismod -async def test_json_get_jset(modclient: redis.Redis): - assert await modclient.json().set("foo", Path.root_path(), "bar") - assert "bar" == await modclient.json().get("foo") - assert await modclient.json().get("baz") is None - assert 1 == await modclient.json().delete("foo") - assert await modclient.exists("foo") == 0 +async def test_json_get_jset(decoded_r: redis.Redis): + assert await decoded_r.json().set("foo", Path.root_path(), "bar") + assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) + assert await decoded_r.json().get("baz") is None + assert 1 == await decoded_r.json().delete("foo") + assert await decoded_r.exists("foo") == 0 @pytest.mark.redismod -@skip_ifmodversion_lt("2.6.0", "ReJSON") # todo: update after the release -async def test_json_merge(modclient: redis.Redis): +async def test_nonascii_setgetdelete(decoded_r: redis.Redis): + assert await decoded_r.json().set("notascii", Path.root_path(), "hyvää-élève") + res = "hyvää-élève" + assert_resp_response( + decoded_r, await decoded_r.json().get("notascii", no_escape=True), res, [[res]] + ) + assert 1 == await decoded_r.json().delete("notascii") + assert await decoded_r.exists("notascii") == 0 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.6.0", "ReJSON") +async def test_json_merge(decoded_r: redis.Redis): # Test with root path $ - assert await modclient.json().set( + assert await decoded_r.json().set( "person_data", "$", {"person1": {"personal_data": {"name": "John"}}}, ) - assert await modclient.json().merge( + assert await decoded_r.json().merge( "person_data", "$", {"person1": {"personal_data": {"hobbies": "reading"}}} ) - assert await modclient.json().get("person_data") == { + assert await decoded_r.json().get("person_data") == { "person1": {"personal_data": {"name": "John", "hobbies": "reading"}} } # Test with root path path $.person1.personal_data - assert await modclient.json().merge( + assert await decoded_r.json().merge( "person_data", "$.person1.personal_data", {"country": "Israel"} ) - assert await modclient.json().get("person_data") == { + assert await decoded_r.json().get("person_data") == { "person1": { "personal_data": {"name": "John", "hobbies": "reading", "country": "Israel"} } } # Test with null value to delete a value - assert await modclient.json().merge( + assert await decoded_r.json().merge( "person_data", "$.person1.personal_data", {"name": None} ) - assert await modclient.json().get("person_data") == { + assert await decoded_r.json().get("person_data") == { "person1": {"personal_data": {"country": "Israel", "hobbies": "reading"}} } @pytest.mark.redismod -async def test_nonascii_setgetdelete(modclient: redis.Redis): - assert await modclient.json().set("notascii", Path.root_path(), "hyvää-élève") - assert "hyvää-élève" == await modclient.json().get("notascii", no_escape=True) - assert 1 == await modclient.json().delete("notascii") - assert await modclient.exists("notascii") == 0 - - -@pytest.mark.redismod -async def test_jsonsetexistentialmodifiersshouldsucceed(modclient: redis.Redis): +async def test_jsonsetexistentialmodifiersshouldsucceed(decoded_r: redis.Redis): obj = {"foo": "bar"} - assert await modclient.json().set("obj", Path.root_path(), obj) + assert await decoded_r.json().set("obj", Path.root_path(), obj) # Test that flags prevent updates when conditions are unmet - assert await modclient.json().set("obj", Path("foo"), "baz", nx=True) is None - assert await modclient.json().set("obj", Path("qaz"), "baz", xx=True) is None + assert await decoded_r.json().set("obj", Path("foo"), "baz", nx=True) is None + assert await decoded_r.json().set("obj", Path("qaz"), "baz", xx=True) is None # Test that flags allow updates when conditions are met - assert await modclient.json().set("obj", Path("foo"), "baz", xx=True) - assert await modclient.json().set("obj", Path("qaz"), "baz", nx=True) + assert await decoded_r.json().set("obj", Path("foo"), "baz", xx=True) + assert await decoded_r.json().set("obj", Path("qaz"), "baz", nx=True) # Test that flags are mutually exlusive with pytest.raises(Exception): - await modclient.json().set("obj", Path("foo"), "baz", nx=True, xx=True) + await decoded_r.json().set("obj", Path("foo"), "baz", nx=True, xx=True) @pytest.mark.redismod -async def test_mgetshouldsucceed(modclient: redis.Redis): - await modclient.json().set("1", Path.root_path(), 1) - await modclient.json().set("2", Path.root_path(), 2) - assert await modclient.json().mget(["1"], Path.root_path()) == [1] +async def test_mgetshouldsucceed(decoded_r: redis.Redis): + await decoded_r.json().set("1", Path.root_path(), 1) + await decoded_r.json().set("2", Path.root_path(), 2) + assert await decoded_r.json().mget(["1"], Path.root_path()) == [1] - assert await modclient.json().mget([1, 2], Path.root_path()) == [1, 2] + assert await decoded_r.json().mget([1, 2], Path.root_path()) == [1, 2] @pytest.mark.redismod -@skip_ifmodversion_lt("2.6.0", "ReJSON") # todo: update after the release -async def test_mset(modclient: redis.Redis): - await modclient.json().mset( +@skip_ifmodversion_lt("2.6.0", "ReJSON") +async def test_mset(decoded_r: redis.Redis): + await decoded_r.json().mset( [("1", Path.root_path(), 1), ("2", Path.root_path(), 2)] ) - assert await modclient.json().mget(["1"], Path.root_path()) == [1] - assert await modclient.json().mget(["1", "2"], Path.root_path()) == [1, 2] + assert await decoded_r.json().mget(["1"], Path.root_path()) == [1] + assert await decoded_r.json().mget(["1", "2"], Path.root_path()) == [1, 2] @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -async def test_clear(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 1 == await modclient.json().clear("arr", Path.root_path()) - assert [] == await modclient.json().get("arr") +async def test_clear(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 1 == await decoded_r.json().clear("arr", Path.root_path()) + assert_resp_response(decoded_r, await decoded_r.json().get("arr"), [], [[[]]]) @pytest.mark.redismod -async def test_type(modclient: redis.Redis): - await modclient.json().set("1", Path.root_path(), 1) - assert "integer" == await modclient.json().type("1", Path.root_path()) - assert "integer" == await modclient.json().type("1") +async def test_type(decoded_r: redis.Redis): + await decoded_r.json().set("1", Path.root_path(), 1) + assert_resp_response( + decoded_r, + await decoded_r.json().type("1", Path.root_path()), + "integer", + ["integer"], + ) + assert_resp_response( + decoded_r, await decoded_r.json().type("1"), "integer", ["integer"] + ) @pytest.mark.redismod -async def test_numincrby(modclient): - await modclient.json().set("num", Path.root_path(), 1) - assert 2 == await modclient.json().numincrby("num", Path.root_path(), 1) - assert 2.5 == await modclient.json().numincrby("num", Path.root_path(), 0.5) - assert 1.25 == await modclient.json().numincrby("num", Path.root_path(), -1.25) +async def test_numincrby(decoded_r): + await decoded_r.json().set("num", Path.root_path(), 1) + assert_resp_response( + decoded_r, await decoded_r.json().numincrby("num", Path.root_path(), 1), 2, [2] + ) + res = await decoded_r.json().numincrby("num", Path.root_path(), 0.5) + assert_resp_response(decoded_r, res, 2.5, [2.5]) + res = await decoded_r.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response(decoded_r, res, 1.25, [1.25]) @pytest.mark.redismod -async def test_nummultby(modclient: redis.Redis): - await modclient.json().set("num", Path.root_path(), 1) +async def test_nummultby(decoded_r: redis.Redis): + await decoded_r.json().set("num", Path.root_path(), 1) with pytest.deprecated_call(): - assert 2 == await modclient.json().nummultby("num", Path.root_path(), 2) - assert 5 == await modclient.json().nummultby("num", Path.root_path(), 2.5) - assert 2.5 == await modclient.json().nummultby("num", Path.root_path(), 0.5) + res = await decoded_r.json().nummultby("num", Path.root_path(), 2) + assert_resp_response(decoded_r, res, 2, [2]) + res = await decoded_r.json().nummultby("num", Path.root_path(), 2.5) + assert_resp_response(decoded_r, res, 5, [5]) + res = await decoded_r.json().nummultby("num", Path.root_path(), 0.5) + assert_resp_response(decoded_r, res, 2.5, [2.5]) @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -async def test_toggle(modclient: redis.Redis): - await modclient.json().set("bool", Path.root_path(), False) - assert await modclient.json().toggle("bool", Path.root_path()) - assert await modclient.json().toggle("bool", Path.root_path()) is False +async def test_toggle(decoded_r: redis.Redis): + await decoded_r.json().set("bool", Path.root_path(), False) + assert await decoded_r.json().toggle("bool", Path.root_path()) + assert await decoded_r.json().toggle("bool", Path.root_path()) is False # check non-boolean value - await modclient.json().set("num", Path.root_path(), 1) + await decoded_r.json().set("num", Path.root_path(), 1) with pytest.raises(exceptions.ResponseError): - await modclient.json().toggle("num", Path.root_path()) + await decoded_r.json().toggle("num", Path.root_path()) @pytest.mark.redismod -async def test_strappend(modclient: redis.Redis): - await modclient.json().set("jsonkey", Path.root_path(), "foo") - assert 6 == await modclient.json().strappend("jsonkey", "bar") - assert "foobar" == await modclient.json().get("jsonkey", Path.root_path()) +async def test_strappend(decoded_r: redis.Redis): + await decoded_r.json().set("jsonkey", Path.root_path(), "foo") + assert 6 == await decoded_r.json().strappend("jsonkey", "bar") + res = await decoded_r.json().get("jsonkey", Path.root_path()) + assert_resp_response(decoded_r, res, "foobar", [["foobar"]]) @pytest.mark.redismod -async def test_strlen(modclient: redis.Redis): - await modclient.json().set("str", Path.root_path(), "foo") - assert 3 == await modclient.json().strlen("str", Path.root_path()) - await modclient.json().strappend("str", "bar", Path.root_path()) - assert 6 == await modclient.json().strlen("str", Path.root_path()) - assert 6 == await modclient.json().strlen("str") +async def test_strlen(decoded_r: redis.Redis): + await decoded_r.json().set("str", Path.root_path(), "foo") + assert 3 == await decoded_r.json().strlen("str", Path.root_path()) + await decoded_r.json().strappend("str", "bar", Path.root_path()) + assert 6 == await decoded_r.json().strlen("str", Path.root_path()) + assert 6 == await decoded_r.json().strlen("str") @pytest.mark.redismod -async def test_arrappend(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [1]) - assert 2 == await modclient.json().arrappend("arr", Path.root_path(), 2) - assert 4 == await modclient.json().arrappend("arr", Path.root_path(), 3, 4) - assert 7 == await modclient.json().arrappend("arr", Path.root_path(), *[5, 6, 7]) +async def test_arrappend(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [1]) + assert 2 == await decoded_r.json().arrappend("arr", Path.root_path(), 2) + assert 4 == await decoded_r.json().arrappend("arr", Path.root_path(), 3, 4) + assert 7 == await decoded_r.json().arrappend("arr", Path.root_path(), *[5, 6, 7]) @pytest.mark.redismod -async def test_arrindex(modclient: redis.Redis): +async def test_arrindex(decoded_r: redis.Redis): r_path = Path.root_path() - await modclient.json().set("arr", r_path, [0, 1, 2, 3, 4]) - assert 1 == await modclient.json().arrindex("arr", r_path, 1) - assert -1 == await modclient.json().arrindex("arr", r_path, 1, 2) - assert 4 == await modclient.json().arrindex("arr", r_path, 4) - assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0) - assert 4 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=5000) - assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=0, stop=-1) - assert -1 == await modclient.json().arrindex("arr", r_path, 4, start=1, stop=3) + await decoded_r.json().set("arr", r_path, [0, 1, 2, 3, 4]) + assert 1 == await decoded_r.json().arrindex("arr", r_path, 1) + assert -1 == await decoded_r.json().arrindex("arr", r_path, 1, 2) + assert 4 == await decoded_r.json().arrindex("arr", r_path, 4) + assert 4 == await decoded_r.json().arrindex("arr", r_path, 4, start=0) + assert 4 == await decoded_r.json().arrindex("arr", r_path, 4, start=0, stop=5000) + assert -1 == await decoded_r.json().arrindex("arr", r_path, 4, start=0, stop=-1) + assert -1 == await decoded_r.json().arrindex("arr", r_path, 4, start=1, stop=3) @pytest.mark.redismod -async def test_arrinsert(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 4]) - assert 5 - -await modclient.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) - assert [0, 1, 2, 3, 4] == await modclient.json().get("arr") +async def test_arrinsert(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 4]) + assert 5 == await decoded_r.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) + res = [0, 1, 2, 3, 4] + assert_resp_response(decoded_r, await decoded_r.json().get("arr"), res, [[res]]) # test prepends - await modclient.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) - await modclient.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) - assert await modclient.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + await decoded_r.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) + await decoded_r.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) + res = [["some", "thing"], 5, 6, 7, 8, 9] + assert_resp_response(decoded_r, await decoded_r.json().get("val2"), res, [[res]]) @pytest.mark.redismod -async def test_arrlen(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 5 == await modclient.json().arrlen("arr", Path.root_path()) - assert 5 == await modclient.json().arrlen("arr") - assert await modclient.json().arrlen("fakekey") is None +async def test_arrlen(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 5 == await decoded_r.json().arrlen("arr", Path.root_path()) + assert 5 == await decoded_r.json().arrlen("arr") + assert await decoded_r.json().arrlen("fakekey") is None @pytest.mark.redismod -async def test_arrpop(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 4 == await modclient.json().arrpop("arr", Path.root_path(), 4) - assert 3 == await modclient.json().arrpop("arr", Path.root_path(), -1) - assert 2 == await modclient.json().arrpop("arr", Path.root_path()) - assert 0 == await modclient.json().arrpop("arr", Path.root_path(), 0) - assert [1] == await modclient.json().get("arr") +async def test_arrpop(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 4 == await decoded_r.json().arrpop("arr", Path.root_path(), 4) + assert 3 == await decoded_r.json().arrpop("arr", Path.root_path(), -1) + assert 2 == await decoded_r.json().arrpop("arr", Path.root_path()) + assert 0 == await decoded_r.json().arrpop("arr", Path.root_path(), 0) + assert_resp_response(decoded_r, await decoded_r.json().get("arr"), [1], [[[1]]]) # test out of bounds - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 4 == await modclient.json().arrpop("arr", Path.root_path(), 99) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 4 == await decoded_r.json().arrpop("arr", Path.root_path(), 99) # none test - await modclient.json().set("arr", Path.root_path(), []) - assert await modclient.json().arrpop("arr") is None + await decoded_r.json().set("arr", Path.root_path(), []) + assert await decoded_r.json().arrpop("arr") is None @pytest.mark.redismod -async def test_arrtrim(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 3 == await modclient.json().arrtrim("arr", Path.root_path(), 1, 3) - assert [1, 2, 3] == await modclient.json().get("arr") +async def test_arrtrim(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 3 == await decoded_r.json().arrtrim("arr", Path.root_path(), 1, 3) + res = await decoded_r.json().get("arr") + assert_resp_response(decoded_r, res, [1, 2, 3], [[[1, 2, 3]]]) # <0 test, should be 0 equivalent - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 0 == await modclient.json().arrtrim("arr", Path.root_path(), -1, 3) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), -1, 3) # testing stop > end - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 2 == await modclient.json().arrtrim("arr", Path.root_path(), 3, 99) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 2 == await decoded_r.json().arrtrim("arr", Path.root_path(), 3, 99) # start > array size and stop - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 0 == await modclient.json().arrtrim("arr", Path.root_path(), 9, 1) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), 9, 1) # all larger - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 0 == await modclient.json().arrtrim("arr", Path.root_path(), 9, 11) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), 9, 11) @pytest.mark.redismod -async def test_resp(modclient: redis.Redis): +async def test_resp(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": 1, "qaz": True} - await modclient.json().set("obj", Path.root_path(), obj) - assert "bar" == await modclient.json().resp("obj", Path("foo")) - assert 1 == await modclient.json().resp("obj", Path("baz")) - assert await modclient.json().resp("obj", Path("qaz")) - assert isinstance(await modclient.json().resp("obj"), list) + await decoded_r.json().set("obj", Path.root_path(), obj) + assert "bar" == await decoded_r.json().resp("obj", Path("foo")) + assert 1 == await decoded_r.json().resp("obj", Path("baz")) + assert await decoded_r.json().resp("obj", Path("qaz")) + assert isinstance(await decoded_r.json().resp("obj"), list) @pytest.mark.redismod -async def test_objkeys(modclient: redis.Redis): +async def test_objkeys(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": "qaz"} - await modclient.json().set("obj", Path.root_path(), obj) - keys = await modclient.json().objkeys("obj", Path.root_path()) + await decoded_r.json().set("obj", Path.root_path(), obj) + keys = await decoded_r.json().objkeys("obj", Path.root_path()) keys.sort() exp = list(obj.keys()) exp.sort() assert exp == keys - await modclient.json().set("obj", Path.root_path(), obj) - keys = await modclient.json().objkeys("obj") + await decoded_r.json().set("obj", Path.root_path(), obj) + keys = await decoded_r.json().objkeys("obj") assert keys == list(obj.keys()) - assert await modclient.json().objkeys("fakekey") is None + assert await decoded_r.json().objkeys("fakekey") is None @pytest.mark.redismod -async def test_objlen(modclient: redis.Redis): +async def test_objlen(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": "qaz"} - await modclient.json().set("obj", Path.root_path(), obj) - assert len(obj) == await modclient.json().objlen("obj", Path.root_path()) + await decoded_r.json().set("obj", Path.root_path(), obj) + assert len(obj) == await decoded_r.json().objlen("obj", Path.root_path()) - await modclient.json().set("obj", Path.root_path(), obj) - assert len(obj) == await modclient.json().objlen("obj") + await decoded_r.json().set("obj", Path.root_path(), obj) + assert len(obj) == await decoded_r.json().objlen("obj") # @pytest.mark.redismod -# async def test_json_commands_in_pipeline(modclient: redis.Redis): -# async with modclient.json().pipeline() as p: +# async def test_json_commands_in_pipeline(decoded_r: redis.Redis): +# async with decoded_r.json().pipeline() as p: # p.set("foo", Path.root_path(), "bar") # p.get("foo") # p.delete("foo") # assert [True, "bar", 1] == await p.execute() -# assert await modclient.keys() == [] -# assert await modclient.get("foo") is None +# assert await decoded_r.keys() == [] +# assert await decoded_r.get("foo") is None # # now with a true, json object -# await modclient.flushdb() -# p = await modclient.json().pipeline() +# await decoded_r.flushdb() +# p = await decoded_r.json().pipeline() # d = {"hello": "world", "oh": "snap"} # with pytest.deprecated_call(): # p.jsonset("foo", Path.root_path(), d) @@ -320,23 +340,24 @@ async def test_objlen(modclient: redis.Redis): # p.exists("notarealkey") # p.delete("foo") # assert [True, d, 0, 1] == p.execute() -# assert await modclient.keys() == [] -# assert await modclient.get("foo") is None +# assert await decoded_r.keys() == [] +# assert await decoded_r.get("foo") is None @pytest.mark.redismod -async def test_json_delete_with_dollar(modclient: redis.Redis): +async def test_json_delete_with_dollar(decoded_r: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} - assert await modclient.json().set("doc1", "$", doc1) - assert await modclient.json().delete("doc1", "$..a") == 2 - r = await modclient.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + assert await decoded_r.json().set("doc1", "$", doc1) + assert await decoded_r.json().delete("doc1", "$..a") == 2 + res = [{"nested": {"b": 3}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} - assert await modclient.json().set("doc2", "$", doc2) - assert await modclient.json().delete("doc2", "$..a") == 1 - res = await modclient.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert await decoded_r.json().set("doc2", "$", doc2) + assert await decoded_r.json().delete("doc2", "$..a") == 1 + res = await decoded_r.json().get("doc2", "$") + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -350,8 +371,8 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): ], } ] - assert await modclient.json().set("doc3", "$", doc3) - assert await modclient.json().delete("doc3", '$.[0]["nested"]..ciao') == 3 + assert await decoded_r.json().set("doc3", "$", doc3) + assert await decoded_r.json().delete("doc3", '$.[0]["nested"]..ciao') == 3 doc3val = [ [ @@ -367,29 +388,29 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): } ] ] - res = await modclient.json().get("doc3", "$") - assert res == doc3val + res = await decoded_r.json().get("doc3", "$") + assert_resp_response(decoded_r, res, doc3val, [doc3val]) # Test async default path - assert await modclient.json().delete("doc3") == 1 - assert await modclient.json().get("doc3", "$") is None + assert await decoded_r.json().delete("doc3") == 1 + assert await decoded_r.json().get("doc3", "$") is None - await modclient.json().delete("not_a_document", "..a") + await decoded_r.json().delete("not_a_document", "..a") @pytest.mark.redismod -async def test_json_forget_with_dollar(modclient: redis.Redis): +async def test_json_forget_with_dollar(decoded_r: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} - assert await modclient.json().set("doc1", "$", doc1) - assert await modclient.json().forget("doc1", "$..a") == 2 - r = await modclient.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + assert await decoded_r.json().set("doc1", "$", doc1) + assert await decoded_r.json().forget("doc1", "$..a") == 2 + res = [{"nested": {"b": 3}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} - assert await modclient.json().set("doc2", "$", doc2) - assert await modclient.json().forget("doc2", "$..a") == 1 - res = await modclient.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert await decoded_r.json().set("doc2", "$", doc2) + assert await decoded_r.json().forget("doc2", "$..a") == 1 + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -403,8 +424,8 @@ async def test_json_forget_with_dollar(modclient: redis.Redis): ], } ] - assert await modclient.json().set("doc3", "$", doc3) - assert await modclient.json().forget("doc3", '$.[0]["nested"]..ciao') == 3 + assert await decoded_r.json().set("doc3", "$", doc3) + assert await decoded_r.json().forget("doc3", '$.[0]["nested"]..ciao') == 3 doc3val = [ [ @@ -420,161 +441,165 @@ async def test_json_forget_with_dollar(modclient: redis.Redis): } ] ] - res = await modclient.json().get("doc3", "$") - assert res == doc3val + res = await decoded_r.json().get("doc3", "$") + assert_resp_response(decoded_r, res, doc3val, [doc3val]) # Test async default path - assert await modclient.json().forget("doc3") == 1 - assert await modclient.json().get("doc3", "$") is None + assert await decoded_r.json().forget("doc3") == 1 + assert await decoded_r.json().get("doc3", "$") is None - await modclient.json().forget("not_a_document", "..a") + await decoded_r.json().forget("not_a_document", "..a") @pytest.mark.redismod -async def test_json_mget_dollar(modclient: redis.Redis): +async def test_json_mget_dollar(decoded_r: redis.Redis): # Test mget with multi paths - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": 1, "b": 2, "nested": {"a": 3}, "c": None, "nested2": {"a": None}}, ) - await modclient.json().set( + await decoded_r.json().set( "doc2", "$", {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET - assert await modclient.json().get("doc1", "$..a") == [1, 3, None] - assert await modclient.json().get("doc2", "$..a") == [4, 6, [None]] + res = [1, 3, None] + assert_resp_response( + decoded_r, await decoded_r.json().get("doc1", "$..a"), res, [res] + ) + res = [4, 6, [None]] + assert_resp_response( + decoded_r, await decoded_r.json().get("doc2", "$..a"), res, [res] + ) # Test mget with single path - await modclient.json().mget("doc1", "$..a") == [1, 3, None] + await decoded_r.json().mget("doc1", "$..a") == [1, 3, None] # Test mget with multi path - res = await modclient.json().mget(["doc1", "doc2"], "$..a") + res = await decoded_r.json().mget(["doc1", "doc2"], "$..a") assert res == [[1, 3, None], [4, 6, [None]]] # Test missing key - res = await modclient.json().mget(["doc1", "missing_doc"], "$..a") + res = await decoded_r.json().mget(["doc1", "missing_doc"], "$..a") assert res == [[1, 3, None], None] - res = await modclient.json().mget(["missing_doc1", "missing_doc2"], "$..a") + res = await decoded_r.json().mget(["missing_doc1", "missing_doc2"], "$..a") assert res == [None, None] @pytest.mark.redismod -async def test_numby_commands_dollar(modclient: redis.Redis): +async def test_numby_commands_dollar(decoded_r: redis.Redis): # Test NUMINCRBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) # Test multi - assert await modclient.json().numincrby("doc1", "$..a", 2) == [None, 4, 7.0, None] + assert await decoded_r.json().numincrby("doc1", "$..a", 2) == [None, 4, 7.0, None] - res = await modclient.json().numincrby("doc1", "$..a", 2.5) + res = await decoded_r.json().numincrby("doc1", "$..a", 2.5) assert res == [None, 6.5, 9.5, None] # Test single - assert await modclient.json().numincrby("doc1", "$.b[1].a", 2) == [11.5] + assert await decoded_r.json().numincrby("doc1", "$.b[1].a", 2) == [11.5] - assert await modclient.json().numincrby("doc1", "$.b[2].a", 2) == [None] - assert await modclient.json().numincrby("doc1", "$.b[1].a", 3.5) == [15.0] + assert await decoded_r.json().numincrby("doc1", "$.b[2].a", 2) == [None] + assert await decoded_r.json().numincrby("doc1", "$.b[1].a", 3.5) == [15.0] # Test NUMMULTBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) # test list with pytest.deprecated_call(): - res = await modclient.json().nummultby("doc1", "$..a", 2) + res = await decoded_r.json().nummultby("doc1", "$..a", 2) assert res == [None, 4, 10, None] - res = await modclient.json().nummultby("doc1", "$..a", 2.5) + res = await decoded_r.json().nummultby("doc1", "$..a", 2.5) assert res == [None, 10.0, 25.0, None] # Test single with pytest.deprecated_call(): - assert await modclient.json().nummultby("doc1", "$.b[1].a", 2) == [50.0] - assert await modclient.json().nummultby("doc1", "$.b[2].a", 2) == [None] - assert await modclient.json().nummultby("doc1", "$.b[1].a", 3) == [150.0] + assert await decoded_r.json().nummultby("doc1", "$.b[1].a", 2) == [50.0] + assert await decoded_r.json().nummultby("doc1", "$.b[2].a", 2) == [None] + assert await decoded_r.json().nummultby("doc1", "$.b[1].a", 3) == [150.0] # test missing keys with pytest.raises(exceptions.ResponseError): - await modclient.json().numincrby("non_existing_doc", "$..a", 2) - await modclient.json().nummultby("non_existing_doc", "$..a", 2) + await decoded_r.json().numincrby("non_existing_doc", "$..a", 2) + await decoded_r.json().nummultby("non_existing_doc", "$..a", 2) # Test legacy NUMINCRBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) - await modclient.json().numincrby("doc1", ".b[0].a", 3) == 5 + await decoded_r.json().numincrby("doc1", ".b[0].a", 3) == 5 # Test legacy NUMMULTBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) with pytest.deprecated_call(): - await modclient.json().nummultby("doc1", ".b[0].a", 3) == 6 + await decoded_r.json().nummultby("doc1", ".b[0].a", 3) == 6 @pytest.mark.redismod -async def test_strappend_dollar(modclient: redis.Redis): +async def test_strappend_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) # Test multi - await modclient.json().strappend("doc1", "bar", "$..a") == [6, 8, None] + await decoded_r.json().strappend("doc1", "bar", "$..a") == [6, 8, None] + + res = [{"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) - await modclient.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] # Test single - await modclient.json().strappend("doc1", "baz", "$.nested1.a") == [11] + await decoded_r.json().strappend("doc1", "baz", "$.nested1.a") == [11] - await modclient.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] + res = [{"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().strappend("non_existing_doc", "$..a", "err") + await decoded_r.json().strappend("non_existing_doc", "$..a", "err") # Test multi - await modclient.json().strappend("doc1", "bar", ".*.a") == 8 - await modclient.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + await decoded_r.json().strappend("doc1", "bar", ".*.a") == 8 + res = [{"a": "foobar", "nested1": {"a": "hellobarbazbar"}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing path with pytest.raises(exceptions.ResponseError): - await modclient.json().strappend("doc1", "piu") + await decoded_r.json().strappend("doc1", "piu") @pytest.mark.redismod -async def test_strlen_dollar(modclient: redis.Redis): +async def test_strlen_dollar(decoded_r: redis.Redis): # Test multi - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) - assert await modclient.json().strlen("doc1", "$..a") == [3, 5, None] + assert await decoded_r.json().strlen("doc1", "$..a") == [3, 5, None] - res2 = await modclient.json().strappend("doc1", "bar", "$..a") - res1 = await modclient.json().strlen("doc1", "$..a") + res2 = await decoded_r.json().strappend("doc1", "bar", "$..a") + res1 = await decoded_r.json().strlen("doc1", "$..a") assert res1 == res2 # Test single - await modclient.json().strlen("doc1", "$.nested1.a") == [8] - await modclient.json().strlen("doc1", "$.nested2.a") == [None] + await decoded_r.json().strlen("doc1", "$.nested1.a") == [8] + await decoded_r.json().strlen("doc1", "$.nested2.a") == [None] # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().strlen("non_existing_doc", "$..a") + await decoded_r.json().strlen("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_arrappend_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_arrappend_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -584,31 +609,33 @@ async def test_arrappend_dollar(modclient: redis.Redis): }, ) # Test multi - await modclient.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert await modclient.json().get("doc1", "$") == [ + await decoded_r.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().arrappend("doc1", "$.nested1.a", "baz") == [6] + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") # Test legacy - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -618,33 +645,35 @@ async def test_arrappend_dollar(modclient: redis.Redis): }, ) # Test multi (all paths are updated, but return result of last path) - assert await modclient.json().arrappend("doc1", "..a", "bar", "racuda") == 5 + assert await decoded_r.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().arrappend("doc1", ".nested1.a", "baz") == 6 + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_arrinsert_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_arrinsert_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -654,35 +683,37 @@ async def test_arrinsert_dollar(modclient: redis.Redis): }, ) # Test multi - res = await modclient.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") + res = await decoded_r.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") assert res == [3, 5, None] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_arrlen_dollar(modclient: redis.Redis): +async def test_arrlen_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -693,20 +724,20 @@ async def test_arrlen_dollar(modclient: redis.Redis): ) # Test multi - assert await modclient.json().arrlen("doc1", "$..a") == [1, 3, None] - res = await modclient.json().arrappend("doc1", "$..a", "non", "abba", "stanza") + assert await decoded_r.json().arrlen("doc1", "$..a") == [1, 3, None] + res = await decoded_r.json().arrappend("doc1", "$..a", "non", "abba", "stanza") assert res == [4, 6, None] - await modclient.json().clear("doc1", "$.a") - assert await modclient.json().arrlen("doc1", "$..a") == [0, 6, None] + await decoded_r.json().clear("doc1", "$.a") + assert await decoded_r.json().arrlen("doc1", "$..a") == [0, 6, None] # Test single - assert await modclient.json().arrlen("doc1", "$.nested1.a") == [6] + assert await decoded_r.json().arrlen("doc1", "$.nested1.a") == [6] # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -716,19 +747,19 @@ async def test_arrlen_dollar(modclient: redis.Redis): }, ) # Test multi (return result of last path) - assert await modclient.json().arrlen("doc1", "$..a") == [1, 3, None] - assert await modclient.json().arrappend("doc1", "..a", "non", "abba", "stanza") == 6 + assert await decoded_r.json().arrlen("doc1", "$..a") == [1, 3, None] + assert await decoded_r.json().arrappend("doc1", "..a", "non", "abba", "stanza") == 6 # Test single - assert await modclient.json().arrlen("doc1", ".nested1.a") == 6 + assert await decoded_r.json().arrlen("doc1", ".nested1.a") == 6 # Test missing key - assert await modclient.json().arrlen("non_existing_doc", "..a") is None + assert await decoded_r.json().arrlen("non_existing_doc", "..a") is None @pytest.mark.redismod -async def test_arrpop_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_arrpop_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -738,19 +769,18 @@ async def test_arrpop_dollar(modclient: redis.Redis): }, ) - # # # Test multi - assert await modclient.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] + # Test multi + assert await decoded_r.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrpop("non_existing_doc", "..a") + await decoded_r.json().arrpop("non_existing_doc", "..a") # # Test legacy - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -760,20 +790,19 @@ async def test_arrpop_dollar(modclient: redis.Redis): }, ) # Test multi (all paths are updated, but return result of last path) - await modclient.json().arrpop("doc1", "..a", "1") is None - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + await decoded_r.json().arrpop("doc1", "..a", "1") is None + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrpop("non_existing_doc", "..a") + await decoded_r.json().arrpop("non_existing_doc", "..a") @pytest.mark.redismod -async def test_arrtrim_dollar(modclient: redis.Redis): +async def test_arrtrim_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -783,27 +812,24 @@ async def test_arrtrim_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] + res = [{"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) - assert await modclient.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] + res = [{"a": [], "nested1": {"a": []}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrtrim("non_existing_doc", "..a", "0", 1) + await decoded_r.json().arrtrim("non_existing_doc", "..a", "0", 1) # Test legacy - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -814,22 +840,21 @@ async def test_arrtrim_dollar(modclient: redis.Redis): ) # Test multi (all paths are updated, but return result of last path) - assert await modclient.json().arrtrim("doc1", "..a", "1", "-1") == 2 + assert await decoded_r.json().arrtrim("doc1", "..a", "1", "-1") == 2 # Test single - assert await modclient.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrtrim("non_existing_doc", "..a", 1, 1) + await decoded_r.json().arrtrim("non_existing_doc", "..a", 1, 1) @pytest.mark.redismod -async def test_objkeys_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_objkeys_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -840,26 +865,26 @@ async def test_objkeys_dollar(modclient: redis.Redis): ) # Test single - assert await modclient.json().objkeys("doc1", "$.nested1.a") == [["foo", "bar"]] + assert await decoded_r.json().objkeys("doc1", "$.nested1.a") == [["foo", "bar"]] # Test legacy - assert await modclient.json().objkeys("doc1", ".*.a") == ["foo", "bar"] + assert await decoded_r.json().objkeys("doc1", ".*.a") == ["foo", "bar"] # Test single - assert await modclient.json().objkeys("doc1", ".nested2.a") == ["baz"] + assert await decoded_r.json().objkeys("doc1", ".nested2.a") == ["baz"] # Test missing key - assert await modclient.json().objkeys("non_existing_doc", "..a") is None + assert await decoded_r.json().objkeys("non_existing_doc", "..a") is None # Test non existing doc with pytest.raises(exceptions.ResponseError): - assert await modclient.json().objkeys("non_existing_doc", "$..a") == [] + assert await decoded_r.json().objkeys("non_existing_doc", "$..a") == [] - assert await modclient.json().objkeys("doc1", "$..nowhere") == [] + assert await decoded_r.json().objkeys("doc1", "$..nowhere") == [] @pytest.mark.redismod -async def test_objlen_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_objlen_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -869,28 +894,28 @@ async def test_objlen_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().objlen("doc1", "$..a") == [None, 2, 1] + assert await decoded_r.json().objlen("doc1", "$..a") == [None, 2, 1] # Test single - assert await modclient.json().objlen("doc1", "$.nested1.a") == [2] + assert await decoded_r.json().objlen("doc1", "$.nested1.a") == [2] # Test missing key, and path with pytest.raises(exceptions.ResponseError): - await modclient.json().objlen("non_existing_doc", "$..a") + await decoded_r.json().objlen("non_existing_doc", "$..a") - assert await modclient.json().objlen("doc1", "$.nowhere") == [] + assert await decoded_r.json().objlen("doc1", "$.nowhere") == [] # Test legacy - assert await modclient.json().objlen("doc1", ".*.a") == 2 + assert await decoded_r.json().objlen("doc1", ".*.a") == 2 # Test single - assert await modclient.json().objlen("doc1", ".nested2.a") == 1 + assert await decoded_r.json().objlen("doc1", ".nested2.a") == 1 # Test missing key - assert await modclient.json().objlen("non_existing_doc", "..a") is None + assert await decoded_r.json().objlen("non_existing_doc", "..a") is None # Test missing path # with pytest.raises(exceptions.ResponseError): - await modclient.json().objlen("doc1", ".nowhere") + await decoded_r.json().objlen("doc1", ".nowhere") @pytest.mark.redismod @@ -914,23 +939,28 @@ def load_types_data(nested_key_name): @pytest.mark.redismod -async def test_type_dollar(modclient: redis.Redis): +async def test_type_dollar(decoded_r: redis.Redis): jdata, jtypes = load_types_data("a") - await modclient.json().set("doc1", "$", jdata) + await decoded_r.json().set("doc1", "$", jdata) # Test multi - assert await modclient.json().type("doc1", "$..a") == jtypes + assert_resp_response( + decoded_r, await decoded_r.json().type("doc1", "$..a"), jtypes, [jtypes] + ) # Test single - assert await modclient.json().type("doc1", "$.nested2.a") == [jtypes[1]] + res = await decoded_r.json().type("doc1", "$.nested2.a") + assert_resp_response(decoded_r, res, [jtypes[1]], [[jtypes[1]]]) # Test missing key - assert await modclient.json().type("non_existing_doc", "..a") is None + assert_resp_response( + decoded_r, await decoded_r.json().type("non_existing_doc", "..a"), None, [None] + ) @pytest.mark.redismod -async def test_clear_dollar(modclient: redis.Redis): +async def test_clear_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -942,14 +972,15 @@ async def test_clear_dollar(modclient: redis.Redis): ) # Test multi - assert await modclient.json().clear("doc1", "$..a") == 3 + assert await decoded_r.json().clear("doc1", "$..a") == 3 - assert await modclient.json().get("doc1", "$") == [ + res = [ {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -959,8 +990,8 @@ async def test_clear_dollar(modclient: redis.Redis): "nested3": {"a": {"baz": 50}}, }, ) - assert await modclient.json().clear("doc1", "$.nested1.a") == 1 - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().clear("doc1", "$.nested1.a") == 1 + res = [ { "nested1": {"a": {}}, "a": ["foo"], @@ -968,19 +999,22 @@ async def test_clear_dollar(modclient: redis.Redis): "nested3": {"a": {"baz": 50}}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing path (async defaults to root) - assert await modclient.json().clear("doc1") == 1 - assert await modclient.json().get("doc1", "$") == [{}] + assert await decoded_r.json().clear("doc1") == 1 + assert_resp_response( + decoded_r, await decoded_r.json().get("doc1", "$"), [{}], [[{}]] + ) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().clear("non_existing_doc", "$..a") + await decoded_r.json().clear("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_toggle_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_toggle_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -991,8 +1025,8 @@ async def test_toggle_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().toggle("doc1", "$..a") == [None, 1, None, 0] + res = [ { "a": ["foo"], "nested1": {"a": True}, @@ -1000,7 +1034,8 @@ async def test_toggle_dollar(modclient: redis.Redis): "nested3": {"a": False}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().toggle("non_existing_doc", "$..a") + await decoded_r.json().toggle("non_existing_doc", "$..a") diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index d78f74164d..75484a2791 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -2,7 +2,6 @@ import pytest import pytest_asyncio - from redis.asyncio.lock import Lock from redis.exceptions import LockError, LockNotOwnedError diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py index 3551579ec0..73ee3cf811 100644 --- a/tests/test_asyncio/test_monitor.py +++ b/tests/test_asyncio/test_monitor.py @@ -1,5 +1,4 @@ import pytest - from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise from .conftest import wait_for_command diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 3df57eb90f..edd2f6d147 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -1,5 +1,4 @@ import pytest - import redis from tests.conftest import skip_if_server_version_lt @@ -21,7 +20,6 @@ async def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert await pipe.execute() == [ True, @@ -29,7 +27,6 @@ async def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] async def test_pipeline_memoryview(self, r): diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index ba70782e42..858576584f 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -14,11 +14,11 @@ import pytest import pytest_asyncio - import redis.asyncio as redis from redis.exceptions import ConnectionError from redis.typing import EncodableT -from tests.conftest import skip_if_server_version_lt +from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import get_protocol_version, skip_if_server_version_lt from .compat import create_task, mock @@ -422,6 +422,24 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): assert expect in info.exconly() +@pytest.mark.onlynoncluster +class TestPubSubRESP3Handler: + def my_handler(self, message): + self.message = ["my handler", message] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + async def test_push_handler(self, r): + if get_protocol_version(r) in [2, "2", None]: + return + p = r.pubsub(push_handler_func=self.my_handler) + await p.subscribe("foo") + assert await wait_for_message(p) is None + assert self.message == ["my handler", [b"subscribe", b"foo", 1]] + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + + @pytest.mark.onlynoncluster class TestPubSubAutoDecoding: """These tests only validate that we get unicode values back""" @@ -658,18 +676,15 @@ async def loop(): nonlocal interrupt await pubsub.subscribe("foo") while True: - # print("loop") try: try: await pubsub.connect() await loop_step() - # print("succ") except redis.ConnectionError: await asyncio.sleep(0.1) except asyncio.CancelledError: # we use a cancel to interrupt the "listen" # when we perform a disconnect - # print("cancel", interrupt) if interrupt: interrupt = False else: @@ -902,7 +917,6 @@ async def loop(self): try: if self.state == 4: break - # print("state a ", self.state) got_msg = await self.get_message() assert got_msg if self.state in (1, 2): @@ -920,7 +934,6 @@ async def loop(self): async def loop_step_get_message(self): # get a single message via get_message message = await self.pubsub.get_message(timeout=0.1) - # print(message) if message is not None: await self.messages.put(message) return True @@ -1000,13 +1013,15 @@ async def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert pubsub.connection.is_connected - with patch("redis.asyncio.connection.PythonParser.read_response") as mock1: + with patch("redis._parsers._AsyncRESP2Parser.read_response") as mock1, patch( + "redis._parsers._AsyncHiredisParser.read_response" + ) as mock2, patch("redis._parsers._AsyncRESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") - with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2: - mock2.side_effect = BaseException("boom") + mock2.side_effect = BaseException("boom") + mock3.side_effect = BaseException("boom") - with pytest.raises(BaseException): - await get_msg() + with pytest.raises(BaseException): + await get_msg() # the timeout on the read should not cause disconnect assert pubsub.connection.is_connected diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index 86e6ddfa0d..2912ca786c 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -1,5 +1,4 @@ import pytest - from redis.asyncio import Redis from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index 3776d12cb7..8375ecd787 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -1,6 +1,5 @@ import pytest import pytest_asyncio - from redis import exceptions from tests.conftest import skip_if_server_version_lt diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 8707cdf61b..149b26d958 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -5,7 +5,6 @@ from io import TextIOWrapper import pytest - import redis.asyncio as redis import redis.commands.search import redis.commands.search.aggregation as aggregations @@ -16,7 +15,12 @@ from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from tests.conftest import skip_if_redis_enterprise, skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_if_redis_enterprise, + skip_ifmodversion_lt, +) WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -32,12 +36,16 @@ async def waitForIndex(env, idx, timeout=None): while True: res = await env.execute_command("FT.INFO", idx) try: - res.index("indexing") + if int(res[res.index("indexing") + 1]) == 0: + break except ValueError: break - - if int(res[res.index("indexing") + 1]) == 0: - break + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break time.sleep(delay) if timeout is not None: @@ -46,23 +54,23 @@ async def waitForIndex(env, idx, timeout=None): break -def getClient(modclient: redis.Redis): +def getClient(decoded_r: redis.Redis): """ Gets a client client attached to an index name which is ready to be created """ - return modclient + return decoded_r -async def createIndex(modclient, num_docs=100, definition=None): +async def createIndex(decoded_r, num_docs=100, definition=None): try: - await modclient.create_index( + await decoded_r.create_index( (TextField("play", weight=5.0), TextField("txt"), NumericField("chapter")), definition=definition, ) except redis.ResponseError: - await modclient.dropindex(delete_documents=True) - return createIndex(modclient, num_docs=num_docs, definition=definition) + await decoded_r.dropindex(delete_documents=True) + return createIndex(decoded_r, num_docs=num_docs, definition=definition) chapters = {} bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") @@ -80,7 +88,7 @@ async def createIndex(modclient, num_docs=100, definition=None): if len(chapters) == num_docs: break - indexer = modclient.batch_indexer(chunk_size=50) + indexer = decoded_r.batch_indexer(chunk_size=50) assert isinstance(indexer, AsyncSearch.BatchIndexer) assert 50 == indexer.chunk_size @@ -90,12 +98,12 @@ async def createIndex(modclient, num_docs=100, definition=None): @pytest.mark.redismod -async def test_client(modclient: redis.Redis): +async def test_client(decoded_r: redis.Redis): num_docs = 500 - await createIndex(modclient.ft(), num_docs=num_docs) - await waitForIndex(modclient, "idx") + await createIndex(decoded_r.ft(), num_docs=num_docs) + await waitForIndex(decoded_r, "idx") # verify info - info = await modclient.ft().info() + info = await decoded_r.ft().info() for k in [ "index_name", "index_options", @@ -115,145 +123,268 @@ async def test_client(modclient: redis.Redis): ]: assert k in info - assert modclient.ft().index_name == info["index_name"] + assert decoded_r.ft().index_name == info["index_name"] assert num_docs == int(info["num_docs"]) - res = await modclient.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" + res = await decoded_r.ft().search("henry iv") + if is_resp2_connection(decoded_r): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc.play == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = await decoded_r.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = (await decoded_r.ft().search(Query("kings").no_content())).total + vtotal = ( + await decoded_r.ft().search(Query("kings").no_content().verbatim()) + ).total + assert total > vtotal + + # test in fields + txt_total = ( + await decoded_r.ft().search(Query("henry").no_content().limit_fields("txt")) + ).total + play_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play") + ) + ).total + both_total = ( + await ( + decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ) + ) + ).total + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await decoded_r.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = await modclient.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = (await modclient.ft().search(Query("kings").no_content())).total - vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim())).total - assert total > vtotal - - # test in fields - txt_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("txt")) - ).total - play_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("play")) - ).total - both_total = ( - await ( - modclient.ft().search( + # test in-keys + ids = [x.id for x in (await decoded_r.ft().search(Query("henry"))).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == (await decoded_r.ft().search(Query("henry king"))).total + assert ( + 3 + == ( + await decoded_r.ft().search(Query("henry king").slop(0).in_order()) + ).total + ) + assert ( + 52 + == ( + await decoded_r.ft().search(Query("king henry").slop(0).in_order()) + ).total + ) + assert 53 == (await decoded_r.ft().search(Query("henry king").slop(0))).total + assert 167 == (await decoded_r.ft().search(Query("henry king").slop(100))).total + + # test delete document + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") + + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res.total + await decoded_r.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["extra_attributes"]["play"] == "Henry IV" + assert len(doc["extra_attributes"]["txt"]) > 0 + + # test no content + res = await decoded_r.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "extra_attributes" not in doc.keys() + + # test verbatim vs no verbatim + total = (await decoded_r.ft().search(Query("kings").no_content()))[ + "total_results" + ] + vtotal = (await decoded_r.ft().search(Query("kings").no_content().verbatim()))[ + "total_results" + ] + assert total > vtotal + + # test in fields + txt_total = ( + await decoded_r.ft().search(Query("henry").no_content().limit_fields("txt")) + )["total_results"] + play_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play") + ) + )["total_results"] + both_total = ( + await decoded_r.ft().search( Query("henry").no_content().limit_fields("play", "txt") ) + )["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await decoded_r.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [ + x["id"] for x in (await decoded_r.ft().search(Query("henry")))["results"] + ] + assert 10 == len(ids) + subset = ids[:5] + docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert ( + 193 == (await decoded_r.ft().search(Query("henry king")))["total_results"] + ) + assert ( + 3 + == (await decoded_r.ft().search(Query("henry king").slop(0).in_order()))[ + "total_results" + ] + ) + assert ( + 52 + == (await decoded_r.ft().search(Query("king henry").slop(0).in_order()))[ + "total_results" + ] + ) + assert ( + 53 + == (await decoded_r.ft().search(Query("henry king").slop(0)))[ + "total_results" + ] + ) + assert ( + 167 + == (await decoded_r.ft().search(Query("henry king").slop(100)))[ + "total_results" + ] ) - ).total - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = await modclient.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in (await modclient.ft().search(Query("henry"))).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = await modclient.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == (await modclient.ft().search(Query("henry king"))).total - assert ( - 3 == (await modclient.ft().search(Query("henry king").slop(0).in_order())).total - ) - assert ( - 52 - == (await modclient.ft().search(Query("king henry").slop(0).in_order())).total - ) - assert 53 == (await modclient.ft().search(Query("henry king").slop(0))).total - assert 167 == (await modclient.ft().search(Query("henry king").slop(100))).total - # test delete document - await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await modclient.ft().search(Query("death of a salesman")) - assert 1 == res.total + # test delete document + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] - assert 1 == await modclient.ft().delete_document("doc-5ghs2") - res = await modclient.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == await modclient.ft().delete_document("doc-5ghs2") + assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") - await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await modclient.ft().search(Query("death of a salesman")) - assert 1 == res.total - await modclient.ft().delete_document("doc-5ghs2") + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + await decoded_r.ft().delete_document("doc-5ghs2") @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_scores(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"),)) +async def test_scores(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"),)) - await modclient.hset("doc1", mapping={"txt": "foo baz"}) - await modclient.hset("doc2", mapping={"txt": "foo bar"}) + await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) + await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo ~bar").with_scores() - res = await modclient.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] @pytest.mark.redismod -async def test_stopwords(modclient: redis.Redis): +async def test_stopwords(decoded_r: redis.Redis): stopwords = ["foo", "bar", "baz"] - await modclient.ft().create_index((TextField("txt"),), stopwords=stopwords) - await modclient.hset("doc1", mapping={"txt": "foo bar"}) - await modclient.hset("doc2", mapping={"txt": "hello world"}) - await waitForIndex(modclient, "idx") + await decoded_r.ft().create_index((TextField("txt"),), stopwords=stopwords) + await decoded_r.hset("doc1", mapping={"txt": "foo bar"}) + await decoded_r.hset("doc2", mapping={"txt": "hello world"}) + await waitForIndex(decoded_r, "idx") q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + if is_resp2_connection(decoded_r): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] @pytest.mark.redismod -async def test_filters(modclient: redis.Redis): +async def test_filters(decoded_r: redis.Redis): await ( - modclient.ft().create_index( + decoded_r.ft().create_index( (TextField("txt"), NumericField("num"), GeoField("loc")) ) ) await ( - modclient.hset( + decoded_r.hset( "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} ) ) await ( - modclient.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) ) - await waitForIndex(modclient, "idx") + await waitForIndex(decoded_r, "idx") # Test numerical filter q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() q2 = ( @@ -261,64 +392,90 @@ async def test_filters(modclient: redis.Redis): .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) .no_content() ) - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + + if is_resp2_connection(decoded_r): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] # Test geo filter q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + + if is_resp2_connection(decoded_r): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res @pytest.mark.redismod -async def test_sort_by(modclient: redis.Redis): +async def test_sort_by(decoded_r: redis.Redis): await ( - modclient.ft().create_index( + decoded_r.ft().create_index( (TextField("txt"), NumericField("num", sortable=True)) ) ) - await modclient.hset("doc1", mapping={"txt": "foo bar", "num": 1}) - await modclient.hset("doc2", mapping={"txt": "foo baz", "num": 2}) - await modclient.hset("doc3", mapping={"txt": "foo qux", "num": 3}) + await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) + await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) + await decoded_r.hset("doc3", mapping={"txt": "foo qux", "num": 3}) # Test sort q1 = Query("foo").sort_by("num", asc=True).no_content() q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + + if is_resp2_connection(decoded_r): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") -async def test_drop_index(modclient: redis.Redis): +async def test_drop_index(decoded_r: redis.Redis): """ Ensure the index gets dropped by data remains by default """ for x in range(20): for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: idx = "HaveIt" - index = getClient(modclient) + index = getClient(decoded_r) await index.hset("index:haveit", mapping={"name": "haveit"}) idef = IndexDefinition(prefix=["index:"]) await index.ft(idx).create_index((TextField("name"),), definition=idef) @@ -329,14 +486,14 @@ async def test_drop_index(modclient: redis.Redis): @pytest.mark.redismod -async def test_example(modclient: redis.Redis): +async def test_example(decoded_r: redis.Redis): # Creating the index definition and schema await ( - modclient.ft().create_index((TextField("title", weight=5.0), TextField("body"))) + decoded_r.ft().create_index((TextField("title", weight=5.0), TextField("body"))) ) # Indexing a document - await modclient.hset( + await decoded_r.hset( "doc1", mapping={ "title": "RediSearch", @@ -347,12 +504,12 @@ async def test_example(modclient: redis.Redis): # Searching with complex parameters: q = Query("search engine").verbatim().no_content().paging(0, 5) - res = await modclient.ft().search(q) + res = await decoded_r.ft().search(q) assert res is not None @pytest.mark.redismod -async def test_auto_complete(modclient: redis.Redis): +async def test_auto_complete(decoded_r: redis.Redis): n = 0 with open(TITLES_CSV) as f: cr = csv.reader(f) @@ -360,10 +517,10 @@ async def test_auto_complete(modclient: redis.Redis): for row in cr: n += 1 term, score = row[0], float(row[1]) - assert n == await modclient.ft().sugadd("ac", Suggestion(term, score=score)) + assert n == await decoded_r.ft().sugadd("ac", Suggestion(term, score=score)) - assert n == await modclient.ft().suglen("ac") - ret = await modclient.ft().sugget("ac", "bad", with_scores=True) + assert n == await decoded_r.ft().suglen("ac") + ret = await decoded_r.ft().sugget("ac", "bad", with_scores=True) assert 2 == len(ret) assert "badger" == ret[0].string assert isinstance(ret[0].score, float) @@ -372,29 +529,29 @@ async def test_auto_complete(modclient: redis.Redis): assert isinstance(ret[1].score, float) assert 1.0 != ret[1].score - ret = await modclient.ft().sugget("ac", "bad", fuzzy=True, num=10) + ret = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) assert 10 == len(ret) assert 1.0 == ret[0].score strs = {x.string for x in ret} for sug in strs: - assert 1 == await modclient.ft().sugdel("ac", sug) + assert 1 == await decoded_r.ft().sugdel("ac", sug) # make sure a second delete returns 0 for sug in strs: - assert 0 == await modclient.ft().sugdel("ac", sug) + assert 0 == await decoded_r.ft().sugdel("ac", sug) # make sure they were actually deleted - ret2 = await modclient.ft().sugget("ac", "bad", fuzzy=True, num=10) + ret2 = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) for sug in ret2: assert sug.string not in strs # Test with payload - await modclient.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - await modclient.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - await modclient.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) + await decoded_r.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) + await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) + await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) sugs = await ( - modclient.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) + decoded_r.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) ) assert 3 == len(sugs) for sug in sugs: @@ -403,8 +560,8 @@ async def test_auto_complete(modclient: redis.Redis): @pytest.mark.redismod -async def test_no_index(modclient: redis.Redis): - await modclient.ft().create_index( +async def test_no_index(decoded_r: redis.Redis): + await decoded_r.ft().create_index( ( TextField("field"), TextField("text", no_index=True, sortable=True), @@ -414,37 +571,60 @@ async def test_no_index(modclient: redis.Redis): ) ) - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"field": "aaa", "text": "1", "numeric": "1", "geo": "1,1", "tag": "1"}, ) - await modclient.hset( + await decoded_r.hset( "doc2", mapping={"field": "aab", "text": "2", "numeric": "2", "geo": "2,2", "tag": "2"}, ) - await waitForIndex(modclient, "idx") + await waitForIndex(decoded_r, "idx") + + if is_resp2_connection(decoded_r): + res = await decoded_r.ft().search(Query("@text:aa*")) + assert 0 == res.total - res = await modclient.ft().search(Query("@text:aa*")) - assert 0 == res.total + res = await decoded_r.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = await modclient.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = await decoded_r.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] - res = await modclient.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] + + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] # Ensure exception is raised for non-indexable, non-sortable fields with pytest.raises(Exception): @@ -458,51 +638,68 @@ async def test_no_index(modclient: redis.Redis): @pytest.mark.redismod -async def test_explain(modclient: redis.Redis): +async def test_explain(decoded_r: redis.Redis): await ( - modclient.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) + decoded_r.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) ) - res = await modclient.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") + res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @pytest.mark.redismod -async def test_explaincli(modclient: redis.Redis): +async def test_explaincli(decoded_r: redis.Redis): with pytest.raises(NotImplementedError): - await modclient.ft().explain_cli("foo") + await decoded_r.ft().explain_cli("foo") @pytest.mark.redismod -async def test_summarize(modclient: redis.Redis): - await createIndex(modclient.ft()) - await waitForIndex(modclient, "idx") +async def test_summarize(decoded_r: redis.Redis): + await createIndex(decoded_r.ft()) + await waitForIndex(decoded_r, "idx") q = Query("king henry").paging(0, 1) q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") - doc = sorted((await modclient.ft().search(q)).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(decoded_r): + doc = sorted((await decoded_r.ft().search(q)).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted((await modclient.ft().search(q)).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + doc = sorted((await decoded_r.ft().search(q)).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted((await decoded_r.ft().search(q))["results"])[0] + assert "Henry IV" == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted((await decoded_r.ft().search(q))["results"])[0] + assert "Henry ... " == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") -async def test_alias(modclient: redis.Redis): - index1 = getClient(modclient) - index2 = getClient(modclient) +async def test_alias(decoded_r: redis.Redis): + index1 = getClient(decoded_r) + index2 = getClient(decoded_r) def1 = IndexDefinition(prefix=["index1:"]) def2 = IndexDefinition(prefix=["index2:"]) @@ -515,25 +712,46 @@ async def test_alias(modclient: redis.Redis): await index1.hset("index1:lonestar", mapping={"name": "lonestar"}) await index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = (await ftindex1.search("*")).docs[0] - assert "index1:lonestar" == res.id + if is_resp2_connection(decoded_r): + res = (await ftindex1.search("*")).docs[0] + assert "index1:lonestar" == res.id - # create alias and check for results - await ftindex1.aliasadd("spaceballs") - alias_client = getClient(modclient).ft("spaceballs") - res = (await alias_client.search("*")).docs[0] - assert "index1:lonestar" == res.id + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = getClient(decoded_r).ft("spaceballs") + res = (await alias_client.search("*")).docs[0] + assert "index1:lonestar" == res.id - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await ftindex2.aliasadd("spaceballs") + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(decoded_r).ft("spaceballs") + + res = (await alias_client2.search("*")).docs[0] + assert "index2:yogurt" == res.id + else: + res = (await ftindex1.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - # update alias and ensure new results - await ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(modclient).ft("spaceballs") + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = getClient(await decoded_r).ft("spaceballs") + res = (await alias_client.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - res = (await alias_client2.search("*")).docs[0] - assert "index2:yogurt" == res.id + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(await decoded_r).ft("spaceballs") + + res = (await alias_client2.search("*"))["results"][0] + assert "index2:yogurt" == res["id"] await ftindex2.aliasdel("spaceballs") with pytest.raises(Exception): @@ -541,34 +759,51 @@ async def test_alias(modclient: redis.Redis): @pytest.mark.redismod -async def test_alias_basic(modclient: redis.Redis): +@pytest.mark.xfail(strict=False) +async def test_alias_basic(decoded_r: redis.Redis): # Creating a client with one index - client = getClient(modclient) + client = getClient(decoded_r) await client.flushdb() - index1 = getClient(modclient).ft("testAlias") + index1 = getClient(decoded_r).ft("testAlias") await index1.create_index((TextField("txt"),)) await index1.client.hset("doc1", mapping={"txt": "text goes here"}) - index2 = getClient(modclient).ft("testAlias2") + index2 = getClient(decoded_r).ft("testAlias2") await index2.create_index((TextField("txt"),)) await index2.client.hset("doc2", mapping={"txt": "text goes here"}) # add the actual alias and check await index1.aliasadd("myalias") - alias_client = getClient(modclient).ft("myalias") - res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - await index2.aliasupdate("myalias") - alias_client2 = getClient(modclient).ft("myalias") - res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id + alias_client = getClient(decoded_r).ft("myalias") + if is_resp2_connection(decoded_r): + res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = getClient(decoded_r).ft("myalias") + res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted((await alias_client.search("*"))["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted( + (await alias_client2.search("*"))["results"], key=lambda x: x["id"] + ) + assert "doc1" == res[0]["id"] # delete the alias and expect an error if we try to query again await index2.aliasdel("myalias") @@ -577,56 +812,79 @@ async def test_alias_basic(modclient: redis.Redis): @pytest.mark.redismod -async def test_tags(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"), TagField("tags"))) +async def test_tags(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"), TagField("tags"))) tags = "foo,foo bar,hello;world" tags2 = "soba,ramen" - await modclient.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) - await modclient.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) - await waitForIndex(modclient, "idx") + await decoded_r.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) + await decoded_r.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) + await waitForIndex(decoded_r, "idx") q = Query("@tags:{foo}") - res = await modclient.ft().search(q) - assert 1 == res.total + if is_resp2_connection(decoded_r): + res = await decoded_r.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo bar}") + res = await decoded_r.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo\\ bar}") + res = await decoded_r.ft().search(q) + assert 1 == res.total - q = Query("@tags:{foo bar}") - res = await modclient.ft().search(q) - assert 1 == res.total + q = Query("@tags:{hello\\;world}") + res = await decoded_r.ft().search(q) + assert 1 == res.total - q = Query("@tags:{foo\\ bar}") - res = await modclient.ft().search(q) - assert 1 == res.total + q2 = await decoded_r.ft().tagvals("tags") + assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + else: + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - q = Query("@tags:{hello\\;world}") - res = await modclient.ft().search(q) - assert 1 == res.total + q = Query("@tags:{foo bar}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - q2 = await modclient.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + q = Query("@tags:{foo\\ bar}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] + + q = Query("@tags:{hello\\;world}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] + + q2 = await decoded_r.ft().tagvals("tags") + assert set(tags.split(",") + tags2.split(",")) == q2 @pytest.mark.redismod -async def test_textfield_sortable_nostem(modclient: redis.Redis): +async def test_textfield_sortable_nostem(decoded_r: redis.Redis): # Creating the index definition with sortable and no_stem - await modclient.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) + await decoded_r.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) # Now get the index info to confirm its contents - response = await modclient.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] + response = await decoded_r.ft().info() + if is_resp2_connection(decoded_r): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] @pytest.mark.redismod -async def test_alter_schema_add(modclient: redis.Redis): +async def test_alter_schema_add(decoded_r: redis.Redis): # Creating the index definition and schema - await modclient.ft().create_index(TextField("title")) + await decoded_r.ft().create_index(TextField("title")) # Using alter to add a field - await modclient.ft().alter_schema_add(TextField("body")) + await decoded_r.ft().alter_schema_add(TextField("body")) # Indexing a document - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"title": "MyTitle", "body": "Some content only in the body"} ) @@ -634,167 +892,236 @@ async def test_alter_schema_add(modclient: redis.Redis): q = Query("only in the body") # Ensure we find the result searching on the added body field - res = await modclient.ft().search(q) - assert 1 == res.total + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 1 == res.total + else: + assert 1 == res["total_results"] @pytest.mark.redismod -async def test_spell_check(modclient: redis.Redis): - await modclient.ft().create_index((TextField("f1"), TextField("f2"))) +async def test_spell_check(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) await ( - modclient.hset( + decoded_r.hset( "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) ) - await modclient.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) - await waitForIndex(modclient, "idx") - - # test spellcheck - res = await modclient.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = await modclient.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = await modclient.ft().spellcheck("vlis") - assert res == {} - res = await modclient.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - await modclient.ft().dict_add("dict", "lore", "lorem", "lorm") - res = await modclient.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = await modclient.ft().spellcheck("lorm", exclude="dict") - assert res == {} + await decoded_r.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) + await waitForIndex(decoded_r, "idx") + + if is_resp2_connection(decoded_r): + # test spellcheck + res = await decoded_r.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = await decoded_r.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = await decoded_r.ft().spellcheck("vlis") + assert res == {} + res = await decoded_r.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await decoded_r.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = await decoded_r.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = await decoded_r.ft().spellcheck("impornant") + assert "important" in res["results"]["impornant"][0].keys() + + res = await decoded_r.ft().spellcheck("contnt") + assert "content" in res["results"]["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = await decoded_r.ft().spellcheck("vlis") + assert res == {"results": {"vlis": []}} + res = await decoded_r.ft().spellcheck("vlis", distance=2) + assert "valid" in res["results"]["vlis"][0].keys() + + # test spellcheck include + await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await decoded_r.ft().spellcheck("lorm", include="dict") + assert len(res["results"]["lorm"]) == 3 + assert "lorem" in res["results"]["lorm"][0].keys() + assert "lore" in res["results"]["lorm"][1].keys() + assert "lorm" in res["results"]["lorm"][2].keys() + assert ( + res["results"]["lorm"][0]["lorem"], + res["results"]["lorm"][1]["lore"], + ) == (0.5, 0) + + # test spellcheck exclude + res = await decoded_r.ft().spellcheck("lorm", exclude="dict") + assert res == {"results": {}} @pytest.mark.redismod -async def test_dict_operations(modclient: redis.Redis): - await modclient.ft().create_index((TextField("f1"), TextField("f2"))) +async def test_dict_operations(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) # Add three items - res = await modclient.ft().dict_add("custom_dict", "item1", "item2", "item3") + res = await decoded_r.ft().dict_add("custom_dict", "item1", "item2", "item3") assert 3 == res # Remove one item - res = await modclient.ft().dict_del("custom_dict", "item2") + res = await decoded_r.ft().dict_del("custom_dict", "item2") assert 1 == res # Dump dict and inspect content - res = await modclient.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res + res = await decoded_r.ft().dict_dump("custom_dict") + assert_resp_response(decoded_r, res, ["item1", "item3"], {"item1", "item3"}) # Remove rest of the items before reload - await modclient.ft().dict_del("custom_dict", *res) + await decoded_r.ft().dict_del("custom_dict", *res) @pytest.mark.redismod -async def test_phonetic_matcher(modclient: redis.Redis): - await modclient.ft().create_index((TextField("name"),)) - await modclient.hset("doc1", mapping={"name": "Jon"}) - await modclient.hset("doc2", mapping={"name": "John"}) - - res = await modclient.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name +async def test_phonetic_matcher(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("name"),)) + await decoded_r.hset("doc1", mapping={"name": "Jon"}) + await decoded_r.hset("doc2", mapping={"name": "John"}) + + res = await decoded_r.ft().search(Query("Jon")) + if is_resp2_connection(decoded_r): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] # Drop and create index with phonetic matcher - await modclient.flushdb() - - await modclient.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - await modclient.hset("doc1", mapping={"name": "Jon"}) - await modclient.hset("doc2", mapping={"name": "John"}) - - res = await modclient.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) + await decoded_r.flushdb() + + await decoded_r.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) + await decoded_r.hset("doc1", mapping={"name": "Jon"}) + await decoded_r.hset("doc2", mapping={"name": "John"}) + + res = await decoded_r.ft().search(Query("Jon")) + if is_resp2_connection(decoded_r): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted( + d["extra_attributes"]["name"] for d in res["results"] + ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_scorer(modclient: redis.Redis): - await modclient.ft().create_index((TextField("description"),)) +async def test_scorer(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("description"),)) - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} ) - await modclient.hset( + await decoded_r.hset( "doc2", mapping={ "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa }, ) - # default scorer is TFIDF - res = await modclient.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = await ( - modclient.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - ) - assert 0.1111111111111111 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + if is_resp2_connection(decoded_r): + # default scorer is TFIDF + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = await ( + decoded_r.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + ) + assert 0.1111111111111111 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res.docs[0].score + else: + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.1111111111111111 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res["results"][0]["score"] @pytest.mark.redismod -async def test_get(modclient: redis.Redis): - await modclient.ft().create_index((TextField("f1"), TextField("f2"))) +async def test_get(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - assert [None] == await modclient.ft().get("doc1") - assert [None, None] == await modclient.ft().get("doc2", "doc1") + assert [None] == await decoded_r.ft().get("doc1") + assert [None, None] == await decoded_r.ft().get("doc2", "doc1") - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"} ) - await modclient.hset( + await decoded_r.hset( "doc2", mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"} ) assert [ ["f1", "some valid content dd2", "f2", "this is sample text f2"] - ] == await modclient.ft().get("doc2") + ] == await decoded_r.ft().get("doc2") assert [ ["f1", "some valid content dd1", "f2", "this is sample text f1"], ["f1", "some valid content dd2", "f2", "this is sample text f2"], - ] == await modclient.ft().get("doc1", "doc2") + ] == await decoded_r.ft().get("doc1", "doc2") @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.2.0", "search") -async def test_config(modclient: redis.Redis): - assert await modclient.ft().config_set("TIMEOUT", "100") +async def test_config(decoded_r: redis.Redis): + assert await decoded_r.ft().config_set("TIMEOUT", "100") with pytest.raises(redis.ResponseError): - await modclient.ft().config_set("TIMEOUT", "null") - res = await modclient.ft().config_get("*") + await decoded_r.ft().config_set("TIMEOUT", "null") + res = await decoded_r.ft().config_get("*") assert "100" == res["TIMEOUT"] - res = await modclient.ft().config_get("TIMEOUT") + res = await decoded_r.ft().config_get("TIMEOUT") assert "100" == res["TIMEOUT"] @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_aggregations_groupby(modclient: redis.Redis): +async def test_aggregations_groupby(decoded_r: redis.Redis): # Creating the index definition and schema - await modclient.ft().create_index( + await decoded_r.ft().create_index( ( NumericField("random_num"), TextField("title"), @@ -804,7 +1131,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): ) # Indexing a document - await modclient.hset( + await decoded_r.hset( "search", mapping={ "title": "RediSearch", @@ -813,7 +1140,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): "random_num": 10, }, ) - await modclient.hset( + await decoded_r.hset( "ai", mapping={ "title": "RedisAI", @@ -822,7 +1149,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): "random_num": 3, }, ) - await modclient.hset( + await decoded_r.hset( "json", mapping={ "title": "RedisJson", @@ -833,207 +1160,408 @@ async def test_aggregations_groupby(modclient: redis.Redis): ) for dialect in [1, 2]: - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count()) - .dialect(dialect) - ) + if is_resp2_connection(decoded_r): + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinct("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinctish("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.sum("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.min("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.max("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.avg("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "7" # (10+3+8)/3 + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "7" # (10+3+8)/3 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.stddev("random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.quantile("@random_num", 0.5)) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.tolist("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.first_value("@title").alias("first")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount"] == "3" - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distinctishtitle"] + == "3" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasstddevrandom_num"] + == "3.60555127546" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] + == "8" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + + req = ( + aggregations.AggregateRequest("redis") + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert "random" in res["extra_attributes"].keys() + assert len(res["extra_attributes"]["random"]) == 2 + assert res["extra_attributes"]["random"][0] in [ + "RediSearch", + "RedisAI", + "RedisJson", + ] @pytest.mark.redismod -async def test_aggregations_sort_by_and_limit(modclient: redis.Redis): - await modclient.ft().create_index((TextField("t1"), TextField("t2"))) +async def test_aggregations_sort_by_and_limit(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("t1"), TextField("t2"))) - await modclient.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) - await modclient.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) + await decoded_r.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) + await decoded_r.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = await modclient.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + if is_resp2_connection(decoded_r): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = await decoded_r.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = await decoded_r.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = (await decoded_r.ft().aggregate(req))["results"] + assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} + assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = await modclient.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = (await decoded_r.ft().aggregate(req))["results"] + assert res[0]["extra_attributes"] == {"t1": "a"} + assert res[1]["extra_attributes"] == {"t1": "b"} - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = await modclient.ft().aggregate(req) - assert len(res.rows) == 1 + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await decoded_r.ft().aggregate(req) + assert len(res["results"]) == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = await modclient.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await decoded_r.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"t1": "b"} @pytest.mark.redismod @pytest.mark.experimental -async def test_withsuffixtrie(modclient: redis.Redis): +async def test_withsuffixtrie(decoded_r: redis.Redis): # create index - assert await modclient.ft().create_index((TextField("txt"),)) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert await modclient.ft().dropindex("idx") - - # create withsuffixtrie index (text field) - assert await modclient.ft().create_index((TextField("t", withsuffixtrie=True))) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert await modclient.ft().dropindex("idx") - - # create withsuffixtrie index (tag field) - assert await modclient.ft().create_index((TagField("t", withsuffixtrie=True))) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await decoded_r.ft().create_index((TextField("txt"),)) + await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + if is_resp2_connection(decoded_r): + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (text field) + assert await decoded_r.ft().create_index((TextField("t", withsuffixtrie=True))) + await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await decoded_r.ft().create_index((TagField("t", withsuffixtrie=True))) + await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert await decoded_r.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await decoded_r.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] @pytest.mark.redismod @skip_if_redis_enterprise() -async def test_search_commands_in_pipeline(modclient: redis.Redis): - p = await modclient.ft().pipeline() +async def test_search_commands_in_pipeline(decoded_r: redis.Redis): + p = await decoded_r.ft().pipeline() p.create_index((TextField("txt"),)) p.hset("doc1", mapping={"txt": "foo bar"}) p.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo bar").with_payloads() await p.search(q) res = await p.execute() - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] + if is_resp2_connection(decoded_r): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["extra_attributes"] + == res[3]["results"][1]["extra_attributes"] + == {"txt": "foo bar"} + ) @pytest.mark.redismod -async def test_query_timeout(modclient: redis.Redis): +async def test_query_timeout(decoded_r: redis.Redis): q1 = Query("foo").timeout(5000) assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): - await modclient.ft().search(q2) + await decoded_r.ft().search(q2) diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 7866056374..2091f2cb87 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -2,7 +2,6 @@ import pytest import pytest_asyncio - import redis.asyncio.sentinel from redis import exceptions from redis.asyncio.sentinel import ( diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index a6e9f37a63..e784690c77 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -1,7 +1,6 @@ import socket import pytest - from redis.asyncio.retry import Retry from redis.asyncio.sentinel import SentinelManagedConnection from redis.backoff import NoBackoff diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index a7109938f2..48ffdfd889 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -2,230 +2,253 @@ from time import sleep import pytest - import redis.asyncio as redis -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_ifmodversion_lt, +) @pytest.mark.redismod -async def test_create(modclient: redis.Redis): - assert await modclient.ts().create(1) - assert await modclient.ts().create(2, retention_msecs=5) - assert await modclient.ts().create(3, labels={"Redis": "Labs"}) - assert await modclient.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) - info = await modclient.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] +async def test_create(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + assert await decoded_r.ts().create(2, retention_msecs=5) + assert await decoded_r.ts().create(3, labels={"Redis": "Labs"}) + assert await decoded_r.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) + info = await decoded_r.ts().info(4) + assert_resp_response( + decoded_r, 20, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Series" == info["labels"]["Time"] # Test for a chunk size of 128 Bytes - assert await modclient.ts().create("time-serie-1", chunk_size=128) - info = await modclient.ts().info("time-serie-1") - assert 128, info.chunk_size + assert await decoded_r.ts().create("time-serie-1", chunk_size=128) + info = await decoded_r.ts().info("time-serie-1") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def test_create_duplicate_policy(modclient: redis.Redis): +async def test_create_duplicate_policy(decoded_r: redis.Redis): # Test for duplicate policy for duplicate_policy in ["block", "last", "first", "min", "max"]: ts_name = f"time-serie-ooo-{duplicate_policy}" - assert await modclient.ts().create(ts_name, duplicate_policy=duplicate_policy) - info = await modclient.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy + assert await decoded_r.ts().create(ts_name, duplicate_policy=duplicate_policy) + info = await decoded_r.ts().info(ts_name) + assert_resp_response( + decoded_r, + duplicate_policy, + info.get("duplicate_policy"), + info.get("duplicatePolicy"), + ) @pytest.mark.redismod -async def test_alter(modclient: redis.Redis): - assert await modclient.ts().create(1) - res = await modclient.ts().info(1) - assert 0 == res.retention_msecs - assert await modclient.ts().alter(1, retention_msecs=10) - res = await modclient.ts().info(1) - assert {} == res.labels - res = await modclient.ts().info(1) - assert 10 == res.retention_msecs - assert await modclient.ts().alter(1, labels={"Time": "Series"}) - res = await modclient.ts().info(1) - assert "Series" == res.labels["Time"] - res = await modclient.ts().info(1) - assert 10 == res.retention_msecs +async def test_alter(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + res = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 0, res.get("retention_msecs"), res.get("retentionTime") + ) + assert await decoded_r.ts().alter(1, retention_msecs=10) + res = await decoded_r.ts().info(1) + assert {} == (await decoded_r.ts().info(1))["labels"] + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert await decoded_r.ts().alter(1, labels={"Time": "Series"}) + res = await decoded_r.ts().info(1) + assert "Series" == (await decoded_r.ts().info(1))["labels"]["Time"] + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 10, info.get("retention_msecs"), info.get("retentionTime") + ) @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def test_alter_diplicate_policy(modclient: redis.Redis): - assert await modclient.ts().create(1) - info = await modclient.ts().info(1) - assert info.duplicate_policy is None - assert await modclient.ts().alter(1, duplicate_policy="min") - info = await modclient.ts().info(1) - assert "min" == info.duplicate_policy +async def test_alter_diplicate_policy(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + assert await decoded_r.ts().alter(1, duplicate_policy="min") + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod -async def test_add(modclient: redis.Redis): - assert 1 == await modclient.ts().add(1, 1, 1) - assert 2 == await modclient.ts().add(2, 2, 3, retention_msecs=10) - assert 3 == await modclient.ts().add(3, 3, 2, labels={"Redis": "Labs"}) - assert 4 == await modclient.ts().add( +async def test_add(decoded_r: redis.Redis): + assert 1 == await decoded_r.ts().add(1, 1, 1) + assert 2 == await decoded_r.ts().add(2, 2, 3, retention_msecs=10) + assert 3 == await decoded_r.ts().add(3, 3, 2, labels={"Redis": "Labs"}) + assert 4 == await decoded_r.ts().add( 4, 4, 2, retention_msecs=10, labels={"Redis": "Labs", "Time": "Series"} ) - res = await modclient.ts().add(5, "*", 1) + res = await decoded_r.ts().add(5, "*", 1) assert abs(time.time() - round(float(res) / 1000)) < 1.0 - info = await modclient.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] + info = await decoded_r.ts().info(4) + assert_resp_response( + decoded_r, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Labs" == info["labels"]["Redis"] # Test for a chunk size of 128 Bytes on TS.ADD - assert await modclient.ts().add("time-serie-1", 1, 10.0, chunk_size=128) - info = await modclient.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert await decoded_r.ts().add("time-serie-1", 1, 10.0, chunk_size=128) + info = await decoded_r.ts().info("time-serie-1") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def test_add_duplicate_policy(modclient: redis.Redis): +async def test_add_duplicate_policy(r: redis.Redis): # Test for duplicate policy BLOCK - assert 1 == await modclient.ts().add("time-serie-add-ooo-block", 1, 5.0) + assert 1 == await r.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): - await modclient.ts().add( - "time-serie-add-ooo-block", 1, 5.0, duplicate_policy="block" - ) + await r.ts().add("time-serie-add-ooo-block", 1, 5.0, duplicate_policy="block") # Test for duplicate policy LAST - assert 1 == await modclient.ts().add("time-serie-add-ooo-last", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-last", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" ) - res = await modclient.ts().get("time-serie-add-ooo-last") + res = await r.ts().get("time-serie-add-ooo-last") assert 10.0 == res[1] # Test for duplicate policy FIRST - assert 1 == await modclient.ts().add("time-serie-add-ooo-first", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-first", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" ) - res = await modclient.ts().get("time-serie-add-ooo-first") + res = await r.ts().get("time-serie-add-ooo-first") assert 5.0 == res[1] # Test for duplicate policy MAX - assert 1 == await modclient.ts().add("time-serie-add-ooo-max", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-max", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" ) - res = await modclient.ts().get("time-serie-add-ooo-max") + res = await r.ts().get("time-serie-add-ooo-max") assert 10.0 == res[1] # Test for duplicate policy MIN - assert 1 == await modclient.ts().add("time-serie-add-ooo-min", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-min", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" ) - res = await modclient.ts().get("time-serie-add-ooo-min") + res = await r.ts().get("time-serie-add-ooo-min") assert 5.0 == res[1] @pytest.mark.redismod -async def test_madd(modclient: redis.Redis): - await modclient.ts().create("a") - assert [1, 2, 3] == await modclient.ts().madd( +async def test_madd(decoded_r: redis.Redis): + await decoded_r.ts().create("a") + assert [1, 2, 3] == await decoded_r.ts().madd( [("a", 1, 5), ("a", 2, 10), ("a", 3, 15)] ) @pytest.mark.redismod -async def test_incrby_decrby(modclient: redis.Redis): +async def test_incrby_decrby(decoded_r: redis.Redis): for _ in range(100): - assert await modclient.ts().incrby(1, 1) + assert await decoded_r.ts().incrby(1, 1) sleep(0.001) - assert 100 == (await modclient.ts().get(1))[1] + assert 100 == (await decoded_r.ts().get(1))[1] for _ in range(100): - assert await modclient.ts().decrby(1, 1) + assert await decoded_r.ts().decrby(1, 1) sleep(0.001) - assert 0 == (await modclient.ts().get(1))[1] + assert 0 == (await decoded_r.ts().get(1))[1] - assert await modclient.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == await modclient.ts().get(2) - assert await modclient.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == await modclient.ts().get(2) - assert await modclient.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == await modclient.ts().get(2) + assert await decoded_r.ts().incrby(2, 1.5, timestamp=5) + assert_resp_response(decoded_r, await decoded_r.ts().get(2), (5, 1.5), [5, 1.5]) + assert await decoded_r.ts().incrby(2, 2.25, timestamp=7) + assert_resp_response(decoded_r, await decoded_r.ts().get(2), (7, 3.75), [7, 3.75]) + assert await decoded_r.ts().decrby(2, 1.5, timestamp=15) + assert_resp_response(decoded_r, await decoded_r.ts().get(2), (15, 2.25), [15, 2.25]) # Test for a chunk size of 128 Bytes on TS.INCRBY - assert await modclient.ts().incrby("time-serie-1", 10, chunk_size=128) - info = await modclient.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert await decoded_r.ts().incrby("time-serie-1", 10, chunk_size=128) + info = await decoded_r.ts().info("time-serie-1") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) # Test for a chunk size of 128 Bytes on TS.DECRBY - assert await modclient.ts().decrby("time-serie-2", 10, chunk_size=128) - info = await modclient.ts().info("time-serie-2") - assert 128 == info.chunk_size + assert await decoded_r.ts().decrby("time-serie-2", 10, chunk_size=128) + info = await decoded_r.ts().info("time-serie-2") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod -async def test_create_and_delete_rule(modclient: redis.Redis): +async def test_create_and_delete_rule(decoded_r: redis.Redis): # test rule creation time = 100 - await modclient.ts().create(1) - await modclient.ts().create(2) - await modclient.ts().createrule(1, 2, "avg", 100) + await decoded_r.ts().create(1) + await decoded_r.ts().create(2) + await decoded_r.ts().createrule(1, 2, "avg", 100) for i in range(50): - await modclient.ts().add(1, time + i * 2, 1) - await modclient.ts().add(1, time + i * 2 + 1, 2) - await modclient.ts().add(1, time * 2, 1.5) - assert round((await modclient.ts().get(2))[1], 5) == 1.5 - info = await modclient.ts().info(1) - assert info.rules[0][1] == 100 + await decoded_r.ts().add(1, time + i * 2, 1) + await decoded_r.ts().add(1, time + i * 2 + 1, 2) + await decoded_r.ts().add(1, time * 2, 1.5) + assert round((await decoded_r.ts().get(2))[1], 5) == 1.5 + info = await decoded_r.ts().info(1) + if is_resp2_connection(decoded_r): + assert info.rules[0][1] == 100 + else: + assert info["rules"]["2"][0] == 100 # test rule deletion - await modclient.ts().deleterule(1, 2) - info = await modclient.ts().info(1) - assert not info.rules + await decoded_r.ts().deleterule(1, 2) + info = await decoded_r.ts().info(1) + assert not info["rules"] @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_del_range(modclient: redis.Redis): +async def test_del_range(decoded_r: redis.Redis): try: - await modclient.ts().delete("test", 0, 100) + await decoded_r.ts().delete("test", 0, 100) except Exception as e: assert e.__str__() != "" for i in range(100): - await modclient.ts().add(1, i, i % 7) - assert 22 == await modclient.ts().delete(1, 0, 21) - assert [] == await modclient.ts().range(1, 0, 21) - assert [(22, 1.0)] == await modclient.ts().range(1, 22, 22) + await decoded_r.ts().add(1, i, i % 7) + assert 22 == await decoded_r.ts().delete(1, 0, 21) + assert [] == await decoded_r.ts().range(1, 0, 21) + assert_resp_response( + decoded_r, await decoded_r.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]] + ) @pytest.mark.redismod -async def test_range(modclient: redis.Redis): +async def test_range(r: redis.Redis): for i in range(100): - await modclient.ts().add(1, i, i % 7) - assert 100 == len(await modclient.ts().range(1, 0, 200)) + await r.ts().add(1, i, i % 7) + assert 100 == len(await r.ts().range(1, 0, 200)) for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - assert 200 == len(await modclient.ts().range(1, 0, 500)) + await r.ts().add(1, i + 200, i % 7) + assert 200 == len(await r.ts().range(1, 0, 500)) # last sample isn't returned assert 20 == len( - await modclient.ts().range( - 1, 0, 500, aggregation_type="avg", bucket_size_msec=10 - ) + await r.ts().range(1, 0, 500, aggregation_type="avg", bucket_size_msec=10) ) - assert 10 == len(await modclient.ts().range(1, 0, 500, count=10)) + assert 10 == len(await r.ts().range(1, 0, 500, count=10)) @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_range_advanced(modclient: redis.Redis): +async def test_range_advanced(decoded_r: redis.Redis): for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(1, i + 200, i % 7) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(1, i + 200, i % 7) assert 2 == len( - await modclient.ts().range( + await decoded_r.ts().range( 1, 0, 500, @@ -234,35 +257,38 @@ async def test_range_advanced(modclient: redis.Redis): filter_by_max_value=2, ) ) - assert [(0, 10.0), (10, 1.0)] == await modclient.ts().range( + res = await decoded_r.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" ) - assert [(0, 5.0), (5, 6.0)] == await modclient.ts().range( + assert_resp_response(decoded_r, res, [(0, 10.0), (10, 1.0)], [[0, 10.0], [10, 1.0]]) + res = await decoded_r.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert [(0, 2.55), (10, 3.0)] == await modclient.ts().range( + assert_resp_response(decoded_r, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = await decoded_r.ts().range( 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 ) + assert_resp_response(decoded_r, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_rev_range(modclient: redis.Redis): +async def test_rev_range(decoded_r: redis.Redis): for i in range(100): - await modclient.ts().add(1, i, i % 7) - assert 100 == len(await modclient.ts().range(1, 0, 200)) + await decoded_r.ts().add(1, i, i % 7) + assert 100 == len(await decoded_r.ts().range(1, 0, 200)) for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - assert 200 == len(await modclient.ts().range(1, 0, 500)) + await decoded_r.ts().add(1, i + 200, i % 7) + assert 200 == len(await decoded_r.ts().range(1, 0, 500)) # first sample isn't returned assert 20 == len( - await modclient.ts().revrange( + await decoded_r.ts().revrange( 1, 0, 500, aggregation_type="avg", bucket_size_msec=10 ) ) - assert 10 == len(await modclient.ts().revrange(1, 0, 500, count=10)) + assert 10 == len(await decoded_r.ts().revrange(1, 0, 500, count=10)) assert 2 == len( - await modclient.ts().revrange( + await decoded_r.ts().revrange( 1, 0, 500, @@ -271,287 +297,471 @@ async def test_rev_range(modclient: redis.Redis): filter_by_max_value=2, ) ) - assert [(10, 1.0), (0, 10.0)] == await modclient.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + assert_resp_response( + decoded_r, + await decoded_r.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ), + [(10, 1.0), (0, 10.0)], + [[10, 1.0], [0, 10.0]], ) - assert [(1, 10.0), (0, 1.0)] == await modclient.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + assert_resp_response( + decoded_r, + await decoded_r.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ), + [(1, 10.0), (0, 1.0)], + [[1, 10.0], [0, 1.0]], ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def testMultiRange(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) - await modclient.ts().create( +async def test_multi_range(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) + await decoded_r.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} ) for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(2, i, i % 11) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(2, i, i % 11) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(decoded_r): + assert 100 == len(res[0]["1"][1]) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - res = await modclient.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = await decoded_r.ts().mrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert 100 == len(res["1"][2]) + + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) - # test withlabels - assert {} == res[0]["1"][0] - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) + + # test withlabels + assert {} == res["1"][0] + res = await decoded_r.ts().mrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res["1"][0] @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_multi_range_advanced(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) - await modclient.ts().create( +async def test_multi_range_advanced(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) + await decoded_r.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} ) for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(2, i, i % 11) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(2, i, i % 11) # test with selected labels - res = await modclient.ts().mrange( + res = await decoded_r.ts().mrange( 0, 200, filters=["Test=This"], select_labels=["team"] ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(decoded_r): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test with filterby - res = await modclient.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + # test with filterby + res = await decoded_r.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - # test groupby - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - - # test align - res = await modclient.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = await modclient.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + # test groupby + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + + # test align + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test with filterby + res = await decoded_r.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[15, 1.0], [16, 2.0]] == res["1"][2] + + # test groupby + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] + + # test align + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[0, 10.0], [10, 1.0]] == res["1"][2] + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [[0, 5.0], [5, 6.0]] == res["1"][2] @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_multi_reverse_range(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) - await modclient.ts().create( +async def test_multi_reverse_range(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) + await decoded_r.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} ) for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(2, i, i % 11) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(2, i, i % 11) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(decoded_r): + assert 100 == len(res[0]["1"][1]) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - res = await modclient.ts().mrevrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrevrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] - # test withlabels - res = await modclient.ts().mrevrange( - 0, 200, filters=["Test=This"], with_labels=True - ) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - # test with selected labels - res = await modclient.ts().mrevrange( - 0, 200, filters=["Test=This"], select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + # test with selected labels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test filterby - res = await modclient.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + # test filterby + res = await decoded_r.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] - # test groupby - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] - - # test align - res = await modclient.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] - res = await modclient.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=1, - ) - assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + # test groupby + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + + # test align + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + else: + assert 100 == len(res["1"][2]) + + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrevrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) + assert {} == res["1"][0] + + # test withlabels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res["1"][0] + + # test with selected labels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test filterby + res = await decoded_r.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[16, 2.0], [15, 1.0]] == res["1"][2] + + # test groupby + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[3, 6.0], [2, 4.0], [1, 2.0], [0, 0.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=ny"][3] + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=sf"][3] + + # test align + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[10, 1.0], [0, 10.0]] == res["1"][2] + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [[1, 10.0], [0, 1.0]] == res["1"][2] @pytest.mark.redismod -async def test_get(modclient: redis.Redis): +async def test_get(decoded_r: redis.Redis): name = "test" - await modclient.ts().create(name) - assert await modclient.ts().get(name) is None - await modclient.ts().add(name, 2, 3) - assert 2 == (await modclient.ts().get(name))[0] - await modclient.ts().add(name, 3, 4) - assert 4 == (await modclient.ts().get(name))[1] + await decoded_r.ts().create(name) + assert not await decoded_r.ts().get(name) + await decoded_r.ts().add(name, 2, 3) + assert 2 == (await decoded_r.ts().get(name))[0] + await decoded_r.ts().add(name, 3, 4) + assert 4 == (await decoded_r.ts().get(name))[1] @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_mget(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This"}) - await modclient.ts().create(2, labels={"Test": "This", "Taste": "That"}) - act_res = await modclient.ts().mget(["Test=This"]) +async def test_mget(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This"}) + await decoded_r.ts().create(2, labels={"Test": "This", "Taste": "That"}) + act_res = await decoded_r.ts().mget(["Test=This"]) exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res - await modclient.ts().add(1, "*", 15) - await modclient.ts().add(2, "*", 25) - res = await modclient.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] - res = await modclient.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] + exp_res_resp3 = {"1": [{}, []], "2": [{}, []]} + assert_resp_response(decoded_r, act_res, exp_res, exp_res_resp3) + await decoded_r.ts().add(1, "*", 15) + await decoded_r.ts().add(2, "*", 25) + res = await decoded_r.ts().mget(["Test=This"]) + if is_resp2_connection(decoded_r): + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + else: + assert 15 == res["1"][1][1] + assert 25 == res["2"][1][1] + res = await decoded_r.ts().mget(["Taste=That"]) + if is_resp2_connection(decoded_r): + assert 25 == res[0]["2"][2] + else: + assert 25 == res["2"][1][1] # test with_labels - assert {} == res[0]["2"][0] - res = await modclient.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + if is_resp2_connection(decoded_r): + assert {} == res[0]["2"][0] + else: + assert {} == res["2"][0] + res = await decoded_r.ts().mget(["Taste=That"], with_labels=True) + if is_resp2_connection(decoded_r): + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + else: + assert {"Taste": "That", "Test": "This"} == res["2"][0] @pytest.mark.redismod -async def test_info(modclient: redis.Redis): - await modclient.ts().create( +async def test_info(decoded_r: redis.Redis): + await decoded_r.ts().create( 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) - info = await modclient.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 5, info.get("retention_msecs"), info.get("retentionTime") + ) + assert info["labels"]["currentLabel"] == "currentData" @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def testInfoDuplicatePolicy(modclient: redis.Redis): - await modclient.ts().create( +async def testInfoDuplicatePolicy(decoded_r: redis.Redis): + await decoded_r.ts().create( 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) - info = await modclient.ts().info(1) - assert info.duplicate_policy is None + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) - await modclient.ts().create("time-serie-2", duplicate_policy="min") - info = await modclient.ts().info("time-serie-2") - assert "min" == info.duplicate_policy + await decoded_r.ts().create("time-serie-2", duplicate_policy="min") + info = await decoded_r.ts().info("time-serie-2") + assert_resp_response( + decoded_r, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_query_index(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This"}) - await modclient.ts().create(2, labels={"Test": "This", "Taste": "That"}) - assert 2 == len(await modclient.ts().queryindex(["Test=This"])) - assert 1 == len(await modclient.ts().queryindex(["Taste=That"])) - assert [2] == await modclient.ts().queryindex(["Taste=That"]) +async def test_query_index(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This"}) + await decoded_r.ts().create(2, labels={"Test": "This", "Taste": "That"}) + assert 2 == len(await decoded_r.ts().queryindex(["Test=This"])) + assert 1 == len(await decoded_r.ts().queryindex(["Taste=That"])) + assert_resp_response( + decoded_r, await decoded_r.ts().queryindex(["Taste=That"]), [2], {"2"} + ) # @pytest.mark.redismod -# async def test_pipeline(modclient: redis.Redis): -# pipeline = await modclient.ts().pipeline() +# async def test_pipeline(r: redis.Redis): +# pipeline = await r.ts().pipeline() # pipeline.create("with_pipeline") # for i in range(100): # pipeline.add("with_pipeline", i, 1.1 * i) # pipeline.execute() -# info = await modclient.ts().info("with_pipeline") +# info = await r.ts().info("with_pipeline") # assert info.lastTimeStamp == 99 # assert info.total_samples == 100 -# assert await modclient.ts().get("with_pipeline")[1] == 99 * 1.1 +# assert await r.ts().get("with_pipeline")[1] == 99 * 1.1 @pytest.mark.redismod -async def test_uncompressed(modclient: redis.Redis): - await modclient.ts().create("compressed") - await modclient.ts().create("uncompressed", uncompressed=True) - compressed_info = await modclient.ts().info("compressed") - uncompressed_info = await modclient.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage +async def test_uncompressed(decoded_r: redis.Redis): + await decoded_r.ts().create("compressed") + await decoded_r.ts().create("uncompressed", uncompressed=True) + compressed_info = await decoded_r.ts().info("compressed") + uncompressed_info = await decoded_r.ts().info("uncompressed") + if is_resp2_connection(decoded_r): + assert compressed_info.memory_usage != uncompressed_info.memory_usage + else: + assert compressed_info["memoryUsage"] != uncompressed_info["memoryUsage"] diff --git a/tests/test_bloom.py b/tests/test_bloom.py index 30d3219404..a82fece470 100644 --- a/tests/test_bloom.py +++ b/tests/test_bloom.py @@ -1,12 +1,11 @@ from math import inf import pytest - import redis.commands.bf from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt def intlist(obj): @@ -14,15 +13,15 @@ def intlist(obj): @pytest.fixture -def client(modclient): - assert isinstance(modclient.bf(), redis.commands.bf.BFBloom) - assert isinstance(modclient.cf(), redis.commands.bf.CFBloom) - assert isinstance(modclient.cms(), redis.commands.bf.CMSBloom) - assert isinstance(modclient.tdigest(), redis.commands.bf.TDigestBloom) - assert isinstance(modclient.topk(), redis.commands.bf.TOPKBloom) +def client(decoded_r): + assert isinstance(decoded_r.bf(), redis.commands.bf.BFBloom) + assert isinstance(decoded_r.cf(), redis.commands.bf.CFBloom) + assert isinstance(decoded_r.cms(), redis.commands.bf.CMSBloom) + assert isinstance(decoded_r.tdigest(), redis.commands.bf.TDigestBloom) + assert isinstance(decoded_r.topk(), redis.commands.bf.TOPKBloom) - modclient.flushdb() - return modclient + decoded_r.flushdb() + return decoded_r @pytest.mark.redismod @@ -61,7 +60,6 @@ def test_tdigest_create(client): assert client.tdigest().create("tDigest", 100) -# region Test Bloom Filter @pytest.mark.redismod def test_bf_add(client): assert client.bf().create("bloom", 0.01, 1000) @@ -86,9 +84,24 @@ def test_bf_insert(client): assert 0 == client.bf().exists("bloom", "noexist") assert [1, 0] == intlist(client.bf().mexists("bloom", "foo", "noexist")) info = client.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum + assert_resp_response( + client, + 2, + info.get("insertedNum"), + info.get("Number of items inserted"), + ) + assert_resp_response( + client, + 1000, + info.get("capacity"), + info.get("Capacity"), + ) + assert_resp_response( + client, + 1, + info.get("filterNum"), + info.get("Number of filters"), + ) @pytest.mark.redismod @@ -149,11 +162,21 @@ def test_bf_info(client): # Store a filter client.bf().create("nonscaling", "0.0001", "1000", noScale=True) info = client.bf().info("nonscaling") - assert info.expansionRate is None + assert_resp_response( + client, + None, + info.get("expansionRate"), + info.get("Expansion rate"), + ) client.bf().create("expanding", "0.0001", "1000", expansion=expansion) info = client.bf().info("expanding") - assert info.expansionRate == 4 + assert_resp_response( + client, + 4, + info.get("expansionRate"), + info.get("Expansion rate"), + ) try: # noScale mean no expansion @@ -180,7 +203,6 @@ def test_bf_card(client): client.bf().card("setKey") -# region Test Cuckoo Filter @pytest.mark.redismod def test_cf_add_and_insert(client): assert client.cf().create("cuckoo", 1000) @@ -196,9 +218,15 @@ def test_cf_add_and_insert(client): assert [1] == client.cf().insert("empty1", ["foo"], capacity=1000) assert [1] == client.cf().insertnx("empty2", ["bar"], capacity=1000) info = client.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum + assert_resp_response( + client, 5, info.get("insertedNum"), info.get("Number of items inserted") + ) + assert_resp_response( + client, 0, info.get("deletedNum"), info.get("Number of items deleted") + ) + assert_resp_response( + client, 1, info.get("filterNum"), info.get("Number of filters") + ) @pytest.mark.redismod @@ -214,7 +242,6 @@ def test_cf_exists_and_del(client): assert 0 == client.cf().count("cuckoo", "filter") -# region Test Count-Min Sketch @pytest.mark.redismod def test_cms(client): assert client.cms().initbydim("dim", 1000, 5) @@ -225,9 +252,10 @@ def test_cms(client): assert [10, 15] == client.cms().incrby("dim", ["foo", "bar"], [5, 15]) assert [10, 15] == client.cms().query("dim", "foo", "bar") info = client.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count + assert info["width"] + assert 1000 == info["width"] + assert 5 == info["depth"] + assert 25 == info["count"] @pytest.mark.redismod @@ -248,10 +276,6 @@ def test_cms_merge(client): assert [16, 15, 21] == client.cms().query("C", "foo", "bar", "baz") -# endregion - - -# region Test Top-K @pytest.mark.redismod def test_topk(client): # test list with empty buckets @@ -326,10 +350,10 @@ def test_topk(client): assert ["A", "B", "E"] == client.topk().list("topklist") assert ["A", 4, "B", 3, "E", 3] == client.topk().list("topklist", withcount=True) info = client.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) + assert 3 == info["k"] + assert 50 == info["width"] + assert 3 == info["depth"] + assert 0.9 == round(float(info["decay"]), 1) @pytest.mark.redismod @@ -346,7 +370,6 @@ def test_topk_incrby(client): ) -# region Test T-Digest @pytest.mark.redismod @pytest.mark.experimental def test_tdigest_reset(client): @@ -357,8 +380,11 @@ def test_tdigest_reset(client): assert client.tdigest().add("tDigest", list(range(10))) assert client.tdigest().reset("tDigest") - # assert we have 0 unmerged nodes - assert 0 == client.tdigest().info("tDigest").unmerged_nodes + # assert we have 0 unmerged + info = client.tdigest().info("tDigest") + assert_resp_response( + client, 0, info.get("unmerged_nodes"), info.get("Unmerged nodes") + ) @pytest.mark.redismod @@ -373,8 +399,10 @@ def test_tdigest_merge(client): assert client.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram info = client.tdigest().info("to-tDigest") - total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) - assert 20 == total_weight_to + if is_resp2_connection(client): + assert 20 == float(info["merged_weight"]) + float(info["unmerged_weight"]) + else: + assert 20 == float(info["Merged weight"]) + float(info["Unmerged weight"]) # test override assert client.tdigest().create("from-override", 10) assert client.tdigest().create("from-override-2", 10) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 705e753bd6..31c31026be 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -10,8 +10,8 @@ from unittest.mock import DEFAULT, Mock, call, patch import pytest - from redis import Redis +from redis._parsers import CommandsParser from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import ( PRIMARY, @@ -22,7 +22,6 @@ RedisCluster, get_node_name, ) -from redis.commands import CommandsParser from redis.connection import BlockingConnectionPool, Connection, ConnectionPool from redis.crc import key_slot from redis.exceptions import ( @@ -43,6 +42,7 @@ from .conftest import ( _get_client, + assert_resp_response, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -1000,7 +1000,7 @@ def test_client_setname(self, r): node = r.get_random_node() r.client_setname("redis_py_test", target_nodes=node) client_name = r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" + assert_resp_response(r, client_name, "redis_py_test", b"redis_py_test") def test_exists(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} @@ -1595,7 +1595,7 @@ def test_client_trackinginfo(self, r): node = r.get_primaries()[0] res = r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @skip_if_server_version_lt("2.9.50") def test_client_pause(self, r): @@ -1757,24 +1757,68 @@ def test_cluster_renamenx(self, r): def test_cluster_blpop(self, r): r.rpush("{foo}a", "1", "2") r.rpush("{foo}b", "3", "4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None r.rpush("{foo}c", "1") - assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, r.blpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) def test_cluster_brpop(self, r): r.rpush("{foo}a", "1", "2") r.rpush("{foo}b", "3", "4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None r.rpush("{foo}c", "1") - assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, r.brpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) def test_cluster_brpoplpush(self, r): r.rpush("{foo}a", "1", "2") @@ -1847,7 +1891,13 @@ def test_cluster_zdiff(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + response = r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response( + r, + response, + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): @@ -1855,7 +1905,8 @@ def test_cluster_zdiffstore(self, r): r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert r.zrange("{foo}out", 0, -1) == [b"a3"] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + response = r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): @@ -1866,31 +1917,42 @@ def test_cluster_zinter(self, r): # invalid aggregation with pytest.raises(DataError): r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MAX"), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MIN"), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) + assert_resp_response( + r, + r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a3", 20.0), (b"a1", 23.0)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zinterstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1900,7 +1962,12 @@ def test_cluster_zinterstore_max(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zinterstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1910,38 +1977,98 @@ def test_cluster_zinterstore_min(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) def test_cluster_zinterstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmax(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], + ) assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + r.bzpopmax("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmin(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"a", b"a2", 2], + ) assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + r.bzpopmin("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zrangestore(self, r): @@ -1950,7 +2077,12 @@ def test_cluster_zrangestore(self, r): assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("{foo}b", "{foo}a", 1, 2) assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}b", 0, 1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1972,39 +2104,45 @@ def test_cluster_zunion(self, r): r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zunionstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zunionstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -2014,12 +2152,12 @@ def test_cluster_zunionstore_max(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zunionstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2029,24 +2167,24 @@ def test_cluster_zunionstore_min(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) def test_cluster_zunionstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") def test_cluster_pfcount(self, r): @@ -3068,7 +3206,12 @@ def test_pipeline_readonly(self, r): with r.pipeline() as readonly_pipe: readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) - assert readonly_pipe.execute() == [b"a1", [(b"z1", 1.0), (b"z2", 4)]] + assert_resp_response( + r, + readonly_pipe.execute(), + [b"a1", [(b"z1", 1.0), (b"z2", 4)]], + [b"a1", [[b"z1", 1.0], [b"z2", 4.0]]], + ) def test_moved_redirection_on_slave_with_default(self, r): """ diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 6c3ede9cdf..e3b44a147f 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,8 +1,11 @@ import pytest +from redis._parsers import CommandsParser -from redis.commands import CommandsParser - -from .conftest import skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + assert_resp_response, + skip_if_redis_enterprise, + skip_if_server_version_lt, +) class TestCommandsParser: @@ -51,13 +54,40 @@ def test_get_moveable_keys(self, r): ] args7 = ["MIGRATE", "192.168.1.34", 6379, "key1", 0, 5000] - assert sorted(commands_parser.get_keys(r, *args1)) == ["key1", "key2"] - assert sorted(commands_parser.get_keys(r, *args2)) == ["mystream", "writers"] - assert sorted(commands_parser.get_keys(r, *args3)) == ["out", "zset1", "zset2"] - assert sorted(commands_parser.get_keys(r, *args4)) == ["Sicily", "out"] - assert sorted(commands_parser.get_keys(r, *args5)) == ["foo"] - assert sorted(commands_parser.get_keys(r, *args6)) == ["key1", "key2", "key3"] - assert sorted(commands_parser.get_keys(r, *args7)) == ["key1"] + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args1)), + ["key1", "key2"], + [b"key1", b"key2"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args2)), + ["mystream", "writers"], + [b"mystream", b"writers"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args3)), + ["out", "zset1", "zset2"], + [b"out", b"zset1", b"zset2"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args4)), + ["Sicily", "out"], + [b"Sicily", b"out"], + ) + assert sorted(commands_parser.get_keys(r, *args5)) in [["foo"], [b"foo"]] + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args6)), + ["key1", "key2", "key3"], + [b"key1", b"key2", b"key3"], + ) + assert_resp_response( + r, sorted(commands_parser.get_keys(r, *args7)), ["key1"], [b"key1"] + ) # A bug in redis<7.0 causes this to fail: https://github.com/redis/redis/issues/9493 @skip_if_server_version_lt("7.0.0") diff --git a/tests/test_commands.py b/tests/test_commands.py index 2213e81f72..fdf41dc5fa 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -9,13 +9,21 @@ from unittest.mock import patch import pytest - import redis from redis import exceptions -from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + parse_info, +) +from redis.client import EMPTY_RESPONSE, NEVER_DECODE from .conftest import ( _get_client, + assert_resp_response, + assert_resp_response_in, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_gte, skip_if_server_version_lt, @@ -58,17 +66,23 @@ class TestResponseCallbacks: "Tests for the response callback system" def test_response_callbacks(self, r): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS - assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + callbacks = _RedisCallbacks + if is_resp2_connection(r): + callbacks.update(_RedisCallbacksRESP2) + else: + callbacks.update(_RedisCallbacksRESP3) + assert r.response_callbacks == callbacks + assert id(r.response_callbacks) != id(_RedisCallbacks) r.set_response_callback("GET", lambda x: "static") r["a"] = "foo" assert r["a"] == "static" def test_case_insensitive_command_names(self, r): - assert r.response_callbacks["del"] == r.response_callbacks["DEL"] + assert r.response_callbacks["ping"] == r.response_callbacks["PING"] class TestRedisCommands: + @pytest.mark.onlynoncluster @skip_if_redis_enterprise() def test_auth(self, r, request): # sending an AUTH command before setting a user/password on the @@ -103,7 +117,6 @@ def teardown(): # connection field is not set in Redis Cluster, but that's ok # because the problem discussed above does not apply to Redis Cluster pass - r.auth(temp_pass) r.config_set("requirepass", "") r.acl_deluser(username) @@ -129,13 +142,13 @@ def test_command_on_invalid_key_type(self, r): def test_acl_cat_no_category(self, r): categories = r.acl_cat() assert isinstance(categories, list) - assert "read" in categories + assert "read" in categories or b"read" in categories @skip_if_server_version_lt("6.0.0") def test_acl_cat_with_category(self, r): commands = r.acl_cat("read") assert isinstance(commands, list) - assert "get" in commands + assert "get" in commands or b"get" in commands @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() @@ -149,9 +162,8 @@ def teardown(): r.acl_setuser(username, keys=["*"], commands=["+set"]) assert r.acl_dryrun(username, "set", "key", "value") == b"OK" - assert r.acl_dryrun(username, "get", "key").startswith( - b"This user has no permissions to run the" - ) + no_permissions_message = b"user has no permissions to run the" + assert no_permissions_message in r.acl_dryrun(username, "get", "key") @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -182,7 +194,7 @@ def teardown(): @skip_if_redis_enterprise() def test_acl_genpass(self, r): password = r.acl_genpass() - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) with pytest.raises(exceptions.DataError): r.acl_genpass("value") @@ -190,11 +202,12 @@ def test_acl_genpass(self, r): r.acl_genpass(5555) r.acl_genpass(555) - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_acl_getuser_setuser(self, r, request): + r.flushall() username = "redis-py-user" def teardown(): @@ -229,19 +242,19 @@ def teardown(): enabled=True, reset=True, passwords=["+pass1", "+pass2"], - categories=["+set", "+@hash", "-geo"], + categories=["+set", "+@hash", "-@geo"], commands=["+get", "+mget", "-hset"], keys=["cache:*", "objects:*"], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"+@hash", "+@set", "-@all", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 - # test reset=False keeps existing ACL and applies new ACL on top + # # test reset=False keeps existing ACL and applies new ACL on top assert r.acl_setuser( username, enabled=True, @@ -260,14 +273,13 @@ def teardown(): keys=["objects:*"], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True assert "on" in acl["flags"] assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 - # test removal of passwords + # # test removal of passwords assert r.acl_setuser( username, enabled=True, reset=True, passwords=["+pass1", "+pass2"] ) @@ -275,7 +287,7 @@ def teardown(): assert r.acl_setuser(username, enabled=True, passwords=["-pass2"]) assert len(r.acl_getuser(username)["passwords"]) == 1 - # Resets and tests that hashed passwords are set properly. + # # Resets and tests that hashed passwords are set properly. hashed_password = ( "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8" ) @@ -299,7 +311,7 @@ def teardown(): ) assert len(r.acl_getuser(username)["passwords"]) == 1 - # test selectors + # # test selectors assert r.acl_setuser( username, enabled=True, @@ -312,16 +324,19 @@ def teardown(): selectors=[("+set", "%W~app*")], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"+@hash", "+@set", "-@all", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 assert set(acl["channels"]) == {"&message:*"} - assert acl["selectors"] == [ - ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] - ] + assert_resp_response( + r, + acl["selectors"], + [["commands", "-@all +set", "keys", "%W~app*", "channels", ""]], + [{"commands": "-@all +set", "keys": "%W~app*", "channels": ""}], + ) @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): @@ -333,6 +348,7 @@ def test_acl_help(self, r): @skip_if_redis_enterprise() def test_acl_list(self, r, request): username = "redis-py-user" + start = r.acl_list() def teardown(): r.acl_deluser(username) @@ -341,7 +357,7 @@ def teardown(): assert r.acl_setuser(username, enabled=False, reset=True) users = r.acl_list() - assert len(users) == 2 + assert len(users) == len(start) + 1 @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -383,8 +399,13 @@ def teardown(): assert len(r.acl_log()) == 2 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - assert "client-info" in r.acl_log(count=1)[0] - assert r.acl_log_reset() + expected = r.acl_log(count=1)[0] + assert_resp_response_in( + r, + "client-info", + expected, + expected.keys(), + ) @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -434,7 +455,7 @@ def test_acl_users(self, r): @skip_if_server_version_lt("6.0.0") def test_acl_whoami(self, r): username = r.acl_whoami() - assert isinstance(username, str) + assert isinstance(username, (str, bytes)) @pytest.mark.onlynoncluster def test_client_list(self, r): @@ -489,7 +510,7 @@ def test_client_id(self, r): def test_client_trackinginfo(self, r): res = r.client_trackinginfo() assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @@ -531,7 +552,7 @@ def test_client_getname(self, r): @skip_if_server_version_lt("2.6.9") def test_client_setname(self, r): assert r.client_setname("redis_py_test") - assert r.client_getname() == "redis_py_test" + assert_resp_response(r, r.client_getname(), "redis_py_test", b"redis_py_test") @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.9") @@ -722,7 +743,7 @@ def test_waitaof(self, r): @skip_if_server_version_lt("3.2.0") def test_client_reply(self, r, r_timeout): assert r_timeout.client_reply("ON") == b"OK" - with pytest.raises(exceptions.TimeoutError): + with pytest.raises(exceptions.RedisError): r_timeout.client_reply("OFF") r_timeout.client_reply("SKIP") @@ -834,7 +855,7 @@ def test_lolwut(self, r): @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise() def test_reset(self, r): - assert r.reset() == "RESET" + assert_resp_response(r, r.reset(), "RESET", b"RESET") def test_object(self, r): r["a"] = "foo" @@ -1147,8 +1168,12 @@ def test_lcs(self, r): r.mset({"foo": "ohmytext", "bar": "mynewtext"}) assert r.lcs("foo", "bar") == b"mytext" assert r.lcs("foo", "bar", len=True) == 6 - result = [b"matches", [[[4, 7], [5, 8]]], b"len", 6] - assert r.lcs("foo", "bar", idx=True, minmatchlen=3) == result + assert_resp_response( + r, + r.lcs("foo", "bar", idx=True, minmatchlen=3), + [b"matches", [[[4, 7], [5, 8]]], b"len", 6], + {b"matches": [[[4, 7], [5, 8]]], b"len": 6}, + ) with pytest.raises(redis.ResponseError): assert r.lcs("foo", "bar", len=True, idx=True) @@ -1562,7 +1587,7 @@ def test_hrandfield(self, r): assert r.hrandfield("key") is not None assert len(r.hrandfield("key", 2)) == 2 # with values - assert len(r.hrandfield("key", 2, True)) == 4 + assert_resp_response(r, len(r.hrandfield("key", 2, withvalues=True)), 4, 2) # without duplications assert len(r.hrandfield("key", 10)) == 5 # with duplications @@ -1715,17 +1740,26 @@ def test_stralgo_lcs(self, r): assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels assert r.stralgo("LCS", value1, value2, len=True) == len(res) - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True), + {"len": len(res), "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]]}, + ) + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]]}, + ) + assert_resp_response( + r, + r.stralgo( + "LCS", value1, value2, idx=True, withmatchlen=True, minmatchlen=4 + ), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]}, + ) @skip_if_server_version_lt("6.0.0") @skip_if_server_version_gte("7.0.0") @@ -1788,25 +1822,41 @@ def test_type(self, r): def test_blpop(self, r): r.rpush("a", "1", "2") r.rpush("b", "3", "4") - assert r.blpop(["b", "a"], timeout=1) == (b"b", b"3") - assert r.blpop(["b", "a"], timeout=1) == (b"b", b"4") - assert r.blpop(["b", "a"], timeout=1) == (b"a", b"1") - assert r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) assert r.blpop(["b", "a"], timeout=1) is None r.rpush("c", "1") - assert r.blpop("c", timeout=1) == (b"c", b"1") + assert_resp_response(r, r.blpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"]) @pytest.mark.onlynoncluster def test_brpop(self, r): r.rpush("a", "1", "2") r.rpush("b", "3", "4") - assert r.brpop(["b", "a"], timeout=1) == (b"b", b"4") - assert r.brpop(["b", "a"], timeout=1) == (b"b", b"3") - assert r.brpop(["b", "a"], timeout=1) == (b"a", b"2") - assert r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) assert r.brpop(["b", "a"], timeout=1) is None r.rpush("c", "1") - assert r.brpop("c", timeout=1) == (b"c", b"1") + assert_resp_response(r, r.brpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"]) @pytest.mark.onlynoncluster def test_brpoplpush(self, r): @@ -2174,8 +2224,12 @@ def test_spop_multi_value(self, r): for value in values: assert value in s - - assert r.spop("a", 1) == list(set(s) - set(values)) + assert_resp_response( + r, + r.spop("a", 1), + list(set(s) - set(values)), + set(s) - set(values), + ) def test_srandmember(self, r): s = [b"1", b"2", b"3"] @@ -2226,11 +2280,12 @@ def test_script_debug(self, r): def test_zadd(self, r): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} r.zadd("a", mapping) - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -2247,17 +2302,32 @@ def test_zadd(self, r): def test_zadd_nx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) def test_zadd_xx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 99.0)], + [[b"a1", 99.0]], + ) def test_zadd_ch(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a2", 2.0), (b"a1", 99.0)] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a2", 2.0), (b"a1", 99.0)], + [[b"a2", 2.0], [b"a1", 99.0]], + ) def test_zadd_incr(self, r): assert r.zadd("a", {"a1": 1}) == 1 @@ -2305,7 +2375,12 @@ def test_zdiff(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiff(["a", "b"]) == [b"a3"] - assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] + assert_resp_response( + r, + r.zdiff(["a", "b"], withscores=True), + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @@ -2314,7 +2389,12 @@ def test_zdiffstore(self, r): r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiffstore("out", ["a", "b"]) assert r.zrange("out", 0, -1) == [b"a3"] - assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] + assert_resp_response( + r, + r.zrange("out", 0, -1, withscores=True), + [(b"a3", 3.0)], + [[b"a3", 3.0]], + ) def test_zincrby(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2340,22 +2420,33 @@ def test_zinter(self, r): with pytest.raises(exceptions.DataError): r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zinter(["a", "b", "c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a3", 1), - ] + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -2372,7 +2463,12 @@ def test_zinterstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"]) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): @@ -2380,7 +2476,12 @@ def test_zinterstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): @@ -2388,7 +2489,12 @@ def test_zinterstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1], [b"a3", 3]], + ) @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): @@ -2396,23 +2502,36 @@ def test_zinterstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmax("a") == [(b"a3", 3)] - + assert_resp_response(r, r.zpopmax("a"), [(b"a3", 3)], [b"a3", 3.0]) # with count - assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + assert_resp_response( + r, + r.zpopmax("a", count=2), + [(b"a2", 2), (b"a1", 1)], + [[b"a2", 2], [b"a1", 1]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmin("a") == [(b"a1", 1)] - + assert_resp_response(r, r.zpopmin("a"), [(b"a1", 1)], [b"a1", 1.0]) # with count - assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zpopmin("a", count=2), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): @@ -2420,7 +2539,12 @@ def test_zrandemember(self, r): assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 # with scores - assert len(r.zrandmember("a", 2, True)) == 4 + assert_resp_response( + r, + len(r.zrandmember("a", 2, withscores=True)), + 4, + 2, + ) # without duplications assert len(r.zrandmember("a", 10)) == 5 # with duplications @@ -2431,49 +2555,86 @@ def test_zrandemember(self, r): def test_bzpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2}) r.zadd("b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"b", b"b2", 20), [b"b", b"b2", 20] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"b", b"b1", 10), [b"b", b"b1", 10] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"a", b"a2", 2), [b"a", b"a2", 2] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"a", b"a1", 1), [b"a", b"a1", 1] + ) assert r.bzpopmax(["b", "a"], timeout=1) is None r.zadd("c", {"c1": 100}) - assert r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, r.bzpopmax("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("4.9.0") def test_bzpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2}) r.zadd("b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"b", b"b1", 10), [b"b", b"b1", 10] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"b", b"b2", 20), [b"b", b"b2", 20] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"a", b"a1", 1), [b"a", b"a1", 1] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"a", b"a2", 2), [b"a", b"a2", 2] + ) assert r.bzpopmin(["b", "a"], timeout=1) is None r.zadd("c", {"c1": 100}) - assert r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, r.bzpopmin("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_zmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - assert r.zmpop("2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.zmpop("2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - assert r.zmpop("2", ["b", "a"], max=True) == [b"b", [[b"b1", b"10"]]] + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_bzmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - assert r.bzmpop(1, "2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.bzmpop(1, "2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.bzmpop(1, "2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - res = [b"b", [[b"b1", b"10"]]] - assert r.bzmpop(0, "2", ["b", "a"], max=True) == res + assert_resp_response( + r, + r.bzmpop(0, "2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) assert r.bzmpop(1, "2", ["foo", "bar"], max=True) is None def test_zrange(self, r): @@ -2484,14 +2645,24 @@ def test_zrange(self, r): assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] + assert_resp_response( + r, + r.zrange("a", 0, 1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) - # custom score function - assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + # # custom score function + # assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -2523,14 +2694,20 @@ def test_zrange_params(self, r): b"a3", b"a2", ] - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrange( - "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + assert_resp_response( + r, + r.zrange("a", 2, 4, byscore=True, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ), + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) # rev assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @@ -2543,7 +2720,12 @@ def test_zrangestore(self, r): assert r.zrange("b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("b", "a", 1, 2) assert r.zrange("b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) # reversed order assert r.zrangestore("b", "a", 1, 2, desc=True) assert r.zrange("b", 0, -1) == [b"a1", b"a2"] @@ -2578,16 +2760,18 @@ def test_zrangebyscore(self, r): # slicing with start/num assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int), + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2644,32 +2828,45 @@ def test_zrevrange(self, r): assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [(b"a3", 3.0), (b"a2", 2.0)] - assert r.zrevrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a1", 1.0)] + assert_resp_response( + r, + r.zrevrange("a", 0, 1, withscores=True), + [(b"a3", 3.0), (b"a2", 2.0)], + [[b"a3", 3.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrevrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a1", 1.0)], + [[b"a2", 2.0], [b"a1", 1.0]], + ) - # custom score function - assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + # # custom score function + # assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a3", 3.0), + # (b"a2", 2.0), + # ] def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) assert r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] # slicing with start/num assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] + # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) # custom score function - assert r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) == [ - (b"a4", 4), - (b"a3", 3), - (b"a2", 2), - ] + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2700,33 +2897,33 @@ def test_zunion(self, r): r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["a", "b", "c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a2", 1), - (b"a3", 1), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1], [b"a2", 1], [b"a3", 1], [b"a4", 4]], + ) # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): @@ -2734,12 +2931,12 @@ def test_zunionstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"]) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): @@ -2747,12 +2944,12 @@ def test_zunionstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): @@ -2760,12 +2957,12 @@ def test_zunionstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1], [b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): @@ -2773,12 +2970,12 @@ def test_zunionstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): @@ -3293,11 +3490,12 @@ def test_geohash(self, r): "place2", ) r.geoadd("barcelona", values) - assert r.geohash("barcelona", "place1", "place2", "place3") == [ - "sp3e9yg3kd0", - "sp3e9cbc3t0", - None, - ] + assert_resp_response( + r, + r.geohash("barcelona", "place1", "place2", "place3"), + ["sp3e9yg3kd0", "sp3e9cbc3t0", None], + [b"sp3e9yg3kd0", b"sp3e9cbc3t0", None], + ) @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") @@ -3309,10 +3507,18 @@ def test_geopos(self, r): ) r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert r.geopos("barcelona", "place1", "place2") == [ - (2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099), - ] + assert_resp_response( + r, + r.geopos("barcelona", "place1", "place2"), + [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ], + [ + [2.19093829393386841, 41.43379028184083523], + [2.18737632036209106, 41.40634178640635099], + ], + ) @skip_if_server_version_lt("4.0.0") def test_geopos_no_value(self, r): @@ -3852,7 +4058,7 @@ def test_xadd_explicit_ms(self, r: redis.Redis): ms = message_id[: message_id.index(b"-")] assert ms == b"9999999999999999999" - @skip_if_server_version_lt("6.2.0") + @skip_if_server_version_lt("7.0.0") def test_xautoclaim(self, r): stream = "stream" group = "group" @@ -3867,7 +4073,7 @@ def test_xautoclaim(self, r): # trying to claim a message that isn't already pending doesn't # do anything response = r.xautoclaim(stream, group, consumer2, min_idle_time=0) - assert response == [b"0-0", []] + assert response == [b"0-0", [], []] # read the group as consumer1 to initially claim the messages r.xreadgroup(group, consumer1, streams={stream: ">"}) @@ -4111,7 +4317,7 @@ def test_xgroup_setid(self, r): ] assert r.xinfo_groups(stream) == expected - @skip_if_server_version_lt("5.0.0") + @skip_if_server_version_lt("7.2.0") def test_xinfo_consumers(self, r): stream = "stream" group = "group" @@ -4127,8 +4333,8 @@ def test_xinfo_consumers(self, r): info = r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ - {"name": consumer1.encode(), "pending": 1}, - {"name": consumer2.encode(), "pending": 2}, + {"name": consumer1.encode(), "pending": 1, "inactive": 2}, + {"name": consumer2.encode(), "pending": 2, "inactive": 2}, ] # we can't determine the idle time, so just make sure it's an int @@ -4159,7 +4365,12 @@ def test_xinfo_stream_full(self, r): info = r.xinfo_stream(stream, full=True) assert info["length"] == 1 - assert m1 in info["entries"] + assert_resp_response_in( + r, + m1, + info["entries"], + info["entries"].keys(), + ) assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4300,25 +4511,39 @@ def test_xread(self, r): m1 = r.xadd(stream, {"foo": "bar"}) m2 = r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + stream_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert r.xread(streams={stream: 0}) == expected + assert_resp_response( + r, + r.xread(streams={stream: 0}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert r.xread(streams={stream: 0}, count=1) == expected + assert_resp_response( + r, + r.xread(streams={stream: 0}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) - expected = [[stream.encode(), [get_stream_message(r, stream, m2)]]] + expected_entries = [get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert r.xread(streams={stream: m1}) == expected + assert_resp_response( + r, + r.xread(streams={stream: m1}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) # xread starting at the last message returns an empty list - assert r.xread(streams={stream: m2}) == [] + assert_resp_response(r, r.xread(streams={stream: m2}), [], {}) @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): @@ -4329,21 +4554,32 @@ def test_xreadgroup(self, r): m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + stream_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) == expected + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) @@ -4351,27 +4587,37 @@ def test_xreadgroup(self, r): # will only find messages added after this r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + assert_resp_response( + r, r.xreadgroup(group, consumer, streams={stream: ">"}), [], {} + ) # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") - assert ( - len(r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True)[0][1]) - == 2 - ) - # now there should be nothing pending - assert len(r.xreadgroup(group, consumer, streams={stream: "0"})[0][1]) == 0 + res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[stream_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[stream_name][0]) == 0 r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - assert r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: "0"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): @@ -4637,7 +4883,7 @@ def test_command_list(self, r: redis.Redis): @skip_if_redis_enterprise() def test_command_getkeys(self, r): res = r.command_getkeys("MSET", "a", "b", "c", "d", "e", "f") - assert res == ["a", "c", "e"] + assert_resp_response(r, res, ["a", "c", "e"], [b"a", b"c", b"e"]) res = r.command_getkeys( "EVAL", '"not consulted"', @@ -4650,7 +4896,9 @@ def test_command_getkeys(self, r): "arg3", "argN", ) - assert res == ["key1", "key2", "key3"] + assert_resp_response( + r, res, ["key1", "key2", "key3"], [b"key1", b"key2", b"key3"] + ) @skip_if_server_version_lt("2.8.13") def test_command(self, r): @@ -4664,12 +4912,17 @@ def test_command(self, r): @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_command_getkeysandflags(self, r: redis.Redis): - res = [ - [b"mylist1", [b"RW", b"access", b"delete"]], - [b"mylist2", [b"RW", b"insert"]], - ] - assert res == r.command_getkeysandflags( - "LMOVE", "mylist1", "mylist2", "left", "left" + assert_resp_response( + r, + r.command_getkeysandflags("LMOVE", "mylist1", "mylist2", "left", "left"), + [ + [b"mylist1", [b"RW", b"access", b"delete"]], + [b"mylist2", [b"RW", b"insert"]], + ], + [ + [b"mylist1", {b"RW", b"access", b"delete"}], + [b"mylist2", {b"RW", b"insert"}], + ], ) @pytest.mark.onlynoncluster @@ -4765,6 +5018,8 @@ def test_shutdown_with_params(self, r: redis.Redis): @skip_if_server_version_lt("2.8.0") @skip_if_redis_enterprise() def test_sync(self, r): + r.flushdb() + time.sleep(1) r2 = redis.Redis(port=6380, decode_responses=False) res = r2.sync() assert b"REDIS" in res diff --git a/tests/test_connect.py b/tests/test_connect.py index b4ec7020e1..b233c67e83 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -6,7 +6,6 @@ import threading import pytest - from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection from .ssl_utils import get_ssl_filename diff --git a/tests/test_connection.py b/tests/test_connection.py index 31268a9e8b..760b23c9c1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,16 +4,10 @@ from unittest.mock import patch import pytest - import redis +from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff -from redis.connection import ( - Connection, - HiredisParser, - PythonParser, - SSLConnection, - UnixDomainSocketConnection, -) +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -35,22 +29,22 @@ def test_invalid_response(r): @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod -def test_loading_external_modules(modclient): +def test_loading_external_modules(r): def inner(): pass - modclient.load_external_module("myfuncname", inner) - assert getattr(modclient, "myfuncname") == inner - assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + r.load_external_module("myfuncname", inner) + assert getattr(r, "myfuncname") == inner + assert isinstance(getattr(r, "myfuncname"), types.FunctionType) # and call it from redis.commands import RedisModuleCommands j = RedisModuleCommands.json - modclient.load_external_module("sometestfuncname", j) + r.load_external_module("sometestfuncname", j) # d = {'hello': 'world!'} - # mod = j(modclient) + # mod = j(r) # mod.set("fookey", ".", d) # assert mod.get('fookey') == d @@ -134,7 +128,9 @@ def test_connect_timeout_error_without_retry(self): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_RESP2Parser, _RESP3Parser, _HiredisParser], + ids=["RESP2Parser", "RESP3Parser", "HiredisParser"], ) def test_connection_parse_response_resume(r: redis.Redis, parser_class): """ @@ -142,7 +138,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): be that PythonParser or HiredisParser, can be interrupted at IO time and then resume parsing. """ - if parser_class is HiredisParser and not HIREDIS_AVAILABLE: + if parser_class is _HiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available)") args = dict(r.connection_pool.connection_kwargs) args["parser_class"] = parser_class @@ -154,7 +150,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): ) mock_socket = MockSocket(message, interrupt_every=2) - if isinstance(conn._parser, PythonParser): + if isinstance(conn._parser, _RESP2Parser) or isinstance(conn._parser, _RESP3Parser): conn._parser._buffer._sock = mock_socket else: conn._parser._sock = mock_socket diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 155bffe56a..ab0fc6be98 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -5,9 +5,9 @@ from unittest import mock import pytest - import redis -from redis.connection import ssl_available, to_bool +from redis.connection import to_bool +from redis.utils import SSL_AVAILABLE from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt from .test_pubsub import wait_for_message @@ -425,7 +425,7 @@ class MyConnection(redis.UnixDomainSocketConnection): assert pool.connection_class == MyConnection -@pytest.mark.skipif(not ssl_available, reason="SSL not installed") +@pytest.mark.skipif(not SSL_AVAILABLE, reason="SSL not installed") class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 9aeb1ef1d5..aade04e082 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -4,7 +4,6 @@ from typing import Optional, Tuple, Union import pytest - import redis from redis import AuthenticationError, DataError, ResponseError from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -198,6 +197,12 @@ def test_change_username_password_on_existing_connection(self, r, request): password = "origin_password" new_username = "new_username" new_password = "new_password" + + def teardown(): + r.acl_deluser(new_username) + + request.addfinalizer(teardown) + init_acl_user(r, request, username, password) r2 = _get_client( redis.Redis, request, flushdb=False, username=username, password=password diff --git a/tests/test_encoding.py b/tests/test_encoding.py index cb9c4e20be..331cd5108c 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -1,5 +1,4 @@ import pytest - import redis from redis.connection import Connection from redis.utils import HIREDIS_PACK_AVAILABLE diff --git a/tests/test_function.py b/tests/test_function.py index 7ce66a38e6..22db904273 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,8 +1,7 @@ import pytest - from redis.exceptions import ResponseError -from .conftest import skip_if_server_version_lt +from .conftest import assert_resp_response, skip_if_server_version_lt engine = "lua" lib = "mylib" @@ -64,12 +63,22 @@ def test_function_list(self, r): [[b"name", b"myfunc", b"description", None, b"flags", [b"no-writes"]]], ] ] - assert r.function_list() == res - assert r.function_list(library="*lib") == res - assert ( - r.function_list(withcode=True)[0][7] - == f"#!{engine} name={lib} \n {function}".encode() + resp3_res = [ + { + b"library_name": b"mylib", + b"engine": b"LUA", + b"functions": [ + {b"name": b"myfunc", b"description": None, b"flags": {b"no-writes"}} + ], + } + ] + assert_resp_response(r, r.function_list(), res, resp3_res) + assert_resp_response(r, r.function_list(library="*lib"), res, resp3_res) + res[0].extend( + [b"library_code", f"#!{engine} name={lib} \n {function}".encode()] ) + resp3_res[0][b"library_code"] = f"#!{engine} name={lib} \n {function}".encode() + assert_resp_response(r, r.function_list(withcode=True), res, resp3_res) @pytest.mark.onlycluster def test_function_list_on_cluster(self, r): diff --git a/tests/test_graph.py b/tests/test_graph.py index 37e5ca43aa..42f1d9e5df 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest - +from redis import Redis from redis.commands.graph import Edge, Node, Path from redis.commands.graph.execution_plan import Operation from redis.commands.graph.query_result import ( @@ -20,13 +20,14 @@ QueryResult, ) from redis.exceptions import ResponseError -from tests.conftest import skip_if_redis_enterprise +from tests.conftest import _get_client, skip_if_redis_enterprise @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(request): + r = _get_client(Redis, request, decode_responses=True) + r.flushdb() + return r @pytest.mark.redismod diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py index b5b7362389..581ebfab5d 100644 --- a/tests/test_graph_utils/test_edge.py +++ b/tests/test_graph_utils/test_edge.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.graph import edge, node diff --git a/tests/test_graph_utils/test_node.py b/tests/test_graph_utils/test_node.py index cd4e936719..c3b34ac6ff 100644 --- a/tests/test_graph_utils/test_node.py +++ b/tests/test_graph_utils/test_node.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.graph import node diff --git a/tests/test_graph_utils/test_path.py b/tests/test_graph_utils/test_path.py index d581269307..1bd38efab4 100644 --- a/tests/test_graph_utils/test_path.py +++ b/tests/test_graph_utils/test_path.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.graph import edge, node, path diff --git a/tests/test_json.py b/tests/test_json.py index c41ad5e2e1..fb608ff425 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,17 +1,17 @@ import pytest - import redis -from redis import exceptions +from redis import Redis, exceptions from redis.commands.json.decoders import decode_list, unstring from redis.commands.json.path import Path -from .conftest import skip_ifmodversion_lt +from .conftest import _get_client, assert_resp_response, skip_ifmodversion_lt @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(request): + r = _get_client(Redis, request, decode_responses=True) + r.flushdb() + return r @pytest.mark.redismod @@ -25,7 +25,7 @@ def test_json_setbinarykey(client): @pytest.mark.redismod def test_json_setgetdeleteforget(client): assert client.json().set("foo", Path.root_path(), "bar") - assert client.json().get("foo") == "bar" + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) assert client.json().get("baz") is None assert client.json().delete("foo") == 1 assert client.json().forget("foo") == 0 # second delete @@ -35,13 +35,13 @@ def test_json_setgetdeleteforget(client): @pytest.mark.redismod def test_jsonget(client): client.json().set("foo", Path.root_path(), "bar") - assert client.json().get("foo") == "bar" + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) @pytest.mark.redismod def test_json_get_jset(client): assert client.json().set("foo", Path.root_path(), "bar") - assert "bar" == client.json().get("foo") + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) assert client.json().get("baz") is None assert 1 == client.json().delete("foo") assert client.exists("foo") == 0 @@ -83,7 +83,10 @@ def test_json_merge(client): @pytest.mark.redismod def test_nonascii_setgetdelete(client): assert client.json().set("notascii", Path.root_path(), "hyvää-élève") - assert "hyvää-élève" == client.json().get("notascii", no_escape=True) + res = "hyvää-élève" + assert_resp_response( + client, client.json().get("notascii", no_escape=True), res, [[res]] + ) assert 1 == client.json().delete("notascii") assert client.exists("notascii") == 0 @@ -129,22 +132,30 @@ def test_mset(client): def test_clear(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 1 == client.json().clear("arr", Path.root_path()) - assert [] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [], [[[]]]) @pytest.mark.redismod def test_type(client): client.json().set("1", Path.root_path(), 1) - assert "integer" == client.json().type("1", Path.root_path()) - assert "integer" == client.json().type("1") + assert_resp_response( + client, client.json().type("1", Path.root_path()), "integer", ["integer"] + ) + assert_resp_response(client, client.json().type("1"), "integer", ["integer"]) @pytest.mark.redismod def test_numincrby(client): client.json().set("num", Path.root_path(), 1) - assert 2 == client.json().numincrby("num", Path.root_path(), 1) - assert 2.5 == client.json().numincrby("num", Path.root_path(), 0.5) - assert 1.25 == client.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), 1), 2, [2] + ) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), 0.5), 2.5, [2.5] + ) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), -1.25), 1.25, [1.25] + ) @pytest.mark.redismod @@ -152,9 +163,15 @@ def test_nummultby(client): client.json().set("num", Path.root_path(), 1) with pytest.deprecated_call(): - assert 2 == client.json().nummultby("num", Path.root_path(), 2) - assert 5 == client.json().nummultby("num", Path.root_path(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.root_path(), 0.5) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 2), 2, [2] + ) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 2.5), 5, [5] + ) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 0.5), 2.5, [2.5] + ) @pytest.mark.redismod @@ -173,7 +190,9 @@ def test_toggle(client): def test_strappend(client): client.json().set("jsonkey", Path.root_path(), "foo") assert 6 == client.json().strappend("jsonkey", "bar") - assert "foobar" == client.json().get("jsonkey", Path.root_path()) + assert_resp_response( + client, client.json().get("jsonkey", Path.root_path()), "foobar", [["foobar"]] + ) # @pytest.mark.redismod @@ -219,12 +238,14 @@ def test_arrindex(client): def test_arrinsert(client): client.json().set("arr", Path.root_path(), [0, 4]) assert 5 - -client.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) - assert [0, 1, 2, 3, 4] == client.json().get("arr") + res = [0, 1, 2, 3, 4] + assert_resp_response(client, client.json().get("arr"), res, [[res]]) # test prepends client.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) client.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) - assert client.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + res = [["some", "thing"], 5, 6, 7, 8, 9] + assert_resp_response(client, client.json().get("val2"), res, [[res]]) @pytest.mark.redismod @@ -242,7 +263,7 @@ def test_arrpop(client): assert 3 == client.json().arrpop("arr", Path.root_path(), -1) assert 2 == client.json().arrpop("arr", Path.root_path()) assert 0 == client.json().arrpop("arr", Path.root_path(), 0) - assert [1] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [1], [[[1]]]) # test out of bounds client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -257,7 +278,7 @@ def test_arrpop(client): def test_arrtrim(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 3 == client.json().arrtrim("arr", Path.root_path(), 1, 3) - assert [1, 2, 3] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [1, 2, 3], [[[1, 2, 3]]]) # <0 test, should be 0 equivalent client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -319,7 +340,7 @@ def test_json_commands_in_pipeline(client): p.set("foo", Path.root_path(), "bar") p.get("foo") p.delete("foo") - assert [True, "bar", 1] == p.execute() + assert_resp_response(client, p.execute(), [True, "bar", 1], [True, [["bar"]], 1]) assert client.keys() == [] assert client.get("foo") is None @@ -332,7 +353,7 @@ def test_json_commands_in_pipeline(client): p.jsonget("foo") p.exists("notarealkey") p.delete("foo") - assert [True, d, 0, 1] == p.execute() + assert_resp_response(client, p.execute(), [True, d, 0, 1], [True, [[d]], 0, 1]) assert client.keys() == [] assert client.get("foo") is None @@ -342,14 +363,14 @@ def test_json_delete_with_dollar(client): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert client.json().set("doc1", "$", doc1) assert client.json().delete("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().delete("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(client, client.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -380,8 +401,7 @@ def test_json_delete_with_dollar(client): } ] ] - res = client.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(client, client.json().get("doc3", "$"), doc3val, [doc3val]) # Test default path assert client.json().delete("doc3") == 1 @@ -395,14 +415,14 @@ def test_json_forget_with_dollar(client): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert client.json().set("doc1", "$", doc1) assert client.json().forget("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().forget("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(client, client.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -433,8 +453,7 @@ def test_json_forget_with_dollar(client): } ] ] - res = client.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(client, client.json().get("doc3", "$"), doc3val, [doc3val]) # Test default path assert client.json().forget("doc3") == 1 @@ -457,8 +476,10 @@ def test_json_mget_dollar(client): {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET - assert client.json().get("doc1", "$..a") == [1, 3, None] - assert client.json().get("doc2", "$..a") == [4, 6, [None]] + res = [1, 3, None] + assert_resp_response(client, client.json().get("doc1", "$..a"), res, [res]) + res = [4, 6, [None]] + assert_resp_response(client, client.json().get("doc2", "$..a"), res, [res]) # Test mget with single path client.json().mget("doc1", "$..a") == [1, 3, None] @@ -525,15 +546,14 @@ def test_strappend_dollar(client): # Test multi client.json().strappend("doc1", "bar", "$..a") == [6, 8, None] - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single client.json().strappend("doc1", "baz", "$.nested1.a") == [11] - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -541,9 +561,8 @@ def test_strappend_dollar(client): # Test multi client.json().strappend("doc1", "bar", ".*.a") == 8 - client.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing path with pytest.raises(exceptions.ResponseError): @@ -585,23 +604,25 @@ def test_arrappend_dollar(client): ) # Test multi client.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test single assert client.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -620,22 +641,25 @@ def test_arrappend_dollar(client): # Test multi (all paths are updated, but return result of last path) assert client.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -656,22 +680,25 @@ def test_arrinsert_dollar(client): # Test multi assert client.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -743,9 +770,8 @@ def test_arrpop_dollar(client): # # # Test multi assert client.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -763,9 +789,8 @@ def test_arrpop_dollar(client): ) # Test multi (all paths are updated, but return result of last path) client.json().arrpop("doc1", "..a", "1") is None - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # # Test missing key with pytest.raises(exceptions.ResponseError): @@ -786,19 +811,17 @@ def test_arrtrim_dollar(client): ) # Test multi assert client.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) assert client.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": []}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -820,9 +843,8 @@ def test_arrtrim_dollar(client): # Test single assert client.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -920,13 +942,17 @@ def test_type_dollar(client): jdata, jtypes = load_types_data("a") client.json().set("doc1", "$", jdata) # Test multi - assert client.json().type("doc1", "$..a") == jtypes + assert_resp_response(client, client.json().type("doc1", "$..a"), jtypes, [jtypes]) # Test single - assert client.json().type("doc1", "$.nested2.a") == [jtypes[1]] + assert_resp_response( + client, client.json().type("doc1", "$.nested2.a"), [jtypes[1]], [[jtypes[1]]] + ) # Test missing key - assert client.json().type("non_existing_doc", "..a") is None + assert_resp_response( + client, client.json().type("non_existing_doc", "..a"), None, [None] + ) @pytest.mark.redismod @@ -944,9 +970,10 @@ def test_clear_dollar(client): # Test multi assert client.json().clear("doc1", "$..a") == 3 - assert client.json().get("doc1", "$") == [ + res = [ {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test single client.json().set( @@ -960,7 +987,7 @@ def test_clear_dollar(client): }, ) assert client.json().clear("doc1", "$.nested1.a") == 1 - assert client.json().get("doc1", "$") == [ + res = [ { "nested1": {"a": {}}, "a": ["foo"], @@ -968,10 +995,11 @@ def test_clear_dollar(client): "nested3": {"a": {"baz": 50}}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing path (defaults to root) assert client.json().clear("doc1") == 1 - assert client.json().get("doc1", "$") == [{}] + assert_resp_response(client, client.json().get("doc1", "$"), [{}], [[{}]]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -992,7 +1020,7 @@ def test_toggle_dollar(client): ) # Test multi assert client.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo"], "nested1": {"a": True}, @@ -1000,6 +1028,7 @@ def test_toggle_dollar(client): "nested3": {"a": False}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -1075,7 +1104,7 @@ def test_resp_dollar(client): client.json().set("doc1", "$", data) # Test multi res = client.json().resp("doc1", "$..a") - assert res == [ + resp2 = [ [ "{", "A1_B1", @@ -1131,10 +1160,67 @@ def test_resp_dollar(client): ["{", "A2_B4_C1", "bar"], ], ] + resp3 = [ + [ + "{", + "A1_B1", + 10, + "A1_B2", + "false", + "A1_B3", + [ + "{", + "A1_B3_C1", + None, + "A1_B3_C2", + [ + "[", + "A1_B3_C2_D1_1", + "A1_B3_C2_D1_2", + -19.5, + "A1_B3_C2_D1_4", + "A1_B3_C2_D1_5", + ["{", "A1_B3_C2_D1_6_E1", "true"], + ], + "A1_B3_C3", + ["[", 1], + ], + "A1_B4", + ["{", "A1_B4_C1", "foo"], + ], + [ + "{", + "A2_B1", + 20, + "A2_B2", + "false", + "A2_B3", + [ + "{", + "A2_B3_C1", + None, + "A2_B3_C2", + [ + "[", + "A2_B3_C2_D1_1", + "A2_B3_C2_D1_2", + -37.5, + "A2_B3_C2_D1_4", + "A2_B3_C2_D1_5", + ["{", "A2_B3_C2_D1_6_E1", "false"], + ], + "A2_B3_C3", + ["[", 2], + ], + "A2_B4", + ["{", "A2_B4_C1", "bar"], + ], + ] + assert_resp_response(client, res, resp2, resp3) # Test single - resSingle = client.json().resp("doc1", "$.L1.a") - assert resSingle == [ + res = client.json().resp("doc1", "$.L1.a") + resp2 = [ [ "{", "A1_B1", @@ -1163,6 +1249,36 @@ def test_resp_dollar(client): ["{", "A1_B4_C1", "foo"], ] ] + resp3 = [ + [ + "{", + "A1_B1", + 10, + "A1_B2", + "false", + "A1_B3", + [ + "{", + "A1_B3_C1", + None, + "A1_B3_C2", + [ + "[", + "A1_B3_C2_D1_1", + "A1_B3_C2_D1_2", + -19.5, + "A1_B3_C2_D1_4", + "A1_B3_C2_D1_5", + ["{", "A1_B3_C2_D1_6_E1", "true"], + ], + "A1_B3_C3", + ["[", 1], + ], + "A1_B4", + ["{", "A1_B4_C1", "foo"], + ] + ] + assert_resp_response(client, res, resp2, resp3) # Test missing path client.json().resp("doc1", "$.nowhere") @@ -1217,10 +1333,13 @@ def test_arrindex_dollar(client): }, ) - assert client.json().get("store", "$.store.book[?(@.price<10)].size") == [ - [10, 20, 30, 40], - [5, 10, 20, 30], - ] + assert_resp_response( + client, + client.json().get("store", "$.store.book[?(@.price<10)].size"), + [[10, 20, 30, 40], [5, 10, 20, 30]], + [[[10, 20, 30, 40], [5, 10, 20, 30]]], + ) + assert client.json().arrindex( "store", "$.store.book[?(@.price<10)].size", "20" ) == [-1, -1] @@ -1241,13 +1360,14 @@ def test_arrindex_dollar(client): ], ) - assert client.json().get("test_num", "$..arr") == [ + res = [ [0, 1, 3.0, 3, 2, 1, 0, 3], [5, 4, 3, 2, 1, 0, 1, 2, 3.0, 2, 4, 5], [2, 4, 6], "3", [], ] + assert_resp_response(client, client.json().get("test_num", "$..arr"), res, [res]) assert client.json().arrindex("test_num", "$..arr", 3) == [3, 2, -1, None, -1] @@ -1273,13 +1393,14 @@ def test_arrindex_dollar(client): ], ], ) - assert client.json().get("test_string", "$..arr") == [ + res = [ ["bazzz", "bar", 2, "baz", 2, "ba", "baz", 3], [None, "baz2", "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5], ["baz2", 4, 6], "3", [], ] + assert_resp_response(client, client.json().get("test_string", "$..arr"), res, [res]) assert client.json().arrindex("test_string", "$..arr", "baz") == [ 3, @@ -1365,13 +1486,14 @@ def test_arrindex_dollar(client): ], ], ) - assert client.json().get("test_None", "$..arr") == [ + res = [ ["bazzz", "None", 2, None, 2, "ba", "baz", 3], ["zaz", "baz2", "buzz", 2, 1, 0, 1, "2", None, 2, 4, 5], ["None", 4, 6], None, [], ] + assert_resp_response(client, client.json().get("test_None", "$..arr"), res, [res]) # Test with none-scalar value assert client.json().arrindex( @@ -1412,7 +1534,7 @@ def test_custom_decoder(client): cj = client.json(encoder=ujson, decoder=ujson) assert cj.set("foo", Path.root_path(), "bar") - assert "bar" == cj.get("foo") + assert_resp_response(client, cj.get("foo"), "bar", [["bar"]]) assert cj.get("baz") is None assert 1 == cj.delete("foo") assert client.exists("foo") == 0 @@ -1434,7 +1556,7 @@ def test_set_file(client): nojsonfile.write(b"Hello World") assert client.json().set_file("test", Path.root_path(), jsonfile.name) - assert client.json().get("test") == obj + assert_resp_response(client, client.json().get("test"), obj, [[obj]]) with pytest.raises(json.JSONDecodeError): client.json().set_file("test2", Path.root_path(), nojsonfile.name) @@ -1456,4 +1578,7 @@ def test_set_path(client): result = {jsonfile: True, nojsonfile: False} assert client.json().set_path(Path.root_path(), root) == result - assert client.json().get(jsonfile.rsplit(".")[0]) == {"hello": "world"} + res = {"hello": "world"} + assert_resp_response( + client, client.json().get(jsonfile.rsplit(".")[0]), res, [[res]] + ) diff --git a/tests/test_lock.py b/tests/test_lock.py index 10ad7e1539..b4b9b32917 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -1,7 +1,6 @@ import time import pytest - from redis.client import Redis from redis.exceptions import LockError, LockNotOwnedError from redis.lock import Lock diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 32f5e23d53..5cda3190a6 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -2,7 +2,6 @@ import multiprocessing import pytest - import redis from redis.connection import Connection, ConnectionPool from redis.exceptions import ConnectionError diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 716cd0fbf6..7b048eec01 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,4 @@ import pytest - import redis from .conftest import skip_if_server_version_lt, wait_for_command @@ -19,7 +18,6 @@ def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert pipe.execute() == [ True, @@ -27,7 +25,6 @@ def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] def test_pipeline_memoryview(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5d86934de6..ba097e3194 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -3,24 +3,39 @@ import socket import threading import time +from collections import defaultdict from unittest import mock from unittest.mock import patch import pytest - import redis from redis.exceptions import ConnectionError +from redis.utils import HIREDIS_AVAILABLE -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + _get_client, + is_resp2_connection, + skip_if_redis_enterprise, + skip_if_server_version_lt, +) -def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False): +def wait_for_message( + pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None, func=None +): now = time.time() timeout = now + timeout while now < timeout: - message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages - ) + if node: + message = pubsub.get_sharded_message( + ignore_subscribe_messages=ignore_subscribe_messages, target_node=node + ) + elif func: + message = func(ignore_subscribe_messages=ignore_subscribe_messages) + else: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) if message is not None: return message time.sleep(0.01) @@ -47,6 +62,15 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.unsubscribe, "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } + elif type == "shard_channel": + return { + "p": pubsub, + "sub_type": "ssubscribe", + "unsub_type": "sunsubscribe", + "sub_func": pubsub.ssubscribe, + "unsub_func": pubsub.sunsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } elif type == "pattern": return { "p": pubsub, @@ -87,6 +111,44 @@ def test_pattern_subscribe_unsubscribe(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe_cluster(self, r): + node_channels = defaultdict(int) + p = r.pubsub() + keys = { + "foo": r.get_node_from_key("foo"), + "bar": r.get_node_from_key("bar"), + "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), + } + + for key, node in keys.items(): + assert p.ssubscribe(key) is None + + # should be a message for each shard_channel we just subscribed to + for key, node in keys.items(): + node_channels[node.name] += 1 + assert wait_for_message(p, node=node) == make_message( + "ssubscribe", key, node_channels[node.name] + ) + + for key in keys.keys(): + assert p.sunsubscribe(key) is None + + # should be a message for each shard_channel we just unsubscribed + # from + for key, node in keys.items(): + node_channels[node.name] -= 1 + assert wait_for_message(p, node=node) == make_message( + "sunsubscribe", key, node_channels[node.name] + ) + def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -130,6 +192,12 @@ def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_resubscribe_on_reconnection(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_resubscribe_to_shard_channels_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_resubscribe_on_reconnection(**kwargs) + def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -186,38 +254,111 @@ def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribed_property(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribed_property(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels_cluster(self, r): + p = r.pubsub() + keys = ["foo", "bar", "uni" + chr(4456) + "code"] + nodes = [r.get_node_from_key(key) for key in keys] + assert p.subscribed is False + p.ssubscribe(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all shard_channels + p.sunsubscribe() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + p.ssubscribe(keys[0]) + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + + # unsubscribe again + p.sunsubscribe() + assert p.subscribed is True + # subscribe to another shard_channel before reading the unsubscribe response + p.ssubscribe(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert wait_for_message(p, node=nodes[1]) == make_message( + "ssubscribe", keys[1], 1 + ) + p.sunsubscribe() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert wait_for_message(p, node=nodes[1]) == make_message( + "sunsubscribe", keys[1], 0 + ) + # now we're finally unsubscribed + assert p.subscribed is False + + @skip_if_server_version_lt("7.0.0") def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - assert wait_for_message(p) is None + assert wait_for_message(p, func=get_func) is None assert p.subscribed is False + @skip_if_server_version_lt("7.0.0") def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - message = wait_for_message(p, ignore_subscribe_messages=True) + message = wait_for_message(p, ignore_subscribe_messages=True, func=get_func) assert message is None assert p.subscribed is False @@ -230,6 +371,12 @@ def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_resub(**kwargs) + def _test_sub_unsub_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -244,6 +391,26 @@ def _test_sub_unsub_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe(key) + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + def test_sub_unsub_all_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_sub_unsub_all_resub(**kwargs) @@ -252,6 +419,12 @@ def test_sub_unsub_all_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_all_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_all_resub(**kwargs) + def _test_sub_unsub_all_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -266,6 +439,26 @@ def _test_sub_unsub_all_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe() + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + class TestPubSubMessages: def setup_method(self, method): @@ -284,6 +477,32 @@ def test_published_message_to_channel(self, r): assert isinstance(message, dict) assert message == make_message("message", "foo", "test message") + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p) == make_message("ssubscribe", "foo", 1) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel_cluster(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", "foo", 1 + ) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p, func=p.get_sharded_message) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + def test_published_message_to_pattern(self, r): p = r.pubsub() p.subscribe("foo") @@ -315,6 +534,15 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(foo=self.message_handler) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == make_message("smessage", "foo", "test message") + @pytest.mark.onlynoncluster def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -336,6 +564,17 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") + @skip_if_server_version_lt("7.0.0") + def test_unicode_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + p.ssubscribe(**channels) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert r.spublish(channel, "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == make_message("smessage", channel, "test message") + @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html # #known-limitations-with-pubsub @@ -352,6 +591,36 @@ def test_unicode_pattern_message_handler(self, r): ) +class TestPubSubRESP3Handler: + def my_handler(self, message): + self.message = ["my handler", message] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + def test_push_handler(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.subscribe("foo") + assert wait_for_message(p) is None + assert self.message == ["my handler", [b"subscribe", b"foo", 1]] + assert r.publish("foo", "test message") == 1 + assert wait_for_message(p) is None + assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @skip_if_server_version_lt("7.0.0") + def test_push_handler_sharded_pubsub(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"ssubscribe", b"foo", 1]] + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"smessage", b"foo", b"test message"]] + + class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" @@ -388,6 +657,19 @@ def test_pattern_subscribe_unsubscribe(self, r): p.punsubscribe(self.pattern) assert wait_for_message(p) == self.make_message("punsubscribe", self.pattern, 0) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) + + p.sunsubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "sunsubscribe", self.channel, 0 + ) + def test_channel_publish(self, r): p = r.pubsub() p.subscribe(self.channel) @@ -407,6 +689,18 @@ def test_pattern_publish(self, r): "pmessage", self.channel, self.data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_publish(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "smessage", self.channel, self.data + ) + def test_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.subscribe(**{self.channel: self.message_handler}) @@ -445,6 +739,30 @@ def test_pattern_message_handler(self, r): "pmessage", self.channel, new_data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(**{self.channel: self.message_handler}) + assert wait_for_message(p, func=p.get_sharded_message) is None + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + try: + # cluster mode + p.disconnect() + except AttributeError: + # standalone mode + p.connection.disconnect() + # should reconnect + assert wait_for_message(p, func=p.get_sharded_message) is None + new_data = self.data + "new data" + r.spublish(self.channel, new_data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, new_data) + def test_context_manager(self, r): with r.pubsub() as pubsub: pubsub.subscribe("foo") @@ -474,6 +792,38 @@ def test_pubsub_channels(self, r): expected = [b"bar", b"baz", b"foo", b"quux"] assert all([channel in r.pubsub_channels() for channel in expected]) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels(self, r): + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert wait_for_message(p)["type"] == "ssubscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in r.pubsub_shardchannels() for channel in expected]) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels_cluster(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + b"quux": r.get_node_from_key("quux"), + } + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for node in channels.values(): + assert wait_for_message(p, node=node)["type"] == "ssubscribe" + for channel, node in channels.items(): + assert channel in r.pubsub_shardchannels(target_nodes=node) + assert all( + [ + channel in r.pubsub_shardchannels(target_nodes="all") + for channel in channels.keys() + ] + ) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.0") def test_pubsub_numsub(self, r): @@ -500,6 +850,32 @@ def test_pubsub_numpat(self, r): assert wait_for_message(p)["type"] == "psubscribe" assert r.pubsub_numpat() == 3 + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardnumsub(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + } + p1 = r.pubsub() + p1.ssubscribe(*channels.keys()) + for node in channels.values(): + assert wait_for_message(p1, node=node)["type"] == "ssubscribe" + p2 = r.pubsub() + p2.ssubscribe("bar", "baz") + for i in range(2): + assert ( + wait_for_message(p2, func=p2.get_sharded_message)["type"] + == "ssubscribe" + ) + p3 = r.pubsub() + p3.ssubscribe("baz") + assert wait_for_message(p3, node=channels[b"baz"])["type"] == "ssubscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels + class TestPubSubPings: @skip_if_server_version_lt("3.0.0") @@ -767,13 +1143,15 @@ def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert is_connected() - with patch("redis.connection.PythonParser.read_response") as mock1: + with patch("redis._parsers._RESP2Parser.read_response") as mock1, patch( + "redis._parsers._HiredisParser.read_response" + ) as mock2, patch("redis._parsers._RESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") - with patch("redis.connection.HiredisParser.read_response") as mock2: - mock2.side_effect = BaseException("boom") + mock2.side_effect = BaseException("boom") + mock3.side_effect = BaseException("boom") - with pytest.raises(BaseException): - get_msg() + with pytest.raises(BaseException): + get_msg() # the timeout on the read should not cause disconnect assert is_connected() diff --git a/tests/test_retry.py b/tests/test_retry.py index 3cfea5c09e..e9d3015897 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,7 +1,6 @@ from unittest.mock import patch import pytest - from redis.backoff import ExponentialBackoff, NoBackoff from redis.client import Redis from redis.connection import Connection, UnixDomainSocketConnection diff --git a/tests/test_scripting.py b/tests/test_scripting.py index b6b5f9fb70..899dc69482 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,5 +1,4 @@ import pytest - import redis from redis import exceptions from redis.commands.core import Script diff --git a/tests/test_search.py b/tests/test_search.py index 7a2428151e..2e42aaba57 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -5,7 +5,6 @@ from io import TextIOWrapper import pytest - import redis import redis.commands.search import redis.commands.search.aggregation as aggregations @@ -24,7 +23,13 @@ from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from .conftest import skip_if_redis_enterprise, skip_ifmodversion_lt +from .conftest import ( + _get_client, + assert_resp_response, + is_resp2_connection, + skip_if_redis_enterprise, + skip_ifmodversion_lt, +) WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -40,12 +45,16 @@ def waitForIndex(env, idx, timeout=None): while True: res = env.execute_command("FT.INFO", idx) try: - res.index("indexing") + if int(res[res.index("indexing") + 1]) == 0: + break except ValueError: break - - if int(res[res.index("indexing") + 1]) == 0: - break + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break time.sleep(delay) if timeout is not None: @@ -98,9 +107,10 @@ def createIndex(client, num_docs=100, definition=None): @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(request): + r = _get_client(redis.Redis, request, decode_responses=True) + r.flushdb() + return r @pytest.mark.redismod @@ -133,84 +143,170 @@ def test_client(client): assert num_docs == int(info["num_docs"]) res = client.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc["id"] - assert doc.play == "Henry IV" - assert doc["play"] == "Henry IV" + if is_resp2_connection(client): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc["id"] + assert doc.play == "Henry IV" + assert doc["play"] == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content()).total + vtotal = client.ft().search(Query("kings").no_content().verbatim()).total + assert total > vtotal + + # test in fields + txt_total = ( + client.ft().search(Query("henry").no_content().limit_fields("txt")).total + ) + play_total = ( + client.ft().search(Query("henry").no_content().limit_fields("play")).total + ) + both_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play", "txt")) + .total + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search(Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search(Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search(Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") + # test in-keys + ids = [x.id for x in client.ft().search(Query("henry")).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king")).total + assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total + assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total + assert 53 == client.ft().search(Query("henry king").slop(0)).total + assert 167 == client.ft().search(Query("henry king").slop(100)).total + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + client.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["extra_attributes"]["play"] == "Henry IV" + assert len(doc["extra_attributes"]["txt"]) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "extra_attributes" not in doc.keys() + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content())["total_results"] + vtotal = client.ft().search(Query("kings").no_content().verbatim())[ + "total_results" + ] + assert total > vtotal + + # test in fields + txt_total = client.ft().search(Query("henry").no_content().limit_fields("txt"))[ + "total_results" + ] + play_total = client.ft().search( + Query("henry").no_content().limit_fields("play") + )["total_results"] + both_total = client.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + )["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [x["id"] for x in client.ft().search(Query("henry"))["results"]] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king"))["total_results"] + assert ( + 3 + == client.ft().search(Query("henry king").slop(0).in_order())[ + "total_results" + ] + ) + assert ( + 52 + == client.ft().search(Query("king henry").slop(0).in_order())[ + "total_results" + ] + ) + assert 53 == client.ft().search(Query("henry king").slop(0))["total_results"] + assert 167 == client.ft().search(Query("henry king").slop(100))["total_results"] + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + client.ft().delete_document("doc-5ghs2") @pytest.mark.redismod @@ -223,12 +319,16 @@ def test_scores(client): q = Query("foo ~bar").with_scores() res = client.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] @pytest.mark.redismod @@ -241,8 +341,12 @@ def test_stopwords(client): q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total + if is_resp2_connection(client): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] @pytest.mark.redismod @@ -262,25 +366,40 @@ def test_filters(client): .no_content() ) res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id + if is_resp2_connection(client): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] # Test geo filter q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id + if is_resp2_connection(client): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] + + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res @pytest.mark.redismod @@ -295,14 +414,24 @@ def test_sort_by(client): q2 = Query("foo").sort_by("num", asc=False).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id + if is_resp2_connection(client): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] @pytest.mark.redismod @@ -417,27 +546,50 @@ def test_no_index(client): ) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total + if is_resp2_connection(client): + res = client.ft().search(Query("@text:aa*")) + assert 0 == res.total + + res = client.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = client.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] + + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] # Ensure exception is raised for non-indexable, non-sortable fields with pytest.raises(Exception): @@ -472,21 +624,38 @@ def test_summarize(client): q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(client): + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry IV" == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry ... " == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) @pytest.mark.redismod @@ -506,25 +675,46 @@ def test_alias(client): index1.hset("index1:lonestar", mapping={"name": "lonestar"}) index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id + if is_resp2_connection(client): + res = ftindex1.search("*").docs[0] + assert "index1:lonestar" == res.id - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient(client).ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient(client).ft("spaceballs") + res = alias_client.search("*").docs[0] + assert "index1:lonestar" == res.id - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") + + res = alias_client2.search("*").docs[0] + assert "index2:yogurt" == res.id + else: + res = ftindex1.search("*")["results"][0] + assert "index1:lonestar" == res["id"] - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(client).ft("spaceballs") + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient(client).ft("spaceballs") + res = alias_client.search("*")["results"][0] + assert "index1:lonestar" == res["id"] - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") + + res = alias_client2.search("*")["results"][0] + assert "index2:yogurt" == res["id"] ftindex2.aliasdel("spaceballs") with pytest.raises(Exception): @@ -532,9 +722,9 @@ def test_alias(client): @pytest.mark.redismod +@pytest.mark.xfail(strict=False) def test_alias_basic(client): # Creating a client with one index - getClient(client).flushdb() index1 = getClient(client).ft("testAlias") index1.create_index((TextField("txt"),)) @@ -547,18 +737,32 @@ def test_alias_basic(client): # add the actual alias and check index1.aliasadd("myalias") alias_client = getClient(client).ft("myalias") - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient(client).ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id + if is_resp2_connection(client): + res = sorted(alias_client.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted(alias_client.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted(alias_client2.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] # delete the alias and expect an error if we try to query again index2.aliasdel("myalias") @@ -573,8 +777,12 @@ def test_textfield_sortable_nostem(client): # Now get the index info to confirm its contents response = client.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] + if is_resp2_connection(client): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] @pytest.mark.redismod @@ -595,7 +803,10 @@ def test_alter_schema_add(client): # Ensure we find the result searching on the added body field res = client.ft().search(q) - assert 1 == res.total + if is_resp2_connection(client): + assert 1 == res.total + else: + assert 1 == res["total_results"] @pytest.mark.redismod @@ -608,33 +819,64 @@ def test_spell_check(client): client.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} + if is_resp2_connection(client): + + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = client.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" in res["results"]["impornant"][0].keys() + + res = client.ft().spellcheck("contnt") + assert "content" in res["results"]["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {"results": {"vlis": []}} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" in res["results"]["vlis"][0].keys() + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["results"]["lorm"]) == 3 + assert "lorem" in res["results"]["lorm"][0].keys() + assert "lore" in res["results"]["lorm"][1].keys() + assert "lorm" in res["results"]["lorm"][2].keys() + assert ( + res["results"]["lorm"][0]["lorem"], + res["results"]["lorm"][1]["lore"], + ) == (0.5, 0) + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {"results": {}} @pytest.mark.redismod @@ -650,7 +892,7 @@ def test_dict_operations(client): # Dump dict and inspect content res = client.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res + assert_resp_response(client, res, ["item1", "item3"], {"item1", "item3"}) # Remove rest of the items before reload client.ft().dict_del("custom_dict", *res) @@ -663,8 +905,12 @@ def test_phonetic_matcher(client): client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] # Drop and create index with phonetic matcher client.flushdb() @@ -674,8 +920,14 @@ def test_phonetic_matcher(client): client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted( + d["extra_attributes"]["name"] for d in res["results"] + ) @pytest.mark.redismod @@ -694,20 +946,36 @@ def test_scorer(client): ) # default scorer is TFIDF - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.1111111111111111 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + if is_resp2_connection(client): + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + else: + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res["results"][0]["score"] @pytest.mark.redismod @@ -788,101 +1056,212 @@ def test_aggregations_groupby(client): }, ) - req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) + if is_resp2_connection(client): + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinct("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinctish("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.sum("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.min("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.max("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.avg("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - index = res.index("__generated_aliasavgrandom_num") - assert res[index + 1] == "7" # (10+3+8)/3 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + index = res.index("__generated_aliasavgrandom_num") + assert res[index + 1] == "7" # (10+3+8)/3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.stddev("random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.quantile("@random_num", 0.5) - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.tolist("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.first_value("@title").alias("first") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) - res = client.ft().aggregate(req).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = client.ft().aggregate(req).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount_distinctishtitle"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasstddevrandom_num"] + == "3.60555127546" + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] == "8" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert "random" in res["extra_attributes"].keys() + assert len(res["extra_attributes"]["random"]) == 2 + assert res["extra_attributes"]["random"][0] in [ + "RediSearch", + "RedisAI", + "RedisJson", + ] @pytest.mark.redismod @@ -892,30 +1271,56 @@ def test_aggregations_sort_by_and_limit(client): client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + if is_resp2_connection(client): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req)["results"] + assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} + assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req)["results"] + assert res[0]["extra_attributes"] == {"t1": "a"} + assert res[1]["extra_attributes"] == {"t1": "b"} + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"t1": "b"} @pytest.mark.redismod @@ -924,20 +1329,36 @@ def test_aggregations_load(client): client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) - # load t1 - req = aggregations.AggregateRequest("*").load("t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello"] + if is_resp2_connection(client): + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello"] - # load t2 - req = aggregations.AggregateRequest("*").load("t2") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "world"] + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "world"] - # load all - req = aggregations.AggregateRequest("*").load() - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello", "t2", "world"] + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello", "t2", "world"] + else: + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t1": "hello"} + + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t2": "world"} + + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t1": "hello", "t2": "world"} @pytest.mark.redismod @@ -962,8 +1383,17 @@ def test_aggregations_apply(client): CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" ) res = client.ft().aggregate(req) - res_set = set([res.rows[0][1], res.rows[1][1]]) - assert res_set == set(["6373878785249699840", "6373878758592700416"]) + if is_resp2_connection(client): + res_set = set([res.rows[0][1], res.rows[1][1]]) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) + else: + res_set = set( + [ + res["results"][0]["extra_attributes"]["CreatedDateTimeUTC"], + res["results"][1]["extra_attributes"]["CreatedDateTimeUTC"], + ], + ) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) @pytest.mark.redismod @@ -982,19 +1412,34 @@ def test_aggregations_filter(client): .dialect(dialect) ) res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["name", "foo", "age", "19"] - - req = ( - aggregations.AggregateRequest("*") - .filter("@age > 15") - .sort_by("@age") - .dialect(dialect) - ) - res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["age", "19"] - assert res.rows[1] == ["age", "25"] + if is_resp2_connection(client): + assert len(res.rows) == 1 + assert res.rows[0] == ["name", "foo", "age", "19"] + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["age", "19"] + assert res.rows[1] == ["age", "25"] + else: + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"name": "foo", "age": "19"} + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"age": "19"} + assert res["results"][1]["extra_attributes"] == {"age": "25"} @pytest.mark.redismod @@ -1060,7 +1505,11 @@ def test_skip_initial_scan(client): q = Query("@foo:bar") client.ft().create_index((TextField("foo"),), skip_initial_scan=True) - assert 0 == client.ft().search(q).total + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 @pytest.mark.redismod @@ -1148,10 +1597,15 @@ def test_create_client_definition_json(client): client.json().set("king:2", Path.root_path(), {"name": "james"}) res = client.ft().search("henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 + if is_resp2_connection(client): + assert res.docs[0].id == "king:1" + assert res.docs[0].payload is None + assert res.docs[0].json == '{"name":"henry"}' + assert res.total == 1 + else: + assert res["results"][0]["id"] == "king:1" + assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry"}' + assert res["total_results"] == 1 @pytest.mark.redismod @@ -1169,11 +1623,17 @@ def test_fields_as_name(client): res = client.json().set("doc:1", Path.root_path(), {"name": "Jon", "age": 25}) assert res - total = client.ft().search(Query("Jon").return_fields("name", "just_a_number")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "Jon" == total[0].name - assert "25" == total[0].just_a_number + res = client.ft().search(Query("Jon").return_fields("name", "just_a_number")) + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "doc:1" == res.docs[0].id + assert "Jon" == res.docs[0].name + assert "25" == res.docs[0].just_a_number + else: + assert 1 == len(res["results"]) + assert "doc:1" == res["results"][0]["id"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] + assert "25" == res["results"][0]["extra_attributes"]["just_a_number"] @pytest.mark.redismod @@ -1184,11 +1644,16 @@ def test_casesensitive(client): client.ft().client.hset("1", "t", "HELLO") client.ft().client.hset("2", "t", "hello") - res = client.ft().search("@t:{HELLO}").docs + res = client.ft().search("@t:{HELLO}") - assert 2 == len(res) - assert "1" == res[0].id - assert "2" == res[1].id + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert "1" == res.docs[0].id + assert "2" == res.docs[1].id + else: + assert 2 == len(res["results"]) + assert "1" == res["results"][0]["id"] + assert "2" == res["results"][1]["id"] # create casesensitive index client.ft().dropindex() @@ -1196,9 +1661,13 @@ def test_casesensitive(client): client.ft().create_index(SCHEMA) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search("@t:{HELLO}").docs - assert 1 == len(res) - assert "1" == res[0].id + res = client.ft().search("@t:{HELLO}") + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "1" == res.docs[0].id + else: + assert 1 == len(res["results"]) + assert "1" == res["results"][0]["id"] @pytest.mark.redismod @@ -1217,15 +1686,26 @@ def test_search_return_fields(client): client.ft().create_index(SCHEMA, definition=definition) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt + if is_resp2_connection(client): + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "riceratops" == total[0].txt - total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "telmatosaurus" == total[0].txt + else: + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "riceratops" == total["results"][0]["extra_attributes"]["txt"] + + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"] @pytest.mark.redismod @@ -1242,9 +1722,14 @@ def test_synupdate(client): client.hset("doc2", mapping={"title": "he is another baby", "body": "another test"}) res = client.ft().search(Query("child").expander("SYNONYM")) - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" + if is_resp2_connection(client): + assert res.docs[0].id == "doc2" + assert res.docs[0].title == "he is another baby" + assert res.docs[0].body == "another test" + else: + assert res["results"][0]["id"] == "doc2" + assert res["results"][0]["extra_attributes"]["title"] == "he is another baby" + assert res["results"][0]["extra_attributes"]["body"] == "another test" @pytest.mark.redismod @@ -1284,15 +1769,28 @@ def test_create_json_with_alias(client): client.json().set("king:1", Path.root_path(), {"name": "henry", "num": 42}) client.json().set("king:2", Path.root_path(), {"name": "james", "num": 3.14}) - res = client.ft().search("@name:henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","num":42}' - assert res.total == 1 - - res = client.ft().search("@num:[0 10]") - assert res.docs[0].id == "king:2" - assert res.docs[0].json == '{"name":"james","num":3.14}' - assert res.total == 1 + if is_resp2_connection(client): + res = client.ft().search("@name:henry") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","num":42}' + assert res.total == 1 + + res = client.ft().search("@num:[0 10]") + assert res.docs[0].id == "king:2" + assert res.docs[0].json == '{"name":"james","num":3.14}' + assert res.total == 1 + else: + res = client.ft().search("@name:henry") + assert res["results"][0]["id"] == "king:1" + assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry","num":42}' + assert res["total_results"] == 1 + + res = client.ft().search("@num:[0 10]") + assert res["results"][0]["id"] == "king:2" + assert ( + res["results"][0]["extra_attributes"]["$"] == '{"name":"james","num":3.14}' + ) + assert res["total_results"] == 1 # Tests returns an error if path contain special characters (user should # use an alias) @@ -1316,15 +1814,32 @@ def test_json_with_multipath(client): "king:1", Path.root_path(), {"name": "henry", "country": {"name": "england"}} ) - res = client.ft().search("@name:{henry}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 + if is_resp2_connection(client): + res = client.ft().search("@name:{henry}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + + res = client.ft().search("@name:{england}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + else: + res = client.ft().search("@name:{henry}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 - res = client.ft().search("@name:{england}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 + res = client.ft().search("@name:{england}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 @pytest.mark.redismod @@ -1341,21 +1856,40 @@ def test_json_with_jsonpath(client): client.json().set("doc:1", Path.root_path(), {"prod:name": "RediSearch"}) - # query for a supported field succeeds - res = client.ft().search(Query("@name:RediSearch")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].json == '{"prod:name":"RediSearch"}' + if is_resp2_connection(client): + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].json == '{"prod:name":"RediSearch"}' + + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res.total == 1 + + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].name == "RediSearch" + else: + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert ( + res["results"][0]["extra_attributes"]["$"] == '{"prod:name":"RediSearch"}' + ) - # query for an unsupported field - res = client.ft().search("@name_unsupported:RediSearch") - assert res.total == 1 + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res["total_results"] == 1 - # return of a supported field succeeds - res = client.ft().search(Query("@name:RediSearch").return_field("name")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].name == "RediSearch" + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert res["results"][0]["extra_attributes"]["name"] == "RediSearch" @pytest.mark.redismod @@ -1368,24 +1902,43 @@ def test_profile(client): # check using Query q = Query("hello|world").no_content() - res, det = client.ft().profile(q) - assert det["Iterators profile"]["Counter"] == 2.0 - assert len(det["Iterators profile"]["Child iterators"]) == 2 - assert det["Iterators profile"]["Type"] == "UNION" - assert det["Parsing time"] < 0.5 - assert len(res.docs) == 2 # check also the search result - - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res, det = client.ft().profile(req) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "WILDCARD" - assert isinstance(det["Parsing time"], float) - assert len(res.rows) == 2 # check also the search result + if is_resp2_connection(client): + res, det = client.ft().profile(q) + assert det["Iterators profile"]["Counter"] == 2.0 + assert len(det["Iterators profile"]["Child iterators"]) == 2 + assert det["Iterators profile"]["Type"] == "UNION" + assert det["Parsing time"] < 0.5 + assert len(res.docs) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res, det = client.ft().profile(req) + assert det["Iterators profile"]["Counter"] == 2 + assert det["Iterators profile"]["Type"] == "WILDCARD" + assert isinstance(det["Parsing time"], float) + assert len(res.rows) == 2 # check also the search result + else: + res = client.ft().profile(q) + assert res["profile"]["Iterators profile"][0]["Counter"] == 2.0 + assert res["profile"]["Iterators profile"][0]["Type"] == "UNION" + assert res["profile"]["Parsing time"] < 0.5 + assert len(res["results"]) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res = client.ft().profile(req) + assert res["profile"]["Iterators profile"][0]["Counter"] == 2 + assert res["profile"]["Iterators profile"][0]["Type"] == "WILDCARD" + assert isinstance(res["profile"]["Parsing time"], float) + assert len(res["results"]) == 2 # check also the search result @pytest.mark.redismod @@ -1398,134 +1951,174 @@ def test_profile_limited(client): client.ft().client.hset("4", "t", "helowa") q = Query("%hell% hel*") - res, det = client.ft().profile(q, limited=True) - assert ( - det["Iterators profile"]["Child iterators"][0]["Child iterators"] - == "The number of iterators in the union is 3" - ) - assert ( - det["Iterators profile"]["Child iterators"][1]["Child iterators"] - == "The number of iterators in the union is 4" - ) - assert det["Iterators profile"]["Type"] == "INTERSECT" - assert len(res.docs) == 3 # check also the search result + if is_resp2_connection(client): + res, det = client.ft().profile(q, limited=True) + assert ( + det["Iterators profile"]["Child iterators"][0]["Child iterators"] + == "The number of iterators in the union is 3" + ) + assert ( + det["Iterators profile"]["Child iterators"][1]["Child iterators"] + == "The number of iterators in the union is 4" + ) + assert det["Iterators profile"]["Type"] == "INTERSECT" + assert len(res.docs) == 3 # check also the search result + else: + res = client.ft().profile(q, limited=True) + iterators_profile = res["profile"]["Iterators profile"] + assert ( + iterators_profile[0]["Child iterators"][0]["Child iterators"] + == "The number of iterators in the union is 3" + ) + assert ( + iterators_profile[0]["Child iterators"][1]["Child iterators"] + == "The number of iterators in the union is 4" + ) + assert iterators_profile[0]["Type"] == "INTERSECT" + assert len(res["results"]) == 3 # check also the search result @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_profile_query_params(modclient: redis.Redis): - modclient.flushdb() - modclient.ft().create_index( +def test_profile_query_params(client): + client.ft().create_index( ( VectorField( "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} ), ) ) - modclient.hset("a", "v", "aaaaaaaa") - modclient.hset("b", "v", "aaaabaaa") - modclient.hset("c", "v", "aaaaabaa") + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) - res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "VECTOR" - assert res.total == 2 - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") + if is_resp2_connection(client): + res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + assert det["Iterators profile"]["Counter"] == 2.0 + assert det["Iterators profile"]["Type"] == "VECTOR" + assert res.total == 2 + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + assert res["profile"]["Iterators profile"][0]["Counter"] == 2 + assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR" + assert res["total_results"] == 2 + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_vector_field(modclient): - modclient.flushdb() - modclient.ft().create_index( +def test_vector_field(client): + client.flushdb() + client.ft().create_index( ( VectorField( "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} ), ) ) - modclient.hset("a", "v", "aaaaaaaa") - modclient.hset("b", "v", "aaaabaaa") - modclient.hset("c", "v", "aaaaabaa") + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) - res = modclient.ft().search(q, query_params={"vec": "aaaaaaaa"}) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") + if is_resp2_connection(client): + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_vector_field_error(modclient): - modclient.flushdb() +def test_vector_field_error(r): + r.flushdb() # sortable tag with pytest.raises(Exception): - modclient.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) + r.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) # not supported algorithm with pytest.raises(Exception): - modclient.ft().create_index((VectorField("v", "SORT", {}),)) + r.ft().create_index((VectorField("v", "SORT", {}),)) @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_text_params(modclient): - modclient.flushdb() - modclient.ft().create_index((TextField("name"),)) +def test_text_params(client): + client.flushdb() + client.ft().create_index((TextField("name"),)) - modclient.hset("doc1", mapping={"name": "Alice"}) - modclient.hset("doc2", mapping={"name": "Bob"}) - modclient.hset("doc3", mapping={"name": "Carol"}) + client.hset("doc1", mapping={"name": "Alice"}) + client.hset("doc2", mapping={"name": "Bob"}) + client.hset("doc3", mapping={"name": "Carol"}) params_dict = {"name1": "Alice", "name2": "Bob"} q = Query("@name:($name1 | $name2 )").dialect(2) - res = modclient.ft().search(q, query_params=params_dict) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id + res = client.ft().search(q, query_params=params_dict) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_numeric_params(modclient): - modclient.flushdb() - modclient.ft().create_index((NumericField("numval"),)) +def test_numeric_params(client): + client.flushdb() + client.ft().create_index((NumericField("numval"),)) - modclient.hset("doc1", mapping={"numval": 101}) - modclient.hset("doc2", mapping={"numval": 102}) - modclient.hset("doc3", mapping={"numval": 103}) + client.hset("doc1", mapping={"numval": 101}) + client.hset("doc2", mapping={"numval": 102}) + client.hset("doc3", mapping={"numval": 103}) params_dict = {"min": 101, "max": 102} q = Query("@numval:[$min $max]").dialect(2) - res = modclient.ft().search(q, query_params=params_dict) + res = client.ft().search(q, query_params=params_dict) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id + if is_resp2_connection(client): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_geo_params(modclient): +def test_geo_params(client): - modclient.flushdb() - modclient.ft().create_index((GeoField("g"))) - modclient.hset("doc1", mapping={"g": "29.69465, 34.95126"}) - modclient.hset("doc2", mapping={"g": "29.69350, 34.94737"}) - modclient.hset("doc3", mapping={"g": "29.68746, 34.94882"}) + client.ft().create_index((GeoField("g"))) + client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) + client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) + client.hset("doc3", mapping={"g": "29.68746, 34.94882"}) params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} q = Query("@g:[$lon $lat $radius $units]").dialect(2) - res = modclient.ft().search(q, query_params=params_dict) - assert 3 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "doc3" == res.docs[2].id + res = client.ft().search(q, query_params=params_dict) + if is_resp2_connection(client): + assert 3 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + assert "doc3" == res.docs[2].id + else: + assert 3 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] + assert "doc3" == res["results"][2]["id"] @pytest.mark.redismod @@ -1538,29 +2131,41 @@ def test_search_commands_in_pipeline(client): q = Query("foo bar").with_payloads() p.search(q) res = p.execute() - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] + if is_resp2_connection(client): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["extra_attributes"] + == res[3]["results"][1]["extra_attributes"] + == {"txt": "foo bar"} + ) @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.4.3", "search") -def test_dialect_config(modclient: redis.Redis): - assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "1"} - assert modclient.ft().config_set("DEFAULT_DIALECT", 2) - assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} +def test_dialect_config(client): + assert client.ft().config_get("DEFAULT_DIALECT") + client.ft().config_set("DEFAULT_DIALECT", 2) + assert client.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} with pytest.raises(redis.ResponseError): - modclient.ft().config_set("DEFAULT_DIALECT", 0) + client.ft().config_set("DEFAULT_DIALECT", 0) @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_dialect(modclient: redis.Redis): - modclient.ft().create_index( +def test_dialect(client): + client.ft().create_index( ( TagField("title"), TextField("t1"), @@ -1571,68 +2176,94 @@ def test_dialect(modclient: redis.Redis): ), ) ) - modclient.hset("h", "t1", "hello") + client.hset("h", "t1", "hello") with pytest.raises(redis.ResponseError) as err: - modclient.ft().explain(Query("(*)").dialect(1)) + client.ft().explain(Query("(*)").dialect(1)) assert "Syntax error" in str(err) - assert "WILDCARD" in modclient.ft().explain(Query("(*)").dialect(2)) + assert "WILDCARD" in client.ft().explain(Query("(*)").dialect(2)) with pytest.raises(redis.ResponseError) as err: - modclient.ft().explain(Query("$hello").dialect(1)) + client.ft().explain(Query("$hello").dialect(1)) assert "Syntax error" in str(err) q = Query("$hello").dialect(2) expected = "UNION {\n hello\n +hello(expanded)\n}\n" - assert expected in modclient.ft().explain(q, query_params={"hello": "hello"}) + assert expected in client.ft().explain(q, query_params={"hello": "hello"}) expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n" - assert expected in modclient.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) + assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) with pytest.raises(redis.ResponseError) as err: - modclient.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) + client.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) assert "Syntax error" in str(err) @pytest.mark.redismod -def test_expire_while_search(modclient: redis.Redis): - modclient.ft().create_index((TextField("txt"),)) - modclient.hset("hset:1", "txt", "a") - modclient.hset("hset:2", "txt", "b") - modclient.hset("hset:3", "txt", "c") - assert 3 == modclient.ft().search(Query("*")).total - modclient.pexpire("hset:2", 300) - for _ in range(500): - modclient.ft().search(Query("*")).docs[1] - time.sleep(1) - assert 2 == modclient.ft().search(Query("*")).total +def test_expire_while_search(client: redis.Redis): + client.ft().create_index((TextField("txt"),)) + client.hset("hset:1", "txt", "a") + client.hset("hset:2", "txt", "b") + client.hset("hset:3", "txt", "c") + if is_resp2_connection(client): + assert 3 == client.ft().search(Query("*")).total + client.pexpire("hset:2", 300) + for _ in range(500): + client.ft().search(Query("*")).docs[1] + time.sleep(1) + assert 2 == client.ft().search(Query("*")).total + else: + assert 3 == client.ft().search(Query("*"))["total_results"] + client.pexpire("hset:2", 300) + for _ in range(500): + client.ft().search(Query("*"))["results"][1] + time.sleep(1) + assert 2 == client.ft().search(Query("*"))["total_results"] @pytest.mark.redismod @pytest.mark.experimental -def test_withsuffixtrie(modclient: redis.Redis): +def test_withsuffixtrie(client: redis.Redis): # create index - assert modclient.ft().create_index((TextField("txt"),)) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert modclient.ft().dropindex("idx") - - # create withsuffixtrie index (text fiels) - assert modclient.ft().create_index((TextField("t", withsuffixtrie=True))) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert modclient.ft().dropindex("idx") - - # create withsuffixtrie index (tag field) - assert modclient.ft().create_index((TagField("t", withsuffixtrie=True))) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert client.ft().create_index((TextField("txt"),)) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + if is_resp2_connection(client): + info = client.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert client.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert client.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = client.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert client.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert client.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] @pytest.mark.redismod -def test_query_timeout(modclient: redis.Redis): +def test_query_timeout(r: redis.Redis): q1 = Query("foo").timeout(5000) assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): - modclient.ft().search(q2) + r.ft().search(q2) diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index e5e3d26fa7..b7bcc27de2 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -1,7 +1,6 @@ import socket import pytest - import redis.sentinel from redis import exceptions from redis.sentinel import ( diff --git a/tests/test_ssl.py b/tests/test_ssl.py index c1a981d310..465fdabb89 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -3,7 +3,6 @@ from urllib.parse import urlparse import pytest - import redis from redis.exceptions import ConnectionError, RedisError diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 6ced5359f7..80490af4ef 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -3,16 +3,15 @@ from time import sleep import pytest - import redis -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(decoded_r): + decoded_r.flushdb() + return decoded_r @pytest.mark.redismod @@ -22,13 +21,15 @@ def test_create(client): assert client.ts().create(3, labels={"Redis": "Labs"}) assert client.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) info = client.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] + assert_resp_response( + client, 20, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Series" == info["labels"]["Time"] # Test for a chunk size of 128 Bytes assert client.ts().create("time-serie-1", chunk_size=128) info = client.ts().info("time-serie-1") - assert 128, info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -39,19 +40,33 @@ def test_create_duplicate_policy(client): ts_name = f"time-serie-ooo-{duplicate_policy}" assert client.ts().create(ts_name, duplicate_policy=duplicate_policy) info = client.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy + assert_resp_response( + client, + duplicate_policy, + info.get("duplicate_policy"), + info.get("duplicatePolicy"), + ) @pytest.mark.redismod def test_alter(client): assert client.ts().create(1) - assert 0 == client.ts().info(1).retention_msecs + info = client.ts().info(1) + assert_resp_response( + client, 0, info.get("retention_msecs"), info.get("retentionTime") + ) assert client.ts().alter(1, retention_msecs=10) - assert {} == client.ts().info(1).labels - assert 10, client.ts().info(1).retention_msecs + assert {} == client.ts().info(1)["labels"] + info = client.ts().info(1) + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) assert client.ts().alter(1, labels={"Time": "Series"}) - assert "Series" == client.ts().info(1).labels["Time"] - assert 10 == client.ts().info(1).retention_msecs + assert "Series" == client.ts().info(1)["labels"]["Time"] + info = client.ts().info(1) + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) @pytest.mark.redismod @@ -59,10 +74,14 @@ def test_alter(client): def test_alter_diplicate_policy(client): assert client.ts().create(1) info = client.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + client, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) assert client.ts().alter(1, duplicate_policy="min") info = client.ts().info(1) - assert "min" == info.duplicate_policy + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @@ -77,13 +96,15 @@ def test_add(client): assert abs(time.time() - float(client.ts().add(5, "*", 1)) / 1000) < 1.0 info = client.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Labs" == info["labels"]["Redis"] # Test for a chunk size of 128 Bytes on TS.ADD assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -142,21 +163,21 @@ def test_incrby_decrby(client): assert 0 == client.ts().get(1)[1] assert client.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (5, 1.5), [5, 1.5]) assert client.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (7, 3.75), [7, 3.75]) assert client.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (15, 2.25), [15, 2.25]) # Test for a chunk size of 128 Bytes on TS.INCRBY assert client.ts().incrby("time-serie-1", 10, chunk_size=128) info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) # Test for a chunk size of 128 Bytes on TS.DECRBY assert client.ts().decrby("time-serie-2", 10, chunk_size=128) info = client.ts().info("time-serie-2") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @@ -172,12 +193,15 @@ def test_create_and_delete_rule(client): client.ts().add(1, time * 2, 1.5) assert round(client.ts().get(2)[1], 5) == 1.5 info = client.ts().info(1) - assert info.rules[0][1] == 100 + if is_resp2_connection(client): + assert info.rules[0][1] == 100 + else: + assert info["rules"]["2"][0] == 100 # test rule deletion client.ts().deleterule(1, 2) info = client.ts().info(1) - assert not info.rules + assert not info["rules"] @pytest.mark.redismod @@ -192,7 +216,7 @@ def test_del_range(client): client.ts().add(1, i, i % 7) assert 22 == client.ts().delete(1, 0, 21) assert [] == client.ts().range(1, 0, 21) - assert [(22, 1.0)] == client.ts().range(1, 22, 22) + assert_resp_response(client, client.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]]) @pytest.mark.redismod @@ -227,15 +251,16 @@ def test_range_advanced(client): filter_by_max_value=2, ) ) - assert [(0, 10.0), (10, 1.0)] == client.ts().range( + res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" ) - assert [(0, 5.0), (5, 6.0)] == client.ts().range( + assert_resp_response(client, res, [(0, 10.0), (10, 1.0)], [[0, 10.0], [10, 1.0]]) + res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert [(0, 2.55), (10, 3.0)] == client.ts().range( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 - ) + assert_resp_response(client, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = client.ts().range(1, 0, 10, aggregation_type="twa", bucket_size_msec=10) + assert_resp_response(client, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @pytest.mark.redismod @@ -249,14 +274,18 @@ def test_range_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - res = timeseries.range("t1", 0, 20) - assert res == [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)] - res = timeseries.range("t2", 0, 10) - assert res == [(0, 4.0)] + assert_resp_response( + client, + timeseries.range("t1", 0, 20), + [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)], + [[1, 1.0], [2, 3.0], [11, 7.0], [13, 1.0]], + ) + assert_resp_response(client, timeseries.range("t2", 0, 10), [(0, 4.0)], [[0, 4.0]]) res = timeseries.range("t2", 0, 10, latest=True) - assert res == [(0, 4.0), (10, 8.0)] - res = timeseries.range("t2", 0, 9, latest=True) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0), (10, 8.0)], [[0, 4.0], [10, 8.0]]) + assert_resp_response( + client, timeseries.range("t2", 0, 9, latest=True), [(0, 4.0)], [[0, 4.0]] + ) @pytest.mark.redismod @@ -269,17 +298,27 @@ def test_range_bucket_timestamp(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 - ) - assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( - "t1", - 0, - 100, - align=0, - aggregation_type="max", - bucket_size_msec=10, - bucket_timestamp="+", + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(10, 4.0), (50, 3.0), (70, 5.0)], + [[10, 4.0], [50, 3.0], [70, 5.0]], + ) + assert_resp_response( + client, + timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ), + [(20, 4.0), (60, 3.0), (80, 5.0)], + [[20, 4.0], [60, 3.0], [80, 5.0]], ) @@ -293,8 +332,13 @@ def test_range_empty(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(10, 4.0), (50, 3.0), (70, 5.0)], + [[10, 4.0], [50, 3.0], [70, 5.0]], ) res = timeseries.range( "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True @@ -302,7 +346,7 @@ def test_range_empty(client: redis.Redis): for i in range(len(res)): if math.isnan(res[i][1]): res[i] = (res[i][0], None) - assert [ + resp2_expected = [ (10, 4.0), (20, None), (30, None), @@ -310,7 +354,17 @@ def test_range_empty(client: redis.Redis): (50, 3.0), (60, None), (70, 5.0), - ] == res + ] + resp3_expected = [ + [10, 4.0], + (20, None), + (30, None), + (40, None), + [50, 3.0], + (60, None), + [70, 5.0], + ] + assert_resp_response(client, res, resp2_expected, resp3_expected) @pytest.mark.redismod @@ -337,14 +391,27 @@ def test_rev_range(client): filter_by_max_value=2, ) ) - assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ), + [(10, 1.0), (0, 10.0)], + [[10, 1.0], [0, 10.0]], ) - assert [(1, 10.0), (0, 1.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ), + [(1, 10.0), (0, 1.0)], + [[1, 10.0], [0, 1.0]], ) - assert [(10, 3.0), (0, 2.55)] == client.ts().revrange( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + assert_resp_response( + client, + client.ts().revrange(1, 0, 10, aggregation_type="twa", bucket_size_msec=10), + [(10, 3.0), (0, 2.55)], + [[10, 3.0], [0, 2.55]], ) @@ -360,11 +427,11 @@ def test_revrange_latest(client: redis.Redis): timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) res = timeseries.revrange("t2", 0, 10) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0)], [[0, 4.0]]) res = timeseries.revrange("t2", 0, 10, latest=True) - assert res == [(10, 8.0), (0, 4.0)] + assert_resp_response(client, res, [(10, 8.0), (0, 4.0)], [[10, 8.0], [0, 4.0]]) res = timeseries.revrange("t2", 0, 9, latest=True) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0)], [[0, 4.0]]) @pytest.mark.redismod @@ -377,17 +444,27 @@ def test_revrange_bucket_timestamp(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 - ) - assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( - "t1", - 0, - 100, - align=0, - aggregation_type="max", - bucket_size_msec=10, - bucket_timestamp="+", + assert_resp_response( + client, + timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(70, 5.0), (50, 3.0), (10, 4.0)], + [[70, 5.0], [50, 3.0], [10, 4.0]], + ) + assert_resp_response( + client, + timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ), + [(20, 4.0), (60, 3.0), (80, 5.0)], + [[20, 4.0], [60, 3.0], [80, 5.0]], ) @@ -401,8 +478,13 @@ def test_revrange_empty(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(70, 5.0), (50, 3.0), (10, 4.0)], + [[70, 5.0], [50, 3.0], [10, 4.0]], ) res = timeseries.revrange( "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True @@ -410,7 +492,7 @@ def test_revrange_empty(client: redis.Redis): for i in range(len(res)): if math.isnan(res[i][1]): res[i] = (res[i][0], None) - assert [ + resp2_expected = [ (70, 5.0), (60, None), (50, 3.0), @@ -418,7 +500,17 @@ def test_revrange_empty(client: redis.Redis): (30, None), (20, None), (10, 4.0), - ] == res + ] + resp3_expected = [ + [70, 5.0], + (60, None), + [50, 3.0], + (40, None), + (30, None), + (20, None), + [10, 4.0], + ] + assert_resp_response(client, res, resp2_expected, resp3_expected) @pytest.mark.redismod @@ -432,23 +524,42 @@ def test_mrange(client): res = client.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 100 == len(res[0]["1"][1]) - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert 100 == len(res["1"][2]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) - # test withlabels - assert {} == res[0]["1"][0] - res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + assert {} == res["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res["1"][0] @pytest.mark.redismod @@ -463,49 +574,106 @@ def test_multi_range_advanced(client): # test with selected labels res = client.ts().mrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(client): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test with filterby - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - # test groupby - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + # test groupby + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - # test align - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[15, 1.0], [16, 2.0]] == res["1"][2] + + # test groupby + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] + + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[0, 10.0], [10, 1.0]] == res["1"][2] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [[0, 5.0], [5, 6.0]] == res["1"][2] @pytest.mark.redismod @@ -527,10 +695,15 @@ def test_mrange_latest(client: redis.Redis): timeseries.add("t3", 2, 3) timeseries.add("t3", 11, 7) timeseries.add("t3", 13, 1) - assert client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True) == [ - {"t2": [{}, [(0, 4.0), (10, 8.0)]]}, - {"t4": [{}, [(0, 4.0), (10, 8.0)]]}, - ] + assert_resp_response( + client, + client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True), + [{"t2": [{}, [(0, 4.0), (10, 8.0)]]}, {"t4": [{}, [(0, 4.0), (10, 8.0)]]}], + { + "t2": [{}, {"aggregators": []}, [[0, 4.0], [10, 8.0]]], + "t4": [{}, {"aggregators": []}, [[0, 4.0], [10, 8.0]]], + }, + ) @pytest.mark.redismod @@ -545,10 +718,16 @@ def test_multi_reverse_range(client): res = client.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 100 == len(res[0]["1"][1]) + else: + assert 100 == len(res["1"][2]) res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 10 == len(res[0]["1"][1]) + else: + assert 10 == len(res["1"][2]) for i in range(100): client.ts().add(1, i + 200, i % 7) @@ -556,17 +735,28 @@ def test_multi_reverse_range(client): 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] + if is_resp2_connection(client): + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] + else: + assert 20 == len(res["1"][2]) + assert {} == res["1"][0] # test withlabels res = client.ts().mrevrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + if is_resp2_connection(client): + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert {"Test": "This", "team": "ny"} == res["1"][0] # test with selected labels res = client.ts().mrevrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(client): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] # test filterby res = client.ts().mrevrange( @@ -577,23 +767,36 @@ def test_multi_reverse_range(client): filter_by_min_value=1, filter_by_max_value=2, ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + else: + assert [[16, 2.0], [15, 1.0]] == res["1"][2] # test groupby res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + if is_resp2_connection(client): + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + else: + assert [[3, 6.0], [2, 4.0], [1, 2.0], [0, 0.0]] == res["Test=This"][3] res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="Test", reduce="max" ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + if is_resp2_connection(client): + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + else: + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["Test=This"][3] res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="team", reduce="min" ) assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + if is_resp2_connection(client): + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + else: + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=ny"][3] + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=sf"][3] # test align res = client.ts().mrevrange( @@ -604,7 +807,10 @@ def test_multi_reverse_range(client): bucket_size_msec=10, align="-", ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + else: + assert [[10, 1.0], [0, 10.0]] == res["1"][2] res = client.ts().mrevrange( 0, 10, @@ -613,7 +819,10 @@ def test_multi_reverse_range(client): bucket_size_msec=10, align=1, ) - assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + else: + assert [[1, 10.0], [0, 1.0]] == res["1"][2] @pytest.mark.redismod @@ -635,16 +844,22 @@ def test_mrevrange_latest(client: redis.Redis): timeseries.add("t3", 2, 3) timeseries.add("t3", 11, 7) timeseries.add("t3", 13, 1) - assert client.ts().mrevrange( - 0, 10, filters=["is_compaction=true"], latest=True - ) == [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}] + assert_resp_response( + client, + client.ts().mrevrange(0, 10, filters=["is_compaction=true"], latest=True), + [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}], + { + "t2": [{}, {"aggregators": []}, [[10, 8.0], [0, 4.0]]], + "t4": [{}, {"aggregators": []}, [[10, 8.0], [0, 4.0]]], + }, + ) @pytest.mark.redismod def test_get(client): name = "test" client.ts().create(name) - assert client.ts().get(name) is None + assert not client.ts().get(name) client.ts().add(name, 2, 3) assert 2 == client.ts().get(name)[0] client.ts().add(name, 3, 4) @@ -662,8 +877,10 @@ def test_get_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - assert (0, 4.0) == timeseries.get("t2") - assert (10, 8.0) == timeseries.get("t2", latest=True) + assert_resp_response(client, timeseries.get("t2"), (0, 4.0), [0, 4.0]) + assert_resp_response( + client, timeseries.get("t2", latest=True), (10, 8.0), [10, 8.0] + ) @pytest.mark.redismod @@ -673,19 +890,33 @@ def test_mget(client): client.ts().create(2, labels={"Test": "This", "Taste": "That"}) act_res = client.ts().mget(["Test=This"]) exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res + exp_res_resp3 = {"1": [{}, []], "2": [{}, []]} + assert_resp_response(client, act_res, exp_res, exp_res_resp3) client.ts().add(1, "*", 15) client.ts().add(2, "*", 25) res = client.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] + if is_resp2_connection(client): + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + else: + assert 15 == res["1"][1][1] + assert 25 == res["2"][1][1] res = client.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] + if is_resp2_connection(client): + assert 25 == res[0]["2"][2] + else: + assert 25 == res["2"][1][1] # test with_labels - assert {} == res[0]["2"][0] + if is_resp2_connection(client): + assert {} == res[0]["2"][0] + else: + assert {} == res["2"][0] res = client.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + if is_resp2_connection(client): + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + else: + assert {"Taste": "That", "Test": "This"} == res["2"][0] @pytest.mark.redismod @@ -700,18 +931,20 @@ def test_mget_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - assert timeseries.mget(filters=["is_compaction=true"]) == [{"t2": [{}, 0, 4.0]}] - assert [{"t2": [{}, 10, 8.0]}] == timeseries.mget( - filters=["is_compaction=true"], latest=True - ) + res = timeseries.mget(filters=["is_compaction=true"]) + assert_resp_response(client, res, [{"t2": [{}, 0, 4.0]}], {"t2": [{}, [0, 4.0]]}) + res = timeseries.mget(filters=["is_compaction=true"], latest=True) + assert_resp_response(client, res, [{"t2": [{}, 10, 8.0]}], {"t2": [{}, [10, 8.0]]}) @pytest.mark.redismod def test_info(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" + assert_resp_response( + client, 5, info.get("retention_msecs"), info.get("retentionTime") + ) + assert info["labels"]["currentLabel"] == "currentData" @pytest.mark.redismod @@ -719,11 +952,15 @@ def test_info(client): def testInfoDuplicatePolicy(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + client, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) client.ts().create("time-serie-2", duplicate_policy="min") info = client.ts().info("time-serie-2") - assert "min" == info.duplicate_policy + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @@ -733,7 +970,7 @@ def test_query_index(client): client.ts().create(2, labels={"Test": "This", "Taste": "That"}) assert 2 == len(client.ts().queryindex(["Test=This"])) assert 1 == len(client.ts().queryindex(["Taste=That"])) - assert [2] == client.ts().queryindex(["Taste=That"]) + assert_resp_response(client, client.ts().queryindex(["Taste=That"]), [2], {"2"}) @pytest.mark.redismod @@ -745,8 +982,12 @@ def test_pipeline(client): pipeline.execute() info = client.ts().info("with_pipeline") - assert info.last_timestamp == 99 - assert info.total_samples == 100 + assert_resp_response( + client, 99, info.get("last_timestamp"), info.get("lastTimestamp") + ) + assert_resp_response( + client, 100, info.get("total_samples"), info.get("totalSamples") + ) assert client.ts().get("with_pipeline")[1] == 99 * 1.1 @@ -756,4 +997,7 @@ def test_uncompressed(client): client.ts().create("uncompressed", uncompressed=True) compressed_info = client.ts().info("compressed") uncompressed_info = client.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage + if is_resp2_connection(client): + assert compressed_info.memory_usage != uncompressed_info.memory_usage + else: + assert compressed_info["memoryUsage"] != uncompressed_info["memoryUsage"] diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 553c77b3c6..0000000000 --- a/tox.ini +++ /dev/null @@ -1,379 +0,0 @@ -[pytest] -addopts = -s -markers = - redismod: run only the redis module tests - pipeline: pipeline tests - onlycluster: marks tests to be run only with cluster mode redis - onlynoncluster: marks tests to be run only with standalone redis - ssl: marker for only the ssl tests - asyncio: marker for async tests - replica: replica tests - experimental: run only experimental tests -asyncio_mode = auto - -[tox] -minversion = 3.2.0 -requires = tox-docker -envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py37,py38,py39,pypy3},linters,docs - -[docker:master] -name = master -image = redisfab/redis-py:6.2.6 -ports = - 6379:6379/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6379)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/master/redis.conf:/redis.conf - -[docker:replica] -name = replica -image = redisfab/redis-py:6.2.6 -links = - master:master -ports = - 6380:6380/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6380)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/replica/redis.conf:/redis.conf - -[docker:unstable] -name = unstable -image = redisfab/redis-py:unstable -ports = - 6378:6378/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6378)) else False" -volumes = - bind:rw:{toxinidir}/docker/unstable/redis.conf:/redis.conf - -[docker:unstable_cluster] -name = unstable_cluster -image = redisfab/redis-py-cluster:unstable -ports = - 6372:6372/tcp - 6373:6373/tcp - 6374:6374/tcp - 6375:6375/tcp - 6376:6376/tcp - 6377:6377/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(6372,6377)]) else False" -volumes = - bind:rw:{toxinidir}/docker/unstable_cluster/redis.conf:/redis.conf - -[docker:sentinel_1] -name = sentinel_1 -image = redisfab/redis-py-sentinel:6.2.6 -links = - master:master -ports = - 26379:26379/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26379)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/sentinel/sentinel_1.conf:/sentinel.conf - -[docker:sentinel_2] -name = sentinel_2 -image = redisfab/redis-py-sentinel:6.2.6 -links = - master:master -ports = - 26380:26380/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26380)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/sentinel/sentinel_2.conf:/sentinel.conf - -[docker:sentinel_3] -name = sentinel_3 -image = redisfab/redis-py-sentinel:6.2.6 -links = - master:master -ports = - 26381:26381/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26381)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/sentinel/sentinel_3.conf:/sentinel.conf - -[docker:redis_stack] -name = redis_stack -image = redis/redis-stack-server:edge -ports = - 36379:6379/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',36379)) else False" - -[docker:redis_cluster] -name = redis_cluster -image = redisfab/redis-py-cluster:6.2.6 -ports = - 16379:16379/tcp - 16380:16380/tcp - 16381:16381/tcp - 16382:16382/tcp - 16383:16383/tcp - 16384:16384/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16379,16384)]) else False" -volumes = - bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf - -[docker:redismod_cluster] -name = redismod_cluster -image = redisfab/redis-py-modcluster:edge -ports = - 46379:46379/tcp - 46380:46380/tcp - 46381:46381/tcp - 46382:46382/tcp - 46383:46383/tcp - 46384:46384/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(46379,46384)]) else False" -volumes = - bind:rw:{toxinidir}/docker/redismod_cluster/redis.conf:/redis.conf - -[docker:stunnel] -name = stunnel -image = redisfab/stunnel:latest -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6666)) else False" -links = - master:master -ports = - 6666:6666/tcp -volumes = - bind:ro:{toxinidir}/docker/stunnel/conf:/etc/stunnel/conf.d - bind:ro:{toxinidir}/docker/stunnel/keys:/etc/stunnel/keys - -[docker:redis5_master] -name = redis5_master -image = redisfab/redis-py:5.0-buster -ports = - 6382:6382/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6382)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/master/redis.conf:/redis.conf - -[docker:redis5_replica] -name = redis5_replica -image = redisfab/redis-py:5.0-buster -links = - redis5_master:redis5_master -ports = - 6383:6383/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6383)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/replica/redis.conf:/redis.conf - -[docker:redis5_sentinel_1] -name = redis5_sentinel_1 -image = redisfab/redis-py-sentinel:5.0-buster -links = - redis5_master:redis5_master -ports = - 26382:26382/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26382)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/sentinel/sentinel_1.conf:/sentinel.conf - -[docker:redis5_sentinel_2] -name = redis5_sentinel_2 -image = redisfab/redis-py-sentinel:5.0-buster -links = - redis5_master:redis5_master -ports = - 26383:26383/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26383)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/sentinel/sentinel_2.conf:/sentinel.conf - -[docker:redis5_sentinel_3] -name = redis5_sentinel_3 -image = redisfab/redis-py-sentinel:5.0-buster -links = - redis5_master:redis5_master -ports = - 26384:26384/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26384)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/sentinel/sentinel_3.conf:/sentinel.conf - -[docker:redis5_cluster] -name = redis5_cluster -image = redisfab/redis-py-cluster:5.0-buster -ports = - 16385:16385/tcp - 16386:16386/tcp - 16387:16387/tcp - 16388:16388/tcp - 16389:16389/tcp - 16390:16390/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16385,16390)]) else False" -volumes = - bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf - -[docker:redis4_master] -name = redis4_master -image = redisfab/redis-py:4.0-buster -ports = - 6381:6381/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6381)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/master/redis.conf:/redis.conf - -[docker:redis4_sentinel_1] -name = redis4_sentinel_1 -image = redisfab/redis-py-sentinel:4.0-buster -links = - redis4_master:redis4_master -ports = - 26385:26385/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26385)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/sentinel/sentinel_1.conf:/sentinel.conf - -[docker:redis4_sentinel_2] -name = redis4_sentinel_2 -image = redisfab/redis-py-sentinel:4.0-buster -links = - redis4_master:redis4_master -ports = - 26386:26386/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26386)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/sentinel/sentinel_2.conf:/sentinel.conf - -[docker:redis4_sentinel_3] -name = redis4_sentinel_3 -image = redisfab/redis-py-sentinel:4.0-buster -links = - redis4_master:redis4_master -ports = - 26387:26387/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26387)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/sentinel/sentinel_3.conf:/sentinel.conf - -[docker:redis4_cluster] -name = redis4_cluster -image = redisfab/redis-py-cluster:4.0-buster -ports = - 16391:16391/tcp - 16392:16392/tcp - 16393:16393/tcp - 16394:16394/tcp - 16395:16395/tcp - 16396:16396/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16391,16396)]) else False" -volumes = - bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf - -[isort] -profile = black -multi_line_output = 3 - -[testenv] -deps = - -r {toxinidir}/requirements.txt - -r {toxinidir}/dev_requirements.txt -docker = - unstable - unstable_cluster - master - replica - sentinel_1 - sentinel_2 - sentinel_3 - redis_cluster - redis_stack - stunnel -extras = - hiredis: hiredis - ocsp: cryptography, pyopenssl, requests -setenv = - CLUSTER_URL = "redis://localhost:16379/0" - UNSTABLE_CLUSTER_URL = "redis://localhost:6372/0" -commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --junit-xml=standalone-results.xml {posargs} - standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --junit-xml=standalone-uvloop-results.xml --uvloop {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} --junit-xml=cluster-results.xml {posargs} - cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} --junit-xml=cluster-uvloop-results.xml --uvloop {posargs} - -[testenv:redis5] -deps = - -r {toxinidir}/requirements.txt - -r {toxinidir}/dev_requirements.txt -docker = - redis5_master - redis5_replica - redis5_sentinel_1 - redis5_sentinel_2 - redis5_sentinel_3 - redis5_cluster -extras = - hiredis: hiredis - cryptography: cryptography, requests -setenv = - CLUSTER_URL = "redis://localhost:16385/0" -commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster and not redismod' {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} - -[testenv:redis4] -deps = - -r {toxinidir}/requirements.txt - -r {toxinidir}/dev_requirements.txt -docker = - redis4_master - redis4_sentinel_1 - redis4_sentinel_2 - redis4_sentinel_3 - redis4_cluster -extras = - hiredis: hiredis - cryptography: cryptography, requests -setenv = - CLUSTER_URL = "redis://localhost:16391/0" -commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster and not redismod' {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} - -[testenv:devenv] -skipsdist = true -skip_install = true -deps = -r {toxinidir}/dev_requirements.txt -docker = {[testenv]docker} - -[testenv:linters] -deps_files = dev_requirements.txt -docker = -commands = - flake8 - black --target-version py37 --check --diff . - isort --check-only --diff . - vulture redis whitelist.py --min-confidence 80 - flynt --fail-on-change --dry-run . -skipsdist = true -skip_install = true - -[testenv:docs] -deps = -r docs/requirements.txt -docker = -changedir = {toxinidir}/docs -allowlist_externals = make -commands = make html - -[flake8] -max-line-length = 88 -exclude = - *.egg-info, - *.pyc, - .git, - .tox, - .venv*, - build, - docs/*, - dist, - docker, - venv*, - .venv*, - whitelist.py -ignore = - F405 - W503 - E203 - E126 diff --git a/whitelist.py b/whitelist.py index 8c9cee3c29..29cd529e4d 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,6 +14,5 @@ exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) -AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46)