Skip to content

Commit

Permalink
Allow to control the minimum SSL version (#3127)
Browse files Browse the repository at this point in the history
* Allow to control the minimum SSL version

It's useful for applications that has strict security requirements.

* Add tests for minimum SSL version

The commit updates test_tcp_ssl_connect for both sync and async
connections. Now it sets the minimum SSL version. The test is ran with
both TLSv1.2 and TLSv1.3 (if supported).

A new test case is test_tcp_ssl_version_mismatch. The test added for
both sync and async connections. It uses TLS 1.3 on the client side,
and TLS 1.2 on the server side. It expects a connection error. The
test is skipped if TLS 1.3 is not supported.

* Add example of using a minimum TLS version
  • Loading branch information
poiuj authored and dvora-h committed Feb 25, 2024
1 parent a54617c commit e868a11
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Allow to control the minimum SSL version
* Add an optional lock_name attribute to LockError.
* Fix return types for `get`, `set_path` and `strappend` in JSONCommands
* Connection.register_connect_callback() is made public.
Expand Down
36 changes: 36 additions & 0 deletions docs/examples/ssl_connection_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,42 @@
"ssl_connection.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connecting to a Redis instance via SSL, while specifying a minimum TLS version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import redis\n",
"import ssl\n",
"\n",
"ssl_conn = redis.Redis(\n",
" host=\"localhost\",\n",
" port=6666,\n",
" ssl=True,\n",
" ssl_min_version=ssl.TLSVersion.TLSv1_3,\n",
")\n",
"ssl_conn.ping()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
3 changes: 3 additions & 0 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import re
import ssl
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -219,6 +220,7 @@ def __init__(
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
Expand Down Expand Up @@ -311,6 +313,7 @@ def __init__(
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
}
)
# This arg only used if no pool is passed in
Expand Down
3 changes: 3 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import random
import socket
import ssl
import warnings
from typing import (
Any,
Expand Down Expand Up @@ -265,6 +266,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_min_version: Optional[ssl.TLSVersion] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
Expand Down Expand Up @@ -323,6 +325,7 @@ def __init__(
"ssl_certfile": ssl_certfile,
"ssl_check_hostname": ssl_check_hostname,
"ssl_keyfile": ssl_keyfile,
"ssl_min_version": ssl_min_version,
}
)

Expand Down
11 changes: 11 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def __init__(
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_min_version: Optional[ssl.TLSVersion] = None,
**kwargs,
):
self.ssl_context: RedisSSLContext = RedisSSLContext(
Expand All @@ -747,6 +748,7 @@ def __init__(
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
)
super().__init__(**kwargs)

Expand Down Expand Up @@ -779,6 +781,10 @@ def ca_data(self):
def check_hostname(self):
return self.ssl_context.check_hostname

@property
def min_version(self):
return self.ssl_context.min_version


class RedisSSLContext:
__slots__ = (
Expand All @@ -789,6 +795,7 @@ class RedisSSLContext:
"ca_data",
"context",
"check_hostname",
"min_version",
)

def __init__(
Expand All @@ -799,6 +806,7 @@ def __init__(
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[ssl.TLSVersion] = None,
):
self.keyfile = keyfile
self.certfile = certfile
Expand All @@ -818,6 +826,7 @@ def __init__(
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = check_hostname
self.min_version = min_version
self.context: Optional[ssl.SSLContext] = None

def get(self) -> ssl.SSLContext:
Expand All @@ -829,6 +838,8 @@ def get(self) -> ssl.SSLContext:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
self.context = context
return self.context

Expand Down
2 changes: 2 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(
ssl_validate_ocsp_stapled=False,
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
"ssl_validate_ocsp": ssl_validate_ocsp,
"ssl_ocsp_context": ssl_ocsp_context,
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
"ssl_min_version": ssl_min_version,
}
)
connection_pool = ConnectionPool(**kwargs)
Expand Down
5 changes: 5 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ def __init__(
ssl_validate_ocsp_stapled=False,
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
ssl_min_version=None,
**kwargs,
):
"""Constructor
Expand All @@ -702,6 +703,7 @@ def __init__(
ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response
ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert
ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service.
ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module.
Raises:
RedisError
Expand Down Expand Up @@ -734,6 +736,7 @@ def __init__(
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
self.ssl_min_version = ssl_min_version
super().__init__(**kwargs)

def _connect(self):
Expand All @@ -756,6 +759,8 @@ def _connect(self):
context.load_verify_locations(
cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data
)
if self.ssl_min_version is not None:
context.minimum_version = self.ssl_min_version
sslsock = context.wrap_socket(sock, server_hostname=self.host)
if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False:
raise RedisError("cryptography is not installed.")
Expand Down
54 changes: 51 additions & 3 deletions tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SSLConnection,
UnixDomainSocketConnection,
)
from redis.exceptions import ConnectionError

from ..ssl_utils import get_ssl_filename

Expand Down Expand Up @@ -50,7 +51,17 @@ async def test_uds_connect(uds_address):


@pytest.mark.ssl
async def test_tcp_ssl_connect(tcp_address):
@pytest.mark.parametrize(
"ssl_min_version",
[
ssl.TLSVersion.TLSv1_2,
pytest.param(
ssl.TLSVersion.TLSv1_3,
marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"),
),
],
)
async def test_tcp_ssl_connect(tcp_address, ssl_min_version):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
Expand All @@ -60,12 +71,44 @@ async def test_tcp_ssl_connect(tcp_address):
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=10,
ssl_min_version=ssl_min_version,
)
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
await conn.disconnect()


async def _assert_connect(conn, server_address, certfile=None, keyfile=None):
@pytest.mark.ssl
@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3")
async def test_tcp_ssl_version_mismatch(tcp_address):
host, port = tcp_address
certfile = get_ssl_filename("server-cert.pem")
keyfile = get_ssl_filename("server-key.pem")
conn = SSLConnection(
host=host,
port=port,
client_name=_CLIENT_NAME,
ssl_ca_certs=certfile,
socket_timeout=1,
ssl_min_version=ssl.TLSVersion.TLSv1_3,
)
with pytest.raises(ConnectionError):
await _assert_connect(
conn,
tcp_address,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.TLSVersion.TLSv1_2,
)
await conn.disconnect()


async def _assert_connect(
conn,
server_address,
certfile=None,
keyfile=None,
ssl_version=None,
):
stop_event = asyncio.Event()
finished = asyncio.Event()

Expand All @@ -82,7 +125,9 @@ async def _handler(reader, writer):
elif certfile:
host, port = server_address
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.minimum_version = ssl.TLSVersion.TLSv1_2
if ssl_version is not None:
context.minimum_version = ssl_version
context.maximum_version = ssl_version
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context)
else:
Expand All @@ -94,6 +139,9 @@ async def _handler(reader, writer):
try:
await conn.connect()
await conn.disconnect()
except ConnectionError:
finished.set()
raise
finally:
stop_event.set()
aserver.close()
Expand Down

0 comments on commit e868a11

Please sign in to comment.