From 457e00ab8816cfbd2c6fc11d8b75be2bc9df8bf8 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Wed, 24 Jun 2026 02:02:26 -0700 Subject: [PATCH 1/4] Add durable option to TriggerDagRunOperator to reconnect on retry With wait_for_completion the trigger-and-wait runs in the task runner. A worker crash while polling makes the retry recompute a fresh run_id and trigger a duplicate child run (or fail with DagRunAlreadyExists), even though the run the first attempt started is healthy and still running. The opt-in durable flag persists the triggered run_id to task_state_store before polling, so the retry reconnects to the in-flight run instead of resubmitting. --- .../standard/operators/trigger_dagrun.py | 8 + task-sdk/src/airflow/sdk/exceptions.py | 2 + .../airflow/sdk/execution_time/task_runner.py | 125 ++++++++---- .../execution_time/test_task_runner.py | 179 ++++++++++++++++++ 4 files changed, 275 insertions(+), 39 deletions(-) diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 60d96cbe3f495..e2946acf59875 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -154,6 +154,9 @@ class TriggerDagRunOperator(BaseOperator): Airflow 3.x this requires Airflow 3.2.0+ (it relies on the task-SDK DAG state endpoint added then); on Airflow 3.0/3.1 setting this raises ``NotImplementedError``. :param deferrable: If waiting for completion, whether to defer the task until done, default is ``False``. + :param durable: If ``True`` and waiting for completion synchronously (non-deferrable), persist the + triggered run id before polling so that a worker crash mid-wait reconnects to the in-flight run on + retry instead of triggering a duplicate. Requires Airflow 3.3+ (task_state_store). Default ``False``. :param openlineage_inject_parent_info: whether to include OpenLineage metadata about the parent task in the triggered DAG run's conf, enabling improved lineage tracking. The metadata is only injected if OpenLineage is enabled and running. This option does not modify any other part of the conf, @@ -198,6 +201,7 @@ def __init__( fail_when_dag_is_paused: bool = False, note: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + durable: bool = False, openlineage_inject_parent_info: bool = True, **kwargs, ) -> None: @@ -221,6 +225,7 @@ def __init__( self.openlineage_inject_parent_info = openlineage_inject_parent_info self.note = note self.deferrable = deferrable + self.durable = durable logical_date = _validate_datetime_param("logical_date", logical_date) run_after = _validate_datetime_param("run_after", run_after) self.logical_date = logical_date @@ -325,6 +330,9 @@ def _trigger_dag_af_3(self, context, run_id, parsed_logical_date, parsed_run_aft if parsed_run_after and "run_after" in parameters: kwargs_accepted["run_after"] = parsed_run_after + if self.durable and "durable" in parameters: + kwargs_accepted["durable"] = self.durable + if isinstance(context, Mapping): from airflow.utils import helpers diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 6f43d5421ecf2..c2aac7d3a3c54 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -309,6 +309,7 @@ def __init__( failed_states: list[str], poke_interval: int, deferrable: bool, + durable: bool = False, note: str | None = None, ): super().__init__() @@ -324,6 +325,7 @@ def __init__( self.failed_states = failed_states self.poke_interval = poke_interval self.deferrable = deferrable + self.durable = durable self.note = note diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f3fee689928a0..a2e8d735b44c2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1854,51 +1854,100 @@ def _finalize_task_failure( ) +_TRIGGERED_RUN_ID_KEY = "triggered_dag_run_id" + + +def _evaluate_prior_triggered_run( + run_id: str, drte: DagRunTriggerException, log: Logger +) -> Literal["succeeded", "reconnect", "resubmit"]: + """ + Classify a run triggered on a prior attempt so the synchronous wait can resume safely. + + ``"succeeded"`` — already finished in an allowed state; skip resubmission. + ``"reconnect"`` — still running; resume the wait without resubmitting. + ``"resubmit"`` — failed, gone, or state unreadable; trigger a fresh run. + """ + comms_msg = SUPERVISOR_COMMS.send(GetDagRunState(dag_id=drte.trigger_dag_id, run_id=run_id)) + state = comms_msg.state if isinstance(comms_msg, DagRunStateResult) else None + if state in drte.allowed_states: + log.info("Run triggered on a prior attempt already succeeded; not resubmitting.", run_id=run_id) + return "succeeded" + if state is None or state in drte.failed_states: + log.warning( + "Run triggered on a prior attempt is not resumable; resubmitting.", run_id=run_id, state=state + ) + return "resubmit" + log.info("Reconnecting to run triggered on a prior attempt.", run_id=run_id, state=state) + return "reconnect" + + def _handle_trigger_dag_run( drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" - log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - comms_msg = SUPERVISOR_COMMS.send( - TriggerDagRun( - dag_id=drte.trigger_dag_id, - run_id=drte.dag_run_id, - logical_date=drte.logical_date, - run_after=drte.run_after, - conf=drte.conf, - reset_dag_run=drte.reset_dag_run, - note=drte.note, - ), + # Crash-safety for the synchronous wait: persist the triggered run id before polling so a worker + # crash mid-wait reconnects to the in-flight run on retry instead of triggering a duplicate. + task_state_store = context.get("task_state_store") + durable = ( + drte.durable and drte.wait_for_completion and not drte.deferrable and task_state_store is not None ) - if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: - if drte.skip_when_already_exists: - log.info( - "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", + run_id = drte.dag_run_id + reconnecting = False + if durable and (stored_run_id := task_state_store.get(_TRIGGERED_RUN_ID_KEY)): + decision = _evaluate_prior_triggered_run(stored_run_id, drte, log) + if decision == "succeeded": + ti.xcom_push(key="trigger_run_id", value=stored_run_id) + return _handle_current_task_success(context, ti) + if decision == "reconnect": + run_id = stored_run_id + reconnecting = True + + if not reconnecting: + log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) + comms_msg = SUPERVISOR_COMMS.send( + TriggerDagRun( dag_id=drte.trigger_dag_id, - ) - msg = TaskState( - state=TaskInstanceState.SKIPPED, - end_date=datetime.now(tz=timezone.utc), - rendered_map_index=ti.rendered_map_index, - ) - state = TaskInstanceState.SKIPPED - else: - log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) - msg = TaskState( - state=TaskInstanceState.FAILED, - end_date=datetime.now(tz=timezone.utc), - rendered_map_index=ti.rendered_map_index, - ) - state = TaskInstanceState.FAILED + run_id=run_id, + logical_date=drte.logical_date, + run_after=drte.run_after, + conf=drte.conf, + reset_dag_run=drte.reset_dag_run, + note=drte.note, + ), + ) - return msg, state + if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: + if drte.skip_when_already_exists: + log.info( + "Dag Run already exists, skipping task as skip_when_already_exists is set to True.", + dag_id=drte.trigger_dag_id, + ) + msg = TaskState( + state=TaskInstanceState.SKIPPED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) + state = TaskInstanceState.SKIPPED + else: + log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id) + msg = TaskState( + state=TaskInstanceState.FAILED, + end_date=datetime.now(tz=timezone.utc), + rendered_map_index=ti.rendered_map_index, + ) + state = TaskInstanceState.FAILED + + return msg, state - log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) + log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) - # Store the run id from the dag run (either created or found above) to - # be used when creating the extra link on the webserver. - ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id) + if durable: + # Persist before polling so a crash mid-wait reconnects on retry instead of resubmitting. + task_state_store.set(_TRIGGERED_RUN_ID_KEY, run_id) + + # Store the run id (created above or reconnected to) for the webserver extra link. + ti.xcom_push(key="trigger_run_id", value=run_id) if drte.wait_for_completion: if drte.deferrable: @@ -1923,14 +1972,12 @@ def _handle_trigger_dag_run( log.info( "Waiting for dag run to complete execution in allowed state.", dag_id=drte.trigger_dag_id, - run_id=drte.dag_run_id, + run_id=run_id, allowed_state=drte.allowed_states, ) time.sleep(drte.poke_interval) - comms_msg = SUPERVISOR_COMMS.send( - GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) - ) + comms_msg = SUPERVISOR_COMMS.send(GetDagRunState(dag_id=drte.trigger_dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 64493f6305f3c..95dbba4019ab5 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4884,6 +4884,19 @@ class CustomOperator(BaseOperator): assert log.exception.mock_calls == expected_exception_logs +class _FakeTaskStateStore: + """Minimal in-memory task_state_store stub for durable TriggerDagRunOperator tests.""" + + def __init__(self, initial: dict | None = None): + self._store = dict(initial or {}) + + def get(self, key, default=None): + return self._store.get(key, default) + + def set(self, key, value, *, retention=None): + self._store[key] = value + + class TestTriggerDagRunOperator: """Tests to verify various aspects of TriggerDagRunOperator""" @@ -5153,6 +5166,172 @@ def _send_side_effect(*args, **kwargs): assert state == TaskInstanceState.UP_FOR_RETRY + def test_handle_trigger_dag_run_persists_run_id_before_polling( + self, create_runtime_ti, mock_supervisor_comms + ): + """A durable synchronous wait persists the triggered run id before it starts polling.""" + from airflow.sdk.execution_time.task_runner import _TRIGGERED_RUN_ID_KEY + + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id="test_dag", + trigger_run_id="fresh_run_id", + poke_interval=5, + wait_for_completion=True, + deferrable=False, + durable=True, + ) + ti = create_runtime_ti(dag_id="test_persist", run_id="test_run", task=task) + store = _FakeTaskStateStore() + context = ti.get_template_context() + context["task_state_store"] = store + + persisted_at_first_poll = {} + + def _send(*args, **kwargs): + msg = kwargs.get("msg") or (args[0] if args else None) + if isinstance(msg, TriggerDagRun): + return OKResponse(ok=True) + if isinstance(msg, GetDagRunState): + persisted_at_first_poll.setdefault("value", store.get(_TRIGGERED_RUN_ID_KEY)) + return DagRunStateResult(state=DagRunState.SUCCESS) + return None + + mock_supervisor_comms.send.side_effect = _send + log = mock.MagicMock() + with mock.patch("time.sleep", return_value=None): + state, _, _ = run(ti, context, log) + + assert state == TaskInstanceState.SUCCESS + assert store.get(_TRIGGERED_RUN_ID_KEY) == "fresh_run_id" + # the id must already be persisted by the time the first poll runs, so a crash mid-wait reconnects + assert persisted_at_first_poll["value"] == "fresh_run_id" + + def test_handle_trigger_dag_run_reconnects_to_running_prior_run( + self, create_runtime_ti, mock_supervisor_comms + ): + """On retry, a still-running run triggered on a prior attempt is resumed, not re-triggered.""" + from airflow.sdk.execution_time.task_runner import _TRIGGERED_RUN_ID_KEY + + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id="test_dag", + trigger_run_id="new_run_id", + poke_interval=5, + wait_for_completion=True, + deferrable=False, + durable=True, + ) + ti = create_runtime_ti(dag_id="test_reconnect", run_id="test_run", task=task) + store = _FakeTaskStateStore({_TRIGGERED_RUN_ID_KEY: "prior_run_id"}) + context = ti.get_template_context() + context["task_state_store"] = store + + poll_states = iter([DagRunState.RUNNING, DagRunState.SUCCESS]) + triggered, polls = [], [] + + def _send(*args, **kwargs): + msg = kwargs.get("msg") or (args[0] if args else None) + if isinstance(msg, TriggerDagRun): + triggered.append(msg) + return OKResponse(ok=True) + if isinstance(msg, GetDagRunState): + polls.append(msg) + return DagRunStateResult(state=next(poll_states)) + return None + + mock_supervisor_comms.send.side_effect = _send + log = mock.MagicMock() + with mock.patch("time.sleep", return_value=None): + state, _, _ = run(ti, context, log) + + assert state == TaskInstanceState.SUCCESS + assert triggered == [] # reconnected — never resubmitted + assert all(m.run_id == "prior_run_id" for m in polls) # polled the prior run, not the new run_id + + def test_handle_trigger_dag_run_returns_success_for_already_succeeded_prior_run( + self, create_runtime_ti, mock_supervisor_comms + ): + """On retry, a prior run that already succeeded short-circuits to success without re-triggering.""" + from airflow.sdk.execution_time.task_runner import _TRIGGERED_RUN_ID_KEY + + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id="test_dag", + trigger_run_id="new_run_id", + poke_interval=5, + wait_for_completion=True, + deferrable=False, + durable=True, + ) + ti = create_runtime_ti(dag_id="test_already_succeeded", run_id="test_run", task=task) + store = _FakeTaskStateStore({_TRIGGERED_RUN_ID_KEY: "prior_run_id"}) + context = ti.get_template_context() + context["task_state_store"] = store + + triggered, polls = [], [] + + def _send(*args, **kwargs): + msg = kwargs.get("msg") or (args[0] if args else None) + if isinstance(msg, TriggerDagRun): + triggered.append(msg) + return OKResponse(ok=True) + if isinstance(msg, GetDagRunState): + polls.append(msg) + return DagRunStateResult(state=DagRunState.SUCCESS) + return None + + mock_supervisor_comms.send.side_effect = _send + log = mock.MagicMock() + state, _, _ = run(ti, context, log) + + assert state == TaskInstanceState.SUCCESS + assert triggered == [] # never resubmitted + assert len(polls) == 1 # only the resume check, no poll loop + assert polls[0].run_id == "prior_run_id" + + def test_handle_trigger_dag_run_resubmits_after_failed_prior_run( + self, create_runtime_ti, mock_supervisor_comms + ): + """On retry, a prior run in a failed state triggers a fresh run rather than reconnecting.""" + from airflow.sdk.execution_time.task_runner import _TRIGGERED_RUN_ID_KEY + + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id="test_dag", + trigger_run_id="fresh_run_id", + poke_interval=5, + wait_for_completion=True, + deferrable=False, + durable=True, + ) + ti = create_runtime_ti(dag_id="test_resubmit", run_id="test_run", task=task) + store = _FakeTaskStateStore({_TRIGGERED_RUN_ID_KEY: "dead_run_id"}) + context = ti.get_template_context() + context["task_state_store"] = store + + poll_states = iter([DagRunState.FAILED, DagRunState.SUCCESS]) + triggered = [] + + def _send(*args, **kwargs): + msg = kwargs.get("msg") or (args[0] if args else None) + if isinstance(msg, TriggerDagRun): + triggered.append(msg) + return OKResponse(ok=True) + if isinstance(msg, GetDagRunState): + return DagRunStateResult(state=next(poll_states)) + return None + + mock_supervisor_comms.send.side_effect = _send + log = mock.MagicMock() + with mock.patch("time.sleep", return_value=None): + state, _, _ = run(ti, context, log) + + assert state == TaskInstanceState.SUCCESS + assert len(triggered) == 1 # a fresh run was triggered after the prior one failed + assert triggered[0].run_id == "fresh_run_id" + assert store.get(_TRIGGERED_RUN_ID_KEY) == "fresh_run_id" # store overwritten with the new run + def test_handle_trigger_dag_run_wait_for_completion_failed_state_retry_policy_fail( self, create_runtime_ti, mock_supervisor_comms ): From ec3c791ad5f2823a6b6524c70e2caf85e4b0584b Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Wed, 24 Jun 2026 02:06:10 -0700 Subject: [PATCH 2/4] Add newsfragment for TriggerDagRunOperator durable flag --- airflow-core/newsfragments/68936.feature.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 airflow-core/newsfragments/68936.feature.rst diff --git a/airflow-core/newsfragments/68936.feature.rst b/airflow-core/newsfragments/68936.feature.rst new file mode 100644 index 0000000000000..fed2dd43f93aa --- /dev/null +++ b/airflow-core/newsfragments/68936.feature.rst @@ -0,0 +1,3 @@ +``TriggerDagRunOperator`` gains an opt-in ``durable`` flag. When waiting synchronously for the +triggered run to complete, the run id is persisted before polling so that a worker crash mid-wait +reconnects to the in-flight run on retry instead of triggering a duplicate run. From c3334232da4eee1ad26fa5226adbd1540c5c6f47 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Wed, 24 Jun 2026 10:16:26 -0700 Subject: [PATCH 3/4] Narrow task_state_store and coerce stored run id for mypy --- .../airflow/sdk/execution_time/task_runner.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index a2e8d735b44c2..dab7441618983 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1888,20 +1888,21 @@ def _handle_trigger_dag_run( # Crash-safety for the synchronous wait: persist the triggered run id before polling so a worker # crash mid-wait reconnects to the in-flight run on retry instead of triggering a duplicate. task_state_store = context.get("task_state_store") - durable = ( - drte.durable and drte.wait_for_completion and not drte.deferrable and task_state_store is not None - ) + durable = drte.durable and drte.wait_for_completion and not drte.deferrable run_id = drte.dag_run_id reconnecting = False - if durable and (stored_run_id := task_state_store.get(_TRIGGERED_RUN_ID_KEY)): - decision = _evaluate_prior_triggered_run(stored_run_id, drte, log) - if decision == "succeeded": - ti.xcom_push(key="trigger_run_id", value=stored_run_id) - return _handle_current_task_success(context, ti) - if decision == "reconnect": - run_id = stored_run_id - reconnecting = True + if durable and task_state_store is not None: + stored_run_id = task_state_store.get(_TRIGGERED_RUN_ID_KEY) + if stored_run_id is not None: + prior_run_id = str(stored_run_id) + decision = _evaluate_prior_triggered_run(prior_run_id, drte, log) + if decision == "succeeded": + ti.xcom_push(key="trigger_run_id", value=prior_run_id) + return _handle_current_task_success(context, ti) + if decision == "reconnect": + run_id = prior_run_id + reconnecting = True if not reconnecting: log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) @@ -1942,7 +1943,7 @@ def _handle_trigger_dag_run( log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id) - if durable: + if durable and task_state_store is not None: # Persist before polling so a crash mid-wait reconnects on retry instead of resubmitting. task_state_store.set(_TRIGGERED_RUN_ID_KEY, run_id) From 001febf38ad5d7876e35ca223993d4d2a44051c6 Mon Sep 17 00:00:00 2001 From: 1fanwang <1fannnw@gmail.com> Date: Wed, 24 Jun 2026 11:09:10 -0700 Subject: [PATCH 4/4] Make the durable TriggerDagRunOperator newsfragment a single line --- airflow-core/newsfragments/68936.feature.rst | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow-core/newsfragments/68936.feature.rst b/airflow-core/newsfragments/68936.feature.rst index fed2dd43f93aa..c429bf95c571f 100644 --- a/airflow-core/newsfragments/68936.feature.rst +++ b/airflow-core/newsfragments/68936.feature.rst @@ -1,3 +1 @@ -``TriggerDagRunOperator`` gains an opt-in ``durable`` flag. When waiting synchronously for the -triggered run to complete, the run id is persisted before polling so that a worker crash mid-wait -reconnects to the in-flight run on retry instead of triggering a duplicate run. +``TriggerDagRunOperator`` gains an opt-in ``durable`` flag that, on a synchronous ``wait_for_completion``, persists the triggered run id before polling so a worker crash mid-wait reconnects to the in-flight run on retry instead of triggering a duplicate run.