Skip to content

Commit

Permalink
Utilize the allow_sql context var to record SQL queries.
Browse files Browse the repository at this point in the history
This also switches our CursorWrappers to inherit from the django
class django.db.backends.utils.CursorWrapper. This reduces
some of the code we need.

This also explicitly disallows specific cursor methods from
being used. This is because the psycopg3 backend's mogrify
function creates a new cursor in order to determine the last
used connection. This means that the ExceptionCursorWrapper
technically needs to access both connection and cursor. Rather
than defining an odd allow list, an explicit deny list made
more sense.

One area of concern is that this wouldn't cover the __iter__
function.
  • Loading branch information
tim-schilling authored and living180 committed May 10, 2023
1 parent ab42456 commit 5a7d015
Showing 1 changed file with 102 additions and 114 deletions.
216 changes: 102 additions & 114 deletions debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from time import time

from django.db.backends.utils import CursorWrapper
from django.utils.encoding import force_str

from debug_toolbar import settings as dt_settings
Expand Down Expand Up @@ -43,23 +44,16 @@ def cursor(*args, **kwargs):
# See:
# https://github.com/jazzband/django-debug-toolbar/pull/615
# https://github.com/jazzband/django-debug-toolbar/pull/896
cursor = connection._djdt_cursor(*args, **kwargs)
# Do not wrap cursors that are created during post-processing in ._record()
if connection._djdt_in_record:
return cursor
if allow_sql.get():
wrapper = NormalCursorWrapper
else:
wrapper = ExceptionCursorWrapper
return wrapper(cursor, connection, panel)
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)

def chunked_cursor(*args, **kwargs):
# prevent double wrapping
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
# Do not wrap cursors that are created during post-processing in ._record()
if connection._djdt_in_record:
return cursor
if not isinstance(cursor, BaseCursorWrapper):
if allow_sql.get():
wrapper = NormalCursorWrapper
Expand All @@ -70,7 +64,6 @@ def chunked_cursor(*args, **kwargs):

connection.cursor = cursor
connection.chunked_cursor = chunked_cursor
connection._djdt_in_record = False


def unwrap_cursor(connection):
Expand All @@ -93,8 +86,11 @@ def unwrap_cursor(connection):
del connection._djdt_chunked_cursor


class BaseCursorWrapper:
pass
class BaseCursorWrapper(CursorWrapper):
def __init__(self, cursor, db, logger):
super().__init__(cursor, db)
# logger must implement a ``record`` method
self.logger = logger


class ExceptionCursorWrapper(BaseCursorWrapper):
Expand All @@ -103,25 +99,28 @@ class ExceptionCursorWrapper(BaseCursorWrapper):
Used in Templates panel.
"""

def __init__(self, cursor, db, logger):
pass

def __getattr__(self, attr):
raise SQLQueryTriggered()
# This allows the cursor to access connection and close which
# are needed in psycopg to determine the last_executed_query via
# the mogrify function.
if attr in (
"callproc",
"execute",
"executemany",
"fetchone",
"fetchmany",
"fetchall",
"nextset",
):
raise SQLQueryTriggered(f"Attr: {attr} was accessed")
return super().__getattr__(attr)


class NormalCursorWrapper(BaseCursorWrapper):
"""
Wraps a cursor and logs queries.
"""

def __init__(self, cursor, db, logger):
self.cursor = cursor
# Instance of a BaseDatabaseWrapper subclass
self.db = db
# logger must implement a ``record`` method
self.logger = logger

def _quote_expr(self, element):
if isinstance(element, str):
return "'%s'" % element.replace("'", "''")
Expand Down Expand Up @@ -161,6 +160,17 @@ def _decode(self, param):
except UnicodeDecodeError:
return "(encoded string)"

def _get_last_executed_query(self, sql, params):
"""Get the last executed query from the connection."""
# The pyscopg3 backend uses a mogrify function which creates a new cursor.
# We need to avoid hooking into that cursor.
reset_token = allow_sql.set(False)
sql_query = self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
)
allow_sql.reset(reset_token)
return sql_query

def _record(self, method, sql, params):
alias = self.db.alias
vendor = self.db.vendor
Expand All @@ -174,106 +184,84 @@ def _record(self, method, sql, params):
try:
return method(sql, params)
finally:
# In certain cases the following code can cause Django to create additional
# CursorWrapper instances (in particular, the
# self.db.ops.last_executed_query() call with psycopg3). However, we do not
# want to wrap such cursors, so set the following flag to avoid that.
self.db._djdt_in_record = True
stop_time = time()
duration = (stop_time - start_time) * 1000
_params = ""
try:
stop_time = time()
duration = (stop_time - start_time) * 1000
_params = ""
_params = json.dumps(self._decode(params))
except TypeError:
pass # object not JSON serializable
template_info = get_template_info()

# Sql might be an object (such as psycopg Composed).
# For logging purposes, make sure it's str.
if vendor == "postgresql" and not isinstance(sql, str):
sql = sql.as_string(conn)
else:
sql = str(sql)

params = {
"vendor": vendor,
"alias": alias,
"sql": self._get_last_executed_query(sql, params),
"duration": duration,
"raw_sql": sql,
"params": _params,
"raw_params": params,
"stacktrace": get_stack_trace(skip=2),
"start_time": start_time,
"stop_time": stop_time,
"is_slow": (
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
),
"is_select": sql.lower().strip().startswith("select"),
"template_info": template_info,
}

if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
_params = json.dumps(self._decode(params))
except TypeError:
pass # object not JSON serializable
template_info = get_template_info()

# Sql might be an object (such as psycopg Composed).
# For logging purposes, make sure it's str.
if vendor == "postgresql" and not isinstance(sql, str):
sql = sql.as_string(conn)
else:
sql = str(sql)

params = {
"vendor": vendor,
"alias": alias,
"sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
),
"duration": duration,
"raw_sql": sql,
"params": _params,
"raw_params": params,
"stacktrace": get_stack_trace(skip=2),
"start_time": start_time,
"stop_time": stop_time,
"is_slow": (
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
),
"is_select": sql.lower().strip().startswith("select"),
"template_info": template_info,
}

if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = "unknown"
# PostgreSQL does not expose any sort of transaction ID, so it is
# necessary to generate synthetic transaction IDs here. If the
# connection was not in a transaction when the query started, and was
# after the query finished, a new transaction definitely started, so get
# a new transaction ID from logger.new_transaction_id(). If the query
# was in a transaction both before and after executing, make the
# assumption that it is the same transaction and get the current
# transaction ID from logger.current_transaction_id(). There is an edge
# case where Django can start a transaction before the first query
# executes, so in that case logger.current_transaction_id() will
# generate a new transaction ID since one does not already exist.
final_conn_status = conn.info.transaction_status
if final_conn_status == STATUS_IN_TRANSACTION:
if initial_conn_status == STATUS_IN_TRANSACTION:
trans_id = self.logger.current_transaction_id(alias)
else:
trans_id = self.logger.new_transaction_id(alias)
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = "unknown"
# PostgreSQL does not expose any sort of transaction ID, so it is
# necessary to generate synthetic transaction IDs here. If the
# connection was not in a transaction when the query started, and was
# after the query finished, a new transaction definitely started, so get
# a new transaction ID from logger.new_transaction_id(). If the query
# was in a transaction both before and after executing, make the
# assumption that it is the same transaction and get the current
# transaction ID from logger.current_transaction_id(). There is an edge
# case where Django can start a transaction before the first query
# executes, so in that case logger.current_transaction_id() will
# generate a new transaction ID since one does not already exist.
final_conn_status = conn.info.transaction_status
if final_conn_status == STATUS_IN_TRANSACTION:
if initial_conn_status == STATUS_IN_TRANSACTION:
trans_id = self.logger.current_transaction_id(alias)
else:
trans_id = None
trans_id = self.logger.new_transaction_id(alias)
else:
trans_id = None

params.update(
{
"trans_id": trans_id,
"trans_status": conn.info.transaction_status,
"iso_level": iso_level,
}
)
params.update(
{
"trans_id": trans_id,
"trans_status": conn.info.transaction_status,
"iso_level": iso_level,
}
)

# We keep `sql` to maintain backwards compatibility
self.logger.record(**params)
finally:
self.db._djdt_in_record = False
# We keep `sql` to maintain backwards compatibility
self.logger.record(**params)

def callproc(self, procname, params=None):
return self._record(self.cursor.callproc, procname, params)
return self._record(super().callproc, procname, params)

def execute(self, sql, params=None):
return self._record(self.cursor.execute, sql, params)
return self._record(super().execute, sql, params)

def executemany(self, sql, param_list):
return self._record(self.cursor.executemany, sql, param_list)

def __getattr__(self, attr):
return getattr(self.cursor, attr)

def __iter__(self):
return iter(self.cursor)

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.close()
return self._record(super().executemany, sql, param_list)

0 comments on commit 5a7d015

Please sign in to comment.