Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/68936.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``TriggerDagRunOperator`` gains an opt-in ``durable`` flag that, on a synchronous ``wait_for_completion``, persists the triggered run id before polling so a worker crash mid-wait reconnects to the in-flight run on retry instead of triggering a duplicate run.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ class TriggerDagRunOperator(BaseOperator):
Airflow 3.x this requires Airflow 3.2.0+ (it relies on the task-SDK DAG state endpoint added then);
on Airflow 3.0/3.1 setting this raises ``NotImplementedError``.
:param deferrable: If waiting for completion, whether to defer the task until done, default is ``False``.
:param durable: If ``True`` and waiting for completion synchronously (non-deferrable), persist the
triggered run id before polling so that a worker crash mid-wait reconnects to the in-flight run on
retry instead of triggering a duplicate. Requires Airflow 3.3+ (task_state_store). Default ``False``.
:param openlineage_inject_parent_info: whether to include OpenLineage metadata about the parent task
in the triggered DAG run's conf, enabling improved lineage tracking. The metadata is only injected
if OpenLineage is enabled and running. This option does not modify any other part of the conf,
Expand Down Expand Up @@ -198,6 +201,7 @@ def __init__(
fail_when_dag_is_paused: bool = False,
note: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
durable: bool = False,
openlineage_inject_parent_info: bool = True,
**kwargs,
) -> None:
Expand All @@ -221,6 +225,7 @@ def __init__(
self.openlineage_inject_parent_info = openlineage_inject_parent_info
self.note = note
self.deferrable = deferrable
self.durable = durable
logical_date = _validate_datetime_param("logical_date", logical_date)
run_after = _validate_datetime_param("run_after", run_after)
self.logical_date = logical_date
Expand Down Expand Up @@ -325,6 +330,9 @@ def _trigger_dag_af_3(self, context, run_id, parsed_logical_date, parsed_run_aft
if parsed_run_after and "run_after" in parameters:
kwargs_accepted["run_after"] = parsed_run_after

if self.durable and "durable" in parameters:
kwargs_accepted["durable"] = self.durable

if isinstance(context, Mapping):
from airflow.utils import helpers

Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def __init__(
failed_states: list[str],
poke_interval: int,
deferrable: bool,
durable: bool = False,
note: str | None = None,
):
super().__init__()
Expand All @@ -324,6 +325,7 @@ def __init__(
self.failed_states = failed_states
self.poke_interval = poke_interval
self.deferrable = deferrable
self.durable = durable
self.note = note


Expand Down
130 changes: 89 additions & 41 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,51 +1854,101 @@ def _finalize_task_failure(
)


_TRIGGERED_RUN_ID_KEY = "triggered_dag_run_id"


def _evaluate_prior_triggered_run(
run_id: str, drte: DagRunTriggerException, log: Logger
) -> Literal["succeeded", "reconnect", "resubmit"]:
"""
Classify a run triggered on a prior attempt so the synchronous wait can resume safely.

``"succeeded"`` — already finished in an allowed state; skip resubmission.
``"reconnect"`` — still running; resume the wait without resubmitting.
``"resubmit"`` — failed, gone, or state unreadable; trigger a fresh run.
"""
comms_msg = SUPERVISOR_COMMS.send(GetDagRunState(dag_id=drte.trigger_dag_id, run_id=run_id))
state = comms_msg.state if isinstance(comms_msg, DagRunStateResult) else None
if state in drte.allowed_states:
log.info("Run triggered on a prior attempt already succeeded; not resubmitting.", run_id=run_id)
return "succeeded"
if state is None or state in drte.failed_states:
log.warning(
"Run triggered on a prior attempt is not resumable; resubmitting.", run_id=run_id, state=state
)
return "resubmit"
log.info("Reconnecting to run triggered on a prior attempt.", run_id=run_id, state=state)
return "reconnect"


def _handle_trigger_dag_run(
drte: DagRunTriggerException, context: Context, ti: RuntimeTaskInstance, log: Logger
) -> tuple[ToSupervisor, TaskInstanceState]:
"""Handle exception from TriggerDagRunOperator."""
log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
comms_msg = SUPERVISOR_COMMS.send(
TriggerDagRun(
dag_id=drte.trigger_dag_id,
run_id=drte.dag_run_id,
logical_date=drte.logical_date,
run_after=drte.run_after,
conf=drte.conf,
reset_dag_run=drte.reset_dag_run,
note=drte.note,
),
)

if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
if drte.skip_when_already_exists:
log.info(
"Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
# Crash-safety for the synchronous wait: persist the triggered run id before polling so a worker
# crash mid-wait reconnects to the in-flight run on retry instead of triggering a duplicate.
task_state_store = context.get("task_state_store")
durable = drte.durable and drte.wait_for_completion and not drte.deferrable

run_id = drte.dag_run_id
reconnecting = False
if durable and task_state_store is not None:
stored_run_id = task_state_store.get(_TRIGGERED_RUN_ID_KEY)
if stored_run_id is not None:
prior_run_id = str(stored_run_id)
decision = _evaluate_prior_triggered_run(prior_run_id, drte, log)
if decision == "succeeded":
ti.xcom_push(key="trigger_run_id", value=prior_run_id)
return _handle_current_task_success(context, ti)
if decision == "reconnect":
run_id = prior_run_id
reconnecting = True

if not reconnecting:
log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
comms_msg = SUPERVISOR_COMMS.send(
TriggerDagRun(
dag_id=drte.trigger_dag_id,
)
msg = TaskState(
state=TaskInstanceState.SKIPPED,
end_date=datetime.now(tz=timezone.utc),
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.SKIPPED
else:
log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
msg = TaskState(
state=TaskInstanceState.FAILED,
end_date=datetime.now(tz=timezone.utc),
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.FAILED
run_id=run_id,
logical_date=drte.logical_date,
run_after=drte.run_after,
conf=drte.conf,
reset_dag_run=drte.reset_dag_run,
note=drte.note,
),
)

if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS:
if drte.skip_when_already_exists:
log.info(
"Dag Run already exists, skipping task as skip_when_already_exists is set to True.",
dag_id=drte.trigger_dag_id,
)
msg = TaskState(
state=TaskInstanceState.SKIPPED,
end_date=datetime.now(tz=timezone.utc),
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.SKIPPED
else:
log.error("Dag Run already exists, marking task as failed.", dag_id=drte.trigger_dag_id)
msg = TaskState(
state=TaskInstanceState.FAILED,
end_date=datetime.now(tz=timezone.utc),
rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.FAILED

return msg, state

return msg, state
log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)

log.info("Dag Run triggered successfully.", trigger_dag_id=drte.trigger_dag_id)
if durable and task_state_store is not None:
# Persist before polling so a crash mid-wait reconnects on retry instead of resubmitting.
task_state_store.set(_TRIGGERED_RUN_ID_KEY, run_id)

# Store the run id from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
ti.xcom_push(key="trigger_run_id", value=drte.dag_run_id)
# Store the run id (created above or reconnected to) for the webserver extra link.
ti.xcom_push(key="trigger_run_id", value=run_id)

if drte.wait_for_completion:
if drte.deferrable:
Expand All @@ -1923,14 +1973,12 @@ def _handle_trigger_dag_run(
log.info(
"Waiting for dag run to complete execution in allowed state.",
dag_id=drte.trigger_dag_id,
run_id=drte.dag_run_id,
run_id=run_id,
allowed_state=drte.allowed_states,
)
time.sleep(drte.poke_interval)

comms_msg = SUPERVISOR_COMMS.send(
GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)
)
comms_msg = SUPERVISOR_COMMS.send(GetDagRunState(dag_id=drte.trigger_dag_id, run_id=run_id))
if TYPE_CHECKING:
assert isinstance(comms_msg, DagRunStateResult)
if comms_msg.state in drte.failed_states:
Expand Down
Loading
Loading