Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class BaseCallbackRequest(BaseModel):
"""File Path to use to run the callback"""
bundle_name: str
bundle_version: str | None
version_data: dict[str, Any] | None = None
"""Optional structured metadata for the pinned bundle version (e.g. an S3 object manifest).

Populated only for pinned runs so the callback initializes the bundle against the same
version the task ran with; ``None`` for unpinned runs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep only the first line. The docstring under each parameter is only intended to explain what the field is. Not its inner workings.

"""
msg: str | None = None
"""Additional Message that can be used for logging to determine failure/task heartbeat timeout"""

Expand Down
6 changes: 5 additions & 1 deletion airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,11 @@ def prepare_callback_bundle(self, request: CallbackRequest) -> BaseDagBundle | N
Override to source the bundle from an API.
"""
try:
bundle = DagBundlesManager().get_bundle(name=request.bundle_name, version=request.bundle_version)
bundle = DagBundlesManager().get_bundle(
name=request.bundle_name,
version=request.bundle_version,
version_data=request.version_data,
)
except ValueError:
self.log.error("Bundle %s no longer configured, skipping callback", request.bundle_name)
return None
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/executors/workloads/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,13 @@ def make(
) -> ExecuteCallback:
"""Create an ExecuteCallback workload from a Callback ORM model."""
if not bundle_info:
from airflow.models.dag_version import resolve_pinned_version_data

version_data = resolve_pinned_version_data(dag_run.created_dag_version, dag_run.bundle_version)
bundle_info = BundleInfo(
name=dag_run.dag_model.bundle_name,
version=dag_run.bundle_version,
version_data=version_data,
)
fname = f"executor_callbacks/{dag_run.dag_id}/{dag_run.run_id}/{callback.id}"

Expand Down
7 changes: 3 additions & 4 deletions airflow-core/src/airflow/executors/workloads/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,12 @@ def make(

ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True)
if not bundle_info:
version_data = None
if ti.dag_version is not None and ti.dag_run.bundle_version is not None:
version_data = ti.dag_version.version_data
from airflow.models.dag_version import resolve_pinned_version_data

bundle_info = BundleInfo(
name=ti.dag_model.bundle_name,
version=ti.dag_run.bundle_version,
version_data=version_data,
version_data=resolve_pinned_version_data(ti.dag_version, ti.dag_run.bundle_version),
)
fname = log_filename_template_renderer()(ti=ti)

Expand Down
14 changes: 13 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
ConnectionTestState,
)
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dag_version import DagVersion, resolve_pinned_version_data
from airflow.models.dagbag import DBDagBag
from airflow.models.dagbundle import DagBundleModel
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -1512,13 +1512,15 @@ def process_executor_events(
if ti.dag_version and ti.dag_run.bundle_version is not None
else ti.dag_run.bundle_version
)
_version_data = resolve_pinned_version_data(ti.dag_version, ti.dag_run.bundle_version)
# Backfill dag_version_id for legacy tasks (Pydantic requires uuid.UUID).
if not _ensure_ti_has_dag_version_id(ti, session, cls.logger()):
continue
request = TaskCallbackRequest(
filepath=ti.dag_model.relative_fileloc or "",
bundle_name=_bundle_name,
bundle_version=_bundle_version,
version_data=_version_data,
ti=ti,
msg=msg,
task_callback_type=(
Expand Down Expand Up @@ -1561,13 +1563,17 @@ def process_executor_events(
_email_bundle_version = (
ti.dag_version.bundle_version if ti.dag_version else ti.dag_run.bundle_version
)
_email_version_data = resolve_pinned_version_data(
ti.dag_version, ti.dag_run.bundle_version
)
# Backfill dag_version_id for legacy tasks (Pydantic requires uuid.UUID).
if not _ensure_ti_has_dag_version_id(ti, session, cls.logger()):
continue
email_request = EmailRequest(
filepath=ti.dag_model.relative_fileloc or "",
bundle_name=_email_bundle_name,
bundle_version=_email_bundle_version,
version_data=_email_version_data,
ti=ti,
msg=msg,
email_type="retry" if ti.is_eligible_to_retry() else "failure",
Expand Down Expand Up @@ -3068,6 +3074,9 @@ def _maybe_requeue_stuck_ti(self, *, ti, session, executor):
if ti.dag_version and ti.dag_run.bundle_version is not None
else ti.dag_run.bundle_version
)
_stuck_version_data = resolve_pinned_version_data(
ti.dag_version, ti.dag_run.bundle_version
)
# Backfill dag_version_id for legacy tasks (Pydantic requires uuid.UUID).
# Note: we cannot use `continue` here because this method is not
# inside a loop. If backfilling fails we simply skip the callback.
Expand All @@ -3076,6 +3085,7 @@ def _maybe_requeue_stuck_ti(self, *, ti, session, executor):
filepath=ti.dag_model.relative_fileloc or "",
bundle_name=_stuck_bundle_name,
bundle_version=_stuck_bundle_version,
version_data=_stuck_version_data,
ti=ti,
msg=msg,
context_from_server=TIRunContext(
Expand Down Expand Up @@ -3539,13 +3549,15 @@ def _purge_task_instances_without_heartbeats(
if ti.dag_version and ti.dag_run.bundle_version is not None
else ti.dag_run.bundle_version
)
_hb_version_data = resolve_pinned_version_data(ti.dag_version, ti.dag_run.bundle_version)
# Backfill dag_version_id for legacy tasks (Pydantic requires uuid.UUID).
if not _ensure_ti_has_dag_version_id(ti, session, self.log):
continue
request = TaskCallbackRequest(
filepath=ti.dag_model.relative_fileloc or "",
bundle_name=_hb_bundle_name,
bundle_version=_hb_bundle_version,
version_data=_hb_version_data,
ti=ti,
msg=str(task_instance_heartbeat_timeout_message_details),
context_from_server=TIRunContext(
Expand Down
20 changes: 19 additions & 1 deletion airflow-core/src/airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import logging
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from uuid import UUID

import sqlalchemy as sa
Expand Down Expand Up @@ -235,3 +235,21 @@ def get_version(
def version(self) -> str:
"""A human-friendly representation of the version."""
return f"{self.dag_id}-{self.version_number}"


def resolve_pinned_version_data(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: should this be _resolve_pinned_version_data? It looks like an internal helper that's only used within Airflow itself, so I'm not sure it needs to be part of the module's public API. I would prefer _resolve_version_data to be more concise.

dag_version: DagVersion | None, bundle_version: str | None
) -> dict[str, Any] | None:
"""
Return a bundle version's ``version_data`` manifest, but only for pinned runs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sam here. I would keep the first line. Leave any detailed explanations for the comments.


Mirrors the bundle-version pinning rule used when building task and callback
workloads: ``version_data`` is exposed only when the run is pinned
(``bundle_version`` is set) and a ``DagVersion`` is available, so the worker
initializes the bundle against the exact version the run used. Returns ``None``
for unpinned runs (which should follow the latest bundle state) and for legacy
rows without a ``DagVersion``.
"""
if dag_version is not None and bundle_version is not None:
return dag_version.version_data
return None
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,12 +1449,18 @@ def produce_dag_callback(
)
relevant_ti = None
if not execute:
from airflow.models.dag_version import resolve_pinned_version_data

# Only carry version_data for pinned runs so the callback initializes the bundle
# against the same version the run used.
version_data = resolve_pinned_version_data(self.created_dag_version, self.bundle_version)
return DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_model.bundle_name,
bundle_version=self.bundle_version,
version_data=version_data,
context_from_server=DagRunContext(
dag_run=self,
last_ti=relevant_ti,
Expand Down
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,18 @@ def _submit_callback_if_necessary() -> None:
if event.task_instance_state in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED):
if task_instance.dag_model.relative_fileloc is None:
raise RuntimeError("relative_fileloc should not be None for a finished task")
from airflow.models.dag_version import resolve_pinned_version_data

version_data = resolve_pinned_version_data(
task_instance.dag_version, task_instance.dag_run.bundle_version
)
request = TaskCallbackRequest(
filepath=task_instance.dag_model.relative_fileloc,
ti=task_instance,
task_callback_type=event.task_instance_state,
bundle_name=task_instance.dag_model.bundle_name,
bundle_version=task_instance.dag_run.bundle_version,
version_data=version_data,
)
log.info("Sending callback: %s", request)
try:
Expand Down
27 changes: 27 additions & 0 deletions airflow-core/tests/unit/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,33 @@ def test_is_failure_callback_property(

assert request.is_failure_callback == expected_is_failure

def test_version_data_round_trips_and_defaults_none(self):
"""version_data survives JSON serialization and defaults to None when omitted."""
version_data = {"schema_version": 1, "files": {"dags/my_dag.py": "ver123"}}
request = DagCallbackRequest(
filepath="filepath",
dag_id="fake_dag",
run_id="fake_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version="abc123",
version_data=version_data,
)
result = DagCallbackRequest.from_json(request.to_json())
assert result.version_data == version_data

# Omitted -> defaults to None and round-trips as None.
unpinned = DagCallbackRequest(
filepath="filepath",
dag_id="fake_dag",
run_id="fake_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
)
assert unpinned.version_data is None
assert DagCallbackRequest.from_json(unpinned.to_json()).version_data is None


class TestDagRunContext:
def test_dagrun_context_creation(self):
Expand Down
24 changes: 23 additions & 1 deletion airflow-core/tests/unit/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,7 @@ def test_collect_results_processes_remaining_files_when_one_persist_fails(self,
"filepath": "dag_callback_dag.py",
"bundle_name": "testing",
"bundle_version": None,
"version_data": None,
"msg": None,
"dag_id": "dag_id",
"run_id": "run_id",
Expand Down Expand Up @@ -1967,7 +1968,28 @@ def test_prepare_callback_bundle_initializes_versioned_bundle(self, mock_bundle_
bundle.initialize.assert_called_once()

@mock.patch("airflow.dag_processing.manager.DagBundlesManager")
def test_prepare_callback_bundle_skips_initialize_for_unversioned_request(self, mock_bundle_manager):
def test_prepare_callback_bundle_forwards_version_data(self, mock_bundle_manager):
manager = DagFileProcessorManager(max_runs=1)
bundle = MagicMock(spec=BaseDagBundle)
bundle.supports_versioning = True
mock_bundle_manager.return_value.get_bundle.return_value = bundle

version_data = {"schema_version": 1, "files": {"dags/my_dag.py": "ver123"}}
request = DagCallbackRequest(
filepath="file1.py",
dag_id="dag1",
run_id="run1",
is_failure_callback=False,
bundle_name="testing",
bundle_version="some_commit_hash",
version_data=version_data,
msg=None,
)

manager.prepare_callback_bundle(request)
mock_bundle_manager.return_value.get_bundle.assert_called_once_with(
name="testing", version="some_commit_hash", version_data=version_data
)
manager = DagFileProcessorManager(max_runs=1)
bundle = MagicMock(spec=BaseDagBundle)
bundle.supports_versioning = True
Expand Down
60 changes: 59 additions & 1 deletion airflow-core/tests/unit/executors/test_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airflow.executors import workloads
from airflow.executors.workloads import TaskInstance, TaskInstanceDTO, base as workloads_base
from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo
from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod
from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod, ExecuteCallback
from airflow.executors.workloads.task import ExecuteTask
from airflow.executors.workloads.types import state_class_for_key
from airflow.models.callback import CallbackKey
Expand Down Expand Up @@ -250,3 +250,61 @@ def test_missing_dag_version_yields_none(self):

assert workload.bundle_info.version == "abc123"
assert workload.bundle_info.version_data is None


class TestExecuteCallbackMakeVersionData:
"""Tests for ExecuteCallback.make() threading version_data through BundleInfo."""

@staticmethod
def _make_mocks(bundle_version, version_data, *, has_created_dag_version=True):
"""Build mock Callback + DagRun with the attributes ExecuteCallback.make() reads."""
from unittest.mock import Mock

callback = Mock()
callback.id = uuid4()
callback.fetch_method = CallbackFetchMethod.IMPORT_PATH
callback.data = {"path": "my_module.my_callback"}

dag_run = Mock()
dag_run.dag_id = "test_dag"
dag_run.run_id = "test_run"
dag_run.bundle_version = bundle_version
dag_run.dag_model.bundle_name = "test-bundle"
dag_run.dag_model.relative_fileloc = "dags/test_dag.py"
if has_created_dag_version:
dag_run.created_dag_version.version_data = version_data
else:
dag_run.created_dag_version = None

return callback, dag_run

def test_pinned_run_populates_version_data(self):
"""When the run is pinned, version_data from created_dag_version flows to BundleInfo."""
version_data = {"schema_version": 1, "files": {"dags/my_dag.py": "ver123"}}
callback, dag_run = self._make_mocks(bundle_version="abc123", version_data=version_data)

workload = ExecuteCallback.make(callback=callback, dag_run=dag_run)

assert workload.bundle_info.version == "abc123"
assert workload.bundle_info.version_data == version_data

def test_unpinned_run_suppresses_present_version_data(self):
"""An unpinned run must not expose version_data even when created_dag_version carries it."""
version_data = {"schema_version": 1, "files": {"dags/my_dag.py": "ver123"}}
callback, dag_run = self._make_mocks(bundle_version=None, version_data=version_data)

workload = ExecuteCallback.make(callback=callback, dag_run=dag_run)

assert workload.bundle_info.version is None
assert workload.bundle_info.version_data is None

def test_missing_created_dag_version_yields_none(self):
"""A pinned run without a created_dag_version yields no version_data."""
callback, dag_run = self._make_mocks(
bundle_version="abc123", version_data=None, has_created_dag_version=False
)

workload = ExecuteCallback.make(callback=callback, dag_run=dag_run)

assert workload.bundle_info.version == "abc123"
assert workload.bundle_info.version_data is None
1 change: 1 addition & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ def test_process_executor_events_with_callback(
ti=mock.ANY,
bundle_name="dag_maker",
bundle_version=None,
version_data=None,
msg=f"Executor {executor} reported that the task instance "
f"<TaskInstance: test_process_executor_events_with_callback.dummy_task test [queued] ti_id={ti1.id}> "
"finished with state failed, but the task instance's state attribute is queued. "
Expand Down
29 changes: 29 additions & 0 deletions airflow-core/tests/unit/models/test_dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

from datetime import timedelta
from unittest import mock

import pytest
from sqlalchemy import func, select
Expand Down Expand Up @@ -181,3 +182,31 @@ def test_write_dag_without_version_data(self, dag_maker, session):
retrieved = DagVersion.get_latest_version("test_no_version_data", session=session)
assert retrieved.version_data is None
assert retrieved.bundle_version == "abc123"


class TestResolvePinnedVersionData:
"""Unit tests for the resolve_pinned_version_data pin-guard helper."""

@pytest.mark.parametrize(
("dag_version", "bundle_version", "expected"),
[
pytest.param(
mock.Mock(version_data={"schema_version": 1}),
"abc123",
{"schema_version": 1},
id="pinned-with-data",
),
pytest.param(
mock.Mock(version_data={"schema_version": 1}),
None,
None,
id="unpinned-suppresses-present-data",
),
pytest.param(None, "abc123", None, id="missing-dag-version"),
pytest.param(None, None, None, id="unpinned-and-missing"),
],
)
def test_resolve_pinned_version_data(self, dag_version, bundle_version, expected):
from airflow.models.dag_version import resolve_pinned_version_data

assert resolve_pinned_version_data(dag_version, bundle_version) == expected
Loading
Loading