Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to control the minimum SSL version #3127

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Allow to control the minimum SSL version
* Add an optional lock_name attribute to LockError.
* Fix return types for `get`, `set_path` and `strappend` in JSONCommands
* Connection.register_connect_callback() is made public.
Expand Down
36 changes: 36 additions & 0 deletions docs/examples/ssl_connection_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,42 @@
"ssl_connection.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connecting to a Redis instance via SSL, while specifying a minimum TLS version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import redis\n",
"import ssl\n",
"\n",
"ssl_conn = redis.Redis(\n",
" host=\"localhost\",\n",
" port=6666,\n",
" ssl=True,\n",
" ssl_min_version=ssl.TLSVersion.TLSv1_3,\n",
")\n",
"ssl_conn.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
3 changes: 3 additions & 0 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import re
import ssl
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
Expand Down Expand Up @@ -332,6 +334,7 @@ def __init__(
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
}
)
# This arg only used if no pool is passed in
Expand Down
3 changes: 3 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import random
import socket
import ssl
import warnings
from typing import (
Any,
Expand Down Expand Up @@ -271,6 +272,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_min_version: Optional[ssl.TLSVersion] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
cache_enabled: bool = False,
Expand Down Expand Up @@ -344,6 +346,7 @@ def __init__(
"ssl_certfile": ssl_certfile,
"ssl_check_hostname": ssl_check_hostname,
"ssl_keyfile": ssl_keyfile,
"ssl_min_version": ssl_min_version,
}
)

Expand Down
11 changes: 11 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ def __init__(
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
**kwargs,
):
self.ssl_context: RedisSSLContext = RedisSSLContext(
Expand All @@ -832,6 +833,7 @@ def __init__(
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
)
super().__init__(**kwargs)

Expand Down Expand Up @@ -864,6 +866,10 @@ def ca_data(self):
def check_hostname(self):
return self.ssl_context.check_hostname

@property
def min_version(self):
return self.ssl_context.min_version


class RedisSSLContext:
__slots__ = (
Expand All @@ -874,6 +880,7 @@ class RedisSSLContext:
"ca_data",
"context",
"check_hostname",
"min_version",
)

def __init__(
Expand All @@ -884,6 +891,7 @@ def __init__(
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[ssl.TLSVersion] = None,
):
self.keyfile = keyfile
self.certfile = certfile
Expand All @@ -903,6 +911,7 @@ def __init__(
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = check_hostname
self.min_version = min_version
self.context: Optional[ssl.SSLContext] = None

def get(self) -> ssl.SSLContext:
Expand All @@ -914,6 +923,8 @@ def get(self) -> ssl.SSLContext:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
self.context = context
return self.context

Expand Down
2 changes: 2 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
ssl_validate_ocsp_stapled=False,
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
Expand Down Expand Up @@ -311,6 +312,7 @@ def __init__(
"ssl_validate_ocsp": ssl_validate_ocsp,
"ssl_ocsp_context": ssl_ocsp_context,
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
"ssl_min_version": ssl_min_version,
}
)
connection_pool = ConnectionPool(**kwargs)
Expand Down
5 changes: 5 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ def __init__(
ssl_validate_ocsp_stapled=False,
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
**kwargs,
):
"""Constructor
Expand All @@ -787,6 +788,7 @@ def __init__(
ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.

Raises:
RedisError
Expand Down Expand Up @@ -819,6 +821,7 @@ def __init__(
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
self.ssl_min_version = ssl_min_version
super().__init__(**kwargs)

def _connect(self):
Expand All @@ -841,6 +844,8 @@ def _connect(self):
context.load_verify_locations(
cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
)
if self.ssl_min_version is not None:
context.minimum_version = self.ssl_min_version
sslsock = context.wrap_socket(sock, server_hostname=self.host)
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
raise RedisError("cryptography is not installed.")
Expand Down
54 changes: 51 additions & 3 deletions tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SSLConnection,
UnixDomainSocketConnection,
)
from redis.exceptions import ConnectionError

from ..ssl_utils import get_ssl_filename

Expand Down Expand Up @@ -50,7 +51,17 @@ async def test_uds_connect(uds_address):


@pytest.mark.ssl
async def test_tcp_ssl_connect(tcp_address):
@pytest.mark.parametrize(
"ssl_min_version",
[
ssl.TLSVersion.TLSv1_2,
pytest.param(
ssl.TLSVersion.TLSv1_3,
marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"),
),
],
)
async def test_tcp_ssl_connect(tcp_address, ssl_min_version):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
Expand All @@ -60,12 +71,44 @@ async def test_tcp_ssl_connect(tcp_address):
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl_min_version,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
await conn.disconnect()


async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
@pytest.mark.ssl
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
async def test_tcp_ssl_version_mismatch(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=1,
ssl_min_version=ssl.TLSVersion.TLSv1_3,
)
with pytest.raises(ConnectionError):
await _assert_connect(
conn,
tcp_address,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.TLSVersion.TLSv1_2,
)
await conn.disconnect()


async def _assert_connect(
conn,
server_address,
certfile=None,
keyfile=None,
ssl_version=None,
):
stop_event = asyncio.Event()
finished = asyncio.Event()

Expand All @@ -82,7 +125,9 @@ async def _handler(reader, writer):
elif certfile:
host, port = server_address
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
if ssl_version is not None:
context.minimum_version = ssl_version
context.maximum_version = ssl_version
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
else:
Expand All @@ -94,6 +139,9 @@ async def _handler(reader, writer):
try:
await conn.connect()
await conn.disconnect()
except ConnectionError:
finished.set()
raise
finally:
stop_event.set()
aserver.close()
Expand Down