From 75dc3084192e01a52b98ca71b133372e5657e88a Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Mon, 29 Jun 2026 09:58:02 -0700 Subject: [PATCH 1/3] Add ti.trigger_dag_run() Task SDK accessor for triggering dag runs Airflow already exposes the dag-run poll half to task code as ti.get_dagrun_state(); the trigger half had no first-class accessor and was only reachable through the DagRunTriggerException side channel. Adding the symmetric ti.trigger_dag_run() routes a trigger through the same execution-API endpoint and scoped token the task runner already uses, so an operator can own its trigger-and-wait execution directly. --- .../airflow/sdk/execution_time/task_runner.py | 33 +++++++++++++++++++ task-sdk/src/airflow/sdk/types.py | 12 +++++++ .../execution_time/test_task_runner.py | 32 ++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index ac8889daf0286..bc1f73c1863a0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -838,6 +838,39 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: return response.state + @staticmethod + def trigger_dag_run( + dag_id: str, + run_id: str, + *, + conf: dict[str, Any] | None = None, + logical_date: datetime | None = None, + run_after: datetime | None = None, + reset_dag_run: bool = False, + note: str | None = None, + ) -> bool: + """ + Trigger a new run of ``dag_id`` through the execution API. + + Counterpart to :meth:`get_dagrun_state`. Returns ``True`` when a run was created and + ``False`` when ``run_id`` already existed and ``reset_dag_run`` was not set (with + ``reset_dag_run=True`` the existing run is cleared and ``True`` is returned). + """ + response = SUPERVISOR_COMMS.send( + msg=TriggerDagRun( + dag_id=dag_id, + run_id=run_id, + conf=conf, + logical_date=logical_date, + run_after=run_after, + reset_dag_run=reset_dag_run, + note=note, + ) + ) + if isinstance(response, ErrorResponse) and response.error == ErrorType.DAGRUN_ALREADY_EXISTS: + return False + return True + @staticmethod def get_dag(dag_id: str) -> DagResult: """Return the DAG with the given dag_id.""" diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index efa3c9f6d9426..67895bbf2f545 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -220,6 +220,18 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: ... + @staticmethod + def trigger_dag_run( + dag_id: str, + run_id: str, + *, + conf: dict[str, Any] | None = None, + logical_date: AwareDatetime | None = None, + run_after: AwareDatetime | None = None, + reset_dag_run: bool = False, + note: str | None = None, + ) -> bool: ... + @staticmethod def get_dag(dag_id: str) -> DagResult: ... diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 64493f6305f3c..e2172cfb55b50 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2831,6 +2831,38 @@ def test_get_dagrun_state(self, mock_supervisor_comms): ) assert state == "running" + def test_trigger_dag_run_returns_true_when_created(self, mock_supervisor_comms): + """trigger_dag_run sends a TriggerDagRun and returns True when a new run is created.""" + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + created = RuntimeTaskInstance.trigger_dag_run( + dag_id="test_dag", + run_id="run1", + conf={"k": "v"}, + reset_dag_run=True, + ) + + mock_supervisor_comms.send.assert_called_once_with( + msg=TriggerDagRun( + dag_id="test_dag", + run_id="run1", + conf={"k": "v"}, + logical_date=None, + run_after=None, + reset_dag_run=True, + note=None, + ), + ) + assert created is True + + def test_trigger_dag_run_returns_false_when_already_exists(self, mock_supervisor_comms): + """trigger_dag_run returns False when the run id already exists and reset is not requested.""" + mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) + + created = RuntimeTaskInstance.trigger_dag_run(dag_id="test_dag", run_id="run1") + + assert created is False + def test_get_task_states(self, mock_supervisor_comms): """Test that get_task_states sends the correct request and returns the states.""" mock_supervisor_comms.send.return_value = TaskStatesResult(task_states={"run1": {"task1": "running"}}) From 32d5c546003aabe7d3e9e6d77c7bab2210241a2e Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Mon, 29 Jun 2026 10:11:19 -0700 Subject: [PATCH 2/3] Make TriggerDagRunOperator own its synchronous execution and add durable reconnect On Airflow 3, TriggerDagRunOperator.execute() raised DagRunTriggerException and the task runner did the trigger and the wait loop, so the synchronous wait-and-reconnect contract was duplicated between the operator and the runner and could drift. With the new ti.trigger_dag_run() accessor the operator does the submit and poll itself and reuses ResumableJobMixin directly, keeping that contract in one place. The opt-in durable flag persists the triggered run id before polling so a worker crash mid-wait reconnects to the in-flight run on retry instead of triggering a duplicate. Deferrable still needs the triggerer handoff, so it keeps the exception path; Airflow < 3.3 and Airflow 2 are unchanged. --- .../standard/operators/trigger_dagrun.py | 125 +++++- .../standard/operators/test_trigger_dagrun.py | 361 +++++++++++++----- 2 files changed, 379 insertions(+), 107 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 60d96cbe3f495..431bfd2e5f742 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -46,6 +46,7 @@ from airflow.providers.standard.version_compat import ( AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, BaseOperator, is_arg_set, ) @@ -57,6 +58,26 @@ except ImportError: from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef] +try: + from airflow.sdk import ResumableJobMixin +except ImportError: + # ResumableJobMixin was added in Airflow 3.3. On older Airflow the durable path is gated off + # (see execute()), so this stub only lets the operator subclass it without behavior change. + class ResumableJobMixin: # type: ignore[no-redef] + """Stub used on Airflow < 3.3, where the durable path is disabled.""" + + external_id_key: str = "remote_job_id" + + def __init__(self, *, durable: bool = True, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.durable = durable + + def execute_resumable(self, context): + external_id = self.submit_job(context) + self.poll_until_complete(external_id, context) + return self.get_job_result(external_id, context) + + XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso" XCOM_RUN_ID = "trigger_run_id" @@ -78,6 +99,10 @@ def __str__(self) -> str: return f"Dag {self.dag_id} is paused" +class TriggeredDagRunFailed(AirflowException): + """Raise when a synchronously-waited triggered Dag run reaches a failed state.""" + + class TriggerDagRunLink(BaseOperatorLink): """ Operator link for TriggerDagRunOperator. @@ -121,7 +146,7 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: return build_airflow_url_with_query(query) -class TriggerDagRunOperator(BaseOperator): +class TriggerDagRunOperator(ResumableJobMixin, BaseOperator): """ Triggers a DAG run for a specified DAG ID. @@ -154,6 +179,10 @@ class TriggerDagRunOperator(BaseOperator): Airflow 3.x this requires Airflow 3.2.0+ (it relies on the task-SDK DAG state endpoint added then); on Airflow 3.0/3.1 setting this raises ``NotImplementedError``. :param deferrable: If waiting for completion, whether to defer the task until done, default is ``False``. + :param durable: On Airflow 3.3+ with a synchronous ``wait_for_completion`` (non-deferrable), persist the + triggered run id before polling so a worker crash mid-wait reconnects to the in-flight run on retry + instead of triggering a duplicate. No effect on Airflow < 3.3 or with ``deferrable=True``. + Default ``False``. :param openlineage_inject_parent_info: whether to include OpenLineage metadata about the parent task in the triggered DAG run's conf, enabling improved lineage tracking. The metadata is only injected if OpenLineage is enabled and running. This option does not modify any other part of the conf, @@ -181,6 +210,9 @@ class TriggerDagRunOperator(BaseOperator): ui_color = "#ffefeb" operator_extra_links = [TriggerDagRunLink()] + # Key under which the triggered run id is persisted to task_state_store for the durable path. + external_id_key = "triggered_dag_run_id" + def __init__( self, *, @@ -198,10 +230,11 @@ def __init__( fail_when_dag_is_paused: bool = False, note: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + durable: bool = False, openlineage_inject_parent_info: bool = True, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(durable=durable, **kwargs) self.trigger_dag_id = trigger_dag_id self.trigger_run_id = trigger_run_id self.conf = conf @@ -225,6 +258,9 @@ def __init__( run_after = _validate_datetime_param("run_after", run_after) self.logical_date = logical_date self.run_after = run_after + # Parsed datetimes, set in execute() before the durable (Airflow 3.3+) submit path reads them. + self._parsed_logical_date: datetime.datetime | None = None + self._parsed_run_after: datetime.datetime | None = None if fail_when_dag_is_paused and AIRFLOW_V_3_0_PLUS and not AIRFLOW_V_3_2_PLUS: raise NotImplementedError( "Setting `fail_when_dag_is_paused` requires Airflow 3.2.0+ on Airflow 3.x " @@ -289,6 +325,17 @@ def execute(self, context: Context): if dag_model.is_paused: raise AirflowException(f"Dag {self.trigger_dag_id} is paused") + if AIRFLOW_V_3_3_PLUS and not self.deferrable: + # On Airflow 3.3+ the operator owns its synchronous execution through the execution-API + # accessors (ti.trigger_dag_run / ti.get_dagrun_state) and ResumableJobMixin, rather than + # handing the trigger and wait off to the task runner via DagRunTriggerException. + self._parsed_logical_date = parsed_logical_date + self._parsed_run_after = parsed_run_after if self.run_after is not NOTSET else None + if self.wait_for_completion: + return self.execute_resumable(context) + self.submit_job(context) + return + if AIRFLOW_V_3_0_PLUS: self._trigger_dag_af_3( context=context, @@ -301,6 +348,80 @@ def execute(self, context: Context): context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date ) + # ResumableJobMixin hooks — drive the synchronous Airflow 3.3+ path through the execution-API + # accessors so the operator owns its execution instead of the task runner. + + def submit_job(self, context: Context) -> Any: + ti = context["ti"] + # run_id is computed in execute() before this runs, so it is always set here. + run_id = cast("str", self.trigger_run_id) + created = ti.trigger_dag_run( + self.trigger_dag_id, + run_id, + conf=self.conf, + logical_date=self._parsed_logical_date, + run_after=self._parsed_run_after, + reset_dag_run=self.reset_dag_run, + note=self.note, + ) + if not created: + if self.skip_when_already_exists: + raise AirflowSkipException( + "Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists" + ) + raise DagRunAlreadyExists(DagRun(dag_id=self.trigger_dag_id, run_id=run_id)) + self._record_triggered_run(run_id, context) + return run_id + + def get_job_status(self, external_id: Any, context: Context) -> str: + return context["ti"].get_dagrun_state(self.trigger_dag_id, str(external_id)) + + def is_job_active(self, status: str) -> bool: + return bool(status) and status not in self.allowed_states and status not in self.failed_states + + def is_job_succeeded(self, status: str) -> bool: + return status in self.allowed_states + + def poll_until_complete(self, external_id: Any, context: Context) -> None: + run_id = str(external_id) + self._record_triggered_run(run_id, context) + ti = context["ti"] + while True: + self.log.info( + "Waiting for %s on %s to reach an allowed state %s ...", + self.trigger_dag_id, + run_id, + self.allowed_states, + ) + time.sleep(self.poke_interval) + state = ti.get_dagrun_state(self.trigger_dag_id, run_id) + if state in self.failed_states: + raise TriggeredDagRunFailed(f"{self.trigger_dag_id} failed with failed state {state}") + if state in self.allowed_states: + self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) + return + + def get_job_result(self, external_id: Any, context: Context) -> Any: + run_id = str(external_id) + self._record_triggered_run(run_id, context) + return run_id + + def _record_triggered_run(self, run_id: str, context: Context) -> None: + """Record the triggered run id as a task attribute and on XCom (run id + extra-link URL).""" + self.trigger_run_id = run_id + ti = context.get("ti") or context.get("task_instance") + if not (ti and hasattr(ti, "xcom_push")): + return + ti.xcom_push(key=XCOM_RUN_ID, value=run_id) + from airflow.utils import helpers + + build_url_fn = getattr(helpers, "build_airflow_dagrun_url", None) + if build_url_fn: + ti.xcom_push( + key=TriggerDagRunLink().xcom_key, + value=build_url_fn(dag_id=self.trigger_dag_id, run_id=run_id), + ) + def _trigger_dag_af_3(self, context, run_id, parsed_logical_date, parsed_run_after=None): from airflow.providers.common.compat.sdk import DagRunTriggerException diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index 2bce35574dc01..af5f5e852bc0c 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -30,8 +30,12 @@ from airflow.models.dagrun import DagRun from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance -from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred, conf -from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, TaskDeferred, conf +from airflow.providers.standard.operators.trigger_dagrun import ( + XCOM_RUN_ID, + TriggerDagRunOperator, + TriggeredDagRunFailed, +) from airflow.providers.standard.triggers.external_task import DagStateTrigger from airflow.utils.session import create_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -42,10 +46,9 @@ AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, ) -if AIRFLOW_V_3_0_PLUS: - from airflow.providers.common.compat.sdk import DagRunTriggerException if AIRFLOW_V_3_1_PLUS: from airflow.sdk import timezone else: @@ -108,77 +111,84 @@ def teardown_method(self): session.execute(delete(DagBundleModel).where(DagBundleModel.name == "test_bundle")) session.commit() - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @staticmethod + def _ti(*, created: bool = True, states=None): + """Build a mocked task instance exposing the execution-API trigger/poll accessors.""" + ti = mock.MagicMock() + ti.stats_tags = {} + ti.trigger_dag_run.return_value = created + if states is not None: + ti.get_dagrun_state.side_effect = states + return ti + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun_with_run_after(self): - """ - Test TriggerDagRunOperator. - - We only verify that the operator runs and raises correct exception. The actual execution logic - after the exception is in Task SDK code. - """ + """The operator triggers via the execution-API accessor, deriving the run id from run_after.""" with time_machine.travel("2025-02-18T08:04:46Z", tick=False): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "bar"}, run_after=timezone.datetime(2025, 2, 19, 12, 0, 0), + openlineage_inject_parent_info=False, ) + ti = self._ti() - # Ensure correct exception is raised - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={}) + task.execute(context={"ti": ti}) - assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID - assert exc_info.value.conf == {"foo": "bar"} - assert exc_info.value.logical_date is None - assert exc_info.value.reset_dag_run is False - assert exc_info.value.skip_when_already_exists is False - assert exc_info.value.wait_for_completion is False - assert exc_info.value.allowed_states == [DagRunState.SUCCESS] - assert exc_info.value.failed_states == [DagRunState.FAILED] - if getattr(exc_info, "note", None) is not None: - assert exc_info.value.note == "Test note" + ti.trigger_dag_run.assert_called_once_with( + TRIGGERED_DAG_ID, + mock.ANY, + conf={"foo": "bar"}, + logical_date=None, + run_after=task.run_after, + reset_dag_run=False, + note=None, + ) + assert task.allowed_states == [DagRunState.SUCCESS] + assert task.failed_states == [DagRunState.FAILED] expected_run_id = DagRun.generate_run_id( run_type=DagRunType.MANUAL, run_after=task.run_after ).rsplit("_", 1)[0] # rsplit because last few characters are random. - assert exc_info.value.dag_run_id.rsplit("_", 1)[0] == expected_run_id + triggered_run_id = ti.trigger_dag_run.call_args.args[1] + assert triggered_run_id.rsplit("_", 1)[0] == expected_run_id assert task.trigger_run_id.rsplit("_", 1)[0] == expected_run_id # run_id is saved as attribute - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun(self): - """ - Test TriggerDagRunOperator. - - We only verify that the operator runs and raises correct exception. The actual execution logic - after the exception is in Task SDK code. - """ + """The operator triggers via the execution-API accessor with a generated run id and note.""" with time_machine.travel("2025-02-18T08:04:46Z", tick=False): task = TriggerDagRunOperator( - task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, conf={"foo": "bar"}, note="Test note" + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf={"foo": "bar"}, + note="Test note", + openlineage_inject_parent_info=False, ) + ti = self._ti() - # Ensure correct exception is raised - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={}) - - assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID - assert exc_info.value.conf == {"foo": "bar"} - assert exc_info.value.logical_date is not None - assert exc_info.value.reset_dag_run is False - assert exc_info.value.skip_when_already_exists is False - assert exc_info.value.wait_for_completion is False - assert exc_info.value.allowed_states == [DagRunState.SUCCESS] - assert exc_info.value.failed_states == [DagRunState.FAILED] - if getattr(exc_info, "note", None) is not None: - assert exc_info.value.note == "Test note" + task.execute(context={"ti": ti}) + # With no logical_date/run_after, the operator derives both from utcnow(); under frozen + # time generate_run_id is deterministic (no random suffix when logical_date is set). expected_run_id = DagRun.generate_run_id( - run_type=DagRunType.MANUAL, run_after=timezone.utcnow() - ).rsplit("_", 1)[0] - # rsplit because last few characters are random. - assert exc_info.value.dag_run_id == expected_run_id + run_type=DagRunType.MANUAL, + logical_date=timezone.utcnow(), + run_after=timezone.utcnow(), + ) + ti.trigger_dag_run.assert_called_once_with( + TRIGGERED_DAG_ID, + expected_run_id, + conf={"foo": "bar"}, + logical_date=timezone.utcnow(), + run_after=None, + reset_dag_run=False, + note="Test note", + ) + assert task.allowed_states == [DagRunState.SUCCESS] + assert task.failed_states == [DagRunState.FAILED] assert task.trigger_run_id == expected_run_id # run_id is saved as attribute @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") @@ -203,7 +213,7 @@ def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): expected_url = f"{base_url}dags/{TRIGGERED_DAG_ID}/runs/test_run_id" assert link == expected_url, f"Expected {expected_url}, but got {link}" - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun_pushes_extra_link_xcom_before_exception(self): """ Eagerly push the "Triggered DAG" extra-link URL so the UI button is available @@ -220,44 +230,47 @@ def test_trigger_dagrun_pushes_extra_link_xcom_before_exception(self): task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="custom_run_id", + openlineage_inject_parent_info=False, ) - ti_mock = mock.MagicMock() - with pytest.raises(DagRunTriggerException): - task.execute(context={"task_instance": ti_mock}) + ti = self._ti() + task.execute(context={"ti": ti}) expected_url = build_url_fn(dag_id=TRIGGERED_DAG_ID, run_id="custom_run_id") - ti_mock.xcom_push.assert_called_once_with( + ti.xcom_push.assert_any_call( key=TriggerDagRunLink().xcom_key, value=expected_url, ) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="custom_run_id", + openlineage_inject_parent_info=False, ) + ti = self._ti() - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={}) + task.execute(context={"ti": ti}) - assert exc_info.value.dag_run_id == "custom_run_id" + assert ti.trigger_dag_run.call_args.args[1] == "custom_run_id" + assert task.trigger_run_id == "custom_run_id" - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun_with_logical_date(self): """Test TriggerDagRunOperator with custom logical_date.""" task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, logical_date=timezone.datetime(2021, 1, 2, 3, 4, 5), + openlineage_inject_parent_info=False, ) + ti = self._ti() - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={}) + task.execute(context={"ti": ti}) - assert exc_info.value.logical_date == timezone.datetime(2021, 1, 2, 3, 4, 5) + assert ti.trigger_dag_run.call_args.kwargs["logical_date"] == timezone.datetime(2021, 1, 2, 3, 4, 5) def test_trigger_dagrun_operator_templated_invalid_conf(self, dag_maker): """Test passing a conf that is not JSON Serializable raise error.""" @@ -400,9 +413,7 @@ def test_trigger_dagrun_fails_when_target_dag_is_paused(self): mock_ti.get_dag.assert_called_once_with(TRIGGERED_DAG_ID) - @pytest.mark.skipif( - not AIRFLOW_V_3_2_PLUS, reason="Needs the task-SDK GetDag endpoint added in Airflow 3.2.0" - ) + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun_proceeds_when_target_dag_is_not_paused(self): task = TriggerDagRunOperator( task_id="test_task", @@ -410,16 +421,15 @@ def test_trigger_dagrun_proceeds_when_target_dag_is_not_paused(self): fail_when_dag_is_paused=True, openlineage_inject_parent_info=False, ) - mock_ti = mock.MagicMock() + mock_ti = self._ti() mock_ti.get_dag.return_value.is_paused = False - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": mock_ti}) + task.execute(context={"ti": mock_ti}) - assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID + assert mock_ti.trigger_dag_run.call_args.args[0] == TRIGGERED_DAG_ID mock_ti.get_dag.assert_called_once_with(TRIGGERED_DAG_ID) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") def test_trigger_dagrun_with_str_conf(self): """ Test TriggerDagRunOperator conf is proper json string formatted @@ -429,14 +439,14 @@ def test_trigger_dagrun_with_str_conf(self): task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, conf='{"foo": "bar"}', + openlineage_inject_parent_info=False, ) + ti = self._ti() - # Ensure correct exception is raised - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={}) + task.execute(context={"ti": ti}) - assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID - assert exc_info.value.conf == {"foo": "bar"} + assert ti.trigger_dag_run.call_args.args[0] == TRIGGERED_DAG_ID + assert ti.trigger_dag_run.call_args.kwargs["conf"] == {"foo": "bar"} @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") def test_trigger_dagrun_with_str_conf_error(self): @@ -454,7 +464,7 @@ def test_trigger_dagrun_with_str_conf_error(self): task.execute(context={}) @pytest.mark.parametrize("original_conf", (None, {}, {"foo": "bar"})) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") @mock.patch(f"{TRIGGER_OP_PATH}.safe_inject_openlineage_properties_into_dagrun_conf") def test_trigger_dagrun_conf_openlineage_injection_disabled_with_explicit_false_arg( self, mock_inject, original_conf @@ -467,16 +477,16 @@ def test_trigger_dagrun_conf_openlineage_injection_disabled_with_explicit_false_ conf=original_conf, openlineage_inject_parent_info=False, ) + ti = self._ti() - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": mock.MagicMock()}) + task.execute(context={"ti": ti}) # Injection function should not be called mock_inject.assert_not_called() # Conf should remain unchanged - assert exc_info.value.conf == original_conf + assert ti.trigger_dag_run.call_args.kwargs["conf"] == original_conf - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") def test_trigger_dagrun_conf_openlineage_injection_disabled_when_ol_not_accessible( self, mock_is_accessible @@ -493,15 +503,14 @@ def test_trigger_dagrun_conf_openlineage_injection_disabled_when_ol_not_accessib trigger_dag_id=TRIGGERED_DAG_ID, conf=original_conf, ) + ti = self._ti() - ti = mock.MagicMock() - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": ti}) + task.execute(context={"ti": ti}) # Conf should remain unchanged when OL is unavailable - assert exc_info.value.conf == original_conf + assert ti.trigger_dag_run.call_args.kwargs["conf"] == original_conf - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") @pytest.mark.parametrize( ("provider_version", "should_modify"), [ @@ -546,20 +555,18 @@ def _mock_version(package): conf=original_conf, ) - mock_ti = mock.MagicMock() + mock_ti = self._ti() if should_modify: # When version is sufficient, mock _get_openlineage_parent_info to return data with mock.patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=ol_parent_info): - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": mock_ti}) - # Conf should be modified - assert exc_info.value.conf == injected_conf + task.execute(context={"ti": mock_ti}) + # Conf should be modified + assert mock_ti.trigger_dag_run.call_args.kwargs["conf"] == injected_conf else: # When version is insufficient, _get_openlineage_parent_info will raise - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": mock_ti}) + task.execute(context={"ti": mock_ti}) # Conf should remain unchanged - assert exc_info.value.conf == original_conf + assert mock_ti.trigger_dag_run.call_args.kwargs["conf"] == original_conf @pytest.mark.parametrize( "exception", @@ -569,7 +576,7 @@ def _mock_version(package): RuntimeError("Runtime issue"), ], ) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") def test_trigger_dagrun_conf_openlineage_injection_preserves_conf_on_exception( self, mock_is_accessible, exception @@ -592,15 +599,14 @@ def test_trigger_dagrun_conf_openlineage_injection_preserves_conf_on_exception( conf=original_conf, ) - mock_ti = mock.MagicMock() - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": mock_ti}) + mock_ti = self._ti() + task.execute(context={"ti": mock_ti}) # Conf should remain unchanged when any exception occurs during injection - assert exc_info.value.conf == original_conf + assert mock_ti.trigger_dag_run.call_args.kwargs["conf"] == original_conf @pytest.mark.parametrize("original_conf", (None, {}, {"foo": "bar"})) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") @mock.patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info") def test_trigger_dagrun_conf_openlineage_injection_valid_data( @@ -629,12 +635,11 @@ def test_trigger_dagrun_conf_openlineage_injection_valid_data( conf=original_conf, ) - mock_ti = mock.MagicMock() - with pytest.raises(DagRunTriggerException) as exc_info: - task.execute(context={"ti": mock_ti}) + mock_ti = self._ti() + task.execute(context={"ti": mock_ti}) # Conf should contain injected OpenLineage metadata - assert exc_info.value.conf == injected_conf + assert mock_ti.trigger_dag_run.call_args.kwargs["conf"] == injected_conf # Verify _get_openlineage_parent_info was called with ti mock_get_parent_info.assert_called_once_with(ti=mock_ti) @@ -1391,3 +1396,149 @@ def test_trigger_dagrun_conf_openlineage_injection_valid_data( assert dagrun.conf == injected_conf # Verify _get_openlineage_parent_info was called mock_get_parent_info.assert_called_once() + + +class _FakeTaskStateStore: + """Minimal in-memory task_state_store stub for durable TriggerDagRunOperator tests.""" + + def __init__(self, initial: dict | None = None): + self._store = dict(initial or {}) + + def get(self, key, default=None): + return self._store.get(key, default) + + def set(self, key, value, *, retention=None): + self._store[key] = value + + +@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Operator-owned execution path requires Airflow 3.3+") +class TestTriggerDagRunOperatorOwnedExecution: + """Airflow 3.3+: the operator owns synchronous execution via the execution-API accessors.""" + + @staticmethod + def _ti(*, created: bool = True, states=None): + ti = mock.MagicMock() + ti.stats_tags = {} + ti.trigger_dag_run.return_value = created + if states is not None: + ti.get_dagrun_state.side_effect = states + return ti + + @staticmethod + def _task(**kwargs): + kwargs.setdefault("openlineage_inject_parent_info", False) + kwargs.setdefault("poke_interval", 0) + return TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, **kwargs) + + def test_fire_and_forget_triggers_via_accessor(self): + task = self._task(trigger_run_id="run1", wait_for_completion=False, conf={"foo": "bar"}) + ti = self._ti() + + task.execute(context={"ti": ti}) + + ti.trigger_dag_run.assert_called_once_with( + TRIGGERED_DAG_ID, + "run1", + conf={"foo": "bar"}, + logical_date=mock.ANY, + run_after=None, + reset_dag_run=False, + note=None, + ) + ti.xcom_push.assert_any_call(key=XCOM_RUN_ID, value="run1") + assert task.trigger_run_id == "run1" + + def test_sync_wait_polls_until_allowed_state(self): + task = self._task(trigger_run_id="run1", wait_for_completion=True) + ti = self._ti(states=[DagRunState.RUNNING, DagRunState.SUCCESS]) + + task.execute(context={"ti": ti}) + + ti.trigger_dag_run.assert_called_once() + assert ti.get_dagrun_state.call_count == 2 + assert all(c.args == (TRIGGERED_DAG_ID, "run1") for c in ti.get_dagrun_state.call_args_list) + + def test_sync_wait_raises_on_failed_state(self): + task = self._task(trigger_run_id="run1", wait_for_completion=True) + ti = self._ti(states=[DagRunState.FAILED]) + + with pytest.raises(TriggeredDagRunFailed): + task.execute(context={"ti": ti}) + + def test_skip_when_already_exists(self): + task = self._task(trigger_run_id="run1", wait_for_completion=False, skip_when_already_exists=True) + ti = self._ti(created=False) + + with pytest.raises(AirflowSkipException): + task.execute(context={"ti": ti}) + + def test_already_exists_without_skip_raises(self): + task = self._task(trigger_run_id="run1", wait_for_completion=False) + ti = self._ti(created=False) + + with pytest.raises(DagRunAlreadyExists): + task.execute(context={"ti": ti}) + + def test_reset_dag_run_is_passed_to_accessor(self): + task = self._task(trigger_run_id="run1", wait_for_completion=False, reset_dag_run=True) + ti = self._ti() + + task.execute(context={"ti": ti}) + + assert ti.trigger_dag_run.call_args.kwargs["reset_dag_run"] is True + + def test_durable_persists_run_id_before_polling(self): + task = self._task(trigger_run_id="run1", wait_for_completion=True, durable=True) + store = _FakeTaskStateStore() + seen = {} + + def _state(_dag_id, _run_id): + seen.setdefault("at_first_poll", store.get(task.external_id_key)) + return DagRunState.SUCCESS + + ti = self._ti() + ti.get_dagrun_state.side_effect = _state + + task.execute(context={"ti": ti, "task_state_store": store}) + + assert store.get(task.external_id_key) == "run1" + # The id must already be persisted by the first poll so a crash mid-wait reconnects on retry. + assert seen["at_first_poll"] == "run1" + + def test_durable_reconnects_to_running_prior_run(self): + task = self._task(trigger_run_id="new_run", wait_for_completion=True, durable=True) + store = _FakeTaskStateStore({"triggered_dag_run_id": "prior_run"}) + ti = self._ti(states=[DagRunState.RUNNING, DagRunState.SUCCESS]) + + task.execute(context={"ti": ti, "task_state_store": store}) + + ti.trigger_dag_run.assert_not_called() # reconnected, never re-triggered + assert all(c.args[1] == "prior_run" for c in ti.get_dagrun_state.call_args_list) + + def test_durable_short_circuits_on_succeeded_prior_run(self): + task = self._task(trigger_run_id="new_run", wait_for_completion=True, durable=True) + store = _FakeTaskStateStore({"triggered_dag_run_id": "prior_run"}) + ti = self._ti(states=[DagRunState.SUCCESS]) + + task.execute(context={"ti": ti, "task_state_store": store}) + + ti.trigger_dag_run.assert_not_called() + assert ti.get_dagrun_state.call_count == 1 # only the resume check, no poll loop + + def test_durable_resubmits_after_failed_prior_run(self): + task = self._task(trigger_run_id="fresh_run", wait_for_completion=True, durable=True) + store = _FakeTaskStateStore({"triggered_dag_run_id": "dead_run"}) + ti = self._ti(states=[DagRunState.FAILED, DagRunState.SUCCESS]) + + task.execute(context={"ti": ti, "task_state_store": store}) + + ti.trigger_dag_run.assert_called_once() # resubmitted fresh + assert store.get(task.external_id_key) == "fresh_run" + + def test_durable_falls_back_to_fresh_trigger_without_task_state_store(self): + task = self._task(trigger_run_id="run1", wait_for_completion=True, durable=True) + ti = self._ti(states=[DagRunState.SUCCESS]) + + task.execute(context={"ti": ti}) # no task_state_store in context + + ti.trigger_dag_run.assert_called_once() From 0d21a67adf711a9c362f585cc5af2d5ef90f6e04 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Mon, 29 Jun 2026 22:05:29 -0700 Subject: [PATCH 3/3] Fix TriggerDagRunOperator task SDK tests for operator-owned execution On Airflow 3.3+ the operator owns its synchronous trigger and wait through the execution-API accessors, so the task instance no longer reaches the task runner's _handle_trigger_dag_run for that path. Two tests still asserted the old task-runner call sequence and failed: assert the resolved task state in the wait test (keyed by message type, so it holds on both paths and is timezone-independent), and pin the original-error guard to the < 3.3 task-runner path it actually covers. The 3.3+ behaviour is verified in the provider test_trigger_dagrun.py. --- .../execution_time/test_task_runner.py | 81 +++++-------------- 1 file changed, 19 insertions(+), 62 deletions(-) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index e2172cfb55b50..ad6ed5f2956bd 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -5001,6 +5001,11 @@ def test_handle_trigger_dag_run_conflict( ] mock_supervisor_comms.assert_has_calls(expected_calls) + # Force the < 3.3 task-runner path: the original-error guard below is specific to + # ``_handle_trigger_dag_run`` raising inside ``run()``'s ``except`` block. On Airflow + # 3.3+ the operator owns execution and the send error is handled normally (covered by + # test_trigger_dagrun.py), so there is no unbound-``state`` hazard to guard there. + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.AIRFLOW_V_3_3_PLUS", False) @time_machine.travel("2025-01-01 00:00:00", tick=False) def test_handle_trigger_dag_run_reraises_original_error(self, create_runtime_ti, mock_supervisor_comms): """ @@ -5058,10 +5063,11 @@ def test_handle_trigger_dag_run_wait_for_completion( create_runtime_ti, mock_supervisor_comms, ): - """ - Test that TriggerDagRunOperator (with wait_for_completion) sends the correct message to the Supervisor + """A wait_for_completion trigger resolves the task to the state implied by the triggered run. - It also polls the Supervisor for the DagRun state until it completes execution. + Keyed by message type rather than call order so it holds for both the operator-owned + (Airflow 3.3+) and task-runner (< 3.3) paths; the exact call sequence is asserted in + test_trigger_dagrun.py. """ from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator @@ -5078,72 +5084,23 @@ def test_handle_trigger_dag_run_wait_for_completion( dag_id="test_handle_trigger_dag_run_wait_for_completion", run_id="test_run", task=task ) + def _send_side_effect(*args, **kwargs): + sent = kwargs.get("msg", args[0] if args else None) + if isinstance(sent, TriggerDagRun): + return OKResponse(ok=True) + if isinstance(sent, GetDagRunState): + return DagRunStateResult(state=target_dr_state) + return None + + mock_supervisor_comms.send.side_effect = _send_side_effect + log = mock.MagicMock() - mock_supervisor_comms.send.side_effect = [ - # Set RTIF - None, - # Account for the extra link XCom message sent by TriggerDagRunLink - None, - # Successful Dag Run trigger - OKResponse(ok=True), - # Set XCOM, - None, - # Dag Run is still running - DagRunStateResult(state=DagRunState.RUNNING), - # Dag Run completes execution on the next poll - DagRunStateResult(state=target_dr_state), - # Succeed/Fail task - None, - ] with mock.patch("time.sleep", return_value=None): state, msg, _ = run(ti, ti.get_template_context(), log) assert state == expected_task_state assert msg.state == expected_task_state - expected_calls = [ - mock.call.send( - msg=SetXCom( - key="_link_TriggerDagRunLink", - value="/dags/test_dag/runs/test_run_id", - dag_id="test_handle_trigger_dag_run_wait_for_completion", - task_id="test_task", - run_id="test_run", - map_index=-1, - ), - ), - mock.call.send( - msg=TriggerDagRun( - dag_id="test_dag", - run_id="test_run_id", - logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - ), - ), - mock.call.send( - msg=SetXCom( - key="trigger_run_id", - value="test_run_id", - dag_id="test_handle_trigger_dag_run_wait_for_completion", - task_id="test_task", - run_id="test_run", - map_index=-1, - ), - ), - mock.call.send( - msg=GetDagRunState( - dag_id="test_dag", - run_id="test_run_id", - ), - ), - mock.call.send( - msg=GetDagRunState( - dag_id="test_dag", - run_id="test_run_id", - ), - ), - ] - mock_supervisor_comms.assert_has_calls(expected_calls) - def test_handle_trigger_dag_run_wait_for_completion_failed_state_retries( self, create_runtime_ti, mock_supervisor_comms ):