Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to customize SSL ciphers #3214

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
Expand Down Expand Up @@ -333,6 +334,7 @@ def __init__(
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
# This arg only used if no pool is passed in
Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __init__(
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
cache_enabled: bool = False,
Expand Down Expand Up @@ -347,6 +348,7 @@ def __init__(
"ssl_check_hostname": ssl_check_hostname,
"ssl_keyfile": ssl_keyfile,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)

Expand Down
7 changes: 7 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def __init__(
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
**kwargs,
):
self.ssl_context: RedisSSLContext = RedisSSLContext(
Expand All @@ -834,6 +835,7 @@ def __init__(
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)

Expand Down Expand Up @@ -881,6 +883,7 @@ class RedisSSLContext:
"context",
"check_hostname",
"min_version",
"ciphers",
)

def __init__(
Expand All @@ -892,6 +895,7 @@ def __init__(
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[ssl.TLSVersion] = None,
ciphers: Optional[str] = None,
):
self.keyfile = keyfile
self.certfile = certfile
Expand All @@ -912,6 +916,7 @@ def __init__(
self.ca_data = ca_data
self.check_hostname = check_hostname
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[ssl.SSLContext] = None

def get(self) -> ssl.SSLContext:
Expand All @@ -925,6 +930,8 @@ def get(self) -> ssl.SSLContext:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context

Expand Down
2 changes: 2 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
Expand Down Expand Up @@ -318,6 +319,7 @@ def __init__(
"ssl_ocsp_context": ssl_ocsp_context,
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
connection_pool = ConnectionPool(**kwargs)
Expand Down
5 changes: 5 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ def __init__(
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
ssl_ciphers=None,
**kwargs,
):
"""Constructor
Expand All @@ -783,6 +784,7 @@ def __init__(
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information.

Raises:
RedisError
Expand Down Expand Up @@ -816,6 +818,7 @@ def __init__(
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
self.ssl_min_version = ssl_min_version
self.ssl_ciphers = ssl_ciphers
super().__init__(**kwargs)

def _connect(self):
Expand All @@ -840,6 +843,8 @@ def _connect(self):
)
if self.ssl_min_version is not None:
context.minimum_version = self.ssl_min_version
if self.ssl_ciphers:
context.set_ciphers(self.ssl_ciphers)
sslsock = context.wrap_socket(sock, server_hostname=self.host)
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
raise RedisError("cryptography is not installed.")
Expand Down
54 changes: 54 additions & 0 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import binascii
import datetime
import ssl
import warnings
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -2961,6 +2962,59 @@ async def test_ssl_connection(
async with await create_client(ssl=True, ssl_cert_reqs="none") as rc:
assert await rc.ping()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
async def test_ssl_connection_tls12_custom_ciphers(
self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
) as rc:
assert await rc.ping()

async def test_ssl_connection_tls12_custom_ciphers_invalid(
self, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers="foo:bar",
) as rc:
with pytest.raises(RedisClusterException) as e:
assert await rc.ping()
assert "Redis Cluster cannot be connected" in str(e.value)

@pytest.mark.parametrize(
"ssl_ciphers",
[
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
],
)
async def test_ssl_connection_tls13_custom_ciphers(
self, ssl_ciphers, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
# TLSv1.3 does not support changing the ciphers
async with await create_client(
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
) as rc:
with pytest.raises(RedisClusterException) as e:
assert await rc.ping()
assert "Redis Cluster cannot be connected" in str(e.value)

async def test_validating_self_signed_certificate(
self, create_client: Callable[..., Awaitable[RedisCluster]]
) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@ async def test_uds_connect(uds_address):
await _assert_connect(conn, path)


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
await conn.disconnect()


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_min_version",
Expand Down
25 changes: 25 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_tcp_ssl_connect(tcp_address, ssl_min_version):
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.ssl
@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES128-GCM-SHA256",
],
)
def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)


@pytest.mark.ssl
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
def test_tcp_ssl_version_mismatch(tcp_address):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,68 @@ def test_validating_self_signed_string_certificate(self, request):
assert r.ping()
r.close()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
"DHE-RSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305",
],
)
def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_3,
ssl_ciphers=ssl_ciphers,
)
assert r.ping()
r.close()

def test_ssl_connection_tls12_custom_ciphers_invalid(self, request):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers="foo:bar",
)
with pytest.raises(RedisError) as e:
r.ping()
assert "No cipher can be selected" in str(e)
r.close()

@pytest.mark.parametrize(
"ssl_ciphers",
[
"TLS_CHACHA20_POLY1305_SHA256",
"TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
],
)
def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers):
# TLSv1.3 does not support changing the ciphers
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
r = redis.Redis(
host=p[0],
port=p[1],
ssl=True,
ssl_cert_reqs="none",
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
with pytest.raises(RedisError) as e:
r.ping()
assert "No cipher can be selected" in str(e)
r.close()

def _create_oscp_conn(self, request):
ssl_url = request.config.option.redis_ssl_url
p = urlparse(ssl_url)[1].split(":")
Expand Down