Skip to content

Commit

Permalink
Make it possible to customize SSL ciphers (#3212)
Browse files Browse the repository at this point in the history
Given that Python 3.10 changed the default list of TLS ciphers, it is a
good idea to allow customization of the list of ciphers when using
Redis with TLS. In some situations the client is unusable right now
with older servers and Python >= 3.10.

Also whitelist a dev dependency vulnerability, and bump version to 5.0.4.

---------

Co-authored-by: Gabriel Erzse <gabriel.erzse@redis.com>
  • Loading branch information
gerzse and Gabriel Erzse committed Apr 23, 2024
1 parent 1784b37 commit e71119d
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/integration.yaml
Expand Up @@ -35,6 +35,7 @@ jobs:
inputs: requirements.txt dev_requirements.txt
ignore-vulns: |
GHSA-w596-4wvx-j9j6 # subversion related git pull, dependency for pytest. There is no impact here.
PYSEC-2024-48 # black vulnerability in 22.3.0, can't upgrade due to python 3.7 support, no impact
lint:
name: Code linters
Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/client.py
Expand Up @@ -221,6 +221,7 @@ def __init__(
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
Expand Down Expand Up @@ -314,6 +315,7 @@ def __init__(
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
# This arg only used if no pool is passed in
Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/cluster.py
Expand Up @@ -267,6 +267,7 @@ def __init__(
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
Expand Down Expand Up @@ -326,6 +327,7 @@ def __init__(
"ssl_check_hostname": ssl_check_hostname,
"ssl_keyfile": ssl_keyfile,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)

Expand Down
7 changes: 7 additions & 0 deletions redis/asyncio/connection.py
Expand Up @@ -739,6 +739,7 @@ def __init__(
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
**kwargs,
):
self.ssl_context: RedisSSLContext = RedisSSLContext(
Expand All @@ -749,6 +750,7 @@ def __init__(
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)

Expand Down Expand Up @@ -796,6 +798,7 @@ class RedisSSLContext:
"context",
"check_hostname",
"min_version",
"ciphers",
)

def __init__(
Expand All @@ -807,6 +810,7 @@ def __init__(
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[ssl.TLSVersion] = None,
ciphers: Optional[str] = None,
):
self.keyfile = keyfile
self.certfile = certfile
Expand All @@ -827,6 +831,7 @@ def __init__(
self.ca_data = ca_data
self.check_hostname = check_hostname
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[ssl.SSLContext] = None

def get(self) -> ssl.SSLContext:
Expand All @@ -840,6 +845,8 @@ def get(self) -> ssl.SSLContext:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context

Expand Down
2 changes: 2 additions & 0 deletions redis/client.py
Expand Up @@ -198,6 +198,7 @@ def __init__(
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
Expand Down Expand Up @@ -298,6 +299,7 @@ def __init__(
"ssl_ocsp_context": ssl_ocsp_context,
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
connection_pool = ConnectionPool(**kwargs)
Expand Down
5 changes: 5 additions & 0 deletions redis/connection.py
Expand Up @@ -685,6 +685,7 @@ def __init__(
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
**kwargs,
):
"""Constructor
Expand All @@ -704,6 +705,7 @@ def __init__(
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.
Raises:
RedisError
Expand Down Expand Up @@ -737,6 +739,7 @@ def __init__(
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
self.ssl_min_version = ssl_min_version
self.ssl_ciphers = ssl_ciphers
super().__init__(**kwargs)

def _connect(self):
Expand All @@ -761,6 +764,8 @@ def _connect(self):
)
if self.ssl_min_version is not None:
context.minimum_version = self.ssl_min_version
if self.ssl_ciphers:
context.set_ciphers(self.ssl_ciphers)
sslsock = context.wrap_socket(sock, server_hostname=self.host)
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
raise RedisError("cryptography is not installed.")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -8,7 +8,7 @@
long_description_content_type="text/markdown",
keywords=["Redis", "key-value store", "database"],
license="MIT",
version="5.0.3",
version="5.0.4",
packages=find_packages(
include=[
"redis",
Expand Down
54 changes: 54 additions & 0 deletions tests/test_asyncio/test_cluster.py
@@ -1,6 +1,7 @@
import asyncio
import binascii
import datetime
import ssl
import warnings
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -2951,6 +2952,59 @@ async def test_ssl_connection(
async with await create_client(ssl=True, ssl_cert_reqs="none") as rc:
assert await rc.ping()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
async def test_ssl_connection_tls12_custom_ciphers(
self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
) as rc:
assert await rc.ping()

async def test_ssl_connection_tls12_custom_ciphers_invalid(
self, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers="foo:bar",
) as rc:
with pytest.raises(RedisClusterException) as e:
assert await rc.ping()
assert "Redis Cluster cannot be connected" in str(e.value)

@pytest.mark.parametrize(
"ssl_ciphers",
[
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
],
)
async def test_ssl_connection_tls13_custom_ciphers(
self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
# TLSv1.3 does not support changing the ciphers
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
) as rc:
with pytest.raises(RedisClusterException) as e:
assert await rc.ping()
assert "Redis Cluster cannot be connected" in str(e.value)

async def test_validating_self_signed_certificate(
self, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_asyncio/test_connect.py
Expand Up @@ -50,6 +50,32 @@ async def test_uds_connect(uds_address):
await _assert_connect(conn, path)


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
await conn.disconnect()


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_min_version",
Expand Down
25 changes: 25 additions & 0 deletions tests/test_connect.py
Expand Up @@ -71,6 +71,31 @@ def test_tcp_ssl_connect(tcp_address, ssl_min_version):
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.ssl
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
def test_tcp_ssl_version_mismatch(tcp_address):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_ssl.py
Expand Up @@ -73,6 +73,68 @@ def test_validating_self_signed_string_certificate(self, request):
)
assert r.ping()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"DHE-RSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305",
],
)
def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_3,
ssl_ciphers=ssl_ciphers,
)
assert r.ping()
r.close()

def test_ssl_connection_tls12_custom_ciphers_invalid(self, request):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers="foo:bar",
)
with pytest.raises(RedisError) as e:
r.ping()
assert "No cipher can be selected" in str(e)
r.close()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
],
)
def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers):
# TLSv1.3 does not support changing the ciphers
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
with pytest.raises(RedisError) as e:
r.ping()
assert "No cipher can be selected" in str(e)
r.close()

def _create_oscp_conn(self, request):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
Expand Down

0 comments on commit e71119d

Please sign in to comment.