From 54904a190b9950ce99df857dafa2185a387773b7 Mon Sep 17 00:00:00 2001 From: Sirui Huang Date: Wed, 17 Jun 2026 15:14:22 -0700 Subject: [PATCH 1/3] Add BlockRefCounter Implementation and Tests Signed-off-by: Sirui Huang --- python/ray/data/BUILD.bazel | 14 + .../_internal/execution/block_ref_counter.py | 72 +++++ .../ray/data/tests/test_block_ref_counter.py | 297 ++++++++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 python/ray/data/_internal/execution/block_ref_counter.py create mode 100644 python/ray/data/tests/test_block_ref_counter.py diff --git a/python/ray/data/BUILD.bazel b/python/ray/data/BUILD.bazel index 194a98fe617..802fba6de44 100644 --- a/python/ray/data/BUILD.bazel +++ b/python/ray/data/BUILD.bazel @@ -1369,6 +1369,20 @@ py_test( ], ) +py_test( + name = "test_block_ref_counter", + size = "small", + srcs = ["tests/test_block_ref_counter.py"], + tags = [ + "exclusive", + "team:data", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_map_operator", size = "medium", diff --git a/python/ray/data/_internal/execution/block_ref_counter.py b/python/ray/data/_internal/execution/block_ref_counter.py new file mode 100644 index 00000000000..6cd5bed5f5f --- /dev/null +++ b/python/ray/data/_internal/execution/block_ref_counter.py @@ -0,0 +1,72 @@ +import threading +from collections import defaultdict +from typing import Dict + +import ray +import ray._private.worker + + +class BlockRefCounter: + """Tracks object-store memory usage per operator via Ray Core callbacks. + + The callback fires when: + - All Python ObjectRefs wrapping the block's ObjectID are garbage-collected, AND + - All Ray tasks that received the block as an argument have completed. + """ + + def __init__(self): + # Object ID binaries of currently live blocks; used by _on_object_freed + # to distinguish a racing clear() from a real callback. + self._registered_ids: set[bytes] = set() + # (producer_id -> total live bytes); maintained incrementally for O(1) reads. + self._bytes_by_producer: Dict[str, int] = defaultdict(int) + self._lock = threading.Lock() + + def on_block_produced( + self, + block_ref: "ray.ObjectRef", + size_bytes: int, + producer_id: str, + ) -> None: + """Register a block and attribute its memory to producer_id. + + Registers a Ray Core out-of-scope callback so that when all references + to block_ref are gone the bytes are automatically removed from the + producer's usage. + """ + id_binary = block_ref.binary() + with self._lock: + self._registered_ids.add(id_binary) + self._bytes_by_producer[producer_id] += size_bytes + + def _on_object_freed(id_bytes: bytes) -> None: + with self._lock: + if id_bytes not in self._registered_ids: + # Already cleared (e.g. by clear()), nothing to do. + return + self._registered_ids.discard(id_bytes) + self._bytes_by_producer[producer_id] -= size_bytes + if self._bytes_by_producer[producer_id] == 0: + del self._bytes_by_producer[producer_id] + + core_worker = ray._private.worker.global_worker.core_worker # type: ignore[attr-defined] + registered = core_worker.add_object_out_of_scope_callback( + block_ref, _on_object_freed + ) + if not registered: + _on_object_freed(id_binary) + + def get_object_store_memory_usage(self, producer_id: str) -> int: + """Total bytes of live blocks attributed to producer_id.""" + with self._lock: + return self._bytes_by_producer.get(producer_id, 0) + + def clear(self) -> None: + """Reset all accounting, e.g. on executor shutdown. + + Any previously registered Ray Core callbacks firing after clear() + will be silently ignored because _registered_ids is empty. + """ + with self._lock: + self._registered_ids.clear() + self._bytes_by_producer.clear() diff --git a/python/ray/data/tests/test_block_ref_counter.py b/python/ray/data/tests/test_block_ref_counter.py new file mode 100644 index 00000000000..4cf9caed780 --- /dev/null +++ b/python/ray/data/tests/test_block_ref_counter.py @@ -0,0 +1,297 @@ +import gc +import threading +import time +import unittest.mock as mock + +import pytest + +import ray +from ray.data._internal.execution.block_ref_counter import BlockRefCounter +from ray.tests.conftest import * # noqa + + +class _FakeRef: + """Minimal stand-in for ray.ObjectRef. Has a .binary() that returns bytes.""" + + def __init__(self, uid: int): + self._binary = uid.to_bytes(28, "big") + + def binary(self) -> bytes: + return self._binary + + +def _register_block(counter, ref, size_bytes, producer_id): + """Call on_block_produced on an existing counter with a mocked core worker. + + Returns the captured _on_out_of_scope callback so tests can fire it directly. + """ + captured_callback = None + + class _MockCoreWorker: + def add_object_out_of_scope_callback(self, block_ref, cb): + nonlocal captured_callback + captured_callback = cb + return True + + with mock.patch( + "ray._private.worker.global_worker", + mock.Mock(core_worker=_MockCoreWorker()), + ): + counter.on_block_produced(ref, size_bytes, producer_id) + + return captured_callback, ref.binary() + + +class TestBlockRefCounterAccounting: + def test_single_block_produced_and_released(self): + counter = BlockRefCounter() + ref = _FakeRef(1) + callback, id_binary = _register_block(counter, ref, 100, "op_a") + + assert counter.get_object_store_memory_usage("op_a") == 100 + callback(id_binary) + assert counter.get_object_store_memory_usage("op_a") == 0 + + def test_multiple_blocks_same_producer(self): + counter = BlockRefCounter() + ref1, ref2 = _FakeRef(1), _FakeRef(2) + cb1, bin1 = _register_block(counter, ref1, 100, "op_a") + cb2, bin2 = _register_block(counter, ref2, 200, "op_a") + + assert counter.get_object_store_memory_usage("op_a") == 300 + cb1(bin1) + assert counter.get_object_store_memory_usage("op_a") == 200 + cb2(bin2) + assert counter.get_object_store_memory_usage("op_a") == 0 + + def test_multiple_producers_isolated(self): + counter = BlockRefCounter() + ref1, ref2 = _FakeRef(1), _FakeRef(2) + cb1, bin1 = _register_block(counter, ref1, 100, "op_a") + _register_block(counter, ref2, 200, "op_b") + + assert counter.get_object_store_memory_usage("op_a") == 100 + assert counter.get_object_store_memory_usage("op_b") == 200 + + cb1(bin1) + assert counter.get_object_store_memory_usage("op_a") == 0 + assert counter.get_object_store_memory_usage("op_b") == 200 + + +class TestBlockRefCounterClear: + def test_clear_resets_usage(self): + counter = BlockRefCounter() + _register_block(counter, _FakeRef(1), 100, "op_a") + assert counter.get_object_store_memory_usage("op_a") == 100 + + counter.clear() + assert counter.get_object_store_memory_usage("op_a") == 0 + + def test_callback_after_clear_is_noop(self): + """A callback firing after clear() must not crash or corrupt state.""" + counter = BlockRefCounter() + ref = _FakeRef(1) + callback, id_binary = _register_block(counter, ref, 100, "op_a") + + counter.clear() + callback(id_binary) # must be a silent no-op + assert counter.get_object_store_memory_usage("op_a") == 0 + + def test_new_blocks_after_clear_are_tracked(self): + """After clear(), new registrations work normally.""" + counter = BlockRefCounter() + _register_block(counter, _FakeRef(1), 50, "op_b") + counter.clear() + assert counter.get_object_store_memory_usage("op_b") == 0 + + _register_block(counter, _FakeRef(2), 50, "op_b") + assert counter.get_object_store_memory_usage("op_b") == 50 + + def test_clear_races_with_object_already_freed(self): + """clear() between byte-increment and the registered=False undo must not go negative. + + If add_object_out_of_scope_callback returns False (object already gone), + on_block_produced calls _on_object_freed to undo the increment. If clear() + fires in that window, the undo must be a no-op (id_binary is no longer in + _registered_ids), not a double-decrement. + """ + counter = BlockRefCounter() + ref = _FakeRef(1) + + class _ClearOnRegisterCoreWorker: + def add_object_out_of_scope_callback(self, block_ref, cb): + counter.clear() # race: clear fires before finally runs + return False # object already out of scope + + with mock.patch( + "ray._private.worker.global_worker", + mock.Mock(core_worker=_ClearOnRegisterCoreWorker()), + ): + counter.on_block_produced(ref, 100, "op_a") + + assert counter.get_object_store_memory_usage("op_a") == 0 + + +class TestBlockRefCounterThreadSafety: + def test_concurrent_callbacks_dont_corrupt_state(self): + """Multiple threads firing callbacks concurrently must not go negative.""" + counter = BlockRefCounter() + producer_id = "op_concurrent" + n = 50 + refs = [_FakeRef(i) for i in range(n)] + callbacks = [] + + for ref in refs: + cb, id_binary = _register_block(counter, ref, 10, producer_id) + callbacks.append((cb, id_binary)) + + threads = [threading.Thread(target=cb, args=(idb,)) for cb, idb in callbacks] + for t in threads: + t.start() + for t in threads: + t.join() + + assert counter.get_object_store_memory_usage(producer_id) == 0 + + +@ray.remote +def _hold_ref_for(block_ref, sleep_s: float) -> bool: + """Hold *block_ref* as a task argument for *sleep_s* seconds, then return. + + Because Ray keeps the object alive for the duration of any task that + received it as an argument, this lets tests verify the callback has + not fired while the task is still running. + """ + import time as _time + + _time.sleep(sleep_s) + return True + + +def _wait_for_counter( + counter: BlockRefCounter, + producer_id: str, + expected: int, + timeout_s: float = 10.0, + poll_interval_s: float = 0.05, +) -> bool: + """Poll until *counter* reports *expected* bytes for *producer_id*. + + Calls ``gc.collect()`` on every iteration so that any pending Python-level + ObjectRef destructors have a chance to run. Returns True if the expected + value is reached before *timeout_s* elapses, False otherwise. + """ + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + gc.collect() + if counter.get_object_store_memory_usage(producer_id) == expected: + return True + time.sleep(poll_interval_s) + return False + + +class TestBlockRefCounterLifecycle: + """Integration tests that exercise the full add_object_out_of_scope_callback path. + + All tests in this class require a live Ray cluster (ray_start_regular_shared). + They verify that the out-of-scope callback fires at exactly the right moment: + not before the last reference drops, and not after it. + + Three cases are covered: + 1. Basic lifecycle: callback fires after the last Python ObjectRef is GC'd. + 2. Two Python refs: callback fires only after both refs are dropped. + 3. Task ref: callback fires only after the holding task finishes and all + Python refs are dropped. This matches the real operator lifecycle where + a block stays live until the task that received it as an argument completes. + """ + + # Byte count attributed to the test operator. The actual object put into + # the store is much smaller; we only care that the counter tracks *this* + # number faithfully. + _SIZE_BYTES = 1 * 1024 * 1024 # 1 MB + + def _make_block(self) -> "ray.ObjectRef": + import numpy as np + + return ray.put(np.zeros(128, dtype=np.float64)) + + def test_callback_fires_after_last_python_ref_deleted( + self, ray_start_regular_shared + ): + """Counter reaches 0 once the only Python ObjectRef is GC'd.""" + counter = BlockRefCounter() + ref = self._make_block() + + counter.on_block_produced(ref, self._SIZE_BYTES, "op_basic") + assert counter.get_object_store_memory_usage("op_basic") == self._SIZE_BYTES + + del ref # last Python ref gone + assert _wait_for_counter(counter, "op_basic", 0), ( + "Counter did not reach 0 after all Python refs were deleted; " + f"remaining: {counter.get_object_store_memory_usage('op_basic')} bytes" + ) + + def test_second_python_ref_keeps_counter_alive(self, ray_start_regular_shared): + """Counter stays non-zero while a second Python ObjectRef is alive. + + Dropping one of two refs that point at the same ObjectID must NOT fire + the callback. Only the final ref drop may do so. + """ + counter = BlockRefCounter() + ref1 = self._make_block() + ref2 = ref1 # second Python ref to the same ObjectID + + counter.on_block_produced(ref1, self._SIZE_BYTES, "op_two_refs") + assert counter.get_object_store_memory_usage("op_two_refs") == self._SIZE_BYTES + + del ref1 + gc.collect() + time.sleep(0.3) # give GC ample time; counter must still be non-zero + + assert ( + counter.get_object_store_memory_usage("op_two_refs") == self._SIZE_BYTES + ), "Callback fired too early — counter decremented while ref2 was still alive" + + del ref2 # last ref gone; callback must now fire + assert _wait_for_counter( + counter, "op_two_refs", 0 + ), "Counter did not reach 0 after the last Python ref was deleted" + + def test_task_ref_keeps_counter_alive_until_task_completes( + self, ray_start_regular_shared + ): + """Counter stays non-zero while a running Ray task holds the block. + + Ray keeps any object alive for the duration of a task that received it + as an argument. The callback should not fire until both conditions hold: + (a) the task has completed, and (b) all Python refs are dropped. + """ + counter = BlockRefCounter() + ref = self._make_block() + + counter.on_block_produced(ref, self._SIZE_BYTES, "op_task") + assert counter.get_object_store_memory_usage("op_task") == self._SIZE_BYTES + + # Submit a task that sleeps for 1 s while holding the block, then drop + # the Python ref so only the task's argument reference remains. + task_future = _hold_ref_for.remote(ref, 1.0) + del ref + gc.collect() + time.sleep(0.3) # task is still running; callback must NOT have fired + + assert ( + counter.get_object_store_memory_usage("op_task") == self._SIZE_BYTES + ), "Callback fired too early: counter decremented while task was still running" + + ray.get(task_future) # task completes; now both refs are gone + assert _wait_for_counter( + counter, "op_task", 0 + ), "Counter did not reach 0 after task completed and Python ref was deleted" + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", __file__])) From 1244004251ede30ba0a3216f0dd6c35375569a7c Mon Sep 17 00:00:00 2001 From: Sirui Huang Date: Wed, 17 Jun 2026 15:17:48 -0700 Subject: [PATCH 2/3] Wire blockRefCounter through operators Signed-off-by: Sirui Huang --- .../execution/interfaces/physical_operator.py | 23 ++++++++++++++- .../operators/actor_pool_map_operator.py | 4 +-- .../operators/base_physical_operator.py | 16 ++++++++++- .../execution/operators/hash_shuffle.py | 6 ++-- .../execution/operators/input_data_buffer.py | 4 +-- .../execution/operators/map_operator.py | 12 +++++--- .../execution/operators/output_splitter.py | 4 +-- .../execution/operators/union_operator.py | 4 +-- .../_internal/execution/resource_manager.py | 8 ++++++ .../_internal/execution/streaming_executor.py | 7 +++++ .../execution/streaming_executor_state.py | 1 - python/ray/data/tests/test_operators.py | 2 ++ .../ray/data/tests/test_streaming_executor.py | 28 +++++++++++++------ 13 files changed, 94 insertions(+), 25 deletions(-) diff --git a/python/ray/data/_internal/execution/interfaces/physical_operator.py b/python/ray/data/_internal/execution/interfaces/physical_operator.py index 545458506a4..fc492400629 100644 --- a/python/ray/data/_internal/execution/interfaces/physical_operator.py +++ b/python/ray/data/_internal/execution/interfaces/physical_operator.py @@ -24,6 +24,7 @@ ActorPoolInfo, AutoscalingActorPool, ) +from ray.data._internal.execution.block_ref_counter import BlockRefCounter from ray.data._internal.execution.interfaces.execution_options import ( ExecutionOptions, ExecutionResources, @@ -134,6 +135,8 @@ def __init__( self, task_index: int, streaming_gen: ObjectRefGenerator, + block_ref_counter: BlockRefCounter, + producer_id: str, output_ready_callback: Callable[[RefBundle], None] = lambda bundle: None, task_done_callback: TaskDoneCallbackType = lambda exc, worker_stats, driver_stats: None, block_ready_callback: Callable[ @@ -149,6 +152,9 @@ def __init__( Args: task_index: Index of the task. Used for callbacks. streaming_gen: The streaming generator of this task. It should yield blocks. + block_ref_counter: The centralized block reference counter. on_block_produced + is called for each block yielded by this task. + producer_id: The id of the operator that produces the blocks from this task. output_ready_callback: The callback to call when a new RefBundle is output from the generator. task_done_callback: The callback to call when the task is done. @@ -171,6 +177,8 @@ def __init__( self._block_ready_callback = block_ready_callback self._metadata_ready_callback = metadata_ready_callback self._operator_name = operator_name + self._block_ref_counter: BlockRefCounter = block_ref_counter + self._producer_id: str = producer_id # If the generator hasn't produced block metadata yet, or if the block metadata # object isn't available after we get a reference, we need store the pending @@ -292,6 +300,9 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int: meta_with_schema_bytes ) meta = meta_with_schema.metadata + self._block_ref_counter.on_block_produced( + self._pending_block_ref, meta.size_bytes or 0, self._producer_id + ) self._output_ready_callback( RefBundle( [BlockEntry(self._pending_block_ref, meta)], @@ -444,6 +455,7 @@ def __init__( self._id = str(uuid.uuid4()) # Initialize metrics after data_context is set self._metrics = OpRuntimeMetrics(self) + self._block_ref_counter: Optional[BlockRefCounter] = None def __reduce__(self): raise ValueError("Operator is not serializable.") @@ -743,12 +755,21 @@ def num_output_splits(self) -> int: """ return self._num_output_splits - def start(self, options: ExecutionOptions) -> None: + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional[BlockRefCounter] = None, + ) -> None: """Called by the executor when execution starts for an operator. Args: options: The global options used for the overall execution. + block_ref_counter: The executor-wide shared counter for tracking + object-store memory. If omitted, a fresh per-operator counter is used. """ + self._block_ref_counter = ( + block_ref_counter if block_ref_counter is not None else BlockRefCounter() + ) self._started = True def can_add_input(self) -> bool: diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index 5fe2f29f443..de38d336c55 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -265,9 +265,9 @@ def _apply_default_actor_task_remote_args( return ray_actor_task_remote_args - def start(self, options: ExecutionOptions): + def start(self, options: ExecutionOptions, block_ref_counter=None): self._actor_locality_enabled = options.actor_locality_enabled - super().start(options) + super().start(options, block_ref_counter) self._actor_cls = ray.remote(**self._ray_remote_args)(self._map_worker_cls) self._actor_pool.scale( diff --git a/python/ray/data/_internal/execution/operators/base_physical_operator.py b/python/ray/data/_internal/execution/operators/base_physical_operator.py index 388e2790608..e48562a07e5 100644 --- a/python/ray/data/_internal/execution/operators/base_physical_operator.py +++ b/python/ray/data/_internal/execution/operators/base_physical_operator.py @@ -183,9 +183,23 @@ def all_inputs_done(self) -> None: ) # NOTE: We don't account object store memory use from intermediate `bulk_fn` # outputs (e.g., map outputs for map-reduce). - output_buffer, self._stats = self._bulk_fn(self._input_buffer.to_list(), ctx) + + # Snapshot input refs before calling bulk_fn. Some bulk_fns (e.g. + # randomize_blocks) forward input ObjectRefs unchanged to the output. + # We only call on_block_produced for genuinely new refs to avoid + # double-counting; forwarded refs stay attributed to their original producer. + input_bundles = self._input_buffer.to_list() + input_refs = {entry.ref for bundle in input_bundles for entry in bundle.blocks} + output_buffer, self._stats = self._bulk_fn(input_bundles, ctx) self._output_buffer = FIFOBundleQueue(output_buffer) + for bundle in output_buffer: + for entry in bundle.blocks: + if entry.ref not in input_refs: + self._block_ref_counter.on_block_produced( + entry.ref, entry.metadata.size_bytes or 0, self.id + ) + while self._input_buffer.has_next(): refs = self._input_buffer.get_next() self._metrics.on_input_dequeued(refs, input_index=0) diff --git a/python/ray/data/_internal/execution/operators/hash_shuffle.py b/python/ray/data/_internal/execution/operators/hash_shuffle.py index edebd54beef..fe9259319a0 100644 --- a/python/ray/data/_internal/execution/operators/hash_shuffle.py +++ b/python/ray/data/_internal/execution/operators/hash_shuffle.py @@ -657,8 +657,8 @@ def __init__( self._reduce_bar = None self._reduce_metrics = OpRuntimeMetrics(self) - def start(self, options: ExecutionOptions) -> None: - super().start(options) + def start(self, options: ExecutionOptions, block_ref_counter=None) -> None: + super().start(options, block_ref_counter) self._aggregator_pool.start() @@ -1007,6 +1007,8 @@ def _on_aggregation_done( ExecutionResources.from_resource_dict(finalize_task_resource_bundle) ), operator_name=self.name, + block_ref_counter=self._block_ref_counter, + producer_id=self.id, ) self._finalizing_tasks[partition_id] = data_task diff --git a/python/ray/data/_internal/execution/operators/input_data_buffer.py b/python/ray/data/_internal/execution/operators/input_data_buffer.py index 66cdde25c81..5b89ea02c7d 100644 --- a/python/ray/data/_internal/execution/operators/input_data_buffer.py +++ b/python/ray/data/_internal/execution/operators/input_data_buffer.py @@ -45,7 +45,7 @@ def __init__( self._input_data_index = 0 self.mark_execution_finished() - def start(self, options: ExecutionOptions) -> None: + def start(self, options: ExecutionOptions, block_ref_counter=None) -> None: if not self._is_input_initialized: self._input_data = self._input_data_factory( self.target_max_block_size_override @@ -57,7 +57,7 @@ def start(self, options: ExecutionOptions) -> None: # so we record input metrics here for bundle in self._input_data: self._metrics.on_input_received(bundle) - super().start(options) + super().start(options, block_ref_counter) def has_next(self) -> bool: return self._input_data_index < len(self._input_data) diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 60df0b491fb..03d7983306b 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -481,8 +481,8 @@ def create( else: raise ValueError(f"Unsupported execution strategy {compute_strategy}") - def start(self, options: "ExecutionOptions"): - super().start(options) + def start(self, options: "ExecutionOptions", block_ref_counter=None): + super().start(options, block_ref_counter) # Create output queue with desired ordering semantics. if options.preserve_order: self._output_queue = ReorderingBundleQueue() @@ -655,8 +655,12 @@ def _task_done_callback( data_task = DataOpTask( task_index, gen, - lambda output: _output_ready_callback(task_index, output), - functools.partial(_task_done_callback, task_index), + self._block_ref_counter, + self.id, + output_ready_callback=lambda output: _output_ready_callback( + task_index, output + ), + task_done_callback=functools.partial(_task_done_callback, task_index), operator_name=self.name, ) self._metrics.on_task_submitted( diff --git a/python/ray/data/_internal/execution/operators/output_splitter.py b/python/ray/data/_internal/execution/operators/output_splitter.py index f436179a77b..b3fdc566970 100644 --- a/python/ray/data/_internal/execution/operators/output_splitter.py +++ b/python/ray/data/_internal/execution/operators/output_splitter.py @@ -124,13 +124,13 @@ def num_output_rows_total(self) -> Optional[int]: # The total number of rows is the same as the number of input rows. return self.input_dependencies[0].num_output_rows_total() - def start(self, options: ExecutionOptions) -> None: + def start(self, options: ExecutionOptions, block_ref_counter=None) -> None: if options.preserve_order: # If preserve_order is set, we need to ignore locality hints to ensure determinism. self._locality_hints = None self._max_buffer_size = 0 - super().start(options) + super().start(options, block_ref_counter) def throttling_disabled(self) -> bool: """Disables resource-based throttling. diff --git a/python/ray/data/_internal/execution/operators/union_operator.py b/python/ray/data/_internal/execution/operators/union_operator.py index caf84f3d4a0..62a0f8be1ff 100644 --- a/python/ray/data/_internal/execution/operators/union_operator.py +++ b/python/ray/data/_internal/execution/operators/union_operator.py @@ -59,12 +59,12 @@ def _input_queues(self) -> List["BaseBundleQueue"]: def _output_queues(self) -> List["BaseBundleQueue"]: return [self._output_buffer] - def start(self, options: ExecutionOptions): + def start(self, options: ExecutionOptions, block_ref_counter=None): # Whether to preserve deterministic ordering of output blocks. # When True, blocks are emitted in round-robin order across inputs, # ensuring the same input always produces the same output order. self._preserve_order = options.preserve_order - super().start(options) + super().start(options, block_ref_counter) def num_outputs_total(self) -> Optional[int]: num_outputs = 0 diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index 0349b2d42c1..5f86d73c8c0 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -8,6 +8,7 @@ from ray._common.utils import env_bool, env_float from ray.data._internal.execution import create_resource_allocator +from ray.data._internal.execution.block_ref_counter import BlockRefCounter from ray.data._internal.execution.interfaces.execution_options import ( ExecutionOptions, ExecutionResources, @@ -137,6 +138,8 @@ def __init__( # operator's output usage. self._output_operator = terminal_operator_from_topology(topology) + self._block_ref_counter = BlockRefCounter() + self._op_resource_allocator: Optional[ "OpResourceAllocator" ] = create_resource_allocator(self, data_context) @@ -168,6 +171,11 @@ def get_external_consumer_bytes(self) -> int: """Get the bytes buffered by external consumers.""" return self._external_consumer_bytes + @property + def block_ref_counter(self) -> BlockRefCounter: + """The centralized block reference counter for this executor.""" + return self._block_ref_counter + def _estimate_object_store_memory_usage( self, op: "PhysicalOperator", state: "OpState" ) -> int: diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 5cc070b03ac..54ea87108b9 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -217,6 +217,10 @@ def execute( self._data_context, ) + counter = self._resource_manager.block_ref_counter + for op in self._topology: + op.start(self._options, counter) + # Constructed once per executor (not per scheduling iteration) so the # guard's idle-detection state accumulates across scheduling iterations. self._output_backpressure_guard = OutputBackpressureGuard( @@ -332,6 +336,9 @@ def shutdown(self, force: bool, exception: Optional[Exception] = None): op.shutdown(timer, force=force) self._clear_topology_queues_post_shutdown(force, exception) + # Queues have been drained; any remaining Ray Core callbacks that fire + # after this point should be no-ops. + self._resource_manager.block_ref_counter.clear() min_ = round(timer.min(), 3) max_ = round(timer.max(), 3) diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py index dd68e6adcb2..1c691f13bf2 100644 --- a/python/ray/data/_internal/execution/streaming_executor_state.py +++ b/python/ray/data/_internal/execution/streaming_executor_state.py @@ -575,7 +575,6 @@ def setup_state(op: PhysicalOperator) -> OpState: # Create state. op_state = OpState(op, inqueues) topology[op] = op_state - op.start(options) return op_state setup_state(dag) diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index d7ed2636b81..00314259438 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -182,6 +182,8 @@ def all_transform(bundles: List[RefBundle], ctx): DataContext.get_current().target_max_block_size, ) + op1.start(ExecutionOptions()) + op2.start(ExecutionOptions()) while input_op.has_next(): op1.add_input(input_op.get_next(), 0) op1.all_inputs_done() diff --git a/python/ray/data/tests/test_streaming_executor.py b/python/ray/data/tests/test_streaming_executor.py index d82ef735308..5db06dfaf29 100644 --- a/python/ray/data/tests/test_streaming_executor.py +++ b/python/ray/data/tests/test_streaming_executor.py @@ -23,6 +23,7 @@ from ray.data._internal.execution.backpressure_policy.resource_budget_backpressure_policy import ( ResourceBudgetBackpressurePolicy, ) +from ray.data._internal.execution.block_ref_counter import BlockRefCounter from ray.data._internal.execution.execution_callback import ExecutionCallback from ray.data._internal.execution.interfaces import ( ExecutionOptions, @@ -1391,6 +1392,13 @@ def ensure_block_metadata_stored_in_plasma(monkeypatch): monkeypatch.setenv("RAY_max_direct_call_object_size", 0) +def _make_data_op_task(task_index, streaming_gen, **kwargs): + """Create a DataOpTask with a default BlockRefCounter and producer_id for tests.""" + kwargs.setdefault("block_ref_counter", BlockRefCounter()) + kwargs.setdefault("producer_id", "test_op") + return DataOpTask(task_index, streaming_gen, **kwargs) + + class TestDataOpTask: def test_on_data_ready_single_output(self, ray_start_regular_shared): streaming_gen = create_stub_streaming_gen(block_nbytes=[128 * MiB]) @@ -1398,7 +1406,9 @@ def test_on_data_ready_single_output(self, ray_start_regular_shared): def verify_output(bundle): assert bundle.size_bytes() == pytest.approx(128 * MiB), bundle.size_bytes() - data_op_task = DataOpTask(0, streaming_gen, output_ready_callback=verify_output) + data_op_task = _make_data_op_task( + 0, streaming_gen, output_ready_callback=verify_output + ) bytes_read = 0 while not data_op_task.has_finished: @@ -1414,7 +1424,9 @@ def test_on_data_ready_multiple_outputs(self, ray_start_regular_shared): def verify_output(bundle): assert bundle.size_bytes() == pytest.approx(128 * MiB), bundle.size_bytes() - data_op_task = DataOpTask(0, streaming_gen, output_ready_callback=verify_output) + data_op_task = _make_data_op_task( + 0, streaming_gen, output_ready_callback=verify_output + ) bytes_read = 0 while not data_op_task.has_finished: @@ -1435,7 +1447,7 @@ def verify_exception(exc, task_exec_stats, task_exec_driver_stats): assert task_exec_stats is None assert task_exec_driver_stats is None - data_op_task = DataOpTask( + data_op_task = _make_data_op_task( 0, streaming_gen, task_done_callback=verify_exception, @@ -1448,11 +1460,11 @@ def verify_exception(exc, task_exec_stats, task_exec_driver_stats): def test_operator_name_parameter(self, ray_start_regular_shared): streaming_gen = create_stub_streaming_gen(block_nbytes=[1]) - task = DataOpTask(0, streaming_gen, operator_name="MapBatches(fn)") + task = _make_data_op_task(0, streaming_gen, operator_name="MapBatches(fn)") assert task._operator_name == "MapBatches(fn)" streaming_gen2 = create_stub_streaming_gen(block_nbytes=[1]) - task_default = DataOpTask(1, streaming_gen2) + task_default = _make_data_op_task(1, streaming_gen2) assert task_default._operator_name == "Unknown" @pytest.mark.parametrize( @@ -1489,7 +1501,7 @@ def remove_and_add_back_worker_node(_): new_worker_node = cluster.add_node(num_cpus=1) # noqa: F841 cluster.wait_for_nodes() - data_op_task = DataOpTask( + data_op_task = _make_data_op_task( 0, streaming_gen, **{preempt_on: remove_and_add_back_worker_node} ) @@ -1520,7 +1532,7 @@ def test_on_data_ready_with_preemption_after_wait( # Create a streaming generator that produces a single 128 MiB output block. streaming_gen = create_stub_streaming_gen(block_nbytes=[128 * MiB]) - data_op_task = DataOpTask(0, streaming_gen) + data_op_task = _make_data_op_task(0, streaming_gen) # Wait for the block to be ready, then remove the worker node. ray.wait([streaming_gen], fetch_local=False) @@ -1560,7 +1572,7 @@ def capture_done(exc, task_exec_stats, task_exec_driver_stats): captured_stats["task_exec_stats"] = task_exec_stats captured_stats["task_exec_driver_stats"] = task_exec_driver_stats - data_op_task = DataOpTask( + data_op_task = _make_data_op_task( 0, streaming_gen, task_done_callback=capture_done, From 95bcd5220a6ff76929f86f6f744920762b785239 Mon Sep 17 00:00:00 2001 From: Sirui Huang Date: Wed, 17 Jun 2026 17:12:23 -0700 Subject: [PATCH 3/3] Add missing type notations + missing hash shuffle change Signed-off-by: Sirui Huang --- .../execution/interfaces/physical_operator.py | 17 ++++++++--------- .../operators/actor_pool_map_operator.py | 8 +++++++- .../operators/base_physical_operator.py | 13 +++++++------ .../execution/operators/hash_shuffle.py | 7 ++++++- .../execution/operators/input_data_buffer.py | 11 +++++++++-- .../execution/operators/map_operator.py | 8 +++++++- .../execution/operators/output_splitter.py | 11 +++++++++-- .../execution/operators/union_operator.py | 11 +++++++++-- .../data/_internal/gpu_shuffle/hash_shuffle.py | 12 +++++++++--- 9 files changed, 71 insertions(+), 27 deletions(-) diff --git a/python/ray/data/_internal/execution/interfaces/physical_operator.py b/python/ray/data/_internal/execution/interfaces/physical_operator.py index fc492400629..25f9034cadb 100644 --- a/python/ray/data/_internal/execution/interfaces/physical_operator.py +++ b/python/ray/data/_internal/execution/interfaces/physical_operator.py @@ -135,7 +135,7 @@ def __init__( self, task_index: int, streaming_gen: ObjectRefGenerator, - block_ref_counter: BlockRefCounter, + block_ref_counter: Optional[BlockRefCounter], producer_id: str, output_ready_callback: Callable[[RefBundle], None] = lambda bundle: None, task_done_callback: TaskDoneCallbackType = lambda exc, worker_stats, driver_stats: None, @@ -177,7 +177,7 @@ def __init__( self._block_ready_callback = block_ready_callback self._metadata_ready_callback = metadata_ready_callback self._operator_name = operator_name - self._block_ref_counter: BlockRefCounter = block_ref_counter + self._block_ref_counter: Optional[BlockRefCounter] = block_ref_counter self._producer_id: str = producer_id # If the generator hasn't produced block metadata yet, or if the block metadata @@ -300,9 +300,10 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int: meta_with_schema_bytes ) meta = meta_with_schema.metadata - self._block_ref_counter.on_block_produced( - self._pending_block_ref, meta.size_bytes or 0, self._producer_id - ) + if self._block_ref_counter is not None: + self._block_ref_counter.on_block_produced( + self._pending_block_ref, meta.size_bytes or 0, self._producer_id + ) self._output_ready_callback( RefBundle( [BlockEntry(self._pending_block_ref, meta)], @@ -765,11 +766,9 @@ def start( Args: options: The global options used for the overall execution. block_ref_counter: The executor-wide shared counter for tracking - object-store memory. If omitted, a fresh per-operator counter is used. + object-store memory. """ - self._block_ref_counter = ( - block_ref_counter if block_ref_counter is not None else BlockRefCounter() - ) + self._block_ref_counter = block_ref_counter self._started = True def can_add_input(self) -> bool: diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index de38d336c55..f5b48b6db32 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: import pyarrow as pa + + from ray.data._internal.execution.block_ref_counter import BlockRefCounter import ray from ray.actor import ActorHandle from ray.core.generated import gcs_pb2 @@ -265,7 +267,11 @@ def _apply_default_actor_task_remote_args( return ray_actor_task_remote_args - def start(self, options: ExecutionOptions, block_ref_counter=None): + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional["BlockRefCounter"] = None, + ): self._actor_locality_enabled = options.actor_locality_enabled super().start(options, block_ref_counter) diff --git a/python/ray/data/_internal/execution/operators/base_physical_operator.py b/python/ray/data/_internal/execution/operators/base_physical_operator.py index e48562a07e5..f0220394d2c 100644 --- a/python/ray/data/_internal/execution/operators/base_physical_operator.py +++ b/python/ray/data/_internal/execution/operators/base_physical_operator.py @@ -193,12 +193,13 @@ def all_inputs_done(self) -> None: output_buffer, self._stats = self._bulk_fn(input_bundles, ctx) self._output_buffer = FIFOBundleQueue(output_buffer) - for bundle in output_buffer: - for entry in bundle.blocks: - if entry.ref not in input_refs: - self._block_ref_counter.on_block_produced( - entry.ref, entry.metadata.size_bytes or 0, self.id - ) + if self._block_ref_counter is not None: + for bundle in output_buffer: + for entry in bundle.blocks: + if entry.ref not in input_refs: + self._block_ref_counter.on_block_produced( + entry.ref, entry.metadata.size_bytes or 0, self.id + ) while self._input_buffer.has_next(): refs = self._input_buffer.get_next() diff --git a/python/ray/data/_internal/execution/operators/hash_shuffle.py b/python/ray/data/_internal/execution/operators/hash_shuffle.py index fe9259319a0..6064103e869 100644 --- a/python/ray/data/_internal/execution/operators/hash_shuffle.py +++ b/python/ray/data/_internal/execution/operators/hash_shuffle.py @@ -79,6 +79,7 @@ ) if typing.TYPE_CHECKING: + from ray.data._internal.execution.block_ref_counter import BlockRefCounter from ray.data._internal.progress.base_progress import BaseProgressBar logger = logging.getLogger(__name__) @@ -657,7 +658,11 @@ def __init__( self._reduce_bar = None self._reduce_metrics = OpRuntimeMetrics(self) - def start(self, options: ExecutionOptions, block_ref_counter=None) -> None: + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional["BlockRefCounter"] = None, + ) -> None: super().start(options, block_ref_counter) self._aggregator_pool.start() diff --git a/python/ray/data/_internal/execution/operators/input_data_buffer.py b/python/ray/data/_internal/execution/operators/input_data_buffer.py index 5b89ea02c7d..16bb37993f4 100644 --- a/python/ray/data/_internal/execution/operators/input_data_buffer.py +++ b/python/ray/data/_internal/execution/operators/input_data_buffer.py @@ -1,4 +1,7 @@ -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional + +if TYPE_CHECKING: + from ray.data._internal.execution.block_ref_counter import BlockRefCounter from ray.data._internal.execution.interfaces import ( ExecutionOptions, @@ -45,7 +48,11 @@ def __init__( self._input_data_index = 0 self.mark_execution_finished() - def start(self, options: ExecutionOptions, block_ref_counter=None) -> None: + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional["BlockRefCounter"] = None, + ) -> None: if not self._is_input_initialized: self._input_data = self._input_data_factory( self.target_max_block_size_override diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 03d7983306b..469c75f6222 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: import pyarrow as pa + from ray.data._internal.execution.block_ref_counter import BlockRefCounter + import ray from ray import ObjectRef from ray._raylet import ObjectRefGenerator, StreamingGeneratorStats @@ -481,7 +483,11 @@ def create( else: raise ValueError(f"Unsupported execution strategy {compute_strategy}") - def start(self, options: "ExecutionOptions", block_ref_counter=None): + def start( + self, + options: "ExecutionOptions", + block_ref_counter: Optional["BlockRefCounter"] = None, + ): super().start(options, block_ref_counter) # Create output queue with desired ordering semantics. if options.preserve_order: diff --git a/python/ray/data/_internal/execution/operators/output_splitter.py b/python/ray/data/_internal/execution/operators/output_splitter.py index b3fdc566970..5180506156a 100644 --- a/python/ray/data/_internal/execution/operators/output_splitter.py +++ b/python/ray/data/_internal/execution/operators/output_splitter.py @@ -2,7 +2,10 @@ import math import time from dataclasses import replace -from typing import Any, Collection, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple + +if TYPE_CHECKING: + from ray.data._internal.execution.block_ref_counter import BlockRefCounter from typing_extensions import override @@ -124,7 +127,11 @@ def num_output_rows_total(self) -> Optional[int]: # The total number of rows is the same as the number of input rows. return self.input_dependencies[0].num_output_rows_total() - def start(self, options: ExecutionOptions, block_ref_counter=None) -> None: + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional["BlockRefCounter"] = None, + ) -> None: if options.preserve_order: # If preserve_order is set, we need to ignore locality hints to ensure determinism. self._locality_hints = None diff --git a/python/ray/data/_internal/execution/operators/union_operator.py b/python/ray/data/_internal/execution/operators/union_operator.py index 62a0f8be1ff..d07bb4639a3 100644 --- a/python/ray/data/_internal/execution/operators/union_operator.py +++ b/python/ray/data/_internal/execution/operators/union_operator.py @@ -1,7 +1,10 @@ -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from typing_extensions import override +if TYPE_CHECKING: + from ray.data._internal.execution.block_ref_counter import BlockRefCounter + from ray.data._internal.execution.bundle_queue import BaseBundleQueue, FIFOBundleQueue from ray.data._internal.execution.interfaces import ( ExecutionOptions, @@ -59,7 +62,11 @@ def _input_queues(self) -> List["BaseBundleQueue"]: def _output_queues(self) -> List["BaseBundleQueue"]: return [self._output_buffer] - def start(self, options: ExecutionOptions, block_ref_counter=None): + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional["BlockRefCounter"] = None, + ): # Whether to preserve deterministic ordering of output blocks. # When True, blocks are emitted in round-robin order across inputs, # ensuring the same input always produces the same output order. diff --git a/python/ray/data/_internal/gpu_shuffle/hash_shuffle.py b/python/ray/data/_internal/gpu_shuffle/hash_shuffle.py index af3345d66bb..245e27bc9c4 100644 --- a/python/ray/data/_internal/gpu_shuffle/hash_shuffle.py +++ b/python/ray/data/_internal/gpu_shuffle/hash_shuffle.py @@ -38,7 +38,7 @@ from ray.data.context import DataContext if typing.TYPE_CHECKING: - + from ray.data._internal.execution.block_ref_counter import BlockRefCounter from ray.data._internal.progress.base_progress import BaseProgressBar logger = logging.getLogger(__name__) @@ -491,8 +491,12 @@ def __init__( # Lifecycle # ------------------------------------------------------------------ - def start(self, options: ExecutionOptions) -> None: - super().start(options) + def start( + self, + options: ExecutionOptions, + block_ref_counter: Optional["BlockRefCounter"] = None, + ) -> None: + super().start(options, block_ref_counter) self._rank_pool.start() def _add_input_inner(self, bundle: RefBundle, input_index: int) -> None: @@ -626,6 +630,8 @@ def _on_extraction_done( data_task = DataOpTask( task_index=rank_idx, streaming_gen=block_gen, + block_ref_counter=self._block_ref_counter, + producer_id=self.id, output_ready_callback=_on_bundle_ready, task_done_callback=functools.partial( _on_extraction_done, rank=rank_idx