Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some column types being parsed twice #582

Merged
merged 4 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 6 additions & 7 deletions databases/backends/common/records.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import enum
import typing
from datetime import date, datetime
from datetime import date, datetime, time

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.row import Row as SQLRow
from sqlalchemy.sql.compiler import _CompileLabel
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import JSON
from sqlalchemy.types import TypeEngine

from databases.interfaces import Record as RecordInterface
Expand Down Expand Up @@ -62,12 +63,10 @@ def __getitem__(self, key: typing.Any) -> typing.Any:
raw = self._row[idx]
processor = datatype._cached_result_processor(self._dialect, None)

if self._dialect.name not in DIALECT_EXCLUDE:
if isinstance(raw, dict):
raw = json.dumps(raw)
if self._dialect.name in DIALECT_EXCLUDE:
if processor is not None and isinstance(raw, (int, str, float)):
return processor(raw)

if processor is not None and (not isinstance(raw, (datetime, date))):
return processor(raw)
return raw

def __iter__(self) -> typing.Iterator:
Expand Down
138 changes: 137 additions & 1 deletion tests/test_databases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import datetime
import decimal
import enum
import functools
import gc
import itertools
Expand Down Expand Up @@ -55,6 +56,47 @@ def process_result_value(self, value, dialect):
sqlalchemy.Column("published", sqlalchemy.DateTime),
)

# Used to test Date
events = sqlalchemy.Table(
"events",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("date", sqlalchemy.Date),
)


# Used to test Time
daily_schedule = sqlalchemy.Table(
"daily_schedule",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("time", sqlalchemy.Time),
)


class TshirtSize(enum.Enum):
SMALL = "SMALL"
MEDIUM = "MEDIUM"
LARGE = "LARGE"
XL = "XL"


class TshirtColor(enum.Enum):
BLUE = 0
GREEN = 1
YELLOW = 2
RED = 3


# Used to test Enum
tshirt_size = sqlalchemy.Table(
"tshirt_size",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("size", sqlalchemy.Enum(TshirtSize)),
sqlalchemy.Column("color", sqlalchemy.Enum(TshirtColor)),
)

# Used to test JSON
session = sqlalchemy.Table(
"session",
Expand Down Expand Up @@ -928,6 +970,52 @@ async def test_datetime_field(database_url):
assert results[0]["published"] == now


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_date_field(database_url):
"""
Test Date columns, to ensure records are coerced to/from proper Python types.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
now = datetime.date.today()

# execute()
query = events.insert()
values = {"date": now}
await database.execute(query, values)

# fetch_all()
query = events.select()
results = await database.fetch_all(query=query)
assert len(results) == 1
assert results[0]["date"] == now


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_time_field(database_url):
"""
Test Time columns, to ensure records are coerced to/from proper Python types.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
now = datetime.datetime.now().time().replace(microsecond=0)

# execute()
query = daily_schedule.insert()
values = {"time": now}
await database.execute(query, values)

# fetch_all()
query = daily_schedule.select()
results = await database.fetch_all(query=query)
assert len(results) == 1
assert results[0]["time"] == now


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_decimal_field(database_url):
Expand Down Expand Up @@ -957,7 +1045,32 @@ async def test_decimal_field(database_url):

@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_json_field(database_url):
async def test_enum_field(database_url):
"""
Test enum columns, to ensure correct cross-database support.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
# execute()
size = TshirtSize.SMALL
color = TshirtColor.GREEN
values = {"size": size, "color": color}
query = tshirt_size.insert()
await database.execute(query, values)

# fetch_all()
query = tshirt_size.select()
results = await database.fetch_all(query=query)

assert len(results) == 1
assert results[0]["size"] == size
assert results[0]["color"] == color


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_json_dict_field(database_url):
"""
Test JSON columns, to ensure correct cross-database support.
"""
Expand All @@ -978,6 +1091,29 @@ async def test_json_field(database_url):
assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1}


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_json_list_field(database_url):
"""
Test JSON columns, to ensure correct cross-database support.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
# execute()
data = ["lemon", "raspberry", "lime", "pumice"]
values = {"data": data}
query = session.insert()
await database.execute(query, values)

# fetch_all()
query = session.select()
results = await database.fetch_all(query=query)

assert len(results) == 1
assert results[0]["data"] == ["lemon", "raspberry", "lime", "pumice"]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_custom_field(database_url):
Expand Down