diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 9ed7558bf8c..5698ca627cd 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -176,8 +176,16 @@ def init_configs(args: Optional[List[str]] = None, which_entry: object = None, l "--executor_type", type=str, default="default", - choices=["default", "ray", "ray_partitioned"], - help='Type of executor, support "default", "ray", or "ray_partitioned".', + choices=["default", "ray", "ray_partitioned", "elastic_ray"], + help='Type of executor, support "default", "ray", "ray_partitioned", or "elastic_ray".', + ) + parser.add_argument( + "--elastic_juicer", + type=Dict, + default={}, + help="Configuration for ElasticJuicer adaptive scheduling when using elastic_ray executor. " + "Keys: scheduler_preset (conservative/gpu/aggressive), " + "rebalance_interval (float, seconds), config_path (optional YAML path).", ) parser.add_argument( "--dataset_path", diff --git a/data_juicer/core/adapter.py b/data_juicer/core/adapter.py index 17492db6012..16a60f1014a 100644 --- a/data_juicer/core/adapter.py +++ b/data_juicer/core/adapter.py @@ -112,6 +112,11 @@ def adapt_workloads(self, dataset, operators): # calculate batch size for each OP according to the analysis results bs_per_op = self.batch_size_strategy(load_analysis_res, base_bs=probed_batch_size) + # Stash probe results so ProbeAdapter (PR-1) can translate them to + # ElasticJuicer ProfilingStore schema in a subsequent call. + self._last_analysis = load_analysis_res + self._last_probe_batch_sizes = bs_per_op + return bs_per_op @dataset_cache_control(on=True) diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 2d8b198565a..ef353090ce4 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -8,6 +8,7 @@ import ray from jsonargparse import Namespace from loguru import logger +from ray.data import ActorPoolStrategy from ray.data._internal.util import get_compute_strategy from data_juicer.core.data import DJDataset @@ -237,9 +238,16 @@ def process_batch_arrow(table: pyarrow.Table): try: if op.use_ray_actor(): - compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.__class__, + # Repartition right before GPU actor stage to ensure enough data blocks + # Pipeline: Read(streaming) → Repartition → GPU actors + # Note: override_num_blocks cannot be passed to map_batches for actors, + # so we repartition beforehand instead + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + self.data = self.data.repartition(override_num_blocks) + + compute = ActorPoolStrategy(size=op.num_proc) + map_batches_kwargs = dict( fn_args=None, fn_kwargs=None, fn_constructor_args=op._init_args, @@ -251,10 +259,10 @@ def process_batch_arrow(table: pyarrow.Table): batch_format="pyarrow", runtime_env=op.runtime_env, ) + self.data = self.data.map_batches(op.__class__, **map_batches_kwargs) else: compute = get_compute_strategy(op.process, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.process, + map_batches_kwargs = dict( batch_size=batch_size, batch_format="pyarrow", num_cpus=op.num_cpus, @@ -262,6 +270,10 @@ def process_batch_arrow(table: pyarrow.Table): compute=compute, runtime_env=op.runtime_env, ) + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + map_batches_kwargs['override_num_blocks'] = override_num_blocks + self.data = self.data.map_batches(op.process, **map_batches_kwargs) finally: # Restore original process method if tracer and should_trace_op(tracer, op._name) and original_process: @@ -280,9 +292,16 @@ def process_batch_arrow(table: pyarrow.Table): ) cached_columns.add(Fields.stats) if op.use_ray_actor(): - compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.__class__, + # Repartition AFTER CPU preprocessing but BEFORE GPU actor stage + # Pipeline: Read(streaming) → CPU preprocessing(streaming) → Repartition → GPU actors + # This allows Read → CPU to stream freely without pipeline barriers, + # then repartition splits collapsed blocks into enough pieces for GPU utilization + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + self.data = self.data.repartition(override_num_blocks) + + compute = ActorPoolStrategy(size=op.num_proc) + map_batches_kwargs = dict( fn_args=None, fn_kwargs=None, fn_constructor_args=op._init_args, @@ -294,10 +313,10 @@ def process_batch_arrow(table: pyarrow.Table): batch_format="pyarrow", runtime_env=op.runtime_env, ) + self.data = self.data.map_batches(op.__class__, **map_batches_kwargs) else: compute = get_compute_strategy(op.compute_stats, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.compute_stats, + map_batches_kwargs = dict( batch_size=batch_size, batch_format="pyarrow", num_cpus=op.num_cpus, @@ -305,6 +324,10 @@ def process_batch_arrow(table: pyarrow.Table): compute=compute, runtime_env=op.runtime_env, ) + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + map_batches_kwargs['override_num_blocks'] = override_num_blocks + self.data = self.data.map_batches(op.compute_stats, **map_batches_kwargs) if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) # Wrap process method with tracer for sample-level collection diff --git a/data_juicer/core/elasticjuicer/__init__.py b/data_juicer/core/elasticjuicer/__init__.py new file mode 100644 index 00000000000..715b79fbd93 --- /dev/null +++ b/data_juicer/core/elasticjuicer/__init__.py @@ -0,0 +1,33 @@ +""" +ElasticJuicer: Adaptive Resource Scheduling for Data-Juicer + +A system that provides dynamic resource management and OOM prevention for +multimodal data processing pipelines. +""" + +__version__ = "0.1.0" + +# Core ElasticJuicer classes +from .elastic_juicer import ElasticJuicer +from .scheduler.scheduler_config import SchedulerConfig +from .scheduler.tower import Tower +from .scheduler.captain import Captain, CaptainPool +from .scheduler.micro_scheduler import MicroScheduler + + +# Lazy import for tuner (requires ray dependency) +def get_pbt_tuner(): + from .tuner.pbt_tuner import PBTTuner + return PBTTuner + + +__all__ = [ + "profiler", + "ElasticJuicer", + "SchedulerConfig", + "Tower", + "Captain", + "CaptainPool", + "MicroScheduler", + "get_pbt_tuner", +] diff --git a/data_juicer/core/elasticjuicer/elastic_juicer.py b/data_juicer/core/elasticjuicer/elastic_juicer.py new file mode 100644 index 00000000000..bd700a8ee5a --- /dev/null +++ b/data_juicer/core/elasticjuicer/elastic_juicer.py @@ -0,0 +1,780 @@ +""" +ElasticJuicer: Scheduling Facade for Bi-Level Adaptive Scheduling. + +This module provides ElasticJuicer, a facade that manages ALL scheduling +infrastructure: Tower (macro-scheduler), Captains (per-stage micro-schedulers), +MetricsBridge, and Ray named actors for metrics collection and quota distribution. + +Architecture: + ElasticRayExecutor → ElasticJuicer (facade) → Tower, Captain, MetricsBridge, etc. + +The facade lives on the driver (not serializable to Ray actors). Ray workers +(AdaptiveOperator) communicate via Ray named actors (PipelineMetricsCollector, +SharedQuotaStore) created and managed by this facade. + +Usage: + elastic = ElasticJuicer(config=scheduler_config) + elastic.register_stages(stage_configs) + elastic.start() + try: + # ... run pipeline with AdaptiveOperator ... + finally: + elastic.stop() +""" + +import logging +import threading +import time +from dataclasses import asdict +from typing import Any, Callable, Dict, List, Optional + +from .scheduler.scheduler_config import SchedulerConfig +from .scheduler.tower import Tower, ClusterState +from .scheduler.captain import Captain, CaptainConfig, CaptainPool + +logger = logging.getLogger(__name__) + + +def _get_default_cluster_state() -> ClusterState: + """Create a default ClusterState based on current system resources. + + Detects CPU, memory, and GPU resources available on the system. + + Returns: + ClusterState with detected or default resource values. + """ + try: + import psutil + + cpu_count = psutil.cpu_count(logical=True) or 4 + memory_info = psutil.virtual_memory() + total_memory_mb = memory_info.total / (1024 * 1024) + available_memory_mb = memory_info.available / (1024 * 1024) + + # Try to detect GPUs + gpu_count = 0 + try: + import torch + gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + except ImportError: + pass + + return ClusterState( + total_cpu_cores=cpu_count, + total_memory_mb=total_memory_mb, + total_gpu_count=gpu_count, + available_cpu_cores=float(cpu_count), + available_memory_mb=available_memory_mb, + available_gpus=float(gpu_count), + ) + except ImportError: + # Fallback to sensible defaults if psutil not available + return ClusterState( + total_cpu_cores=4, + total_memory_mb=8192.0, + total_gpu_count=0, + available_cpu_cores=4.0, + available_memory_mb=6144.0, + available_gpus=0.0, + ) + + +def _create_pipeline_metrics_collector(): + """Create the PipelineMetricsCollector Ray actor class. + + This actor aggregates metrics from all AdaptiveOperator instances, + providing a centralized view of pipeline performance. + """ + import ray + + @ray.remote + class PipelineMetricsCollector: + """Shared actor to aggregate metrics from all stages.""" + + def __init__(self): + self.stage_metrics: Dict[str, Dict[str, Any]] = {} + + def report( + self, stage_name: str, batch_size: int, latency_ms: float, memory_mb: float + ): + """Report metrics for a batch processed by a stage.""" + import time + if stage_name not in self.stage_metrics: + self.stage_metrics[stage_name] = { + "batch_sizes": [], + "latencies": [], + "memories": [], + "total_samples": 0, + "total_batches": 0, + "start_time": time.time(), + } + m = self.stage_metrics[stage_name] + m["batch_sizes"].append(batch_size) + m["latencies"].append(latency_ms) + m["memories"].append(memory_mb) + m["total_samples"] += batch_size + m["total_batches"] += 1 + m["last_time"] = time.time() + + def get_summary(self) -> Dict[str, Any]: + """Get summary metrics for all stages.""" + result = {} + for stage, data in self.stage_metrics.items(): + bs = data["batch_sizes"] + lat = data["latencies"] + mem = data["memories"] + wall_clock_elapsed = max( + data.get("last_time", 0) - data.get("start_time", 0), 0.001 + ) + result[stage] = { + "total_samples": data["total_samples"], + "total_batches": data["total_batches"], + "min_bs": min(bs) if bs else 0, + "max_bs": max(bs) if bs else 0, + "avg_bs": sum(bs) / len(bs) if bs else 0, + "avg_latency_ms": sum(lat) / len(lat) if lat else 0, + "avg_memory_mb": sum(mem) / len(mem) if mem else 0, + "peak_memory_mb": max(mem) if mem else 0, + "wall_clock_throughput": data["total_samples"] / wall_clock_elapsed, + } + return result + + def reset(self): + """Reset all metrics.""" + self.stage_metrics = {} + + return PipelineMetricsCollector + + +def _create_shared_quota_store(): + """Create the SharedQuotaStore Ray actor class. + + This actor serves as a bridge between driver-side Tower/Captains and + Ray actor-side AdaptiveOperators. The MetricsBridge thread updates + quotas here, and AdaptiveOperators read them periodically. + """ + import ray + + @ray.remote + class SharedQuotaStore: + """Shared store for Tower quotas, readable by AdaptiveOperator actors.""" + + def __init__(self): + self.quotas = {} # {stage_name: {'batch_size': int, 'backpressure': bool, ...}} + + def update_quota(self, stage_name: str, quota_dict: Dict): + """Update quota for a stage.""" + self.quotas[stage_name] = quota_dict + + def get_quota(self, stage_name: str): + """Get quota for a stage.""" + return self.quotas.get(stage_name, None) + + def get_all_quotas(self): + """Get all quotas.""" + return dict(self.quotas) + + return SharedQuotaStore + + +class MetricsBridge(threading.Thread): + """ + Bridge between Ray actor metrics and driver-side Tower/Captains. + + This thread runs on the driver and periodically: + 1. Polls PipelineMetricsCollector for per-stage metrics from actors + 2. Feeds metrics to corresponding Captains (updating their internal state) + 3. Tower's rebalance loop collects from Captains and computes quotas + 4. Reads Captain quotas and pushes to SharedQuotaStore for actors to read + + This bridges the gap between: + - Actor-side: AdaptiveOperator with MicroScheduler reporting to PipelineMetricsCollector + - Driver-side: Tower collecting from Captains and broadcasting quotas + """ + + def __init__( + self, + tower, + captains: Dict[str, Any], + metrics_collector, + quota_store, + interval: float = 2.0, + ): + """ + Initialize MetricsBridge. + + Args: + tower: Tower macro-scheduler instance + captains: Dict mapping stage_name to Captain instance + metrics_collector: PipelineMetricsCollector Ray actor handle + quota_store: SharedQuotaStore Ray actor handle + interval: Bridge cycle interval in seconds + """ + super().__init__(daemon=True, name="MetricsBridge") + self.tower = tower + self.captains = captains + self.metrics_collector = metrics_collector + self.quota_store = quota_store + self.interval = interval + self._running = False + + def run(self): + """Main bridge loop.""" + self._running = True + while self._running: + try: + self._bridge_cycle() + except Exception as e: + logger.debug(f"MetricsBridge cycle error: {e}") + time.sleep(self.interval) + + def stop(self): + """Stop the bridge thread.""" + self._running = False + + def _bridge_cycle(self): + """Execute one bridge cycle.""" + import ray + + # 1. Get metrics from PipelineMetricsCollector (actor-side metrics) + try: + summary = ray.get(self.metrics_collector.get_summary.remote()) + except Exception: + return + + # 2. Feed actor metrics to Captains to update their internal state + for stage_name, captain in self.captains.items(): + stage_data = summary.get(stage_name, {}) + if stage_data: + # Update captain's internal metrics tracking fields + total_batches = stage_data.get("total_batches", 1) + avg_latency_ms = stage_data.get("avg_latency_ms", 0) + total_samples = stage_data.get("total_samples", 0) + + # Calculate throughput: use wall-clock throughput if available + throughput = stage_data.get("wall_clock_throughput", 0) + if throughput <= 0: + # Fallback to old formula + time_sec = (avg_latency_ms * total_batches) / 1000.0 if avg_latency_ms > 0 else 1.0 + throughput = total_samples / max(time_sec, 0.001) + + # Update captain's internal metrics for Tower to collect + captain._recent_throughput = throughput + captain._recent_latency_ms = avg_latency_ms + captain.metrics.throughput = throughput + captain.metrics.avg_latency_ms = avg_latency_ms + + # Update memory utilization on captain + avg_memory_mb = stage_data.get('avg_memory_mb', 0.0) + peak_memory_mb = stage_data.get('peak_memory_mb', 0.0) + if hasattr(captain, '_current_memory_util'): + try: + import psutil + total_mem_mb = psutil.virtual_memory().total / (1024 * 1024) + captain._current_memory_util = (peak_memory_mb / total_mem_mb * 100.0) if total_mem_mb > 0 else 0.0 + except Exception: + pass + if hasattr(captain, 'metrics') and hasattr(captain.metrics, 'memory_utilization'): + captain.metrics.memory_utilization = captain._current_memory_util + + # 3. Push quota updates from Captains to SharedQuotaStore + for stage_name, captain in self.captains.items(): + try: + # Get captain's current state (batch size, backpressure) + current_batch_size = ( + captain.micro_scheduler.controller.current_batch_size + if captain.micro_scheduler + else captain.config.initial_batch_size + ) + + quota_dict = { + "batch_size": current_batch_size, + "backpressure": captain._backpressure_active, + "memory_quota_mb": ( + captain.quota.memory_quota_mb if captain.quota else 0 + ), + } + ray.get(self.quota_store.update_quota.remote(stage_name, quota_dict)) + except Exception: + pass + + +class ElasticJuicer: + """ + Scheduling facade for bi-level adaptive scheduling. + + Manages Tower (macro-scheduler), Captains (per-stage), MetricsBridge, + and Ray named actors for metrics collection and quota distribution. + + The facade lives on the driver (not serializable). Ray workers connect + to the named actors by name to receive quotas and report metrics. + + Usage: + elastic = ElasticJuicer(config=scheduler_config, cluster_state=cluster_state) + elastic.register_stages(stage_configs) + elastic.start() + try: + # ... run pipeline with AdaptiveOperator ... + finally: + elastic.stop() + + Attributes: + config: SchedulerConfig instance. + tower: Tower macro-scheduler instance (None until register_stages called). + is_running: Whether the scheduling system is currently active. + """ + + def __init__( + self, + config: Optional[SchedulerConfig] = None, + config_path: Optional[str] = None, + cluster_state: Optional[ClusterState] = None, + preset: str = 'gpu', + ): + """ + Initialize with config from preset, YAML file, or SchedulerConfig object. + + Args: + config: Pre-existing SchedulerConfig. Takes precedence over other options. + config_path: Path to load config from YAML. Used if config is None. + cluster_state: Optional ClusterState for Tower. If None, auto-detected. + preset: Preset name if no config or config_path provided. + Options: 'conservative', 'gpu', 'aggressive'. Default: 'gpu'. + """ + # Determine SchedulerConfig + if config is not None: + self._config = config + elif config_path is not None: + self._config = SchedulerConfig.from_yaml(config_path) + logger.info(f"Loaded config from {config_path}") + else: + # Use preset + presets = { + 'conservative': SchedulerConfig.conservative, + 'gpu': SchedulerConfig.gpu, + 'aggressive': SchedulerConfig.aggressive, + } + factory = presets.get(preset, SchedulerConfig.gpu) + self._config = factory() + logger.info(f"Using ElasticJuicer {preset} preset config") + + # Detect cluster state + self._cluster_state = cluster_state or _get_default_cluster_state() + + # Components (initialized in register_stages / start) + self._tower: Optional[Tower] = None + self._captains: Dict[str, Captain] = {} + self._captain_pool: Optional[CaptainPool] = None + self._captain_ids: Dict[str, str] = {} # stage_name -> captain_id mapping + self._metrics_bridge: Optional[MetricsBridge] = None + self._metrics_collector = None # Ray named actor + self._quota_store = None # Ray named actor + self._is_running = False + self._stage_names: List[str] = [] + + def register_stages( + self, + stage_configs: List[Dict[str, Any]], + ) -> None: + """ + Register operator stages for scheduling. + + Creates Tower, Captains, and registers all stages. Must be called + before start(). + + Args: + stage_configs: List of dicts with keys: + - 'name': str (stage identifier, e.g. 'stage_0_video_aesthetics_filter') + - 'batch_size': int (initial batch size, optional) + - 'num_gpus': float (GPU requirement, 0 for CPU-only, optional) + - 'num_actors': int (actor pool size, optional) + """ + # Create Tower macro-scheduler + self._tower = Tower( + cluster_state=self._cluster_state, + target_queue_depth=100, + sla_latency_ms=5000.0, + update_interval_sec=self._config.rebalance_interval_sec, + config=self._config, + ) + + self._stage_names = [] + self._captains = {} + self._captain_ids = {} + + for sc in stage_configs: + stage_name = sc['name'] + batch_size = sc.get('batch_size', self._config.initial_batch_size) + num_actors = sc.get('num_actors', 1) + + # Register stage with Tower (returns captain_id) + captain_id = self._tower.register_stage( + stage_name=stage_name, initial_parallelism=num_actors + ) + self._captain_ids[stage_name] = captain_id + + # Create Captain config + captain_config = CaptainConfig( + stage_name=stage_name, + initial_batch_size=batch_size, + enable_micro_scheduler=self._config.enable_auto_adjust, + enable_prediction=self._config.enable_prediction, + ) + + # Create Captain instance + captain = Captain(config=captain_config) + self._captains[stage_name] = captain + + # Register captain with Tower for metrics collection and quota broadcast + self._tower.register_captain(captain_id, captain) + + self._stage_names.append(stage_name) + + logger.info(f"ElasticJuicer: Registered {len(stage_configs)} stages with Tower") + + def start(self) -> None: + """ + Start the scheduling system. + + Creates Ray named actors (PipelineMetricsCollector, SharedQuotaStore), + starts Tower rebalance loop, and starts MetricsBridge thread. + + Must call register_stages() before this method. + """ + if self._is_running: + logger.warning("ElasticJuicer is already running") + return + + if self._tower is None: + raise RuntimeError( + "Must call register_stages() before start(). " + "No stages have been registered." + ) + + import ray + + # Create PipelineMetricsCollector named actor + try: + collector_cls = _create_pipeline_metrics_collector() + self._metrics_collector = collector_cls.options( + name="elastic_pipeline_metrics", get_if_exists=True + ).remote() + logger.info("PipelineMetricsCollector named actor created") + except Exception as e: + logger.warning(f"Failed to create PipelineMetricsCollector: {e}") + self._metrics_collector = None + + # Create SharedQuotaStore named actor + try: + quota_cls = _create_shared_quota_store() + self._quota_store = quota_cls.options( + name="elastic_quota_store", get_if_exists=True + ).remote() + logger.info("SharedQuotaStore named actor created") + except Exception as e: + logger.warning(f"Failed to create SharedQuotaStore: {e}") + self._quota_store = None + + # Start Tower rebalance loop + self._tower.start() + logger.info("Tower rebalance loop started") + + # Start MetricsBridge + if self._captains and self._metrics_collector is not None and self._quota_store is not None: + self._metrics_bridge = MetricsBridge( + tower=self._tower, + captains=self._captains, + metrics_collector=self._metrics_collector, + quota_store=self._quota_store, + interval=self._config.rebalance_interval_sec, + ) + self._metrics_bridge.start() + logger.info("MetricsBridge started - connecting actors ↔ Captains ↔ Tower") + + self._is_running = True + logger.info("ElasticJuicer: Started (Tower + MetricsBridge + Ray actors)") + + def stop(self) -> None: + """ + Stop all scheduling components. + + Stops MetricsBridge thread, Tower rebalance loop, and cleans up + Ray named actors. Safe to call even if not running. + """ + if not self._is_running: + return + + # Stop MetricsBridge + if self._metrics_bridge is not None: + try: + self._metrics_bridge.stop() + self._metrics_bridge.join(timeout=5) + logger.info("MetricsBridge stopped") + except Exception: + pass + self._metrics_bridge = None + + # Stop Tower + if self._tower is not None: + try: + self._tower.stop() + logger.info("Tower rebalance loop stopped") + except Exception: + pass + + # Cleanup Ray actors + import ray + for actor in [self._metrics_collector, self._quota_store]: + if actor is not None: + try: + ray.kill(actor) + except Exception: + pass + self._metrics_collector = None + self._quota_store = None + + self._is_running = False + logger.info("ElasticJuicer: Stopped") + + def get_adaptive_op_config(self, stage_name: str) -> Dict[str, Any]: + """ + Get configuration dict for AdaptiveOperator constructor. + + This is what gets passed to fn_constructor_kwargs in map_batches(). + AdaptiveOperator uses this to connect to the named actors. + + Args: + stage_name: Name of the stage to get config for. + + Returns: + Dict with: stage_name, initial_batch_size, scheduler_config_dict + """ + captain = self._captains.get(stage_name) + batch_size = captain.config.initial_batch_size if captain else self._config.initial_batch_size + + return { + 'stage_name': stage_name, + 'initial_batch_size': batch_size, + 'scheduler_config_dict': asdict(self._config), + } + + def get_metrics_summary(self) -> Dict[str, Any]: + """ + Get per-stage metrics from PipelineMetricsCollector. + + Returns: + Dict mapping stage_name to metrics dict with: + total_samples, total_batches, min_bs, max_bs, avg_bs, + avg_latency_ms, avg_memory_mb, peak_memory_mb + """ + if self._metrics_collector is None: + return {} + try: + import ray + return ray.get(self._metrics_collector.get_summary.remote()) + except Exception: + return {} + + def get_captain_stats(self) -> Dict[str, Any]: + """ + Get per-stage Captain statistics. + + Returns: + Dict mapping stage_name to stats dict with: + throughput, latency_ms, batch_size, backpressure, oom_count + """ + stats = {} + for name, captain in self._captains.items(): + try: + metrics = captain.collect_metrics() + stats[name] = { + 'throughput': getattr(metrics, 'throughput', 0), + 'latency_ms': getattr(metrics, 'avg_latency_ms', 0), + 'batch_size': ( + captain.micro_scheduler.controller.current_batch_size + if captain.micro_scheduler + else captain.config.initial_batch_size + ), + 'backpressure': captain._backpressure_active if hasattr(captain, '_backpressure_active') else False, + 'oom_count': captain._total_oom_count if hasattr(captain, '_total_oom_count') else 0, + } + except Exception: + stats[name] = {} + return stats + + def get_tower_stats(self) -> Dict[str, Any]: + """ + Get Tower global statistics. + + Returns: + Dict with global stats: total_stages, total_parallelism, + sla_compliance_rate, total_requests, sla_violations, etc. + """ + if self._tower is None: + return {} + try: + return self._tower.get_global_stats() + except Exception: + return {} + + def get_status(self) -> Dict[str, Any]: + """ + Get complete system status. + + Returns: + Dict containing: is_running, stages, config_preset, + rebalance_interval, metrics, captains, tower + """ + return { + 'is_running': self._is_running, + 'stages': self._stage_names, + 'rebalance_interval': self._config.rebalance_interval_sec, + 'metrics': self.get_metrics_summary(), + 'captains': self.get_captain_stats(), + 'tower': self.get_tower_stats(), + } + + # ---- Offline tuning ---- + + def run_offline_tuning( + self, + stage_names: Optional[List[str]] = None, + simulation_fn: Optional[Callable[[SchedulerConfig], Dict[str, float]]] = None, + num_samples: int = 8, + max_iterations: int = 50, + export_path: str = "base_config.yaml", + ) -> SchedulerConfig: + """ + Run PBT offline tuning to find optimal SchedulerConfig. + + Uses Population Based Training to optimize PID controller parameters, + memory safety buffers, predictor settings, and per-stage resource + allocation weights. + + Args: + stage_names: Operator stage names to tune allocation weights for. + If None, uses self._stage_names or an empty list. + simulation_fn: Custom simulation function. If None, uses default. + num_samples: PBT population size (number of parallel trials). + max_iterations: Maximum training iterations per trial. + export_path: Path to save the tuned config YAML. + + Returns: + Optimized SchedulerConfig. + + Raises: + ImportError: If Ray Tune is not installed. + """ + from .tuner.pbt_tuner import PBTTuner, PBTTunerConfig + + tuner_config = PBTTunerConfig( + num_samples=num_samples, + max_iterations=max_iterations, + stage_names=stage_names or self._stage_names or [], + ) + + tuner = PBTTuner(config=tuner_config, simulation_fn=simulation_fn) + + logger.info( + f"Starting offline PBT tuning with {num_samples} samples, " + f"{max_iterations} iterations" + ) + + self._config = tuner.tune() + tuner.export_config(self._config, export_path) + + logger.info(f"Offline tuning complete. Config saved to {export_path}") + return self._config + + # ---- Properties ---- + + @property + def config(self) -> SchedulerConfig: + """Get the current SchedulerConfig.""" + return self._config + + @property + def tower(self) -> Optional[Tower]: + """Get the Tower macro-scheduler instance (None if not registered).""" + return self._tower + + @property + def captains(self) -> Dict[str, Captain]: + """Get all Captain instances mapped by stage name.""" + return self._captains + + @property + def is_running(self) -> bool: + """Check if the scheduling system is currently running.""" + return self._is_running + + @property + def stage_names(self) -> List[str]: + """Get list of registered stage names.""" + return self._stage_names.copy() + + # ---- Captain access ---- + + def get_captain(self, stage_name: str) -> Optional[Captain]: + """ + Get the Captain for a specific stage. + + Args: + stage_name: Name of the operator stage. + + Returns: + Captain instance for the stage, or None if not found. + """ + return self._captains.get(stage_name) + + # ---- Config management ---- + + def update_config(self, **kwargs) -> None: + """ + Update configuration parameters dynamically. + + Note: For runtime changes to take effect on Tower and Captains, + the system may need to be restarted. + + Args: + **kwargs: Configuration parameters to update. See SchedulerConfig. + """ + from dataclasses import fields + + valid_fields = {f.name for f in fields(SchedulerConfig)} + + for key, value in kwargs.items(): + if key not in valid_fields: + logger.warning(f"Unknown config parameter ignored: {key}") + continue + setattr(self._config, key, value) + + logger.info(f"Config updated with: {kwargs}") + + def save_config(self, path: str) -> None: + """ + Save current configuration to a YAML file. + + Args: + path: Output file path for the YAML config. + """ + self._config.to_yaml(path) + logger.info(f"Config saved to {path}") + + # ---- Context manager support ---- + + def __enter__(self) -> 'ElasticJuicer': + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + """Context manager exit - ensures graceful shutdown.""" + self.stop() + return False + + def __repr__(self) -> str: + """String representation of ElasticJuicer instance.""" + return ( + f"ElasticJuicer(is_running={self._is_running}, " + f"stages={self._stage_names})" + ) diff --git a/data_juicer/core/elasticjuicer/predictor/__init__.py b/data_juicer/core/elasticjuicer/predictor/__init__.py new file mode 100644 index 00000000000..2faac3c81b1 --- /dev/null +++ b/data_juicer/core/elasticjuicer/predictor/__init__.py @@ -0,0 +1,18 @@ +""" +Memory Prediction Module + +Provides: +- Online learning models for memory prediction +- Feature extraction from data samples +- Prediction with confidence intervals +- Safety margin calculations +""" + +from .memory_predictor import MemoryPredictor, PredictionResult +from .feature_extractor import FeatureExtractor + +__all__ = [ + "MemoryPredictor", + "PredictionResult", + "FeatureExtractor", +] diff --git a/data_juicer/core/elasticjuicer/predictor/feature_extractor.py b/data_juicer/core/elasticjuicer/predictor/feature_extractor.py new file mode 100644 index 00000000000..10438ad87bc --- /dev/null +++ b/data_juicer/core/elasticjuicer/predictor/feature_extractor.py @@ -0,0 +1,294 @@ +""" +Feature Extractor for Memory Prediction + +Extracts relevant features from data samples to predict memory usage: +- Text: length, num_tokens, special_chars +- Image: width, height, channels, format +- Video: resolution, frame_count, fps, duration +- Audio: sample_rate, duration, channels + +Based on Report Section 3.3 - Prediction Model +""" + +from typing import Dict, Any, List, Optional +from dataclasses import dataclass +import re + + +@dataclass +class SampleFeatures: + """Features extracted from a single sample""" + # Common features + batch_size: int = 1 + modality: str = "text" # text, image, video, audio, multimodal + + # Text features + text_length: Optional[int] = None + num_tokens: Optional[int] = None + + # Image features + image_width: Optional[int] = None + image_height: Optional[int] = None + image_channels: Optional[int] = None + num_images: Optional[int] = 0 + + # Video features + video_width: Optional[int] = None + video_height: Optional[int] = None + frame_count: Optional[int] = None + fps: Optional[float] = None + num_videos: Optional[int] = 0 + + # Audio features + audio_sample_rate: Optional[int] = None + audio_duration: Optional[float] = None + num_audios: Optional[int] = 0 + + # Derived features + total_pixels: Optional[int] = None # For images/videos + estimated_size_mb: Optional[float] = None # Rough size estimate + + def to_feature_vector(self) -> List[float]: + """Convert to numerical feature vector for ML models""" + features = [ + float(self.batch_size), + # Text + float(self.text_length or 0), + float(self.num_tokens or 0), + # Image + float(self.image_width or 0), + float(self.image_height or 0), + float(self.image_channels or 0), + float(self.num_images or 0), + # Video + float(self.video_width or 0), + float(self.video_height or 0), + float(self.frame_count or 0), + float(self.fps or 0), + float(self.num_videos or 0), + # Audio + float(self.audio_sample_rate or 0), + float(self.audio_duration or 0), + float(self.num_audios or 0), + # Derived + float(self.total_pixels or 0), + float(self.estimated_size_mb or 0), + ] + return features + + @staticmethod + def feature_names() -> List[str]: + """Get names of features in the vector""" + return [ + 'batch_size', + 'text_length', 'num_tokens', + 'image_width', 'image_height', 'image_channels', 'num_images', + 'video_width', 'video_height', 'frame_count', 'fps', 'num_videos', + 'audio_sample_rate', 'audio_duration', 'num_audios', + 'total_pixels', 'estimated_size_mb' + ] + + +class FeatureExtractor: + """ + Extracts memory-relevant features from Data-Juicer samples. + + Handles different modalities and data formats. + """ + + def __init__(self): + pass + + def extract_from_sample(self, sample: Dict[str, Any]) -> SampleFeatures: + """ + Extract features from a single sample. + + Args: + sample: Data-Juicer sample dictionary + + Returns: + SampleFeatures object + """ + features = SampleFeatures(batch_size=1) + + # Determine modality + has_text = bool('text' in sample and sample['text']) + has_images = bool('images' in sample and sample['images']) + has_videos = bool('videos' in sample and sample['videos']) + has_audios = bool('audios' in sample and sample['audios']) + + modality_count = sum([has_text, has_images, has_videos, has_audios]) + if modality_count > 1: + features.modality = "multimodal" + elif has_text: + features.modality = "text" + elif has_images: + features.modality = "image" + elif has_videos: + features.modality = "video" + elif has_audios: + features.modality = "audio" + + # Extract text features + if has_text: + text = sample['text'] + features.text_length = len(text) + # Simple tokenization (space-based) + features.num_tokens = len(text.split()) + + # Extract image features + if has_images: + images = sample['images'] + features.num_images = len(images) if isinstance(images, list) else 1 + # Try to get image metadata if available + if 'image_metadata' in sample: + meta = sample['image_metadata'] + if isinstance(meta, list) and meta: + meta = meta[0] # Use first image + features.image_width = meta.get('width') + features.image_height = meta.get('height') + features.image_channels = meta.get('channels', 3) + + if features.image_width and features.image_height: + features.total_pixels = features.image_width * features.image_height * features.num_images + + # Extract video features + if has_videos: + videos = sample['videos'] + features.num_videos = len(videos) if isinstance(videos, list) else 1 + # Try to get video metadata + if 'video_metadata' in sample: + meta = sample['video_metadata'] + if isinstance(meta, list) and meta: + meta = meta[0] # Use first video + features.video_width = meta.get('width') + features.video_height = meta.get('height') + features.frame_count = meta.get('frame_count') + features.fps = meta.get('fps') + + if features.video_width and features.video_height and features.frame_count: + features.total_pixels = (features.video_width * features.video_height * + features.frame_count * features.num_videos) + + # Extract audio features + if has_audios: + audios = sample['audios'] + features.num_audios = len(audios) if isinstance(audios, list) else 1 + if 'audio_metadata' in sample: + meta = sample['audio_metadata'] + if isinstance(meta, list) and meta: + meta = meta[0] + features.audio_sample_rate = meta.get('sample_rate') + features.audio_duration = meta.get('duration') + + # Estimate rough size in MB + features.estimated_size_mb = self._estimate_size(features) + + return features + + def extract_from_batch(self, batch: Dict[str, Any]) -> SampleFeatures: + """ + Extract features from a batched sample. + + Args: + batch: Batched data dictionary where values are lists + + Returns: + SampleFeatures object (aggregated) + """ + # Determine batch size + batch_size = 0 + for value in batch.values(): + if isinstance(value, list): + batch_size = len(value) + break + + if batch_size == 0: + # Not a batched format, treat as single + return self.extract_from_sample(batch) + + # Extract features from first sample and scale + first_sample = { + key: values[0] if isinstance(values, list) and values else values + for key, values in batch.items() + } + + features = self.extract_from_sample(first_sample) + features.batch_size = batch_size + + # Scale certain features + if features.estimated_size_mb: + features.estimated_size_mb *= batch_size + + return features + + def _estimate_size(self, features: SampleFeatures) -> float: + """ + Rough estimate of sample size in MB. + + This is a heuristic based on typical data sizes. + """ + size_mb = 0.0 + + # Text: ~1 byte per character + if features.text_length: + size_mb += features.text_length / (1024 * 1024) + + # Images: width * height * channels * bytes_per_pixel (typically 1-4) + if features.total_pixels and features.modality in ['image', 'multimodal']: + # Assume 3 bytes per pixel for RGB + size_mb += (features.total_pixels * 3) / (1024 * 1024) + + # Videos: similar but multiplied by frames + if features.total_pixels and features.modality == 'video': + # Videos in memory are often decoded to raw frames + size_mb += (features.total_pixels * 3) / (1024 * 1024) + + # Audio: sample_rate * duration * channels * bytes_per_sample + if features.audio_sample_rate and features.audio_duration: + # Assume 2 bytes per sample (16-bit), mono or stereo + channels = 2 + bytes_per_sample = 2 + size_mb += (features.audio_sample_rate * features.audio_duration * + channels * bytes_per_sample) / (1024 * 1024) + + return size_mb + + def analyze_batch_variance(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """ + Analyze variance in a batch to detect skew. + + High variance indicates need for dynamic batching. + """ + if not any(isinstance(v, list) for v in batch.values()): + return {'variance': 0, 'requires_dynamic_batching': False} + + # Extract features for each sample in batch + batch_size = len(batch[next(iter(batch))]) + sizes = [] + + for i in range(batch_size): + sample = {k: (v[i] if isinstance(v, list) else v) for k, v in batch.items()} + features = self.extract_from_sample(sample) + if features.estimated_size_mb: + sizes.append(features.estimated_size_mb) + + if not sizes: + return {'variance': 0, 'requires_dynamic_batching': False} + + import numpy as np + variance = float(np.var(sizes)) + mean_size = float(np.mean(sizes)) + coef_variation = variance / mean_size if mean_size > 0 else 0 + + # High coefficient of variation suggests dynamic batching + requires_dynamic = coef_variation > 0.5 + + return { + 'variance': variance, + 'mean_size_mb': mean_size, + 'min_size_mb': float(np.min(sizes)), + 'max_size_mb': float(np.max(sizes)), + 'coef_variation': coef_variation, + 'requires_dynamic_batching': requires_dynamic, + } diff --git a/data_juicer/core/elasticjuicer/predictor/memory_predictor.py b/data_juicer/core/elasticjuicer/predictor/memory_predictor.py new file mode 100644 index 00000000000..81d8e585e1a --- /dev/null +++ b/data_juicer/core/elasticjuicer/predictor/memory_predictor.py @@ -0,0 +1,301 @@ +""" +Memory Predictor with Online Learning + +Predicts memory usage for operators based on sample features. +Uses online learning to adapt to changing data distributions. + +Based on: +- Autothrottle: Online learning for resource prediction +- Report Section 3.3: Prediction Model +""" + +from typing import Optional, List, Tuple +from dataclasses import dataclass +import numpy as np +from collections import deque + +from .feature_extractor import SampleFeatures, FeatureExtractor + + +@dataclass +class PredictionResult: + """Result of memory prediction""" + predicted_memory_mb: float + confidence_lower: float # Lower bound of confidence interval + confidence_upper: float # Upper bound of confidence interval + prediction_error_history: Optional[float] = None # Recent prediction error + + def get_safe_prediction(self, safety_margin: float = 0.9) -> float: + """ + Get conservative prediction with safety margin. + + Uses upper confidence bound to be safe. + """ + return self.confidence_upper / safety_margin + + +class MemoryPredictor: + """ + Online learning model for memory prediction. + + Features: + - Incremental learning from new observations + - Confidence intervals for predictions + - Automatic model retraining + - Handles different operator types + """ + + def __init__( + self, + op_name: str, + window_size: int = 100, + confidence_level: float = 0.95, + min_samples_for_prediction: int = 5, + ): + """ + Initialize memory predictor. + + Args: + op_name: Operator name + window_size: Number of recent samples to keep for online learning + confidence_level: Confidence level for prediction intervals (default 95%) + min_samples_for_prediction: Minimum samples needed before making predictions + """ + self.op_name = op_name + self.window_size = window_size + self.confidence_level = confidence_level + self.min_samples_for_prediction = min_samples_for_prediction + + # Online learning data + self.feature_history = deque(maxlen=window_size) + self.memory_history = deque(maxlen=window_size) + self.error_history = deque(maxlen=window_size) + + # Model parameters (online linear regression) + self.weights: Optional[np.ndarray] = None + self.intercept: float = 0.0 + + # Feature extractor + self.feature_extractor = FeatureExtractor() + + # Statistics + self.total_predictions = 0 + self.total_updates = 0 + + def observe(self, features: SampleFeatures, actual_memory_mb: float): + """ + Observe a new data point and update the model. + + This is the core of online learning - model adapts as new data arrives. + + Args: + features: Sample features + actual_memory_mb: Actual memory used + """ + feature_vec = np.array(features.to_feature_vector()) + + # Store observation + self.feature_history.append(feature_vec) + self.memory_history.append(actual_memory_mb) + self.total_updates += 1 + + # Calculate prediction error if we had a model + if self.weights is not None: + predicted = self._predict_from_vector(feature_vec) + error = abs(predicted - actual_memory_mb) + self.error_history.append(error) + + # Retrain model if we have enough samples + if len(self.feature_history) >= self.min_samples_for_prediction: + self._retrain_model() + + def predict(self, features: SampleFeatures) -> Optional[PredictionResult]: + """ + Predict memory usage for given features. + + Args: + features: Sample features + + Returns: + PredictionResult with prediction and confidence bounds, or None if not enough data + """ + if len(self.feature_history) < self.min_samples_for_prediction: + return None + + if self.weights is None: + return None + + feature_vec = np.array(features.to_feature_vector()) + predicted = self._predict_from_vector(feature_vec) + + # Calculate confidence interval based on recent errors + if self.error_history: + # Use standard deviation of recent errors + std_error = np.std(list(self.error_history)) + # For 95% confidence, use ~2 standard deviations + z_score = 1.96 if self.confidence_level == 0.95 else 2.58 + margin = z_score * std_error + + confidence_lower = max(0, predicted - margin) + confidence_upper = predicted + margin + avg_error = np.mean(list(self.error_history)) + else: + # No error history yet, use conservative estimate + confidence_lower = predicted * 0.8 + confidence_upper = predicted * 1.5 + avg_error = None + + self.total_predictions += 1 + + return PredictionResult( + predicted_memory_mb=predicted, + confidence_lower=confidence_lower, + confidence_upper=confidence_upper, + prediction_error_history=avg_error, + ) + + def predict_batch_memory( + self, + sample_features: SampleFeatures, + target_batch_size: int, + ) -> Optional[PredictionResult]: + """ + Predict memory for a specific batch size. + + Scales the prediction based on batch size. + """ + # Scale features to target batch size + scaled_features = SampleFeatures(**vars(sample_features)) + scale_factor = target_batch_size / sample_features.batch_size + + scaled_features.batch_size = target_batch_size + if scaled_features.estimated_size_mb: + scaled_features.estimated_size_mb *= scale_factor + + return self.predict(scaled_features) + + def recommend_batch_size( + self, + sample_features: SampleFeatures, + available_memory_mb: float, + safety_margin: float = 0.85, + ) -> int: + """ + Recommend safe batch size given available memory. + + Uses binary search to find maximum safe batch size. + + Args: + sample_features: Features of a single sample + available_memory_mb: Available memory in MB + safety_margin: Use this fraction of available memory (default 85%) + + Returns: + Recommended batch size + """ + target_memory = available_memory_mb * safety_margin + + # Binary search for optimal batch size + low, high = 1, 1000 + best_batch_size = 1 + + for _ in range(20): # Max 20 iterations + mid = (low + high) // 2 + prediction = self.predict_batch_memory(sample_features, mid) + + if prediction is None: + # Not enough data, return conservative estimate + return 1 + + predicted_mem = prediction.get_safe_prediction(safety_margin) + + if predicted_mem <= target_memory: + best_batch_size = mid + low = mid + 1 + else: + high = mid - 1 + + return max(1, best_batch_size) + + def _predict_from_vector(self, feature_vec: np.ndarray) -> float: + """Make prediction from feature vector""" + if self.weights is None: + return 0.0 + + prediction = np.dot(feature_vec, self.weights) + self.intercept + return max(0, prediction) # Memory can't be negative + + def _retrain_model(self): + """ + Retrain the model using recent observations. + + Uses online linear regression for efficiency. + """ + if len(self.feature_history) < self.min_samples_for_prediction: + return + + # Convert to arrays + X = np.array(list(self.feature_history)) + y = np.array(list(self.memory_history)) + + try: + # Add regularization to prevent overfitting + lambda_reg = 0.01 + n_features = X.shape[1] + + # Ridge regression: (X^T X + λI)^-1 X^T y + XtX = X.T @ X + Xty = X.T @ y + + # Add regularization + XtX_reg = XtX + lambda_reg * np.eye(n_features) + + # Solve for weights + self.weights = np.linalg.solve(XtX_reg, Xty) + + # Calculate intercept (for better fit) + self.intercept = np.mean(y - X @ self.weights) + + except np.linalg.LinAlgError: + # Singular matrix, fall back to simple mean + self.weights = np.zeros(X.shape[1]) + self.intercept = np.mean(y) + + def get_model_stats(self) -> dict: + """Get statistics about the model""" + stats = { + 'op_name': self.op_name, + 'total_updates': self.total_updates, + 'total_predictions': self.total_predictions, + 'samples_in_window': len(self.feature_history), + 'model_trained': self.weights is not None, + } + + if self.error_history: + stats['avg_prediction_error_mb'] = float(np.mean(list(self.error_history))) + stats['std_prediction_error_mb'] = float(np.std(list(self.error_history))) + + if self.memory_history: + stats['avg_memory_mb'] = float(np.mean(list(self.memory_history))) + stats['peak_memory_mb'] = float(np.max(list(self.memory_history))) + + return stats + + def export_model(self) -> dict: + """Export model parameters for serialization""" + return { + 'op_name': self.op_name, + 'weights': self.weights.tolist() if self.weights is not None else None, + 'intercept': self.intercept, + 'window_size': self.window_size, + 'total_updates': self.total_updates, + 'stats': self.get_model_stats(), + } + + def import_model(self, model_data: dict): + """Import model parameters""" + self.op_name = model_data['op_name'] + if model_data['weights'] is not None: + self.weights = np.array(model_data['weights']) + self.intercept = model_data['intercept'] + self.total_updates = model_data.get('total_updates', 0) diff --git a/data_juicer/core/elasticjuicer/profiler/__init__.py b/data_juicer/core/elasticjuicer/profiler/__init__.py new file mode 100644 index 00000000000..52d7f8aedd3 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/__init__.py @@ -0,0 +1,20 @@ +""" +Resource Profiling Module + +Provides: +- Lightweight resource monitoring for operators +- Operator Cost Signature (OCS) annotations +- Resource-throughput curve fitting +""" + +from .resource_monitor import ResourceMonitor, MonitoredOp +from .ocs_annotator import OCSAnnotator, OpCostSignature +from .profiling_store import ProfilingStore + +__all__ = [ + "ResourceMonitor", + "MonitoredOp", + "OCSAnnotator", + "OpCostSignature", + "ProfilingStore", +] diff --git a/data_juicer/core/elasticjuicer/profiler/ocs_annotator.py b/data_juicer/core/elasticjuicer/profiler/ocs_annotator.py new file mode 100644 index 00000000000..25621db4c30 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/ocs_annotator.py @@ -0,0 +1,287 @@ +""" +Operator Cost Signature (OCS) Annotator + +Provides semantic annotations for operators based on Alpa's operator modeling: +- Memory Locality: Device preference (CPU-Strong, GPU-Strong, Balanced) +- Transfer Cost: Data movement overhead (Low, Medium, High) +- Failure Cost: Recovery cost from OOM (Low, Medium, High) +- State-free: Whether operator can be safely retried + +Inspired by: +- Alpa's operator cost modeling +- ExoFlow's failure cost analysis +""" + +from enum import Enum +from dataclasses import dataclass, field +from typing import Dict, Optional, List +import json + + +class MemoryLocality(Enum): + """Device preference for operator execution""" + CPU_STRONG = "cpu_strong" # Strongly prefers CPU (e.g., regex, text filters) + GPU_STRONG = "gpu_strong" # Strongly prefers GPU (e.g., VLM, video decoding) + BALANCED = "balanced" # Can run efficiently on either + MIXED = "mixed" # Benefits from CPU-GPU cooperation + + +class TransferCost(Enum): + """Data movement overhead""" + LOW = "low" # < 1MB per sample (text, metadata) + MEDIUM = "medium" # 1-100MB per sample (images) + HIGH = "high" # > 100MB per sample (videos, large models) + + +class FailureCost(Enum): + """Recovery cost from failure""" + LOW = "low" # Fast retry, no state loss + MEDIUM = "medium" # Moderate retry cost + HIGH = "high" # Expensive recomputation (e.g., long video processing) + + +@dataclass +class OpCostSignature: + """ + Cost signature for an operator. + + This is the core of OCS profiling - semantic annotations that guide scheduling. + """ + op_name: str + op_type: str # filter, mapper, deduplicator, etc. + + # Core OCS attributes (based on Alpa) + memory_locality: MemoryLocality = MemoryLocality.BALANCED + transfer_cost: TransferCost = TransferCost.MEDIUM + failure_cost: FailureCost = FailureCost.MEDIUM + + # State properties (based on ExoFlow) + state_free: bool = True # Can be safely retried without side effects + deterministic: bool = True # Same input always produces same output + + # Resource preferences + preferred_batch_size: Optional[int] = None + min_memory_mb: Optional[float] = None + max_memory_mb: Optional[float] = None + + # Modality tags + handles_text: bool = False + handles_image: bool = False + handles_video: bool = False + handles_audio: bool = False + + # Additional metadata + notes: str = "" + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + 'op_name': self.op_name, + 'op_type': self.op_type, + 'memory_locality': self.memory_locality.value, + 'transfer_cost': self.transfer_cost.value, + 'failure_cost': self.failure_cost.value, + 'state_free': self.state_free, + 'deterministic': self.deterministic, + 'preferred_batch_size': self.preferred_batch_size, + 'min_memory_mb': self.min_memory_mb, + 'max_memory_mb': self.max_memory_mb, + 'handles_text': self.handles_text, + 'handles_image': self.handles_image, + 'handles_video': self.handles_video, + 'handles_audio': self.handles_audio, + 'notes': self.notes, + } + + @classmethod + def from_dict(cls, data: Dict) -> 'OpCostSignature': + """Create from dictionary""" + return cls( + op_name=data['op_name'], + op_type=data['op_type'], + memory_locality=MemoryLocality(data.get('memory_locality', 'balanced')), + transfer_cost=TransferCost(data.get('transfer_cost', 'medium')), + failure_cost=FailureCost(data.get('failure_cost', 'medium')), + state_free=data.get('state_free', True), + deterministic=data.get('deterministic', True), + preferred_batch_size=data.get('preferred_batch_size'), + min_memory_mb=data.get('min_memory_mb'), + max_memory_mb=data.get('max_memory_mb'), + handles_text=data.get('handles_text', False), + handles_image=data.get('handles_image', False), + handles_video=data.get('handles_video', False), + handles_audio=data.get('handles_audio', False), + notes=data.get('notes', ''), + ) + + +class OCSAnnotator: + """ + Annotates operators with cost signatures. + + Provides pre-defined annotations for common Data-Juicer operators + and supports custom annotations. + """ + + def __init__(self): + self.signatures: Dict[str, OpCostSignature] = {} + self._load_default_signatures() + + def _load_default_signatures(self): + """Load default OCS signatures for common Data-Juicer operators""" + + # Text Filters - CPU Strong, Low Transfer, Low Failure + text_filter_ops = [ + 'TextLengthFilter', + 'AlphanumericFilter', + 'CharacterRepetitionFilter', + 'WordRepetitionFilter', + 'SpecialCharactersFilter', + ] + for op in text_filter_ops: + self.signatures[op] = OpCostSignature( + op_name=op, + op_type='filter', + memory_locality=MemoryLocality.CPU_STRONG, + transfer_cost=TransferCost.LOW, + failure_cost=FailureCost.LOW, + state_free=True, + deterministic=True, + handles_text=True, + notes="Lightweight text filter, CPU-bound" + ) + + # Image Operations - GPU Preferred, Medium Transfer + image_ops = [ + 'ImageFaceRatioFilter', + 'ImageAestheticFilter', + 'ImageNSFWFilter', + ] + for op in image_ops: + self.signatures[op] = OpCostSignature( + op_name=op, + op_type='filter', + memory_locality=MemoryLocality.GPU_STRONG, + transfer_cost=TransferCost.MEDIUM, + failure_cost=FailureCost.MEDIUM, + state_free=True, + deterministic=True, + handles_image=True, + notes="Image model inference, GPU-accelerated" + ) + + # Video Operations - GPU Strong, High Transfer, High Failure + video_ops = [ + 'VideoDecoder', + 'VideoCaptioning', + 'VideoActionRecognition', + ] + for op in video_ops: + self.signatures[op] = OpCostSignature( + op_name=op, + op_type='mapper', + memory_locality=MemoryLocality.GPU_STRONG, + transfer_cost=TransferCost.HIGH, + failure_cost=FailureCost.HIGH, + state_free=True, + deterministic=True, + handles_video=True, + notes="Heavy video processing, high memory requirement" + ) + + # Deduplicators - Mixed locality, variable cost + self.signatures['DocumentDeduplicator'] = OpCostSignature( + op_name='DocumentDeduplicator', + op_type='deduplicator', + memory_locality=MemoryLocality.CPU_STRONG, + transfer_cost=TransferCost.LOW, + failure_cost=FailureCost.HIGH, + state_free=False, # Maintains hash index + deterministic=True, + handles_text=True, + notes="Hash-based dedup, stateful index" + ) + + self.signatures['ImageDeduplicator'] = OpCostSignature( + op_name='ImageDeduplicator', + op_type='deduplicator', + memory_locality=MemoryLocality.MIXED, + transfer_cost=TransferCost.MEDIUM, + failure_cost=FailureCost.HIGH, + state_free=False, + deterministic=True, + handles_image=True, + notes="Image hash computation, benefits from GPU" + ) + + def annotate(self, op_name: str, signature: OpCostSignature): + """Add or update OCS signature for an operator""" + self.signatures[op_name] = signature + + def get_signature(self, op_name: str) -> Optional[OpCostSignature]: + """Get OCS signature for an operator""" + return self.signatures.get(op_name) + + def get_all_signatures(self) -> Dict[str, OpCostSignature]: + """Get all registered signatures""" + return dict(self.signatures) + + def export_to_file(self, filepath: str): + """Export signatures to JSON file""" + data = { + name: sig.to_dict() + for name, sig in self.signatures.items() + } + with open(filepath, 'w') as f: + json.dump(data, f, indent=2) + + def import_from_file(self, filepath: str): + """Import signatures from JSON file""" + with open(filepath, 'r') as f: + data = json.load(f) + + for name, sig_dict in data.items(): + self.signatures[name] = OpCostSignature.from_dict(sig_dict) + + def infer_signature(self, op_name: str, op_type: str, **hints) -> OpCostSignature: + """ + Infer OCS signature from operator name and hints. + + This provides a best-effort annotation for unknown operators. + """ + # Default values + locality = MemoryLocality.BALANCED + transfer = TransferCost.MEDIUM + failure = FailureCost.MEDIUM + + # Infer from name patterns + op_lower = op_name.lower() + + if 'video' in op_lower: + locality = MemoryLocality.GPU_STRONG + transfer = TransferCost.HIGH + failure = FailureCost.HIGH + elif 'image' in op_lower: + locality = MemoryLocality.GPU_STRONG + transfer = TransferCost.MEDIUM + elif 'text' in op_lower or 'word' in op_lower or 'character' in op_lower: + locality = MemoryLocality.CPU_STRONG + transfer = TransferCost.LOW + failure = FailureCost.LOW + + # Apply hints if provided + if 'accelerator' in hints and hints['accelerator'] == 'cuda': + locality = MemoryLocality.GPU_STRONG + + return OpCostSignature( + op_name=op_name, + op_type=op_type, + memory_locality=locality, + transfer_cost=transfer, + failure_cost=failure, + handles_text='text' in op_lower, + handles_image='image' in op_lower, + handles_video='video' in op_lower, + handles_audio='audio' in op_lower, + notes="Auto-inferred signature" + ) diff --git a/data_juicer/core/elasticjuicer/profiler/probe_adapter.py b/data_juicer/core/elasticjuicer/profiler/probe_adapter.py new file mode 100644 index 00000000000..f781f1c3d01 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/probe_adapter.py @@ -0,0 +1,353 @@ +"""Bridge between Stock Data-Juicer Adapter and ElasticJuicer ProfilingStore. + +Translates the system-wide resource measurements produced by +``Adapter.adapt_workloads()`` (in ``data_juicer/core/adapter.py``) into the +per-process schema expected by ElasticJuicer's ProfilingStore +(``ResourceSnapshot`` / ``OpExecutionStats``). + +Background +---------- +DJ Adapter probes a small batch and uses ``psutil`` / ``GPUtil`` to measure +SYSTEM-WIDE resources, then uses a 90% utilization heuristic to size per-OP +batches. ElasticJuicer's ProfilingStore expects PER-PROCESS RSS metrics. + +Several translation steps are lossy: + +========================== ================================ =================== +DJ Adapter field EJ field Translation rule +========================== ================================ =================== +``"CPU util."`` [0,1] ``cpu_percent`` [0,100] ``x * 100`` +``"Used mem."`` system MB ``memory_mb`` process MB direct, confidence=0.5 +``"GPU used mem."`` list ``gpu_memory_mb`` scalar ``[0]`` (warn if len>1) +``"GPU util."`` list[0,1] ``gpu_utilization`` [0,100] ``[0] * 100`` +``"speed"`` ``throughput`` direct +``"time"`` total seconds ``latency_ms`` per-batch ``1000 * bs / speed`` +``"timestamp"`` ``timestamp`` direct +(absent) ``batch_size`` external injection +(absent) ``sample_features`` None (no content data) +========================== ================================ =================== + +Per-op confidence scores are tracked in ``self.confidence_by_op`` out-of-band so +we don't have to extend the ``OpExecutionStats`` dataclass schema in this PR +(that change is owned by PR-2 schema versioning). + +Usage +----- + from data_juicer.core.elasticjuicer.profiler.probe_adapter import ProbeAdapter + from data_juicer.core.elasticjuicer.profiler.profiling_store import ProfilingStore + + store = ProfilingStore('./elastic_juicer_profiles') + bridge = ProbeAdapter(store) + + # In PR-4, `adapter._last_analysis` is added by a small change to adapter.py + # that stashes the probe-results list before `adapt_workloads` returns. + bridge.ingest_probe_results( + probe_results=adapter._last_analysis, + op_names=[op._name for op in ops], + probe_batch_sizes=bs_per_op, + ) + +This module is pure library code in PR-1: no integration with DefaultExecutor +or ElasticRayExecutor. Those wirings are PR-4 and PR-5 respectively. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Tuple + +from .ocs_annotator import MemoryLocality, OCSAnnotator, OpCostSignature +from .profiling_store import ProfilingStore +from .resource_monitor import OpExecutionStats, ResourceSnapshot + +logger = logging.getLogger(__name__) + + +class ProbeAdapter: + """Translate DJ Adapter probe results to ElasticJuicer ProfilingStore schema.""" + + SCHEMA_VERSION = "1.0" + + # Confidence markers in [0, 1]. 1.0 means "trust this number"; lower values + # mean "translation was lossy, downstream consumers should treat as an upper + # bound or otherwise discount". + CONFIDENCE_FULL: float = 1.0 + CONFIDENCE_SYSTEM_MEMORY_AS_PROCESS: float = 0.5 # system MB used as process RSS proxy + CONFIDENCE_FIRST_GPU_ONLY: float = 0.3 # multi-GPU collapsed to GPU[0] + + # Heuristic thresholds for memory locality inference (on GPU util max). + GPU_STRONG_UTIL_THRESHOLD: float = 0.5 + GPU_BALANCED_UTIL_THRESHOLD: float = 0.1 + + def __init__( + self, + store: ProfilingStore, + annotator: Optional[OCSAnnotator] = None, + ) -> None: + self.store = store + self.annotator = annotator or OCSAnnotator() + # Out-of-band confidence dict: {op_name: {field_name: confidence}}. + # Stored on the bridge instance, not in OpExecutionStats, so PR-1 does + # not need to coordinate with PR-2's schema versioning changes. + self.confidence_by_op: Dict[str, Dict[str, float]] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def ingest_probe_results( + self, + probe_results: List[Dict[str, Any]], + op_names: List[str], + probe_batch_sizes: List[int], + ) -> Dict[str, OpExecutionStats]: + """Translate DJ Adapter probe outputs and persist to ProfilingStore. + + Parameters + ---------- + probe_results + List of probe dicts as produced by ``Adapter.adapt_workloads`` + after ``Monitor.analyze_resource_util_list`` has been applied. + Each dict is expected to contain keys ``"time"``, ``"speed"``, + ``"resource"`` (list of per-tick snapshots) and + ``"resource_analysis"`` (max/min/avg per DYNAMIC_FIELD). + op_names + Parallel list of op names (one per probe dict). + probe_batch_sizes + Parallel list of batch sizes used during probing. + + Returns + ------- + Dict mapping ``op_name`` to the ``OpExecutionStats`` that was written. + Ops that fail to translate are logged but skipped (other ops still + proceed). + + Raises + ------ + ValueError + If the three input lists have mismatched lengths. + """ + if not (len(probe_results) == len(op_names) == len(probe_batch_sizes)): + raise ValueError( + "Length mismatch: " + f"probe_results={len(probe_results)}, " + f"op_names={len(op_names)}, " + f"probe_batch_sizes={len(probe_batch_sizes)}" + ) + + written: Dict[str, OpExecutionStats] = {} + for probe_dict, op_name, bs in zip(probe_results, op_names, probe_batch_sizes): + try: + stats, confidence = self._translate_one(probe_dict, op_name, bs) + self.store.update_execution_stats(op_name, stats) + written[op_name] = stats + self.confidence_by_op[op_name] = confidence + + sig = self._derive_signature(probe_dict, op_name, stats) + if sig is not None: + self.store.update_ocs_signature(op_name, sig) + except Exception as e: + # One bad op should not block the rest. + logger.error( + "Failed to translate probe for op '%s' (batch_size=%s): %s", + op_name, bs, e, exc_info=True, + ) + + try: + self.store.save_all() + logger.info( + "Ingested %d/%d probe results to ProfilingStore " + "(schema_version=%s)", + len(written), len(op_names), self.SCHEMA_VERSION, + ) + except Exception as e: + logger.error("ProfilingStore.save_all() failed after ingest: %s", e, exc_info=True) + + return written + + def get_confidence(self, op_name: str, field: str) -> float: + """Return tracked confidence for ``(op_name, field)``. + + Returns ``CONFIDENCE_FULL`` if the op or field is unknown (the caller + is implicitly trusting the value). + """ + return self.confidence_by_op.get(op_name, {}).get(field, self.CONFIDENCE_FULL) + + # ------------------------------------------------------------------ + # Translation primitives + # ------------------------------------------------------------------ + + def _translate_one( + self, + probe: Dict[str, Any], + op_name: str, + batch_size: int, + ) -> Tuple[OpExecutionStats, Dict[str, float]]: + """Translate one probe dict into ``(OpExecutionStats, confidence map)``.""" + speed = float(probe.get("speed", 0.0) or 0.0) # samples / sec + if speed > 0: + latency_ms_per_batch = (1000.0 / speed) * float(batch_size) + else: + latency_ms_per_batch = 0.0 + + # Per-op confidence dict. GPU-related fields may be downgraded below + # if any sample shows multi-GPU. + confidence: Dict[str, float] = { + "cpu_percent": self.CONFIDENCE_FULL, + "memory_mb": self.CONFIDENCE_SYSTEM_MEMORY_AS_PROCESS, + "gpu_memory_mb": self.CONFIDENCE_FULL, + "gpu_utilization": self.CONFIDENCE_FULL, + "throughput": self.CONFIDENCE_FULL, + "latency_ms": self.CONFIDENCE_FULL if speed > 0 else 0.0, + } + + stats = OpExecutionStats(op_name=op_name) + + raw_snapshots = probe.get("resource") or [] + for raw in raw_snapshots: + gpu_mem, gpu_mem_conf = self._extract_gpu_mem(raw) + gpu_util, gpu_util_conf = self._extract_gpu_util(raw) + + # Op-level confidence drops to the worst per-snapshot confidence. + if gpu_mem_conf < confidence["gpu_memory_mb"]: + confidence["gpu_memory_mb"] = gpu_mem_conf + if gpu_util_conf < confidence["gpu_utilization"]: + confidence["gpu_utilization"] = gpu_util_conf + + snap = ResourceSnapshot( + timestamp=float(raw.get("timestamp", 0.0) or 0.0), + batch_size=int(batch_size), + cpu_percent=float(raw.get("CPU util.", 0.0) or 0.0) * 100.0, + memory_mb=float(raw.get("Used mem.", 0.0) or 0.0), + gpu_memory_mb=gpu_mem, + gpu_utilization=gpu_util, + latency_ms=latency_ms_per_batch, + throughput=speed, + ) + # OpExecutionStats.update appends to snapshots and recomputes + # the aggregate fields (avg/p95/p99/peak). + stats.update(snap) + + return stats, confidence + + @staticmethod + def _extract_gpu_mem(raw: Dict[str, Any]) -> Tuple[Optional[float], float]: + """Extract scalar GPU memory (MB) from a probe snapshot. + + Returns ``(value_mb, confidence)``. ``value_mb`` is ``None`` if the + snapshot has no GPU data. Confidence drops to + ``CONFIDENCE_FIRST_GPU_ONLY`` and a warning is logged if multi-GPU. + """ + gpu_used = raw.get("GPU used mem.") + if gpu_used is None or len(gpu_used) == 0: + return None, ProbeAdapter.CONFIDENCE_FULL + + if len(gpu_used) > 1: + logger.warning( + "Multi-GPU probe detected (%d GPUs); using gpus[0]=%.1f MB. " + "ElasticJuicer currently assumes single-GPU.", + len(gpu_used), float(gpu_used[0]), + ) + return float(gpu_used[0]), ProbeAdapter.CONFIDENCE_FIRST_GPU_ONLY + + return float(gpu_used[0]), ProbeAdapter.CONFIDENCE_FULL + + @staticmethod + def _extract_gpu_util(raw: Dict[str, Any]) -> Tuple[Optional[float], float]: + """Extract scalar GPU utilization (percent) from a probe snapshot. + + DJ Adapter reports ratio in ``[0, 1]``; EJ expects percent ``[0, 100]``. + Multi-GPU is collapsed to ``gpus[0]`` with reduced confidence + (warning is emitted by ``_extract_gpu_mem``; not duplicated here). + """ + gpu_util = raw.get("GPU util.") + if gpu_util is None or len(gpu_util) == 0: + return None, ProbeAdapter.CONFIDENCE_FULL + + confidence = ( + ProbeAdapter.CONFIDENCE_FIRST_GPU_ONLY + if len(gpu_util) > 1 + else ProbeAdapter.CONFIDENCE_FULL + ) + return float(gpu_util[0]) * 100.0, confidence + + # ------------------------------------------------------------------ + # Signature derivation + # ------------------------------------------------------------------ + + def _derive_signature( + self, + probe: Dict[str, Any], + op_name: str, + stats: OpExecutionStats, + ) -> Optional[OpCostSignature]: + """Derive an ``OpCostSignature`` from probe + stats + heuristics. + + Auto-derived (approximately 6 of the 14 signature fields): + + * ``preferred_batch_size`` -- from the probe batch size + * ``min_memory_mb`` / ``max_memory_mb`` -- from stats peaks + * ``memory_locality`` -- from GPU util max in ``resource_analysis`` + * ``handles_{text,image,video,audio}`` -- inherited from the + existing ``OCSAnnotator`` substring heuristics + + The remaining fields (``op_type``, ``transfer_cost``, ``failure_cost``, + ``state_free``, ``deterministic``) default to ``OCSAnnotator``'s + registered or inferred defaults. + """ + existing = self.annotator.get_signature(op_name) + if existing is not None: + sig = existing + else: + sig = self.annotator.infer_signature(op_name, op_type="unknown") + + # Memory bounds from probe stats. + if stats.peak_memory_mb > 0: + sig.max_memory_mb = stats.peak_memory_mb + if stats.avg_memory_mb > 0: + # Conservative lower bound: half the average. + sig.min_memory_mb = stats.avg_memory_mb * 0.5 + + if stats.snapshots: + sig.preferred_batch_size = int(stats.snapshots[0].batch_size) + + # Memory locality from GPU activity during the probe window. + analysis = probe.get("resource_analysis") or {} + gpu_util_stats = analysis.get("GPU util.") or {} + gpu_util_max = 0.0 + if isinstance(gpu_util_stats, dict): + gpu_util_max = float(gpu_util_stats.get("max", 0.0) or 0.0) + + if gpu_util_max > self.GPU_STRONG_UTIL_THRESHOLD: + sig.memory_locality = MemoryLocality.GPU_STRONG + elif gpu_util_max > self.GPU_BALANCED_UTIL_THRESHOLD: + sig.memory_locality = MemoryLocality.BALANCED + # Otherwise leave whatever default OCSAnnotator gave us. + + return sig + + # ------------------------------------------------------------------ + # Convenience helpers + # ------------------------------------------------------------------ + + @staticmethod + def attach_to_default_executor( + executor: Any, + storage_dir: str = "./elastic_juicer_profiles", + ) -> "ProbeAdapter": + """Wire a ProbeAdapter onto a ``DefaultExecutor`` instance. + + Call AFTER ``executor.adapter.adapt_workloads(...)`` returns. PR-4 + adds the call site in ``default_executor.py``; this helper is provided + so unit / integration tests can do the same wiring without touching + the executor source. + + Returns + ------- + The constructed ``ProbeAdapter``, also stashed on + ``executor._probe_bridge`` for later inspection. + """ + store = ProfilingStore(storage_dir=storage_dir) + bridge = ProbeAdapter(store) + executor._probe_bridge = bridge + return bridge diff --git a/data_juicer/core/elasticjuicer/profiler/profiling_store.py b/data_juicer/core/elasticjuicer/profiler/profiling_store.py new file mode 100644 index 00000000000..d30ccd9b418 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/profiling_store.py @@ -0,0 +1,344 @@ +""" +Profiling Store + +Persistent storage and query interface for: +- Resource-throughput curves +- OCS signatures +- Historical performance data + +Supports online learning and model updating. +""" + +import json +import logging +import pickle +import time as _time +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, asdict +import numpy as np +from scipy.optimize import curve_fit + +from .resource_monitor import OpExecutionStats, ResourceSnapshot +from .ocs_annotator import OpCostSignature + +logger = logging.getLogger(__name__) + + +@dataclass +class ResourceThroughputCurve: + """ + Resource-throughput relationship for an operator. + + Models T(r, b) where: + - T = throughput (samples/sec) + - r = resource allocation (memory, GPU) + - b = batch size + """ + op_name: str + # Curve parameters (fitted from data) + coefficients: Dict[str, float] + # Model type: 'linear', 'polynomial', 'power' + model_type: str = 'linear' + # Goodness of fit + r_squared: float = 0.0 + # Sample count used for fitting + n_samples: int = 0 + + def predict_throughput(self, batch_size: int, memory_mb: float) -> float: + """Predict throughput given batch size and memory""" + if self.model_type == 'linear': + # T = a * batch_size + b * memory + c + a = self.coefficients.get('batch_coef', 0) + b = self.coefficients.get('memory_coef', 0) + c = self.coefficients.get('intercept', 0) + return max(0, a * batch_size + b * memory_mb + c) + + elif self.model_type == 'power': + # T = a * batch_size^b + a = self.coefficients.get('scale', 1) + b = self.coefficients.get('power', 1) + return a * (batch_size ** b) + + return 0.0 + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict) -> 'ResourceThroughputCurve': + """Create from dictionary""" + return cls(**data) + + +class ProfilingStore: + """ + Persistent store for operator profiling data. + + Provides: + - Storage and retrieval of execution stats + - Resource-throughput curve fitting + - Online model updates + - Query interface for schedulers + """ + + SCHEMA_VERSION = "1.0" + + def __init__(self, storage_dir: str = "./elastic_juicer_profiles"): + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(parents=True, exist_ok=True) + + # In-memory caches + self.execution_stats: Dict[str, OpExecutionStats] = {} + self.ocs_signatures: Dict[str, OpCostSignature] = {} + self.throughput_curves: Dict[str, ResourceThroughputCurve] = {} + + # Load existing data + self._load_all() + + def _load_all(self): + """Load all stored profiles, checking schema version.""" + # Load execution stats (versioned pickle wrapper) + stats_file = self.storage_dir / "execution_stats.pkl" + if stats_file.exists(): + try: + with open(stats_file, 'rb') as f: + raw = pickle.load(f) + if isinstance(raw, dict) and "meta" in raw and "stats" in raw: + file_version = raw["meta"].get("schema_version") + if file_version != self.SCHEMA_VERSION: + logger.warning( + "ProfilingStore schema version mismatch: " + "file=%s, code=%s. Skipping load.", + file_version, self.SCHEMA_VERSION, + ) + else: + self.execution_stats = raw["stats"] + elif isinstance(raw, dict): + # Legacy format (pre-versioning): dict[str, OpExecutionStats] + logger.warning( + "ProfilingStore data has no schema version (legacy). " + "Loading anyway; will be re-saved with version on next save_all()." + ) + self.execution_stats = raw + else: + logger.warning( + "ProfilingStore data has unexpected type %s. Skipping.", + type(raw).__name__, + ) + except Exception as e: + logger.warning("Failed to load execution_stats.pkl: %s", e) + + # Load OCS signatures + ocs_file = self.storage_dir / "ocs_signatures.json" + if ocs_file.exists(): + with open(ocs_file, 'r') as f: + data = json.load(f) + self.ocs_signatures = { + name: OpCostSignature.from_dict(sig) + for name, sig in data.items() + } + + # Load throughput curves + curves_file = self.storage_dir / "throughput_curves.json" + if curves_file.exists(): + with open(curves_file, 'r') as f: + data = json.load(f) + self.throughput_curves = { + name: ResourceThroughputCurve.from_dict(curve) + for name, curve in data.items() + } + + def save_all(self): + """Persist all profiles to disk (versioned wrapper).""" + # Save execution stats with schema version + stats_file = self.storage_dir / "execution_stats.pkl" + wrapped = { + "meta": { + "schema_version": self.SCHEMA_VERSION, + "saved_at": _time.time(), + }, + "stats": self.execution_stats, + } + with open(stats_file, 'wb') as f: + pickle.dump(wrapped, f) + logger.info( + "ProfilingStore saved %d entries (schema=%s)", + len(self.execution_stats), self.SCHEMA_VERSION, + ) + + # Save OCS signatures + ocs_file = self.storage_dir / "ocs_signatures.json" + with open(ocs_file, 'w') as f: + data = { + name: sig.to_dict() + for name, sig in self.ocs_signatures.items() + } + json.dump(data, f, indent=2) + + # Save throughput curves + curves_file = self.storage_dir / "throughput_curves.json" + with open(curves_file, 'w') as f: + data = { + name: curve.to_dict() + for name, curve in self.throughput_curves.items() + } + json.dump(data, f, indent=2) + + def update_execution_stats(self, op_name: str, stats: OpExecutionStats): + """Update execution statistics for an operator""" + self.execution_stats[op_name] = stats + self._fit_throughput_curve(op_name, stats) + + def update_ocs_signature(self, op_name: str, signature: OpCostSignature): + """Update OCS signature for an operator""" + self.ocs_signatures[op_name] = signature + + def get_execution_stats(self, op_name: str) -> Optional[OpExecutionStats]: + """Get execution statistics for an operator""" + return self.execution_stats.get(op_name) + + def get_ocs_signature(self, op_name: str) -> Optional[OpCostSignature]: + """Get OCS signature for an operator""" + return self.ocs_signatures.get(op_name) + + def get_throughput_curve(self, op_name: str) -> Optional[ResourceThroughputCurve]: + """Get resource-throughput curve for an operator""" + return self.throughput_curves.get(op_name) + + def _fit_throughput_curve(self, op_name: str, stats: OpExecutionStats): + """ + Fit resource-throughput curve from execution statistics. + + Uses online learning approach (inspired by Autothrottle). + """ + if len(stats.snapshots) < 5: + # Not enough data points + return + + # Extract features and target + batch_sizes = np.array([s.batch_size for s in stats.snapshots]) + memories = np.array([s.memory_mb for s in stats.snapshots]) + throughputs = np.array([s.throughput for s in stats.snapshots]) + + # Filter out invalid data + valid_idx = throughputs > 0 + if valid_idx.sum() < 5: + return + + batch_sizes = batch_sizes[valid_idx] + memories = memories[valid_idx] + throughputs = throughputs[valid_idx] + + try: + # Try linear model first: T = a*batch + b*mem + c + X = np.column_stack([batch_sizes, memories, np.ones_like(batch_sizes)]) + coeffs, residuals, _, _ = np.linalg.lstsq(X, throughputs, rcond=None) + + # Calculate R² + ss_res = residuals[0] if len(residuals) > 0 else 0 + ss_tot = np.sum((throughputs - np.mean(throughputs)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + curve = ResourceThroughputCurve( + op_name=op_name, + coefficients={ + 'batch_coef': float(coeffs[0]), + 'memory_coef': float(coeffs[1]), + 'intercept': float(coeffs[2]), + }, + model_type='linear', + r_squared=float(r_squared), + n_samples=len(batch_sizes), + ) + + self.throughput_curves[op_name] = curve + + except Exception as e: + # Fitting failed, use simple mean + pass + + def predict_memory_for_batch(self, op_name: str, batch_size: int) -> Optional[float]: + """ + Predict memory usage for a given batch size. + + Based on historical data with online learning. + """ + stats = self.execution_stats.get(op_name) + if not stats or len(stats.snapshots) < 3: + return None + + # Simple linear regression: memory = a * batch_size + b + batch_sizes = np.array([s.batch_size for s in stats.snapshots]) + memories = np.array([s.memory_mb for s in stats.snapshots]) + + try: + # Fit linear model + coeffs = np.polyfit(batch_sizes, memories, deg=1) + predicted = coeffs[0] * batch_size + coeffs[1] + return float(predicted) + except Exception: + # Fall back to average + return float(np.mean(memories)) + + def get_safe_batch_size(self, op_name: str, available_memory_mb: float, + safety_margin: float = 0.9) -> int: + """ + Recommend safe batch size given available memory. + + Args: + op_name: Operator name + available_memory_mb: Available memory in MB + safety_margin: Use only this fraction of available memory (default 90%) + + Returns: + Recommended batch size + """ + stats = self.execution_stats.get(op_name) + if not stats or len(stats.snapshots) < 3: + return 1 # Conservative default + + # Find batch sizes and their memory usage + batch_sizes = np.array([s.batch_size for s in stats.snapshots]) + memories = np.array([s.memory_mb for s in stats.snapshots]) + + # Calculate memory per sample + mem_per_sample = memories / batch_sizes + avg_mem_per_sample = np.median(mem_per_sample) # Use median for robustness + + # Calculate safe batch size + target_memory = available_memory_mb * safety_margin + safe_batch = int(target_memory / avg_mem_per_sample) + + return max(1, safe_batch) + + def export_report(self, output_file: str): + """Export profiling report as markdown""" + lines = ["# ElasticJuicer Profiling Report\n"] + + lines.append("## Operator Execution Statistics\n") + for op_name, stats in sorted(self.execution_stats.items()): + lines.append(f"### {op_name}\n") + lines.append(f"- Total Samples: {stats.total_samples}") + lines.append(f"- Total Batches: {stats.total_batches}") + lines.append(f"- Avg Latency: {stats.avg_latency_ms:.2f} ms") + lines.append(f"- P95 Latency: {stats.p95_latency_ms:.2f} ms") + lines.append(f"- Avg Throughput: {stats.avg_throughput:.2f} samples/s") + lines.append(f"- Peak Memory: {stats.peak_memory_mb:.2f} MB") + if stats.peak_gpu_memory_mb: + lines.append(f"- Peak GPU Memory: {stats.peak_gpu_memory_mb:.2f} MB") + lines.append("") + + lines.append("\n## OCS Signatures\n") + for op_name, sig in sorted(self.ocs_signatures.items()): + lines.append(f"### {op_name}") + lines.append(f"- Type: {sig.op_type}") + lines.append(f"- Memory Locality: {sig.memory_locality.value}") + lines.append(f"- Transfer Cost: {sig.transfer_cost.value}") + lines.append(f"- Failure Cost: {sig.failure_cost.value}") + lines.append(f"- State Free: {sig.state_free}") + lines.append("") + + with open(output_file, 'w') as f: + f.writelines(line + '\n' for line in lines) diff --git a/data_juicer/core/elasticjuicer/profiler/resource_monitor.py b/data_juicer/core/elasticjuicer/profiler/resource_monitor.py new file mode 100644 index 00000000000..a2dc7cde09a --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/resource_monitor.py @@ -0,0 +1,376 @@ +""" +Resource Monitor - Thin Adapter over Ray Memory Infrastructure + +Operator-level execution measurement (latency, throughput) is unique to +ElasticJuicer. System-level memory polling delegates to Ray event-based +memory monitor when available, falling back to psutil outside Ray. + +Based on Pollux-style agent monitoring. +""" + +import time +import psutil +import threading +from dataclasses import dataclass, field +from typing import Optional, Dict, List, Any, Tuple, Set +from collections import defaultdict +import numpy as np + +try: + import GPUtil + GPU_AVAILABLE = True +except ImportError: + GPU_AVAILABLE = False + + +@dataclass +class SampleFeatures: + """Sample-level content features for cost model conditioning.""" + duration_s: float = 0.0 # Video/audio duration in seconds + frame_count: int = 0 # Video frame count + resolution_hw: Tuple[int, int] = (0, 0) # (height, width) + text_length: int = 0 # Text character/token length + modality: Set[str] = field(default_factory=set) # {'text','image','video','audio'} + + def to_vector(self) -> list: + """Convert to numerical vector for model input.""" + import math + return [ + math.log(self.duration_s + 1), + math.log(self.frame_count + 1), + math.log(self.resolution_hw[0] * self.resolution_hw[1] + 1), + math.log(self.text_length + 1), + 1.0 if 'text' in self.modality else 0.0, + 1.0 if 'image' in self.modality else 0.0, + 1.0 if 'video' in self.modality else 0.0, + 1.0 if 'audio' in self.modality else 0.0, + ] + + +@dataclass +class ResourceSnapshot: + """Single measurement of resource usage""" + timestamp: float + batch_size: int + # CPU metrics + cpu_percent: float + memory_mb: float + # GPU metrics (if available) + gpu_memory_mb: Optional[float] = None + gpu_utilization: Optional[float] = None + # Performance metrics + latency_ms: float = 0.0 + throughput: float = 0.0 # samples/sec + # Sample features (optional) + sample_features: Optional['SampleFeatures'] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + 'timestamp': self.timestamp, + 'batch_size': self.batch_size, + 'cpu_percent': self.cpu_percent, + 'memory_mb': self.memory_mb, + 'gpu_memory_mb': self.gpu_memory_mb, + 'gpu_utilization': self.gpu_utilization, + 'latency_ms': self.latency_ms, + 'throughput': self.throughput, + 'sample_features': self.sample_features, + } + + +@dataclass +class OpExecutionStats: + """Aggregated statistics for an operator""" + op_name: str + total_samples: int = 0 + total_batches: int = 0 + avg_latency_ms: float = 0.0 + p95_latency_ms: float = 0.0 + p99_latency_ms: float = 0.0 + avg_throughput: float = 0.0 + avg_memory_mb: float = 0.0 + peak_memory_mb: float = 0.0 + avg_gpu_memory_mb: Optional[float] = None + peak_gpu_memory_mb: Optional[float] = None + snapshots: List[ResourceSnapshot] = field(default_factory=list) + + def update(self, snapshot: ResourceSnapshot): + """Update statistics with new snapshot""" + self.snapshots.append(snapshot) + self.total_samples += snapshot.batch_size + self.total_batches += 1 + + # Update averages + latencies = [s.latency_ms for s in self.snapshots] + self.avg_latency_ms = np.mean(latencies) + self.p95_latency_ms = np.percentile(latencies, 95) + self.p99_latency_ms = np.percentile(latencies, 99) + + throughputs = [s.throughput for s in self.snapshots if s.throughput > 0] + if throughputs: + self.avg_throughput = np.mean(throughputs) + + memories = [s.memory_mb for s in self.snapshots] + self.avg_memory_mb = np.mean(memories) + self.peak_memory_mb = max(memories) + + if snapshot.gpu_memory_mb is not None: + gpu_mems = [s.gpu_memory_mb for s in self.snapshots if s.gpu_memory_mb is not None] + if gpu_mems: + self.avg_gpu_memory_mb = np.mean(gpu_mems) + self.peak_gpu_memory_mb = max(gpu_mems) + + +class ResourceMonitor: + """ + Lightweight resource monitor for operators. + + Inspired by PolluxAgent - measures resource-throughput curves in real-time. + """ + + # Only actually poll expensive metrics every Nth call + _POLL_INTERVAL = 5 + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.stats_by_op: Dict[str, OpExecutionStats] = defaultdict(OpExecutionStats) + self._lock = threading.Lock() + self.process = psutil.Process() + + # Throttle state for expensive polls + self._poll_count = 0 + self._cached_cpu_percent = 0.0 + self._cached_gpu_stats = None + + # Cache library availability (avoid try/except on every poll) + self._has_ray = False + self._has_gputil = False + self._ray_memory_monitor = None + + try: + import ray + if ray.is_initialized(): + self._has_ray = True + try: + from ray._private.memory_monitor import MemoryMonitor + self._ray_memory_monitor = MemoryMonitor() + except (ImportError, AttributeError): + pass + except ImportError: + pass + + try: + import GPUtil # noqa: F401 + self._has_gputil = True + except ImportError: + pass + + def measure_execution(self, op_name: str, batch_size: int, sample_features: Optional['SampleFeatures'] = None): + """ + Context manager to measure operator execution. + + Usage: + with monitor.measure_execution("my_filter", batch_size=100): + # Process batch + result = op.process(batch) + """ + return ExecutionContext(self, op_name, batch_size, sample_features=sample_features) + + def record_snapshot(self, op_name: str, snapshot: ResourceSnapshot): + """Record a resource snapshot for an operator""" + if not self.enabled: + return + + with self._lock: + if op_name not in self.stats_by_op: + self.stats_by_op[op_name] = OpExecutionStats(op_name=op_name) + self.stats_by_op[op_name].update(snapshot) + + def get_stats(self, op_name: str) -> Optional[OpExecutionStats]: + """Get statistics for a specific operator""" + return self.stats_by_op.get(op_name) + + def get_all_stats(self) -> Dict[str, OpExecutionStats]: + """Get statistics for all operators""" + return dict(self.stats_by_op) + + def clear(self): + """Clear all collected statistics""" + with self._lock: + self.stats_by_op.clear() + + def flush_to_store(self, store: 'ProfilingStore') -> int: + """Transfer all per-op stats to a ProfilingStore. + + Each op's ``OpExecutionStats`` is written via + ``store.update_execution_stats``. The monitor's internal state is NOT + cleared so subsequent calls are idempotent (the store simply overwrites + the same keys with potentially newer data). + + Args: + store: Target :class:`ProfilingStore` instance. + + Returns: + Number of ops flushed. + """ + with self._lock: + snapshot = dict(self.stats_by_op) + + flushed = 0 + for op_name, stats in snapshot.items(): + store.update_execution_stats(op_name, stats) + flushed += 1 + return flushed + + def _get_current_resources(self) -> Dict[str, Any]: + """Get current resource usage. + + Delegates to Ray's memory infrastructure when running inside Ray, + falling back to psutil for standalone execution. + + Expensive metrics (cpu_percent, GPUtil.getGPUs) are throttled and + only refreshed every ``_POLL_INTERVAL`` calls; other calls reuse + cached values. Memory (psutil) is cheap and polled every call. + """ + self._poll_count += 1 + + # Refresh expensive metrics on first call and every Nth call thereafter + should_poll = ( + self._poll_count == 1 + or self._poll_count % self._POLL_INTERVAL == 1 + ) + if should_poll: + self._cached_cpu_percent = self.process.cpu_percent() + if self._has_gputil: + import GPUtil + try: + self._cached_gpu_stats = GPUtil.getGPUs() + except Exception: + self._cached_gpu_stats = None + + cpu_percent = self._cached_cpu_percent + memory_mb = None + gpu_memory_mb = None + gpu_utilization = None + + # Try Ray's memory monitor first (event-based, lower overhead) + if self._has_ray and self._ray_memory_monitor is not None: + try: + system_memory = self._ray_memory_monitor.get_memory_usage() + if system_memory is not None: + memory_mb = system_memory / (1024 * 1024) + except Exception: + pass + + # Fallback to psutil if Ray monitor didn't provide memory (cheap call) + if memory_mb is None: + memory_mb = self.process.memory_info().rss / (1024 * 1024) + + # Use cached GPU stats + gpus = self._cached_gpu_stats + if gpus: + gpu = gpus[0] + gpu_memory_mb = gpu.memoryUsed + gpu_utilization = gpu.load * 100 + + return { + 'cpu_percent': cpu_percent, + 'memory_mb': memory_mb, + 'gpu_memory_mb': gpu_memory_mb, + 'gpu_utilization': gpu_utilization, + } + + +class ExecutionContext: + """Context manager for measuring operator execution""" + + def __init__(self, monitor: ResourceMonitor, op_name: str, batch_size: int, sample_features: Optional['SampleFeatures'] = None): + self.monitor = monitor + self.op_name = op_name + self.batch_size = batch_size + self.sample_features = sample_features + self.start_time = None + + def __enter__(self): + if self.monitor.enabled: + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.monitor.enabled or self.start_time is None: + return + + # Calculate latency + end_time = time.time() + latency_s = end_time - self.start_time + latency_ms = latency_s * 1000 + + # Calculate throughput + throughput = self.batch_size / latency_s if latency_s > 0 else 0 + + # Get resource usage + resources = self.monitor._get_current_resources() + + # Create snapshot + snapshot = ResourceSnapshot( + timestamp=end_time, + batch_size=self.batch_size, + cpu_percent=resources['cpu_percent'], + memory_mb=resources['memory_mb'], + gpu_memory_mb=resources['gpu_memory_mb'], + gpu_utilization=resources['gpu_utilization'], + latency_ms=latency_ms, + throughput=throughput, + sample_features=self.sample_features, + ) + + # Record snapshot + self.monitor.record_snapshot(self.op_name, snapshot) + + +class MonitoredOp: + """ + Wrapper to inject monitoring into Data-Juicer operators. + + Usage: + original_op = SomeFilter(**config) + monitored_op = MonitoredOp(original_op, monitor) + """ + + def __init__(self, operator, monitor: ResourceMonitor): + self.operator = operator + self.monitor = monitor + self.op_name = operator.__class__.__name__ + + def __getattr__(self, name): + """Delegate attribute access to wrapped operator""" + return getattr(self.operator, name) + + def process(self, *args, **kwargs): + """Wrap process method with monitoring""" + # Estimate batch size + batch_size = self._estimate_batch_size(args, kwargs) + + with self.monitor.measure_execution(self.op_name, batch_size): + return self.operator.process(*args, **kwargs) + + def compute_stats(self, *args, **kwargs): + """Wrap compute_stats method with monitoring (for filters)""" + batch_size = self._estimate_batch_size(args, kwargs) + + with self.monitor.measure_execution(f"{self.op_name}_stats", batch_size): + return self.operator.compute_stats(*args, **kwargs) + + def _estimate_batch_size(self, args, kwargs) -> int: + """Estimate batch size from arguments""" + # For single sample: return 1 + # For batched: try to extract from first argument (usually a dict/dataset) + if args: + sample = args[0] + if isinstance(sample, dict): + # Check if it's batched data + for value in sample.values(): + if isinstance(value, list): + return len(value) + return 1 diff --git a/data_juicer/core/elasticjuicer/scheduler/__init__.py b/data_juicer/core/elasticjuicer/scheduler/__init__.py new file mode 100644 index 00000000000..013666352b6 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/__init__.py @@ -0,0 +1,22 @@ +""" +Scheduler Module + +Provides: +- Micro-Scheduler: JABAS-style PID control for batch size +- Macro-Scheduler: Tower/Captain bi-level architecture +""" + +from .micro_scheduler import MicroScheduler, PIDController, BatchSizeController +from .scheduler_config import SchedulerConfig +from .tower import Tower +from .captain import Captain, CaptainPool + +__all__ = [ + "MicroScheduler", + "PIDController", + "BatchSizeController", + "SchedulerConfig", + "Tower", + "Captain", + "CaptainPool", +] diff --git a/data_juicer/core/elasticjuicer/scheduler/captain.py b/data_juicer/core/elasticjuicer/scheduler/captain.py new file mode 100644 index 00000000000..ce485b05d66 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/captain.py @@ -0,0 +1,496 @@ +""" +Captain: Local Per-Operator Scheduler for ElasticJuicer + +Based on Autothrottle's bi-level architecture, Captain is the local scheduler +that manages a single operator stage under Tower's global constraints. + +Key responsibilities: +1. Execute Micro-Scheduler (JABAS-style batch size control) within quota +2. Report metrics to Tower +3. Enforce resource quotas from Tower +4. Handle local OOM events and recovery +5. Coordinate with adjacent Captains in pipeline + +References: +- Autothrottle (NSDI 2024): Bi-level control for SLO-targeted microservices +- Report Section 5.1: Tower/Captain architecture +""" + +import time +from dataclasses import dataclass +from typing import Optional, Callable, List, Dict, TYPE_CHECKING +from collections import deque +import psutil + +if TYPE_CHECKING: + from .tower import StageMetrics + +from ..scheduler.micro_scheduler import MicroScheduler, BatchSizeController +from ..scheduler.tower import ResourceQuota, StageMetrics, TopologyMode +from ..profiler.resource_monitor import ResourceMonitor, ResourceSnapshot +from ..predictor.memory_predictor import MemoryPredictor +from ..predictor.feature_extractor import FeatureExtractor + + +@dataclass +class CaptainConfig: + """Configuration for Captain scheduler""" + stage_name: str + initial_batch_size: int = 32 + report_interval_sec: float = 1.0 # How often to report to Tower + quota_check_interval_sec: float = 0.5 # How often to check quota + enable_micro_scheduler: bool = True # Use JABAS-style control + enable_prediction: bool = True # Use memory prediction + emergency_backoff_ratio: float = 0.5 # OOM backoff ratio + + +class Captain: + """ + Local Per-Operator Scheduler (Captain from Autothrottle architecture) + + Captain manages a single operator stage, executing micro-scheduling decisions + (batch size adjustment) within the global constraints set by Tower. + + Key design: + - Tower sets "what to achieve" (target parallelism, resource quota, SLO) + - Captain decides "how to achieve it" (batch size, local optimization) + - Bi-level decoupling enables scalability and autonomy + """ + + def __init__( + self, + config: CaptainConfig, + tower_callback: Optional[Callable[[StageMetrics], None]] = None + ): + """ + Initialize Captain local scheduler + + Args: + config: Captain configuration + tower_callback: Callback function to report metrics to Tower + """ + self.config = config + self.tower_callback = tower_callback + + # Micro-scheduler for batch size control + if config.enable_micro_scheduler: + self.micro_scheduler = MicroScheduler( + initial_batch_size=config.initial_batch_size, + max_batch_size=1024, + min_batch_size=1 + ) + else: + self.micro_scheduler = None + + # Resource monitoring + self.monitor = ResourceMonitor() + + # Memory prediction + if config.enable_prediction: + self.predictor = MemoryPredictor(op_name=config.stage_name) + self.feature_extractor = FeatureExtractor() + else: + self.predictor = None + self.feature_extractor = None + + # Current resource quota from Tower + self.quota: Optional[ResourceQuota] = None + + # Current stage metrics + self.metrics = StageMetrics(stage_name=config.stage_name) + + # Queue simulation + self.queue: deque = deque() + + # Timing + self.last_report_time = time.time() + self.last_quota_check_time = time.time() + + # Processing statistics + self.samples_processed = 0 + self.total_latency_ms = 0.0 + self.latency_history = deque(maxlen=100) + self.throughput_history = deque(maxlen=100) + + # OOM tracking + self.oom_events = 0 + self.last_oom_time = 0.0 + self._total_oom_count: int = 0 # Cumulative OOM count for metrics reporting + + # Backpressure state + self._backpressure_active: bool = False + self._backpressure_slowdown: float = 0.5 # Default slowdown ratio + + # Metrics tracking fields for Tower consumption + self._recent_throughput: float = 0.0 # samples/sec from recent batches + self._recent_latency_ms: float = 0.0 # average latency from recent batches + self._current_cpu_util: float = 0.0 # latest CPU utilization + self._current_memory_util: float = 0.0 # latest memory utilization + self._current_gpu_util: float = 0.0 # latest GPU utilization + + def set_quota(self, quota: ResourceQuota): + """ + Receive resource quota from Tower + + Args: + quota: Resource allocation from Tower. May include a 'backpressure' + attribute to signal upstream throttling. + """ + self.quota = quota + + # Check for backpressure signal from Tower + if hasattr(quota, 'backpressure'): + self._backpressure_active = quota.backpressure + + # Update micro-scheduler constraints if quota changed + if self.micro_scheduler and quota.memory_quota_mb > 0: + # Adjust max batch size based on memory quota + # Rough estimate: 100MB per sample for typical multimodal data + estimated_max_batch = max(1, int(quota.memory_quota_mb / 100)) + self.micro_scheduler.controller.max_batch_size = min( + self.micro_scheduler.controller.max_batch_size, + estimated_max_batch + ) + + def enqueue_samples(self, samples: List): + """ + Add samples to processing queue + + Args: + samples: List of samples to process + """ + for sample in samples: + self.queue.append(sample) + + # Update queue depth metric + self.metrics.queue_depth = len(self.queue) + + def process_batch( + self, + operator_func: Callable, + sample_batch: Optional[List] = None + ) -> Optional[List]: + """ + Process a batch using the operator, with Captain's orchestration + + This is the core execution loop that: + 1. Gets batch size from micro-scheduler + 2. Dequeues samples + 3. Monitors execution + 4. Updates predictor and scheduler + 5. Checks quota compliance + + Args: + operator_func: The actual operator function to execute + sample_batch: Optional pre-formed batch (if None, dequeue from queue) + + Returns: + Processed results or None if queue empty + """ + start_time = time.time() + + # Get current batch size recommendation + if self.micro_scheduler: + current_batch_size = self.micro_scheduler.controller.current_batch_size + else: + current_batch_size = self.config.initial_batch_size + + # Apply backpressure throttling if active + if self._backpressure_active: + # Throttle: reduce effective batch size and add delay + current_batch_size = max(1, int(current_batch_size * self._backpressure_slowdown)) + time.sleep(0.1) # Small delay to reduce pressure on downstream + + # Dequeue samples if not provided + if sample_batch is None: + if len(self.queue) == 0: + return None + + actual_batch_size = min(current_batch_size, len(self.queue)) + sample_batch = [self.queue.popleft() for _ in range(actual_batch_size)] + else: + actual_batch_size = len(sample_batch) + + # Extract features for prediction (if enabled) + predicted_memory_mb = None + if self.predictor and self.feature_extractor and len(sample_batch) > 0: + features = self.feature_extractor.extract_from_sample(sample_batch[0]) + features.batch_size = actual_batch_size # Set batch size + prediction = self.predictor.predict(features) + if prediction: + predicted_memory_mb = prediction.predicted_memory_mb + + # Monitor execution + with self.monitor.measure_execution( + self.config.stage_name, + actual_batch_size + ): + try: + # Execute operator + results = operator_func(sample_batch) + + # Record success + self.samples_processed += actual_batch_size + + except MemoryError as e: + # OOM event - get approximate snapshot + snapshot_approx = ResourceSnapshot( + timestamp=time.time(), + batch_size=actual_batch_size, + cpu_percent=psutil.cpu_percent(), + memory_mb=psutil.virtual_memory().used / (1024 * 1024), + latency_ms=0 + ) + self._handle_oom(actual_batch_size, snapshot_approx) + raise + + # Get recorded stats + op_stats = self.monitor.get_stats(self.config.stage_name) + if op_stats and op_stats.snapshots: + snapshot = op_stats.snapshots[-1] # Get latest snapshot + # Update predictor with actual memory usage + if self.predictor and self.feature_extractor and len(sample_batch) > 0: + features = self.feature_extractor.extract_from_sample(sample_batch[0]) + self.predictor.observe(features, snapshot.memory_mb) + + # Update micro-scheduler + if self.micro_scheduler: + self.micro_scheduler.update( + actual_memory_used=snapshot.memory_mb, + sample_features=None # Already updated predictor above + ) + + # Update metrics + latency_ms = snapshot.latency_ms + self.total_latency_ms += latency_ms + self.latency_history.append(latency_ms) + + throughput = snapshot.throughput + self.throughput_history.append(throughput) + + self.metrics.avg_latency_ms = ( + sum(self.latency_history) / len(self.latency_history) + if self.latency_history else 0 + ) + self.metrics.throughput = ( + sum(self.throughput_history) / len(self.throughput_history) + if self.throughput_history else 0 + ) + self.metrics.cpu_utilization = snapshot.cpu_percent + self.metrics.memory_utilization = ( + (snapshot.memory_mb / self.quota.memory_quota_mb * 100) + if self.quota and self.quota.memory_quota_mb > 0 + else 0 + ) + self.metrics.gpu_utilization = snapshot.gpu_utilization or 0 + self.metrics.queue_depth = len(self.queue) + self.metrics.oom_count = self.oom_events + self.metrics.current_parallelism = 1 # Single-actor for now + + # Update internal metrics tracking fields for collect_metrics() + elapsed_time = time.time() - start_time + processed_count = actual_batch_size + self._recent_throughput = processed_count / elapsed_time if elapsed_time > 0 else 0 + self._recent_latency_ms = elapsed_time * 1000 / processed_count if processed_count > 0 else 0 + self._current_cpu_util = snapshot.cpu_percent + self._current_memory_util = ( + (snapshot.memory_mb / self.quota.memory_quota_mb * 100) + if self.quota and self.quota.memory_quota_mb > 0 + else 0 + ) + self._current_gpu_util = snapshot.gpu_utilization or 0 + + # Check if should report to Tower + current_time = time.time() + if current_time - self.last_report_time >= self.config.report_interval_sec: + self._report_to_tower() + self.last_report_time = current_time + + # Check quota compliance + if current_time - self.last_quota_check_time >= self.config.quota_check_interval_sec: + self._check_quota_compliance() + self.last_quota_check_time = current_time + + return results + + def _handle_oom(self, batch_size: int, snapshot: Optional[ResourceSnapshot]): + """ + Handle OOM event with emergency backoff + + Args: + batch_size: Batch size that caused OOM + snapshot: Resource snapshot at OOM time + """ + self.oom_events += 1 + self._total_oom_count += 1 # Increment cumulative OOM count for metrics + self.last_oom_time = time.time() + + # Emergency backoff + if self.micro_scheduler: + new_batch_size = max(1, batch_size // 2) + self.micro_scheduler.controller.current_batch_size = new_batch_size + self.micro_scheduler.controller.max_batch_size = batch_size + + # Update metrics + self.metrics.oom_count = self.oom_events + + def _report_to_tower(self): + """Report current metrics to Tower""" + if self.tower_callback: + # Update timestamp + self.metrics.last_update = time.time() + + # Send to Tower + self.tower_callback(self.metrics) + + def _check_quota_compliance(self): + """ + Check if current resource usage is within Tower's quota + + If exceeding quota, apply throttling + """ + if not self.quota: + return + + # Check memory quota + current_memory_mb = psutil.virtual_memory().used / (1024 * 1024) + if current_memory_mb > self.quota.memory_quota_mb: + # Exceeding memory quota, reduce batch size + if self.micro_scheduler: + reduction_ratio = self.quota.memory_quota_mb / current_memory_mb + new_batch_size = max( + 1, + int(self.micro_scheduler.controller.current_batch_size * reduction_ratio) + ) + self.micro_scheduler.controller.current_batch_size = new_batch_size + + def _get_available_memory_mb(self) -> float: + """Get available system memory in MB""" + return psutil.virtual_memory().available / (1024 * 1024) + + def get_stats(self) -> dict: + """Get Captain statistics""" + return { + 'stage_name': self.config.stage_name, + 'samples_processed': self.samples_processed, + 'queue_depth': len(self.queue), + 'current_batch_size': ( + self.micro_scheduler.controller.current_batch_size + if self.micro_scheduler else self.config.initial_batch_size + ), + 'avg_latency_ms': self.metrics.avg_latency_ms, + 'avg_throughput': self.metrics.throughput, + 'oom_events': self.oom_events, + 'quota': { + 'target_parallelism': self.quota.target_parallelism if self.quota else 1, + 'memory_quota_mb': self.quota.memory_quota_mb if self.quota else 0, + 'cpu_quota': self.quota.cpu_quota if self.quota else 0, + } if self.quota else None + } + + def collect_metrics(self) -> 'StageMetrics': + """Collect current metrics snapshot for Tower consumption. + + This is the standardized interface for Tower to pull metrics from Captain. + Returns a StageMetrics object containing the current state of this stage. + + Returns: + StageMetrics: A snapshot containing: + - stage_name: Name of this operator stage + - queue_depth: Number of pending samples in queue + - current_parallelism: Current number of actors (1 for single-actor) + - throughput: Recent throughput in samples/sec + - avg_latency_ms: Recent average processing latency + - cpu_utilization: Current CPU utilization percentage + - memory_utilization: Current memory utilization percentage + - gpu_utilization: Current GPU utilization percentage (if applicable) + - oom_count: Cumulative OOM event count + """ + # Use local import to avoid circular dependencies + from .tower import StageMetrics + + return StageMetrics( + stage_name=self.config.stage_name, + queue_depth=len(self.queue), + current_parallelism=self.metrics.current_parallelism, + throughput=self._recent_throughput if self._recent_throughput > 0 else self.metrics.throughput, + avg_latency_ms=self._recent_latency_ms if self._recent_latency_ms > 0 else self.metrics.avg_latency_ms, + cpu_utilization=self._current_cpu_util if self._current_cpu_util > 0 else self.metrics.cpu_utilization, + memory_utilization=self._current_memory_util if self._current_memory_util > 0 else self.metrics.memory_utilization, + gpu_utilization=self._current_gpu_util if self._current_gpu_util > 0 else self.metrics.gpu_utilization, + oom_count=self._total_oom_count, + last_update=time.time() + ) + + +class CaptainPool: + """ + Manages multiple Captains in a pipeline + + Coordinates execution across multiple stages, ensuring data flows + correctly and all Captains report to Tower. + """ + + def __init__(self, tower_callback: Optional[Callable[[StageMetrics], None]] = None): + """ + Initialize Captain pool + + Args: + tower_callback: Shared callback to Tower for all Captains + """ + self.tower_callback = tower_callback + self.captains: dict[str, Captain] = {} + + def add_captain(self, config: CaptainConfig) -> Captain: + """ + Add a new Captain to the pool + + Args: + config: Configuration for the Captain + + Returns: + The created Captain instance + """ + captain = Captain(config, self.tower_callback) + self.captains[config.stage_name] = captain + return captain + + def get_captain(self, stage_name: str) -> Optional[Captain]: + """Get Captain by stage name""" + return self.captains.get(stage_name) + + def set_quotas(self, quotas: dict[str, ResourceQuota]): + """ + Distribute quotas from Tower to all Captains + + Args: + quotas: Dict mapping captain_id to ResourceQuota + """ + for captain_id, quota in quotas.items(): + # Extract stage name from captain_id + stage_name = quota.captain_id.replace('captain_', '').rsplit('_', 1)[0] + + if stage_name in self.captains: + self.captains[stage_name].set_quota(quota) + + def get_all_stats(self) -> dict[str, dict]: + """Get statistics from all Captains""" + return { + name: captain.get_stats() + for name, captain in self.captains.items() + } + + def collect_all_metrics(self) -> Dict[str, 'StageMetrics']: + """Collect metrics from all managed Captains. + + This is the standardized interface for Tower to pull metrics from all + Captains in the pool at once. + + Returns: + Dict[str, StageMetrics]: A dictionary mapping captain stage names + to their current StageMetrics snapshots. + """ + return { + stage_name: captain.collect_metrics() + for stage_name, captain in self.captains.items() + } diff --git a/data_juicer/core/elasticjuicer/scheduler/micro_scheduler.py b/data_juicer/core/elasticjuicer/scheduler/micro_scheduler.py new file mode 100644 index 00000000000..96c53fd7f95 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/micro_scheduler.py @@ -0,0 +1,550 @@ +""" +Micro-Scheduler with JABAS-style PID Control + +Implements dynamic batch size adjustment based on memory feedback. +Prevents OOM by continuously monitoring memory and adjusting batch sizes. + +Based on: +- JABAS (EuroSys 2025): Adaptive batching for heterogeneous GPUs +- Report Section 4.1: PID Control for Batch Size + +Key Features: +- PID controller for smooth batch size adjustment +- Memory pressure monitoring +- Safety thresholds and fallback strategies +- Integration with MemoryPredictor +""" + +import time +import psutil +from typing import Optional, Dict, Any, Callable +from dataclasses import dataclass +from collections import deque +import numpy as np + +try: + import GPUtil + GPU_AVAILABLE = True +except ImportError: + GPU_AVAILABLE = False + + +@dataclass +class SampleFeatures: + """Sample features for MicroScheduler prediction.""" + batch_size: int = 1 + estimated_memory_mb: float = 0.0 + + +@dataclass +class MemoryState: + """Current memory state""" + timestamp: float + # CPU memory + total_memory_mb: float + used_memory_mb: float + available_memory_mb: float + memory_percent: float + # GPU memory (if available) + gpu_total_mb: Optional[float] = None + gpu_used_mb: Optional[float] = None + gpu_available_mb: Optional[float] = None + gpu_percent: Optional[float] = None + + def get_available_memory(self, use_gpu: bool = False) -> float: + """Get available memory in MB""" + if use_gpu and self.gpu_available_mb is not None: + return self.gpu_available_mb + return self.available_memory_mb + + +class PIDController: + """ + PID (Proportional-Integral-Derivative) Controller. + + Classic control theory algorithm used in JABAS for smooth adjustments. + + Formula: + output(t) = Kp * error(t) + Ki * Σerror + Kd * Δerror + + Where: + - error = setpoint - current_value + - Kp, Ki, Kd are tuning parameters + """ + + def __init__( + self, + kp: float = 1.0, + ki: float = 0.1, + kd: float = 0.05, + setpoint: float = 1000.0, + output_limits: tuple = (1, 1000), + ): + """ + Initialize PID controller. + + Args: + kp: Proportional gain + ki: Integral gain + kd: Derivative gain + setpoint: Target value (e.g., target available memory in MB) + output_limits: (min, max) bounds for output + """ + self.kp = kp + self.ki = ki + self.kd = kd + self.setpoint = setpoint + self.output_limits = output_limits + + # State + self.last_error = 0.0 + self.integral = 0.0 + self.last_time = None + + def update(self, current_value: float, dt: Optional[float] = None) -> float: + """ + Update PID controller with new measurement. + + Args: + current_value: Current measured value + dt: Time delta since last update (optional) + + Returns: + Control output + """ + # Calculate error + error = self.setpoint - current_value + + # Calculate time delta + current_time = time.time() + if self.last_time is None or dt is not None: + dt = dt or 0.1 # Default dt + else: + dt = current_time - self.last_time + self.last_time = current_time + + # Proportional term + p_term = self.kp * error + + # Integral term (with anti-windup) + self.integral += error * dt + # Clamp integral to prevent windup + max_integral = self.output_limits[1] / (self.ki + 1e-6) + self.integral = np.clip(self.integral, -max_integral, max_integral) + i_term = self.ki * self.integral + + # Derivative term + derivative = (error - self.last_error) / (dt + 1e-6) + d_term = self.kd * derivative + + # Calculate output + output = p_term + i_term + d_term + + # Apply output limits + output = np.clip(output, self.output_limits[0], self.output_limits[1]) + + # Update state + self.last_error = error + + return output + + def reset(self): + """Reset controller state""" + self.last_error = 0.0 + self.integral = 0.0 + self.last_time = None + + def set_setpoint(self, setpoint: float): + """Update setpoint""" + self.setpoint = setpoint + + +class BatchSizeController: + """ + Controls batch size using PID feedback based on memory pressure. + + This is the core of the micro-scheduler - it continuously monitors + memory and adjusts batch size to maximize throughput while preventing OOM. + + Strategy (from Report Section 4.1): + B_next = B_curr × (M_target / M_curr) + + Enhanced with PID for smoothness. + """ + + def __init__( + self, + initial_batch_size: int = 1, + min_batch_size: int = 1, + max_batch_size: int = 1000, + target_memory_utilization: float = 0.85, + safety_buffer_mb: float = 1000.0, + use_gpu: bool = False, + enable_prediction: bool = True, + memory_predictor = None, + ): + """ + Initialize batch size controller. + + Args: + initial_batch_size: Starting batch size + min_batch_size: Minimum allowed batch size + max_batch_size: Maximum allowed batch size + target_memory_utilization: Target memory usage (0.0-1.0) + safety_buffer_mb: Safety buffer to keep free (MB) + use_gpu: Monitor GPU memory instead of CPU + enable_prediction: Use MemoryPredictor for proactive adjustment + memory_predictor: MemoryPredictor instance + """ + self.current_batch_size = initial_batch_size + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + self.target_utilization = target_memory_utilization + self.safety_buffer_mb = safety_buffer_mb + self.use_gpu = use_gpu + self.enable_prediction = enable_prediction + self.memory_predictor = memory_predictor + + # PID controller for smooth adjustments + # Setpoint will be dynamically updated based on available memory + self.pid = PIDController( + kp=0.5, # Moderate proportional gain + ki=0.05, # Small integral gain + kd=0.1, # Small derivative gain + setpoint=safety_buffer_mb, + output_limits=(min_batch_size, max_batch_size), + ) + + # History + self.batch_size_history = deque(maxlen=100) + self.memory_history = deque(maxlen=100) + self.oom_events = [] + + # Statistics + self.total_adjustments = 0 + self.increase_count = 0 + self.decrease_count = 0 + + def get_memory_state(self) -> MemoryState: + """Get current memory state""" + # CPU memory + mem = psutil.virtual_memory() + total_mb = mem.total / (1024 * 1024) + used_mb = mem.used / (1024 * 1024) + available_mb = mem.available / (1024 * 1024) + percent = mem.percent + + # GPU memory + gpu_total = None + gpu_used = None + gpu_available = None + gpu_percent = None + + if self.use_gpu and GPU_AVAILABLE: + try: + gpus = GPUtil.getGPUs() + if gpus: + gpu = gpus[0] # Use first GPU + gpu_total = gpu.memoryTotal + gpu_used = gpu.memoryUsed + gpu_available = gpu.memoryFree + gpu_percent = (gpu_used / gpu_total * 100) if gpu_total > 0 else 0 + except Exception: + pass + + return MemoryState( + timestamp=time.time(), + total_memory_mb=total_mb, + used_memory_mb=used_mb, + available_memory_mb=available_mb, + memory_percent=percent, + gpu_total_mb=gpu_total, + gpu_used_mb=gpu_used, + gpu_available_mb=gpu_available, + gpu_percent=gpu_percent, + ) + + def calculate_next_batch_size( + self, + memory_state: MemoryState, + predicted_memory_per_sample: Optional[float] = None, + ) -> int: + """ + Calculate next batch size using PID control and predictions. + + Args: + memory_state: Current memory state + predicted_memory_per_sample: Predicted memory per sample (optional) + + Returns: + Recommended batch size + """ + available_mb = memory_state.get_available_memory(self.use_gpu) + + # Method 1: Direct calculation based on available memory + # More memory available -> larger batch size + if available_mb > self.safety_buffer_mb: + usable_memory = available_mb - self.safety_buffer_mb + # Scale batch size proportionally to usable memory + # Normalize to a reasonable range + memory_based_batch = int((usable_memory / 1000.0) * self.max_batch_size) + memory_based_batch = np.clip(memory_based_batch, self.min_batch_size, self.max_batch_size) + else: + # Below safety buffer, use minimum + memory_based_batch = self.min_batch_size + + # Method 2: Ratio-based adjustment (from JABAS paper) + # B_next = B_curr × (M_target / M_curr) + total_mb = memory_state.total_memory_mb if not self.use_gpu else (memory_state.gpu_total_mb or 1000) + target_used = total_mb * self.target_utilization + current_used = total_mb - available_mb + + if current_used > 0: + ratio = target_used / current_used + # Clamp ratio to prevent extreme changes + ratio = np.clip(ratio, 0.5, 2.0) + ratio_batch_size = int(self.current_batch_size * ratio) + else: + ratio_batch_size = self.current_batch_size + + # Method 3: Prediction-based adjustment (if predictor available) + prediction_batch_size = None + if self.enable_prediction and predicted_memory_per_sample: + # Calculate how many samples can fit + usable_memory = available_mb - self.safety_buffer_mb + if usable_memory > 0 and predicted_memory_per_sample > 0: + prediction_batch_size = int(usable_memory / predicted_memory_per_sample) + + # Combine methods (weighted average) + candidates = [] + weights = [] + + candidates.append(memory_based_batch) + weights.append(0.4) # 40% weight on memory-based + + candidates.append(ratio_batch_size) + weights.append(0.3) # 30% weight on ratio + + if prediction_batch_size is not None: + candidates.append(prediction_batch_size) + weights.append(0.3) # 30% weight on prediction + + # Weighted average + next_batch_size = int(np.average(candidates, weights=weights)) + + # Apply bounds + next_batch_size = np.clip(next_batch_size, self.min_batch_size, self.max_batch_size) + + # Smooth changes (avoid drastic jumps) + max_change = max(1, int(self.current_batch_size * 0.5)) # Max 50% change per step + if abs(next_batch_size - self.current_batch_size) > max_change: + if next_batch_size > self.current_batch_size: + next_batch_size = self.current_batch_size + max_change + else: + next_batch_size = self.current_batch_size - max_change + + return int(next_batch_size) + + def update_batch_size( + self, + actual_memory_used: Optional[float] = None, + predicted_memory_per_sample: Optional[float] = None, + ) -> int: + """ + Update batch size based on current memory state. + + Args: + actual_memory_used: Actual memory used by last batch (for feedback) + predicted_memory_per_sample: Predicted memory per sample + + Returns: + New batch size + """ + # Get current memory state + memory_state = self.get_memory_state() + + # Calculate next batch size + next_batch_size = self.calculate_next_batch_size( + memory_state, + predicted_memory_per_sample, + ) + + # Update statistics + if next_batch_size > self.current_batch_size: + self.increase_count += 1 + elif next_batch_size < self.current_batch_size: + self.decrease_count += 1 + + if next_batch_size != self.current_batch_size: + self.total_adjustments += 1 + + # Update current batch size + old_batch_size = self.current_batch_size + self.current_batch_size = next_batch_size + + # Record history + self.batch_size_history.append({ + 'timestamp': time.time(), + 'old_batch': old_batch_size, + 'new_batch': next_batch_size, + 'available_mb': memory_state.get_available_memory(self.use_gpu), + 'memory_percent': memory_state.memory_percent, + }) + + self.memory_history.append(memory_state) + + return next_batch_size + + def report_oom(self, batch_size: int, memory_mb: float): + """Report an OOM event to adjust strategy""" + self.oom_events.append({ + 'timestamp': time.time(), + 'batch_size': batch_size, + 'memory_mb': memory_mb, + }) + + # Emergency reduction + self.current_batch_size = max(1, batch_size // 2) + self.max_batch_size = batch_size # Don't go higher than OOM point + + # Reset PID to avoid windup + self.pid.reset() + + def get_stats(self) -> Dict[str, Any]: + """Get controller statistics""" + return { + 'current_batch_size': self.current_batch_size, + 'min_batch_size': self.min_batch_size, + 'max_batch_size': self.max_batch_size, + 'total_adjustments': self.total_adjustments, + 'increase_count': self.increase_count, + 'decrease_count': self.decrease_count, + 'oom_events': len(self.oom_events), + 'avg_batch_size': np.mean([h['new_batch'] for h in self.batch_size_history]) if self.batch_size_history else 0, + } + + +class MicroScheduler: + """ + Micro-Scheduler with JABAS-style adaptive batching. + + Orchestrates: + - Memory monitoring + - Batch size control via PID + - Memory prediction integration + - OOM prevention + + Usage: + scheduler = MicroScheduler(memory_predictor=predictor) + + for batch in data_loader: + # Get recommended batch size + batch_size = scheduler.get_batch_size() + + # Process batch + result = process(batch[:batch_size]) + + # Update scheduler with feedback + scheduler.update(actual_memory_used=memory_mb) + """ + + def __init__( + self, + memory_predictor=None, + initial_batch_size: int = 32, + min_batch_size: int = 1, + max_batch_size: int = 1000, + target_memory_utilization: float = 0.85, + safety_buffer_mb: float = 1000.0, + use_gpu: bool = False, + enable_auto_adjust: bool = True, + ): + """ + Initialize micro-scheduler. + + Args: + memory_predictor: MemoryPredictor instance + initial_batch_size: Starting batch size + min_batch_size: Minimum batch size + max_batch_size: Maximum batch size + target_memory_utilization: Target memory usage (0.0-1.0) + safety_buffer_mb: Safety buffer in MB + use_gpu: Monitor GPU memory + enable_auto_adjust: Enable automatic batch size adjustment + """ + self.memory_predictor = memory_predictor + self.enable_auto_adjust = enable_auto_adjust + + # Batch size controller + self.controller = BatchSizeController( + initial_batch_size=initial_batch_size, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + target_memory_utilization=target_memory_utilization, + safety_buffer_mb=safety_buffer_mb, + use_gpu=use_gpu, + enable_prediction=memory_predictor is not None, + memory_predictor=memory_predictor, + ) + + # State + self.iteration = 0 + self.last_prediction = None + + def get_batch_size(self, sample_features=None) -> int: + """ + Get recommended batch size for next iteration. + + Args: + sample_features: Optional sample features for prediction + + Returns: + Recommended batch size + """ + if not self.enable_auto_adjust: + return self.controller.current_batch_size + + # Get memory prediction if available + predicted_per_sample = None + if self.memory_predictor and sample_features: + prediction = self.memory_predictor.predict(sample_features) + if prediction: + self.last_prediction = prediction + # Estimate per-sample memory + predicted_per_sample = prediction.predicted_memory_mb / sample_features.batch_size + + # Update batch size + new_batch_size = self.controller.update_batch_size( + predicted_memory_per_sample=predicted_per_sample, + ) + + self.iteration += 1 + return new_batch_size + + def update(self, actual_memory_used: float, sample_features=None): + """ + Update scheduler with feedback from actual execution. + + Args: + actual_memory_used: Actual memory used in MB + sample_features: Sample features (for predictor update) + """ + # Update memory predictor if available + if self.memory_predictor and sample_features: + self.memory_predictor.observe(sample_features, actual_memory_used) + + def report_oom(self, batch_size: int, memory_mb: float): + """Report OOM event""" + self.controller.report_oom(batch_size, memory_mb) + + def get_stats(self) -> Dict[str, Any]: + """Get scheduler statistics""" + stats = self.controller.get_stats() + stats['iteration'] = self.iteration + if self.last_prediction: + stats['last_prediction'] = { + 'predicted_mb': self.last_prediction.predicted_memory_mb, + 'confidence_lower': self.last_prediction.confidence_lower, + 'confidence_upper': self.last_prediction.confidence_upper, + } + return stats diff --git a/data_juicer/core/elasticjuicer/scheduler/scheduler_config.py b/data_juicer/core/elasticjuicer/scheduler/scheduler_config.py new file mode 100644 index 00000000000..9efd0c4de97 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/scheduler_config.py @@ -0,0 +1,147 @@ +""" +Scheduler Configuration + +Centralized configuration for micro and macro schedulers. +""" + +from dataclasses import dataclass, field, asdict +from typing import Dict, Optional + +try: + import yaml + YAML_AVAILABLE = True +except ImportError: + YAML_AVAILABLE = False + + +@dataclass +class SchedulerConfig: + """Configuration for ElasticJuicer schedulers""" + + # Batch size control + initial_batch_size: int = 32 + min_batch_size: int = 1 + max_batch_size: int = 1000 + + # Memory management + target_memory_utilization: float = 0.85 # 85% utilization target + safety_buffer_mb: float = 1000.0 # 1GB safety buffer + use_gpu_memory: bool = False + + # PID tuning + pid_kp: float = 0.5 # Proportional gain + pid_ki: float = 0.05 # Integral gain + pid_kd: float = 0.1 # Derivative gain + + # Auto-adjustment + enable_auto_adjust: bool = True + enable_prediction: bool = True + + # Predictor settings + predictor_window_size: int = 100 + predictor_min_samples: int = 5 + predictor_confidence_level: float = 0.95 + + # Safety settings + max_batch_change_ratio: float = 0.5 # Max 50% change per adjustment + oom_backoff_ratio: float = 0.5 # Reduce to 50% on OOM + + # Tower macro-scheduler settings (PBT output) + rebalance_interval_sec: float = 5.0 # Tower macro-scheduler rebalance loop interval in seconds + tower_allocation_weights: Optional[Dict[str, float]] = field(default=None) # Per-stage resource allocation weights from PBT tuning + backpressure_threshold: float = 0.9 # Memory utilization threshold above which backpressure is applied + backpressure_slowdown_ratio: float = 0.5 # Factor to reduce throughput when backpressure is active + + @classmethod + def conservative(cls) -> 'SchedulerConfig': + """Conservative configuration (prioritizes safety)""" + return cls( + target_memory_utilization=0.70, + safety_buffer_mb=2000.0, + max_batch_change_ratio=0.25, + rebalance_interval_sec=10.0, + backpressure_threshold=0.8, + ) + + @classmethod + def aggressive(cls) -> 'SchedulerConfig': + """Aggressive configuration (prioritizes throughput)""" + return cls( + target_memory_utilization=0.95, + safety_buffer_mb=500.0, + max_batch_change_ratio=0.75, + rebalance_interval_sec=2.0, + backpressure_threshold=0.95, + ) + + @classmethod + def gpu(cls) -> 'SchedulerConfig': + """GPU-optimized configuration""" + return cls( + use_gpu_memory=True, + target_memory_utilization=0.90, + safety_buffer_mb=1024.0, # 1GB buffer for GPU + ) + + @classmethod + def from_yaml(cls, path: str) -> 'SchedulerConfig': + """Load config from a YAML file (the output of PBT tuning). + + Args: + path: Path to the YAML configuration file. + + Returns: + SchedulerConfig instance with values from YAML, using defaults for missing fields. + + Raises: + ImportError: If PyYAML is not installed. + FileNotFoundError: If the YAML file does not exist. + """ + if not YAML_AVAILABLE: + raise ImportError( + "PyYAML is required for YAML support. " + "Install it with: pip install pyyaml" + ) + + with open(path, 'r') as f: + data = yaml.safe_load(f) or {} + + # Filter to only include valid fields for SchedulerConfig + valid_fields = {f.name for f in cls.__dataclass_fields__.values()} + filtered_data = {k: v for k, v in data.items() if k in valid_fields} + + return cls(**filtered_data) + + def to_yaml(self, path: str) -> None: + """Export config to YAML file. + + Args: + path: Path to write the YAML configuration file. + + Raises: + ImportError: If PyYAML is not installed. + """ + if not YAML_AVAILABLE: + raise ImportError( + "PyYAML is required for YAML support. " + "Install it with: pip install pyyaml" + ) + + data = asdict(self) + + with open(path, 'w') as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + + def get_stage_weight(self, stage_name: str) -> float: + """Return the allocation weight for a given stage. + + Args: + stage_name: Name of the stage to get weight for. + + Returns: + The allocation weight for the stage. Returns 1.0 if tower_allocation_weights + is None or the stage is not found (equal weight). + """ + if self.tower_allocation_weights is None: + return 1.0 + return self.tower_allocation_weights.get(stage_name, 1.0) diff --git a/data_juicer/core/elasticjuicer/scheduler/tower.py b/data_juicer/core/elasticjuicer/scheduler/tower.py new file mode 100644 index 00000000000..3f255dc77b7 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/tower.py @@ -0,0 +1,879 @@ +""" +Tower: Global Macro-Scheduler for ElasticJuicer + +Based on Autothrottle's bi-level architecture, Tower is the global resource allocator +that sets performance targets and resource quotas for local Captains. + +Key responsibilities: +1. Monitor global queue depth and cluster resource utilization +2. Set target parallelism for each operator stage +3. Allocate resource budgets to Captains +4. Make topology decisions (co-location vs distributed) +5. Handle global SLA guarantees + +References: +- Autothrottle (NSDI 2024): Bi-level control for SLO-targeted microservices +- Report Section 5.1: Tower/Captain architecture +""" + +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from enum import Enum +import numpy as np +from collections import deque + +if TYPE_CHECKING: + from .scheduler_config import SchedulerConfig + + +class TopologyMode(Enum): + """Topology execution mode based on transfer cost and resource availability""" + CO_LOCATION = "co_location" # Operators on same node (high transfer cost) + DISTRIBUTED = "distributed" # Operators on different nodes (high parallelism) + ADAPTIVE = "adaptive" # Let Tower decide based on current state + + +@dataclass +class StageMetrics: + """Performance metrics for an operator stage""" + stage_name: str + queue_depth: int = 0 # Number of pending samples + current_parallelism: int = 1 # Current number of actors + throughput: float = 0.0 # Samples/sec + avg_latency_ms: float = 0.0 # Average processing latency + cpu_utilization: float = 0.0 # % CPU used + memory_utilization: float = 0.0 # % Memory used + gpu_utilization: float = 0.0 # % GPU used (if applicable) + oom_count: int = 0 # Number of OOM events + last_update: float = field(default_factory=time.time) + + +@dataclass +class ResourceQuota: + """Resource allocation quota for a Captain""" + captain_id: str + target_parallelism: int # Target number of actors + cpu_quota: float # CPU cores allocated + memory_quota_mb: float # Memory budget in MB + gpu_quota: float = 0.0 # GPU cores allocated (0-1) + target_throughput: float = 0.0 # Target samples/sec (SLO) + topology_mode: TopologyMode = TopologyMode.ADAPTIVE + backpressure: bool = False # Whether upstream backpressure is active + + +@dataclass +class ClusterState: + """Global cluster resource state""" + total_cpu_cores: int + total_memory_mb: float + total_gpu_count: int + available_cpu_cores: float + available_memory_mb: float + available_gpus: float + timestamp: float = field(default_factory=time.time) + + +class Tower: + """ + Global Macro-Scheduler (Tower from Autothrottle architecture) + + Tower doesn't directly control individual actors' behavior. Instead, it: + 1. Monitors global system state (queue depths, resource utilization) + 2. Sets performance targets and resource quotas for Captains + 3. Makes high-level topology decisions + 4. Ensures cluster-wide SLA guarantees + + The bi-level design (Tower + Captain) solves the single-point bottleneck + problem of centralized schedulers, enabling high-frequency local decisions + under global constraints. + """ + + def __init__( + self, + cluster_state: ClusterState, + target_queue_depth: int = 100, + sla_latency_ms: float = 5000.0, + update_interval_sec: float = 5.0, + history_window: int = 20, + config: Optional['SchedulerConfig'] = None + ): + """ + Initialize Tower global scheduler + + Args: + cluster_state: Initial cluster resource state + target_queue_depth: Target queue depth to maintain + sla_latency_ms: SLA latency target (max allowed latency) + update_interval_sec: How often to recompute resource allocation + history_window: Window size for tracking metrics history + config: Optional SchedulerConfig for rebalance settings + """ + self.cluster = cluster_state + self.target_queue_depth = target_queue_depth + self.sla_latency_ms = sla_latency_ms + self.update_interval = update_interval_sec + self.config = config + + # Track all stages and their metrics + self.stages: Dict[str, StageMetrics] = {} + + # Track resource quotas allocated to each Captain + self.quotas: Dict[str, ResourceQuota] = {} + + # Metrics history for trend analysis + self.metrics_history: Dict[str, deque] = {} + self.history_window = history_window + + # Last allocation time + self.last_allocation_time = time.time() + + # SLA violation tracking + self.sla_violations = 0 + self.total_requests = 0 + + # Captain registry for direct metric collection and quota broadcast + self._captains: Dict[str, Any] = {} + + # Rebalance loop control + self._running: bool = False + self._rebalance_thread: Optional[threading.Thread] = None + + # Rebalance interval from config or fallback to update_interval + self.rebalance_interval: float = ( + config.rebalance_interval_sec if config else update_interval_sec + ) + + # Backpressure threshold from config + self._backpressure_threshold: float = ( + config.backpressure_threshold if config else 0.9 + ) + + # Per-stage backpressure state tracking + self._backpressure_states: Dict[str, bool] = {} + + # Track registration order for upstream detection + self._stage_order: List[str] = [] + + def register_stage(self, stage_name: str, initial_parallelism: int = 1) -> str: + """ + Register a new operator stage with Tower + + Args: + stage_name: Name of the operator stage + initial_parallelism: Initial number of actors + + Returns: + captain_id: Unique ID for the Captain managing this stage + """ + captain_id = f"captain_{stage_name}_{int(time.time())}" + + # Initialize stage metrics + self.stages[stage_name] = StageMetrics( + stage_name=stage_name, + current_parallelism=initial_parallelism + ) + + # Initialize metrics history + self.metrics_history[stage_name] = deque(maxlen=self.history_window) + + # Track stage registration order for upstream detection + if stage_name not in self._stage_order: + self._stage_order.append(stage_name) + + # Initialize backpressure state + self._backpressure_states[stage_name] = False + + # Allocate initial quota + initial_quota = self._compute_initial_quota(stage_name, initial_parallelism) + self.quotas[captain_id] = initial_quota + + return captain_id + + def update_stage_metrics(self, stage_name: str, metrics: StageMetrics): + """ + Update metrics for a stage (called by Captain) + + Args: + stage_name: Name of the stage + metrics: Latest metrics from Captain + """ + if stage_name not in self.stages: + raise ValueError(f"Stage {stage_name} not registered") + + # Update current metrics + self.stages[stage_name] = metrics + + # Add to history + self.metrics_history[stage_name].append({ + 'timestamp': metrics.last_update, + 'queue_depth': metrics.queue_depth, + 'throughput': metrics.throughput, + 'latency_ms': metrics.avg_latency_ms, + 'cpu_util': metrics.cpu_utilization, + 'memory_util': metrics.memory_utilization + }) + + # Track SLA violations + self.total_requests += 1 + if metrics.avg_latency_ms > self.sla_latency_ms: + self.sla_violations += 1 + + def allocate_resources(self) -> Dict[str, ResourceQuota]: + """ + Compute and allocate resource quotas to all Captains + + This is the core global decision-making function. It: + 1. Analyzes global bottlenecks (queue depths, latencies) + 2. Computes target parallelism for each stage + 3. Allocates CPU/GPU/memory budgets + 4. Returns updated quotas for Captains to enforce + + Returns: + Updated resource quotas for all Captains + """ + current_time = time.time() + + # Rate limit allocation updates (avoid thrashing) + if current_time - self.last_allocation_time < self.update_interval: + return self.quotas + + self.last_allocation_time = current_time + + # Identify bottleneck stages + bottlenecks = self._identify_bottlenecks() + + # Compute resource allocation strategy + for captain_id, quota in self.quotas.items(): + stage_name = self._get_stage_from_captain(captain_id) + if stage_name not in self.stages: + continue + + metrics = self.stages[stage_name] + + # Decide target parallelism based on queue depth and throughput + target_parallelism = self._compute_target_parallelism( + metrics, + is_bottleneck=(stage_name in bottlenecks) + ) + + # Allocate resources proportionally + resource_allocation = self._allocate_stage_resources( + stage_name, + target_parallelism + ) + + # Update quota + quota.target_parallelism = target_parallelism + quota.cpu_quota = resource_allocation['cpu'] + quota.memory_quota_mb = resource_allocation['memory_mb'] + quota.gpu_quota = resource_allocation['gpu'] + quota.target_throughput = self._compute_target_throughput(metrics) + quota.topology_mode = self._decide_topology(stage_name, metrics) + + return self.quotas + + def _identify_bottlenecks(self) -> List[str]: + """ + Identify bottleneck stages based on queue depth and latency + + A stage is a bottleneck if: + 1. Queue depth > target_queue_depth + 2. Latency approaching SLA limit + 3. Throughput declining over time + + Returns: + List of bottleneck stage names + """ + bottlenecks = [] + + for stage_name, metrics in self.stages.items(): + # Check queue depth + queue_pressure = metrics.queue_depth > self.target_queue_depth + + # Check latency + latency_pressure = metrics.avg_latency_ms > (self.sla_latency_ms * 0.8) + + # Check throughput trend + throughput_declining = False + if stage_name in self.metrics_history and len(self.metrics_history[stage_name]) >= 3: + recent = list(self.metrics_history[stage_name])[-3:] + throughputs = [m['throughput'] for m in recent] + if len(throughputs) >= 2: + throughput_declining = throughputs[-1] < throughputs[0] * 0.9 + + if queue_pressure or latency_pressure or throughput_declining: + bottlenecks.append(stage_name) + + return bottlenecks + + def _compute_target_parallelism( + self, + metrics: StageMetrics, + is_bottleneck: bool + ) -> int: + """ + Compute target parallelism for a stage + + Strategy: + - If bottleneck: Increase parallelism to drain queue + - If underutilized: Decrease parallelism to free resources + - Consider resource availability + + Args: + metrics: Current stage metrics + is_bottleneck: Whether this stage is a bottleneck + + Returns: + Target parallelism (number of actors) + """ + current = metrics.current_parallelism + + if is_bottleneck: + # Estimate needed parallelism to drain queue + if metrics.throughput > 0: + # Time to process queue at current throughput + queue_drain_time = metrics.queue_depth / metrics.throughput + + # If drain time > SLA, scale up + if queue_drain_time > (self.sla_latency_ms / 1000.0): + scale_factor = min(2.0, queue_drain_time / (self.sla_latency_ms / 1000.0)) + target = int(current * scale_factor) + else: + target = current + 1 # Conservative increase + else: + target = current + 1 # No throughput data, try increasing + else: + # Check if we can scale down (free resources) + if metrics.queue_depth < self.target_queue_depth * 0.5 and current > 1: + target = max(1, current - 1) + else: + target = current # Keep current level + + # Clamp to available resources + max_possible = self._estimate_max_parallelism() + target = min(target, max_possible) + + return max(1, target) # At least 1 actor + + def _allocate_stage_resources( + self, + stage_name: str, + target_parallelism: int + ) -> Dict[str, float]: + """ + Allocate CPU/GPU/memory to a stage based on target parallelism + + Args: + stage_name: Name of the stage + target_parallelism: Target number of actors + + Returns: + Resource allocation dict with 'cpu', 'memory_mb', 'gpu' + """ + # Simple proportional allocation (can be enhanced with OCS annotations) + total_stages = len(self.stages) + + if total_stages == 0: + cpu_share = self.cluster.available_cpu_cores + memory_share = self.cluster.available_memory_mb + gpu_share = self.cluster.available_gpus + else: + # Equal share for now (TODO: weight by OCS cost) + cpu_share = self.cluster.available_cpu_cores / total_stages + memory_share = self.cluster.available_memory_mb / total_stages + gpu_share = self.cluster.available_gpus / total_stages + + return { + 'cpu': cpu_share * target_parallelism, + 'memory_mb': memory_share * target_parallelism, + 'gpu': gpu_share * target_parallelism + } + + def _compute_target_throughput(self, metrics: StageMetrics) -> float: + """ + Compute target throughput to meet SLA + + Args: + metrics: Current stage metrics + + Returns: + Target throughput in samples/sec + """ + # To meet SLA, we need throughput >= queue_depth / (SLA_time - current_latency) + sla_time_sec = self.sla_latency_ms / 1000.0 + current_latency_sec = metrics.avg_latency_ms / 1000.0 + + remaining_time = max(0.1, sla_time_sec - current_latency_sec) + + if metrics.queue_depth > 0: + target = metrics.queue_depth / remaining_time + else: + target = metrics.throughput # Maintain current + + return max(1.0, target) + + def _decide_topology( + self, + stage_name: str, + metrics: StageMetrics + ) -> TopologyMode: + """ + Decide topology mode for operator placement + + Based on Report Section 5.4: + - CO_LOCATION: High transfer cost, sufficient local resources + - DISTRIBUTED: Different resource bottlenecks, ample bandwidth + + Args: + stage_name: Name of the stage + metrics: Current metrics + + Returns: + Topology mode decision + """ + # Check resource pressure + high_cpu = metrics.cpu_utilization > 80 + high_memory = metrics.memory_utilization > 80 + high_gpu = metrics.gpu_utilization > 80 + + # If single resource bottleneck, distribute to specialize + bottleneck_count = sum([high_cpu, high_memory, high_gpu]) + + if bottleneck_count >= 2: + # Multiple bottlenecks on same node -> distribute + return TopologyMode.DISTRIBUTED + elif bottleneck_count == 0: + # No pressure -> co-locate for efficiency + return TopologyMode.CO_LOCATION + else: + # Single bottleneck -> adaptive + return TopologyMode.ADAPTIVE + + def _estimate_max_parallelism(self) -> int: + """ + Estimate maximum parallelism given available resources + + Returns: + Maximum number of actors cluster can support + """ + # Conservative estimate: assume each actor needs 1 CPU + 1GB memory + cpu_limit = int(self.cluster.available_cpu_cores) + memory_limit = int(self.cluster.available_memory_mb / 1024) # 1GB per actor + + return max(1, min(cpu_limit, memory_limit)) + + def _get_stage_from_captain(self, captain_id: str) -> str: + """Extract stage name from captain ID""" + # captain_video_decoder_1234567890 -> video_decoder + parts = captain_id.split('_') + if len(parts) >= 3: + return '_'.join(parts[1:-1]) + return captain_id + + def _compute_initial_quota( + self, + stage_name: str, + parallelism: int + ) -> ResourceQuota: + """Compute initial resource quota for a new stage""" + captain_id = f"captain_{stage_name}_{int(time.time())}" + + # Equal share allocation initially + total_stages = max(1, len(self.stages)) + + return ResourceQuota( + captain_id=captain_id, + target_parallelism=parallelism, + cpu_quota=self.cluster.available_cpu_cores / total_stages, + memory_quota_mb=self.cluster.available_memory_mb / total_stages, + gpu_quota=self.cluster.available_gpus / total_stages, + target_throughput=10.0, # Default + topology_mode=TopologyMode.ADAPTIVE + ) + + def get_sla_compliance_rate(self) -> float: + """ + Calculate SLA compliance rate + + Returns: + Percentage of requests meeting SLA (0-100) + """ + if self.total_requests == 0: + return 100.0 + + return ((self.total_requests - self.sla_violations) / self.total_requests) * 100.0 + + def get_global_stats(self) -> Dict: + """Get global system statistics""" + return { + 'total_stages': len(self.stages), + 'total_parallelism': sum(q.target_parallelism for q in self.quotas.values()), + 'sla_compliance_rate': self.get_sla_compliance_rate(), + 'total_requests': self.total_requests, + 'sla_violations': self.sla_violations, + 'cluster_cpu_util': ( + (self.cluster.total_cpu_cores - self.cluster.available_cpu_cores) / + self.cluster.total_cpu_cores * 100 + ) if self.cluster.total_cpu_cores > 0 else 0, + 'cluster_memory_util': ( + (self.cluster.total_memory_mb - self.cluster.available_memory_mb) / + self.cluster.total_memory_mb * 100 + ) if self.cluster.total_memory_mb > 0 else 0 + } + + # ========== Captain Registry Methods ========== + + def register_captain(self, captain_id: str, captain: Any) -> None: + """Register a Captain instance for direct metric collection and quota broadcast. + + Args: + captain_id: Unique identifier for the captain + captain: Captain instance to register + """ + self._captains[captain_id] = captain + + def unregister_captain(self, captain_id: str) -> None: + """Remove a Captain from the registry. + + Args: + captain_id: Unique identifier of the captain to remove + """ + self._captains.pop(captain_id, None) + + # ========== Rebalance Loop Methods ========== + + def collect_all_metrics(self) -> Dict[str, StageMetrics]: + """Step 1: Collect metrics from all registered Captains. + + For each registered captain, call captain.collect_metrics() if available, + or use captain.metrics directly. Also update internal stage_metrics dict. + + Returns: + Dict mapping stage_name to StageMetrics + """ + collected_metrics: Dict[str, StageMetrics] = {} + + for captain_id, captain in self._captains.items(): + try: + # Try collect_metrics() method first, fall back to metrics attribute + if hasattr(captain, 'collect_metrics'): + metrics = captain.collect_metrics() + elif hasattr(captain, 'metrics'): + metrics = captain.metrics + else: + continue + + if metrics and hasattr(metrics, 'stage_name'): + stage_name = metrics.stage_name + collected_metrics[stage_name] = metrics + + # Update internal stage metrics + if stage_name in self.stages: + self.stages[stage_name] = metrics + except Exception as e: + logger = logging.getLogger(__name__) + logger.warning(f"Failed to collect metrics from captain {captain_id}: {e}") + + # Also include stages that have metrics but no registered captain + for stage_name, metrics in self.stages.items(): + if stage_name not in collected_metrics: + collected_metrics[stage_name] = metrics + + return collected_metrics + + def identify_bottleneck(self, metrics: Dict[str, StageMetrics]) -> Optional[str]: + """Step 2: Identify the bottleneck stage (highest queue depth / lowest throughput). + + Find the single worst bottleneck based on queue depth and throughput. + Uses existing _identify_bottlenecks() logic but returns only the worst one. + + Args: + metrics: Dict of stage metrics + + Returns: + stage_name of worst bottleneck, or None if no bottleneck + """ + if not metrics: + return None + + # Get all bottleneck candidates + bottlenecks = self._identify_bottlenecks() + + if not bottlenecks: + return None + + # Find the worst bottleneck based on score + # Score = (queue_depth / target) + (1 - throughput_ratio) + worst_bottleneck = None + worst_score = -1.0 + + for stage_name in bottlenecks: + if stage_name not in metrics: + continue + + stage_metrics = metrics[stage_name] + + # Calculate bottleneck severity score + queue_ratio = ( + stage_metrics.queue_depth / self.target_queue_depth + if self.target_queue_depth > 0 else 0 + ) + + # Consider throughput relative to what's needed + # Higher queue with lower throughput = worse bottleneck + throughput_factor = 1.0 + if stage_metrics.throughput > 0 and stage_metrics.queue_depth > 0: + # Time to drain queue at current throughput + drain_time = stage_metrics.queue_depth / stage_metrics.throughput + sla_time = self.sla_latency_ms / 1000.0 + throughput_factor = drain_time / sla_time if sla_time > 0 else 1.0 + + score = queue_ratio + throughput_factor + + if score > worst_score: + worst_score = score + worst_bottleneck = stage_name + + return worst_bottleneck + + def reallocate_resources( + self, + bottleneck: Optional[str], + metrics: Dict[str, StageMetrics] + ) -> Dict[str, ResourceQuota]: + """Step 3: Reallocate resources, increasing quota for bottleneck. + + If bottleneck exists, shift resources toward it using tower_allocation_weights + from config (via get_stage_weight). + + Args: + bottleneck: Name of bottleneck stage, or None + metrics: Current stage metrics + + Returns: + Dict mapping captain_id to new ResourceQuota + """ + # Compute weights for each stage + stage_weights: Dict[str, float] = {} + for stage_name in self.stages: + # Base weight from config + if self.config: + base_weight = self.config.get_stage_weight(stage_name) + else: + base_weight = 1.0 + + # Boost weight for bottleneck stage + if stage_name == bottleneck: + base_weight *= 1.5 # 50% boost for bottleneck + + stage_weights[stage_name] = base_weight + + # Normalize weights + total_weight = sum(stage_weights.values()) + if total_weight > 0: + for stage_name in stage_weights: + stage_weights[stage_name] /= total_weight + + # Allocate resources based on weights + for captain_id, quota in self.quotas.items(): + stage_name = self._get_stage_from_captain(captain_id) + if stage_name not in self.stages: + continue + + stage_metrics = metrics.get(stage_name, self.stages[stage_name]) + weight = stage_weights.get(stage_name, 1.0 / max(1, len(self.stages))) + + # Compute target parallelism + target_parallelism = self._compute_target_parallelism( + stage_metrics, + is_bottleneck=(stage_name == bottleneck) + ) + + # Allocate resources proportionally to weight + resource_allocation = { + 'cpu': self.cluster.available_cpu_cores * weight * target_parallelism, + 'memory_mb': self.cluster.available_memory_mb * weight * target_parallelism, + 'gpu': self.cluster.available_gpus * weight * target_parallelism + } + + # Update quota + quota.target_parallelism = target_parallelism + quota.cpu_quota = resource_allocation['cpu'] + quota.memory_quota_mb = resource_allocation['memory_mb'] + quota.gpu_quota = resource_allocation['gpu'] + quota.target_throughput = self._compute_target_throughput(stage_metrics) + quota.topology_mode = self._decide_topology(stage_name, stage_metrics) + + return self.quotas + + def apply_backpressure( + self, + bottleneck: Optional[str], + metrics: Dict[str, StageMetrics] + ) -> None: + """Step 3b: Apply backpressure to upstream stages if needed. + + If bottleneck stage memory_utilization > backpressure_threshold: + Find upstream stages (stages registered before bottleneck) + Set backpressure=True flag on their quotas + + Args: + bottleneck: Name of bottleneck stage, or None + metrics: Current stage metrics + """ + # Reset all backpressure states first + for stage_name in self._backpressure_states: + self._backpressure_states[stage_name] = False + + if not bottleneck or bottleneck not in metrics: + # No bottleneck, clear all backpressure + for quota in self.quotas.values(): + quota.backpressure = False + return + + bottleneck_metrics = metrics[bottleneck] + + # Check if memory utilization exceeds threshold + # memory_utilization is in percentage (0-100), threshold is ratio (0-1) + memory_util_ratio = bottleneck_metrics.memory_utilization / 100.0 + + if memory_util_ratio <= self._backpressure_threshold: + # Below threshold, no backpressure needed + for quota in self.quotas.values(): + quota.backpressure = False + return + + # Find bottleneck position in stage order + try: + bottleneck_idx = self._stage_order.index(bottleneck) + except ValueError: + return + + # Apply backpressure to all upstream stages (before bottleneck) + upstream_stages = set(self._stage_order[:bottleneck_idx]) + + for captain_id, quota in self.quotas.items(): + stage_name = self._get_stage_from_captain(captain_id) + if stage_name in upstream_stages: + quota.backpressure = True + self._backpressure_states[stage_name] = True + else: + quota.backpressure = False + + def broadcast_quotas(self, quotas: Dict[str, ResourceQuota]) -> None: + """Step 4: Send updated quotas to all Captains. + + For each captain in registry, call captain.set_quota(quota). + + Args: + quotas: Dict mapping captain_id to ResourceQuota + """ + logger = logging.getLogger(__name__) + + for captain_id, captain in self._captains.items(): + # Find matching quota + quota = quotas.get(captain_id) + + if quota is None: + # Try to find quota by stage name + for qid, q in quotas.items(): + if self._get_stage_from_captain(qid) == self._get_stage_from_captain(captain_id): + quota = q + break + + if quota is None: + continue + + try: + if hasattr(captain, 'set_quota'): + captain.set_quota(quota) + except Exception as e: + logger.warning(f"Failed to broadcast quota to captain {captain_id}: {e}") + + # ========== Rebalance Loop Lifecycle ========== + + def start(self) -> None: + """Start the rebalance loop in a background thread.""" + if self._running: + return + + self._running = True + self._rebalance_thread = threading.Thread( + target=self._rebalance_loop, + daemon=True, + name="Tower-Rebalance-Loop" + ) + self._rebalance_thread.start() + + def stop(self) -> None: + """Stop the rebalance loop.""" + self._running = False + if self._rebalance_thread and self._rebalance_thread.is_alive(): + self._rebalance_thread.join(timeout=self.rebalance_interval * 2) + self._rebalance_thread = None + + def _rebalance_loop(self) -> None: + """Main rebalance loop - runs periodically. + + Implements the adaptive tower macro-scheduler: + for each rebalance_interval: + 1. Collect metrics from all Captains + 2. Identify bottleneck stage (highest queue / lowest throughput) + 3. Reallocate resources (increase quota for bottleneck, apply backpressure) + 4. Broadcast new quotas to Captains + """ + logger = logging.getLogger(__name__) + logger.info(f"Tower rebalance loop started (interval={self.rebalance_interval}s)") + + while self._running: + try: + # Step 1: Collect metrics from all Captains + metrics = self.collect_all_metrics() + + if metrics: + # Step 2: Identify bottleneck stage + bottleneck = self.identify_bottleneck(metrics) + + # Step 3: Reallocate resources + new_quotas = self.reallocate_resources(bottleneck, metrics) + + # Step 3b: Apply backpressure if needed + self.apply_backpressure(bottleneck, metrics) + + # Step 4: Broadcast new quotas to Captains + self.broadcast_quotas(new_quotas) + + if bottleneck: + logger.info(f"Tower rebalance: bottleneck={bottleneck}") + else: + logger.info("Tower rebalance: no bottleneck detected") + + # Log throughput metrics per stage + for stage_name, stage_metrics in metrics.items(): + logger.info( + f" [{stage_name}] throughput={stage_metrics.throughput:.1f} sps, " + f"latency={stage_metrics.avg_latency_ms:.1f}ms, " + f"queue={stage_metrics.queue_depth}" + ) + + except Exception as e: + logger.error(f"Rebalance loop error: {e}") + + # Wait for next interval + time.sleep(self.rebalance_interval) + + logger.info("Tower rebalance loop stopped") + + # ========== Context Manager Support ========== + + def __enter__(self) -> 'Tower': + """Context manager entry - start the rebalance loop.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit - stop the rebalance loop.""" + self.stop() diff --git a/data_juicer/core/elasticjuicer/tuner/__init__.py b/data_juicer/core/elasticjuicer/tuner/__init__.py new file mode 100644 index 00000000000..7f9ee5c56dd --- /dev/null +++ b/data_juicer/core/elasticjuicer/tuner/__init__.py @@ -0,0 +1,10 @@ +""" +Tuner submodule for ElasticJuicer hyperparameter optimization. + +Provides OFFLINE Ray Tune Population Based Training (PBT) for tuning +scheduling parameters. +""" + +from .pbt_tuner import PBTTuner + +__all__ = ["PBTTuner"] diff --git a/data_juicer/core/elasticjuicer/tuner/pbt_tuner.py b/data_juicer/core/elasticjuicer/tuner/pbt_tuner.py new file mode 100644 index 00000000000..8e304934c5a --- /dev/null +++ b/data_juicer/core/elasticjuicer/tuner/pbt_tuner.py @@ -0,0 +1,407 @@ +""" +OFFLINE Phase: Ray Tune Population Based Training (PBT) +for hyperparameter optimization of ElasticJuicer scheduling parameters. + +Tunes: + - PID controller params (kp, ki, kd) + - Safety buffers (safety_buffer_mb, target_memory_utilization) + - Predictor params (predictor_window_size, predictor_confidence_level) + - Tower allocation weights (per-stage resource proportions) +Output: base_config.yaml (a SchedulerConfig serialized to YAML) + +Usage: + from data_juicer.core.elasticjuicer.tuner import PBTTuner + + config = PBTTunerConfig( + stage_names=["filter", "mapper", "deduplicator"], + num_samples=8, + max_iterations=50, + ) + tuner = PBTTuner(config) + best_config = tuner.tune() + tuner.export_config(best_config, "base_config.yaml") +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Callable, Any +import random +import numpy as np + +# Graceful handling of optional Ray dependency +try: + import ray + from ray import tune + from ray.tune.schedulers import PopulationBasedTraining + RAY_AVAILABLE = True +except ImportError: + RAY_AVAILABLE = False + ray = None + tune = None + PopulationBasedTraining = None + +from ..scheduler.scheduler_config import SchedulerConfig +from ..scheduler.micro_scheduler import MicroScheduler, BatchSizeController + + +@dataclass +class PBTTunerConfig: + """Configuration for PBT-based hyperparameter tuning. + + Attributes: + num_samples: Number of PBT population members (parallel trials). + max_iterations: Maximum training iterations per trial. + perturbation_interval: How often PBT perturbs hyperparameters. + metric: Metric to optimize (e.g., "throughput", "score"). + mode: Optimization mode - "max" to maximize, "min" to minimize. + stage_names: Operator stage names to tune allocation weights for. + resources_per_trial: Resources allocated per trial (cpu, gpu). + grace_period: Minimum iterations before stopping poor trials. + """ + num_samples: int = 8 + max_iterations: int = 50 + perturbation_interval: int = 5 + metric: str = "throughput" + mode: str = "max" + stage_names: List[str] = field(default_factory=list) + resources_per_trial: Dict[str, float] = field( + default_factory=lambda: {"cpu": 2, "gpu": 0} + ) + grace_period: int = 5 + + +class PBTTuner: + """ + Population Based Training (PBT) tuner for ElasticJuicer scheduling parameters. + + This class implements OFFLINE hyperparameter optimization using Ray Tune's PBT + scheduler. It tunes PID controller parameters, memory safety buffers, predictor + settings, and per-stage resource allocation weights. + + The tuning process simulates batch processing with the given configuration and + measures throughput and OOM rates to find optimal parameters. + + Attributes: + config: PBTTunerConfig instance with tuning settings. + simulation_fn: Callable that simulates execution and returns metrics. + + Example: + >>> tuner_config = PBTTunerConfig( + ... stage_names=["filter", "mapper"], + ... num_samples=4, + ... max_iterations=20, + ... ) + >>> tuner = PBTTuner(tuner_config) + >>> best_config = tuner.tune() + >>> tuner.export_config(best_config, "base_config.yaml") + """ + + def __init__( + self, + config: PBTTunerConfig, + simulation_fn: Optional[Callable[[SchedulerConfig], Dict[str, float]]] = None, + ): + """ + Initialize the PBT tuner. + + Args: + config: PBTTunerConfig with tuning parameters. + simulation_fn: Optional callable that takes a SchedulerConfig and returns + a dict with "throughput" and "oom_rate" keys. If None, uses default + simulation that creates a MicroScheduler and simulates batch processing. + + Raises: + ImportError: If Ray is not installed and tune() is called. + """ + self.config = config + self.simulation_fn = simulation_fn or self._default_simulation + + def _get_search_space(self) -> Dict[str, Any]: + """ + Get the Ray Tune search space for hyperparameters. + + Returns: + Dictionary mapping parameter names to Ray Tune search distributions. + + Raises: + ImportError: If Ray Tune is not available. + """ + if not RAY_AVAILABLE: + raise ImportError( + "Ray Tune is required for PBT tuning. " + "Install it with: pip install 'ray[tune]'" + ) + + search_space = { + # PID controller parameters + "pid_kp": tune.uniform(0.1, 2.0), + "pid_ki": tune.uniform(0.01, 0.2), + "pid_kd": tune.uniform(0.01, 0.5), + + # Safety and memory parameters + "safety_buffer_mb": tune.uniform(256, 4096), + "target_memory_utilization": tune.uniform(0.6, 0.95), + + # Predictor parameters + "predictor_window_size": tune.choice([50, 100, 200, 500]), + "predictor_confidence_level": tune.uniform(0.9, 0.99), + } + + # Add per-stage allocation weights + for stage_name in self.config.stage_names: + search_space[f"weight_{stage_name}"] = tune.uniform(0.1, 5.0) + + return search_space + + def _default_simulation(self, scheduler_config: SchedulerConfig) -> Dict[str, float]: + """ + Default simulation function that tests a SchedulerConfig. + + Creates a MicroScheduler with the given PID parameters and simulates + N iterations of batch processing with random memory fluctuations. + Measures simulated throughput and OOM events. + + Args: + scheduler_config: Configuration to evaluate. + + Returns: + Dictionary with "throughput" (samples/sec) and "oom_rate" (0.0-1.0). + """ + # Create MicroScheduler with config parameters + micro_scheduler = MicroScheduler( + memory_predictor=None, + initial_batch_size=scheduler_config.initial_batch_size, + min_batch_size=scheduler_config.min_batch_size, + max_batch_size=scheduler_config.max_batch_size, + target_memory_utilization=scheduler_config.target_memory_utilization, + safety_buffer_mb=scheduler_config.safety_buffer_mb, + use_gpu=scheduler_config.use_gpu_memory, + enable_auto_adjust=scheduler_config.enable_auto_adjust, + ) + + # Override PID parameters in the controller + micro_scheduler.controller.pid.kp = scheduler_config.pid_kp + micro_scheduler.controller.pid.ki = scheduler_config.pid_ki + micro_scheduler.controller.pid.kd = scheduler_config.pid_kd + + # Simulation parameters + num_iterations = 100 + total_samples_processed = 0 + oom_events = 0 + + # Simulated available memory (starts high, fluctuates) + base_memory_mb = 8000.0 # 8GB base + + for i in range(num_iterations): + # Get current batch size from scheduler + batch_size = micro_scheduler.controller.current_batch_size + + # Simulate memory usage per sample (varies randomly) + memory_per_sample = np.random.uniform(5.0, 20.0) # 5-20 MB per sample + + # Add random memory fluctuation (simulates other processes) + memory_fluctuation = np.random.uniform(-500, 500) + + # Calculate simulated memory state + simulated_used_memory = batch_size * memory_per_sample + memory_fluctuation + simulated_available = base_memory_mb - simulated_used_memory + + # Check for simulated OOM + if simulated_available < scheduler_config.safety_buffer_mb * 0.5: + oom_events += 1 + # Report OOM to scheduler + micro_scheduler.controller.report_oom(batch_size, simulated_used_memory) + # Penalize throughput for OOM + total_samples_processed += batch_size // 4 + else: + # Successful batch + total_samples_processed += batch_size + + # Update scheduler (simulates feedback loop) + micro_scheduler.controller.update_batch_size( + predicted_memory_per_sample=memory_per_sample + ) + + # Calculate metrics + throughput = total_samples_processed / num_iterations # samples per iteration + oom_rate = oom_events / num_iterations + + return { + "throughput": throughput, + "oom_rate": oom_rate, + } + + def _trial_config_to_scheduler_config(self, trial_config: Dict) -> SchedulerConfig: + """ + Convert Ray Tune trial config dict to SchedulerConfig. + + Args: + trial_config: Dictionary of hyperparameters from Ray Tune. + + Returns: + SchedulerConfig instance with the trial's hyperparameters. + """ + # Extract tower allocation weights from weight_{stage} keys + tower_weights = {} + for key, value in trial_config.items(): + if key.startswith("weight_"): + stage_name = key[7:] # Remove "weight_" prefix + tower_weights[stage_name] = value + + return SchedulerConfig( + # PID parameters + pid_kp=trial_config.get("pid_kp", 0.5), + pid_ki=trial_config.get("pid_ki", 0.05), + pid_kd=trial_config.get("pid_kd", 0.1), + + # Safety parameters + safety_buffer_mb=trial_config.get("safety_buffer_mb", 1000.0), + target_memory_utilization=trial_config.get("target_memory_utilization", 0.85), + + # Predictor parameters + predictor_window_size=int(trial_config.get("predictor_window_size", 100)), + predictor_confidence_level=trial_config.get("predictor_confidence_level", 0.95), + + # Tower allocation weights + tower_allocation_weights=tower_weights if tower_weights else None, + ) + + def _trainable(self, trial_config: Dict) -> None: + """ + Ray Tune trainable function. + + Converts trial config to SchedulerConfig, runs simulation, + and reports metrics to Ray Tune. + + Args: + trial_config: Dictionary of hyperparameters from Ray Tune. + """ + if not RAY_AVAILABLE: + raise ImportError( + "Ray Tune is required for PBT tuning. " + "Install it with: pip install 'ray[tune]'" + ) + + # Convert trial config to SchedulerConfig + scheduler_config = self._trial_config_to_scheduler_config(trial_config) + + # Run simulation + results = self.simulation_fn(scheduler_config) + + throughput = results.get("throughput", 0.0) + oom_rate = results.get("oom_rate", 1.0) + + # Calculate composite score (higher is better) + # Penalize OOM events heavily + score = throughput * (1.0 - oom_rate) + + # Report metrics to Ray Tune + tune.report( + throughput=throughput, + oom_rate=oom_rate, + score=score, + ) + + def tune(self) -> SchedulerConfig: + """ + Run PBT hyperparameter tuning. + + Sets up Ray Tune with PBT scheduler, runs the tuning process, + and returns the best configuration found. + + Returns: + SchedulerConfig with the best hyperparameters found. + + Raises: + ImportError: If Ray Tune is not installed. + RuntimeError: If tuning fails or no results are found. + """ + if not RAY_AVAILABLE: + raise ImportError( + "Ray Tune is required for PBT tuning. " + "Install it with: pip install 'ray[tune]'" + ) + + # Initialize Ray if not already initialized + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + # Get search space + search_space = self._get_search_space() + + # Define perturbation bounds for PBT + hyperparam_mutations = { + "pid_kp": tune.uniform(0.1, 2.0), + "pid_ki": tune.uniform(0.01, 0.2), + "pid_kd": tune.uniform(0.01, 0.5), + "safety_buffer_mb": tune.uniform(256, 4096), + "target_memory_utilization": tune.uniform(0.6, 0.95), + "predictor_window_size": [50, 100, 200, 500], + "predictor_confidence_level": tune.uniform(0.9, 0.99), + } + + # Add stage weight mutations + for stage_name in self.config.stage_names: + hyperparam_mutations[f"weight_{stage_name}"] = tune.uniform(0.1, 5.0) + + # Create PBT scheduler + pbt_scheduler = PopulationBasedTraining( + time_attr="training_iteration", + perturbation_interval=self.config.perturbation_interval, + hyperparam_mutations=hyperparam_mutations, + quantile_fraction=0.25, # Top 25% survive + resample_probability=0.25, # 25% chance to resample instead of perturb + ) + + # Run tuning + analysis = tune.run( + self._trainable, + config=search_space, + metric=self.config.metric, + mode=self.config.mode, + num_samples=self.config.num_samples, + scheduler=pbt_scheduler, + resources_per_trial=self.config.resources_per_trial, + stop={"training_iteration": self.config.max_iterations}, + verbose=1, + raise_on_failed_trial=False, + ) + + # Get best trial + best_trial = analysis.get_best_trial( + metric=self.config.metric, + mode=self.config.mode, + ) + + if best_trial is None: + raise RuntimeError( + "PBT tuning failed: no successful trials found. " + "Check simulation function and resource availability." + ) + + # Convert best config to SchedulerConfig + best_config = self._trial_config_to_scheduler_config(best_trial.config) + + return best_config + + def export_config(self, config: SchedulerConfig, path: str = "base_config.yaml") -> None: + """ + Export a SchedulerConfig to a YAML file. + + Args: + config: SchedulerConfig to export. + path: Output file path (default: "base_config.yaml"). + """ + config.to_yaml(path) + + @staticmethod + def load_config(path: str) -> SchedulerConfig: + """ + Load a SchedulerConfig from a YAML file. + + Args: + path: Path to the YAML configuration file. + + Returns: + SchedulerConfig instance loaded from the file. + """ + return SchedulerConfig.from_yaml(path) diff --git a/data_juicer/core/executor/__init__.py b/data_juicer/core/executor/__init__.py index 5073c6760f9..66aa75689c9 100644 --- a/data_juicer/core/executor/__init__.py +++ b/data_juicer/core/executor/__init__.py @@ -3,5 +3,6 @@ from .factory import ExecutorFactory from .ray_executor import RayExecutor from .ray_executor_partitioned import PartitionedRayExecutor +from .elastic_ray_executor import ElasticRayExecutor -__all__ = ["ExecutorBase", "ExecutorFactory", "DefaultExecutor", "RayExecutor", "PartitionedRayExecutor"] +__all__ = ["ExecutorBase", "ExecutorFactory", "DefaultExecutor", "RayExecutor", "PartitionedRayExecutor", "ElasticRayExecutor"] diff --git a/data_juicer/core/executor/default_executor.py b/data_juicer/core/executor/default_executor.py index dd958af1ccd..dc65f6303ee 100644 --- a/data_juicer/core/executor/default_executor.py +++ b/data_juicer/core/executor/default_executor.py @@ -207,6 +207,28 @@ def run( if op.is_batched_op(): op.batch_size = bs_per_op[i] + # Persist probe results to ProfilingStore so ElasticRayExecutor + # can load them as priors on the next run. + if getattr(self.cfg, 'persist_probe_to_profiling_store', True): + try: + from data_juicer.core.elasticjuicer.profiler.probe_adapter import ProbeAdapter + from data_juicer.core.elasticjuicer.profiler.profiling_store import ProfilingStore + + store_dir = getattr( + self.cfg, 'profiling_store_dir', + './elastic_juicer_profiles', + ) + bridge = ProbeAdapter(ProfilingStore(storage_dir=store_dir)) + bridge.ingest_probe_results( + probe_results=self.adapter._last_analysis, + op_names=[op._name for op in ops], + probe_batch_sizes=bs_per_op, + ) + except Exception as e: + logger.warning( + "Failed to persist probe to ProfilingStore: %s", e + ) + # 3. data process with DAG monitoring # - If tracer is open, trace each op after it's processed # - If checkpoint is open, clean the cache files after each process diff --git a/data_juicer/core/executor/elastic_ray_executor.py b/data_juicer/core/executor/elastic_ray_executor.py new file mode 100644 index 00000000000..169bfc7a68d --- /dev/null +++ b/data_juicer/core/executor/elastic_ray_executor.py @@ -0,0 +1,967 @@ +""" +Elastic Ray Executor with ElasticJuicer Adaptive Scheduling + +This module provides ElasticRayExecutor, which integrates ElasticJuicer's +adaptive scheduling (Tower macro-scheduler, Captain micro-scheduler, +MicroScheduler batch size control) into Data-Juicer's standard execution pipeline. + +The bi-level scheduling architecture: +- Tower (macro-scheduler): Global resource allocation and rebalancing +- Captains (micro-schedulers): Per-operator batch size control with PID + +Usage in YAML config: + executor_type: elastic_ray + + elastic_juicer: + scheduler_preset: gpu # conservative/gpu/aggressive + rebalance_interval: 5.0 # Tower rebalance interval in seconds + enable_offline_tuning: false # PBT offline tuning + config_path: null # path to pre-tuned SchedulerConfig YAML +""" + +import os +import shutil +import time +from dataclasses import asdict +from functools import partial +from typing import Any, Dict, List, Optional + +from jsonargparse import Namespace +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.core.data.dataset_builder import DatasetBuilder +from data_juicer.core.executor import ExecutorBase +from data_juicer.core.executor.dag_execution_mixin import DAGExecutionMixin +from data_juicer.core.executor.event_logging_mixin import EventLoggingMixin +from data_juicer.core.ray_exporter import RayExporter +from data_juicer.core.tracer.ray_tracer import RayTracer +from data_juicer.ops import Deduplicator, Filter, Mapper, OPEnvManager, Pipeline, load_ops +from data_juicer.ops.base_op import DEFAULT_BATCH_SIZE, TAGGING_OPS +from data_juicer.ops.op_fusion import fuse_operators +from data_juicer.utils.constant import Fields +from data_juicer.utils.lazy_loader import LazyLoader + +ray = LazyLoader("ray") +pyarrow = LazyLoader("pyarrow") + + +class TempDirManager: + """Context manager for temporary directory cleanup.""" + + def __init__(self, tmp_dir): + self.tmp_dir = tmp_dir + + def __enter__(self): + os.makedirs(self.tmp_dir, exist_ok=True) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if os.path.exists(self.tmp_dir): + logger.info(f"Removing tmp dir {self.tmp_dir} ...") + shutil.rmtree(self.tmp_dir) + + +# Note: _get_preset_config, _get_cluster_state, and _scheduler_config_to_dict +# have been moved to the ElasticJuicer facade in elastic_juicer.py + + +def filter_batch(batch, filter_func): + """Filter batch using filter function.""" + import pyarrow + + mask = pyarrow.array(filter_func(batch.to_pydict())) + return batch.filter(mask) + + +# Note: _create_shared_quota_store and MetricsBridge have been moved to +# the ElasticJuicer facade in elastic_juicer.py + + +class AdaptiveOperator: + """ + Adaptive wrapper for any DJ operator with MicroScheduler for batch sizing. + + This wrapper is used with Ray Data's map_batches() and ActorPoolStrategy + to provide adaptive batch sizing for GPU operators. + + The bi-level scheduling integration: + - MicroScheduler: Local batch size control based on memory feedback + - SharedQuotaStore: Reads Tower quotas for global coordination + """ + + def __init__( + self, + op_class_name: str, + op_kwargs: Dict[str, Any], + stage_name: str, + initial_batch_size: int, + scheduler_config_dict: Dict[str, Any], + ): + """ + Initialize the adaptive operator wrapper. + + Args: + op_class_name: Name of the operator class in OPERATORS registry + op_kwargs: Constructor kwargs for the operator + stage_name: Name for this stage (used in metrics reporting) + initial_batch_size: Initial batch size for MicroScheduler + scheduler_config_dict: SchedulerConfig as dict for MicroScheduler init + """ + # Import and instantiate the actual operator + from data_juicer.ops import OPERATORS + + op_cls = OPERATORS.modules[op_class_name] + self.op = op_cls(**op_kwargs) + self.stage_name = stage_name + self.initial_batch_size = initial_batch_size + + # Create MicroScheduler with config from dict + from data_juicer.core.elasticjuicer.scheduler.micro_scheduler import MicroScheduler + + self.micro_scheduler = MicroScheduler( + initial_batch_size=initial_batch_size, + max_batch_size=scheduler_config_dict.get("max_batch_size", 1000), + min_batch_size=scheduler_config_dict.get("min_batch_size", 1), + target_memory_utilization=scheduler_config_dict.get("target_memory_utilization", 0.85), + safety_buffer_mb=scheduler_config_dict.get("safety_buffer_mb", 1000.0), + use_gpu=scheduler_config_dict.get("use_gpu_memory", False), + ) + + # Try to connect to PipelineMetricsCollector with retry logic + self.metrics_collector = None + import ray + for attempt in range(3): + try: + self.metrics_collector = ray.get_actor("elastic_pipeline_metrics") + logger.info(f"[{self.stage_name}] Connected to PipelineMetricsCollector") + break + except Exception as e: + delay = 0.5 * (attempt + 1) # 0.5s, 1.0s, 1.5s + if attempt < 2: + time.sleep(delay) + else: + logger.warning(f"[{self.stage_name}] Failed to connect to PipelineMetricsCollector after 3 attempts: {e}") + + # Connect to SharedQuotaStore for Tower quotas with retry logic + self.quota_store = None + for attempt in range(3): + try: + self.quota_store = ray.get_actor("elastic_quota_store") + logger.info(f"[{self.stage_name}] Connected to SharedQuotaStore") + break + except Exception as e: + delay = 0.5 * (attempt + 1) # 0.5s, 1.0s, 1.5s + if attempt < 2: + time.sleep(delay) + else: + logger.warning(f"[{self.stage_name}] Failed to connect to SharedQuotaStore after 3 attempts: {e}") + + # Quota check state + self._last_quota_check = 0 + self._quota_check_interval = 2.0 # seconds + self._backpressure = False + + # Statistics + self.batch_sizes_used: List[int] = [] + self.samples_processed = 0 + self.total_latency_ms = 0.0 + + # Store last batch for _update_scheduler feature extraction + self._last_batch = None + + def _check_quota(self): + """Check if Tower has issued new quota for this stage.""" + import time + + now = time.time() + if now - self._last_quota_check < self._quota_check_interval: + return + self._last_quota_check = now + + if self.quota_store is None: + return + + try: + import ray + + quota = ray.get(self.quota_store.get_quota.remote(self.stage_name)) + if quota: + # Apply Tower's batch size recommendation + # Blend Tower recommendation with MicroScheduler's local decision + tower_bs = quota.get("batch_size", None) + if tower_bs and tower_bs > 0: + current = self.micro_scheduler.controller.current_batch_size + # Only blend if Tower recommends increase; don't let Tower pull batch size below initial + if tower_bs >= current: + blended = int(0.7 * current + 0.3 * tower_bs) + new_bs = max(self.initial_batch_size, blended) + else: + # Tower recommends decrease - only allow if significantly lower (backpressure scenario) + if tower_bs < current * 0.5: + new_bs = max(self.initial_batch_size, int(0.7 * current + 0.3 * tower_bs)) + else: + new_bs = current # Keep current, ignore minor decreases + if new_bs != current: + logger.info(f"[{self.stage_name}] Batch size: {current} -> {new_bs} (tower_bs={tower_bs})") + self.micro_scheduler.controller.current_batch_size = new_bs + + # Apply backpressure from Tower + self._backpressure = quota.get("backpressure", False) + if self._backpressure: + logger.info(f"[{self.stage_name}] Backpressure activated") + except Exception as e: + logger.debug(f"[{self.stage_name}] Quota check failed: {e}") + + def _extract_features(self, batch): + """Extract sample features from batch for MicroScheduler.""" + try: + num_rows = batch.num_rows if hasattr(batch, 'num_rows') else len(batch) + batch_memory_bytes = batch.nbytes if hasattr(batch, 'nbytes') else 0 + per_sample_mb = (batch_memory_bytes / max(num_rows, 1)) / (1024 * 1024) + + from data_juicer.core.elasticjuicer.scheduler.micro_scheduler import SampleFeatures + return SampleFeatures( + batch_size=num_rows, + estimated_memory_mb=per_sample_mb * num_rows, + ) + except Exception: + return None + + def __call__(self, batch): + """Process batch with adaptive sub-batching.""" + import time + + import pyarrow as pa + + # Store batch for _update_scheduler feature extraction + self._last_batch = batch + + # Check for Tower quotas periodically + self._check_quota() + + # If backpressure from Tower, add small delay to reduce pressure on downstream + if self._backpressure: + time.sleep(0.1) + + # Get actual memory usage once per __call__ invocation for metrics reporting + try: + import psutil + memory_mb = psutil.virtual_memory().used / (1024 * 1024) + except Exception: + memory_mb = 0.0 + + total_rows = batch.num_rows if hasattr(batch, "num_rows") else len(batch) + + if total_rows == 0: + return batch + + # Get recommended batch size from MicroScheduler (uses PID feedback) + recommended_bs = self.micro_scheduler.get_batch_size(sample_features=self._extract_features(batch)) + recommended_bs = max(1, recommended_bs) + + # If batch is smaller than recommended, process whole batch + if total_rows <= recommended_bs: + self.batch_sizes_used.append(total_rows) + + t0 = time.time() + try: + result = self.op(batch) + elapsed_ms = (time.time() - t0) * 1000 + success = True + except Exception as e: + elapsed_ms = (time.time() - t0) * 1000 + self.micro_scheduler.report_oom(batch_size=total_rows, memory_mb=0.0) + raise + + # Report metrics + if self.metrics_collector: + try: + import ray + + ray.get( + self.metrics_collector.report.remote( + self.stage_name, total_rows, elapsed_ms, memory_mb + ) + ) + except Exception: + pass + + # Update MicroScheduler + self._update_scheduler() + + self.samples_processed += total_rows + self.total_latency_ms += elapsed_ms + + return result + + # Process in sub-batches + offset = 0 + results = [] + + while offset < total_rows: + recommended_bs = self.micro_scheduler.get_batch_size(sample_features=self._extract_features(batch)) + recommended_bs = max(1, min(recommended_bs, total_rows - offset)) + self.batch_sizes_used.append(recommended_bs) + + end = min(offset + recommended_bs, total_rows) + sub_batch = batch.slice(offset, end - offset) + + t0 = time.time() + try: + sub_result = self.op(sub_batch) + elapsed_ms = (time.time() - t0) * 1000 + success = True + except Exception as e: + elapsed_ms = (time.time() - t0) * 1000 + success = False + # On error (e.g. OOM), reduce batch size and retry with smaller batch + self.micro_scheduler.report_oom(batch_size=end - offset, memory_mb=0.0) + if end - offset > 1: + recommended_bs = max(1, recommended_bs // 2) + continue + else: + raise + + results.append(sub_result) + + # Report metrics + if self.metrics_collector: + try: + import ray + + ray.get( + self.metrics_collector.report.remote( + self.stage_name, end - offset, elapsed_ms, memory_mb + ) + ) + except Exception: + pass + + # Update MicroScheduler + self._update_scheduler() + + self.samples_processed += end - offset + self.total_latency_ms += elapsed_ms + offset = end + + if len(results) == 1: + return results[0] + + # Concatenate PyArrow tables + if isinstance(results[0], pa.Table): + return pa.concat_tables(results) + + return results[0] + + def _update_scheduler(self): + """Update MicroScheduler with current memory state.""" + try: + import psutil + + memory_mb = psutil.virtual_memory().used / (1024 * 1024) + sample_features = self._extract_features(self._last_batch) if self._last_batch is not None else None + self.micro_scheduler.update(actual_memory_used=memory_mb, sample_features=sample_features) + except Exception: + pass + + +# Note: _create_pipeline_metrics_collector has been moved to the ElasticJuicer facade +# AdaptiveOperator connects to the named actors by name + + +class ElasticRayExecutor(ExecutorBase, DAGExecutionMixin, EventLoggingMixin): + """ + Ray executor with ElasticJuicer adaptive scheduling. + + Integrates Tower (macro-scheduler), Captain (per-stage micro-scheduler), + and MicroScheduler (per-actor batch adaptation) into Data-Juicer's + standard execution pipeline. + + Features: + - Adaptive batch sizing for GPU operators via MicroScheduler + - PID-controlled memory management to prevent OOM + - Tower-based global resource allocation (optional rebalancing) + - Per-stage metrics collection + + Usage in YAML config: + executor_type: elastic_ray + + elastic_juicer: + scheduler_preset: gpu # conservative/gpu/aggressive + rebalance_interval: 5.0 # Tower rebalance interval in seconds + enable_offline_tuning: false # PBT offline tuning + config_path: null # path to pre-tuned SchedulerConfig YAML + """ + + def __init__(self, cfg: Optional[Namespace] = None): + """ + Initialization method. + + :param cfg: optional config dict. + """ + super().__init__(cfg) + + self.executor_type = "elastic_ray" + self.work_dir = self.cfg.work_dir + + # Initialize EventLoggingMixin for job management and event logging + EventLoggingMixin.__init__(self, cfg) + + # Initialize DAGExecutionMixin for AST/DAG functionality + DAGExecutionMixin.__init__(self) + + # init ray + logger.info("Initializing Ray for ElasticRayExecutor ...") + + from data_juicer.utils.ray_utils import initialize_ray + + initialize_ray(cfg=cfg, force=True) + + self.tmp_dir = os.path.join( + self.work_dir, ".tmp", ray.get_runtime_context().get_job_id() + ) + + # init dataset builder + self.datasetbuilder = DatasetBuilder(self.cfg, executor_type="ray") + + logger.info("Preparing exporter...") + # Prepare export extra args, including S3 credentials if export_path is S3 + export_extra_args = ( + dict(self.cfg.export_extra_args) + if hasattr(self.cfg, "export_extra_args") + else {} + ) + + # If export_path is S3, extract AWS credentials + if self.cfg.export_path.startswith("s3://"): + if ( + hasattr(self.cfg, "export_aws_credentials") + and self.cfg.export_aws_credentials + ): + export_aws_creds = self.cfg.export_aws_credentials + credential_fields = { + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region", + "endpoint_url", + } + for field in credential_fields.intersection(export_aws_creds): + export_extra_args[field] = export_aws_creds[field] + + self.exporter = RayExporter( + self.cfg.export_path, + self.cfg.export_type, + self.cfg.export_shard_size, + keep_stats_in_res_ds=self.cfg.keep_stats_in_res_ds, + keep_hashes_in_res_ds=self.cfg.keep_hashes_in_res_ds, + **export_extra_args, + ) + + # setup tracer + self.tracer = None + self.open_tracer = self.cfg.open_tracer + if self.open_tracer: + logger.info("Preparing tracer...") + self.tracer = RayTracer.remote( + self.work_dir, + self.cfg.op_list_to_trace, + show_num=self.cfg.trace_num, + trace_keys=self.cfg.trace_keys, + ) + + # setup OPEnvManager + self.op_env_manager = None + if self.cfg.min_common_dep_num_to_combine >= 0: + logger.info("Preparing OPEnvManager...") + self.op_env_manager = OPEnvManager( + min_common_dep_num_to_combine=self.cfg.min_common_dep_num_to_combine, + conflict_resolve_strategy=self.cfg.conflict_resolve_strategy, + ) + + # ElasticJuicer facade (initialized in run()) + self._elastic = None # ElasticJuicer facade instance + self._scheduler_config = None + + def _parse_elastic_juicer_config(self): + """Parse elastic_juicer config section and create SchedulerConfig.""" + from data_juicer.core.elasticjuicer.scheduler.scheduler_config import SchedulerConfig + + elastic_cfg = getattr(self.cfg, "elastic_juicer", None) + if elastic_cfg is None: + elastic_cfg = {} + elif hasattr(elastic_cfg, "__dict__"): + # Convert Namespace to dict + elastic_cfg = dict(elastic_cfg) + + scheduler_preset = elastic_cfg.get("scheduler_preset", "gpu") + config_path = elastic_cfg.get("config_path", None) + rebalance_interval = elastic_cfg.get("rebalance_interval", 5.0) + + # Load or create SchedulerConfig + if config_path and os.path.exists(config_path): + self._scheduler_config = SchedulerConfig.from_yaml(config_path) + logger.info(f"Loaded ElasticJuicer config from {config_path}") + else: + # Use preset + presets = { + "conservative": SchedulerConfig.conservative, + "gpu": SchedulerConfig.gpu, + "aggressive": SchedulerConfig.aggressive, + } + factory = presets.get(scheduler_preset, SchedulerConfig.gpu) + self._scheduler_config = factory() + logger.info(f"Using ElasticJuicer {scheduler_preset} preset config") + + # Override rebalance interval if specified + self._scheduler_config.rebalance_interval_sec = rebalance_interval + + return elastic_cfg + + def _init_elasticjuicer_components(self, ops: List): + """Initialize ElasticJuicer facade with all scheduling components. + + Creates stage configs from operators and registers them with the + ElasticJuicer facade, which manages Tower, Captains, and Ray actors. + """ + from data_juicer.core.elasticjuicer import ElasticJuicer + + # Build stage configs from operators + stage_configs = [] + for i, op in enumerate(ops): + stage_configs.append({ + 'name': f"stage_{i}_{op._name}", + 'batch_size': getattr(op, 'batch_size', self._scheduler_config.initial_batch_size), + 'num_gpus': op.num_gpus or 0, + 'num_actors': op.runtime_np() if hasattr(op, 'runtime_np') else (op.num_proc or 1), + }) + + # Create ElasticJuicer facade + self._elastic = ElasticJuicer( + config=self._scheduler_config, + cluster_state=None, # auto-detect + ) + + # Register all stages with the facade + self._elastic.register_stages(stage_configs) + + logger.info( + f"ElasticJuicer facade initialized with {len(ops)} stages") + + def _run_single_op_elastic( + self, + ds, + op, + stage_index: int, + cached_columns: set, + scheduler_config_dict: Dict[str, Any], + ): + """ + Execute a single operator with ElasticJuicer adaptive scheduling. + + For GPU operators: use AdaptiveOperator wrapper with MicroScheduler + For CPU operators: use standard execution (same as RayDataset._run_single_op) + """ + import pyarrow as pa + from ray.data import ActorPoolStrategy + + stage_name = f"stage_{stage_index}_{op._name}" + + # Handle tagging ops - add meta column if needed + if op._name in TAGGING_OPS.modules and Fields.meta not in cached_columns: + + def process_batch_arrow(table: pa.Table): + new_column_data = [{} for _ in range(len(table))] + return table.append_column(Fields.meta, [new_column_data]) + + ds = ds.map_batches( + process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE + ) + cached_columns.add(Fields.meta) + + batch_size = getattr(op, "batch_size", 1) if op.is_batched_op() else 1 + + if isinstance(op, Mapper): + if op.use_ray_actor(): + # GPU Mapper: use AdaptiveOperator wrapper + ds = self._apply_adaptive_gpu_op( + ds, op, stage_name, batch_size, scheduler_config_dict + ) + else: + # CPU Mapper: standard execution + from ray.data._internal.util import get_compute_strategy + + num_proc = op.num_proc if op.num_proc and op.num_proc > 0 else None + compute = get_compute_strategy(op.process, concurrency=num_proc) + map_batches_kwargs = dict( + batch_size=batch_size, + batch_format="pyarrow", + num_cpus=op.num_cpus, + num_gpus=op.num_gpus, + compute=compute, + runtime_env=op.runtime_env, + ) + ds = ds.map_batches(op.process, **map_batches_kwargs) + + elif isinstance(op, Filter): + # Ensure stats column exists + if Fields.stats not in cached_columns: + + def process_batch_arrow(table: pa.Table): + new_column_data = [{} for _ in range(len(table))] + return table.append_column(Fields.stats, [new_column_data]) + + ds = ds.map_batches( + process_batch_arrow, + batch_format="pyarrow", + batch_size=DEFAULT_BATCH_SIZE, + ) + cached_columns.add(Fields.stats) + + if op.use_ray_actor(): + # GPU Filter: use AdaptiveOperator wrapper + ds = self._apply_adaptive_gpu_op( + ds, op, stage_name, batch_size, scheduler_config_dict + ) + else: + # CPU Filter: compute_stats then filter + from ray.data._internal.util import get_compute_strategy + + num_proc = op.num_proc if op.num_proc and op.num_proc > 0 else None + compute = get_compute_strategy(op.compute_stats, concurrency=num_proc) + map_batches_kwargs = dict( + batch_size=batch_size, + batch_format="pyarrow", + num_cpus=op.num_cpus, + num_gpus=op.num_gpus, + compute=compute, + runtime_env=op.runtime_env, + ) + ds = ds.map_batches(op.compute_stats, **map_batches_kwargs) + + # Apply filter (for both GPU and CPU filters) + if op.is_batched_op(): + ds = ds.map_batches( + partial(filter_batch, filter_func=op.process), + batch_format="pyarrow", + zero_copy_batch=True, + batch_size=DEFAULT_BATCH_SIZE, + runtime_env=op.runtime_env, + ) + else: + ds = ds.filter(op.process, runtime_env=op.runtime_env) + + elif isinstance(op, (Deduplicator, Pipeline)): + # Global ops: run directly + ds = op.run(ds) + + else: + logger.error( + "ElasticRayExecutor only supports Filter, Mapper, " + "Deduplicator and Pipeline OPs" + ) + raise NotImplementedError + + return ds, cached_columns + + def _apply_adaptive_gpu_op( + self, + ds, + op, + stage_name: str, + batch_size: int, + scheduler_config_dict: Dict[str, Any], + ): + """Apply GPU operator with AdaptiveOperator wrapper.""" + from ray.data import ActorPoolStrategy + + # Repartition for GPU actors + num_actors = op.runtime_np() if hasattr(op, "runtime_np") else (op.num_proc or 1) + override_num_blocks = getattr(op, "override_num_blocks", None) + if override_num_blocks is not None: + ds = ds.repartition(override_num_blocks) + else: + ds = ds.repartition(num_actors * 2) + + # Auto-scale batch size based on available GPU memory + if batch_size <= 4: # Only auto-scale very conservative defaults + try: + import subprocess + result = subprocess.run( + ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,noheader,nounits'], + capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0: + free_mb = min(int(x.strip()) for x in result.stdout.strip().split('\n') if x.strip()) + # Conservative: assume 1000MB per sample for video ops (frames are memory-hungry) + auto_bs = max(batch_size, min(16, free_mb // 1000)) + if auto_bs > batch_size: + logger.info(f"[{stage_name}] Auto batch_size: {batch_size} -> {auto_bs} (free GPU mem: {free_mb}MB)") + batch_size = auto_bs + except Exception: + pass + + # Extract operator config for AdaptiveOperator + op_kwargs = {} + if hasattr(op, "_op_cfg") and op._op_cfg: + # _op_cfg is like {"op_name": {kwargs}} + op_name, op_args = list(op._op_cfg.items())[0] + op_kwargs = dict(op_args) if op_args else {} + else: + # Fallback: extract from _init_kwargs + if hasattr(op, "_init_kwargs") and op._init_kwargs: + op_kwargs = dict(op._init_kwargs) + + # Get adaptive config from facade (if available) + adaptive_config = {} + if self._elastic is not None: + adaptive_config = self._elastic.get_adaptive_op_config(stage_name) + # Update with auto-scaled batch_size + adaptive_config['initial_batch_size'] = batch_size + else: + # Fallback if facade not initialized + adaptive_config = { + 'stage_name': stage_name, + 'initial_batch_size': batch_size, + 'scheduler_config_dict': scheduler_config_dict, + } + + # Use map_batches with AdaptiveOperator class + ds = ds.map_batches( + AdaptiveOperator, + fn_constructor_kwargs={ + "op_class_name": op._name, + "op_kwargs": op_kwargs, + **adaptive_config, + }, + batch_size=batch_size, + num_cpus=op.num_cpus or 1, + num_gpus=op.num_gpus or 1, + compute=ActorPoolStrategy(size=num_actors), + batch_format="pyarrow", + runtime_env=op.runtime_env, + ) + + return ds + + def run( + self, + load_data_np: Optional[PositiveInt] = None, + skip_export: bool = False, + skip_return: bool = False, + ): + """ + Running the dataset process pipeline with ElasticJuicer adaptive scheduling. + + :param load_data_np: number of workers when loading the dataset. + :param skip_export: whether to skip exporting results to disk + :param skip_return: skip return for API called. + :return: processed dataset. + """ + # 1. Parse ElasticJuicer config + elastic_cfg = self._parse_elastic_juicer_config() + + # 2. Load data + logger.info("Loading dataset with ElasticRayExecutor...") + dataset = self.datasetbuilder.load_dataset(num_proc=load_data_np) + columns = dataset.data.columns() + + # 3. Extract processes + logger.info("Preparing process operators...") + ops = load_ops(self.cfg.process, self.op_env_manager) + + # Initialize DAG execution planning + self._initialize_dag_execution(self.cfg, ops=ops) + + # Log job start with DAG context + dataset_info = {} + if hasattr(self.cfg, "dataset_path") and self.cfg.dataset_path: + dataset_info["dataset_path"] = self.cfg.dataset_path + if hasattr(self.cfg, "dataset") and self.cfg.dataset: + dataset_info["dataset"] = self.cfg.dataset + + job_config = { + **dataset_info, + "work_dir": self.work_dir, + "executor_type": self.executor_type, + "dag_node_count": len(self.pipeline_dag.nodes) if self.pipeline_dag else 0, + "dag_edge_count": len(self.pipeline_dag.edges) if self.pipeline_dag else 0, + "parallel_groups_count": ( + len(self.pipeline_dag.parallel_groups) if self.pipeline_dag else 0 + ), + } + self.log_job_start(job_config, len(ops)) + + if self.cfg.op_fusion: + logger.info( + f"Start OP fusion and reordering with strategy " + f"[{self.cfg.fusion_strategy}]..." + ) + ops = fuse_operators(ops) + + # 4. Detect whether pipeline has GPU operators + gpu_op_count = sum(1 for op in ops if (getattr(op, 'num_gpus', 0) or 0) > 0) + has_gpu_ops = gpu_op_count > 0 + logger.info(f'Pipeline analysis: {len(ops)} total operators, {gpu_op_count} GPU operators') + + # 5. Empty dataset guard + input_rows = dataset.data.count() + if input_rows == 0: + logger.warning('Empty dataset — skipping processing.') + if not skip_export: + self.exporter.export(dataset) + return dataset if not skip_return else None + + with TempDirManager(self.tmp_dir): + tstart = time.time() + start_time = time.time() + + # Pre-execute DAG monitoring + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(ops) + + if not has_gpu_ops: + # CPU-only fallback path: use standard RayDataset.process() + logger.info(f'No GPU operators detected in pipeline ({len(ops)} CPU ops). ' + f'Using standard execution path (ElasticJuicer disabled).') + + # Execute operations using standard dataset.process() like RayExecutor + dataset = dataset.process(ops, tracer=self.tracer) + + # Force materialization to get real execution + logger.info("Materializing dataset to collect real metrics...") + dataset.data = dataset.data.materialize() + + ds = dataset.data + else: + # GPU path: use full ElasticJuicer adaptive scheduling + logger.info(f'GPU operators detected. ' + f'Using ElasticJuicer adaptive scheduling.') + + # Initialize ElasticJuicer facade for GPU path + self._init_elasticjuicer_components(ops) + + # Prepare scheduler config dict for serialization + scheduler_config_dict = asdict(self._scheduler_config) + + logger.info("Processing data with ElasticJuicer adaptive scheduling...") + + # Start ElasticJuicer facade (Tower + MetricsBridge + Ray actors) + if self._elastic is not None: + self._elastic.start() + + try: + # Execute operations with adaptive scheduling + ds = dataset.data + cached_columns = set(ds.columns()) if ds.columns() else set() + + for i, op in enumerate(ops): + try: + ds, cached_columns = self._run_single_op_elastic( + ds, op, i, cached_columns, scheduler_config_dict + ) + except Exception as e: + logger.error(f"Error processing operator {op}: {e}") + if op.runtime_env is not None: + logger.error("Trying to fallback to base runtime environment") + original_runtime_env = op.runtime_env + try: + op.runtime_env = None + ds, cached_columns = self._run_single_op_elastic( + ds, op, i, cached_columns, scheduler_config_dict + ) + finally: + op.runtime_env = original_runtime_env + else: + raise e + + # Materialize dataset + logger.info("Materializing dataset to collect real metrics...") + ds = ds.materialize() + + finally: + # Stop ElasticJuicer facade (handles all cleanup) + if self._elastic is not None: + self._elastic.stop() + + # Update dataset.data for GPU path + dataset.data = ds + + # Get metrics after execution + duration = time.time() - start_time + output_rows = ds.count() + + # Post-execute DAG monitoring + if self.pipeline_dag: + metrics = { + "duration": duration, + "input_rows": input_rows, + "output_rows": output_rows, + } + self._post_execute_operations_with_dag_monitoring(ops, metrics=metrics) + + # Collect and log ElasticJuicer metrics (only for GPU path) + if has_gpu_ops: + self._log_elasticjuicer_metrics(input_rows, output_rows, duration) + + # 6. Data export + if not skip_export: + logger.info("Exporting dataset to disk...") + self.exporter.export(ds, columns=columns) + + tend = time.time() + logger.info(f"All Ops are done in {tend - tstart:.3f}s.") + + # Log job completion + job_duration = time.time() - tstart + self.log_job_complete(job_duration, self.cfg.export_path) + + # Finalize tracer + if self.tracer: + ray.get(self.tracer.finalize_traces.remote()) + + if not skip_return: + return dataset + + def _log_elasticjuicer_metrics( + self, input_rows: int, output_rows: int, duration: float + ): + """Log ElasticJuicer metrics summary using facade.""" + throughput = input_rows / duration if duration > 0 else 0 + + logger.info(f"ElasticJuicer Pipeline Summary:") + logger.info(f" Input rows: {input_rows}, Output rows: {output_rows}") + logger.info(f" Duration: {duration:.1f}s, Throughput: {throughput:.2f} samples/sec") + + if self._elastic is None: + return + + # Get metrics from facade + metrics_summary = self._elastic.get_metrics_summary() + for stage_name, metrics in metrics_summary.items(): + logger.info( + f" {stage_name}: {metrics.get('total_samples', 0)} samples, " + f"bs={metrics.get('min_bs', 0)}→{metrics.get('max_bs', 0)}, " + f"avg_latency={metrics.get('avg_latency_ms', 0):.1f}ms" + ) + + # Get Captain stats from facade + captain_stats = self._elastic.get_captain_stats() + if captain_stats: + logger.info(f" Captain Statistics:") + for stage_name, stats in captain_stats.items(): + if stats: + logger.info( + f" [{stage_name}]: " + f"throughput={stats.get('throughput', 0):.1f} sps, " + f"latency={stats.get('latency_ms', 0):.1f}ms, " + f"batch_size={stats.get('batch_size', 0)}, " + f"backpressure={stats.get('backpressure', False)}, " + f"oom_events={stats.get('oom_count', 0)}" + ) + + # Get Tower stats from facade + tower_stats = self._elastic.get_tower_stats() + if tower_stats: + logger.info(f" Tower stats: {tower_stats}") diff --git a/data_juicer/core/executor/elastic_scheduling_mixin.py b/data_juicer/core/executor/elastic_scheduling_mixin.py new file mode 100644 index 00000000000..c06b3e93626 --- /dev/null +++ b/data_juicer/core/executor/elastic_scheduling_mixin.py @@ -0,0 +1,594 @@ +""" +Elastic Scheduling Mixin for GPU-Adaptive Batch Processing + +Provides reusable ElasticJuicer adaptive scheduling logic that can be mixed +into any Ray-based executor. Extracted from ElasticRayExecutor to allow +both ElasticRayExecutor and PartitionedRayExecutor to share the same +GPU-adaptive scheduling infrastructure. + +The bi-level scheduling architecture: +- Tower (macro-scheduler): Global resource allocation and rebalancing +- Captains (micro-schedulers): Per-operator batch size control with PID +- MicroScheduler: Per-actor adaptive batch sizing within map_batches + +Usage: + class MyExecutor(ExecutorBase, ElasticSchedulingMixin, ...): + def __init__(self, cfg): + ... + ElasticSchedulingMixin.__init__(self) + ... +""" + +import os +import time +from dataclasses import asdict +from functools import partial +from typing import Any, Dict, List, Optional + +from loguru import logger + +from data_juicer.ops import Deduplicator, Filter, Mapper, Pipeline +from data_juicer.ops.base_op import DEFAULT_BATCH_SIZE, TAGGING_OPS +from data_juicer.utils.constant import Fields +from data_juicer.utils.lazy_loader import LazyLoader + +ray = LazyLoader("ray") +pyarrow = LazyLoader("pyarrow") + + +def filter_batch(batch, filter_func): + """Filter batch using filter function.""" + import pyarrow + + mask = pyarrow.array(filter_func(batch.to_pydict())) + return batch.filter(mask) + + +class ElasticSchedulingMixin: + """ + Mixin providing ElasticJuicer adaptive GPU scheduling capabilities. + + Any executor that mixes this in gains: + - Adaptive batch sizing for GPU operators via AdaptiveOperator/MicroScheduler + - PID-controlled memory management to prevent OOM + - Tower-based global resource allocation (optional rebalancing) + - Per-stage metrics collection + + Requirements on the host class: + - Must have ``self.cfg`` (Namespace with config, including elastic_juicer) + - Must be running in a Ray environment + """ + + def _init_elastic_scheduling(self): + """Initialize elastic scheduling state. Call from host __init__.""" + self._elastic = None # ElasticJuicer facade instance + self._scheduler_config = None + self._elastic_started = False + self._profiling_store = None # ProfilingStore for prior loading + flush + + # ------------------------------------------------------------------ + # Config parsing + # ------------------------------------------------------------------ + + def _parse_elastic_juicer_config(self, has_gpu_ops: bool = True): + """Parse elastic_juicer config section and create SchedulerConfig. + + Auto-selects appropriate preset when no explicit config is provided: + - GPU pipelines default to 'gpu' preset + - CPU-only pipelines default to 'cpu_optimized' preset + + Args: + has_gpu_ops: Whether the pipeline contains GPU operators. + Used to select the appropriate default preset. + """ + from data_juicer.core.elasticjuicer.scheduler.scheduler_config import ( + SchedulerConfig, + ) + + elastic_cfg = getattr(self.cfg, "elastic_juicer", None) + if elastic_cfg is None: + elastic_cfg = {} + elif hasattr(elastic_cfg, "__dict__"): + # Convert Namespace to dict + elastic_cfg = vars(elastic_cfg) + + default_preset = "gpu" if has_gpu_ops else "cpu_optimized" + scheduler_preset = elastic_cfg.get("scheduler_preset", default_preset) + config_path = elastic_cfg.get("config_path", None) + rebalance_interval = elastic_cfg.get("rebalance_interval", 5.0) + + # Load or create SchedulerConfig + if config_path and os.path.exists(config_path): + self._scheduler_config = SchedulerConfig.from_yaml(config_path) + logger.info(f"Loaded ElasticJuicer config from {config_path}") + else: + presets = { + "conservative": SchedulerConfig.conservative, + "gpu": SchedulerConfig.gpu, + "gpu_optimized": SchedulerConfig.gpu_optimized, + "cpu_optimized": SchedulerConfig.cpu_optimized, + "aggressive": SchedulerConfig.aggressive, + "memory_constrained": SchedulerConfig.memory_constrained, + } + factory = presets.get(scheduler_preset, SchedulerConfig.gpu) + self._scheduler_config = factory() + logger.info( + f"Using ElasticJuicer '{scheduler_preset}' preset config" + ) + + # Override rebalance interval if specified + self._scheduler_config.rebalance_interval_sec = rebalance_interval + + return elastic_cfg + + # ------------------------------------------------------------------ + # Component initialization and lifecycle + # ------------------------------------------------------------------ + + def _init_elasticjuicer_components(self, ops: List): + """Initialize ElasticJuicer facade with all scheduling components. + + Creates stage configs from operators and registers them with the + ElasticJuicer facade, which manages Tower, Captains, and Ray actors. + + When a ProfilingStore directory from a prior run exists the method + loads per-op priors (peak memory, safe batch size) and uses them + as ``initial_batch_size`` instead of the blind SchedulerConfig default. + """ + from data_juicer.core.elasticjuicer import ElasticJuicer + from data_juicer.core.elasticjuicer.profiler.profiling_store import ( + ProfilingStore, + ) + + # Initialize ProfilingStore (loads priors from previous runs if any) + elastic_cfg = getattr(self.cfg, "elastic_juicer", None) or {} + if hasattr(elastic_cfg, "__dict__"): + elastic_cfg = vars(elastic_cfg) + store_dir = elastic_cfg.get( + "profiling_store_dir", "./elastic_juicer_profiles" + ) + try: + self._profiling_store = ProfilingStore(storage_dir=store_dir) + except Exception as e: + logger.warning("Failed to initialize ProfilingStore: %s", e) + self._profiling_store = None + + # Build stage configs from operators (with prior loading) + stage_configs = [] + for i, op in enumerate(ops): + op_name = f"stage_{i}_{op._name}" + + # Try to load prior from a previous run + prior_bs = None + if self._profiling_store is not None: + prior_stats = self._profiling_store.get_execution_stats( + op._name + ) + if prior_stats and prior_stats.peak_memory_mb > 0: + prior_bs = self._profiling_store.get_safe_batch_size( + op._name, + available_memory_mb=( + self._scheduler_config.safety_buffer_mb * 4 + ), + ) + logger.info( + "Loaded prior for %s: peak_memory=%.1f MB, " + "recommended_bs=%d", + op._name, + prior_stats.peak_memory_mb, + prior_bs, + ) + + # Priority: op.batch_size > prior_bs > scheduler default + initial_bs = ( + getattr(op, "batch_size", None) + or prior_bs + or self._scheduler_config.initial_batch_size + ) + + stage_configs.append( + { + "name": op_name, + "batch_size": initial_bs, + "num_gpus": op.num_gpus or 0, + "num_actors": ( + op.runtime_np() + if hasattr(op, "runtime_np") + else (op.num_proc or 1) + ), + } + ) + + # Create ElasticJuicer facade + self._elastic = ElasticJuicer( + config=self._scheduler_config, + cluster_state=None, # auto-detect + ) + + # Register all stages with the facade + self._elastic.register_stages(stage_configs) + + logger.info( + f"ElasticJuicer facade initialized with {len(ops)} stages" + ) + + def _start_elastic(self): + """Start the ElasticJuicer facade (Tower + MetricsBridge + actors).""" + if self._elastic is not None and not self._elastic_started: + self._elastic.start() + self._elastic_started = True + + def _stop_elastic(self): + """Stop the ElasticJuicer facade, flush stats, and clean up.""" + if self._elastic is not None and self._elastic_started: + # Flush runtime stats to ProfilingStore before stopping + if self._profiling_store is not None: + try: + captains = getattr(self._elastic, "captains", {}) + total_flushed = 0 + for name, captain in captains.items(): + if hasattr(captain, "monitor"): + n = captain.monitor.flush_to_store( + self._profiling_store + ) + total_flushed += n + if total_flushed > 0: + self._profiling_store.save_all() + logger.info( + "Flushed %d op stats to ProfilingStore on stop", + total_flushed, + ) + except Exception as e: + logger.warning( + "Failed to flush stats to ProfilingStore: %s", e + ) + + self._elastic.stop() + self._elastic_started = False + + # ------------------------------------------------------------------ + # GPU operator detection + # ------------------------------------------------------------------ + + @staticmethod + def _has_gpu_ops(ops: List) -> bool: + """Return True if any operator in *ops* requires GPU.""" + return any((getattr(op, "num_gpus", 0) or 0) > 0 for op in ops) + + def _has_elastic_config(self) -> bool: + """Return True if elastic_juicer config is present and non-empty. + + Note: For ElasticRayExecutor, elastic scheduling auto-activates with + sensible defaults even without explicit config. This method is still + used by PartitionedRayExecutor to gate optional elastic activation. + """ + elastic_cfg = getattr(self.cfg, "elastic_juicer", None) + if elastic_cfg is None: + return False + if hasattr(elastic_cfg, "__dict__"): + return bool(vars(elastic_cfg)) + if isinstance(elastic_cfg, dict): + return bool(elastic_cfg) + return False + + def _should_use_elastic(self, ops: List) -> bool: + """Determine whether elastic scheduling should be activated. + + Returns True when GPU operators are present. For CPU-only pipelines, + ElasticRayExecutor uses a lightweight adaptive batching path instead + of the full Tower/Captain hierarchy. + """ + return self._has_gpu_ops(ops) + + # ------------------------------------------------------------------ + # Per-operator elastic execution + # ------------------------------------------------------------------ + + def _run_single_op_elastic( + self, + ds, + op, + stage_index: int, + cached_columns: set, + scheduler_config_dict: Dict[str, Any], + ): + """ + Execute a single operator with ElasticJuicer adaptive scheduling. + + For GPU operators: use AdaptiveOperator wrapper with MicroScheduler + For CPU operators: use standard execution + """ + import pyarrow as pa + from ray.data import ActorPoolStrategy + + stage_name = f"stage_{stage_index}_{op._name}" + + # Handle tagging ops - add meta column if needed + if ( + op._name in TAGGING_OPS.modules + and Fields.meta not in cached_columns + ): + + def process_batch_arrow(table: pa.Table): + new_column_data = [{} for _ in range(len(table))] + return table.append_column(Fields.meta, [new_column_data]) + + ds = ds.map_batches( + process_batch_arrow, + batch_format="pyarrow", + batch_size=DEFAULT_BATCH_SIZE, + ) + cached_columns.add(Fields.meta) + + batch_size = ( + getattr(op, "batch_size", 1) if op.is_batched_op() else 1 + ) + + if isinstance(op, Mapper): + if op.use_ray_actor(): + # GPU Mapper: use AdaptiveOperator wrapper + ds = self._apply_adaptive_gpu_op( + ds, op, stage_name, batch_size, scheduler_config_dict + ) + else: + # CPU Mapper: standard execution + from ray.data._internal.util import get_compute_strategy + + num_proc = ( + op.num_proc if op.num_proc and op.num_proc > 0 else None + ) + compute = get_compute_strategy( + op.process, concurrency=num_proc + ) + map_batches_kwargs = dict( + batch_size=batch_size, + batch_format="pyarrow", + num_cpus=op.num_cpus, + num_gpus=op.num_gpus, + compute=compute, + runtime_env=op.runtime_env, + ) + ds = ds.map_batches(op.process, **map_batches_kwargs) + + elif isinstance(op, Filter): + # Ensure stats column exists + if Fields.stats not in cached_columns: + + def process_batch_arrow(table: pa.Table): + new_column_data = [{} for _ in range(len(table))] + return table.append_column( + Fields.stats, [new_column_data] + ) + + ds = ds.map_batches( + process_batch_arrow, + batch_format="pyarrow", + batch_size=DEFAULT_BATCH_SIZE, + ) + cached_columns.add(Fields.stats) + + if op.use_ray_actor(): + # GPU Filter: use AdaptiveOperator wrapper + ds = self._apply_adaptive_gpu_op( + ds, op, stage_name, batch_size, scheduler_config_dict + ) + else: + # CPU Filter: compute_stats then filter + from ray.data._internal.util import get_compute_strategy + + num_proc = ( + op.num_proc if op.num_proc and op.num_proc > 0 else None + ) + compute = get_compute_strategy( + op.compute_stats, concurrency=num_proc + ) + map_batches_kwargs = dict( + batch_size=batch_size, + batch_format="pyarrow", + num_cpus=op.num_cpus, + num_gpus=op.num_gpus, + compute=compute, + runtime_env=op.runtime_env, + ) + ds = ds.map_batches(op.compute_stats, **map_batches_kwargs) + + # Apply filter (for both GPU and CPU filters) + if op.is_batched_op(): + ds = ds.map_batches( + partial(filter_batch, filter_func=op.process), + batch_format="pyarrow", + zero_copy_batch=True, + batch_size=DEFAULT_BATCH_SIZE, + runtime_env=op.runtime_env, + ) + else: + ds = ds.filter(op.process, runtime_env=op.runtime_env) + + elif isinstance(op, (Deduplicator, Pipeline)): + # Global ops: run directly + ds = op.run(ds) + + else: + logger.error( + "ElasticSchedulingMixin only supports Filter, Mapper, " + "Deduplicator and Pipeline OPs" + ) + raise NotImplementedError + + return ds, cached_columns + + def _apply_adaptive_gpu_op( + self, + ds, + op, + stage_name: str, + batch_size: int, + scheduler_config_dict: Dict[str, Any], + ): + """Apply GPU operator with AdaptiveOperator wrapper.""" + from ray.data import ActorPoolStrategy + + from data_juicer.core.executor.elastic_ray_executor import ( + AdaptiveOperator, + ) + + # Repartition for GPU actors + # Use op.num_proc directly (not runtime_np()) to respect the + # user-configured actor count from the benchmark/config. + # runtime_np() auto-calculates based on system resources and can + # overshoot, requesting more actors than GPUs can support. + num_actors = op.num_proc or 1 + override_num_blocks = getattr(op, "override_num_blocks", None) + if override_num_blocks is not None: + ds = ds.repartition(override_num_blocks) + else: + ds = ds.repartition(num_actors * 2) + + logger.info( + f"[{stage_name}] Actors: {num_actors}, " + f"GPU/actor: {op.num_gpus}, batch_size: {batch_size}" + ) + + # Extract operator config for AdaptiveOperator + op_kwargs = {} + if hasattr(op, "_op_cfg") and op._op_cfg: + # _op_cfg is like {"op_name": {kwargs}} + op_name, op_args = list(op._op_cfg.items())[0] + op_kwargs = dict(op_args) if op_args else {} + else: + # Fallback: extract from _init_kwargs + if hasattr(op, "_init_kwargs") and op._init_kwargs: + op_kwargs = dict(op._init_kwargs) + + # Get adaptive config from facade (if available) + adaptive_config = {} + if self._elastic is not None: + adaptive_config = self._elastic.get_adaptive_op_config(stage_name) + # Update with auto-scaled batch_size + adaptive_config["initial_batch_size"] = batch_size + else: + # Fallback if facade not initialized + adaptive_config = { + "stage_name": stage_name, + "initial_batch_size": batch_size, + "scheduler_config_dict": scheduler_config_dict, + } + + # Use map_batches with AdaptiveOperator class + # Forward custom_operator_paths so that fresh Ray worker actors can + # re-register custom ops before the OPERATORS lookup in + # AdaptiveOperator.__init__. + custom_operator_paths = getattr( + self.cfg, "custom_operator_paths", None + ) + ds = ds.map_batches( + AdaptiveOperator, + fn_constructor_kwargs={ + "op_class_name": op._name, + "op_kwargs": op_kwargs, + "custom_operator_paths": custom_operator_paths, + **adaptive_config, + }, + batch_size=batch_size, + num_cpus=op.num_cpus or 1, + num_gpus=op.num_gpus or 1, + compute=ActorPoolStrategy(size=num_actors), + batch_format="pyarrow", + runtime_env=op.runtime_env, + ) + + return ds + + # ------------------------------------------------------------------ + # Convenience: process a list of ops with elastic scheduling + # ------------------------------------------------------------------ + + def _process_ops_elastic(self, ds, ops: List): + """ + Process a list of operators on a Ray dataset using elastic scheduling. + + This is the elastic equivalent of ``dataset.process(ops)``. + Iterates over each op and dispatches to ``_run_single_op_elastic``. + + Returns: + Materialized Ray dataset after all ops. + """ + scheduler_config_dict = asdict(self._scheduler_config) + cached_columns = set(ds.columns()) if ds.columns() else set() + + for i, op in enumerate(ops): + try: + ds, cached_columns = self._run_single_op_elastic( + ds, op, i, cached_columns, scheduler_config_dict + ) + except Exception as e: + logger.error(f"Error processing operator {op}: {e}") + if op.runtime_env is not None: + logger.error( + "Trying to fallback to base runtime environment" + ) + original_runtime_env = op.runtime_env + try: + op.runtime_env = None + ds, cached_columns = self._run_single_op_elastic( + ds, op, i, cached_columns, scheduler_config_dict + ) + finally: + op.runtime_env = original_runtime_env + else: + raise e + + # Materialize dataset + logger.info("Materializing dataset to collect real metrics...") + ds = ds.materialize() + return ds + + # ------------------------------------------------------------------ + # Metrics logging + # ------------------------------------------------------------------ + + def _log_elasticjuicer_metrics( + self, input_rows: int, output_rows: int, duration: float + ): + """Log ElasticJuicer metrics summary using facade.""" + throughput = input_rows / duration if duration > 0 else 0 + + logger.info("ElasticJuicer Pipeline Summary:") + logger.info( + f" Input rows: {input_rows}, Output rows: {output_rows}" + ) + logger.info( + f" Duration: {duration:.1f}s, Throughput: {throughput:.2f} samples/sec" + ) + + if self._elastic is None: + return + + # Get metrics from facade + metrics_summary = self._elastic.get_metrics_summary() + for stage_name, metrics in metrics_summary.items(): + logger.info( + f" {stage_name}: {metrics.get('total_samples', 0)} samples, " + f"bs={metrics.get('min_bs', 0)}\u2192{metrics.get('max_bs', 0)}, " + f"avg_latency={metrics.get('avg_latency_ms', 0):.1f}ms" + ) + + # Get Captain stats from facade + captain_stats = self._elastic.get_captain_stats() + if captain_stats: + logger.info(" Captain Statistics:") + for stage_name, stats in captain_stats.items(): + if stats: + logger.info( + f" [{stage_name}]: " + f"throughput={stats.get('throughput', 0):.1f} sps, " + f"latency={stats.get('latency_ms', 0):.1f}ms, " + f"batch_size={stats.get('batch_size', 0)}, " + f"backpressure={stats.get('backpressure', False)}, " + f"oom_events={stats.get('oom_count', 0)}" + ) + + # Get Tower stats from facade + tower_stats = self._elastic.get_tower_stats() + if tower_stats: + logger.info(f" Tower stats: {tower_stats}") diff --git a/data_juicer/core/executor/event_logging_mixin.py b/data_juicer/core/executor/event_logging_mixin.py index c994b455ad5..14c0c34e547 100644 --- a/data_juicer/core/executor/event_logging_mixin.py +++ b/data_juicer/core/executor/event_logging_mixin.py @@ -646,6 +646,16 @@ def log_op_complete( "operation_class": operation_name, } + # Ensure input_rows and output_rows are integers (they might be strings from some sources) + try: + input_rows = int(input_rows) if input_rows is not None else None + except (ValueError, TypeError): + input_rows = None + try: + output_rows = int(output_rows) if output_rows is not None else None + except (ValueError, TypeError): + output_rows = None + # Only include row counts and derived metrics if they're meaningful (non-zero or explicitly set) if input_rows is not None and input_rows > 0: metadata["input_rows"] = input_rows diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py index d507b0efb11..3d2b1b25630 100644 --- a/data_juicer/core/executor/factory.py +++ b/data_juicer/core/executor/factory.py @@ -1,3 +1,5 @@ +from loguru import logger + from .base import ExecutorBase from .default_executor import DefaultExecutor @@ -8,13 +10,17 @@ def create_executor(executor_type: str) -> ExecutorBase: if executor_type in ("local", "default"): return DefaultExecutor elif executor_type == "ray": - from .ray_executor import RayExecutor - - return RayExecutor + from .elastic_ray_executor import ElasticRayExecutor + logger.info('Using ElasticRayExecutor (adaptive scheduling enabled for GPU operators)') + return ElasticRayExecutor elif executor_type == "ray_partitioned": from .ray_executor_partitioned import PartitionedRayExecutor return PartitionedRayExecutor + elif executor_type == "elastic_ray": + from .elastic_ray_executor import ElasticRayExecutor + + return ElasticRayExecutor # TODO: add nemo support # elif executor_type == "nemo": # return NemoExecutor diff --git a/data_juicer/core/executor/ray_executor_partitioned.py b/data_juicer/core/executor/ray_executor_partitioned.py index ddcb1fc4425..342b012a77b 100644 --- a/data_juicer/core/executor/ray_executor_partitioned.py +++ b/data_juicer/core/executor/ray_executor_partitioned.py @@ -13,6 +13,7 @@ import os import shutil import time +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional @@ -179,9 +180,13 @@ def __init__(self, cfg: Optional[Namespace] = None): checkpoint_cfg = getattr(self.cfg, "checkpoint", None) checkpoint_dir = getattr(self.cfg, "checkpoint_dir", os.path.join(self.work_dir, "checkpoints")) + # Debug: log checkpoint_cfg type and value + logger.info(f"DEBUG: checkpoint_cfg type = {type(checkpoint_cfg)}, value = {checkpoint_cfg}") + if checkpoint_cfg: # Use ConfigAccessor to handle both dict and object configurations checkpoint_enabled = ConfigAccessor.get(checkpoint_cfg, "enabled", True) + logger.info(f"DEBUG: checkpoint_enabled from ConfigAccessor = {checkpoint_enabled}") strategy_str = ConfigAccessor.get(checkpoint_cfg, "strategy", "every_op") checkpoint_n_ops = ConfigAccessor.get(checkpoint_cfg, "n_ops", 1) checkpoint_op_names = ConfigAccessor.get(checkpoint_cfg, "op_names", []) @@ -449,11 +454,21 @@ def _run_impl(self, load_data_np: Optional[PositiveInt] = None, skip_return=Fals # Detect convergence points for global operations convergence_points = self._detect_convergence_points(self.cfg) + # Debug logging for checkpoint status + logger.info(f"DEBUG: checkpoint_enabled = {self.ckpt_manager.checkpoint_enabled}") + logger.info(f"DEBUG: convergence_points = {convergence_points}") + + # Choose processing strategy based on checkpointing and convergence points + # Fast path: when checkpointing is disabled and no convergence points, + # process without manual partitioning to let Ray Data handle parallelism if convergence_points: logger.info(f"Found convergence points at operations: {convergence_points}") final_dataset = self._process_with_convergence(dataset, ops, convergence_points) + elif not self.ckpt_manager.checkpoint_enabled: + logger.info("Checkpointing disabled, using fast path without manual partitioning") + final_dataset = self._process_without_partitioning(dataset, ops) else: - logger.info("No convergence points found, processing with simple partitioning") + logger.info("Checkpointing enabled, processing with partitioning for checkpoint support") final_dataset = self._process_with_simple_partitioning(dataset, ops) # Export final dataset @@ -482,6 +497,42 @@ def cleanup_temp_files(self): else: logger.info("No temporary files found to clean up") + def _process_without_partitioning(self, dataset: RayDataset, ops: List) -> RayDataset: + """ + Process dataset without manual partitioning. + + This is the fast path when checkpointing is disabled. + Ray Data handles parallelism automatically through map_batches with concurrency. + """ + logger.info("Processing without manual partitioning (fast path)...") + + start_time = time.time() + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(ops, partition_id=0) + + # Execute operations (lazy evaluation - Ray Data handles parallelism) + processed_dataset = dataset.process(ops) + + # Force materialization only at the end (required for export anyway) + logger.info("Materializing final dataset...") + processed_dataset.data = processed_dataset.data.materialize() + + duration = time.time() - start_time + logger.info(f"Processing completed in {duration:.2f}s") + + # Post-execute DAG monitoring + if self.pipeline_dag: + try: + output_rows = processed_dataset.data.count() + metrics = {"duration": duration, "input_rows": "unknown", "output_rows": output_rows} + except Exception: + metrics = {"duration": duration} + self._post_execute_operations_with_dag_monitoring(ops, partition_id=0, metrics=metrics) + + return processed_dataset + def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): """ Process dataset with real partitioning using Ray Data's split and union. @@ -500,9 +551,10 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): # Process each partition separately with checkpointing logger.info("Processing partitions with checkpointing support...") - processed_partitions = [] + processed_partitions = [None] * len(partitions) - for i, partition in enumerate(partitions): + def process_single_partition(i, partition): + """Helper function to process a single partition.""" logger.info(f"Processing partition {i+1}/{len(partitions)}") # Log partition start event @@ -518,9 +570,6 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): # Apply operations with checkpointing support and DAG monitoring processed_partition = self._process_with_checkpointing(partition_dataset, i, ops) - # Store the processed partition's data - processed_partitions.append(processed_partition.data) - # Log partition completion event self._log_event( event_type=EventType.PARTITION_COMPLETE, @@ -528,6 +577,18 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): partition_id=i, ) + return i, processed_partition.data + + # Process partitions in parallel using ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=len(partitions)) as executor: + futures = { + executor.submit(process_single_partition, i, partition): i + for i, partition in enumerate(partitions) + } + for future in as_completed(futures): + i, result = future.result() + processed_partitions[i] = result + # Merge all processed partitions back into a single dataset logger.info("Merging processed partitions...") if len(processed_partitions) == 1: @@ -1001,4 +1062,4 @@ def _clear_invalid_checkpoints(self) -> None: if os.path.exists(self.ckpt_manager.ckpt_dir): logger.warning(f"Clearing invalid checkpoints in {self.ckpt_manager.ckpt_dir}") shutil.rmtree(self.ckpt_manager.ckpt_dir) - os.makedirs(self.ckpt_manager.ckpt_dir, exist_ok=True) + os.makedirs(self.ckpt_manager.ckpt_dir, exist_ok=True) \ No newline at end of file diff --git a/data_juicer/core/ray_exporter.py b/data_juicer/core/ray_exporter.py index ea4b700ae97..a6b98e62f09 100644 --- a/data_juicer/core/ray_exporter.py +++ b/data_juicer/core/ray_exporter.py @@ -131,12 +131,33 @@ def _export_impl(self, dataset, export_path, columns=None): :param columns: the columns to export. :return: """ + # Debug: Log dataset info before export + logger.info(f"Starting export to: {export_path}") + + # Materialize the dataset first to get accurate count + # Ray dataset needs to be materialized before calling count() or num_blocks() + try: + dataset = dataset.materialize() + logger.info("Dataset materialized successfully") + except Exception as e: + logger.warning(f"Dataset materialize failed (may already be materialized): {e}") + + # Get row count for validation + try: + row_count = dataset.count() + logger.info(f"Dataset row count before export: {row_count}") + if row_count == 0: + logger.warning("Dataset is empty (0 rows)! Export will produce no data files.") + except Exception as e: + logger.warning(f"Could not get dataset row count: {e}") + row_count = None + # Handle empty dataset case - Ray returns None for columns() on empty datasets # Check if dataset is empty by calling columns() regardless of columns parameter cols = dataset.columns() if cols is None: # Empty dataset with unknown schema - create an empty file - logger.warning(f"Dataset is empty, creating empty export file at {export_path}") + logger.warning(f"Dataset is empty (no columns), creating empty export file at {export_path}") os.makedirs(os.path.dirname(export_path) or ".", exist_ok=True) with open(export_path, "w"): pass # Create empty file @@ -182,7 +203,28 @@ def _export_impl(self, dataset, export_path, columns=None): if not export_path.startswith("s3://"): os.makedirs(export_path, exist_ok=True) - return export_method(dataset, export_path, **export_kwargs) + result = export_method(dataset, export_path, **export_kwargs) + + # Post-export verification: check if files were actually written + if not export_path.startswith("s3://"): + if os.path.isdir(export_path): + files = os.listdir(export_path) + if files: + logger.info(f"Export verification: {len(files)} file(s) written to {export_path}") + # Log first few files for debugging + for f in files[:5]: + file_path = os.path.join(export_path, f) + file_size = os.path.getsize(file_path) + logger.info(f" - {f} ({file_size} bytes)") + if len(files) > 5: + logger.info(f" ... and {len(files) - 5} more files") + else: + logger.warning(f"Export verification FAILED: No files written to {export_path}!") + logger.warning("This may indicate the dataset was empty or an export error occurred.") + else: + logger.warning(f"Export path {export_path} is not a directory after export!") + + return result def export(self, dataset, columns=None): """ diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 4ae7d07bc99..6b4de22c399 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -402,6 +402,10 @@ def __init__(self, *args, **kwargs): self.ray_execution_mode = kwargs.get("ray_execution_mode", None) assert self.ray_execution_mode in [None, "actor", "task"] + # Override the number of output blocks for Ray Data map_batches + # (helps prevent Ray block starvation when Ray fuses/coalesces blocks) + self.override_num_blocks = kwargs.get("override_num_blocks", None) + # Local import to avoid logger being serialized in multiprocessing from loguru import logger diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index 4a7901d98b1..c924ce825ff 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -123,7 +124,10 @@ def __init__( trust_remote_code=trust_remote_code, ) # the original score predicted by laion-ai's scorer is within [0, 10] - self.need_normalized_by_ten = "shunk031/aesthetics-predictor" in hf_scorer_model + self.need_normalized_by_ten = ( + "shunk031/aesthetics-predictor" in hf_scorer_model + or "aesthetics-predictor" in os.path.basename(hf_scorer_model) + ) self.frame_sampling_method = frame_sampling_method self.frame_num = frame_num diff --git a/tests/core/elasticjuicer/__init__.py b/tests/core/elasticjuicer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/core/elasticjuicer/test_probe_adapter.py b/tests/core/elasticjuicer/test_probe_adapter.py new file mode 100644 index 00000000000..c4a5452d564 --- /dev/null +++ b/tests/core/elasticjuicer/test_probe_adapter.py @@ -0,0 +1,310 @@ +"""Unit tests for ``ProbeAdapter``. + +Tests the schema translation between DJ Adapter probe outputs and +ElasticJuicer ProfilingStore schema. Pure unit tests -- no real DJ Adapter +or pipeline required. Uses a real ``ProfilingStore`` backed by a temp dir +(so we exercise the actual store contract rather than mocks). + +Run from repo root:: + + python -m unittest tests/core/elasticjuicer/test_probe_adapter.py -v + +or:: + + pytest tests/core/elasticjuicer/test_probe_adapter.py -v +""" + +import logging +import shutil +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict, List, Optional + +from data_juicer.core.elasticjuicer.profiler.ocs_annotator import ( + MemoryLocality, +) +from data_juicer.core.elasticjuicer.profiler.probe_adapter import ProbeAdapter +from data_juicer.core.elasticjuicer.profiler.profiling_store import ProfilingStore + + +# ---------------------------------------------------------------------- +# Shared probe-dict factory +# ---------------------------------------------------------------------- + +def _make_probe( + *, + cpu_util_ratio: float = 0.85, + used_mem_mb: float = 12000.0, + gpu_used_mb: Optional[List[float]] = None, + gpu_util_ratios: Optional[List[float]] = None, + speed: float = 200.0, + total_time: float = 5.0, + n_snapshots: int = 1, + timestamp: float = 1.0, +) -> Dict[str, Any]: + """Build a DJ-Adapter-shaped probe dict for testing. + + Mirrors the dict shape produced by ``Adapter.execute_and_probe`` + + ``Monitor.analyze_resource_util_list`` (per ``data_juicer/core/monitor.py``). + """ + resource = [] + for i in range(n_snapshots): + snap = { + "timestamp": timestamp + i, + "CPU util.": cpu_util_ratio, + "Used mem.": used_mem_mb, + "GPU used mem.": gpu_used_mb, + "GPU util.": gpu_util_ratios, + } + resource.append(snap) + + analysis: Dict[str, Dict[str, float]] = {} + if cpu_util_ratio is not None: + analysis["CPU util."] = { + "max": cpu_util_ratio, + "min": cpu_util_ratio, + "avg": cpu_util_ratio, + } + if used_mem_mb is not None: + analysis["Used mem."] = { + "max": used_mem_mb, + "min": used_mem_mb * 0.9, + "avg": used_mem_mb * 0.95, + } + if gpu_util_ratios: + gpu_max = max(gpu_util_ratios) + analysis["GPU util."] = { + "max": gpu_max, + "min": gpu_max * 0.5, + "avg": gpu_max * 0.75, + } + + return { + "time": total_time, + "sampling interval": 0.5, + "speed": speed, + "resource": resource, + "resource_analysis": analysis, + } + + +# ---------------------------------------------------------------------- +# Tests +# ---------------------------------------------------------------------- + +class TestProbeAdapter(unittest.TestCase): + def setUp(self) -> None: + self.tmpdir = Path(tempfile.mkdtemp(prefix="ej_probe_adapter_test_")) + self.store = ProfilingStore(storage_dir=str(self.tmpdir)) + self.bridge = ProbeAdapter(self.store) + + def tearDown(self) -> None: + shutil.rmtree(self.tmpdir, ignore_errors=True) + + # -- Translation primitives ------------------------------------------------- + + def test_translate_one_basic(self) -> None: + """Field-by-field translation for a typical CPU-only probe.""" + probe = _make_probe( + cpu_util_ratio=0.85, + used_mem_mb=12000.0, + speed=200.0, + n_snapshots=1, + timestamp=1.0, + ) + + stats, _ = self.bridge._translate_one(probe, "myop", batch_size=1000) + + self.assertEqual(stats.op_name, "myop") + self.assertEqual(len(stats.snapshots), 1) + + snap = stats.snapshots[0] + # Ratio -> percent + self.assertAlmostEqual(snap.cpu_percent, 85.0, places=5) + # System MB copied directly (lossy: marked low-confidence elsewhere) + self.assertAlmostEqual(snap.memory_mb, 12000.0, places=5) + # Externally-injected batch size + self.assertEqual(snap.batch_size, 1000) + # speed -> throughput + self.assertAlmostEqual(snap.throughput, 200.0, places=5) + # 1000 samples / 200 samples-per-sec = 5 s = 5000 ms per batch + self.assertAlmostEqual(snap.latency_ms, 5000.0, places=2) + self.assertEqual(snap.timestamp, 1.0) + # No GPU data + self.assertIsNone(snap.gpu_memory_mb) + self.assertIsNone(snap.gpu_utilization) + + def test_extract_gpu_mem_single_gpu(self) -> None: + """Single-GPU list returns the value with full confidence.""" + raw = {"GPU used mem.": [8192.0]} + val, conf = ProbeAdapter._extract_gpu_mem(raw) + self.assertEqual(val, 8192.0) + self.assertEqual(conf, ProbeAdapter.CONFIDENCE_FULL) + + def test_extract_gpu_mem_no_gpu(self) -> None: + """Missing / None / empty GPU list returns (None, full confidence).""" + for raw in ({}, {"GPU used mem.": None}, {"GPU used mem.": []}): + val, conf = ProbeAdapter._extract_gpu_mem(raw) + self.assertIsNone(val) + self.assertEqual(conf, ProbeAdapter.CONFIDENCE_FULL) + + def test_extract_gpu_mem_multi_gpu_warns(self) -> None: + """Multi-GPU collapses to gpus[0] with reduced confidence + WARNING log.""" + raw = {"GPU used mem.": [8192.0, 4096.0, 2048.0]} + + with self.assertLogs( + logger="data_juicer.core.elasticjuicer.profiler.probe_adapter", + level=logging.WARNING, + ) as captured: + val, conf = ProbeAdapter._extract_gpu_mem(raw) + + self.assertEqual(val, 8192.0) + self.assertEqual(conf, ProbeAdapter.CONFIDENCE_FIRST_GPU_ONLY) + self.assertTrue( + any("Multi-GPU" in line for line in captured.output), + f"expected 'Multi-GPU' substring in log lines, got: {captured.output}", + ) + + def test_extract_gpu_util_percent_conversion(self) -> None: + """GPU util ratio in [0, 1] -> percent in [0, 100].""" + raw = {"GPU util.": [0.7]} + val, conf = ProbeAdapter._extract_gpu_util(raw) + self.assertAlmostEqual(val, 70.0, places=5) + self.assertEqual(conf, ProbeAdapter.CONFIDENCE_FULL) + + def test_speed_zero_safe(self) -> None: + """speed=0 must not raise (no divide-by-zero); latency reported as 0.""" + probe = _make_probe(speed=0.0) + stats, conf = self.bridge._translate_one(probe, "myop", batch_size=100) + self.assertEqual(len(stats.snapshots), 1) + self.assertEqual(stats.snapshots[0].latency_ms, 0.0) + # latency confidence should reflect that we couldn't compute it + self.assertEqual(conf["latency_ms"], 0.0) + + # -- Signature derivation --------------------------------------------------- + + def test_derive_signature_gpu_strong(self) -> None: + """GPU util max > 0.5 -> memory_locality = GPU_STRONG.""" + probe = _make_probe(gpu_used_mb=[4096.0], gpu_util_ratios=[0.85]) + stats, _ = self.bridge._translate_one(probe, "image_op", batch_size=16) + + sig = self.bridge._derive_signature(probe, "image_op", stats) + + self.assertIsNotNone(sig) + self.assertEqual(sig.memory_locality, MemoryLocality.GPU_STRONG) + + def test_derive_signature_max_memory_from_stats(self) -> None: + """max_memory_mb in signature comes from stats.peak_memory_mb.""" + probe = _make_probe(used_mem_mb=8000.0) + stats, _ = self.bridge._translate_one(probe, "myop", batch_size=100) + + sig = self.bridge._derive_signature(probe, "myop", stats) + + self.assertIsNotNone(sig) + self.assertAlmostEqual(sig.max_memory_mb, stats.peak_memory_mb) + # preferred_batch_size set from first snapshot's batch_size + self.assertEqual(sig.preferred_batch_size, 100) + + # -- End-to-end ingest ------------------------------------------------------ + + def test_ingest_writes_to_store(self) -> None: + """Full ingest: probe results -> ProfilingStore has both stats and sig.""" + probes = [_make_probe(speed=100.0), _make_probe(speed=200.0)] + + written = self.bridge.ingest_probe_results( + probe_results=probes, + op_names=["op_a", "op_b"], + probe_batch_sizes=[500, 1000], + ) + + self.assertEqual(set(written.keys()), {"op_a", "op_b"}) + # Stats persisted + self.assertIsNotNone(self.store.get_execution_stats("op_a")) + self.assertIsNotNone(self.store.get_execution_stats("op_b")) + # Signatures persisted + self.assertIsNotNone(self.store.get_ocs_signature("op_a")) + self.assertIsNotNone(self.store.get_ocs_signature("op_b")) + + def test_confidence_marker_system_memory(self) -> None: + """memory_mb is always tracked with reduced confidence (system->process).""" + probe = _make_probe(used_mem_mb=12000.0) + self.bridge.ingest_probe_results([probe], ["op_a"], [100]) + + # System-memory translation is always lossy + self.assertEqual( + self.bridge.get_confidence("op_a", "memory_mb"), + ProbeAdapter.CONFIDENCE_SYSTEM_MEMORY_AS_PROCESS, + ) + # cpu_percent is a direct math operation, full confidence + self.assertEqual( + self.bridge.get_confidence("op_a", "cpu_percent"), + ProbeAdapter.CONFIDENCE_FULL, + ) + + def test_confidence_marker_multi_gpu_downgrade(self) -> None: + """If any snapshot in the probe is multi-GPU, op-level GPU confidence drops.""" + probe = _make_probe( + gpu_used_mb=[8192.0, 4096.0], # two GPUs + gpu_util_ratios=[0.8, 0.6], + ) + # Suppress warning noise emitted by the bridge during ingest. + with self.assertLogs( + logger="data_juicer.core.elasticjuicer.profiler.probe_adapter", + level=logging.WARNING, + ): + self.bridge.ingest_probe_results([probe], ["op_a"], [16]) + + self.assertEqual( + self.bridge.get_confidence("op_a", "gpu_memory_mb"), + ProbeAdapter.CONFIDENCE_FIRST_GPU_ONLY, + ) + self.assertEqual( + self.bridge.get_confidence("op_a", "gpu_utilization"), + ProbeAdapter.CONFIDENCE_FIRST_GPU_ONLY, + ) + + def test_ingest_length_mismatch_raises(self) -> None: + """Mismatched input list lengths must raise ValueError.""" + with self.assertRaises(ValueError): + self.bridge.ingest_probe_results( + probe_results=[_make_probe()], + op_names=["a", "b"], + probe_batch_sizes=[100], + ) + + def test_empty_resource_list_yields_empty_stats(self) -> None: + """Probe with no resource snapshots -> OpExecutionStats with no snapshots.""" + probe = { + "time": 0.0, + "speed": 0.0, + "resource": [], + "resource_analysis": {}, + } + stats, _ = self.bridge._translate_one(probe, "empty_op", batch_size=1) + self.assertEqual(stats.op_name, "empty_op") + self.assertEqual(len(stats.snapshots), 0) + + def test_ingest_continues_after_per_op_failure(self) -> None: + """If one probe dict is malformed, ingest still writes the good ones.""" + bad_probe: Dict[str, Any] = {"resource": object()} # not iterable + good_probe = _make_probe() + + with self.assertLogs( + logger="data_juicer.core.elasticjuicer.profiler.probe_adapter", + level=logging.ERROR, + ): + written = self.bridge.ingest_probe_results( + probe_results=[bad_probe, good_probe], + op_names=["bad", "good"], + probe_batch_sizes=[10, 100], + ) + + self.assertIn("good", written) + self.assertNotIn("bad", written) + self.assertIsNotNone(self.store.get_execution_stats("good")) + self.assertIsNone(self.store.get_execution_stats("bad")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/process_data.py b/tools/process_data.py index 3e959618fba..a42c024da42 100644 --- a/tools/process_data.py +++ b/tools/process_data.py @@ -33,6 +33,12 @@ def main(): ) executor = PartitionedRayExecutor(cfg) + elif cfg.executor_type == "elastic_ray": + from data_juicer.core.executor.elastic_ray_executor import ( + ElasticRayExecutor, + ) + + executor = ElasticRayExecutor(cfg) else: raise ValueError(f"Unsupported executor type: {cfg.executor_type}")