Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from airflow.providers.standard.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS,
AIRFLOW_V_3_3_PLUS,
BaseOperator,
is_arg_set,
)
Expand All @@ -57,6 +58,26 @@
except ImportError:
from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef]

try:
from airflow.sdk import ResumableJobMixin
except ImportError:
# ResumableJobMixin was added in Airflow 3.3. On older Airflow the durable path is gated off
# (see execute()), so this stub only lets the operator subclass it without behavior change.
class ResumableJobMixin: # type: ignore[no-redef]
"""Stub used on Airflow < 3.3, where the durable path is disabled."""

external_id_key: str = "remote_job_id"

def __init__(self, *, durable: bool = True, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.durable = durable

def execute_resumable(self, context):
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)


XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"

Expand All @@ -78,6 +99,10 @@ def __str__(self) -> str:
return f"Dag {self.dag_id} is paused"


class TriggeredDagRunFailed(AirflowException):
"""Raise when a synchronously-waited triggered Dag run reaches a failed state."""


class TriggerDagRunLink(BaseOperatorLink):
"""
Operator link for TriggerDagRunOperator.
Expand Down Expand Up @@ -121,7 +146,7 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
return build_airflow_url_with_query(query)


class TriggerDagRunOperator(BaseOperator):
class TriggerDagRunOperator(ResumableJobMixin, BaseOperator):
"""
Triggers a DAG run for a specified DAG ID.

Expand Down Expand Up @@ -154,6 +179,10 @@ class TriggerDagRunOperator(BaseOperator):
Airflow 3.x this requires Airflow 3.2.0+ (it relies on the task-SDK DAG state endpoint added then);
on Airflow 3.0/3.1 setting this raises ``NotImplementedError``.
:param deferrable: If waiting for completion, whether to defer the task until done, default is ``False``.
:param durable: On Airflow 3.3+ with a synchronous ``wait_for_completion`` (non-deferrable), persist the
triggered run id before polling so a worker crash mid-wait reconnects to the in-flight run on retry
instead of triggering a duplicate. No effect on Airflow < 3.3 or with ``deferrable=True``.
Default ``False``.
:param openlineage_inject_parent_info: whether to include OpenLineage metadata about the parent task
in the triggered DAG run's conf, enabling improved lineage tracking. The metadata is only injected
if OpenLineage is enabled and running. This option does not modify any other part of the conf,
Expand Down Expand Up @@ -181,6 +210,9 @@ class TriggerDagRunOperator(BaseOperator):
ui_color = "#ffefeb"
operator_extra_links = [TriggerDagRunLink()]

# Key under which the triggered run id is persisted to task_state_store for the durable path.
external_id_key = "triggered_dag_run_id"

def __init__(
self,
*,
Expand All @@ -198,10 +230,11 @@ def __init__(
fail_when_dag_is_paused: bool = False,
note: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
durable: bool = False,
openlineage_inject_parent_info: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
super().__init__(durable=durable, **kwargs)
self.trigger_dag_id = trigger_dag_id
self.trigger_run_id = trigger_run_id
self.conf = conf
Expand All @@ -225,6 +258,9 @@ def __init__(
run_after = _validate_datetime_param("run_after", run_after)
self.logical_date = logical_date
self.run_after = run_after
# Parsed datetimes, set in execute() before the durable (Airflow 3.3+) submit path reads them.
self._parsed_logical_date: datetime.datetime | None = None
self._parsed_run_after: datetime.datetime | None = None
if fail_when_dag_is_paused and AIRFLOW_V_3_0_PLUS and not AIRFLOW_V_3_2_PLUS:
raise NotImplementedError(
"Setting `fail_when_dag_is_paused` requires Airflow 3.2.0+ on Airflow 3.x "
Expand Down Expand Up @@ -289,6 +325,17 @@ def execute(self, context: Context):
if dag_model.is_paused:
raise AirflowException(f"Dag {self.trigger_dag_id} is paused")

if AIRFLOW_V_3_3_PLUS and not self.deferrable:
# On Airflow 3.3+ the operator owns its synchronous execution through the execution-API
# accessors (ti.trigger_dag_run / ti.get_dagrun_state) and ResumableJobMixin, rather than
# handing the trigger and wait off to the task runner via DagRunTriggerException.
self._parsed_logical_date = parsed_logical_date
self._parsed_run_after = parsed_run_after if self.run_after is not NOTSET else None
if self.wait_for_completion:
return self.execute_resumable(context)
self.submit_job(context)
return

if AIRFLOW_V_3_0_PLUS:
self._trigger_dag_af_3(
context=context,
Expand All @@ -301,6 +348,80 @@ def execute(self, context: Context):
context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date
)

# ResumableJobMixin hooks — drive the synchronous Airflow 3.3+ path through the execution-API
# accessors so the operator owns its execution instead of the task runner.

def submit_job(self, context: Context) -> Any:
ti = context["ti"]
# run_id is computed in execute() before this runs, so it is always set here.
run_id = cast("str", self.trigger_run_id)
created = ti.trigger_dag_run(
self.trigger_dag_id,
run_id,
conf=self.conf,
logical_date=self._parsed_logical_date,
run_after=self._parsed_run_after,
reset_dag_run=self.reset_dag_run,
note=self.note,
)
if not created:
if self.skip_when_already_exists:
raise AirflowSkipException(
"Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists"
)
raise DagRunAlreadyExists(DagRun(dag_id=self.trigger_dag_id, run_id=run_id))
self._record_triggered_run(run_id, context)
return run_id

def get_job_status(self, external_id: Any, context: Context) -> str:
return context["ti"].get_dagrun_state(self.trigger_dag_id, str(external_id))

def is_job_active(self, status: str) -> bool:
return bool(status) and status not in self.allowed_states and status not in self.failed_states

def is_job_succeeded(self, status: str) -> bool:
return status in self.allowed_states

def poll_until_complete(self, external_id: Any, context: Context) -> None:
run_id = str(external_id)
self._record_triggered_run(run_id, context)
ti = context["ti"]
while True:
self.log.info(
"Waiting for %s on %s to reach an allowed state %s ...",
self.trigger_dag_id,
run_id,
self.allowed_states,
)
time.sleep(self.poke_interval)
state = ti.get_dagrun_state(self.trigger_dag_id, run_id)
if state in self.failed_states:
raise TriggeredDagRunFailed(f"{self.trigger_dag_id} failed with failed state {state}")
if state in self.allowed_states:
self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state)
return

def get_job_result(self, external_id: Any, context: Context) -> Any:
run_id = str(external_id)
self._record_triggered_run(run_id, context)
return run_id

def _record_triggered_run(self, run_id: str, context: Context) -> None:
"""Record the triggered run id as a task attribute and on XCom (run id + extra-link URL)."""
self.trigger_run_id = run_id
ti = context.get("ti") or context.get("task_instance")
if not (ti and hasattr(ti, "xcom_push")):
return
ti.xcom_push(key=XCOM_RUN_ID, value=run_id)
from airflow.utils import helpers

build_url_fn = getattr(helpers, "build_airflow_dagrun_url", None)
if build_url_fn:
ti.xcom_push(
key=TriggerDagRunLink().xcom_key,
value=build_url_fn(dag_id=self.trigger_dag_id, run_id=run_id),
)

def _trigger_dag_af_3(self, context, run_id, parsed_logical_date, parsed_run_after=None):
from airflow.providers.common.compat.sdk import DagRunTriggerException

Expand Down
Loading
Loading