Skip to content

Commit

Permalink
test: check multiple databases in the same task use independant conne…
Browse files Browse the repository at this point in the history
…ctions
  • Loading branch information
zevisert committed Apr 11, 2023
1 parent 75969d3 commit 4cd7451
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import re
from unittest.mock import MagicMock, patch

import itertools
import pytest
import sqlalchemy

Expand Down Expand Up @@ -789,15 +789,16 @@ async def test_connect_and_disconnect(database_url):

@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_connection_context(database_url):
"""
Test connection contexts are task-local.
"""
async def test_connection_context_same_task(database_url):
async with Database(database_url) as database:
async with database.connection() as connection_1:
async with database.connection() as connection_2:
assert connection_1 is connection_2


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_connection_context_multiple_tasks(database_url):
async with Database(database_url) as database:
connection_1 = None
connection_2 = None
Expand All @@ -817,9 +818,8 @@ async def get_connection_2():
connection_2 = connection
await test_complete.wait()

loop = asyncio.get_event_loop()
task_1 = loop.create_task(get_connection_1())
task_2 = loop.create_task(get_connection_2())
task_1 = asyncio.create_task(get_connection_1())
task_2 = asyncio.create_task(get_connection_2())
while connection_1 is None or connection_2 is None:
await asyncio.sleep(0.000001)
assert connection_1 is not connection_2
Expand All @@ -828,6 +828,20 @@ async def get_connection_2():
await task_2


@pytest.mark.parametrize(
"database_url1,database_url2",
(
pytest.param(db1, db2, id=f"{db1} | {db2}")
for (db1, db2) in itertools.combinations(DATABASE_URLS, 2)
),
)
@async_adapter
async def test_connection_context_multiple_databases(database_url1, database_url2):
async with Database(database_url1) as database1:
async with Database(database_url2) as database2:
assert database1.connection() is not database2.connection()


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_connection_context_with_raw_connection(database_url):
Expand Down

0 comments on commit 4cd7451

Please sign in to comment.