diff --git a/contributing-docs/testing/unit_tests.rst b/contributing-docs/testing/unit_tests.rst index 9388eaef576fa..0b6cbb31a5eaf 100644 --- a/contributing-docs/testing/unit_tests.rst +++ b/contributing-docs/testing/unit_tests.rst @@ -721,6 +721,66 @@ You can also use a fixture to create an object that needs the database. conn = request.getfixturevalue(conn) ... +Database fixtures and test isolation +------------------------------------ + +Database tests must not leak rows into the database that other tests can see. A test that +depends on rows left behind by an earlier test, or that breaks because of them, is +order-dependent and flaky. The shared fixtures below own their rows and remove them when the +test finishes, so individual tests should not need defensive pre-cleaning such as +``clear_db_dag_bundles()`` or ``clear_db_teams()`` at the top of a test. + +Why these fixtures exist +........................ + +``dag_maker`` provides one consistent way to create a Dag in the database and remove it again. A +Dag spans several foreign-key-linked rows (``DagModel``, ``SerializedDagModel``, ``DagVersion``, +``DagRun``, ``TaskInstance``, and the ``DagBundleModel`` it belongs to), so creating or deleting +them by hand is error-prone, and a missed row leaks into later tests. ``dag_maker`` keeps that +setup and teardown in one place. + +Bundles and teams were added later with their own ``clear_db_*`` helpers rather than through +``dag_maker``, which left cleanup to each caller and led tests to add defensive +``clear_db_dag_bundles()`` / ``clear_db_teams()`` calls against leaked rows. The fixtures now clean +up after themselves, so that pre-cleaning is no longer required. + +``dag_maker`` +............. + +``dag_maker`` is the primary fixture for tests that need a Dag in the database. It is a context +manager that builds a Dag, serializes it, and writes the ``DagModel``, ``DagRun``, +``SerializedDagModel``, and ``DagVersion`` rows for you: + +.. code-block:: python + + def test_something(dag_maker): + with dag_maker("my_dag") as dag: + EmptyOperator(task_id="task") + dr = dag_maker.create_dagrun() + ... + +On teardown ``dag_maker`` removes everything it created. Because the ``dag_maker`` bundle is shared +across tests, it drops that ``DagBundleModel`` row only once no Dag still references it +(``DagModel.bundle_name`` is a foreign key with no ``ON DELETE`` action, so deleting a referenced +bundle would fail). Prefer ``dag_maker`` over constructing ``DagBag``, ``DagBundleModel``, or +``DagModel`` rows by hand. + +``testing_dag_bundle`` and ``testing_team`` +........................................... + +For tests that need a bundle or a team but do not go through ``dag_maker``, use the +``testing_dag_bundle`` and ``testing_team`` fixtures. Each one lazily creates a shared +``"testing"`` row only if it does not already exist, and tears that row down on exit only when +this fixture is the one that created it, so overlapping usage does not delete a row another +fixture still needs. ``testing_dag_bundle`` drops the ``"testing"`` bundle only once nothing +references it, leaving the cleanup of the test's own dags to whichever fixture owns them. +``testing_team`` deletes its row directly, because every foreign key to ``team.name`` is +``ON DELETE CASCADE`` or ``ON DELETE SET NULL``. + +If you find yourself adding ``clear_db_*`` calls at the start of a test to work around rows left +by another test, that is a sign the other test's fixture is not cleaning up after itself. Fix the +fixture rather than spreading defensive cleanup across tests. + Running Unit tests ------------------ diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index eaac4e37fe01c..f10a6d10dbcff 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -878,6 +878,23 @@ def __call__( def serialized_dag(self) -> SerializedDAG: ... +def _delete_bundle_if_unreferenced(session, bundle_name): + """Delete a DagBundleModel row, but only once no DagModel still references it. + + ``DagModel.bundle_name`` is a foreign key with no ``ON DELETE`` action, and the bundle is + shared across tests, so it can only be dropped after the last referencing Dag is gone. + """ + from sqlalchemy import delete, func, select + + from airflow.models.dag import DagModel + from airflow.models.dagbundle import DagBundleModel + + if not session.scalar( + select(func.count()).select_from(DagModel).where(DagModel.bundle_name == bundle_name) + ): + session.execute(delete(DagBundleModel).where(DagBundleModel.name == bundle_name)) + + @pytest.fixture def dag_maker(request) -> Generator[DagMaker, None, None]: """ @@ -946,6 +963,7 @@ def __init__(self): self.dagbag = DagBag(os.devnull) else: self.dagbag = DagBag(os.devnull, include_examples=False) # type: ignore[call-arg] + self.created_bundle_names: set[str] = set() def __enter__(self): self.serialized_model = None @@ -1380,6 +1398,7 @@ def __call__( ): self.session.add(DagBundleModel(name=self.bundle_name)) self.session.commit() + self.created_bundle_names.add(self.bundle_name) return self @@ -1424,6 +1443,9 @@ def cleanup(self): self.session.execute(delete(DagModel).where(DagModel.dag_id.in_(dag_ids))) self.session.execute(delete(TaskMap).where(TaskMap.dag_id.in_(dag_ids))) self.session.execute(delete(AssetEvent).where(AssetEvent.source_dag_id.in_(dag_ids))) + if AIRFLOW_V_3_0_PLUS: + for bundle_name in self.created_bundle_names: + _delete_bundle_if_unreferenced(self.session, bundle_name) self.session.commit() if self._own_session: self.session.expunge_all() @@ -1743,6 +1765,8 @@ def session(): @pytest.fixture def get_test_dag(): + created = {"bundle": False, "import_error_files": set()} + def _get(dag_id: str): from airflow import settings from airflow.models.serialized_dag import SerializedDagModel @@ -1784,6 +1808,7 @@ def _get(dag_id: str): stacktrace=stacktrace, ) ) + created["import_error_files"].add(str(dag_file)) return @@ -1798,6 +1823,7 @@ def _get(dag_id: str): session = settings.Session() if not session.scalar(select(func.count()).where(DagBundleModel.name == "testing")): session.add(DagBundleModel(name="testing")) + created["bundle"] = True session.flush() SerializedDAG.bulk_write_to_db("testing", None, [dag], session=session) session.commit() @@ -1808,7 +1834,25 @@ def _get(dag_id: str): return dag - return _get + yield _get + + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + + if not AIRFLOW_V_3_0_PLUS: + return + + from sqlalchemy import delete + + from airflow.models.errors import ParseImportError + from airflow.utils.session import create_session + + with create_session() as session: + if created["import_error_files"]: + session.execute( + delete(ParseImportError).where(ParseImportError.filename.in_(created["import_error_files"])) + ) + if created["bundle"]: + _delete_bundle_if_unreferenced(session, "testing") @pytest.fixture @@ -2918,40 +2962,60 @@ def mock_xcom_backend(): def testing_dag_bundle(): from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - if AIRFLOW_V_3_0_PLUS: - from sqlalchemy import func, select + if not AIRFLOW_V_3_0_PLUS: + yield + return + + from sqlalchemy import func, select - from airflow.models.dagbundle import DagBundleModel - from airflow.utils.session import create_session + from airflow.models.dagbundle import DagBundleModel + from airflow.utils.session import create_session + created = False + with create_session() as session: + if ( + session.scalar( + select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing") + ) + == 0 + ): + session.add(DagBundleModel(name="testing")) + created = True + + yield + + if created: with create_session() as session: - if ( - session.scalar( - select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing") - ) - == 0 - ): - testing = DagBundleModel(name="testing") - session.add(testing) + _delete_bundle_if_unreferenced(session, "testing") @pytest.fixture def testing_team(): from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - if AIRFLOW_V_3_0_PLUS: - from sqlalchemy import select + if not AIRFLOW_V_3_0_PLUS: + yield None + return - from airflow.models.team import Team - from airflow.utils.session import create_session + from sqlalchemy import delete, select + from airflow.models.team import Team + from airflow.utils.session import create_session + + created = False + with create_session() as session: + team = session.scalar(select(Team).where(Team.name == "testing")) + if not team: + team = Team(name="testing") + session.add(team) + session.commit() + created = True + yield team + + if created: + # FKs to team.name are CASCADE / SET NULL, so deleting the row is safe. with create_session() as session: - team = session.scalar(select(Team).where(Team.name == "testing")) - if not team: - team = Team(name="testing") - session.add(team) - session.flush() - yield team + session.execute(delete(Team).where(Team.name == "testing")) @pytest.fixture