diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py index ec81e232fc8cf..617395a989d06 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py @@ -18,7 +18,7 @@ import os import traceback -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import yaml from openlineage.client import OpenLineageClient, set_producer @@ -50,12 +50,14 @@ get_dag_job_dependency_facet, get_processing_engine_facet, ) +from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from datetime import datetime from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.plugins.facets import AirflowRunFacet from airflow.sdk.execution_time.secrets_masker import SecretsMasker, _secrets_masker from airflow.utils.state import DagRunState else: @@ -172,10 +174,24 @@ def emit(self, event: RunEvent): event_type = event.eventType.value.lower() if event.eventType else "" transport_type = f"{self._client.transport.kind}".lower() + team_name = None + + facets = event.run.facets or {} + airflow_facet = cast("AirflowRunFacet | None", facets.get("airflow")) + + if airflow_facet: + team_name = airflow_facet.dagRun.get("dag_team_name") + try: with Stats.timer( "ol.emit.attempts", - tags={"event_type": event_type, "transport_type": transport_type}, + tags=prune_dict( + { + "event_type": event_type, + "transport_type": transport_type, + "team_name": team_name, + } + ), ): self._client.emit(redacted_event) self.log.info( @@ -184,7 +200,11 @@ def emit(self, event: RunEvent): event.run.runId, ) except Exception as e: - Stats.incr("ol.emit.failed") + Stats.incr( + "ol.emit.failed", + tags=prune_dict({"team_name": team_name}), + ) + self.log.warning( "Failed to emit OpenLineage `%s` event of id `%s` with the following exception: `%s`", event_type.upper(), diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index ec24c37128f89..4c059343e8090 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -41,6 +41,7 @@ from airflow.providers.openlineage.utils.utils import ( AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, + DagRunInfo, get_airflow_dag_run_facet, get_airflow_debug_facet, get_airflow_job_facet, @@ -57,6 +58,7 @@ print_warning, ) from airflow.settings import configure_orm +from airflow.utils.helpers import prune_dict from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -241,9 +243,19 @@ def on_running(): if not doc: doc, doc_type = get_dag_documentation(dag) + team_name = None + team_name = DagRunInfo.team_name(dagrun) + if controls.extract_operator_metadata: with Stats.timer( - "ol.extract", tags={"event_type": event_type, "operator_name": operator_name} + "ol.extract", + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ): task_metadata = self.extractor_manager.extract_metadata( dagrun=dagrun, @@ -257,6 +269,7 @@ def on_running(): "Skipping OpenLineage operator metadata extraction for task `%s` due to emission_policy.", task_instance.task_id, ) + task_metadata = OperatorLineage() redacted_event = self.adapter.start_task( @@ -291,10 +304,23 @@ def on_running(): }, ) event_size = len(Serde.to_json(redacted_event).encode("utf-8")) + + airflow_facet = redacted_event.run.facets.get("airflow") + team_name = None + + if airflow_facet: + team_name = getattr(airflow_facet.dagRun, "dag_team_name", None) + Stats.gauge( "ol.event.size", event_size, - tags={"event_type": event_type, "operator_name": operator_name}, + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ) self._execute(on_running, "on_running", use_fork=True) @@ -386,9 +412,19 @@ def on_success(): if not doc: doc, doc_type = get_dag_documentation(dag) + team_name = None + team_name = DagRunInfo.team_name(dagrun) + if controls.extract_operator_metadata: with Stats.timer( - "ol.extract", tags={"event_type": event_type, "operator_name": operator_name} + "ol.extract", + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ): task_metadata = self.extractor_manager.extract_metadata( dagrun=dagrun, @@ -402,6 +438,7 @@ def on_success(): "Skipping OpenLineage operator metadata extraction for task `%s` due to emission_policy.", task_instance.task_id, ) + task_metadata = OperatorLineage() redacted_event = self.adapter.complete_task( @@ -435,10 +472,23 @@ def on_success(): }, ) event_size = len(Serde.to_json(redacted_event).encode("utf-8")) + + airflow_facet = redacted_event.run.facets.get("airflow") + team_name = None + + if airflow_facet: + team_name = getattr(airflow_facet.dagRun, "dag_team_name", None) + Stats.gauge( "ol.event.size", event_size, - tags={"event_type": event_type, "operator_name": operator_name}, + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ) self._execute(on_success, "on_success", use_fork=True) @@ -545,9 +595,19 @@ def on_failure(): if not doc: doc, doc_type = get_dag_documentation(dag) + team_name = None + team_name = DagRunInfo.team_name(dagrun) + if controls.extract_operator_metadata: with Stats.timer( - "ol.extract", tags={"event_type": event_type, "operator_name": operator_name} + "ol.extract", + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ): task_metadata = self.extractor_manager.extract_metadata( dagrun=dagrun, @@ -595,10 +655,23 @@ def on_failure(): }, ) event_size = len(Serde.to_json(redacted_event).encode("utf-8")) + + airflow_facet = redacted_event.run.facets.get("airflow") + team_name = None + + if airflow_facet: + team_name = getattr(airflow_facet.dagRun, "dag_team_name", None) + Stats.gauge( "ol.event.size", event_size, - tags={"event_type": event_type, "operator_name": operator_name}, + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ) self._execute(on_failure, "on_failure", use_fork=True) @@ -681,9 +754,19 @@ def on_skipped(): if not doc: doc, doc_type = get_dag_documentation(dag) + team_name = None + team_name = DagRunInfo.team_name(dagrun) + if controls.extract_operator_metadata: with Stats.timer( - "ol.extract", tags={"event_type": event_type, "operator_name": operator_name} + "ol.extract", + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ): task_metadata = self.extractor_manager.extract_metadata( dagrun=dagrun, @@ -730,10 +813,23 @@ def on_skipped(): }, ) event_size = len(Serde.to_json(redacted_event).encode("utf-8")) + + airflow_facet = redacted_event.run.facets.get("airflow") + team_name = None + + if airflow_facet: + team_name = getattr(airflow_facet.dagRun, "dag_team_name", None) + Stats.gauge( "ol.event.size", event_size, - tags={"event_type": event_type, "operator_name": operator_name}, + tags=prune_dict( + { + "event_type": event_type, + "operator_name": operator_name, + "team_name": team_name, + } + ), ) self._execute(on_skipped, "on_skipped", use_fork=True) diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 6bf7658754615..bbe08c38f2e9c 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -24,7 +24,7 @@ from contextlib import suppress from functools import wraps from importlib import metadata -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar import attrs from openlineage.client.facet_v2 import ( @@ -50,6 +50,7 @@ BaseOperator, BaseSensorOperator, MappedOperator, + conf as airflow_conf, ) from airflow.providers.openlineage import ( __version__ as OPENLINEAGE_PROVIDER_VERSION, @@ -72,6 +73,7 @@ from airflow.providers.openlineage.version_compat import ( AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, get_base_airflow_version_tuple, ) from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG @@ -86,6 +88,9 @@ if not AIRFLOW_V_3_0_PLUS: from airflow.utils.session import NEW_SESSION, provide_session +if AIRFLOW_V_3_3_PLUS: + from airflow.models.dagbundle import DagBundleModel + if TYPE_CHECKING: from typing import TypeAlias @@ -980,9 +985,12 @@ class DagRunInfo(InfoJsonEncodable): "dag_bundle_version": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "bundle_version"), "dag_version_id": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "version_id"), "dag_version_number": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "version_number"), + "dag_team_name": lambda dagrun: DagRunInfo.team_name(dagrun) if AIRFLOW_V_3_3_PLUS else None, "deadlines": lambda dagrun: DagRunInfo.deadlines(dagrun), } + _team_name_cache: ClassVar[dict[str, str | None]] = {} + @classmethod def duration(cls, dagrun: DagRun) -> float | None: if not getattr(dagrun, "end_date", None) or not isinstance(dagrun.end_date, datetime.datetime): @@ -1053,6 +1061,21 @@ def dag_version_info(cls, dagrun: DagRun, key: str) -> str | int | None: return current_version.version_number raise ValueError(f"Unsupported key: {key}`") + @classmethod + def team_name(cls, dagrun: DagRun) -> str | None: + """Extract the team name for the DagRun.""" + if not AIRFLOW_V_3_3_PLUS or not airflow_conf.getboolean("core", "multi_team", fallback=False): + return None + + bundle_name = cls.dag_version_info(dagrun, "bundle_name") + if not isinstance(bundle_name, str): + return None + + if bundle_name not in cls._team_name_cache: + cls._team_name_cache[bundle_name] = DagBundleModel.get_team_name(bundle_name) + + return cls._team_name_cache[bundle_name] + class TaskInstanceInfo(InfoJsonEncodable): """Defines encoding TaskInstance object to JSON.""" diff --git a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py index 114631640bc3c..663f36ccc16e0 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/version_compat.py +++ b/providers/openlineage/src/airflow/providers/openlineage/version_compat.py @@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) +AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0) -__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_2_PLUS"] +__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_2_PLUS", "AIRFLOW_V_3_3_PLUS"] diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py b/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py index a3a252b8cbb18..5c8d32f0bc761 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_adapter.py @@ -21,6 +21,7 @@ import os import pathlib import uuid +from types import SimpleNamespace from unittest import mock from unittest.mock import ANY, MagicMock, call, patch @@ -49,6 +50,7 @@ from airflow.providers.openlineage.plugins.facets import ( AirflowDagRunFacet, AirflowDebugRunFacet, + AirflowRunFacet, AirflowStateRunFacet, ) from airflow.providers.openlineage.token_provider import ( @@ -65,7 +67,7 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker from tests_common.test_utils.taskinstance import create_task_instance -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS stats_reference = f"{Stats.__module__}.Stats" @@ -304,14 +306,59 @@ def test_create_client_overrides_env_vars(): assert client.transport.kind == "console" +@pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "start", + "transport_type": ANY, + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "start", + "transport_type": ANY, + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], +) @mock.patch(f"{stats_reference}.timer") @mock.patch(f"{stats_reference}.incr") -def test_emit_start_event(mock_stats_incr, mock_stats_timer): +def test_emit_start_event( + mock_stats_incr, + mock_stats_timer, + team_name, + expected_tags, +): + client = MagicMock() adapter = OpenLineageAdapter(client) run_id = str(uuid.uuid4()) event_time = datetime.datetime.now().isoformat() + + run_facets = None + if team_name is not None: + run_facets = { + "airflow": AirflowRunFacet( + dag={}, + dagRun={"dag_team_name": team_name}, + taskInstance={}, + task={}, + taskUuid="task_uuid", + ) + } + adapter.start_task( run_id=run_id, job_name="job", @@ -322,7 +369,7 @@ def test_emit_start_event(mock_stats_incr, mock_stats_timer): owners=[], tags=[], task=None, - run_facets=None, + run_facets=run_facets, ) assert ( @@ -340,6 +387,7 @@ def test_emit_start_event(mock_stats_incr, mock_stats_timer): "processing_engine": processing_engine_run.ProcessingEngineRunFacet( version=ANY, name="Airflow", openlineageAdapterVersion=ANY ), + **({"airflow": ANY} if team_name is not None else {}), }, ), job=Job( @@ -361,7 +409,10 @@ def test_emit_start_event(mock_stats_incr, mock_stats_timer): ) mock_stats_incr.assert_not_called() - mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) + mock_stats_timer.assert_called_with( + "ol.emit.attempts", + tags=expected_tags, + ) @mock.patch(f"{stats_reference}.timer") @@ -477,19 +528,64 @@ def test_emit_start_event_with_additional_information(mock_stats_incr, mock_stat mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) +@pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "complete", + "transport_type": ANY, + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "complete", + "transport_type": ANY, + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], +) @mock.patch(f"{stats_reference}.timer") @mock.patch(f"{stats_reference}.incr") -def test_emit_complete_event(mock_stats_incr, mock_stats_timer): +def test_emit_complete_event( + mock_stats_incr, + mock_stats_timer, + team_name, + expected_tags, +): client = MagicMock() adapter = OpenLineageAdapter(client) run_id = str(uuid.uuid4()) event_time = datetime.datetime.now().isoformat() + + task = OperatorLineage() + + if team_name is not None: + task.run_facets = { + "airflow": AirflowRunFacet( + dag=None, + dagRun={"dag_team_name": team_name}, + taskInstance=None, + task=None, + taskUuid="task_uuid", + ) + } + adapter.complete_task( run_id=run_id, end_time=event_time, job_name="job", - task=OperatorLineage(), + task=task, owners=[], tags=[], job_description=None, @@ -505,8 +601,11 @@ def test_emit_complete_event(mock_stats_incr, mock_stats_timer): run=Run( runId=run_id, facets={ + **({"airflow": ANY} if team_name is not None else {}), "processing_engine": processing_engine_run.ProcessingEngineRunFacet( - version=ANY, name="Airflow", openlineageAdapterVersion=ANY + version=ANY, + name="Airflow", + openlineageAdapterVersion=ANY, ), "nominalTime": nominal_time_run.NominalTimeRunFacet( nominalStartTime="2022-01-01T00:00:00", @@ -519,7 +618,9 @@ def test_emit_complete_event(mock_stats_incr, mock_stats_timer): name="job", facets={ "jobType": job_type_job.JobTypeJobFacet( - processingType="BATCH", integration="AIRFLOW", jobType="TASK" + processingType="BATCH", + integration="AIRFLOW", + jobType="TASK", ) }, ), @@ -532,7 +633,10 @@ def test_emit_complete_event(mock_stats_incr, mock_stats_timer): ) mock_stats_incr.assert_not_called() - mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) + mock_stats_timer.assert_called_with( + "ol.emit.attempts", + tags=expected_tags, + ) @mock.patch(f"{stats_reference}.timer") @@ -650,19 +754,63 @@ def test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) +@pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "fail", + "transport_type": ANY, + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "fail", + "transport_type": ANY, + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], +) @mock.patch(f"{stats_reference}.timer") @mock.patch(f"{stats_reference}.incr") -def test_emit_failed_event(mock_stats_incr, mock_stats_timer): +def test_emit_failed_event( + mock_stats_incr, + mock_stats_timer, + team_name, + expected_tags, +): client = MagicMock() adapter = OpenLineageAdapter(client) run_id = str(uuid.uuid4()) event_time = datetime.datetime.now().isoformat() + + task = OperatorLineage() + if team_name is not None: + task.run_facets = { + "airflow": AirflowRunFacet( + dag=None, + dagRun={"dag_team_name": team_name}, + taskInstance=None, + task=None, + taskUuid="task_uuid", + ) + } + adapter.fail_task( run_id=run_id, end_time=event_time, job_name="job", - task=OperatorLineage(), + task=task, owners=[], tags=[], job_description=None, @@ -678,6 +826,7 @@ def test_emit_failed_event(mock_stats_incr, mock_stats_timer): run=Run( runId=run_id, facets={ + **({"airflow": ANY} if team_name is not None else {}), "processing_engine": processing_engine_run.ProcessingEngineRunFacet( version=ANY, name="Airflow", openlineageAdapterVersion=ANY ), @@ -705,7 +854,10 @@ def test_emit_failed_event(mock_stats_incr, mock_stats_timer): ) mock_stats_incr.assert_not_called() - mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) + mock_stats_timer.assert_called_with( + "ol.emit.attempts", + tags=expected_tags, + ) @mock.patch(f"{stats_reference}.timer") @@ -1335,20 +1487,76 @@ def test_emit_dag_failed_event( mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) +@pytest.mark.parametrize( + ("team_name", "expected_timer_tags", "expected_failed_tags"), + [ + pytest.param( + None, + { + "event_type": ANY, + "transport_type": ANY, + }, + {}, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": ANY, + "transport_type": ANY, + "team_name": "team_a", + }, + { + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], +) @patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.get_or_create_openlineage_client") @patch("airflow.providers.openlineage.plugins.adapter.OpenLineageRedactor") @patch(f"{stats_reference}.timer") @patch(f"{stats_reference}.incr") def test_openlineage_adapter_stats_emit_failed( - mock_stats_incr, mock_stats_timer, mock_redact, mock_get_client + mock_stats_incr, + mock_stats_timer, + mock_redact, + mock_get_client, + team_name, + expected_timer_tags, + expected_failed_tags, ): adapter = OpenLineageAdapter() mock_get_client.return_value.emit.side_effect = Exception() - adapter.emit(MagicMock()) + event = SimpleNamespace( + eventType=SimpleNamespace(value="COMPLETE"), + run=SimpleNamespace( + runId="run-id", + facets={}, + ), + ) - mock_stats_timer.assert_called_with("ol.emit.attempts", tags={"event_type": ANY, "transport_type": ANY}) - mock_stats_incr.assert_has_calls([mock.call("ol.emit.failed")]) + if team_name is not None: + event.run.facets["airflow"] = SimpleNamespace( + dagRun={"dag_team_name": team_name}, + ) + + adapter.emit(event) + + mock_stats_timer.assert_called_once_with( + "ol.emit.attempts", + tags=expected_timer_tags, + ) + + mock_stats_incr.assert_called_once_with( + "ol.emit.failed", + tags=expected_failed_tags, + ) def test_build_dag_run_id_is_valid_uuid(): diff --git a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py index f3132746cd660..ea82d6c7e79ee 100644 --- a/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py +++ b/providers/openlineage/tests/unit/openlineage/plugins/test_listener.py @@ -23,6 +23,7 @@ from concurrent.futures import Future from contextlib import suppress from datetime import datetime +from types import SimpleNamespace from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock, patch @@ -49,7 +50,12 @@ from tests_common.test_utils.dag import create_scheduler_dag from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.taskinstance import create_task_instance -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS +from tests_common.test_utils.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, + AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, +) if AIRFLOW_V_3_1_PLUS: from airflow._shared.timezones import timezone @@ -1344,6 +1350,35 @@ def mock_task_id(dag_id, task_id, try_number, logical_date, map_index): return listener, task_instance + @pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "running", + "operator_name": "emptyoperator", + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "running", + "operator_name": "emptyoperator", + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], + ) + @mock.patch("airflow.providers.openlineage.plugins.listener.Serde.to_json") + @mock.patch("airflow.providers.openlineage.plugins.listener.DagRunInfo.team_name") + @mock.patch("airflow.providers.openlineage.plugins.listener.Stats") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.resolve_task_emission_policy") @@ -1363,6 +1398,11 @@ def test_adapter_start_task_is_called_with_proper_arguments( mock_disabled, mock_debug_facet, mock_debug_mode, + mock_stats, + mock_team_name, + mock_to_json, + team_name, + expected_tags, ): """Tests that the 'start_task' method of the OpenLineageAdapter is invoked with the correct arguments. @@ -1380,6 +1420,26 @@ def test_adapter_start_task_is_called_with_proper_arguments( mock_get_task_parent_run_facet.return_value = {"parent": 4} mock_debug_facet.return_value = {"debug": "packages"} mock_disabled.return_value = EmissionPolicy.defaults() + mock_team_name.return_value = team_name + mock_to_json.return_value = "{}" + + fake_event = SimpleNamespace( + run=SimpleNamespace( + facets=( + {} + if team_name is None + else { + "airflow": SimpleNamespace( + dagRun=SimpleNamespace( + dag_team_name=team_name, + ) + ) + } + ) + ) + ) + + listener.adapter.start_task.return_value = fake_event listener.on_task_instance_running(None, task_instance) listener.adapter.start_task.assert_called_once_with( @@ -1402,6 +1462,17 @@ def test_adapter_start_task_is_called_with_proper_arguments( }, ) + mock_stats.timer.assert_any_call( + "ol.extract", + tags=expected_tags, + ) + + mock_stats.gauge.assert_called_once_with( + "ol.event.size", + mock.ANY, + tags=expected_tags, + ) + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet") @@ -1644,6 +1715,35 @@ def test_adapter_start_task_is_called_with_dag_description_when_task_doc_is_empt assert listener.adapter.start_task.call_args.kwargs["job_description"] == "Test DAG Description" assert listener.adapter.start_task.call_args.kwargs["job_description_type"] == "text/plain" + @pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "fail", + "operator_name": "emptyoperator", + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "fail", + "operator_name": "emptyoperator", + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], + ) + @mock.patch("airflow.providers.openlineage.plugins.listener.Serde.to_json") + @mock.patch("airflow.providers.openlineage.plugins.listener.DagRunInfo.team_name") + @mock.patch("airflow.providers.openlineage.plugins.listener.Stats") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.resolve_task_emission_policy") @@ -1661,7 +1761,12 @@ def test_adapter_fail_task_is_called_with_proper_arguments( mock_disabled, mock_debug_facet, mock_debug_mode, + mock_stats, + mock_team_name, + mock_to_json, time_machine, + team_name, + expected_tags, ): """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. @@ -1679,8 +1784,29 @@ def test_adapter_fail_task_is_called_with_proper_arguments( mock_get_task_parent_run_facet.return_value = {"parent": 4} mock_debug_facet.return_value = {"debug": "packages"} mock_disabled.return_value = EmissionPolicy.defaults() + mock_team_name.return_value = team_name + mock_to_json.return_value = "{}" err = ValueError("test") + + fake_event = SimpleNamespace( + run=SimpleNamespace( + facets=( + {} + if team_name is None + else { + "airflow": SimpleNamespace( + dagRun=SimpleNamespace( + dag_team_name=team_name, + ) + ) + } + ) + ) + ) + + listener.adapter.fail_task.return_value = fake_event + listener.on_task_instance_failed(previous_state=None, task_instance=task_instance, error=err) listener.adapter.fail_task.assert_called_once_with( end_time="2023-01-03T13:01:01+00:00", @@ -1702,6 +1828,17 @@ def test_adapter_fail_task_is_called_with_proper_arguments( error=err, ) + mock_stats.timer.assert_any_call( + "ol.extract", + tags=expected_tags, + ) + + mock_stats.gauge.assert_called_once_with( + "ol.event.size", + mock.ANY, + tags=expected_tags, + ) + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.resolve_task_emission_policy") @@ -1839,6 +1976,35 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_ adapter.fail_task(**expected_args) assert mock_emit.assert_called_once + @pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "complete", + "operator_name": "emptyoperator", + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "complete", + "operator_name": "emptyoperator", + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], + ) + @mock.patch("airflow.providers.openlineage.plugins.listener.Serde.to_json") + @mock.patch("airflow.providers.openlineage.plugins.listener.DagRunInfo.team_name") + @mock.patch("airflow.providers.openlineage.plugins.listener.Stats") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.resolve_task_emission_policy") @@ -1856,7 +2022,12 @@ def test_adapter_complete_task_is_called_with_proper_arguments( mock_disabled, mock_debug_facet, mock_debug_mode, + mock_stats, + mock_team_name, + mock_to_json, time_machine, + team_name, + expected_tags, ): """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. @@ -1874,6 +2045,26 @@ def test_adapter_complete_task_is_called_with_proper_arguments( mock_get_task_parent_run_facet.return_value = {"parent": 4} mock_debug_facet.return_value = {"debug": "packages"} mock_disabled.return_value = EmissionPolicy.defaults() + mock_team_name.return_value = team_name + mock_to_json.return_value = "{}" + + fake_event = SimpleNamespace( + run=SimpleNamespace( + facets=( + {} + if team_name is None + else { + "airflow": SimpleNamespace( + dagRun=SimpleNamespace( + dag_team_name=team_name, + ) + ) + } + ) + ) + ) + + listener.adapter.complete_task.return_value = fake_event listener.on_task_instance_success(None, task_instance) calls = listener.adapter.complete_task.call_args_list @@ -1897,6 +2088,17 @@ def test_adapter_complete_task_is_called_with_proper_arguments( }, ) + mock_stats.timer.assert_any_call( + "ol.extract", + tags=expected_tags, + ) + + mock_stats.gauge.assert_called_once_with( + "ol.event.size", + mock.ANY, + tags=expected_tags, + ) + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.resolve_task_emission_policy") @@ -2206,6 +2408,121 @@ def test_listener_on_task_instance_skipped_do_not_call_adapter_when_disabled_ope listener.extractor_manager.extract_metadata.assert_not_called() listener.adapter.complete_task.assert_not_called() + @pytest.mark.parametrize( + ("team_name", "expected_tags"), + [ + pytest.param( + None, + { + "event_type": "complete", + "operator_name": "emptyoperator", + }, + id="without_team", + ), + pytest.param( + "team_a", + { + "event_type": "complete", + "operator_name": "emptyoperator", + "team_name": "team_a", + }, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, + reason="team_name metrics require Airflow 3.3+", + ), + ), + ], + ) + @mock.patch("airflow.providers.openlineage.plugins.listener.Serde.to_json") + @mock.patch("airflow.providers.openlineage.plugins.listener.DagRunInfo.team_name") + @mock.patch("airflow.providers.openlineage.plugins.listener.Stats") + @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.resolve_task_emission_policy") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") + @mock.patch("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets") + @mock.patch( + "airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", + new=regular_call, + ) + def test_adapter_complete_task_is_called_with_proper_arguments_on_skip( + self, + mock_get_user_provided_run_facets, + mock_get_airflow_run_facet, + mock_get_task_parent_run_facet, + mock_disabled, + mock_debug_facet, + mock_debug_mode, + mock_stats, + mock_team_name, + mock_to_json, + time_machine, + team_name, + expected_tags, + ): + time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False) + + listener, task_instance = self._create_listener_and_task_instance() + + mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2, "parent": 99} + mock_get_airflow_run_facet.return_value = {"airflow": {"task": "..."}} + mock_get_task_parent_run_facet.return_value = {"parent": 4} + mock_debug_facet.return_value = {"debug": "packages"} + mock_disabled.return_value = EmissionPolicy.defaults() + mock_team_name.return_value = team_name + mock_to_json.return_value = "{}" + + fake_event = SimpleNamespace( + run=SimpleNamespace( + facets=( + {} + if team_name is None + else { + "airflow": SimpleNamespace( + dagRun=SimpleNamespace( + dag_team_name=team_name, + ) + ) + } + ) + ) + ) + listener.adapter.complete_task.return_value = fake_event + + listener.on_task_instance_skipped(previous_state=None, task_instance=task_instance) + + listener.adapter.complete_task.assert_called_once_with( + end_time="2023-01-03T13:01:01+00:00", + job_name="dag_id.task_id", + run_id="2020-01-01T01:01:01+00:00.dag_id.task_id.1.-1", + task=listener.extractor_manager.extract_metadata(), + owners=["task_owner"], + tags={"tag1", "tag2"}, + job_description="TASK Description", + job_description_type="text/markdown", + nominal_start_time=None, + nominal_end_time=None, + run_facets={ + "parent": 4, + "custom_user_facet": 2, + "airflow": {"task": "..."}, + "debug": "packages", + }, + ) + + mock_stats.timer.assert_any_call( + "ol.extract", + tags=expected_tags, + ) + + mock_stats.gauge.assert_called_once_with( + "ol.event.size", + mock.ANY, + tags=expected_tags, + ) + @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute") @mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit") @mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True) diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index 36fe5125a7f0f..767690b1523aa 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -92,6 +92,7 @@ AIRFLOW_V_3_0_3_PLUS, AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, ) BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash" @@ -276,6 +277,7 @@ def test_get_airflow_dag_run_facet(): "dag_bundle_version": "bundle_version", "dag_version_id": "version_id", "dag_version_number": "version_number", + "dag_team_name": None, "triggering_user_name": "user1", "partition_key": "some_partition_key", "partition_date": "2024-06-01T02:03:34+00:00", @@ -331,6 +333,57 @@ def test_dag_run_version(key): assert result == key +@pytest.mark.db_test +@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="multi-team requires Airflow 3.3+") +@patch("airflow.providers.openlineage.utils.utils.DagBundleModel.get_team_name") +@patch("airflow.providers.openlineage.utils.utils.airflow_conf.getboolean", return_value=True) +def test_dag_run_team_name( + mock_getboolean, + mock_get_team_name, +): + DagRunInfo._team_name_cache.clear() + + dagrun_mock = MagicMock(DagRun) + dagrun_mock.dag_versions = [ + MagicMock( + bundle_name="bundle_name", + bundle_version="bundle_version", + id="version_id", + version_number="version_number", + ) + ] + + mock_get_team_name.return_value = "team_a" + + assert DagRunInfo.team_name(dagrun_mock) == "team_a" + assert DagRunInfo.team_name(dagrun_mock) == "team_a" + + mock_get_team_name.assert_called_once_with("bundle_name") + + +@pytest.mark.db_test +@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="multi-team requires Airflow 3.1+") +def test_dag_run_team_name_no_bundle(): + dagrun_mock = MagicMock(DagRun) + del dagrun_mock.dag_versions + + assert DagRunInfo.team_name(dagrun_mock) is None + + +@pytest.mark.db_test +@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="multi-team requires Airflow 3.1+") +@patch("airflow.providers.openlineage.utils.utils.DagBundleModel.get_team_name") +@patch("airflow.providers.openlineage.utils.utils.airflow_conf.getboolean", return_value=False) +def test_dag_run_team_name_multi_team_disabled(mock_getboolean, mock_get_team_name): + + DagRunInfo._team_name_cache.clear() + + dagrun_mock = MagicMock(DagRun) + + assert DagRunInfo.team_name(dagrun_mock) is None + mock_get_team_name.assert_not_called() + + def test_get_fully_qualified_class_name_serialized_operator(): op_module_path = BASH_OPERATOR_PATH op_name = "BashOperator" @@ -2965,6 +3018,7 @@ def test_dagrun_info_af3(mocked_dag_versions): "conf": {"a": 1}, "clear_number": 0, "dag_id": "dag_id", + "dag_team_name": None, "data_interval_end": "2024-06-01T00:00:00+00:00", "data_interval_start": "2024-06-01T00:00:00+00:00", "duration": 74.000546, @@ -3011,6 +3065,7 @@ def test_dagrun_info_af2(): "conf": {"a": 1}, "clear_number": 0, "dag_id": "dag_id", + "dag_team_name": None, "data_interval_end": "2024-06-01T00:00:00+00:00", "data_interval_start": "2024-06-01T00:00:00+00:00", "duration": 74.000546,