Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce AbstractConnection so that UnixDomainSocketConnection can call super().__init__ #2588

Merged
merged 2 commits into from Mar 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
278 changes: 120 additions & 158 deletions redis/connection.py
Expand Up @@ -6,6 +6,7 @@
import sys
import threading
import weakref
from abc import abstractmethod
from io import SEEK_END
from itertools import chain
from queue import Empty, Full, LifoQueue
Expand Down Expand Up @@ -585,20 +586,13 @@ def pack(self, *args):
return output


class Connection:
"Manages TCP communication to and from a Redis server"
class AbstractConnection:
"Manages communication to and from a Redis server"

def __init__(
self,
host="localhost",
port=6379,
db=0,
password=None,
socket_timeout=None,
socket_connect_timeout=None,
socket_keepalive=False,
socket_keepalive_options=None,
socket_type=0,
retry_on_timeout=False,
retry_on_error=SENTINEL,
encoding="utf-8",
Expand Down Expand Up @@ -629,18 +623,11 @@ def __init__(
"2. 'credential_provider'"
)
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.db = db
self.client_name = client_name
self.credential_provider = credential_provider
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
Expand Down Expand Up @@ -673,11 +660,9 @@ def __repr__(self):
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
return f"{self.__class__.__name__}<{repr_args}>"

@abstractmethod
def repr_pieces(self):
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
if self.client_name:
pieces.append(("client_name", self.client_name))
return pieces
pass

def __del__(self):
try:
Expand Down Expand Up @@ -740,75 +725,17 @@ def connect(self):
if callback:
callback(self)

@abstractmethod
def _connect(self):
"Create a TCP socket connection"
# we want to mimic what socket.create_connection does to support
# ipv4/ipv6, but we want to set options prior to calling
# socket.connect()
err = None
for res in socket.getaddrinfo(
self.host, self.port, self.socket_type, socket.SOCK_STREAM
):
family, socktype, proto, canonname, socket_address = res
sock = None
try:
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# TCP_KEEPALIVE
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in self.socket_keepalive_options.items():
sock.setsockopt(socket.IPPROTO_TCP, k, v)

# set the socket_connect_timeout before we connect
sock.settimeout(self.socket_connect_timeout)

# connect
sock.connect(socket_address)

# set the socket_timeout now that we're connected
sock.settimeout(self.socket_timeout)
return sock

except OSError as _:
err = _
if sock is not None:
sock.close()

if err is not None:
raise err
raise OSError("socket.getaddrinfo returned an empty list")
pass

@abstractmethod
def _host_error(self):
try:
host_error = f"{self.host}:{self.port}"
except AttributeError:
host_error = "connection"

return host_error
pass

@abstractmethod
def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"

host_error = self._host_error()

if len(exception.args) == 1:
try:
return f"Error connecting to {host_error}. \
{exception.args[0]}."
except AttributeError:
return f"Connection Error: {exception.args[0]}"
else:
try:
return (
f"Error {exception.args[0]} connecting to "
f"{host_error}. {exception.args[1]}."
)
except AttributeError:
return f"Connection Error: {exception.args[0]}"
pass

def on_connect(self):
"Initialize the connection, authenticate and select a database"
Expand Down Expand Up @@ -992,6 +919,101 @@ def pack_commands(self, commands):
return output


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"

def __init__(
self,
host="localhost",
port=6379,
socket_timeout=None,
socket_connect_timeout=None,
socket_keepalive=False,
socket_keepalive_options=None,
socket_type=0,
**kwargs,
):
self.host = host
self.port = int(port)
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
self.socket_keepalive = socket_keepalive
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
super().__init__(**kwargs)

def repr_pieces(self):
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
if self.client_name:
pieces.append(("client_name", self.client_name))
return pieces

def _connect(self):
"Create a TCP socket connection"
# we want to mimic what socket.create_connection does to support
# ipv4/ipv6, but we want to set options prior to calling
# socket.connect()
err = None
for res in socket.getaddrinfo(
self.host, self.port, self.socket_type, socket.SOCK_STREAM
):
family, socktype, proto, canonname, socket_address = res
sock = None
try:
sock = socket.socket(family, socktype, proto)
# TCP_NODELAY
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# TCP_KEEPALIVE
if self.socket_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
for k, v in self.socket_keepalive_options.items():
sock.setsockopt(socket.IPPROTO_TCP, k, v)

# set the socket_connect_timeout before we connect
sock.settimeout(self.socket_connect_timeout)

# connect
sock.connect(socket_address)

# set the socket_timeout now that we're connected
sock.settimeout(self.socket_timeout)
return sock

except OSError as _:
err = _
if sock is not None:
sock.close()

if err is not None:
raise err
raise OSError("socket.getaddrinfo returned an empty list")

def _host_error(self):
return f"{self.host}:{self.port}"

def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"

host_error = self._host_error()

if len(exception.args) == 1:
try:
return f"Error connecting to {host_error}. \
{exception.args[0]}."
except AttributeError:
return f"Connection Error: {exception.args[0]}"
else:
try:
return (
f"Error {exception.args[0]} connecting to "
f"{host_error}. {exception.args[1]}."
)
except AttributeError:
return f"Connection Error: {exception.args[0]}"


class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
This class extends the Connection class, adding SSL functionality, and making
Expand Down Expand Up @@ -1037,8 +1059,6 @@ def __init__(
if not ssl_available:
raise RedisError("Python wasn't built with SSL support")

super().__init__(**kwargs)

self.keyfile = ssl_keyfile
self.certfile = ssl_certfile
if ssl_cert_reqs is None:
Expand All @@ -1064,6 +1084,7 @@ def __init__(
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
self.ssl_ocsp_context = ssl_ocsp_context
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
super().__init__(**kwargs)

def _connect(self):
"Wrap the socket with SSL support"
Expand Down Expand Up @@ -1133,77 +1154,12 @@ def _connect(self):
return sslsock


class UnixDomainSocketConnection(Connection):
def __init__(
self,
path="",
db=0,
username=None,
password=None,
socket_timeout=None,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
retry_on_timeout=False,
retry_on_error=SENTINEL,
parser_class=DefaultParser,
socket_read_size=65536,
health_check_interval=0,
client_name=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
command_packer=None,
):
"""
Initialize a new UnixDomainSocketConnection.
To specify a retry policy for specific errors, first set
`retry_on_error` to a list of the error/s to retry on, then set
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
if (username or password) and credential_provider is not None:
raise DataError(
"'username' and 'password' cannot be passed along with 'credential_"
"provider'. Please provide only one of the following arguments: \n"
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
self.pid = os.getpid()
class UnixDomainSocketConnection(AbstractConnection):
"Manages UDS communication to and from a Redis server"

def __init__(self, path="", **kwargs):
self.path = path
self.db = db
self.client_name = client_name
self.credential_provider = credential_provider
self.password = password
self.username = username
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_error is SENTINEL:
retry_on_error = []
if retry_on_timeout:
# Add TimeoutError to the errors list to retry on
retry_on_error.append(TimeoutError)
self.retry_on_error = retry_on_error
if self.retry_on_error:
if retry is None:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
# Update the retry's supported errors with the specified errors
self.retry.update_supported_errors(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
self.next_health_check = 0
self.redis_connect_func = redis_connect_func
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._socket_read_size = socket_read_size
self.set_parser(parser_class)
self._connect_callbacks = []
self._buffer_cutoff = 6000
self._command_packer = self._construct_command_packer(command_packer)
super().__init__(**kwargs)

def repr_pieces(self):
pieces = [("path", self.path), ("db", self.db)]
Expand All @@ -1218,15 +1174,21 @@ def _connect(self):
sock.connect(self.path)
return sock

def _host_error(self):
return self.path

def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{self.path}. {exception.args[1]}."
f"{host_error}. {exception.args[1]}."
)


Expand Down