Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix psycopg3 tests #1773

Merged
merged 4 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions debug_toolbar/panels/sql/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from debug_toolbar.panels import Panel
from debug_toolbar.panels.sql import views
from debug_toolbar.panels.sql.forms import SQLSelectForm
from debug_toolbar.panels.sql.tracking import unwrap_cursor, wrap_cursor
from debug_toolbar.panels.sql.tracking import wrap_cursor
from debug_toolbar.panels.sql.utils import contrasting_color_generator, reformat_sql
from debug_toolbar.utils import render_stacktrace

Expand Down Expand Up @@ -190,11 +190,12 @@ def get_urls(cls):
def enable_instrumentation(self):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, self)
wrap_cursor(connection)
connection._djdt_logger = self

def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
connection._djdt_logger = None

def generate_stats(self, request, response):
colors = contrasting_color_generator()
Expand Down
103 changes: 45 additions & 58 deletions debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
from time import time

import django.test.testcases
from django.db.backends.utils import CursorWrapper
from django.utils.encoding import force_str

from debug_toolbar import settings as dt_settings
Expand Down Expand Up @@ -31,10 +33,15 @@ class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""


def wrap_cursor(connection, panel):
def wrap_cursor(connection):
# If running a Django SimpleTestCase, which isn't allowed to access the database,
# don't perform any monkey patching.
tim-schilling marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(connection.cursor, django.test.testcases._DatabaseFailure):
return
if not hasattr(connection, "_djdt_cursor"):
connection._djdt_cursor = connection.cursor
connection._djdt_chunked_cursor = connection.chunked_cursor
connection._djdt_logger = None

def cursor(*args, **kwargs):
# Per the DB API cursor() does not accept any arguments. There's
Expand All @@ -43,78 +50,55 @@ def cursor(*args, **kwargs):
# See:
# https://github.com/jazzband/django-debug-toolbar/pull/615
# https://github.com/jazzband/django-debug-toolbar/pull/896
logger = connection._djdt_logger
cursor = connection._djdt_cursor(*args, **kwargs)
if logger is None:
return cursor
if allow_sql.get():
wrapper = NormalCursorWrapper
else:
wrapper = ExceptionCursorWrapper
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)
return wrapper(cursor.cursor, connection, logger)

def chunked_cursor(*args, **kwargs):
# prevent double wrapping
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
logger = connection._djdt_logger
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
if not isinstance(cursor, BaseCursorWrapper):
if logger is not None and not isinstance(cursor, DjDTCursorWrapper):
if allow_sql.get():
wrapper = NormalCursorWrapper
else:
wrapper = ExceptionCursorWrapper
return wrapper(cursor, connection, panel)
return wrapper(cursor.cursor, connection, logger)
return cursor

connection.cursor = cursor
connection.chunked_cursor = chunked_cursor
return cursor


def unwrap_cursor(connection):
if hasattr(connection, "_djdt_cursor"):
# Sometimes the cursor()/chunked_cursor() methods of the DatabaseWrapper
# instance are already monkey patched before wrap_cursor() is called. (In
# particular, Django's SimpleTestCase monkey patches those methods for any
# disallowed databases to raise an exception if they are accessed.) Thus only
# delete our monkey patch if the method we saved is the same as the class
# method. Otherwise, restore the prior monkey patch from our saved method.
if connection._djdt_cursor == connection.__class__.cursor:
del connection.cursor
else:
connection.cursor = connection._djdt_cursor
del connection._djdt_cursor
if connection._djdt_chunked_cursor == connection.__class__.chunked_cursor:
del connection.chunked_cursor
else:
connection.chunked_cursor = connection._djdt_chunked_cursor
del connection._djdt_chunked_cursor


class BaseCursorWrapper:
pass
class DjDTCursorWrapper(CursorWrapper):
def __init__(self, cursor, db, logger):
super().__init__(cursor, db)
# logger must implement a ``record`` method
self.logger = logger


class ExceptionCursorWrapper(BaseCursorWrapper):
class ExceptionCursorWrapper(DjDTCursorWrapper):
"""
Wraps a cursor and raises an exception on any operation.
Used in Templates panel.
"""

def __init__(self, cursor, db, logger):
pass

def __getattr__(self, attr):
raise SQLQueryTriggered()


class NormalCursorWrapper(BaseCursorWrapper):
class NormalCursorWrapper(DjDTCursorWrapper):
"""
Wraps a cursor and logs queries.
"""

def __init__(self, cursor, db, logger):
self.cursor = cursor
# Instance of a BaseDatabaseWrapper subclass
self.db = db
# logger must implement a ``record`` method
self.logger = logger

def _quote_expr(self, element):
if isinstance(element, str):
return "'%s'" % element.replace("'", "''")
Expand Down Expand Up @@ -154,6 +138,21 @@ def _decode(self, param):
except UnicodeDecodeError:
return "(encoded string)"

def _last_executed_query(self, sql, params):
"""Get the last executed query from the connection."""
# Django's psycopg3 backend creates a new cursor in its implementation of the
# .last_executed_query() method. To avoid wrapping that cursor, temporarily set
# the DatabaseWrapper's ._djdt_logger attribute to None. This will cause the
# monkey-patched .cursor() and .chunked_cursor() methods to skip the wrapping
# process during the .last_executed_query() call.
self.db._djdt_logger = None
try:
return self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
)
finally:
self.db._djdt_logger = self.logger

def _record(self, method, sql, params):
alias = self.db.alias
vendor = self.db.vendor
Expand Down Expand Up @@ -186,17 +185,17 @@ def _record(self, method, sql, params):
params = {
"vendor": vendor,
"alias": alias,
"sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
),
"sql": self._last_executed_query(sql, params),
"duration": duration,
"raw_sql": sql,
"params": _params,
"raw_params": params,
"stacktrace": get_stack_trace(skip=2),
"start_time": start_time,
"stop_time": stop_time,
"is_slow": duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"],
"is_slow": (
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
),
"is_select": sql.lower().strip().startswith("select"),
"template_info": template_info,
}
Expand Down Expand Up @@ -241,22 +240,10 @@ def _record(self, method, sql, params):
self.logger.record(**params)

def callproc(self, procname, params=None):
return self._record(self.cursor.callproc, procname, params)
return self._record(super().callproc, procname, params)

def execute(self, sql, params=None):
return self._record(self.cursor.execute, sql, params)
return self._record(super().execute, sql, params)

def executemany(self, sql, param_list):
return self._record(self.cursor.executemany, sql, param_list)

def __getattr__(self, attr):
return getattr(self.cursor, attr)

def __iter__(self):
return iter(self.cursor)

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.close()
return self._record(super().executemany, sql, param_list)