Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions contributing-docs/testing/unit_tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,66 @@ You can also use a fixture to create an object that needs the database.
conn = request.getfixturevalue(conn)
...

Database fixtures and test isolation
------------------------------------

Database tests must not leak rows into the database that other tests can see. A test that
depends on rows left behind by an earlier test, or that breaks because of them, is
order-dependent and flaky. The shared fixtures below own their rows and remove them when the
test finishes, so individual tests should not need defensive pre-cleaning such as
``clear_db_dag_bundles()`` or ``clear_db_teams()`` at the top of a test.

Why these fixtures exist
........................

Comment thread
anishgirianish marked this conversation as resolved.
``dag_maker`` provides one consistent way to create a Dag in the database and remove it again. A
Dag spans several foreign-key-linked rows (``DagModel``, ``SerializedDagModel``, ``DagVersion``,
``DagRun``, ``TaskInstance``, and the ``DagBundleModel`` it belongs to), so creating or deleting
them by hand is error-prone, and a missed row leaks into later tests. ``dag_maker`` keeps that
setup and teardown in one place.

Bundles and teams were added later with their own ``clear_db_*`` helpers rather than through
``dag_maker``, which left cleanup to each caller and led tests to add defensive
``clear_db_dag_bundles()`` / ``clear_db_teams()`` calls against leaked rows. The fixtures now clean
up after themselves, so that pre-cleaning is no longer required.

``dag_maker``
.............

``dag_maker`` is the primary fixture for tests that need a Dag in the database. It is a context
manager that builds a Dag, serializes it, and writes the ``DagModel``, ``DagRun``,
``SerializedDagModel``, and ``DagVersion`` rows for you:

.. code-block:: python

def test_something(dag_maker):
with dag_maker("my_dag") as dag:
EmptyOperator(task_id="task")
dr = dag_maker.create_dagrun()
...

On teardown ``dag_maker`` removes everything it created. Because the ``dag_maker`` bundle is shared
across tests, it drops that ``DagBundleModel`` row only once no Dag still references it
(``DagModel.bundle_name`` is a foreign key with no ``ON DELETE`` action, so deleting a referenced
bundle would fail). Prefer ``dag_maker`` over constructing ``DagBag``, ``DagBundleModel``, or
``DagModel`` rows by hand.

``testing_dag_bundle`` and ``testing_team``
...........................................

For tests that need a bundle or a team but do not go through ``dag_maker``, use the
``testing_dag_bundle`` and ``testing_team`` fixtures. Each one lazily creates a shared
``"testing"`` row only if it does not already exist, and tears that row down on exit only when
this fixture is the one that created it, so overlapping usage does not delete a row another
fixture still needs. ``testing_dag_bundle`` drops the ``"testing"`` bundle only once nothing
references it, leaving the cleanup of the test's own dags to whichever fixture owns them.
``testing_team`` deletes its row directly, because every foreign key to ``team.name`` is
``ON DELETE CASCADE`` or ``ON DELETE SET NULL``.

If you find yourself adding ``clear_db_*`` calls at the start of a test to work around rows left
by another test, that is a sign the other test's fixture is not cleaning up after itself. Fix the
fixture rather than spreading defensive cleanup across tests.
Comment thread
anishgirianish marked this conversation as resolved.

Running Unit tests
------------------

Expand Down
110 changes: 87 additions & 23 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,23 @@ def __call__(
def serialized_dag(self) -> SerializedDAG: ...


def _delete_bundle_if_unreferenced(session, bundle_name):
"""Delete a DagBundleModel row, but only once no DagModel still references it.

``DagModel.bundle_name`` is a foreign key with no ``ON DELETE`` action, and the bundle is
shared across tests, so it can only be dropped after the last referencing Dag is gone.
"""
from sqlalchemy import delete, func, select

from airflow.models.dag import DagModel
from airflow.models.dagbundle import DagBundleModel

if not session.scalar(
select(func.count()).select_from(DagModel).where(DagModel.bundle_name == bundle_name)
):
session.execute(delete(DagBundleModel).where(DagBundleModel.name == bundle_name))


@pytest.fixture
def dag_maker(request) -> Generator[DagMaker, None, None]:
"""
Expand Down Expand Up @@ -946,6 +963,7 @@ def __init__(self):
self.dagbag = DagBag(os.devnull)
else:
self.dagbag = DagBag(os.devnull, include_examples=False) # type: ignore[call-arg]
self.created_bundle_names: set[str] = set()

def __enter__(self):
self.serialized_model = None
Expand Down Expand Up @@ -1380,6 +1398,7 @@ def __call__(
):
self.session.add(DagBundleModel(name=self.bundle_name))
self.session.commit()
self.created_bundle_names.add(self.bundle_name)

return self

Expand Down Expand Up @@ -1424,6 +1443,9 @@ def cleanup(self):
self.session.execute(delete(DagModel).where(DagModel.dag_id.in_(dag_ids)))
self.session.execute(delete(TaskMap).where(TaskMap.dag_id.in_(dag_ids)))
self.session.execute(delete(AssetEvent).where(AssetEvent.source_dag_id.in_(dag_ids)))
if AIRFLOW_V_3_0_PLUS:
for bundle_name in self.created_bundle_names:
_delete_bundle_if_unreferenced(self.session, bundle_name)
self.session.commit()
if self._own_session:
self.session.expunge_all()
Expand Down Expand Up @@ -1743,6 +1765,8 @@ def session():

@pytest.fixture
def get_test_dag():
created = {"bundle": False, "import_error_files": set()}

def _get(dag_id: str):
from airflow import settings
from airflow.models.serialized_dag import SerializedDagModel
Expand Down Expand Up @@ -1784,6 +1808,7 @@ def _get(dag_id: str):
stacktrace=stacktrace,
)
)
created["import_error_files"].add(str(dag_file))

return

Expand All @@ -1798,6 +1823,7 @@ def _get(dag_id: str):
session = settings.Session()
if not session.scalar(select(func.count()).where(DagBundleModel.name == "testing")):
session.add(DagBundleModel(name="testing"))
created["bundle"] = True
session.flush()
SerializedDAG.bulk_write_to_db("testing", None, [dag], session=session)
session.commit()
Expand All @@ -1808,7 +1834,25 @@ def _get(dag_id: str):

return dag

return _get
yield _get

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if not AIRFLOW_V_3_0_PLUS:
return

from sqlalchemy import delete

from airflow.models.errors import ParseImportError
from airflow.utils.session import create_session

with create_session() as session:
if created["import_error_files"]:
session.execute(
delete(ParseImportError).where(ParseImportError.filename.in_(created["import_error_files"]))
)
if created["bundle"]:
_delete_bundle_if_unreferenced(session, "testing")


@pytest.fixture
Expand Down Expand Up @@ -2918,40 +2962,60 @@ def mock_xcom_backend():
def testing_dag_bundle():
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from sqlalchemy import func, select
if not AIRFLOW_V_3_0_PLUS:
yield
return

from sqlalchemy import func, select

from airflow.models.dagbundle import DagBundleModel
from airflow.utils.session import create_session
from airflow.models.dagbundle import DagBundleModel
from airflow.utils.session import create_session

created = False
with create_session() as session:
if (
session.scalar(
select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing")
)
== 0
):
session.add(DagBundleModel(name="testing"))
created = True

yield

if created:
with create_session() as session:
if (
session.scalar(
select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing")
)
== 0
):
testing = DagBundleModel(name="testing")
session.add(testing)
_delete_bundle_if_unreferenced(session, "testing")


@pytest.fixture
def testing_team():
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from sqlalchemy import select
if not AIRFLOW_V_3_0_PLUS:
yield None
return

from airflow.models.team import Team
from airflow.utils.session import create_session
from sqlalchemy import delete, select

from airflow.models.team import Team
from airflow.utils.session import create_session

created = False
with create_session() as session:
team = session.scalar(select(Team).where(Team.name == "testing"))
if not team:
team = Team(name="testing")
session.add(team)
session.commit()
created = True
yield team

if created:
# FKs to team.name are CASCADE / SET NULL, so deleting the row is safe.
with create_session() as session:
team = session.scalar(select(Team).where(Team.name == "testing"))
if not team:
team = Team(name="testing")
session.add(team)
session.flush()
yield team
session.execute(delete(Team).where(Team.name == "testing"))


@pytest.fixture
Expand Down
Loading