Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def process_batch_arrow(table: pyarrow.Table):
process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE
)
cached_columns.add(Fields.stats)
if op.use_ray_actor():
prepare_for_ray_map_batches = getattr(op, "_prepare_for_ray_map_batches", None)
use_instance_for_ray_tasks = bool(prepare_for_ray_map_batches and prepare_for_ray_map_batches())
if op.use_ray_actor() and not use_instance_for_ray_tasks:
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
self.data = self.data.map_batches(
op.__class__,
Expand All @@ -301,6 +303,8 @@ def process_batch_arrow(table: pyarrow.Table):
compute=compute,
runtime_env=op.runtime_env,
)
if use_instance_for_ray_tasks:
self.data = self.data.materialize()
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path, force_ascii=False)
# Wrap process method with tracer for sample-level collection
Expand Down
11 changes: 11 additions & 0 deletions data_juicer/ops/deduplicator/ray_basic_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def _ensure_actors(self):
RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set()
self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)]

def prepare_for_ray_execution(self):
"""Create shared actors before this backend is serialized to Ray tasks."""
RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set()
self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)]
Comment thread
macroguo-ghy marked this conversation as resolved.
Outdated

def is_unique(self, md5_value: str):
self._ensure_actors()
dedup_set_id = int.from_bytes(md5_value.encode(), byteorder="little") % MERSENNE_PRIME % self.dedup_set_num
Expand Down Expand Up @@ -133,6 +138,12 @@ def __init__(
else:
raise ValueError(f"Unknown backend: {backend}")

def _prepare_for_ray_map_batches(self):
if isinstance(self.backend, ActorBackend):
self.backend.prepare_for_ray_execution()
return True
return False
Comment thread
macroguo-ghy marked this conversation as resolved.
Outdated

def calculate_hash(self, sample, context=False):
"""Calculate hash value for the sample."""
raise NotImplementedError
Expand Down
159 changes: 156 additions & 3 deletions tests/ops/deduplicator/test_ray_document_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import patch

from data_juicer.core.data import NestedDataset as Dataset

Expand All @@ -9,8 +10,39 @@

class RayDocumentDeduplicatorTest(DataJuicerTestCaseBase):

def _run_ray_cross_block_dedup(self, samples, op):
dataset = self._build_ray_cross_block_dataset(samples)
dataset.process([op])
return dataset.data.take_all()

def _build_ray_cross_block_dataset(self, samples):
import ray

from data_juicer.core.data.ray_dataset import RayDataset
from data_juicer.utils.constant import Fields

ds_list = [{Fields.stats: {}, **sample} for sample in samples]
return RayDataset(
ray.data.from_items(ds_list, override_num_blocks=len(ds_list)),
cfg={'auto_op_parallelism': False},
auto_op_parallelism=False,
)

def _run_doc_dedup(self, dataset: Dataset, target_list, op):
res_list = self.run_single_op(dataset, op, [op.text_key])
import ray

from data_juicer.core.data.ray_dataset import RayDataset

dataset = RayDataset(
ray.data.from_items(dataset.to_list()),
cfg={'auto_op_parallelism': False},
auto_op_parallelism=False,
)
dataset.process([op])
res_list = [
{op.text_key: sample[op.text_key]}
for sample in dataset.data.take_all()
]
res_list.sort(key=lambda x: x['text'])
target_list.sort(key=lambda x: x['text'])
self.assertEqual(res_list, target_list)
Expand Down Expand Up @@ -47,7 +79,14 @@ def test_english_deduplication(self):
'This paper proposed a novel method on LLM pretraining.'
}]
dataset = self.generate_dataset(ds_list)
op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False)
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=2,
auto_op_parallelism=False,
)
self._run_doc_dedup(dataset, tgt_list, op)

@TEST_TAG("ray")
Expand Down Expand Up @@ -94,9 +133,123 @@ def test_chinese_deduplication(self):
},
]
dataset = self.generate_dataset(ds_list)
op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False)
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=2,
auto_op_parallelism=False,
)
self._run_doc_dedup(dataset, tgt_list, op)

@TEST_TAG("ray")
def test_ray_actor_backend_deduplicates_across_blocks(self):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)

res_list = self._run_ray_cross_block_dedup(
[{'text': 'duplicate across ray blocks'} for _ in range(8)],
op,
)

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate across ray blocks')

@TEST_TAG("ray")
def test_ray_actor_execution_mode_still_shares_dedup_sets(self):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
ray_execution_mode='actor',
)

res_list = self._run_ray_cross_block_dedup(
[{'text': 'duplicate with actor execution mode'} for _ in range(8)],
op,
)

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate with actor execution mode')

@TEST_TAG("ray")
def test_ray_basic_deduplicator_subclasses_share_dedup_sets(self):
from data_juicer.ops.deduplicator.ray_image_deduplicator import RayImageDeduplicator
from data_juicer.ops.deduplicator.ray_video_deduplicator import RayVideoDeduplicator

cases = [
(RayImageDeduplicator, {'images': []}),
(RayVideoDeduplicator, {'videos': []}),
]
for op_cls, sample in cases:
with self.subTest(op_cls=op_cls.__name__):
op = op_cls(
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)

res_list = self._run_ray_cross_block_dedup([sample for _ in range(8)], op)

self.assertEqual(len(res_list), 1)

@TEST_TAG("ray")
def test_repeated_execution_keeps_materialized_dedup_result(self):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)
dataset = self._build_ray_cross_block_dataset([{
'text': 'duplicate across repeated executions',
} for _ in range(8)])

dataset.process([op])
self.assertEqual(dataset.data.count(), 1)
res_list = dataset.data.take_all()

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate across repeated executions')

@TEST_TAG("ray")
def test_stats_export_does_not_consume_dedup_state_before_filter(self):
def materializing_write_json(dataset, *args, **kwargs):
return dataset.count()

with patch('ray.data.Dataset.write_json', materializing_write_json):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
stats_export_path='mock_stats_export_path',
)
dataset = self._build_ray_cross_block_dataset([{
'text': 'duplicate with stats export',
} for _ in range(8)])

dataset.process([op])
res_list = dataset.data.take_all()

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate with stats export')


if __name__ == '__main__':
unittest.main()
Loading