Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion data_juicer/core/data/dj_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import os
import traceback
from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import wraps
from time import time
from typing import Any, Dict, List, Optional, Union

from datasets import Dataset, DatasetDict, is_caching_enabled
from datasets import Dataset, DatasetDict, Features, is_caching_enabled
from datasets.formatting.formatting import LazyBatch

from data_juicer.core.data.schema import Schema
Expand Down Expand Up @@ -138,6 +139,26 @@ def nested_obj_factory(obj):
return obj


def _merge_feature_dicts(base_features, output_feature_hints):
"""Merge partial output feature hints into existing dataset features."""
merged = copy.deepcopy(dict(base_features or {}))

for key, feature in dict(output_feature_hints or {}).items():
if key in merged and isinstance(merged[key], Mapping) and isinstance(feature, Mapping):
merged[key] = _merge_feature_dicts(merged[key], feature)
else:
merged[key] = copy.deepcopy(feature)

return merged
Comment thread
cmgzn marked this conversation as resolved.


def merge_features(base_features, output_feature_hints):
"""Return HuggingFace Features with partial output hints merged in."""
if output_feature_hints is None:
return base_features
return Features(_merge_feature_dicts(base_features, output_feature_hints))


class NestedQueryDict(dict):
"""Enhanced dict for better usability."""

Expand Down Expand Up @@ -399,6 +420,10 @@ def map(self, *args, **kargs):
"""Override the map func, which is called by most common operations,
such that the processed samples can be accessed by nested manner."""

output_feature_hints = kargs.pop("output_feature_hints", None)
if output_feature_hints is not None:
kargs["features"] = merge_features(kargs.get("features", self.features), output_feature_hints)

args, kargs = self.update_args(args, kargs)

if cache_utils.CACHE_COMPRESS:
Expand Down
46 changes: 41 additions & 5 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,35 @@ def use_ray_actor(self):
def process(self, *args, **kwargs):
raise NotImplementedError

def output_feature_hints(self, input_features):
"""Return partial HuggingFace feature hints for fields written by this OP.

Most OPs can return None. Provide hints when this OP adds or rewrites
fields whose type cannot be inferred reliably from early HuggingFace
map batches, especially empty lists that later contain concrete values,
nested lists, list-of-struct fields, or numpy arrays.

The returned value is a partial feature tree. It is merged into the
input dataset features and forwarded to ``Dataset.map(features=...)``;
it is not required to describe the full output schema. HuggingFace will
cast mapped values to the declared features, so hints must match the
values returned by the OP.

Example:
return {
Fields.meta: {
MetaKeys.bbox_tag: List(List(Value("float32"))),
}
}
"""
return None

def _map_with_output_feature_hints(self, dataset, *args, **kwargs):
output_feature_hints = self.output_feature_hints(dataset.features)
if output_feature_hints is not None:
kwargs["output_feature_hints"] = output_feature_hints
return dataset.map(*args, **kwargs)

def use_cuda(self):
return self.accelerator == "cuda" and is_cuda_available()

Expand Down Expand Up @@ -698,7 +727,8 @@ def run(self, dataset, *, exporter=None, tracer=None):
)

try:
new_dataset = dataset.map(
new_dataset = self._map_with_output_feature_hints(
dataset,
self.process,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
Expand Down Expand Up @@ -831,7 +861,8 @@ def process_single(self, sample):

def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
dataset = super(Filter, self).run(dataset)
new_dataset = dataset.map(
new_dataset = self._map_with_output_feature_hints(
dataset,
self.compute_stats,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
Expand Down Expand Up @@ -918,8 +949,12 @@ def process(self, dataset, show_num=0):

def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
dataset = super(Deduplicator, self).run(dataset)
new_dataset = dataset.map(
self.compute_hash, num_proc=self.runtime_np(), with_rank=self.use_cuda(), desc=self._name + "_compute_hash"
new_dataset = self._map_with_output_feature_hints(
dataset,
self.compute_hash,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
desc=self._name + "_compute_hash",
)
if reduce:
show_num = tracer.show_num if tracer else 0
Expand Down Expand Up @@ -1065,7 +1100,8 @@ def run(self, dataset, *, exporter=None, tracer=None):
batch_size=self.batch_size,
desc="Adding new column for aggregation",
)
new_dataset = dataset.map(
new_dataset = self._map_with_output_feature_hints(
dataset,
self.process,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional

import numpy as np
from datasets import List, Value

import data_juicer
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
Expand Down Expand Up @@ -192,6 +193,9 @@ def __init__(
# update num_proc with the min num_proc of all fusible filters
self.num_proc = min([op.runtime_np() for op in self.fused_ops]) if self.fused_ops else 1

def output_feature_hints(self, input_features):
return {Fields.meta: {MetaKeys.bbox_tag: List(List(Value("float32")))}}

def _prepare_op_args(self, op_name, args_dict):
for key in self.FIXED_ARGS[op_name]:
if key not in args_dict:
Expand Down Expand Up @@ -224,7 +228,7 @@ def process_single(self, samples, rank=None):
new_samples_s1 = self.fused_ops[0].process_single(new_samples_s1, rank=rank)

if not new_samples_s1:
return {Fields.meta: {MetaKeys.bbox_tag: np.zeros((1, 4), dtype=np.float32)}}
return {Fields.meta: {MetaKeys.bbox_tag: []}}

# Step2: compare the differences between the two captions and identify
# the "valid object".
Expand Down Expand Up @@ -402,7 +406,7 @@ def process_single(self, samples, rank=None):
os.remove(temp_image_path)
for temp_image_path in crop_image2_path_to_bbox_dict:
os.remove(temp_image_path)
return {Fields.meta: {MetaKeys.bbox_tag: np.zeros((1, 4), dtype=np.float32)}}
return {Fields.meta: {MetaKeys.bbox_tag: []}}

filtered_bboxes = []
for temp_sub_image_pairs in filtered_sub_image_pairs:
Expand All @@ -414,7 +418,7 @@ def process_single(self, samples, rank=None):
iou_thresh = 0.5
filtered_bboxes = iou_filter(filtered_bboxes, iou_thresh)
samples[Fields.meta] = {}
samples[Fields.meta][MetaKeys.bbox_tag] = filtered_bboxes
samples[Fields.meta][MetaKeys.bbox_tag] = filtered_bboxes.tolist()

# Step8: clear the cache
for temp_image_path in crop_image1_path_to_bbox_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Optional

import numpy as np
from datasets import List, Value

import data_juicer
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
Expand Down Expand Up @@ -36,7 +37,7 @@ class Difference_Caption_Generator_Mapper(Mapper):
- The key metric is the similarity score between the captions, computed using a CLIP
model.
- If no valid bounding boxes or differences are found, it returns empty captions and
zeroed bounding boxes.
empty bounding boxes.
- Uses 'cuda' as the accelerator if any of the fused operations support it.
- Caches temporary images during processing and clears them afterward."""

Expand Down Expand Up @@ -121,26 +122,37 @@ def _prepare_op_args(self, op_name, args_dict):
args_dict["accelerator"] = self.accelerator
return args_dict

def output_feature_hints(self, input_features):
return {
Fields.meta: {
"region_caption1": List(Value("string")),
"region_caption2": List(Value("string")),
MetaKeys.bbox_tag: List(List(Value("float32"))),
"bbox_difference_captions": List(Value("string")),
}
}

def _empty_difference_meta(self):
return {
Fields.meta: {
"region_caption1": [],
"region_caption2": [],
MetaKeys.bbox_tag: [],
"bbox_difference_captions": [],
}
}

def process_single(self, samples, rank=None):
random_num = str(random.random()).split(".")[-1]
if not os.path.exists(DATA_JUICER_ASSETS_CACHE):
os.makedirs(DATA_JUICER_ASSETS_CACHE, exist_ok=True)
cache_image_list = []

if (
len(samples[Fields.meta][MetaKeys.bbox_tag]) == 1
and np.sum(samples[Fields.meta][MetaKeys.bbox_tag][0]) == 0
):
bbox_tag = samples[Fields.meta][MetaKeys.bbox_tag]
if len(bbox_tag) == 0 or len(bbox_tag) == 1 and np.sum(bbox_tag[0]) == 0:
for temp_image_path in cache_image_list:
os.remove(temp_image_path)
return {
Fields.meta: {
"region_caption1": [""],
"region_caption2": [""],
MetaKeys.bbox_tag: np.zeros((1, 4), dtype=np.float32),
"bbox_difference_captions": [""],
}
}
return self._empty_difference_meta()

# fused_ops 1.mllm_mapper 2.image_text_matching_filter 3.text_pair_similarity_filter
# keys of sample: "image_path1", "image_path2", Fields.meta[MetaKeys.bbox_tag]
Expand Down Expand Up @@ -269,14 +281,7 @@ def process_single(self, samples, rank=None):
if len(filtered_caption_pairs) == 0:
for temp_image_path in cache_image_list:
os.remove(temp_image_path)
return {
Fields.meta: {
"region_caption1": [""],
"region_caption2": [""],
MetaKeys.bbox_tag: np.zeros((1, 4), dtype=np.float32),
"bbox_difference_captions": [""],
}
}
return self._empty_difference_meta()

# Step4: determine whether there are differences between the two captions.
filtered_caption_pairs = data_juicer.core.NestedDataset.from_list(filtered_caption_pairs)
Expand Down Expand Up @@ -313,14 +318,7 @@ def process_single(self, samples, rank=None):
if len(effective_bboxes) == 0:
for temp_image_path in cache_image_list:
os.remove(temp_image_path)
return {
Fields.meta: {
"region_caption1": [""],
"region_caption2": [""],
MetaKeys.bbox_tag: np.zeros((1, 4), dtype=np.float32),
"bbox_difference_captions": [""],
}
}
return self._empty_difference_meta()

# Step5: Mark the difference area with a red box
text_mllm_samples = []
Expand Down
24 changes: 24 additions & 0 deletions docs/DeveloperGuide.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,30 @@ class StatsKeysConstant(object):
return False
```

If the OP adds or rewrites fields whose type can be ambiguous in early
HuggingFace map batches, declare partial output feature hints by overriding
`output_feature_hints()`. This is mainly needed for empty lists that later
contain concrete values, nested lists, list-of-struct fields, or numpy array
outputs. The hints are merged into the input dataset features and forwarded
to `Dataset.map(features=...)`; they are not a complete output schema.
HuggingFace will cast mapped values to the declared features, so the hints
must match the values returned by the OP.

Prefer `datasets.List` for list-like hints. Avoid `Sequence(dict)` in this
hook: HuggingFace treats it as a struct of lists rather than a list of
structs.

```python
from datasets import List, Value

def output_feature_hints(self, input_features):
return {
Fields.meta: {
MetaKeys.bbox_tag: List(List(Value("float32"))),
}
}
```

3. After implementation, add it to the OP dictionary in the `__init__.py` file in `data_juicer/ops/filter/` directory.

```python
Expand Down
23 changes: 22 additions & 1 deletion docs/DeveloperGuide_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ class StatsKeysConstant(object):
return False
```

如果算子会新增或改写在 HuggingFace map 早期 batch 中类型不明确的字段,请通过覆写
`output_feature_hints()` 声明局部输出特征提示。这主要适用于前几条样本为空 list、
后续样本才出现具体值的字段,或嵌套 list、list-of-struct、numpy array 等输出。该
hint 会被合并到输入 dataset features 中,并传给 `Dataset.map(features=...)`;
它不是完整的输出 schema。HuggingFace 会按声明的 features 对 map 结果做 cast,
因此 hint 必须和算子实际返回值一致。

对于 list 类型的 hints,优先使用 `datasets.List`。避免在这个 hook 中使用
`Sequence(dict)`:HuggingFace 会把它解释为 struct of lists,而不是 list of
structs。

```python
from datasets import List, Value

def output_feature_hints(self, input_features):
return {
Fields.meta: {
MetaKeys.bbox_tag: List(List(Value("float32"))),
}
}
```

3. 实现后,将其添加到 `data_juicer/ops/filter` 目录下 `__init__.py` 文件中的算子字典中:

Expand Down Expand Up @@ -601,4 +622,4 @@ class PerplexityFilter(Filter):

- 欢迎添加您新配方的相应参考文献,或提出一些新需求、以及改进现有配方的想法。

- 我们非常欢迎共建,并将重点[注明致谢](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)!
- 我们非常欢迎共建,并将重点[注明致谢](https://github.com/datajuicer/data-juicer?tab=readme-ov-file#contribution-and-acknowledgements)!
38 changes: 37 additions & 1 deletion tests/core/data/test_dj_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from datasets import Dataset, DatasetDict
from datasets import Dataset, DatasetDict, List, Value
from datasets.formatting.formatting import LazyBatch
from data_juicer.core.data import NestedDataset, wrap_func_with_nested_access
from data_juicer.core.data.dj_dataset import nested_obj_factory, NestedDatasetDict, NestedQueryDict
Expand Down Expand Up @@ -283,6 +283,42 @@ def __init__(self):
result = nested_obj_factory(None)
self.assertIsNone(result)

def test_map_output_feature_hints_allow_empty_nested_list_first_batch(self):
dataset = NestedDataset(
Dataset.from_list(
[
{'text': 'empty', 'meta': {'existing': 1}},
{'text': 'hit', 'meta': {'existing': 2}},
]
)
)

def add_bbox(batch):
metas = batch['meta']
existing_values = metas['existing'] if isinstance(metas, dict) else [meta['existing'] for meta in metas]
return {
'meta': [
{
'existing': existing,
'bbox': [] if text == 'empty' else [[1.0, 2.0, 3.0, 4.0]],
}
for text, existing in zip(batch['text'], existing_values)
]
}

mapped = dataset.map(
add_bbox,
batched=True,
batch_size=1,
output_feature_hints={'meta': {'bbox': List(List(Value('float32')))}},
)

self.assertEqual(mapped[0]['meta']['bbox'], [])
self.assertEqual(mapped[1]['meta']['bbox'], [[1.0, 2.0, 3.0, 4.0]])
self.assertEqual(mapped[0]['meta']['existing'], 1)
self.assertIn('existing', mapped.features['meta'])
self.assertIn('bbox', mapped.features['meta'])

def test_nested_dataset(self):
import pyarrow as pa
table = pa.Table.from_pydict({"text": ["hello", "world"]})
Expand Down
Loading
Loading