Skip to content

Commit

Permalink
Added new queries to the Database (#356)
Browse files Browse the repository at this point in the history
This commit adds new queries to the Database to:

- Get the `Port` row of the `Token` defined in the parameter
- Get the `Step` rows which have a dependency with the `Port`
  defined in the parameter

Moreover, this commit adds two methods to create the `Token`
objects after the queries `get_dependees` and `get_dependers`
  • Loading branch information
LanderOtto committed Jan 23, 2024
1 parent 09d7058 commit ac8f41a
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ jobs:
python -m pip install -r docs/requirements.txt
- name: "Build documentation and check for consistency"
env:
CHECKSUM: "fc9bdd01ef90f0b24d019da7683aa528af10119ef54d0a13cb16ec7adaa04242"
CHECKSUM: "3cecd445a73422c87ea1a37d014b3657e15d18cdf784fb4cc8466b2d05ae8230"
run: |
cd docs/
HASH="$(make checksum | tail -n1)"
Expand Down
15 changes: 15 additions & 0 deletions docs/source/ext/database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,31 @@ The ``Database`` interface, defined in the ``streamflow.core.persistence`` modul
) -> MutableSequence[MutableMapping[str, Any]]:
...
async def get_input_steps(
self, port_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
...
async def get_output_ports(
self, step_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
...
async def get_output_steps(
self, port_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
...
async def get_port(
self, port_id: int
) -> MutableMapping[str, Any]:
...
async def get_port_from_token(
self, token_id: int
) -> MutableMapping[str, Any]:
...
async def get_port_tokens(
self, port_id: int
) -> MutableSequence[int]:
Expand Down
16 changes: 16 additions & 0 deletions streamflow/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,32 @@ async def get_input_ports(
) -> MutableSequence[MutableMapping[str, Any]]:
...

@abstractmethod
async def get_input_steps(
self, port_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
...

@abstractmethod
async def get_output_ports(
self, step_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
...

@abstractmethod
async def get_output_steps(
self, port_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
...

@abstractmethod
async def get_port(self, port_id: int) -> MutableMapping[str, Any]:
...

@abstractmethod
async def get_port_from_token(self, token_id: int) -> MutableMapping[str, Any]:
...

@abstractmethod
async def get_port_tokens(self, port_id: int) -> MutableSequence[int]:
...
Expand Down
30 changes: 30 additions & 0 deletions streamflow/persistence/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,16 @@ async def get_input_ports(
) as cursor:
return await cursor.fetchall()

async def get_input_steps(
self, port_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
async with self.connection as db:
async with db.execute(
"SELECT * FROM dependency WHERE port = :port AND type = :type",
{"port": port_id, "type": DependencyType.OUTPUT.value},
) as cursor:
return await cursor.fetchall()

async def get_output_ports(
self, step_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
Expand All @@ -333,6 +343,16 @@ async def get_output_ports(
) as cursor:
return await cursor.fetchall()

async def get_output_steps(
self, port_id: int
) -> MutableSequence[MutableMapping[str, Any]]:
async with self.connection as db:
async with db.execute(
"SELECT * FROM dependency WHERE port = :port AND type = :type",
{"port": port_id, "type": DependencyType.INPUT.value},
) as cursor:
return await cursor.fetchall()

@cachedmethod(lambda self: self.port_cache)
async def get_port(self, port_id: int) -> MutableMapping[str, Any]:
async with self.connection as db:
Expand All @@ -341,6 +361,16 @@ async def get_port(self, port_id: int) -> MutableMapping[str, Any]:
) as cursor:
return await cursor.fetchone()

async def get_port_from_token(self, token_id: int) -> MutableMapping[str, Any]:
async with self.connection as db:
async with db.execute(
"SELECT port.* "
"FROM token JOIN port ON token.port = port.id "
"WHERE token.id = :token_id",
{"token_id": token_id},
) as cursor:
return await cursor.fetchone()

async def get_port_tokens(self, port_id: int) -> MutableSequence[int]:
async with self.connection as db:
async with db.execute(
Expand Down
32 changes: 32 additions & 0 deletions streamflow/persistence/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio

from streamflow.core.context import StreamFlowContext
from streamflow.core.persistence import DatabaseLoadingContext


async def load_depender_tokens(
persistent_id: int,
context: StreamFlowContext,
loading_context: DatabaseLoadingContext,
):
rows = await context.database.get_dependers(persistent_id)
return await asyncio.gather(
*(
asyncio.create_task(loading_context.load_token(context, row["depender"]))
for row in rows
)
)


async def load_dependee_tokens(
persistent_id: int,
context: StreamFlowContext,
loading_context: DatabaseLoadingContext,
):
rows = await context.database.get_dependees(persistent_id)
return await asyncio.gather(
*(
asyncio.create_task(loading_context.load_token(context, row["dependee"]))
for row in rows
)
)
44 changes: 44 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from streamflow.core import utils
from streamflow.core.context import StreamFlowContext
from streamflow.workflow.step import ExecuteStep
from tests.utils.workflow import create_workflow


@pytest.mark.asyncio
async def test_get_steps_queries(context: StreamFlowContext):
"""Test get_input_steps and get_output_steps queries"""
workflow, (port_a, job_port, job_port_2, port_b, port_c) = await create_workflow(
context, num_port=5
)
step = workflow.create_step(
cls=ExecuteStep, name=utils.random_name(), job_port=job_port
)
step_2 = workflow.create_step(
cls=ExecuteStep, name=utils.random_name(), job_port=job_port_2
)
step.add_input_port("in", port_a)
step.add_output_port("out", port_b)
step_2.add_input_port("in2", port_b)
step_2.add_output_port("out2", port_c)
await workflow.save(context)

input_steps_port_a = await context.database.get_input_steps(port_a.persistent_id)
assert len(input_steps_port_a) == 0
output_steps_port_a = await context.database.get_output_steps(port_a.persistent_id)
assert len(output_steps_port_a) == 1
assert output_steps_port_a[0]["step"] == step.persistent_id

input_steps_port_b = await context.database.get_input_steps(port_b.persistent_id)
assert len(input_steps_port_b) == 1
assert input_steps_port_b[0]["step"] == step.persistent_id
output_steps_port_b = await context.database.get_output_steps(port_b.persistent_id)
assert len(output_steps_port_b) == 1
assert output_steps_port_b[0]["step"] == step_2.persistent_id

input_steps_port_c = await context.database.get_input_steps(port_c.persistent_id)
assert len(input_steps_port_c) == 1
assert input_steps_port_c[0]["step"] == step_2.persistent_id
output_steps_port_c = await context.database.get_output_steps(port_c.persistent_id)
assert len(output_steps_port_c) == 0
47 changes: 8 additions & 39 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import annotations

import asyncio
from typing import Any, MutableMapping, MutableSequence, cast
from typing import Any, MutableMapping, MutableSequence

import pytest

from streamflow.core import utils
from streamflow.core.context import StreamFlowContext
from streamflow.core.persistence import DatabaseLoadingContext
from streamflow.core.workflow import Port, Status, Step, Token, Workflow
from streamflow.cwl.command import CWLCommand, CWLCommandToken
from streamflow.cwl.translator import _create_command_output_processor_base
from streamflow.persistence.loading_context import DefaultDatabaseLoadingContext
from streamflow.persistence.utils import load_depender_tokens, load_dependee_tokens
from streamflow.workflow.combinator import (
CartesianProductCombinator,
DotProductCombinator,
Expand Down Expand Up @@ -68,40 +67,6 @@ async def _general_test(
return step


async def _load_dependees(
token_id: int, loading_context: DatabaseLoadingContext, context: StreamFlowContext
) -> MutableSequence[Token]:
rows = await context.database.get_dependees(token_id)
return cast(
MutableSequence[Token],
await asyncio.gather(
*(
asyncio.create_task(
loading_context.load_token(context, row["dependee"])
)
for row in rows
)
),
)


async def _load_dependers(
token_id: int, loading_context: DatabaseLoadingContext, context: StreamFlowContext
) -> MutableSequence[Token]:
rows = await context.database.get_dependers(token_id)
return cast(
MutableSequence[Token],
await asyncio.gather(
*(
asyncio.create_task(
loading_context.load_token(context, row["depender"])
)
for row in rows
)
),
)


async def _put_tokens(
token_list: MutableSequence[Token],
in_port: Port,
Expand Down Expand Up @@ -130,7 +95,9 @@ async def _verify_dependency_tokens(
token_reloaded = await context.database.get_token(token_id=token.persistent_id)
assert token_reloaded["port"] == port.persistent_id

depender_list = await _load_dependers(token.persistent_id, loading_context, context)
depender_list = await load_depender_tokens(
token.persistent_id, context, loading_context
)
print(
"depender:",
{token.persistent_id: [t.persistent_id for t in depender_list]},
Expand All @@ -143,7 +110,9 @@ async def _verify_dependency_tokens(
for t1 in depender_list:
assert _contains_id(t1.persistent_id, expected_depender)

dependee_list = await _load_dependees(token.persistent_id, loading_context, context)
dependee_list = await load_dependee_tokens(
token.persistent_id, context, loading_context
)
print(
"dependee:",
{token.persistent_id: [t.persistent_id for t in dependee_list]},
Expand Down

0 comments on commit ac8f41a

Please sign in to comment.