From 09a756f3cf1c4bc38d28edd2ffc7def473d006ef Mon Sep 17 00:00:00 2001 From: "guohongyu.7" Date: Thu, 14 May 2026 10:58:21 +0800 Subject: [PATCH 1/4] fix(ray): share dedup state per execution --- data_juicer/core/data/ray_dataset.py | 6 +- .../deduplicator/ray_basic_deduplicator.py | 11 ++ .../test_ray_document_deduplicator.py | 159 +++++++++++++++++- 3 files changed, 172 insertions(+), 4 deletions(-) diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 538a90bbb95..7503fdd91a0 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -275,7 +275,9 @@ def process_batch_arrow(table: pyarrow.Table): process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE ) cached_columns.add(Fields.stats) - if op.use_ray_actor(): + prepare_for_ray_map_batches = getattr(op, "_prepare_for_ray_map_batches", None) + use_instance_for_ray_tasks = bool(prepare_for_ray_map_batches and prepare_for_ray_map_batches()) + if op.use_ray_actor() and not use_instance_for_ray_tasks: compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) self.data = self.data.map_batches( op.__class__, @@ -301,6 +303,8 @@ def process_batch_arrow(table: pyarrow.Table): compute=compute, runtime_env=op.runtime_env, ) + if use_instance_for_ray_tasks: + self.data = self.data.materialize() if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) # Wrap process method with tracer for sample-level collection diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index b7b50172a42..e949f682a18 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -74,6 +74,11 @@ def _ensure_actors(self): RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set() self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)] + def prepare_for_ray_execution(self): + """Create shared actors before this backend is serialized to Ray tasks.""" + RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set() + self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)] + def is_unique(self, md5_value: str): self._ensure_actors() dedup_set_id = int.from_bytes(md5_value.encode(), byteorder="little") % MERSENNE_PRIME % self.dedup_set_num @@ -133,6 +138,12 @@ def __init__( else: raise ValueError(f"Unknown backend: {backend}") + def _prepare_for_ray_map_batches(self): + if isinstance(self.backend, ActorBackend): + self.backend.prepare_for_ray_execution() + return True + return False + def calculate_hash(self, sample, context=False): """Calculate hash value for the sample.""" raise NotImplementedError diff --git a/tests/ops/deduplicator/test_ray_document_deduplicator.py b/tests/ops/deduplicator/test_ray_document_deduplicator.py index e8bf23183af..c40331feedc 100644 --- a/tests/ops/deduplicator/test_ray_document_deduplicator.py +++ b/tests/ops/deduplicator/test_ray_document_deduplicator.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch from data_juicer.core.data import NestedDataset as Dataset @@ -9,8 +10,39 @@ class RayDocumentDeduplicatorTest(DataJuicerTestCaseBase): + def _run_ray_cross_block_dedup(self, samples, op): + dataset = self._build_ray_cross_block_dataset(samples) + dataset.process([op]) + return dataset.data.take_all() + + def _build_ray_cross_block_dataset(self, samples): + import ray + + from data_juicer.core.data.ray_dataset import RayDataset + from data_juicer.utils.constant import Fields + + ds_list = [{Fields.stats: {}, **sample} for sample in samples] + return RayDataset( + ray.data.from_items(ds_list, override_num_blocks=len(ds_list)), + cfg={'auto_op_parallelism': False}, + auto_op_parallelism=False, + ) + def _run_doc_dedup(self, dataset: Dataset, target_list, op): - res_list = self.run_single_op(dataset, op, [op.text_key]) + import ray + + from data_juicer.core.data.ray_dataset import RayDataset + + dataset = RayDataset( + ray.data.from_items(list(dataset)), + cfg={'auto_op_parallelism': False}, + auto_op_parallelism=False, + ) + dataset.process([op]) + res_list = [ + {op.text_key: sample[op.text_key]} + for sample in dataset.data.take_all() + ] res_list.sort(key=lambda x: x['text']) target_list.sort(key=lambda x: x['text']) self.assertEqual(res_list, target_list) @@ -47,7 +79,14 @@ def test_english_deduplication(self): 'This paper proposed a novel method on LLM pretraining.' }] dataset = self.generate_dataset(ds_list) - op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False) + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=2, + auto_op_parallelism=False, + ) self._run_doc_dedup(dataset, tgt_list, op) @TEST_TAG("ray") @@ -94,9 +133,123 @@ def test_chinese_deduplication(self): }, ] dataset = self.generate_dataset(ds_list) - op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False) + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=2, + auto_op_parallelism=False, + ) self._run_doc_dedup(dataset, tgt_list, op) + @TEST_TAG("ray") + def test_ray_actor_backend_deduplicates_across_blocks(self): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + + res_list = self._run_ray_cross_block_dedup( + [{'text': 'duplicate across ray blocks'} for _ in range(8)], + op, + ) + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate across ray blocks') + + @TEST_TAG("ray") + def test_ray_actor_execution_mode_still_shares_dedup_sets(self): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ray_execution_mode='actor', + ) + + res_list = self._run_ray_cross_block_dedup( + [{'text': 'duplicate with actor execution mode'} for _ in range(8)], + op, + ) + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate with actor execution mode') + + @TEST_TAG("ray") + def test_ray_basic_deduplicator_subclasses_share_dedup_sets(self): + from data_juicer.ops.deduplicator.ray_image_deduplicator import RayImageDeduplicator + from data_juicer.ops.deduplicator.ray_video_deduplicator import RayVideoDeduplicator + + cases = [ + (RayImageDeduplicator, {'images': []}), + (RayVideoDeduplicator, {'videos': []}), + ] + for op_cls, sample in cases: + with self.subTest(op_cls=op_cls.__name__): + op = op_cls( + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + + res_list = self._run_ray_cross_block_dedup([sample for _ in range(8)], op) + + self.assertEqual(len(res_list), 1) + + @TEST_TAG("ray") + def test_repeated_execution_keeps_materialized_dedup_result(self): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + dataset = self._build_ray_cross_block_dataset([{ + 'text': 'duplicate across repeated executions', + } for _ in range(8)]) + + dataset.process([op]) + self.assertEqual(dataset.data.count(), 1) + res_list = dataset.data.take_all() + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate across repeated executions') + + @TEST_TAG("ray") + def test_stats_export_does_not_consume_dedup_state_before_filter(self): + def materializing_write_json(dataset, *args, **kwargs): + return dataset.count() + + with patch('ray.data.Dataset.write_json', materializing_write_json): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + stats_export_path='mock_stats_export_path', + ) + dataset = self._build_ray_cross_block_dataset([{ + 'text': 'duplicate with stats export', + } for _ in range(8)]) + + dataset.process([op]) + res_list = dataset.data.take_all() + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate with stats export') + if __name__ == '__main__': unittest.main() From 14899c8b99bf961e7284c029cfb0da4256749b01 Mon Sep 17 00:00:00 2001 From: "guohongyu.7" Date: Thu, 14 May 2026 11:59:11 +0800 Subject: [PATCH 2/4] Fix Ray document deduplicator test dataset conversion --- tests/ops/deduplicator/test_ray_document_deduplicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ops/deduplicator/test_ray_document_deduplicator.py b/tests/ops/deduplicator/test_ray_document_deduplicator.py index c40331feedc..e13c8e7ab47 100644 --- a/tests/ops/deduplicator/test_ray_document_deduplicator.py +++ b/tests/ops/deduplicator/test_ray_document_deduplicator.py @@ -34,7 +34,7 @@ def _run_doc_dedup(self, dataset: Dataset, target_list, op): from data_juicer.core.data.ray_dataset import RayDataset dataset = RayDataset( - ray.data.from_items(list(dataset)), + ray.data.from_items(dataset.to_list()), cfg={'auto_op_parallelism': False}, auto_op_parallelism=False, ) From 5dce71592150fdc5569689c9d65bafa5cff02eac Mon Sep 17 00:00:00 2001 From: "guohongyu.7" Date: Thu, 14 May 2026 14:18:19 +0800 Subject: [PATCH 3/4] Address Ray deduplicator review comments --- .../deduplicator/ray_basic_deduplicator.py | 6 ++-- .../test_ray_document_deduplicator.py | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index e949f682a18..1ff09a88861 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -76,8 +76,7 @@ def _ensure_actors(self): def prepare_for_ray_execution(self): """Create shared actors before this backend is serialized to Ray tasks.""" - RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set() - self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)] + self._ensure_actors() def is_unique(self, md5_value: str): self._ensure_actors() @@ -141,8 +140,7 @@ def __init__( def _prepare_for_ray_map_batches(self): if isinstance(self.backend, ActorBackend): self.backend.prepare_for_ray_execution() - return True - return False + return True def calculate_hash(self, sample, context=False): """Calculate hash value for the sample.""" diff --git a/tests/ops/deduplicator/test_ray_document_deduplicator.py b/tests/ops/deduplicator/test_ray_document_deduplicator.py index e13c8e7ab47..12d15c5d78e 100644 --- a/tests/ops/deduplicator/test_ray_document_deduplicator.py +++ b/tests/ops/deduplicator/test_ray_document_deduplicator.py @@ -250,6 +250,41 @@ def materializing_write_json(dataset, *args, **kwargs): self.assertEqual(len(res_list), 1) self.assertEqual(res_list[0]['text'], 'duplicate with stats export') + def test_prepare_for_ray_execution_reuses_existing_actor_handles(self): + from data_juicer.ops.deduplicator.ray_basic_deduplicator import ActorBackend + + class RemoteDedupSet: + calls = 0 + + @classmethod + def remote(cls): + cls.calls += 1 + return object() + + backend = ActorBackend(dedup_set_num=2, RemoteDedupSet=RemoteDedupSet) + + backend.prepare_for_ray_execution() + dedup_sets = backend._dedup_sets + backend.prepare_for_ray_execution() + + self.assertIs(backend._dedup_sets, dedup_sets) + self.assertEqual(RemoteDedupSet.calls, 2) + + def test_redis_backend_requests_ray_materialization(self): + from data_juicer.ops.deduplicator.ray_basic_deduplicator import RedisBackend + + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + op.backend = RedisBackend.__new__(RedisBackend) + + self.assertTrue(op._prepare_for_ray_map_batches()) + if __name__ == '__main__': unittest.main() From c551ad07f4796c9427b7effcad35e094ee23c05d Mon Sep 17 00:00:00 2001 From: fengrui-z Date: Mon, 8 Jun 2026 14:48:44 +0800 Subject: [PATCH 4/4] fix: log when ray_execution_mode is overridden for stateful dedup ops When a dedup operator has ray_execution_mode='actor' but gets downgraded to task mode to preserve shared dedup state, emit an info log so users understand why their config was overridden. Co-Authored-By: Claude Opus 4.6 --- data_juicer/core/data/ray_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 7503fdd91a0..bfb87714317 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -277,6 +277,11 @@ def process_batch_arrow(table: pyarrow.Table): cached_columns.add(Fields.stats) prepare_for_ray_map_batches = getattr(op, "_prepare_for_ray_map_batches", None) use_instance_for_ray_tasks = bool(prepare_for_ray_map_batches and prepare_for_ray_map_batches()) + if use_instance_for_ray_tasks and op.use_ray_actor(): + logger.info( + f"{op._name}: overriding ray_execution_mode from actor to task " + f"to preserve shared dedup state across workers" + ) if op.use_ray_actor() and not use_instance_for_ray_tasks: compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) self.data = self.data.map_batches(