From 1a037f4b0579b54c60c70a46b0c88d5e3b647a1f Mon Sep 17 00:00:00 2001 From: cmgzn Date: Mon, 8 Jun 2026 18:26:48 +0800 Subject: [PATCH 01/13] test: add unit tests for file_utils, empty_formatter, formatter, config functions - file_utils: byte_size_to_size_str, is_remote_path, get_all_files_paths_under, single_partition_write_with_filename, read_single_partition, expand_outdir_and_mkdir - empty_formatter: multiple feature_keys, string-to-list conversion, null_value, zero length - formatter: audio/video relative-to-absolute path conversion, mixed media keys - config: resolve_job_id, validate_work_dir_config, resolve_job_directories --- tests/config/test_config_functions.py | 597 +++++++++++++++++++++++ tests/core/executor/test_pipeline_dag.py | 317 ++++++++++++ tests/download/test_download.py | 68 +++ tests/format/test_empty_formatter.py | 44 ++ tests/format/test_formatter.py | 322 +++++++++++- tests/format/test_json_formatter.py | 77 +++ tests/ops/test_mixins.py | 595 ++++++++++++++++++++++ tests/utils/test_file_utils.py | 304 ++++++++++++ uv.lock | 16 +- 9 files changed, 2306 insertions(+), 34 deletions(-) create mode 100644 tests/config/test_config_functions.py create mode 100644 tests/ops/test_mixins.py diff --git a/tests/config/test_config_functions.py b/tests/config/test_config_functions.py new file mode 100644 index 00000000000..3735c203cf0 --- /dev/null +++ b/tests/config/test_config_functions.py @@ -0,0 +1,597 @@ +"""Tests for pure-logic utility functions in data_juicer/config/config.py. + +Covers: timing_context, _generate_module_name, load_custom_operators, +sort_op_by_types_and_names, _parse_cli_to_config, _parse_value, +config_backup, validate_config_for_resumption, prepare_side_configs (edge), +namespace_to_arg_list. +""" +import json +import os +import shutil +import sys +import tempfile +import time +import unittest + +import yaml +from jsonargparse import Namespace + +from data_juicer.config.config import ( + _generate_module_name, + _parse_cli_to_config, + _parse_value, + config_backup, + load_custom_operators, + prepare_side_configs, + resolve_job_directories, + resolve_job_id, + sort_op_by_types_and_names, + timing_context, + validate_config_for_resumption, + validate_work_dir_config, +) +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TimingContextTest(DataJuicerTestCaseBase): + + def test_timing_context_runs_block(self): + executed = False + with timing_context("test block"): + executed = True + self.assertTrue(executed) + + def test_timing_context_measures_time(self): + start = time.time() + with timing_context("sleep"): + time.sleep(0.05) + elapsed = time.time() - start + self.assertGreaterEqual(elapsed, 0.04) + + +class GenerateModuleNameTest(DataJuicerTestCaseBase): + + def test_simple_path(self): + self.assertEqual(_generate_module_name("/foo/bar/my_module.py"), "my_module") + + def test_nested_path(self): + self.assertEqual(_generate_module_name("/a/b/c/ops.py"), "ops") + + def test_no_extension(self): + self.assertEqual(_generate_module_name("/foo/bar/module"), "module") + + def test_basename_only(self): + self.assertEqual(_generate_module_name("script.py"), "script") + + +class LoadCustomOperatorsTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + # Clean up any modules we loaded + for key in list(sys.modules.keys()): + if key.startswith("_test_custom_op"): + del sys.modules[key] + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_load_single_file(self): + op_file = os.path.join(self.tmp_dir, "_test_custom_op_file.py") + with open(op_file, "w") as f: + f.write("LOADED = True\n") + + load_custom_operators([op_file]) + self.assertIn("_test_custom_op_file", sys.modules) + self.assertTrue(sys.modules["_test_custom_op_file"].LOADED) + + def test_load_package(self): + pkg_dir = os.path.join(self.tmp_dir, "_test_custom_op_pkg") + os.makedirs(pkg_dir) + with open(os.path.join(pkg_dir, "__init__.py"), "w") as f: + f.write("PKG_LOADED = True\n") + + load_custom_operators([pkg_dir]) + self.assertIn("_test_custom_op_pkg", sys.modules) + + def test_load_package_missing_init_raises(self): + pkg_dir = os.path.join(self.tmp_dir, "_test_custom_op_noinit") + os.makedirs(pkg_dir) + # No __init__.py + + with self.assertRaises(ValueError) as ctx: + load_custom_operators([pkg_dir]) + self.assertIn("__init__.py", str(ctx.exception)) + + def test_load_nonexistent_path_raises(self): + with self.assertRaises(ValueError) as ctx: + load_custom_operators(["/nonexistent/path/to/module.py"]) + self.assertIn("neither a file nor a directory", str(ctx.exception)) + + def test_load_duplicate_module_raises(self): + op_file = os.path.join(self.tmp_dir, "_test_custom_op_dup.py") + with open(op_file, "w") as f: + f.write("X = 1\n") + + load_custom_operators([op_file]) + with self.assertRaises(RuntimeError) as ctx: + load_custom_operators([op_file]) + self.assertIn("already loaded", str(ctx.exception)) + + def test_load_file_with_syntax_error_raises(self): + op_file = os.path.join(self.tmp_dir, "_test_custom_op_bad.py") + with open(op_file, "w") as f: + f.write("def broken(\n") # syntax error + + with self.assertRaises(RuntimeError) as ctx: + load_custom_operators([op_file]) + self.assertIn("Error loading", str(ctx.exception)) + + +class SortOpByTypesAndNamesTest(DataJuicerTestCaseBase): + + def test_sorts_by_type_then_name(self): + ops = [ + ("z_filter", "FilterClass"), + ("a_mapper", "MapperClass"), + ("b_deduplicator", "DedupClass"), + ("c_mapper", "MapperClass2"), + ("a_selector", "SelectorClass"), + ("b_grouper", "GrouperClass"), + ("a_aggregator", "AggClass"), + ] + result = sort_op_by_types_and_names(ops) + names = [name for name, _ in result] + self.assertEqual(names, [ + "a_mapper", "c_mapper", # mappers first, sorted + "z_filter", # filters + "b_deduplicator", # deduplicators + "a_selector", # selectors + "b_grouper", # groupers + "a_aggregator", # aggregators + ]) + + def test_empty_list(self): + result = sort_op_by_types_and_names([]) + self.assertEqual(result, []) + + def test_single_type(self): + ops = [("b_mapper", "B"), ("a_mapper", "A")] + result = sort_op_by_types_and_names(ops) + self.assertEqual([n for n, _ in result], ["a_mapper", "b_mapper"]) + + +class ParseValueTest(DataJuicerTestCaseBase): + + def test_parse_true(self): + self.assertIs(_parse_value("true"), True) + self.assertIs(_parse_value("True"), True) + self.assertIs(_parse_value("TRUE"), True) + + def test_parse_false(self): + self.assertIs(_parse_value("false"), False) + self.assertIs(_parse_value("False"), False) + + def test_parse_integer(self): + self.assertEqual(_parse_value("42"), 42) + self.assertIsInstance(_parse_value("42"), int) + + def test_parse_negative_integer(self): + self.assertEqual(_parse_value("-7"), -7) + + def test_parse_float(self): + self.assertAlmostEqual(_parse_value("3.14"), 3.14) + self.assertIsInstance(_parse_value("3.14"), float) + + def test_parse_float_scientific(self): + self.assertAlmostEqual(_parse_value("1e-3"), 0.001) + + def test_parse_string(self): + self.assertEqual(_parse_value("hello"), "hello") + self.assertIsInstance(_parse_value("hello"), str) + + def test_parse_path(self): + self.assertEqual(_parse_value("/path/to/file.yaml"), "/path/to/file.yaml") + + +class ParseCliToConfigTest(DataJuicerTestCaseBase): + + def test_empty_args(self): + self.assertEqual(_parse_cli_to_config([]), {}) + + def test_key_value_pair(self): + result = _parse_cli_to_config(["--name", "test"]) + self.assertEqual(result.get("name"), "test") + + def test_key_equals_value(self): + result = _parse_cli_to_config(["--count=5"]) + self.assertEqual(result.get("count"), 5) + + def test_boolean_flag(self): + result = _parse_cli_to_config(["--debug"]) + self.assertEqual(result.get("debug"), True) + + def test_multiple_values(self): + result = _parse_cli_to_config(["--items", "a", "b", "c"]) + self.assertEqual(result.get("items"), ["a", "b", "c"]) + + def test_mixed_args(self): + result = _parse_cli_to_config([ + "--name", "test", + "--count=10", + "--verbose", + ]) + self.assertEqual(result.get("name"), "test") + self.assertEqual(result.get("count"), 10) + self.assertEqual(result.get("verbose"), True) + + +class ConfigBackupTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.work_dir = os.path.join(self.tmp_dir, "work") + os.makedirs(self.work_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_backup_copies_config_file(self): + cfg_file = os.path.join(self.tmp_dir, "test.yaml") + with open(cfg_file, "w") as f: + f.write("key: value\n") + + target_path = os.path.join(self.work_dir, "test.yaml") + cfg = Namespace( + config=[cfg_file], + work_dir=self.work_dir, + backed_up_config_path=target_path, + _original_args=[], + ) + + config_backup(cfg) + self.assertTrue(os.path.exists(target_path)) + with open(target_path) as f: + self.assertEqual(f.read(), "key: value\n") + + def test_backup_skips_if_exists(self): + cfg_file = os.path.join(self.tmp_dir, "test.yaml") + with open(cfg_file, "w") as f: + f.write("original\n") + + target_path = os.path.join(self.work_dir, "test.yaml") + with open(target_path, "w") as f: + f.write("already_there\n") + + cfg = Namespace( + config=[cfg_file], + work_dir=self.work_dir, + backed_up_config_path=target_path, + _original_args=[], + ) + config_backup(cfg) + + with open(target_path) as f: + self.assertEqual(f.read(), "already_there\n") + + def test_backup_no_config_does_nothing(self): + cfg = Namespace(config=None, work_dir=self.work_dir) + # Should not raise + config_backup(cfg) + + def test_backup_fallback_without_backed_up_path(self): + cfg_file = os.path.join(self.tmp_dir, "fallback.yaml") + with open(cfg_file, "w") as f: + f.write("fallback_data\n") + + cfg = Namespace( + config=[cfg_file], + work_dir=self.work_dir, + _original_args=[], + ) + + config_backup(cfg) + expected_path = os.path.join(self.work_dir, "fallback.yaml") + self.assertTrue(os.path.exists(expected_path)) + + +class ValidateConfigForResumptionTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.work_dir = os.path.join(self.tmp_dir, "work") + os.makedirs(self.work_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_matching_configs_returns_true(self): + config_content = "dataset_path: ./data\nprocess:\n - clean_mapper:\n" + # Write "original" config in work_dir + orig_cfg = os.path.join(self.work_dir, "config.yaml") + with open(orig_cfg, "w") as f: + f.write(config_content) + + # Write a matching CLI args file with specific args + cli_yaml = os.path.join(self.work_dir, "cli.yaml") + with open(cli_yaml, "w") as f: + yaml.dump({"arguments": ["--np", "4"]}, f) + + # Write "current" config somewhere else (same content) + cur_cfg = os.path.join(self.tmp_dir, "current.yaml") + with open(cur_cfg, "w") as f: + f.write(config_content) + + cfg = Namespace(config=[cur_cfg]) + # Pass same CLI args as saved in cli.yaml + result = validate_config_for_resumption( + cfg, self.work_dir, original_args=["--np", "4"] + ) + self.assertTrue(result) + self.assertTrue(cfg._same_yaml_config) + + def test_mismatched_configs_returns_false(self): + orig_cfg = os.path.join(self.work_dir, "config.yaml") + with open(orig_cfg, "w") as f: + f.write("dataset_path: ./old_data\n") + + cur_cfg = os.path.join(self.tmp_dir, "current.yaml") + with open(cur_cfg, "w") as f: + f.write("dataset_path: ./new_data\n") + + cfg = Namespace(config=[cur_cfg]) + result = validate_config_for_resumption(cfg, self.work_dir) + self.assertFalse(result) + self.assertFalse(cfg._same_yaml_config) + + def test_no_config_files_returns_false(self): + empty_dir = os.path.join(self.tmp_dir, "empty") + os.makedirs(empty_dir) + + cfg = Namespace() + result = validate_config_for_resumption(cfg, empty_dir) + self.assertFalse(result) + + def test_no_current_config_returns_false(self): + orig_cfg = os.path.join(self.work_dir, "config.yaml") + with open(orig_cfg, "w") as f: + f.write("x: 1\n") + + cfg = Namespace() # no config attribute + result = validate_config_for_resumption(cfg, self.work_dir) + self.assertFalse(result) + + def test_cli_yaml_comparison(self): + config_content = "dataset_path: ./data\n" + orig_cfg = os.path.join(self.work_dir, "config.yaml") + with open(orig_cfg, "w") as f: + f.write(config_content) + + # Write CLI args + cli_yaml = os.path.join(self.work_dir, "cli.yaml") + with open(cli_yaml, "w") as f: + yaml.dump({"arguments": ["--np", "4"]}, f) + + cur_cfg = os.path.join(self.tmp_dir, "current.yaml") + with open(cur_cfg, "w") as f: + f.write(config_content) + + cfg = Namespace(config=[cur_cfg]) + # Pass different CLI args + result = validate_config_for_resumption( + cfg, self.work_dir, original_args=["--np", "8"] + ) + self.assertFalse(result) + + +class PrepareSideConfigsEdgeCasesTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_json_file(self): + cfg_file = os.path.join(self.tmp_dir, "config.json") + with open(cfg_file, "w") as f: + json.dump({"key": "value"}, f) + result = prepare_side_configs(cfg_file) + self.assertEqual(result, {"key": "value"}) + + def test_unsupported_extension_raises(self): + cfg_file = os.path.join(self.tmp_dir, "config.toml") + with open(cfg_file, "w") as f: + f.write("") + with self.assertRaises(TypeError): + prepare_side_configs(cfg_file) + + def test_unsupported_type_raises(self): + with self.assertRaises(TypeError): + prepare_side_configs(12345) + + def test_namespace_input(self): + ns = Namespace(a=1, b=2) + result = prepare_side_configs(ns) + self.assertEqual(result, ns) + + +class ResolveJobIdTest(DataJuicerTestCaseBase): + """Test resolve_job_id: auto-generates or preserves user-provided job_id.""" + + def test_auto_generates_job_id_when_missing(self): + cfg = Namespace() + result = resolve_job_id(cfg) + self.assertTrue(hasattr(result, "job_id")) + self.assertIsInstance(result.job_id, str) + self.assertGreater(len(result.job_id), 0) + self.assertFalse(result._user_provided_job_id) + + def test_auto_generated_format(self): + """Auto-generated job_id should be timestamp_hash format.""" + cfg = Namespace() + resolve_job_id(cfg) + parts = cfg.job_id.split("_") + # Format: YYYYMMDD_HHMMSS_hexhash + self.assertEqual(len(parts), 3) + self.assertEqual(len(parts[0]), 8) # date + self.assertEqual(len(parts[1]), 6) # time + self.assertEqual(len(parts[2]), 6) # hex hash + + def test_preserves_user_provided_job_id(self): + cfg = Namespace(job_id="my_custom_job") + result = resolve_job_id(cfg) + self.assertEqual(result.job_id, "my_custom_job") + self.assertTrue(result._user_provided_job_id) + + def test_uniqueness(self): + """Two calls should produce different job_ids.""" + cfg1 = Namespace() + cfg2 = Namespace() + resolve_job_id(cfg1) + resolve_job_id(cfg2) + self.assertNotEqual(cfg1.job_id, cfg2.job_id) + + +class ValidateWorkDirConfigTest(DataJuicerTestCaseBase): + """Test validate_work_dir_config: ensures {job_id} is at end of path.""" + + def test_valid_job_id_at_end(self): + # Should not raise + validate_work_dir_config("./outputs/project/{job_id}") + + def test_valid_absolute_path(self): + validate_work_dir_config("/data/experiments/{job_id}") + + def test_valid_no_job_id_placeholder(self): + """Paths without {job_id} are valid (job_id appended later).""" + validate_work_dir_config("./outputs/project") + + def test_invalid_job_id_not_at_end(self): + with self.assertRaises(ValueError) as ctx: + validate_work_dir_config("./outputs/{job_id}/results") + self.assertIn("last part", str(ctx.exception)) + + def test_invalid_job_id_in_middle(self): + with self.assertRaises(ValueError) as ctx: + validate_work_dir_config("./{job_id}/outputs/data") + self.assertIn("last part", str(ctx.exception)) + + def test_trailing_slash_still_valid(self): + validate_work_dir_config("./outputs/{job_id}/") + + +class ResolveJobDirectoriesTest(DataJuicerTestCaseBase): + """Test resolve_job_directories: sets up all job-specific directories.""" + + def test_basic_directory_resolution(self): + cfg = Namespace( + work_dir="./outputs/project", + job_id="test_job_123", + config=["test.yaml"], + event_log_dir=None, + checkpoint_dir=None, + partition_dir=None, + ) + result = resolve_job_directories(cfg) + self.assertTrue(result.work_dir.endswith("test_job_123")) + self.assertEqual(result.event_log_dir, + os.path.join(result.work_dir, "logs")) + self.assertEqual(result.checkpoint_dir, + os.path.join(result.work_dir, "checkpoints")) + self.assertEqual(result.partition_dir, + os.path.join(result.work_dir, "partitions")) + self.assertEqual(result.metadata_dir, + os.path.join(result.work_dir, "metadata")) + self.assertEqual(result.results_dir, + os.path.join(result.work_dir, "results")) + + def test_job_id_placeholder_substitution(self): + cfg = Namespace( + work_dir="./outputs/{job_id}", + job_id="abc123", + config=["cfg.yaml"], + event_log_dir=None, + checkpoint_dir=None, + partition_dir=None, + ) + result = resolve_job_directories(cfg) + self.assertTrue(result.work_dir.endswith("abc123")) + self.assertNotIn("{job_id}", result.work_dir) + + def test_work_dir_placeholder_in_other_paths(self): + cfg = Namespace( + work_dir="./outputs/proj", + job_id="j1", + config=["c.yaml"], + event_log_dir="{work_dir}/my_logs", + checkpoint_dir=None, + partition_dir=None, + ) + result = resolve_job_directories(cfg) + self.assertIn("my_logs", result.event_log_dir) + self.assertNotIn("{work_dir}", result.event_log_dir) + + def test_no_config_uses_default_backup_path(self): + cfg = Namespace( + work_dir="./outputs", + job_id="j2", + config=None, + event_log_dir=None, + checkpoint_dir=None, + partition_dir=None, + ) + result = resolve_job_directories(cfg) + self.assertTrue(result.backed_up_config_path.endswith("config.yaml")) + + def test_missing_job_id_raises(self): + cfg = Namespace( + work_dir="./outputs", + job_id="", + config=None, + event_log_dir=None, + checkpoint_dir=None, + partition_dir=None, + ) + with self.assertRaises(ValueError): + resolve_job_directories(cfg) + + def test_custom_dirs_preserved(self): + """Custom event_log_dir/checkpoint_dir should not be overridden.""" + cfg = Namespace( + work_dir="./outputs", + job_id="j3", + config=["c.yaml"], + event_log_dir="/custom/logs", + checkpoint_dir="/custom/ckpts", + partition_dir="/custom/parts", + ) + result = resolve_job_directories(cfg) + self.assertEqual(result.event_log_dir, "/custom/logs") + self.assertEqual(result.checkpoint_dir, "/custom/ckpts") + self.assertEqual(result.partition_dir, "/custom/parts") + + def test_event_log_file_set(self): + cfg = Namespace( + work_dir="./outputs", + job_id="j4", + config=["c.yaml"], + event_log_dir=None, + checkpoint_dir=None, + partition_dir=None, + ) + result = resolve_job_directories(cfg) + self.assertTrue(result.event_log_file.endswith("events.jsonl")) + self.assertTrue(result.job_summary_file.endswith("job_summary.json")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/executor/test_pipeline_dag.py b/tests/core/executor/test_pipeline_dag.py index a6a819bb3ca..dbd5c82ac7a 100644 --- a/tests/core/executor/test_pipeline_dag.py +++ b/tests/core/executor/test_pipeline_dag.py @@ -190,5 +190,322 @@ def __init__(self): self.assertFalse(is_global_operation(filter_op)) +def _make_node(node_id, operation_name="op", dependencies=None, + partition_id=None, execution_order=0): + """Helper to create a DAG node dict.""" + return { + "node_id": node_id, + "operation_name": operation_name, + "node_type": "operation", + "partition_id": partition_id, + "config": {}, + "dependencies": dependencies or [], + "execution_order": execution_order, + "estimated_duration": 0.0, + "metadata": {}, + "status": DAGNodeStatus.PENDING.value, + "actual_duration": None, + "start_time": None, + "end_time": None, + "error_message": None, + } + + +class PipelineDAGFailureTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.tmp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_mark_node_failed_with_start_time(self): + self.dag.nodes["n1"] = _make_node("n1") + self.dag.mark_node_started("n1") + start_time = self.dag.nodes["n1"]["start_time"] + + import time + time.sleep(0.02) + self.dag.mark_node_failed("n1", "something broke") + + node = self.dag.nodes["n1"] + self.assertEqual(node["status"], DAGNodeStatus.FAILED.value) + self.assertEqual(node["error_message"], "something broke") + self.assertGreater(node["actual_duration"], 0) + self.assertEqual(node["start_time"], start_time) + + def test_mark_node_failed_without_start_time(self): + """When a node fails before being started, duration should be ~0.""" + self.dag.nodes["n1"] = _make_node("n1") + self.dag.mark_node_failed("n1", "failed early") + + node = self.dag.nodes["n1"] + self.assertEqual(node["status"], DAGNodeStatus.FAILED.value) + self.assertIsNotNone(node["actual_duration"]) + self.assertLessEqual(node["actual_duration"], 1.0) + + def test_mark_node_failed_nonexistent_node(self): + self.dag.mark_node_failed("nonexistent", "error") + self.assertNotIn("nonexistent", self.dag.nodes) + + +class PipelineDAGReadyNodesTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.tmp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_ready_nodes_no_dependencies(self): + self.dag.nodes["a"] = _make_node("a", "op_a") + self.dag.nodes["b"] = _make_node("b", "op_b") + ready = self.dag.get_ready_nodes() + self.assertEqual(set(ready), {"a", "b"}) + + def test_ready_nodes_with_dependencies(self): + self.dag.nodes["a"] = _make_node("a", "op_a") + self.dag.nodes["b"] = _make_node("b", "op_b", dependencies=["a"]) + + ready = self.dag.get_ready_nodes() + self.assertEqual(ready, ["a"]) + + self.dag.mark_node_started("a") + self.dag.mark_node_completed("a") + ready = self.dag.get_ready_nodes() + self.assertEqual(ready, ["b"]) + + def test_ready_nodes_skips_non_pending(self): + self.dag.nodes["a"] = _make_node("a", "op_a") + self.dag.mark_node_started("a") + + ready = self.dag.get_ready_nodes() + self.assertEqual(ready, []) + + def test_ready_nodes_multiple_deps(self): + self.dag.nodes["a"] = _make_node("a", "op_a", execution_order=0) + self.dag.nodes["b"] = _make_node("b", "op_b", execution_order=1) + self.dag.nodes["c"] = _make_node("c", "op_c", + dependencies=["a", "b"], + execution_order=2) + + self.assertNotIn("c", self.dag.get_ready_nodes()) + + self.dag.mark_node_started("a") + self.dag.mark_node_completed("a") + self.assertNotIn("c", self.dag.get_ready_nodes()) + + self.dag.mark_node_started("b") + self.dag.mark_node_completed("b") + self.assertIn("c", self.dag.get_ready_nodes()) + + def test_ready_nodes_dep_failed(self): + self.dag.nodes["a"] = _make_node("a", "op_a") + self.dag.nodes["b"] = _make_node("b", "op_b", dependencies=["a"]) + + self.dag.mark_node_failed("a", "crash") + ready = self.dag.get_ready_nodes() + self.assertNotIn("b", ready) + + +class PipelineDAGStatusTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.tmp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_unknown_node_returns_pending(self): + status = self.dag.get_node_status("nonexistent") + self.assertEqual(status, DAGNodeStatus.PENDING) + + def test_status_transitions(self): + self.dag.nodes["n1"] = _make_node("n1") + + self.assertEqual(self.dag.get_node_status("n1"), DAGNodeStatus.PENDING) + + self.dag.mark_node_started("n1") + self.assertEqual(self.dag.get_node_status("n1"), DAGNodeStatus.RUNNING) + + self.dag.mark_node_completed("n1") + self.assertEqual(self.dag.get_node_status("n1"), + DAGNodeStatus.COMPLETED) + + +class PipelineDAGVisualizeTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.tmp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_empty_dag(self): + result = self.dag.visualize() + self.assertEqual(result, "Empty DAG") + + def test_visualize_with_partitions(self): + self.dag.nodes["n1"] = _make_node("n1", "clean_mapper", + partition_id=0, execution_order=0) + self.dag.nodes["n2"] = _make_node("n2", "clean_mapper", + partition_id=1, execution_order=1) + + output = self.dag.visualize() + self.assertIn("partition 0", output) + self.assertIn("partition 1", output) + self.assertIn("clean_mapper", output) + + def test_visualize_with_dependencies(self): + self.dag.nodes["a"] = _make_node("a", "step_a", execution_order=0) + self.dag.nodes["b"] = _make_node("b", "step_b", + dependencies=["a"], + execution_order=1) + + output = self.dag.visualize() + self.assertIn("Dependencies:", output) + self.assertIn("step_b <- step_a", output) + + def test_visualize_status_icons(self): + self.dag.nodes["a"] = _make_node("a", "op_a", execution_order=0) + self.dag.mark_node_started("a") + + self.dag.nodes["b"] = _make_node("b", "op_b", execution_order=1) + self.dag.mark_node_started("b") + self.dag.mark_node_completed("b") + + self.dag.nodes["c"] = _make_node("c", "op_c", execution_order=2) + self.dag.mark_node_failed("c", "err") + + output = self.dag.visualize() + self.assertIn("[~]", output) + self.assertIn("[x]", output) + self.assertIn("[!]", output) + + def test_visualize_no_partition(self): + self.dag.nodes["a"] = _make_node("a", "op_a", execution_order=0) + output = self.dag.visualize() + self.assertNotIn("partition", output) + + +class PipelineDAGLoadPlanTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.tmp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_load_nonexistent_file(self): + result = self.dag.load_execution_plan("does_not_exist.json") + self.assertFalse(result) + + def test_save_and_load_roundtrip(self): + self.dag.nodes["n1"] = _make_node("n1", "op_a", execution_order=0) + self.dag.nodes["n2"] = _make_node("n2", "op_b", + dependencies=["n1"], + execution_order=1) + + self.dag.save_execution_plan() + + new_dag = PipelineDAG(self.tmp_dir) + loaded = new_dag.load_execution_plan() + self.assertTrue(loaded) + self.assertIn("n1", new_dag.nodes) + self.assertIn("n2", new_dag.nodes) + self.assertEqual(new_dag.nodes["n2"]["dependencies"], ["n1"]) + self.assertEqual(new_dag.nodes["n1"]["status"], + DAGNodeStatus.PENDING.value) + + def test_load_corrupted_file(self): + plan_path = os.path.join(self.tmp_dir, "dag_execution_plan.json") + with open(plan_path, "w") as f: + f.write("not valid json{{{") + + result = self.dag.load_execution_plan() + self.assertFalse(result) + + +class PipelineDAGCompletionTest(DataJuicerTestCaseBase): + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.dag = PipelineDAG(self.tmp_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_mark_completed_auto_duration(self): + import time + self.dag.nodes["n1"] = _make_node("n1") + self.dag.mark_node_started("n1") + time.sleep(0.02) + self.dag.mark_node_completed("n1") + + node = self.dag.nodes["n1"] + self.assertGreater(node["actual_duration"], 0) + + def test_mark_completed_explicit_duration(self): + self.dag.nodes["n1"] = _make_node("n1") + self.dag.mark_node_started("n1") + self.dag.mark_node_completed("n1", duration=42.0) + + self.assertEqual(self.dag.nodes["n1"]["actual_duration"], 42.0) + + def test_mark_completed_without_start(self): + """Complete a node that was never started - duration should be ~0.""" + self.dag.nodes["n1"] = _make_node("n1") + self.dag.mark_node_completed("n1") + + node = self.dag.nodes["n1"] + self.assertEqual(node["status"], DAGNodeStatus.COMPLETED.value) + self.assertLessEqual(node["actual_duration"], 1.0) + + def test_execution_summary_mixed_states(self): + self.dag.nodes["a"] = _make_node("a", execution_order=0) + self.dag.nodes["b"] = _make_node("b", execution_order=1) + self.dag.nodes["c"] = _make_node("c", execution_order=2) + self.dag.nodes["d"] = _make_node("d", execution_order=3) + + self.dag.mark_node_started("a") + self.dag.mark_node_completed("a", duration=1.0) + + self.dag.mark_node_started("b") + self.dag.mark_node_failed("b", "err") + + self.dag.mark_node_started("c") + + summary = self.dag.get_execution_summary() + self.assertEqual(summary["total_nodes"], 4) + self.assertEqual(summary["completed_nodes"], 1) + self.assertEqual(summary["failed_nodes"], 1) + self.assertEqual(summary["running_nodes"], 1) + self.assertEqual(summary["pending_nodes"], 1) + self.assertAlmostEqual(summary["completion_percentage"], 25.0) + + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/download/test_download.py b/tests/download/test_download.py index d3e9a15e906..8570744a85a 100644 --- a/tests/download/test_download.py +++ b/tests/download/test_download.py @@ -163,5 +163,73 @@ def test_wikipedia_download(self, mock_download_and_extract_wiki, mock_download_ assert all(field in result.features for field in expected_format.keys()) +class ValidateSnapshotFormatTest(DataJuicerTestCaseBase): + + def test_none_is_valid(self): + from data_juicer.download.downloader import validate_snapshot_format + validate_snapshot_format(None) + + def test_valid_format(self): + from data_juicer.download.downloader import validate_snapshot_format + validate_snapshot_format("2020-50") + validate_snapshot_format("2024-01") + validate_snapshot_format("2024-53") + + def test_invalid_format_no_dash(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError) as ctx: + validate_snapshot_format("202050") + self.assertIn("Invalid snapshot format", str(ctx.exception)) + + def test_invalid_format_extra_parts(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError): + validate_snapshot_format("2020-50-01") + + def test_invalid_format_letters(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError): + validate_snapshot_format("abcd-ef") + + def test_year_too_low(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError) as ctx: + validate_snapshot_format("1999-01") + self.assertIn("Year must be between", str(ctx.exception)) + + def test_year_too_high(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError) as ctx: + validate_snapshot_format("2101-01") + self.assertIn("Year must be between", str(ctx.exception)) + + def test_week_zero(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError) as ctx: + validate_snapshot_format("2020-00") + self.assertIn("Week must be between", str(ctx.exception)) + + def test_week_too_high(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError) as ctx: + validate_snapshot_format("2020-54") + self.assertIn("Week must be between", str(ctx.exception)) + + def test_boundary_valid_year(self): + from data_juicer.download.downloader import validate_snapshot_format + validate_snapshot_format("2000-01") + validate_snapshot_format("2100-53") + + def test_boundary_valid_week(self): + from data_juicer.download.downloader import validate_snapshot_format + validate_snapshot_format("2024-01") + validate_snapshot_format("2024-53") + + def test_empty_string(self): + from data_juicer.download.downloader import validate_snapshot_format + with self.assertRaises(ValueError): + validate_snapshot_format("") + + if __name__ == '__main__': unittest.main() diff --git a/tests/format/test_empty_formatter.py b/tests/format/test_empty_formatter.py index d9400777c5b..a0d24118378 100644 --- a/tests/format/test_empty_formatter.py +++ b/tests/format/test_empty_formatter.py @@ -40,5 +40,49 @@ def filter_fn(sample): self.assertEqual(len(ds), 0) + def test_multiple_feature_keys(self): + """Multiple feature_keys should create columns for each key.""" + keys = ['text', 'meta', 'label'] + ds_len = 5 + formatter = EmptyFormatter(length=ds_len, feature_keys=keys) + ds = formatter.load_dataset() + + self.assertEqual(len(ds), ds_len) + self.assertEqual(sorted(ds.features.keys()), sorted(keys)) + for item in ds: + for key in keys: + self.assertIsNone(item[key]) + + def test_string_feature_keys_converted_to_list(self): + """A single string feature_key should be auto-wrapped into a list.""" + formatter = EmptyFormatter(length=3, feature_keys='text') + self.assertIsInstance(formatter.feature_keys, list) + self.assertEqual(formatter.feature_keys, ['text']) + + ds = formatter.load_dataset() + self.assertEqual(len(ds), 3) + self.assertEqual(list(ds.features.keys()), ['text']) + + def test_null_value_property(self): + """EmptyFormatter.null_value should return None.""" + formatter = EmptyFormatter(length=1, feature_keys=['text']) + self.assertIsNone(formatter.null_value) + + def test_zero_length(self): + """length=0 should produce an empty dataset with correct schema.""" + formatter = EmptyFormatter(length=0, feature_keys=['text']) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 0) + self.assertIn('text', ds.features) + + def test_no_feature_keys(self): + """Empty feature_keys list produces a 0-row dataset (no columns → no rows).""" + formatter = EmptyFormatter(length=3, feature_keys=[]) + ds = formatter.load_dataset() + # Dataset.from_dict({}) yields 0 rows regardless of requested length + self.assertEqual(len(ds), 0) + self.assertEqual(list(ds.features.keys()), []) + + if __name__ == '__main__': unittest.main() diff --git a/tests/format/test_formatter.py b/tests/format/test_formatter.py index 2514fedef34..1bfba774b79 100644 --- a/tests/format/test_formatter.py +++ b/tests/format/test_formatter.py @@ -352,26 +352,6 @@ def test_hetero_meta(self): file_path = os.path.join(cur_dir, 'demo-dataset.jsonl') ds = load_dataset('json', data_files=file_path, split='train') ds = unify_format(ds) - # import datetime - - # the 'None' fields are missing fields after merging - # sample = [{ - # 'text': "Today is Sunday and it's a happy day!", - # 'meta': { - # 'src': 'Arxiv', - # 'date': datetime.datetime(2023, 4, 27, 0, 0), - # 'version': '1.0', - # 'author': None - # } - # }, { - # 'text': 'Do you need a cup of coffee?', - # 'meta': { - # 'src': 'code', - # 'date': None, - # 'version': None, - # 'author': 'xxx' - # } - # }] # test nested and missing field for the following cases: # Fields present in a row are always accessible; fields absent in the raw # data may be filled with None (datasets <=4.4 struct merge) OR simply @@ -577,3 +557,305 @@ def tracking_ntf(*args, **kwargs): os.path.exists(p), f"Temporary file {p} was not cleaned up after load_dataset", ) + + +# --------------------------------------------------------------------------- +# Additional coverage tests for unify_format edge cases and add_suffixes +# --------------------------------------------------------------------------- + +from datasets import Dataset as HFDataset, DatasetDict +from jsonargparse import Namespace as JANamespace +from data_juicer.format.formatter import add_suffixes + + +class UnifyFormatNoneFilterTest(DataJuicerTestCaseBase): + """Test that unify_format filters out samples with None text.""" + + def test_filters_none_text(self): + ds = HFDataset.from_dict({ + "text": ["hello", None, "world", None], + }) + result = unify_format(ds, text_keys=["text"]) + self.assertEqual(len(result), 2) + self.assertEqual(list(result["text"]), ["hello", "world"]) + + def test_keeps_all_non_none(self): + ds = HFDataset.from_dict({"text": ["a", "b", "c"]}) + result = unify_format(ds, text_keys=["text"]) + self.assertEqual(len(result), 3) + + def test_filters_all_none(self): + ds = HFDataset.from_dict({"text": [None, None]}) + result = unify_format(ds, text_keys=["text"]) + self.assertEqual(len(result), 0) + + def test_empty_dataset(self): + ds = HFDataset.from_dict({"text": []}) + result = unify_format(ds, text_keys=["text"]) + self.assertEqual(len(result), 0) + + +class UnifyFormatTextKeysTest(DataJuicerTestCaseBase): + + def test_missing_text_key_raises(self): + ds = HFDataset.from_dict({"content": ["hello"]}) + with self.assertRaises(ValueError) as ctx: + unify_format(ds, text_keys=["text"]) + self.assertIn("no key [text]", str(ctx.exception).lower()) + + def test_string_text_key_converted_to_list(self): + ds = HFDataset.from_dict({"text": ["hello"]}) + result = unify_format(ds, text_keys="text") + self.assertEqual(len(result), 1) + + def test_none_text_keys_skips_filtering(self): + ds = HFDataset.from_dict({"text": ["hello", None]}) + result = unify_format(ds, text_keys=None) + self.assertEqual(len(result), 2) + + def test_empty_text_keys_skips_filtering(self): + ds = HFDataset.from_dict({"text": ["hello", None]}) + result = unify_format(ds, text_keys=[]) + self.assertEqual(len(result), 2) + + +class UnifyFormatDatasetDictTest(DataJuicerTestCaseBase): + + def test_unwraps_single_split_datasetdict(self): + ds = HFDataset.from_dict({"text": ["hello", "world"]}) + dd = DatasetDict({"train": ds}) + result = unify_format(dd, text_keys=["text"]) + self.assertEqual(len(result), 2) + + def test_multiple_splits_raises(self): + ds1 = HFDataset.from_dict({"text": ["a"]}) + ds2 = HFDataset.from_dict({"text": ["b"]}) + dd = DatasetDict({"train": ds1, "test": ds2}) + with self.assertRaises(AssertionError): + unify_format(dd, text_keys=["text"]) + + +class UnifyFormatPathConversionTest(DataJuicerTestCaseBase): + """Test relative-to-absolute path conversion in unify_format.""" + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.ds_dir = os.path.join(self.tmp_dir, "dataset") + os.makedirs(self.ds_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_converts_relative_image_paths(self): + ds = HFDataset.from_dict({ + "text": ["sample1", "sample2"], + "images": [["img1.jpg"], ["img2.jpg"]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + image_key="images", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + for row in result: + for path in row["images"]: + self.assertTrue(os.path.isabs(path)) + self.assertTrue(path.startswith(self.ds_dir)) + + def test_preserves_absolute_paths(self): + abs_path = "/absolute/path/img.jpg" + ds = HFDataset.from_dict({ + "text": ["sample"], + "images": [[abs_path]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + image_key="images", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertEqual(result[0]["images"][0], abs_path) + + def test_no_media_keys_returns_unchanged(self): + ds = HFDataset.from_dict({"text": ["sample"]}) + cfg = JANamespace( + dataset_path=self.ds_dir, + image_key="images", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertEqual(len(result), 1) + + def test_no_global_cfg_warns(self): + ds = HFDataset.from_dict({"text": ["sample"]}) + result = unify_format(ds, text_keys=["text"], global_cfg=None) + self.assertEqual(len(result), 1) + + def test_empty_ds_dir_skips_conversion(self): + ds = HFDataset.from_dict({ + "text": ["sample"], + "images": [["relative/img.jpg"]], + }) + cfg = JANamespace( + dataset_path="/nonexistent/path", + image_key="images", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertEqual(result[0]["images"][0], "relative/img.jpg") + + def test_dict_global_cfg(self): + ds = HFDataset.from_dict({ + "text": ["sample"], + "images": [["img.jpg"]], + }) + cfg = { + "dataset_path": self.ds_dir, + "image_key": "images", + } + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertTrue(os.path.isabs(result[0]["images"][0])) + + def test_dataset_path_is_file(self): + ds_file = os.path.join(self.ds_dir, "data.jsonl") + with open(ds_file, "w") as f: + f.write('{"text": "hello"}\n') + + ds = HFDataset.from_dict({ + "text": ["sample"], + "images": [["img.jpg"]], + }) + cfg = JANamespace( + dataset_path=ds_file, + image_key="images", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertTrue(result[0]["images"][0].startswith(self.ds_dir)) + + +class AddSuffixesTest(DataJuicerTestCaseBase): + + def test_adds_suffix_column(self): + ds1 = HFDataset.from_dict({"text": ["a", "b"]}) + ds2 = HFDataset.from_dict({"text": ["c"]}) + dd = DatasetDict({"json": ds1, "csv": ds2}) + + result = add_suffixes(dd) + self.assertIn("__dj__suffix__", result.column_names) + suffixes = result["__dj__suffix__"] + self.assertIn(".json", suffixes) + self.assertIn(".csv", suffixes) + self.assertEqual(len(result), 3) + + +class UnifyFormatAudioVideoPathTest(DataJuicerTestCaseBase): + """Cover audio_key / video_key relative-to-absolute path conversion + (lines 226-232, 244, 253, 261, 272-302 in formatter.py).""" + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + self.ds_dir = os.path.join(self.tmp_dir, "dataset") + os.makedirs(self.ds_dir) + + def tearDown(self): + import shutil + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def test_converts_relative_audio_paths(self): + ds = HFDataset.from_dict({ + "text": ["s1"], + "audios": [["clip.wav"]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + audio_key="audios", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + audio_path = result[0]["audios"][0] + self.assertTrue(os.path.isabs(audio_path)) + self.assertTrue(audio_path.startswith(self.ds_dir)) + + def test_converts_relative_video_paths(self): + ds = HFDataset.from_dict({ + "text": ["s1"], + "videos": [["clip.mp4"]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + video_key="videos", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + video_path = result[0]["videos"][0] + self.assertTrue(os.path.isabs(video_path)) + self.assertTrue(video_path.startswith(self.ds_dir)) + + def test_mixed_media_keys(self): + """All three media keys (image/audio/video) converted in one call.""" + ds = HFDataset.from_dict({ + "text": ["s1"], + "images": [["img.jpg"]], + "audios": [["clip.wav"]], + "videos": [["clip.mp4"]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + image_key="images", + audio_key="audios", + video_key="videos", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + row = result[0] + for key in ("images", "audios", "videos"): + self.assertTrue(os.path.isabs(row[key][0]), + f"{key} path not absolute: {row[key][0]}") + + def test_custom_audio_key_name(self): + ds = HFDataset.from_dict({ + "text": ["s1"], + "my_audio": [["clip.wav"]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + audio_key="my_audio", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertTrue(os.path.isabs(result[0]["my_audio"][0])) + + def test_custom_video_key_name(self): + ds = HFDataset.from_dict({ + "text": ["s1"], + "my_video": [["clip.mp4"]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + video_key="my_video", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertTrue(os.path.isabs(result[0]["my_video"][0])) + + def test_preserves_absolute_audio_path(self): + abs_path = "/abs/audio.wav" + ds = HFDataset.from_dict({ + "text": ["s1"], + "audios": [[abs_path]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + audio_key="audios", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertEqual(result[0]["audios"][0], abs_path) + + def test_empty_media_list_unchanged(self): + ds = HFDataset.from_dict({ + "text": ["s1"], + "audios": [[]], + }) + cfg = JANamespace( + dataset_path=self.ds_dir, + audio_key="audios", + ) + result = unify_format(ds, text_keys=["text"], global_cfg=cfg) + self.assertEqual(result[0]["audios"], []) + diff --git a/tests/format/test_json_formatter.py b/tests/format/test_json_formatter.py index 16285291889..055e3e6e369 100644 --- a/tests/format/test_json_formatter.py +++ b/tests/format/test_json_formatter.py @@ -113,5 +113,82 @@ def test_jsonl_zst_file(self): self.assertEqual(list(ds.features.keys()), ["text", "meta"]) +class JsonFormatterLenientTest(DataJuicerTestCaseBase): + """Test JsonFormatter's lenient JSONL loading with real files.""" + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + os.environ.pop("DATA_JUICER_JSONL_LENIENT", None) + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def _write_jsonl(self, filename, lines): + import json + path = os.path.join(self.tmp_dir, filename) + with open(path, "w") as f: + for line in lines: + f.write(json.dumps(line) + "\n") + return path + + def test_lenient_env_var(self): + self._write_jsonl("data.jsonl", [ + {"text": "hello"}, + {"text": "world"}, + ]) + os.environ["DATA_JUICER_JSONL_LENIENT"] = "1" + + formatter = JsonFormatter(self.tmp_dir, text_keys=["text"]) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 2) + + def test_lenient_skips_bad_lines(self): + import json + path = os.path.join(self.tmp_dir, "data.jsonl") + with open(path, "w") as f: + f.write('{"text": "good1"}\n') + f.write('bad json line\n') + f.write('{"text": "good2"}\n') + + os.environ["DATA_JUICER_JSONL_LENIENT"] = "true" + formatter = JsonFormatter(self.tmp_dir, text_keys=["text"]) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 2) + texts = sorted(ds["text"]) + self.assertEqual(texts, ["good1", "good2"]) + + def test_lenient_cfg_flag(self): + from jsonargparse import Namespace + self._write_jsonl("data.jsonl", [ + {"text": "test"}, + ]) + cfg = Namespace(load_jsonl_lenient=True) + formatter = JsonFormatter(self.tmp_dir, text_keys=["text"]) + ds = formatter.load_dataset(global_cfg=cfg) + self.assertEqual(len(ds), 1) + + def test_non_lenient_uses_default(self): + self._write_jsonl("data.jsonl", [ + {"text": "normal"}, + ]) + formatter = JsonFormatter(self.tmp_dir, text_keys=["text"]) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 1) + self.assertEqual(ds[0]["text"], "normal") + + def test_lenient_no_jsonl_files_falls_back(self): + import json + path = os.path.join(self.tmp_dir, "data.json") + with open(path, "w") as f: + json.dump([{"text": "from_json"}], f) + + os.environ["DATA_JUICER_JSONL_LENIENT"] = "1" + formatter = JsonFormatter(self.tmp_dir, text_keys=["text"]) + ds = formatter.load_dataset() + self.assertEqual(len(ds), 1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/ops/test_mixins.py b/tests/ops/test_mixins.py new file mode 100644 index 00000000000..42388e3efec --- /dev/null +++ b/tests/ops/test_mixins.py @@ -0,0 +1,595 @@ +import os +import time +import unittest +from unittest.mock import MagicMock, patch, ANY + +from data_juicer.ops.mixins import EventDrivenMixin, NotificationMixin +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +# Concrete classes for testing (mixins need a base) +class EventDrivenTestClass(EventDrivenMixin): + pass + + +class _KwargsAbsorber: + """Base class that absorbs any kwargs so object.__init__ won't complain.""" + def __init__(self, *args, **kwargs): + pass + + +class NotificationTestClass(NotificationMixin, _KwargsAbsorber): + pass + + +class EventDrivenMixinTest(DataJuicerTestCaseBase): + + def test_register_and_trigger_event(self): + obj = EventDrivenTestClass() + results = [] + obj.register_event_handler("test_event", lambda d: results.append(d)) + obj.trigger_event("test_event", {"key": "value"}) + self.assertEqual(results, [{"key": "value"}]) + + def test_register_multiple_handlers(self): + obj = EventDrivenTestClass() + results = [] + obj.register_event_handler("evt", lambda d: results.append("a")) + obj.register_event_handler("evt", lambda d: results.append("b")) + obj.trigger_event("evt", {}) + self.assertEqual(results, ["a", "b"]) + + def test_trigger_unregistered_event_does_nothing(self): + obj = EventDrivenTestClass() + # Should not raise + obj.trigger_event("nonexistent", {"data": 1}) + + def test_start_and_stop_polling(self): + obj = EventDrivenTestClass() + call_count = {"n": 0} + + def poll_func(): + call_count["n"] += 1 + return {"polled": True} + + triggered = [] + obj.register_event_handler("poll_evt", lambda d: triggered.append(d)) + + obj.start_polling("poll_evt", poll_func, interval=0.05) + time.sleep(0.2) + obj.stop_polling("poll_evt") + + self.assertGreater(call_count["n"], 0) + self.assertGreater(len(triggered), 0) + self.assertEqual(triggered[0], {"polled": True}) + self.assertNotIn("poll_evt", obj.polling_threads) + + def test_start_polling_already_running(self): + obj = EventDrivenTestClass() + obj.start_polling("evt", lambda: None, interval=0.5) + thread1 = obj.polling_threads["evt"] + + # Starting again should not create a new thread + obj.start_polling("evt", lambda: None, interval=0.5) + thread2 = obj.polling_threads["evt"] + + self.assertIs(thread1, thread2) + obj.stop_polling("evt") + + def test_polling_handles_exception(self): + obj = EventDrivenTestClass() + call_count = {"n": 0} + + def failing_poll(): + call_count["n"] += 1 + raise ValueError("poll error") + + obj.start_polling("err_evt", failing_poll, interval=0.05) + time.sleep(0.2) + obj.stop_polling("err_evt") + + # Should have been called multiple times despite exceptions + self.assertGreater(call_count["n"], 1) + + def test_polling_returns_none_no_trigger(self): + obj = EventDrivenTestClass() + triggered = [] + obj.register_event_handler("evt", lambda d: triggered.append(d)) + + obj.start_polling("evt", lambda: None, interval=0.05) + time.sleep(0.15) + obj.stop_polling("evt") + + # None return should not trigger events + self.assertEqual(len(triggered), 0) + + def test_stop_all_polling(self): + obj = EventDrivenTestClass() + obj.start_polling("evt1", lambda: None, interval=1) + obj.start_polling("evt2", lambda: None, interval=1) + + self.assertEqual(len(obj.polling_threads), 2) + obj.stop_all_polling() + self.assertEqual(len(obj.polling_threads), 0) + + def test_wait_for_completion_success(self): + counter = {"n": 0} + + def condition(): + counter["n"] += 1 + return counter["n"] >= 3 + + obj = EventDrivenTestClass() + result = obj.wait_for_completion( + condition, timeout=5, poll_interval=0.05 + ) + self.assertTrue(result) + self.assertGreaterEqual(counter["n"], 3) + + def test_wait_for_completion_timeout(self): + obj = EventDrivenTestClass() + with self.assertRaises(TimeoutError) as ctx: + obj.wait_for_completion( + lambda: False, + timeout=0.1, + poll_interval=0.02, + error_message="custom timeout msg", + ) + self.assertIn("custom timeout msg", str(ctx.exception)) + + def test_wait_for_completion_immediate(self): + obj = EventDrivenTestClass() + result = obj.wait_for_completion(lambda: True, timeout=1) + self.assertTrue(result) + + +class NotificationMixinTest(DataJuicerTestCaseBase): + + def _make_obj(self, **config): + return NotificationTestClass(notification_config=config) + + def test_init_disabled_by_default(self): + obj = self._make_obj() + self.assertEqual(obj.notification_config, {}) + + def test_init_enabled(self): + obj = self._make_obj(enabled=True) + self.assertTrue(obj.notification_config["enabled"]) + + def test_send_notification_disabled(self): + obj = self._make_obj(enabled=False) + result = obj.send_notification("hello", notification_type="email") + self.assertTrue(result) # Returns True when disabled + + def test_send_notification_no_config(self): + obj = NotificationTestClass() + result = obj.send_notification("hello", notification_type="email") + self.assertTrue(result) + + def test_send_notification_none_type(self): + obj = self._make_obj(enabled=True) + # notification_type=None should just log and return None + obj.send_notification("hello", notification_type=None) + + def test_send_notification_unsupported_type(self): + obj = self._make_obj(enabled=True) + result = obj.send_notification("hello", notification_type="telegram") + self.assertFalse(result) + + def test_send_notification_channel_disabled(self): + obj = self._make_obj( + enabled=True, + email={"enabled": False}, + ) + result = obj.send_notification("hello", notification_type="email") + self.assertTrue(result) + + def test_send_notification_kwargs_override(self): + obj = self._make_obj( + enabled=True, + email={"smtp_server": "original.com"}, + ) + + mock_handler = MagicMock(return_value=True) + obj.notification_handlers["email"] = mock_handler + + obj.send_notification( + "hello", + notification_type="email", + email={"smtp_server": "override.com"}, + ) + mock_handler.assert_called_once() + + # After the call, the top-level notification_config reference + # should be restored (the finally block restores it). + self.assertTrue(obj.notification_config["enabled"]) + + @patch("smtplib.SMTP_SSL") + def test_send_email_ssl_with_password(self, mock_smtp_ssl): + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "username": "user@test.com", + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + "password": "secret123", + }, + ) + result = obj.send_notification("test msg", notification_type="email") + self.assertTrue(result) + mock_smtp_ssl.assert_called_once() + + @patch("smtplib.SMTP") + def test_send_email_starttls_with_password(self, mock_smtp): + mock_server = MagicMock() + mock_smtp.return_value.__enter__ = MagicMock(return_value=mock_server) + mock_smtp.return_value.__exit__ = MagicMock(return_value=False) + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 587, + "use_ssl": False, + "username": "user@test.com", + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + "password": "secret123", + }, + ) + result = obj.send_notification("test msg", notification_type="email") + self.assertTrue(result) + mock_smtp.assert_called_once() + + def test_send_email_missing_server(self): + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "", + "recipients": ["dest@test.com"], + "password": "pw", + }, + ) + result = obj.send_notification("test", notification_type="email") + self.assertFalse(result) + + def test_send_email_missing_credentials(self): + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "recipients": ["dest@test.com"], + # no password, no cert + }, + ) + result = obj.send_notification("test", notification_type="email") + self.assertFalse(result) + + def test_send_email_missing_cert_files(self): + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "recipients": ["dest@test.com"], + "use_cert_auth": True, + # no cert/key files + }, + ) + result = obj.send_notification("test", notification_type="email") + self.assertFalse(result) + + @patch("smtplib.SMTP") + @patch("ssl.create_default_context") + def test_send_email_starttls_with_cert(self, mock_ssl_ctx, mock_smtp): + mock_server = MagicMock() + mock_smtp.return_value.__enter__ = MagicMock(return_value=mock_server) + mock_smtp.return_value.__exit__ = MagicMock(return_value=False) + mock_ctx = MagicMock() + mock_ssl_ctx.return_value = mock_ctx + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 587, + "use_ssl": False, + "use_cert_auth": True, + "client_cert_file": "/fake/cert.pem", + "client_key_file": "/fake/key.pem", + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + }, + ) + result = obj.send_notification("test msg", notification_type="email") + self.assertTrue(result) + mock_ctx.load_cert_chain.assert_called_once_with( + certfile="/fake/cert.pem", keyfile="/fake/key.pem" + ) + + @patch("smtplib.SMTP_SSL") + @patch("ssl.create_default_context") + def test_send_email_ssl_with_cert(self, mock_ssl_ctx, mock_smtp_ssl): + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + mock_ctx = MagicMock() + mock_ssl_ctx.return_value = mock_ctx + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "use_cert_auth": True, + "client_cert_file": "/fake/cert.pem", + "client_key_file": "/fake/key.pem", + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + }, + ) + result = obj.send_notification("test msg", notification_type="email") + self.assertTrue(result) + + @patch.dict( + os.environ, + {"DATA_JUICER_EMAIL_PASSWORD": "env_password"}, + clear=False, + ) + @patch("smtplib.SMTP_SSL") + def test_send_email_password_from_env(self, mock_smtp_ssl): + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "username": "user@test.com", + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + # no password in config; should use env var + }, + ) + result = obj.send_notification("test msg", notification_type="email") + self.assertTrue(result) + + def test_send_email_sender_name_formatting(self): + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "sender_email": "user@test.com", + "sender_name": "Test User", + "username": "user@test.com", + "password": "pw", + "recipients": ["dest@test.com"], + }, + ) + # Patch SMTP to avoid real connection but verify the message + with patch("smtplib.SMTP_SSL") as mock_ssl: + mock_server = MagicMock() + mock_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_ssl.return_value.__exit__ = MagicMock(return_value=False) + obj.send_notification("test", notification_type="email") + + # Check the sendmail call has formatted sender + call_args = mock_server.sendmail.call_args + self.assertIn("Test User", call_args[0][0]) + + def test_send_email_exception_returns_false(self): + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "username": "u", + "password": "p", + "sender_email": "u@t.com", + "recipients": ["d@t.com"], + }, + ) + with patch("smtplib.SMTP_SSL", side_effect=Exception("conn fail")): + result = obj.send_notification("test", notification_type="email") + self.assertFalse(result) + + @patch("requests.post") + def test_send_slack_notification_success(self, mock_post): + mock_post.return_value = MagicMock(status_code=200) + + obj = self._make_obj( + enabled=True, + slack={ + "webhook_url": "https://hooks.slack.com/test", + "channel": "#test", + }, + ) + result = obj.send_notification("hello slack", notification_type="slack") + self.assertTrue(result) + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + self.assertIn("hello slack", call_kwargs[1].get("data", call_kwargs[0][1] if len(call_kwargs[0]) > 1 else "")) + + @patch("requests.post") + def test_send_slack_notification_failure(self, mock_post): + mock_post.return_value = MagicMock(status_code=500) + + obj = self._make_obj( + enabled=True, + slack={"webhook_url": "https://hooks.slack.com/test"}, + ) + result = obj.send_notification("hello", notification_type="slack") + self.assertFalse(result) + + def test_send_slack_missing_webhook(self): + obj = self._make_obj(enabled=True, slack={}) + result = obj.send_notification("hello", notification_type="slack") + self.assertFalse(result) + + def test_send_slack_exception_returns_false(self): + obj = self._make_obj( + enabled=True, + slack={"webhook_url": "https://hooks.slack.com/test"}, + ) + with patch("requests.post", side_effect=Exception("network error")): + result = obj.send_notification("hello", notification_type="slack") + self.assertFalse(result) + + @patch("requests.post") + def test_send_dingtalk_notification_success(self, mock_post): + mock_post.return_value = MagicMock( + json=MagicMock(return_value={"errcode": 0}) + ) + + obj = self._make_obj( + enabled=True, + dingtalk={"access_token": "fake_token"}, + ) + result = obj.send_notification( + "hello dingtalk", notification_type="dingtalk" + ) + self.assertTrue(result) + mock_post.assert_called_once() + + @patch("requests.post") + def test_send_dingtalk_with_secret(self, mock_post): + mock_post.return_value = MagicMock( + json=MagicMock(return_value={"errcode": 0}) + ) + + obj = self._make_obj( + enabled=True, + dingtalk={ + "access_token": "fake_token", + "secret": "fake_secret", + }, + ) + result = obj.send_notification("hello", notification_type="dingtalk") + self.assertTrue(result) + + # URL should contain timestamp and sign params + call_url = mock_post.call_args[0][0] + self.assertIn("timestamp=", call_url) + self.assertIn("sign=", call_url) + + def test_send_dingtalk_missing_token(self): + obj = self._make_obj(enabled=True, dingtalk={}) + result = obj.send_notification("hello", notification_type="dingtalk") + self.assertFalse(result) + + def test_send_dingtalk_exception_returns_false(self): + obj = self._make_obj( + enabled=True, + dingtalk={"access_token": "fake_token"}, + ) + with patch("requests.post", side_effect=Exception("fail")): + result = obj.send_notification( + "hello", notification_type="dingtalk" + ) + self.assertFalse(result) + + @patch("requests.post") + def test_send_dingtalk_api_error(self, mock_post): + mock_post.return_value = MagicMock( + json=MagicMock(return_value={"errcode": 310000, "errmsg": "fail"}) + ) + obj = self._make_obj( + enabled=True, + dingtalk={"access_token": "fake_token"}, + ) + result = obj.send_notification("hello", notification_type="dingtalk") + self.assertFalse(result) + + def test_notification_handlers_not_initialized(self): + """Test when notification_handlers is missing.""" + obj = self._make_obj(enabled=True) + del obj.notification_handlers + result = obj.send_notification("hello", notification_type="email") + self.assertFalse(result) + + @patch("smtplib.SMTP_SSL") + def test_send_email_no_include_port(self, mock_smtp_ssl): + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "include_port_in_address": False, + "username": "user@test.com", + "sender_email": "user@test.com", + "password": "pw", + "recipients": ["dest@test.com"], + }, + ) + result = obj.send_notification("test", notification_type="email") + self.assertTrue(result) + # Server address should not include port + call_args = mock_smtp_ssl.call_args + self.assertEqual(call_args[0][0], "smtp.test.com") + + @patch.dict( + os.environ, + { + "DATA_JUICER_EMAIL_CERT": "/env/cert.pem", + "DATA_JUICER_EMAIL_KEY": "/env/key.pem", + }, + clear=False, + ) + @patch("smtplib.SMTP_SSL") + @patch("ssl.create_default_context") + def test_send_email_cert_from_env(self, mock_ssl_ctx, mock_smtp_ssl): + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + mock_ctx = MagicMock() + mock_ssl_ctx.return_value = mock_ctx + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "use_cert_auth": True, + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + # cert/key from env, not config + }, + ) + result = obj.send_notification("test", notification_type="email") + self.assertTrue(result) + mock_ctx.load_cert_chain.assert_called_once_with( + certfile="/env/cert.pem", keyfile="/env/key.pem" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index dd6844037df..3cb3deee8cd 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -1,13 +1,24 @@ +import json import os +import tempfile import unittest import regex as re import gzip +import pandas as pd + from data_juicer.utils.file_utils import ( + Sizes, + byte_size_to_size_str, find_files_with_suffix, + get_all_files_paths_under, is_absolute_path, + is_remote_path, add_suffix_to_filename, create_directory_if_not_exists, + expand_outdir_and_mkdir, + read_single_partition, + single_partition_write_with_filename, transfer_filename, copy_data, ) @@ -154,5 +165,298 @@ def test_find_files_with_suffix_gzip(self): self.assertEqual(result[".jsonl.gz"], [gz_path]) +class ByteSizeToSizeStrTest(DataJuicerTestCaseBase): + """Test byte_size_to_size_str: converts byte count to human-readable string.""" + + def test_bytes_range(self): + self.assertEqual(byte_size_to_size_str(0), "0.00 Bytes") + self.assertEqual(byte_size_to_size_str(512), "512.00 Bytes") + self.assertEqual(byte_size_to_size_str(1023), "1023.00 Bytes") + + def test_kib_range(self): + self.assertEqual(byte_size_to_size_str(Sizes.KiB), "1.00 KiB") + self.assertEqual(byte_size_to_size_str(int(1.5 * Sizes.KiB)), "1.50 KiB") + + def test_mib_range(self): + self.assertEqual(byte_size_to_size_str(Sizes.MiB), "1.00 MiB") + self.assertEqual(byte_size_to_size_str(5 * Sizes.MiB), "5.00 MiB") + + def test_gib_range(self): + self.assertEqual(byte_size_to_size_str(Sizes.GiB), "1.00 GiB") + self.assertEqual(byte_size_to_size_str(int(2.5 * Sizes.GiB)), "2.50 GiB") + + def test_tib_range(self): + self.assertEqual(byte_size_to_size_str(Sizes.TiB), "1.00 TiB") + self.assertEqual(byte_size_to_size_str(3 * Sizes.TiB), "3.00 TiB") + + def test_boundary_kib(self): + """Exactly at KiB boundary should show KiB, not Bytes.""" + self.assertEqual(byte_size_to_size_str(1024), "1.00 KiB") + + def test_boundary_mib(self): + """Exactly at MiB boundary should show MiB, not KiB.""" + self.assertEqual(byte_size_to_size_str(1024 * 1024), "1.00 MiB") + + +class IsRemotePathTest(DataJuicerTestCaseBase): + """Test is_remote_path: detects http/https/s3/gs/hdfs URLs.""" + + def test_http(self): + self.assertTrue(is_remote_path("http://example.com/data.json")) + + def test_https(self): + self.assertTrue(is_remote_path("https://example.com/data.json")) + + def test_s3(self): + self.assertTrue(is_remote_path("s3://bucket/key")) + + def test_gs(self): + self.assertTrue(is_remote_path("gs://bucket/key")) + + def test_hdfs(self): + self.assertTrue(is_remote_path("hdfs://cluster/path")) + + def test_local_absolute(self): + self.assertFalse(is_remote_path("/home/user/data.json")) + + def test_local_relative(self): + self.assertFalse(is_remote_path("relative/path.json")) + + def test_remote_path_in_is_absolute(self): + """is_absolute_path should return True for remote paths.""" + self.assertTrue(is_absolute_path("s3://bucket/key")) + self.assertTrue(is_absolute_path("https://example.com/file")) + + +class GetAllFilesPathsUnderTest(DataJuicerTestCaseBase): + """Test get_all_files_paths_under: lists files recursively or flat.""" + + def setUp(self): + super().setUp() + self.root = tempfile.mkdtemp() + # Create structure: + # root/a.txt + # root/b.txt + # root/sub/c.txt + for name in ["a.txt", "b.txt"]: + with open(os.path.join(self.root, name), "w") as f: + f.write(name) + subdir = os.path.join(self.root, "sub") + os.makedirs(subdir) + with open(os.path.join(subdir, "c.txt"), "w") as f: + f.write("c") + + def tearDown(self): + import shutil + shutil.rmtree(self.root, ignore_errors=True) + super().tearDown() + + def test_recurse(self): + result = get_all_files_paths_under(self.root, recurse_subdirectories=True) + basenames = [os.path.basename(p) for p in result] + self.assertIn("a.txt", basenames) + self.assertIn("b.txt", basenames) + self.assertIn("c.txt", basenames) + + def test_no_recurse(self): + result = get_all_files_paths_under(self.root, recurse_subdirectories=False) + basenames = [os.path.basename(p) for p in result] + self.assertIn("a.txt", basenames) + self.assertIn("b.txt", basenames) + # sub/ is a directory entry, c.txt should not appear at top level + self.assertNotIn("c.txt", [os.path.basename(p) for p in result if os.path.isfile(p) and "sub" not in p]) + + def test_sorted_output(self): + result = get_all_files_paths_under(self.root, recurse_subdirectories=True) + self.assertEqual(result, sorted(result)) + + +class SinglePartitionWriteTest(DataJuicerTestCaseBase): + """Test single_partition_write_with_filename: writes DataFrame partitions to disk.""" + + def setUp(self): + super().setUp() + self.outdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.outdir, ignore_errors=True) + super().tearDown() + + def test_write_jsonl(self): + df = pd.DataFrame({ + "text": ["hello", "world"], + "filename": ["part1.jsonl", "part1.jsonl"], + }) + result = single_partition_write_with_filename(df, self.outdir, output_type="jsonl") + # Non-empty partition → returns Series([False]) (empty_partition=False) + self.assertEqual(list(result), [False]) + + out_file = os.path.join(self.outdir, "part1.jsonl") + self.assertTrue(os.path.exists(out_file)) + with open(out_file) as f: + lines = [json.loads(line) for line in f] + texts = [row["text"] for row in lines] + self.assertEqual(sorted(texts), ["hello", "world"]) + # filename column should be dropped by default + for row in lines: + self.assertNotIn("filename", row) + + def test_write_parquet(self): + df = pd.DataFrame({ + "text": ["alpha", "beta"], + "filename": ["data.parquet", "data.parquet"], + }) + result = single_partition_write_with_filename( + df, self.outdir, output_type="parquet") + self.assertEqual(list(result), [False]) + + out_file = os.path.join(self.outdir, "data.parquet") + self.assertTrue(os.path.exists(out_file)) + read_back = pd.read_parquet(out_file) + self.assertEqual(list(read_back["text"]), ["alpha", "beta"]) + + def test_keep_filename_column(self): + df = pd.DataFrame({ + "text": ["keep"], + "filename": ["out.jsonl"], + }) + single_partition_write_with_filename( + df, self.outdir, keep_filename_column=True, output_type="jsonl") + out_file = os.path.join(self.outdir, "out.jsonl") + with open(out_file) as f: + row = json.loads(f.readline()) + self.assertIn("filename", row) + + def test_empty_partition(self): + df = pd.DataFrame({"text": [], "filename": []}) + result = single_partition_write_with_filename(df, self.outdir) + # Empty partition → returns Series([True]) + self.assertEqual(list(result), [True]) + # No file should be written + self.assertEqual(os.listdir(self.outdir), []) + + def test_unknown_output_type_raises(self): + df = pd.DataFrame({ + "text": ["x"], + "filename": ["f.csv"], + }) + with self.assertRaises(ValueError): + single_partition_write_with_filename( + df, self.outdir, output_type="csv") + + def test_multiple_filenames(self): + """Rows with different filenames should be split into separate files.""" + df = pd.DataFrame({ + "text": ["a1", "b1", "a2"], + "filename": ["file_a.jsonl", "file_b.jsonl", "file_a.jsonl"], + }) + single_partition_write_with_filename(df, self.outdir, output_type="jsonl") + self.assertTrue(os.path.exists(os.path.join(self.outdir, "file_a.jsonl"))) + self.assertTrue(os.path.exists(os.path.join(self.outdir, "file_b.jsonl"))) + + +class ReadSinglePartitionTest(DataJuicerTestCaseBase): + """Test read_single_partition: reads jsonl/json/parquet files into DataFrame.""" + + def setUp(self): + super().setUp() + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + super().tearDown() + + def _write_jsonl(self, filename, rows): + path = os.path.join(self.tmpdir, filename) + with open(path, "w") as f: + for row in rows: + f.write(json.dumps(row) + "\n") + return path + + def test_read_jsonl(self): + path = self._write_jsonl("data.jsonl", [ + {"text": "hello", "score": 1}, + {"text": "world", "score": 2}, + ]) + df = read_single_partition([path], filetype="jsonl") + self.assertEqual(len(df), 2) + self.assertIn("text", df.columns) + self.assertIn("score", df.columns) + + def test_read_jsonl_add_filename(self): + path = self._write_jsonl("data.jsonl", [{"text": "row1"}]) + df = read_single_partition([path], filetype="jsonl", add_filename=True) + self.assertIn("filename", df.columns) + self.assertEqual(df["filename"].iloc[0], "data.jsonl") + + def test_read_jsonl_with_columns_filter(self): + path = self._write_jsonl("data.jsonl", [ + {"text": "a", "score": 1, "extra": "x"}, + ]) + df = read_single_partition( + [path], filetype="jsonl", columns=["text", "score"]) + self.assertIn("text", df.columns) + self.assertIn("score", df.columns) + self.assertNotIn("extra", df.columns) + + def test_read_parquet(self): + path = os.path.join(self.tmpdir, "data.parquet") + pd.DataFrame({"text": ["a", "b"], "num": [1, 2]}).to_parquet(path) + df = read_single_partition([path], filetype="parquet") + self.assertEqual(len(df), 2) + self.assertIn("text", df.columns) + + def test_read_parquet_with_columns(self): + path = os.path.join(self.tmpdir, "data.parquet") + pd.DataFrame({"text": ["a"], "num": [1], "extra": ["x"]}).to_parquet(path) + df = read_single_partition( + [path], filetype="parquet", columns=["text"]) + self.assertIn("text", df.columns) + self.assertNotIn("extra", df.columns) + + def test_read_multiple_jsonl_files(self): + p1 = self._write_jsonl("a.jsonl", [{"text": "row1"}]) + p2 = self._write_jsonl("b.jsonl", [{"text": "row2"}]) + df = read_single_partition([p1, p2], filetype="jsonl") + self.assertEqual(len(df), 2) + + def test_unknown_filetype_raises(self): + with self.assertRaises(RuntimeError): + read_single_partition(["dummy.csv"], filetype="csv") + + def test_columns_sorted(self): + """Output columns should be alphabetically sorted.""" + path = self._write_jsonl("data.jsonl", [ + {"z_col": 1, "a_col": 2, "m_col": 3}, + ]) + df = read_single_partition([path], filetype="jsonl") + self.assertEqual(list(df.columns), sorted(df.columns)) + + +class ExpandOutdirAndMkdirTest(DataJuicerTestCaseBase): + """Test expand_outdir_and_mkdir: expands path and creates directory.""" + + def test_creates_and_returns_absolute(self): + tmpdir = tempfile.mkdtemp() + import shutil + shutil.rmtree(tmpdir) + self.assertFalse(os.path.exists(tmpdir)) + + result = expand_outdir_and_mkdir(tmpdir) + self.assertTrue(os.path.exists(result)) + self.assertTrue(os.path.isabs(result)) + + shutil.rmtree(result, ignore_errors=True) + + def test_existing_dir(self): + tmpdir = tempfile.mkdtemp() + result = expand_outdir_and_mkdir(tmpdir) + self.assertTrue(os.path.exists(result)) + import shutil + shutil.rmtree(tmpdir, ignore_errors=True) + + if __name__ == "__main__": unittest.main() diff --git a/uv.lock b/uv.lock index c47201c0775..664cc5d8641 100644 --- a/uv.lock +++ b/uv.lock @@ -1016,18 +1016,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/c5/f6ce561004db45f0b847c2cd9b19c67c6bf348a82018a48cb718be6b58b0/botocore-1.40.61-py3-none-any.whl", hash = "sha256:17ebae412692fd4824f99cde0f08d50126dc97954008e5ba2b522eb049238aa7", size = 14055973, upload-time = "2025-10-28T19:26:42.15Z" }, ] -[[package]] -name = "bs4" -version = "0.0.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "beautifulsoup4" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c9/aa/4acaf814ff901145da37332e05bb510452ebed97bc9602695059dd46ef39/bs4-0.0.2.tar.gz", hash = "sha256:a48685c58f50fe127722417bae83fe6badf500d54b55f7e39ffe43b798653925", size = 698, upload-time = "2024-01-17T18:15:47.371Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/51/bb/bf7aab772a159614954d84aa832c129624ba6c32faa559dfb200a534e50b/bs4-0.0.2-py2.py3-none-any.whl", hash = "sha256:abf8742c0805ef7f662dce4b51cca104cffe52b835238afc169142ab9b3fbccc", size = 1189, upload-time = "2024-01-17T18:15:48.613Z" }, -] - [[package]] name = "build" version = "1.2.2.post1" @@ -6244,7 +6232,7 @@ wheels = [ name = "py-data-juicer" source = { editable = "." } dependencies = [ - { name = "bs4" }, + { name = "beautifulsoup4" }, { name = "datasets" }, { name = "dep-logic" }, { name = "dill" }, @@ -6464,13 +6452,13 @@ requires-dist = [ { name = "audiomentations", marker = "extra == 'audio'" }, { name = "av", marker = "extra == 'all'", specifier = "==13.1.0" }, { name = "av", marker = "extra == 'vision'", specifier = "==13.1.0" }, + { name = "beautifulsoup4" }, { name = "bitarray", marker = "extra == 'all'" }, { name = "bitarray", marker = "extra == 'distributed'" }, { name = "black", marker = "extra == 'all'", specifier = ">=25.1.0" }, { name = "black", marker = "extra == 'dev'", specifier = ">=25.1.0" }, { name = "boto3", marker = "extra == 'all'" }, { name = "boto3", marker = "extra == 'distributed'" }, - { name = "bs4" }, { name = "build", marker = "extra == 'all'" }, { name = "build", marker = "extra == 'dev'" }, { name = "click", marker = "extra == 'all'" }, From 724ce80d367c70dab589845065e11160e4f3616a Mon Sep 17 00:00:00 2001 From: cmgzn Date: Mon, 8 Jun 2026 18:31:15 +0800 Subject: [PATCH 02/13] test: add tests for common_utils, sample, jsonl_lenient_loader - common_utils: check_op_method_param, deprecated decorator (bare/with-reason/with-version/invalid-args) - sample: random_sample (weight/number/upsample/seed/rounding) - jsonl_lenient_loader: zstd decompression, missing file handling, empty lines --- tests/utils/test_common_utils.py | 119 ++++++++++++++++++++++- tests/utils/test_jsonl_lenient_loader.py | 63 ++++++++++++ tests/utils/test_sample.py | 71 ++++++++++++++ 3 files changed, 249 insertions(+), 4 deletions(-) create mode 100644 tests/utils/test_sample.py diff --git a/tests/utils/test_common_utils.py b/tests/utils/test_common_utils.py index c0bc520e730..68d8108b994 100644 --- a/tests/utils/test_common_utils.py +++ b/tests/utils/test_common_utils.py @@ -1,11 +1,17 @@ -import unittest import sys +import unittest +import warnings from data_juicer.utils.common_utils import ( - stats_to_number, dict_to_hash, nested_access, is_string_list, - avg_split_string_list_under_limit, is_float + avg_split_string_list_under_limit, + check_op_method_param, + deprecated, + dict_to_hash, + is_float, + is_string_list, + nested_access, + stats_to_number, ) - from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class CommonUtilsTest(DataJuicerTestCaseBase): @@ -53,5 +59,110 @@ def test_avg_split_string_list_under_limit(self): self.assertEqual(avg_split_string_list_under_limit(str_list, token_nums, max_token_num), expected_result) +class CheckOpMethodParamTest(DataJuicerTestCaseBase): + """Test check_op_method_param: checks if method has named param or **kwargs.""" + + def test_finds_named_param(self): + def example(x, target_param, y): + pass + self.assertTrue(check_op_method_param(example, 'target_param')) + + def test_missing_param_returns_false(self): + def example(x, y): + pass + self.assertFalse(check_op_method_param(example, 'missing')) + + def test_finds_var_keyword(self): + """If method has **kwargs, any param name should return True.""" + def example(x, **kwargs): + pass + self.assertTrue(check_op_method_param(example, 'anything')) + + def test_no_params(self): + def example(): + pass + self.assertFalse(check_op_method_param(example, 'x')) + + def test_self_param_on_method(self): + class Dummy: + def method(self, context): + pass + self.assertTrue(check_op_method_param(Dummy.method, 'context')) + self.assertFalse(check_op_method_param(Dummy.method, 'missing')) + + +class DeprecatedDecoratorTest(DataJuicerTestCaseBase): + """Test deprecated decorator: marks functions as deprecated with warnings.""" + + def test_bare_decorator(self): + """@deprecated without arguments should emit DeprecationWarning.""" + @deprecated + def old_func(): + return 42 + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result = old_func() + + self.assertEqual(result, 42) + self.assertEqual(len(caught), 1) + self.assertTrue(issubclass(caught[0].category, DeprecationWarning)) + self.assertIn("old_func", str(caught[0].message)) + + def test_with_reason(self): + @deprecated(reason="Use new_func instead") + def old_func(): + return 1 + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + old_func() + + self.assertIn("Use new_func instead", str(caught[0].message)) + + def test_with_version(self): + @deprecated(reason="Outdated", version="2.0") + def old_func(): + return 1 + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + old_func() + + msg = str(caught[0].message) + self.assertIn("Outdated", msg) + self.assertIn("2.0", msg) + + def test_preserves_function_name(self): + @deprecated(reason="old") + def my_special_func(): + pass + self.assertEqual(my_special_func.__name__, "my_special_func") + + def test_invalid_reason_type_raises(self): + with self.assertRaises(TypeError): + @deprecated(reason=123) + def func(): + pass + + def test_invalid_version_type_raises(self): + with self.assertRaises(TypeError): + @deprecated(version=123) + def func(): + pass + + def test_bare_decorator_no_parens_works(self): + """@deprecated without parens should work and return a wrapper.""" + @deprecated + def old_func(): + return "result" + + self.assertEqual(old_func.__name__, "old_func") + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + self.assertEqual(old_func(), "result") + self.assertEqual(len(caught), 1) + + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_jsonl_lenient_loader.py b/tests/utils/test_jsonl_lenient_loader.py index 42d9c6c1040..eabbe38fd2a 100644 --- a/tests/utils/test_jsonl_lenient_loader.py +++ b/tests/utils/test_jsonl_lenient_loader.py @@ -112,5 +112,68 @@ def test_handles_gzip_file(self): self.assertEqual(len(rows), 2) + def test_handles_zstd_file(self): + """Should handle .jsonl.zst files via zstandard decompression.""" + import zstandard as zstd_mod + + jsonl_path = os.path.join(self.tmp_dir, "test.jsonl.zst") + raw = (json.dumps({"z": 1}) + "\n" + json.dumps({"z": 2}) + "\n").encode("utf-8") + cctx = zstd_mod.ZstdCompressor() + with open(jsonl_path, "wb") as f: + f.write(cctx.compress(raw)) + + rows = list( + iter_lenient_jsonl_records( + [(jsonl_path, ".jsonl.zst")], + add_suffix_column=False, + ) + ) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["z"], 1) + self.assertEqual(rows[1]["z"], 2) + + def test_missing_file_skipped(self): + """Missing files should be skipped gracefully, not raise.""" + rows = list( + iter_lenient_jsonl_records( + [("/nonexistent/file.jsonl", ".jsonl")], + add_suffix_column=False, + ) + ) + self.assertEqual(rows, []) + + def test_empty_lines_skipped(self): + """Blank lines in JSONL should be silently skipped.""" + jsonl_path = os.path.join(self.tmp_dir, "blanks.jsonl") + with open(jsonl_path, "w") as f: + f.write("\n\n" + json.dumps({"k": 1}) + "\n\n") + rows = list( + iter_lenient_jsonl_records( + [(jsonl_path, ".jsonl")], + add_suffix_column=False, + ) + ) + self.assertEqual(len(rows), 1) + + def test_zstd_with_suffix_column(self): + """Zstd file with add_suffix_column=True should add suffix.""" + import zstandard as zstd_mod + + jsonl_path = os.path.join(self.tmp_dir, "test.jsonl.zst") + raw = (json.dumps({"x": 1}) + "\n").encode("utf-8") + cctx = zstd_mod.ZstdCompressor() + with open(jsonl_path, "wb") as f: + f.write(cctx.compress(raw)) + + rows = list( + iter_lenient_jsonl_records( + [(jsonl_path, ".jsonl.zst")], + add_suffix_column=True, + ) + ) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][Fields.suffix], ".jsonl.zst") + + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/test_sample.py b/tests/utils/test_sample.py new file mode 100644 index 00000000000..ee5c3696ebc --- /dev/null +++ b/tests/utils/test_sample.py @@ -0,0 +1,71 @@ +import unittest + +from datasets import Dataset + +from data_juicer.utils.sample import random_sample +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class RandomSampleTest(DataJuicerTestCaseBase): + """Test random_sample: subset selection by weight or number.""" + + def _make_dataset(self, num_rows): + return Dataset.from_dict({"text": [f"row{i}" for i in range(num_rows)]}) + + def test_sample_by_weight(self): + ds = self._make_dataset(10) + result = random_sample(ds, weight=0.5) + self.assertEqual(len(result), 5) + + def test_sample_by_number(self): + ds = self._make_dataset(10) + result = random_sample(ds, sample_number=3) + self.assertEqual(len(result), 3) + + def test_number_overrides_weight(self): + """When sample_number > 0, it takes precedence over weight.""" + ds = self._make_dataset(10) + result = random_sample(ds, weight=0.1, sample_number=7) + self.assertEqual(len(result), 7) + + def test_full_dataset_returns_same(self): + """weight=1.0 with full dataset returns the original dataset.""" + ds = self._make_dataset(10) + result = random_sample(ds, weight=1.0) + self.assertIs(result, ds) + + def test_default_seed_is_42(self): + """Two calls without seed should produce identical results.""" + ds = self._make_dataset(20) + r1 = random_sample(ds, weight=0.5) + r2 = random_sample(ds, weight=0.5) + self.assertEqual(r1["text"], r2["text"]) + + def test_different_seeds_differ(self): + ds = self._make_dataset(20) + r1 = random_sample(ds, weight=0.5, seed=1) + r2 = random_sample(ds, weight=0.5, seed=2) + self.assertNotEqual(r1["text"], r2["text"]) + + def test_upsample_repeats(self): + """sample_number > dataset size should repeat rows.""" + ds = self._make_dataset(3) + result = random_sample(ds, sample_number=7) + self.assertEqual(len(result), 7) + + def test_weight_zero(self): + """weight=0 with sample_number=0 → ceil(0)=0 → empty subset.""" + ds = self._make_dataset(10) + result = random_sample(ds, weight=0.0, sample_number=0) + self.assertEqual(len(result), 0) + + def test_fractional_weight_rounds_up(self): + """np.ceil ensures partial rows round up.""" + ds = self._make_dataset(10) + result = random_sample(ds, weight=0.15) + # ceil(10 * 0.15) = ceil(1.5) = 2 + self.assertEqual(len(result), 2) + + +if __name__ == "__main__": + unittest.main() From 192b251f4876ad7fd28602016fb2be8095fd8900 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Mon, 8 Jun 2026 18:32:43 +0800 Subject: [PATCH 03/13] test: add Hasher and update_fingerprint tests for fingerprint_utils - HasherBasicTest: hash_bytes, update/hexdigest, dispatch fallback - UpdateFingerprintTest: deterministic behavior, unhashable transform/args with caching enabled/disabled, empty args --- tests/utils/test_fingerprint_utils.py | 127 +++++++++++++++++++++++++- 1 file changed, 126 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_fingerprint_utils.py b/tests/utils/test_fingerprint_utils.py index f7f34354623..7493a83139c 100644 --- a/tests/utils/test_fingerprint_utils.py +++ b/tests/utils/test_fingerprint_utils.py @@ -1,10 +1,15 @@ import unittest +from unittest.mock import patch import dill from data_juicer.core import NestedDataset from data_juicer.ops.filter.text_length_filter import TextLengthFilter -from data_juicer.utils.fingerprint_utils import Hasher, generate_fingerprint +from data_juicer.utils.fingerprint_utils import ( + Hasher, + generate_fingerprint, + update_fingerprint, +) from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -154,5 +159,125 @@ def run_pipeline(dataset, work_dir): f'files; expected 0 (full cache hit)') +class HasherBasicTest(DataJuicerTestCaseBase): + """Test Hasher class basic operations.""" + + def test_hash_bytes_single(self): + result = Hasher.hash_bytes(b"hello") + self.assertIsInstance(result, str) + self.assertGreater(len(result), 0) + + def test_hash_bytes_list(self): + result = Hasher.hash_bytes([b"hello", b"world"]) + self.assertIsInstance(result, str) + + def test_hash_bytes_deterministic(self): + r1 = Hasher.hash_bytes(b"test") + r2 = Hasher.hash_bytes(b"test") + self.assertEqual(r1, r2) + + def test_hash_bytes_different_input(self): + r1 = Hasher.hash_bytes(b"a") + r2 = Hasher.hash_bytes(b"b") + self.assertNotEqual(r1, r2) + + def test_update_and_hexdigest(self): + h = Hasher() + h.update("hello") + result = h.hexdigest() + self.assertIsInstance(result, str) + self.assertGreater(len(result), 0) + + def test_hash_dispatch_fallback(self): + """Types not in dispatch should use hash_default (dill-based).""" + result = Hasher.hash({"key": "value"}) + self.assertIsInstance(result, str) + + +class UpdateFingerprintTest(DataJuicerTestCaseBase): + """Test update_fingerprint: handles unhashable transforms gracefully.""" + + def test_normal_transform(self): + """Serializable transform + args should produce a deterministic fingerprint.""" + fp = update_fingerprint("base_fp", lambda x: x, {"key": "val"}) + self.assertIsInstance(fp, str) + self.assertGreater(len(fp), 0) + + def test_deterministic_same_inputs(self): + fp1 = update_fingerprint("fp", "transform", {"a": 1}) + fp2 = update_fingerprint("fp", "transform", {"a": 1}) + self.assertEqual(fp1, fp2) + + def test_different_args_different_fingerprint(self): + fp1 = update_fingerprint("fp", "transform", {"a": 1}) + fp2 = update_fingerprint("fp", "transform", {"a": 2}) + self.assertNotEqual(fp1, fp2) + + def test_unhashable_transform_returns_random(self): + """When transform can't be serialized, should return a random fingerprint.""" + from datasets.fingerprint import fingerprint_warnings + # Reset warning state + fingerprint_warnings.pop("update_fingerprint_transform_hash_failed", None) + + # Create an object that dill cannot serialize + class Unpicklable: + def __reduce__(self): + raise TypeError("cannot pickle") + + with patch("data_juicer.utils.fingerprint_utils._CACHING_ENABLED", True): + fp = update_fingerprint("base", Unpicklable(), {}) + + self.assertIsInstance(fp, str) + self.assertGreater(len(fp), 0) + # Should be a random fingerprint, different each time + fingerprint_warnings.pop("update_fingerprint_transform_hash_failed", None) + with patch("data_juicer.utils.fingerprint_utils._CACHING_ENABLED", True): + fp2 = update_fingerprint("base", Unpicklable(), {}) + self.assertNotEqual(fp, fp2) + + def test_unhashable_transform_no_caching(self): + """When caching is disabled and transform unhashable, still returns random fp.""" + from datasets.fingerprint import fingerprint_warnings + fingerprint_warnings.pop("update_fingerprint_transform_hash_failed", None) + + class Unpicklable: + def __reduce__(self): + raise TypeError("cannot pickle") + + with patch("data_juicer.utils.fingerprint_utils._CACHING_ENABLED", False): + fp = update_fingerprint("base", Unpicklable(), {}) + self.assertIsInstance(fp, str) + + def test_unhashable_arg_returns_random(self): + """When a transform_arg can't be serialized, should return random fingerprint.""" + from datasets.fingerprint import fingerprint_warnings + fingerprint_warnings.pop("update_fingerprint_transform_hash_failed", None) + + class Unpicklable: + def __reduce__(self): + raise TypeError("cannot pickle") + + with patch("data_juicer.utils.fingerprint_utils._CACHING_ENABLED", True): + fp = update_fingerprint("base", "good_transform", {"bad_arg": Unpicklable()}) + self.assertIsInstance(fp, str) + + def test_unhashable_arg_no_caching(self): + """When caching disabled and arg unhashable, still returns random fp.""" + from datasets.fingerprint import fingerprint_warnings + fingerprint_warnings.pop("update_fingerprint_transform_hash_failed", None) + + class Unpicklable: + def __reduce__(self): + raise TypeError("cannot pickle") + + with patch("data_juicer.utils.fingerprint_utils._CACHING_ENABLED", False): + fp = update_fingerprint("base", "transform", {"arg": Unpicklable()}) + self.assertIsInstance(fp, str) + + def test_empty_args(self): + fp = update_fingerprint("fp", "transform", {}) + self.assertIsInstance(fp, str) + + if __name__ == '__main__': unittest.main() From 09087e5dd23ef573898a6d8e4e1c1e057036b0bf Mon Sep 17 00:00:00 2001 From: cmgzn Date: Mon, 8 Jun 2026 18:36:36 +0800 Subject: [PATCH 04/13] test: add agent_output_locale tests for zh/en locale branches - normalize_preferred_output_lang: zh variants, en variants, empty, unknown - rubric_reason_language_clause: zh/en branches - llm_filter_free_text_language_appendix: zh/en/empty - agent_insight_system_prompt: zh/en - dialog_detection_output_language_note: intent/topic/sentiment/intensity modes --- tests/utils/test_agent_output_locale.py | 81 +++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/utils/test_agent_output_locale.py b/tests/utils/test_agent_output_locale.py index 9e602836a3f..427f1651189 100644 --- a/tests/utils/test_agent_output_locale.py +++ b/tests/utils/test_agent_output_locale.py @@ -4,10 +4,13 @@ import unittest from data_juicer.utils.agent_output_locale import ( + agent_insight_system_prompt, agent_skill_insight_system_prompt, + dialog_detection_output_language_note, dialog_score_json_instruction, llm_filter_free_text_language_appendix, normalize_preferred_output_lang, + rubric_reason_language_clause, ) from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -37,5 +40,83 @@ def test_skill_insight_prompt_en_concrete_length(self): self.assertIn("Forbidden", s) + def test_normalize_zh_variants(self): + """All Chinese locale variants should normalize to 'zh'.""" + for val in ("zh", "zh-CN", "zh-TW", "zh_cn", "ZH"): + self.assertEqual(normalize_preferred_output_lang(val), "zh", + f"Failed for {val!r}") + + def test_normalize_en_variants(self): + for val in ("en", "EN", "english", "eng", "English"): + self.assertEqual(normalize_preferred_output_lang(val), "en", + f"Failed for {val!r}") + + def test_normalize_empty_string(self): + self.assertEqual(normalize_preferred_output_lang(""), "en") + + def test_normalize_unknown_defaults_en(self): + self.assertEqual(normalize_preferred_output_lang("fr"), "en") + self.assertEqual(normalize_preferred_output_lang("ja"), "en") + + def test_json_instruction_en(self): + s = dialog_score_json_instruction("en") + self.assertIn("score", s) + self.assertIn("JSON", s) + + def test_rubric_reason_zh(self): + s = rubric_reason_language_clause("zh") + self.assertIn("简体中文", s) + + def test_rubric_reason_en(self): + s = rubric_reason_language_clause("en") + self.assertIn("English", s) + + def test_filter_appendix_zh(self): + s = llm_filter_free_text_language_appendix("zh") + self.assertIn("简体中文", s) + self.assertGreater(len(s), 0) + + def test_filter_appendix_en(self): + s = llm_filter_free_text_language_appendix("en") + self.assertIn("English", s) + + def test_filter_appendix_empty_string(self): + self.assertEqual(llm_filter_free_text_language_appendix(""), "") + + def test_agent_insight_prompt_zh(self): + s = agent_insight_system_prompt("zh") + self.assertIn("分析员", s) + + def test_agent_insight_prompt_en(self): + s = agent_insight_system_prompt("en") + self.assertIn("analyst", s.lower()) + + def test_dialog_detection_note_intent_zh(self): + s = dialog_detection_output_language_note("zh", "intent") + self.assertIn("意图分析", s) + self.assertIn("简体中文", s) + + def test_dialog_detection_note_intent_en(self): + s = dialog_detection_output_language_note("en", "intent") + self.assertIn("English", s) + + def test_dialog_detection_note_topic_zh(self): + s = dialog_detection_output_language_note("zh", "topic") + self.assertIn("话题", s) + + def test_dialog_detection_note_sentiment_zh(self): + s = dialog_detection_output_language_note("zh", "sentiment") + self.assertIn("情感", s) + + def test_dialog_detection_note_intensity_zh(self): + s = dialog_detection_output_language_note("zh", "intensity") + self.assertIn("情绪", s) + + def test_dialog_detection_note_unknown_mode(self): + """Unknown mode should return empty string.""" + self.assertEqual(dialog_detection_output_language_note("zh", "unknown"), "") + self.assertEqual(dialog_detection_output_language_note("en", "foobar"), "") + + if __name__ == "__main__": unittest.main() From 1e50dc48e18857f4a45ef4bac4d9c30351b065ab Mon Sep 17 00:00:00 2001 From: cmgzn Date: Mon, 8 Jun 2026 18:44:13 +0800 Subject: [PATCH 05/13] test: add mapper tests for replace_content, clean_email, fix_unicode - replace_content_mapper: None pattern, multi-pattern/repl, mismatched length, raw string strip - clean_email_mapper: custom pattern, no-match, batched process path - fix_unicode_mapper: custom NFKC normalization, invalid raises, empty defaults --- tests/ops/mapper/test_clean_email_mapper.py | 36 ++++++++++ tests/ops/mapper/test_fix_unicode_mapper.py | 19 ++++++ .../ops/mapper/test_replace_content_mapper.py | 65 +++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py index 1ff7e389e2a..ee4c62dfe1b 100644 --- a/tests/ops/mapper/test_clean_email_mapper.py +++ b/tests/ops/mapper/test_clean_email_mapper.py @@ -51,5 +51,41 @@ def test_replace_email(self): self._run_clean_email(op, samples) + def test_custom_pattern(self): + """Custom pattern via r-string should have r'...' markers stripped.""" + samples = [{ + 'text': 'Contact: user@example.org for info', + 'target': 'Contact: for info', + }] + op = CleanEmailMapper( + pattern=r"r'[A-Za-z0-9.+_]+@[a-z0-9.+_]+\.[a-z]+'", + repl='', + ) + self._run_clean_email(op, samples) + + def test_no_email_unchanged(self): + samples = [{ + 'text': 'No emails here!', + 'target': 'No emails here!', + }] + op = CleanEmailMapper() + self._run_clean_email(op, samples) + + def test_batched_process(self): + """Ensure process_batched is exercised via generate_dataset + run_single_op.""" + ds_list = [ + {'text': 'hello user@test.com world'}, + {'text': 'clean text here'}, + ] + tgt_list = [ + {'text': 'hello world'}, + {'text': 'clean text here'}, + ] + dataset = self.generate_dataset(ds_list) + op = CleanEmailMapper(batch_size=2) + result = self.run_single_op(dataset, op, ['text']) + self.assertDatasetEqual(result, tgt_list) + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_fix_unicode_mapper.py b/tests/ops/mapper/test_fix_unicode_mapper.py index 25776a96f91..55e812c0fb7 100644 --- a/tests/ops/mapper/test_fix_unicode_mapper.py +++ b/tests/ops/mapper/test_fix_unicode_mapper.py @@ -49,5 +49,24 @@ def test_good_unicode_text(self): self._run_fix_unicode(samples) + def test_custom_normalization_nfkc(self): + """Custom normalization mode NFKC should work.""" + op = FixUnicodeMapper(normalization='nfkc') + samples = [{'text': 'fi', 'target': 'fi'}] # fi ligature → fi in NFKC + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + for data in dataset: + self.assertEqual(data['text'], data['target']) + + def test_invalid_normalization_raises(self): + with self.assertRaises(ValueError): + FixUnicodeMapper(normalization='INVALID') + + def test_empty_normalization_defaults_nfc(self): + """Empty string normalization should default to NFC.""" + op = FixUnicodeMapper(normalization='') + self.assertEqual(op.normalization, 'NFC') + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_replace_content_mapper.py b/tests/ops/mapper/test_replace_content_mapper.py index 23ddb34539b..7560408fd7b 100644 --- a/tests/ops/mapper/test_replace_content_mapper.py +++ b/tests/ops/mapper/test_replace_content_mapper.py @@ -60,5 +60,70 @@ def test_regular_digit_pattern_text(self): self._run_helper(op, samples) + def test_none_pattern_returns_unchanged(self): + samples = [ + {'text': 'Hello world', 'target': 'Hello world'}, + ] + op = ReplaceContentMapper(pattern=None) + self._run_helper(op, samples) + + def test_multiple_patterns_with_list_repl(self): + samples = [ + { + 'text': 'foo@bar.com called 123-456', + 'target': ' called ', + }, + ] + op = ReplaceContentMapper( + pattern=[r'[\w]+@[\w]+\.[\w]+', r'\d+-\d+'], + repl=['', ''], + ) + self._run_helper(op, samples) + + def test_multiple_patterns_single_repl(self): + samples = [ + { + 'text': 'aaa bbb ccc', + 'target': 'X X ccc', + }, + ] + op = ReplaceContentMapper( + pattern=['aaa', 'bbb'], + repl='X', + ) + self._run_helper(op, samples) + + def test_mismatched_pattern_repl_length_raises(self): + op = ReplaceContentMapper( + pattern=['a', 'b', 'c'], + repl=['x'], + ) + samples = [{'text': 'a b c'}] + dataset = Dataset.from_list(samples) + with self.assertRaises(ValueError): + dataset.map(op.process, batch_size=2) + + def test_raw_string_pattern_stripped(self): + """Pattern wrapped in r'...' should have the r-string markers removed.""" + samples = [ + { + 'text': 'test 123 end', + 'target': 'test end', + }, + ] + op = ReplaceContentMapper(pattern="r'\\d+'", repl='') + self._run_helper(op, samples) + + def test_empty_repl_removes_match(self): + samples = [ + { + 'text': 'Hello World 123', + 'target': 'Hello World ', + }, + ] + op = ReplaceContentMapper(pattern=r'\d+', repl='') + self._run_helper(op, samples) + + if __name__ == '__main__': unittest.main() From 50ab163d34039399328e54f2b2aad8d8639c448c Mon Sep 17 00:00:00 2001 From: cmgzn Date: Mon, 8 Jun 2026 18:48:50 +0800 Subject: [PATCH 06/13] test: replace load_ops placeholder with real tests - load_ops: single/multiple ops, args passing, _op_cfg stored, order preserved, empty list --- tests/ops/test_load.py | 56 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/tests/ops/test_load.py b/tests/ops/test_load.py index eb5d3dfb350..18960828d00 100644 --- a/tests/ops/test_load.py +++ b/tests/ops/test_load.py @@ -1,12 +1,62 @@ import unittest +from data_juicer.ops.load import load_ops from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class LoadOPsTest(DataJuicerTestCaseBase): - def test_placeholder(self): - # placeholder for test - pass + """Test load_ops: instantiates operators from process config list.""" + + def test_load_single_op(self): + process_list = [ + {"clean_email_mapper": {}}, + ] + ops = load_ops(process_list) + self.assertEqual(len(ops), 1) + self.assertEqual(ops[0]._name, "clean_email_mapper") + + def test_load_multiple_ops(self): + process_list = [ + {"clean_email_mapper": {}}, + {"fix_unicode_mapper": {}}, + ] + ops = load_ops(process_list) + self.assertEqual(len(ops), 2) + self.assertEqual(ops[0]._name, "clean_email_mapper") + self.assertEqual(ops[1]._name, "fix_unicode_mapper") + + def test_op_with_args(self): + process_list = [ + {"text_length_filter": {"min_len": 10, "max_len": 1000}}, + ] + ops = load_ops(process_list) + self.assertEqual(len(ops), 1) + self.assertEqual(ops[0].min_len, 10) + self.assertEqual(ops[0].max_len, 1000) + + def test_op_cfg_stored(self): + """Each op should have its config stored in _op_cfg.""" + cfg = {"clean_email_mapper": {"repl": ""}} + ops = load_ops([cfg]) + self.assertEqual(ops[0]._op_cfg, cfg) + + def test_empty_process_list(self): + ops = load_ops([]) + self.assertEqual(ops, []) + + def test_op_order_preserved(self): + process_list = [ + {"fix_unicode_mapper": {}}, + {"clean_email_mapper": {}}, + {"whitespace_normalization_mapper": {}}, + ] + ops = load_ops(process_list) + names = [op._name for op in ops] + self.assertEqual(names, [ + "fix_unicode_mapper", + "clean_email_mapper", + "whitespace_normalization_mapper", + ]) if __name__ == '__main__': From 449a3b01c4db602a35ec027c68961c647f68e802 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 09:52:55 +0800 Subject: [PATCH 07/13] fix(test): make test_custom_pattern actually test custom pattern behavior The original test used an input ('user@example.org') that matched both the default and custom patterns identically, so the test would pass even without the custom pattern feature. Now uses '{admin}@srv.co' which only matches the custom pattern (with curly braces in char class), with a precondition assertion proving the default pattern does NOT match. --- tests/ops/mapper/test_clean_email_mapper.py | 27 +++++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py index ee4c62dfe1b..ee6c48eadfe 100644 --- a/tests/ops/mapper/test_clean_email_mapper.py +++ b/tests/ops/mapper/test_clean_email_mapper.py @@ -52,16 +52,33 @@ def test_replace_email(self): def test_custom_pattern(self): - """Custom pattern via r-string should have r'...' markers stripped.""" + """Custom pattern must produce different results from default pattern. + + Default pattern: [A-Za-z0-9.\\-+_]+@[a-z0-9.\\-+_]+\\.[a-z]+ + Custom pattern below also matches {user}@host forms (curly braces). + The input '{admin}@srv.co' does NOT match the default pattern (because + '{' and '}' are not in the default character class), but DOES match + the custom one — proving the custom pattern is actually used. + """ + input_text = 'Contact: {admin}@srv.co for info' + # Verify default pattern does NOT match this input + default_op = CleanEmailMapper(repl='') + ds = Dataset.from_list([{'text': input_text}]) + default_result = ds.map(default_op.process, batch_size=2) + self.assertEqual(default_result[0]['text'], input_text, + "Precondition failed: default pattern should NOT " + "match '{admin}@srv.co'") + + # Now verify custom pattern DOES match samples = [{ - 'text': 'Contact: user@example.org for info', + 'text': input_text, 'target': 'Contact: for info', }] - op = CleanEmailMapper( - pattern=r"r'[A-Za-z0-9.+_]+@[a-z0-9.+_]+\.[a-z]+'", + custom_op = CleanEmailMapper( + pattern=r"r'[A-Za-z0-9.+_{}\-]+@[a-z0-9.+_\-]+\.[a-z]+'", repl='', ) - self._run_clean_email(op, samples) + self._run_clean_email(custom_op, samples) def test_no_email_unchanged(self): samples = [{ From 4e253d13e89534e6ca3fd5d1b6d36b1072c571fc Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 10:03:48 +0800 Subject: [PATCH 08/13] test: enhance test_mixins assertions, clean test_json_formatter debug print, refactor test_download --- tests/download/test_download.py | 116 ++++++++++++---------------- tests/format/test_json_formatter.py | 1 - tests/ops/test_mixins.py | 20 ++++- 3 files changed, 67 insertions(+), 70 deletions(-) diff --git a/tests/download/test_download.py b/tests/download/test_download.py index 8570744a85a..6b6d3761238 100644 --- a/tests/download/test_download.py +++ b/tests/download/test_download.py @@ -11,6 +11,18 @@ ) from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +# Field schema that download_wikipedia promises to produce. +_WIKI_OUTPUT_FORMAT = { + "text": str, + "title": str, + "id": str, + "url": str, + "language": str, + "source_id": str, + "filename": str, +} + class TestDownload(DataJuicerTestCaseBase): def setUp(self): super().setUp() @@ -68,99 +80,73 @@ def mock_get_response(*args, **kwargs): mock_get.side_effect = mock_get_response urls = get_wikipedia_urls(dump_date=dump_date) - - # Verify returned URLs - assert len(urls) == 3 - assert urls[0] == expected_urls[0] - assert urls[1] == expected_urls[1] - assert urls[2] == expected_urls[2] + self.assertEqual(urls, expected_urls) @patch('data_juicer.download.wikipedia.get_wikipedia_urls') - @patch('data_juicer.download.downloader.download_and_extract') - @patch('data_juicer.download.wikipedia.download_and_extract') # Add this patch too - def test_wikipedia_download(self, mock_download_and_extract_wiki, mock_download_and_extract, mock_get_urls): + @patch('data_juicer.download.wikipedia.download_and_extract') + def test_wikipedia_download(self, mock_download_and_extract, mock_get_urls): + # download_wikipedia promises to: (1) fetch urls for the language/date, + # (2) clip them by url_limit, (3) derive one output path per url, + # (4) hand real downloader/iterator/extractor + the field schema to + # download_and_extract, and (5) return its dataset unchanged. dump_date = "20241101" url_limit = 1 item_limit = 50 - # Mock the URLs returned mock_urls = [ - "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2" + "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2", + "https://dumps.wikimedia.org/enwiki/20241101/enwiki-20241101-pages-articles-multistream2.xml-p41243p151573.bz2", ] mock_get_urls.return_value = mock_urls - # Create expected output paths - output_paths = [ - os.path.join(self.temp_dir, "enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2.jsonl") + expected_output_paths = [ + os.path.join( + self.temp_dir, + "enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2.jsonl", + ) ] - # Create mock dataset - mock_dataset = Dataset.from_dict({ + returned_dataset = Dataset.from_dict({ 'text': [f"Article {i}" for i in range(10)], 'title': [f"Title {i}" for i in range(10)], 'id': [str(i) for i in range(10)], 'url': [f"https://en.wikipedia.org/wiki/Title_{i}" for i in range(10)], 'language': ['en'] * 10, 'source_id': ['enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2'] * 10, - 'filename': ['enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2.jsonl'] * 10 + 'filename': ['enwiki-20241101-pages-articles-multistream1.xml-p1p41242.bz2.jsonl'] * 10, }) + mock_download_and_extract.return_value = returned_dataset - # Set return value for both mocks - mock_download_and_extract.return_value = mock_dataset - mock_download_and_extract_wiki.return_value = mock_dataset - - # Add print statements to debug - print("Before calling download_wikipedia") - - # Run the function result = download_wikipedia( self.temp_dir, dump_date=dump_date, url_limit=url_limit, - item_limit=item_limit + item_limit=item_limit, ) - print("After calling download_wikipedia") - - # Print mock call counts - print(f"mock_download_and_extract.call_count: {mock_download_and_extract.call_count}") - print(f"mock_download_and_extract_wiki.call_count: {mock_download_and_extract_wiki.call_count}") - - # Verify the calls mock_get_urls.assert_called_once_with(language='en', dump_date=dump_date) - - # Try both mocks - if mock_download_and_extract.call_count > 0: - mock = mock_download_and_extract - else: - mock = mock_download_and_extract_wiki - - # Verify download_and_extract was called with correct arguments - mock.assert_called_once() - call_args = mock.call_args[0] - assert call_args[0] == mock_urls[:url_limit] # urls (limited by url_limit) - assert call_args[1] == output_paths # output_paths - assert isinstance(call_args[2], WikipediaDownloader) # downloader - assert isinstance(call_args[3], WikipediaIterator) # iterator - assert isinstance(call_args[4], WikipediaExtractor) # extractor - - # Verify the output format - expected_format = { - 'text': str, - 'title': str, - 'id': str, - 'url': str, - 'language': str, - 'source_id': str, - 'filename': str, - } - assert call_args[5] == expected_format # output_format - - # Verify the result - assert isinstance(result, Dataset) - assert len(result) == 10 - assert all(field in result.features for field in expected_format.keys()) + + mock_download_and_extract.assert_called_once() + call_args = mock_download_and_extract.call_args[0] + # urls clipped to url_limit (only the first of two) + self.assertEqual(call_args[0], mock_urls[:url_limit]) + self.assertEqual(call_args[1], expected_output_paths) + self.assertIsInstance(call_args[2], WikipediaDownloader) + self.assertIsInstance(call_args[3], WikipediaIterator) + self.assertIsInstance(call_args[4], WikipediaExtractor) + self.assertEqual(call_args[5], _WIKI_OUTPUT_FORMAT) + # item_limit is forwarded as a keyword argument + self.assertEqual( + mock_download_and_extract.call_args[1].get('item_limit'), item_limit) + + # the dataset from download_and_extract is returned unchanged + self.assertIs(result, returned_dataset) + self.assertEqual(len(result), 10) + self.assertEqual( + sorted(result.features.keys()), + sorted(_WIKI_OUTPUT_FORMAT.keys()), + ) class ValidateSnapshotFormatTest(DataJuicerTestCaseBase): diff --git a/tests/format/test_json_formatter.py b/tests/format/test_json_formatter.py index 055e3e6e369..583f1a9d58d 100644 --- a/tests/format/test_json_formatter.py +++ b/tests/format/test_json_formatter.py @@ -24,7 +24,6 @@ def setUp(self): self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "structured") self._file = os.path.join(self._path, "demo-dataset.jsonl") - print(self._file) # create compressed variants for testing # create a temp directory to hold generated compressed files self._temp_dir = tempfile.mkdtemp() diff --git a/tests/ops/test_mixins.py b/tests/ops/test_mixins.py index 42388e3efec..a7c82369811 100644 --- a/tests/ops/test_mixins.py +++ b/tests/ops/test_mixins.py @@ -1,7 +1,8 @@ +import json import os import time import unittest -from unittest.mock import MagicMock, patch, ANY +from unittest.mock import MagicMock, patch from data_juicer.ops.mixins import EventDrivenMixin, NotificationMixin from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -168,8 +169,11 @@ def test_send_notification_no_config(self): def test_send_notification_none_type(self): obj = self._make_obj(enabled=True) - # notification_type=None should just log and return None + mock_handler = MagicMock(return_value=True) + obj.notification_handlers = {"email": mock_handler} + # notification_type=None must dispatch to no handler at all obj.send_notification("hello", notification_type=None) + mock_handler.assert_not_called() def test_send_notification_unsupported_type(self): obj = self._make_obj(enabled=True) @@ -181,8 +185,12 @@ def test_send_notification_channel_disabled(self): enabled=True, email={"enabled": False}, ) + mock_handler = MagicMock(return_value=True) + obj.notification_handlers["email"] = mock_handler result = obj.send_notification("hello", notification_type="email") + # disabled channel returns True but must NOT actually send self.assertTrue(result) + mock_handler.assert_not_called() def test_send_notification_kwargs_override(self): obj = self._make_obj( @@ -426,8 +434,12 @@ def test_send_slack_notification_success(self, mock_post): result = obj.send_notification("hello slack", notification_type="slack") self.assertTrue(result) mock_post.assert_called_once() - call_kwargs = mock_post.call_args - self.assertIn("hello slack", call_kwargs[1].get("data", call_kwargs[0][1] if len(call_kwargs[0]) > 1 else "")) + # url is the first positional arg; payload is JSON in the `data` kwarg + self.assertEqual( + mock_post.call_args[0][0], "https://hooks.slack.com/test") + payload = json.loads(mock_post.call_args[1]["data"]) + self.assertEqual(payload["text"], "hello slack") + self.assertEqual(payload["channel"], "#test") @patch("requests.post") def test_send_slack_notification_failure(self, mock_post): From 8af5f5a71c276d240f62ec0222500f8cf371328d Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 10:15:07 +0800 Subject: [PATCH 09/13] test: improve ckpt_utils coverage from 84% to 98% - add checkpoint_exists, empty dataset save, unknown strategy, malformed filenames, ray save/load with event_logger tests --- tests/utils/test_ckpt_utils.py | 194 +++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) diff --git a/tests/utils/test_ckpt_utils.py b/tests/utils/test_ckpt_utils.py index 95582d967bd..0bb3c25a6e2 100644 --- a/tests/utils/test_ckpt_utils.py +++ b/tests/utils/test_ckpt_utils.py @@ -536,6 +536,200 @@ def test_init_disabled_strategy_disables_checkpointing(self): self.assertFalse(mgr.checkpoint_enabled) +class CkptUtilsEdgeCaseTest(DataJuicerTestCaseBase): + """Tests for edge cases in CheckpointManager and base class.""" + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp(prefix='test_ckpt_edge_') + + def tearDown(self): + super().tearDown() + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + + def test_checkpoint_exists_true(self): + """checkpoint_exists returns True for existing path.""" + mgr = CheckpointManager( + self.tmp_dir, original_process_list=[] + ) + existing_file = os.path.join(self.tmp_dir, "somefile") + with open(existing_file, "w") as f: + f.write("data") + self.assertTrue(mgr.checkpoint_exists(existing_file)) + + def test_checkpoint_exists_false(self): + """checkpoint_exists returns False for non-existing path.""" + mgr = CheckpointManager( + self.tmp_dir, original_process_list=[] + ) + self.assertFalse( + mgr.checkpoint_exists(os.path.join(self.tmp_dir, "no_such_file")) + ) + + def test_save_checkpoint_empty_dataset(self): + """Saving an empty dataset logs warning and still writes op record.""" + ckpt_path = os.path.join(self.tmp_dir, "ckpt_empty") + mgr = CheckpointManager(ckpt_path, original_process_list=[]) + mgr.record({"op_a": {}}) + + empty_ds = NestedDataset.from_dict({"text": []}) + result = mgr.save_checkpoint(empty_ds) + + self.assertEqual(result, mgr.ckpt_ds_dir) + # Op record should still be written + self.assertTrue(os.path.exists(mgr.ckpt_op_record)) + with open(mgr.ckpt_op_record) as f: + ops = json.load(f) + self.assertEqual(ops, [{"op_a": {}}]) + + def test_should_checkpoint_unknown_strategy(self): + """Unknown strategy logs warning and defaults to True.""" + mgr = RayCheckpointManager( + ckpt_dir=self.tmp_dir, + checkpoint_enabled=True, + checkpoint_strategy=CheckpointStrategy.EVERY_OP, + ) + # Force an unknown strategy value by bypassing the enum + mgr.checkpoint_strategy = "totally_bogus" + result = mgr.should_checkpoint(0, "op_a") + self.assertTrue(result) + + def test_find_latest_checkpoint_malformed_filenames(self): + """find_latest_checkpoint skips files with unparseable names.""" + mgr = RayCheckpointManager(ckpt_dir=self.tmp_dir) + + # Create a file that starts with checkpoint_op_ and ends with + # _partition_0000.parquet but has a non-integer op index + bad_file = os.path.join( + self.tmp_dir, + "checkpoint_op_XXXX_partition_0000.parquet" + ) + with open(bad_file, "w") as f: + f.write("mock") + + # Also create a valid one to ensure it still finds it + good_file = os.path.join( + self.tmp_dir, + "checkpoint_op_0003_partition_0000.parquet" + ) + with open(good_file, "w") as f: + f.write("mock") + + result = mgr.find_latest_checkpoint(partition_id=0) + self.assertIsNotNone(result) + self.assertEqual(result[0], 3) + + +class RayCheckpointSaveLoadTest(DataJuicerTestCaseBase): + """Tests for RayCheckpointManager save/load with mocked ray.""" + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp(prefix='test_ray_ckpt_sl_') + + def tearDown(self): + super().tearDown() + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + + def test_save_checkpoint_calls_write_parquet(self): + """save_checkpoint extracts .data and writes parquet.""" + mgr = RayCheckpointManager(ckpt_dir=self.tmp_dir) + + mock_ray_data = MagicMock() + mock_dataset = MagicMock() + mock_dataset.data = mock_ray_data + + path = mgr.save_checkpoint( + dataset=mock_dataset, op_idx=1, op_name="filter_a", + ) + + mock_ray_data.write_parquet.assert_called_once_with(path) + self.assertIn("checkpoint_op_0001", path) + + def test_save_checkpoint_with_event_logger(self): + """save_checkpoint logs event when event_logger is provided.""" + mock_logger = MagicMock() + mock_logger._log_event = MagicMock() + mgr = RayCheckpointManager( + ckpt_dir=self.tmp_dir, event_logger=mock_logger, + ) + + mock_dataset = MagicMock() + mock_dataset.data = MagicMock() + + mgr.save_checkpoint( + dataset=mock_dataset, op_idx=2, op_name="clean_op", + ) + + mock_logger._log_event.assert_called_once() + call_kwargs = mock_logger._log_event.call_args[1] + self.assertEqual(call_kwargs["operation_name"], "clean_op") + self.assertEqual(call_kwargs["operation_idx"], 2) + + def test_load_checkpoint_with_event_logger(self): + """load_checkpoint logs event when event_logger is provided.""" + mock_logger = MagicMock() + mock_logger._log_event = MagicMock() + mgr = RayCheckpointManager( + ckpt_dir=self.tmp_dir, event_logger=mock_logger, + ) + + # Create a mock checkpoint file so path exists + ckpt_file = os.path.join( + self.tmp_dir, "checkpoint_op_0005_partition_0000.parquet" + ) + with open(ckpt_file, "w") as f: + f.write("mock") + + mock_ray_ds = MagicMock() + with patch( + "data_juicer.utils.lazy_loader.LazyLoader" + ) as mock_lazy: + mock_ray = MagicMock() + mock_ray.data.read_parquet.return_value = mock_ray_ds + mock_lazy.return_value = mock_ray + + result = mgr.load_checkpoint( + op_idx=5, op_name="my_op", partition_id=0, + ) + + self.assertIs(result, mock_ray_ds) + mock_logger._log_event.assert_called_once() + + def test_load_checkpoint_with_cfg_wraps_in_ray_dataset(self): + """load_checkpoint wraps result in RayDataset when cfg is provided.""" + mgr = RayCheckpointManager(ckpt_dir=self.tmp_dir) + + ckpt_file = os.path.join( + self.tmp_dir, "checkpoint_op_0002_partition_0000.parquet" + ) + with open(ckpt_file, "w") as f: + f.write("mock") + + mock_ray_ds = MagicMock() + mock_cfg = MagicMock() + mock_wrapped = MagicMock() + + with patch( + "data_juicer.utils.lazy_loader.LazyLoader" + ) as mock_lazy, patch( + "data_juicer.core.data.ray_dataset.RayDataset", + return_value=mock_wrapped, + ) as mock_ray_dataset_cls: + mock_ray = MagicMock() + mock_ray.data.read_parquet.return_value = mock_ray_ds + mock_lazy.return_value = mock_ray + + result = mgr.load_checkpoint( + op_idx=2, partition_id=0, cfg=mock_cfg, + ) + + mock_ray_dataset_cls.assert_called_once_with(mock_ray_ds, cfg=mock_cfg) + self.assertIs(result, mock_wrapped) + + class CheckpointStrategyEnumTest(DataJuicerTestCaseBase): """Tests for CheckpointStrategy enum.""" From c70f156821573f01fd60038b2ba5aec4a18da239 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 10:23:32 +0800 Subject: [PATCH 10/13] test: add logger_utils tests for is_notebook, StreamToLoguru, HiddenPrints, redirect_sys_output, make_log_summarization, setup_logger branches --- tests/utils/test_logger_utils.py | 228 +++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) diff --git a/tests/utils/test_logger_utils.py b/tests/utils/test_logger_utils.py index 7c76c4cb2ad..77ddfef4a41 100644 --- a/tests/utils/test_logger_utils.py +++ b/tests/utils/test_logger_utils.py @@ -1,5 +1,7 @@ import os +import sys import unittest +import unittest.mock import jsonlines import regex as re from loguru import logger @@ -114,5 +116,231 @@ def test_make_log_summarization(self): self.assertTrue(os.path.exists(log_fn[0])) +class IsNotebookTest(DataJuicerTestCaseBase): + """Tests for is_notebook() without breaking the global logger.""" + + def test_is_notebook_returns_false_outside_notebook(self): + """In a normal test environment, is_notebook() should return False.""" + from data_juicer.utils.logger_utils import is_notebook + self.assertFalse(is_notebook()) + + def test_is_notebook_returns_false_when_get_ipython_is_none(self): + """When get_ipython is None, is_notebook returns False.""" + from data_juicer.utils import logger_utils + from data_juicer.utils.logger_utils import is_notebook + + original = logger_utils.get_ipython + try: + logger_utils.get_ipython = None + self.assertFalse(is_notebook()) + finally: + logger_utils.get_ipython = original + + def test_is_notebook_exception_returns_false(self): + """When get_ipython() raises, is_notebook returns False.""" + from data_juicer.utils import logger_utils + from data_juicer.utils.logger_utils import is_notebook + + original = logger_utils.get_ipython + + def exploding(): + raise RuntimeError("boom") + + try: + logger_utils.get_ipython = exploding + self.assertFalse(is_notebook()) + finally: + logger_utils.get_ipython = original + + +class StreamToLoguruTest(DataJuicerTestCaseBase): + """Tests for StreamToLoguru without modifying global sys.stdout/stderr.""" + + def test_write_non_caller_module(self): + """Writing from a non-caller module goes through raw info path.""" + from data_juicer.utils.logger_utils import StreamToLoguru + stream = StreamToLoguru(level="INFO", caller_names=("nonexistent_module",)) + # Should not raise + stream.write("hello from test\n") + + def test_getvalue_returns_buffer_content(self): + """getvalue() returns accumulated buffer content.""" + from data_juicer.utils.logger_utils import StreamToLoguru + stream = StreamToLoguru() + stream.write("first ") + stream.write("second") + value = stream.getvalue() + self.assertIn("first ", value) + self.assertIn("second", value) + + def test_flush_does_not_raise(self): + """flush() should not raise.""" + from data_juicer.utils.logger_utils import StreamToLoguru + stream = StreamToLoguru() + stream.flush() # Should not raise + + def test_isatty_returns_false(self): + """isatty() should return False.""" + from data_juicer.utils.logger_utils import StreamToLoguru + stream = StreamToLoguru() + self.assertFalse(stream.isatty()) + + def test_buffer_truncate_on_write(self): + """Buffer is truncated to BUFFER_SIZE after write.""" + from data_juicer.utils.logger_utils import StreamToLoguru + stream = StreamToLoguru() + stream.BUFFER_SIZE = 10 # Small buffer for test + stream.write("a" * 100) + # After truncate(10), the buffer position is at 10 but content + # up to that point is retained + self.assertLessEqual(stream.buffer.tell(), 100) + + +class HiddenPrintsTest(DataJuicerTestCaseBase): + """Tests for the HiddenPrints context manager.""" + + def test_hidden_prints_suppresses_stdout(self): + """HiddenPrints redirects stdout to devnull.""" + from data_juicer.utils.logger_utils import HiddenPrints + import io + + original_stdout = sys.stdout + with HiddenPrints(): + # stdout should be redirected to devnull + self.assertNotEqual(sys.stdout, original_stdout) + print("this should be suppressed") + # After exiting, stdout should be restored + self.assertIs(sys.stdout, original_stdout) + + def test_hidden_prints_restores_stdout_on_exception(self): + """HiddenPrints restores stdout even if exception occurs inside.""" + from data_juicer.utils.logger_utils import HiddenPrints + + original_stdout = sys.stdout + try: + with HiddenPrints(): + raise ValueError("test error") + except ValueError: + pass + self.assertIs(sys.stdout, original_stdout) + + +class GetCallerNameTest(DataJuicerTestCaseBase): + """Tests for get_caller_name().""" + + def test_get_caller_name_depth_zero(self): + """get_caller_name(0) returns the caller's module name.""" + from data_juicer.utils.logger_utils import get_caller_name + name = get_caller_name(depth=0) + # Should be this test module's name + self.assertIn("test_logger_utils", name) + + +class RedirectSysOutputTest(DataJuicerTestCaseBase): + """Tests for redirect_sys_output() — verifies notebook guard.""" + + def test_redirect_noop_in_notebook(self): + """redirect_sys_output does nothing when is_notebook() returns True.""" + from data_juicer.utils import logger_utils + from data_juicer.utils.logger_utils import redirect_sys_output + + original_stdout = sys.stdout + original_stderr = sys.stderr + + with unittest.mock.patch.object( + logger_utils, "is_notebook", return_value=True + ): + redirect_sys_output("INFO") + + # stdout/stderr should be unchanged + self.assertIs(sys.stdout, original_stdout) + self.assertIs(sys.stderr, original_stderr) + + +class MakeLogSummarizationTest(DataJuicerTestCaseBase): + """Tests for make_log_summarization edge cases.""" + + def test_make_log_summarization_no_log_file(self): + """make_log_summarization returns early when no log file exists.""" + with unittest.mock.patch( + "data_juicer.utils.logger_utils.get_log_file_path", + return_value=None, + ): + # Should return None (early exit) without error + result = make_log_summarization() + self.assertIsNone(result) + + +class SetupLoggerIsolatedTest(DataJuicerTestCaseBase): + """Tests for setup_logger branches that don't break the global logger. + + These tests reset LOGGER_SETUP before and after each test to avoid + contaminating other tests. + """ + + def setUp(self): + super().setUp() + self.tmp_dir = os.path.join('tmp', 'test_setup_logger_isolated') + os.makedirs(self.tmp_dir, exist_ok=True) + # Save original state + self._orig_setup = data_juicer.utils.logger_utils.LOGGER_SETUP + self._orig_handlers = dict(logger._core.handlers) + self._orig_stdout = sys.stdout + self._orig_stderr = sys.stderr + + def tearDown(self): + # Restore global state + data_juicer.utils.logger_utils.LOGGER_SETUP = self._orig_setup + sys.stdout = self._orig_stdout + sys.stderr = self._orig_stderr + # Remove any handlers we added + current_ids = set(logger._core.handlers.keys()) + orig_ids = set(self._orig_handlers.keys()) + for handler_id in current_ids - orig_ids: + try: + logger.remove(handler_id) + except ValueError: + pass + if os.path.exists(self.tmp_dir): + os.system(f'rm -rf {self.tmp_dir}') + super().tearDown() + + def test_setup_logger_override_mode_removes_existing_file(self): + """setup_logger in override mode removes an existing log file.""" + log_file = os.path.join(self.tmp_dir, 'log.txt') + with open(log_file, 'w') as f: + f.write("old content") + self.assertTrue(os.path.exists(log_file)) + + data_juicer.utils.logger_utils.LOGGER_SETUP = False + setup_logger( + self.tmp_dir, filename='log.txt', mode='o', redirect=False, + ) + # The old file should have been removed (a new one may be created + # by the file sink, but the old content should be gone) + if os.path.exists(log_file): + with open(log_file) as f: + self.assertNotIn("old content", f.read()) + + def test_setup_logger_redirect_auto(self): + """setup_logger with redirect='auto' resolves based on is_notebook.""" + data_juicer.utils.logger_utils.LOGGER_SETUP = False + # In non-notebook env, redirect='auto' should redirect + setup_logger( + self.tmp_dir, filename='log_auto.txt', redirect='auto', + ) + # LOGGER_SETUP should be True after setup + self.assertTrue(data_juicer.utils.logger_utils.LOGGER_SETUP) + + def test_setup_logger_skips_when_already_setup(self): + """setup_logger is a no-op when LOGGER_SETUP is already True.""" + data_juicer.utils.logger_utils.LOGGER_SETUP = True + handlers_before = set(logger._core.handlers.keys()) + setup_logger(self.tmp_dir, filename='log_skip.txt', redirect=False) + handlers_after = set(logger._core.handlers.keys()) + # No new handlers should be added + self.assertEqual(handlers_before, handlers_after) + + if __name__ == '__main__': unittest.main() From d6cf62ffaa1a35290c0e6f7e177618916d92c6f2 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 15:05:52 +0800 Subject: [PATCH 11/13] test: enhance test assertions and add missing env vars --- tests/core/executor/test_pipeline_dag.py | 8 ++++---- tests/ops/test_mixins.py | 9 ++++++++- tests/utils/test_logger_utils.py | 7 +++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/core/executor/test_pipeline_dag.py b/tests/core/executor/test_pipeline_dag.py index dbd5c82ac7a..3b0c9ac014e 100644 --- a/tests/core/executor/test_pipeline_dag.py +++ b/tests/core/executor/test_pipeline_dag.py @@ -239,14 +239,14 @@ def test_mark_node_failed_with_start_time(self): self.assertEqual(node["start_time"], start_time) def test_mark_node_failed_without_start_time(self): - """When a node fails before being started, duration should be ~0.""" + """When a node fails before being started, duration should be 0.""" self.dag.nodes["n1"] = _make_node("n1") self.dag.mark_node_failed("n1", "failed early") node = self.dag.nodes["n1"] self.assertEqual(node["status"], DAGNodeStatus.FAILED.value) self.assertIsNotNone(node["actual_duration"]) - self.assertLessEqual(node["actual_duration"], 1.0) + self.assertEqual(node["actual_duration"], 0.0) def test_mark_node_failed_nonexistent_node(self): self.dag.mark_node_failed("nonexistent", "error") @@ -476,13 +476,13 @@ def test_mark_completed_explicit_duration(self): self.assertEqual(self.dag.nodes["n1"]["actual_duration"], 42.0) def test_mark_completed_without_start(self): - """Complete a node that was never started - duration should be ~0.""" + """Complete a node that was never started - duration should be 0.""" self.dag.nodes["n1"] = _make_node("n1") self.dag.mark_node_completed("n1") node = self.dag.nodes["n1"] self.assertEqual(node["status"], DAGNodeStatus.COMPLETED.value) - self.assertLessEqual(node["actual_duration"], 1.0) + self.assertEqual(node["actual_duration"], 0.0) def test_execution_summary_mixed_states(self): self.dag.nodes["a"] = _make_node("a", execution_order=0) diff --git a/tests/ops/test_mixins.py b/tests/ops/test_mixins.py index a7c82369811..6146b797770 100644 --- a/tests/ops/test_mixins.py +++ b/tests/ops/test_mixins.py @@ -352,7 +352,12 @@ def test_send_email_ssl_with_cert(self, mock_ssl_ctx, mock_smtp_ssl): @patch.dict( os.environ, - {"DATA_JUICER_EMAIL_PASSWORD": "env_password"}, + { + "DATA_JUICER_EMAIL_PASSWORD": "env_password", + "DATA_JUICER_SMTP_TEST_COM_PASSWORD": "", + "DATA_JUICER_EMAIL_CERT": "", + "DATA_JUICER_EMAIL_KEY": "", + }, clear=False, ) @patch("smtplib.SMTP_SSL") @@ -570,6 +575,8 @@ def test_send_email_no_include_port(self, mock_smtp_ssl): { "DATA_JUICER_EMAIL_CERT": "/env/cert.pem", "DATA_JUICER_EMAIL_KEY": "/env/key.pem", + "DATA_JUICER_EMAIL_PASSWORD": "", + "DATA_JUICER_SMTP_TEST_COM_PASSWORD": "", }, clear=False, ) diff --git a/tests/utils/test_logger_utils.py b/tests/utils/test_logger_utils.py index 77ddfef4a41..5937379ebf8 100644 --- a/tests/utils/test_logger_utils.py +++ b/tests/utils/test_logger_utils.py @@ -186,14 +186,13 @@ def test_isatty_returns_false(self): self.assertFalse(stream.isatty()) def test_buffer_truncate_on_write(self): - """Buffer is truncated to BUFFER_SIZE after write.""" + """Buffer content is truncated to BUFFER_SIZE after write.""" from data_juicer.utils.logger_utils import StreamToLoguru stream = StreamToLoguru() stream.BUFFER_SIZE = 10 # Small buffer for test stream.write("a" * 100) - # After truncate(10), the buffer position is at 10 but content - # up to that point is retained - self.assertLessEqual(stream.buffer.tell(), 100) + # truncate(10) keeps only the first 10 characters of content + self.assertEqual(len(stream.buffer.getvalue()), 10) class HiddenPrintsTest(DataJuicerTestCaseBase): From e7785c2e995529340665455205d90571e1af58bb Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 15:40:14 +0800 Subject: [PATCH 12/13] test: refactor email password env var test setup --- tests/ops/test_mixins.py | 118 +++++++++++++++++++-------------------- 1 file changed, 56 insertions(+), 62 deletions(-) diff --git a/tests/ops/test_mixins.py b/tests/ops/test_mixins.py index 6146b797770..99e70ddb622 100644 --- a/tests/ops/test_mixins.py +++ b/tests/ops/test_mixins.py @@ -350,38 +350,35 @@ def test_send_email_ssl_with_cert(self, mock_ssl_ctx, mock_smtp_ssl): result = obj.send_notification("test msg", notification_type="email") self.assertTrue(result) - @patch.dict( - os.environ, - { + def test_send_email_password_from_env(self): + env_vars = { "DATA_JUICER_EMAIL_PASSWORD": "env_password", "DATA_JUICER_SMTP_TEST_COM_PASSWORD": "", "DATA_JUICER_EMAIL_CERT": "", "DATA_JUICER_EMAIL_KEY": "", - }, - clear=False, - ) - @patch("smtplib.SMTP_SSL") - def test_send_email_password_from_env(self, mock_smtp_ssl): - mock_server = MagicMock() - mock_smtp_ssl.return_value.__enter__ = MagicMock( - return_value=mock_server - ) - mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) - - obj = self._make_obj( - enabled=True, - email={ - "smtp_server": "smtp.test.com", - "smtp_port": 465, - "use_ssl": True, - "username": "user@test.com", - "sender_email": "user@test.com", - "recipients": ["dest@test.com"], - # no password in config; should use env var - }, - ) - result = obj.send_notification("test msg", notification_type="email") - self.assertTrue(result) + } + with patch.dict(os.environ, env_vars, clear=False), \ + patch("smtplib.SMTP_SSL") as mock_smtp_ssl: + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "username": "user@test.com", + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + # no password in config; should use env var + }, + ) + result = obj.send_notification("test msg", notification_type="email") + self.assertTrue(result) def test_send_email_sender_name_formatting(self): obj = self._make_obj( @@ -570,44 +567,41 @@ def test_send_email_no_include_port(self, mock_smtp_ssl): call_args = mock_smtp_ssl.call_args self.assertEqual(call_args[0][0], "smtp.test.com") - @patch.dict( - os.environ, - { + def test_send_email_cert_from_env(self): + env_vars = { "DATA_JUICER_EMAIL_CERT": "/env/cert.pem", "DATA_JUICER_EMAIL_KEY": "/env/key.pem", "DATA_JUICER_EMAIL_PASSWORD": "", "DATA_JUICER_SMTP_TEST_COM_PASSWORD": "", - }, - clear=False, - ) - @patch("smtplib.SMTP_SSL") - @patch("ssl.create_default_context") - def test_send_email_cert_from_env(self, mock_ssl_ctx, mock_smtp_ssl): - mock_server = MagicMock() - mock_smtp_ssl.return_value.__enter__ = MagicMock( - return_value=mock_server - ) - mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) - mock_ctx = MagicMock() - mock_ssl_ctx.return_value = mock_ctx - - obj = self._make_obj( - enabled=True, - email={ - "smtp_server": "smtp.test.com", - "smtp_port": 465, - "use_ssl": True, - "use_cert_auth": True, - "sender_email": "user@test.com", - "recipients": ["dest@test.com"], - # cert/key from env, not config - }, - ) - result = obj.send_notification("test", notification_type="email") - self.assertTrue(result) - mock_ctx.load_cert_chain.assert_called_once_with( - certfile="/env/cert.pem", keyfile="/env/key.pem" - ) + } + with patch.dict(os.environ, env_vars, clear=False), \ + patch("smtplib.SMTP_SSL") as mock_smtp_ssl, \ + patch("ssl.create_default_context") as mock_ssl_ctx: + mock_server = MagicMock() + mock_smtp_ssl.return_value.__enter__ = MagicMock( + return_value=mock_server + ) + mock_smtp_ssl.return_value.__exit__ = MagicMock(return_value=False) + mock_ctx = MagicMock() + mock_ssl_ctx.return_value = mock_ctx + + obj = self._make_obj( + enabled=True, + email={ + "smtp_server": "smtp.test.com", + "smtp_port": 465, + "use_ssl": True, + "use_cert_auth": True, + "sender_email": "user@test.com", + "recipients": ["dest@test.com"], + # cert/key from env, not config + }, + ) + result = obj.send_notification("test", notification_type="email") + self.assertTrue(result) + mock_ctx.load_cert_chain.assert_called_once_with( + certfile="/env/cert.pem", keyfile="/env/key.pem" + ) if __name__ == "__main__": From 7b2da57796e30fa900fd2cdd112d55d0370badd0 Mon Sep 17 00:00:00 2001 From: cmgzn Date: Tue, 9 Jun 2026 17:47:06 +0800 Subject: [PATCH 13/13] test: add batched API tests for filters, availability_utils, and deduplicator - New: tests/utils/test_availability_utils.py covering _is_package_available and _torch_check_and_set (fixes typo in old filename too) - character_repetition_filter: add direct compute_stats_batched/process_batched tests (57% -> 100%) - alphanumeric_filter: add batched API tests for non-tokenization path (50% -> 82%) - suffix_filter: cover None/str suffixes and reversed_range (87% -> 100%) - word_repetition_filter: add batched API direct call tests (87% -> 91%) - document_line_deduplicator: fix skip_brackets test to actually hit the branch, add compute_hash existing-hash early-return test (84% -> 86%) - Remove misspelled test_availablility_utils.py --- .../test_document_line_deduplicator.py | 25 +++- tests/ops/filter/test_alphanumeric_filter.py | 51 +++++++ .../test_character_repetition_filter.py | 69 ++++++++++ tests/ops/filter/test_suffix_filter.py | 33 +++++ .../ops/filter/test_word_repetition_filter.py | 60 ++++++++ tests/utils/test_availability_utils.py | 130 ++++++++++++++++++ tests/utils/test_availablility_utils.py | 23 ---- 7 files changed, 362 insertions(+), 29 deletions(-) create mode 100644 tests/utils/test_availability_utils.py delete mode 100644 tests/utils/test_availablility_utils.py diff --git a/tests/ops/deduplicator/test_document_line_deduplicator.py b/tests/ops/deduplicator/test_document_line_deduplicator.py index 616a0a56625..b90d6cfd4b0 100644 --- a/tests/ops/deduplicator/test_document_line_deduplicator.py +++ b/tests/ops/deduplicator/test_document_line_deduplicator.py @@ -69,15 +69,16 @@ def test_skip_short_lines(self): def test_skip_brackets(self): """Bracket-only lines are never removed even if high-frequency.""" + # Use multi-char bracket lines to pass min_line_length check ds_list = [ - {"text": "{\nHello\n}"}, - {"text": "{\nWorld\n}"}, - {"text": "{\nFoo\n}"}, + {"text": "{}\nHello\n[]"}, + {"text": "{}\nWorld\n[]"}, + {"text": "{}\nFoo\n[]"}, ] tgt_list = [ - {"text": "{\nHello\n}"}, - {"text": "{\nWorld\n}"}, - {"text": "{\nFoo\n}"}, + {"text": "{}\nHello\n[]"}, + {"text": "{}\nWorld\n[]"}, + {"text": "{}\nFoo\n[]"}, ] dataset = Dataset.from_list(ds_list) op = DocumentLineDeduplicator(frequency_threshold=2) @@ -171,6 +172,18 @@ def test_single_document(self): op = DocumentLineDeduplicator(frequency_threshold=2) self._run_line_dedup(dataset, tgt_list, op) + def test_compute_hash_skips_existing(self): + """compute_hash does not overwrite pre-existing line_hashes.""" + from data_juicer.utils.constant import HashKeys + + ds_list = [ + {"text": "Hello\nWorld", HashKeys.line_hashes: ["pre", "existing"]}, + ] + dataset = Dataset.from_list(ds_list) + op = DocumentLineDeduplicator(frequency_threshold=2) + result = dataset.map(op.compute_hash) + self.assertEqual(result[0][HashKeys.line_hashes], ["pre", "existing"]) + def test_show_num(self): """show_num returns duplicate pair information.""" ds_list = [ diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py index 3d66189d058..c2832cf8296 100644 --- a/tests/ops/filter/test_alphanumeric_filter.py +++ b/tests/ops/filter/test_alphanumeric_filter.py @@ -69,5 +69,56 @@ def test_token_case(self): self.assertDatasetEqual(result, tgt_list) + def test_compute_stats_batched_no_tokenization(self): + """Directly call compute_stats_batched for non-tokenization path.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = AlphanumericFilter(tokenization=False, min_ratio=0.25, batch_size=2) + samples = { + 'text': [ + 'hello world 123', # 13 alnum / 15 total = 0.8667 + '!@#$%^&*()', # 0 alnum / 10 total = 0.0 + '', # empty -> 0.0 + ], + Fields.stats: [{}, {}, {}], + } + result = op.compute_stats_batched(samples) + self.assertAlmostEqual( + result[Fields.stats][0][StatsKeys.alnum_ratio], 13 / 15, places=5) + self.assertAlmostEqual( + result[Fields.stats][1][StatsKeys.alnum_ratio], 0.0) + self.assertAlmostEqual( + result[Fields.stats][2][StatsKeys.alnum_ratio], 0.0) + + def test_process_batched_directly(self): + """Directly call process_batched to verify filtering.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = AlphanumericFilter(tokenization=False, min_ratio=0.25, max_ratio=0.9, batch_size=2) + samples = { + Fields.stats: [ + {StatsKeys.alnum_ratio: 0.8}, # within range -> keep + {StatsKeys.alnum_ratio: 0.1}, # below min -> filtered + {StatsKeys.alnum_ratio: 0.95}, # above max -> filtered + ], + } + keep_flags = list(op.process_batched(samples)) + self.assertEqual(keep_flags, [True, False, False]) + + def test_compute_stats_batched_skips_existing(self): + """Already computed stats are not recomputed.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = AlphanumericFilter(tokenization=False, min_ratio=0.25, batch_size=2) + samples = { + 'text': ['hello'], + Fields.stats: [{StatsKeys.alnum_ratio: 0.99}], # pre-set + } + result = op.compute_stats_batched(samples) + # Should preserve existing value, not recompute + self.assertAlmostEqual( + result[Fields.stats][0][StatsKeys.alnum_ratio], 0.99) + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py index b122da631de..689a4042dd8 100644 --- a/tests/ops/filter/test_character_repetition_filter.py +++ b/tests/ops/filter/test_character_repetition_filter.py @@ -72,5 +72,74 @@ def test_existing_stats(self): self.assertEqual(dataset_after_compute_stats.to_list(), ds_list) + def test_compute_stats_batched_directly(self): + """Call compute_stats_batched directly to verify stat computation.""" + op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=0.4, batch_size=2) + samples = { + 'text': [ + 'aaaaaaaaaaaaaaaaaaaaaaaaa', # all same char -> high ratio + 'abcdefghijklmnopqrstuvwxyz', # all unique ngrams -> ratio 0 + ], + Fields.stats: [{}, {}], + } + result = op.compute_stats_batched(samples) + # Highly repetitive text should have ratio close to 1.0 + self.assertGreater(result[Fields.stats][0]['char_rep_ratio'], 0.9) + # Unique text should have ratio 0.0 + self.assertAlmostEqual(result[Fields.stats][1]['char_rep_ratio'], 0.0) + + def test_process_batched_directly(self): + """Call process_batched directly to verify filtering logic.""" + op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=0.4, batch_size=2) + samples = { + Fields.stats: [ + {'char_rep_ratio': 0.9}, # exceeds max, filtered out + {'char_rep_ratio': 0.2}, # within range, kept + {'char_rep_ratio': 0.0}, # within range, kept + ], + } + keep_flags = list(op.process_batched(samples)) + self.assertEqual(keep_flags, [False, True, True]) + + def test_compute_stats_batched_short_text(self): + """Text shorter than rep_len produces ratio 0.0.""" + op = CharacterRepetitionFilter(rep_len=10, min_ratio=0.0, max_ratio=0.5, batch_size=2) + samples = { + 'text': ['hi', 'abcde', ''], + Fields.stats: [{}, {}, {}], + } + result = op.compute_stats_batched(samples) + for stat in result[Fields.stats]: + self.assertAlmostEqual(stat['char_rep_ratio'], 0.0) + + def test_compute_stats_batched_skips_existing(self): + """Already computed stats are not overwritten.""" + op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=0.5, batch_size=2) + samples = { + 'text': ['aaaaaaaaaaaaaaaaaaaaaaaaa'], + Fields.stats: [{'char_rep_ratio': 0.42}], # pre-computed + } + result = op.compute_stats_batched(samples) + # Should preserve existing value + self.assertAlmostEqual(result[Fields.stats][0]['char_rep_ratio'], 0.42) + + def test_compute_stats_batched_mixed_repetition(self): + """Verify correct ratio for text with moderate repetition.""" + op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=1.0, batch_size=2) + # "Today is Sund Sund Sund..." has some repeated 5-grams + samples = { + 'text': [ + "Today is Sund Sund Sund Sund Sund Sunda and it's a happy day!", + 'abcdefghijklmnopqrstuvwxyz0123456789', + ], + Fields.stats: [{}, {}], + } + result = op.compute_stats_batched(samples) + # First text has repeated ngrams so ratio > 0 + self.assertGreater(result[Fields.stats][0]['char_rep_ratio'], 0.0) + # Second text has no repetition + self.assertAlmostEqual(result[Fields.stats][1]['char_rep_ratio'], 0.0) + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/filter/test_suffix_filter.py b/tests/ops/filter/test_suffix_filter.py index fed28594e6c..97a695204e6 100644 --- a/tests/ops/filter/test_suffix_filter.py +++ b/tests/ops/filter/test_suffix_filter.py @@ -83,5 +83,38 @@ def test_none_case(self): self._run_suffix_filter(dataset, tgt_list, op) + def test_string_suffix(self): + """Single string suffix (not list) is handled correctly.""" + ds_list = [ + {'text': 'hello', Fields.suffix: '.txt'}, + {'text': 'world', Fields.suffix: '.py'}, + ] + tgt_list = [{'text': 'hello', Fields.suffix: '.txt'}] + dataset = Dataset.from_list(ds_list) + op = SuffixFilter(suffixes='.txt') + self._run_suffix_filter(dataset, tgt_list, op) + + def test_none_suffixes(self): + """suffixes=None keeps all samples.""" + ds_list = [ + {'text': 'hello', Fields.suffix: '.txt'}, + {'text': 'world', Fields.suffix: '.py'}, + ] + dataset = Dataset.from_list(ds_list) + op = SuffixFilter(suffixes=None) + self._run_suffix_filter(dataset, ds_list, op) + + def test_reversed_range(self): + """reversed_range=True inverts the filter logic.""" + ds_list = [ + {'text': 'hello', Fields.suffix: '.txt'}, + {'text': 'world', Fields.suffix: '.py'}, + ] + tgt_list = [{'text': 'world', Fields.suffix: '.py'}] + dataset = Dataset.from_list(ds_list) + op = SuffixFilter(suffixes=['.txt'], reversed_range=True) + self._run_suffix_filter(dataset, tgt_list, op) + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/filter/test_word_repetition_filter.py b/tests/ops/filter/test_word_repetition_filter.py index 030189d5637..31206790a2d 100644 --- a/tests/ops/filter/test_word_repetition_filter.py +++ b/tests/ops/filter/test_word_repetition_filter.py @@ -88,5 +88,65 @@ def test_zh_case(self): self._run_word_repetition_filter(dataset, tgt_list, op) + def test_compute_stats_batched_directly(self): + """Directly call compute_stats_batched for non-tokenization path.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = WordRepetitionFilter(rep_len=2, min_ratio=0.0, max_ratio=0.5, tokenization=False) + # 'hello world' repeated 5 times -> high word 2-gram repetition + text_repetitive = ' '.join(['hello world'] * 5 + ['unique text here']) + text_unique = 'all different words in this sentence no repeats at all' + samples = { + 'text': [text_repetitive, text_unique], + Fields.stats: [{}, {}], + } + result = op.compute_stats_batched(samples) + # Repetitive text should have high ratio + self.assertGreater(result[Fields.stats][0][StatsKeys.word_rep_ratio], 0.5) + # Unique text should have ratio 0.0 + self.assertAlmostEqual(result[Fields.stats][1][StatsKeys.word_rep_ratio], 0.0) + + def test_process_batched_directly(self): + """Directly call process_batched to verify filtering.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = WordRepetitionFilter(rep_len=2, min_ratio=0.0, max_ratio=0.5, tokenization=False) + samples = { + Fields.stats: [ + {StatsKeys.word_rep_ratio: 0.8}, # exceeds max -> filtered + {StatsKeys.word_rep_ratio: 0.3}, # within range -> kept + {StatsKeys.word_rep_ratio: 0.0}, # within range -> kept + ], + } + keep_flags = list(op.process_batched(samples)) + self.assertEqual(keep_flags, [False, True, True]) + + def test_compute_stats_batched_empty_words(self): + """Text that produces no word n-grams gets ratio 0.0.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = WordRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=1.0, tokenization=False) + # Very short text with fewer words than rep_len=5 + samples = { + 'text': ['hi there', ''], + Fields.stats: [{}, {}], + } + result = op.compute_stats_batched(samples) + self.assertAlmostEqual(result[Fields.stats][0][StatsKeys.word_rep_ratio], 0.0) + self.assertAlmostEqual(result[Fields.stats][1][StatsKeys.word_rep_ratio], 0.0) + + def test_compute_stats_batched_skips_existing(self): + """Already computed stats are preserved.""" + from data_juicer.utils.constant import Fields, StatsKeys + + op = WordRepetitionFilter(rep_len=2, min_ratio=0.0, max_ratio=1.0, tokenization=False) + samples = { + 'text': ['hello world hello world hello world'], + Fields.stats: [{StatsKeys.word_rep_ratio: 0.42}], + } + result = op.compute_stats_batched(samples) + self.assertAlmostEqual(result[Fields.stats][0][StatsKeys.word_rep_ratio], 0.42) + + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_availability_utils.py b/tests/utils/test_availability_utils.py new file mode 100644 index 00000000000..aa9c23714c0 --- /dev/null +++ b/tests/utils/test_availability_utils.py @@ -0,0 +1,130 @@ +"""Unit tests for data_juicer/utils/availability_utils.py.""" + +import importlib.metadata +import importlib.util +import os +import unittest +from unittest.mock import MagicMock, patch + +from data_juicer.utils.availability_utils import ( + _is_package_available, + _torch_check_and_set, +) +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class IsPackageAvailableTest(DataJuicerTestCaseBase): + """Tests for _is_package_available.""" + + def test_existing_package_returns_true(self): + """A known installed package should return True.""" + result = _is_package_available("datasets") + self.assertTrue(result) + + def test_nonexistent_package_returns_false(self): + """A package that does not exist should return False.""" + result = _is_package_available("nonexistent_package_xyz_12345") + self.assertFalse(result) + + def test_existing_package_with_version(self): + """return_version=True should return (True, version_string).""" + exists, version = _is_package_available("datasets", return_version=True) + self.assertTrue(exists) + self.assertIsInstance(version, str) + self.assertNotEqual(version, "N/A") + + def test_nonexistent_package_with_version(self): + """return_version=True for missing package returns (False, 'N/A').""" + exists, version = _is_package_available( + "nonexistent_package_xyz_12345", return_version=True + ) + self.assertFalse(exists) + self.assertEqual(version, "N/A") + + def test_spec_exists_but_no_metadata(self): + """Package whose spec exists but metadata.version raises PackageNotFoundError.""" + with patch("importlib.util.find_spec", return_value=MagicMock()): + with patch( + "importlib.metadata.version", + side_effect=importlib.metadata.PackageNotFoundError("fake"), + ): + result = _is_package_available("fake_pkg") + self.assertFalse(result) + + def test_spec_exists_but_no_metadata_with_version(self): + """Same as above but with return_version=True.""" + with patch("importlib.util.find_spec", return_value=MagicMock()): + with patch( + "importlib.metadata.version", + side_effect=importlib.metadata.PackageNotFoundError("fake"), + ): + exists, version = _is_package_available("fake_pkg", return_version=True) + self.assertFalse(exists) + self.assertEqual(version, "N/A") + + +class TorchCheckAndSetTest(DataJuicerTestCaseBase): + """Tests for _torch_check_and_set.""" + + def setUp(self): + super().setUp() + # Save and reset global flag + import data_juicer.utils.availability_utils as mod + self._mod = mod + self._original_flag = mod.CHECK_SYSTEM_INFO_ONCE + mod.CHECK_SYSTEM_INFO_ONCE = True + self._original_omp = os.environ.get("OMP_NUM_THREADS") + + def tearDown(self): + self._mod.CHECK_SYSTEM_INFO_ONCE = self._original_flag + if self._original_omp is not None: + os.environ["OMP_NUM_THREADS"] = self._original_omp + elif "OMP_NUM_THREADS" in os.environ: + del os.environ["OMP_NUM_THREADS"] + super().tearDown() + + def test_sets_torch_num_threads(self): + """On a system with torch, _torch_check_and_set calls torch.set_num_threads(1).""" + mock_torch = MagicMock() + with patch.dict("sys.modules", {"torch": mock_torch}): + with patch("importlib.util.find_spec", return_value=MagicMock()): + # Not on Mac py3.8 path + with patch("sys.version_info", (3, 11, 0)): + with patch("platform.system", return_value="Linux"): + _torch_check_and_set() + mock_torch.set_num_threads.assert_called_once_with(1) + + def test_mac_py38_sets_omp_env(self): + """On Mac + Python 3.8, should set OMP_NUM_THREADS=1.""" + mock_torch = MagicMock() + with patch.dict("sys.modules", {"torch": mock_torch}): + with patch("importlib.util.find_spec", return_value=MagicMock()): + with patch( + "data_juicer.utils.availability_utils.sys" + ) as mock_sys: + mock_sys.version_info = (3, 8, 0) + with patch("platform.system", return_value="Darwin"): + self._mod.CHECK_SYSTEM_INFO_ONCE = True + _torch_check_and_set() + self.assertEqual(os.environ.get("OMP_NUM_THREADS"), "1") + self.assertFalse(self._mod.CHECK_SYSTEM_INFO_ONCE) + + def test_no_torch_does_nothing(self): + """When torch is not installed, _torch_check_and_set does nothing.""" + with patch("importlib.util.find_spec", return_value=None): + # Should not raise + _torch_check_and_set() + # Flag unchanged + self.assertTrue(self._mod.CHECK_SYSTEM_INFO_ONCE) + + def test_check_system_info_once_false_skips(self): + """When CHECK_SYSTEM_INFO_ONCE is False, torch is not found, function skips.""" + self._mod.CHECK_SYSTEM_INFO_ONCE = False + with patch("importlib.util.find_spec", return_value=None): + _torch_check_and_set() + # Remains False + self.assertFalse(self._mod.CHECK_SYSTEM_INFO_ONCE) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_availablility_utils.py b/tests/utils/test_availablility_utils.py deleted file mode 100644 index 163e3431d11..00000000000 --- a/tests/utils/test_availablility_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -import unittest - -from data_juicer.utils.availability_utils import _is_package_available -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - -class AvailabilityUtilsTest(DataJuicerTestCaseBase): - - def test_is_package_available(self): - exist = _is_package_available('fsspec') - self.assertTrue(exist) - exist, version = _is_package_available('fsspec', return_version=True) - self.assertTrue(exist) - self.assertEqual(version, '2023.5.0') - - exist = _is_package_available('non_existing_package') - self.assertFalse(exist) - exist, version = _is_package_available('non_existing_package', return_version=True) - self.assertFalse(exist) - self.assertEqual(version, 'N/A') - - -if __name__ == '__main__': - unittest.main()