diff --git a/tests/rl/test_rl_colocate_trainer.py b/tests/rl/test_rl_colocate_trainer.py index 5adcb7320..cadaf81a4 100644 --- a/tests/rl/test_rl_colocate_trainer.py +++ b/tests/rl/test_rl_colocate_trainer.py @@ -52,7 +52,6 @@ def __init__(self, uid: int): self.extra_fields = {} self.response_model_steps = [] - class _FakeSampler: def __init__(self): self._next_id = 0 @@ -148,13 +147,18 @@ def _make_trainer(self, agent_loop_manager, *, total_train_steps: int = 1, sync_ ) trainer.rollout_controller = SimpleNamespace( - ensure_workers_healthy_before_training=SimpleNamespace( - remote=MagicMock(return_value="rollout_ready_for_training") + check_and_shutdown_inactive_workers=SimpleNamespace( + remote=MagicMock(return_value="rollout_inactive_workers_shutdown") ), offload=SimpleNamespace(remote=MagicMock(return_value="rollout_offloaded")), + restart_inactive_workers=SimpleNamespace(remote=MagicMock(return_value="rollout_restarted")), + onload_weights=SimpleNamespace(remote=MagicMock(return_value="weights_loaded")), + onload_kvcache=SimpleNamespace(remote=MagicMock(return_value="kvcache_loaded")), ) trainer.train_controller = SimpleNamespace( onload=MagicMock(return_value="train_onloaded"), + offload=MagicMock(return_value="train_offloaded"), + update_weights=MagicMock(return_value="weights_updated"), fit=MagicMock( return_value=[ { @@ -220,15 +224,37 @@ async def _produce_empty(batch_size, train_step, **kwargs): trainer.train_controller.fit.assert_not_called() self.assertEqual(trainer._cur_step, 0) + def test_fit_does_not_onload_train_when_rollout_training_barrier_fails(self): + # 验证共卡训练进入训练前必须先通过 rollout phase-switch barrier; + # 失败时不能 onload 训练。 + async def _produce_batch(batch_size, train_step, *, model_step): + return ProduceBatchResult(rollout_states=[[_FakeRolloutState(train_step)]]) + + trainer = self._make_trainer(SimpleNamespace(produce_batch=_produce_batch)) + trainer.rollout_controller.check_and_shutdown_inactive_workers.remote.side_effect = RuntimeError( + "inactive rollout workers after recovery" + ) + + with ( + patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run), + patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj), + ): + with self.assertRaisesRegex(RuntimeError, "inactive rollout workers"): + trainer.fit() + + trainer.rollout_controller.check_and_shutdown_inactive_workers.remote.assert_called_once_with() + trainer.rollout_controller.offload.remote.assert_not_called() + trainer.train_controller.onload.assert_not_called() + trainer.train_controller.fit.assert_not_called() + self.assertEqual(trainer._cur_step, 0) + def test_fit_uses_sync_interval_and_passes_rollout_model_step(self): # 验证 rollout 看到的是按 sync interval 推进后的 model_step。 produce_calls = [] async def _produce_batch(batch_size, train_step, *, model_step): produce_calls.append((batch_size, train_step, model_step)) - return ProduceBatchResult( - rollout_states=[[SimpleNamespace(group_id=train_step, rollout_id=train_step)]] - ) + return ProduceBatchResult(rollout_states=[[_FakeRolloutState(train_step)]]) trainer = self._make_trainer( SimpleNamespace(produce_batch=_produce_batch), diff --git a/tests/rl/test_rl_disaggregated_trainer.py b/tests/rl/test_rl_disaggregated_trainer.py index 4f01e7d06..b6a604cac 100644 --- a/tests/rl/test_rl_disaggregated_trainer.py +++ b/tests/rl/test_rl_disaggregated_trainer.py @@ -146,7 +146,10 @@ def _make_trainer(self, agent_loop_manager): update_weights=MagicMock(return_value="update"), ) trainer.rollout_controller = SimpleNamespace( - recover_failed_workers=SimpleNamespace(remote=MagicMock(return_value="recover")), + check_and_shutdown_inactive_workers=SimpleNamespace( + remote=MagicMock(return_value="rollout_inactive_workers_shutdown") + ), + restart_inactive_workers=SimpleNamespace(remote=MagicMock(return_value="rollout_restarted")), pause_generation=SimpleNamespace(remote=MagicMock(return_value="pause")), continue_generation=SimpleNamespace(remote=MagicMock(return_value="continue")), onload_weights=SimpleNamespace(remote=MagicMock(return_value="onload_weights")), diff --git a/tests/rl/test_rl_trainer_checkpoint.py b/tests/rl/test_rl_trainer_checkpoint.py index 8b4474ac9..0aad6774b 100644 --- a/tests/rl/test_rl_trainer_checkpoint.py +++ b/tests/rl/test_rl_trainer_checkpoint.py @@ -90,8 +90,8 @@ def __init__(self): self.pause_generation = _RemoteMethod(async_result=True) self.continue_generation = _RemoteMethod(async_result=True) self.offload = _RemoteMethod(return_value="rollout_offloaded") - self.ensure_workers_healthy_before_training = _RemoteMethod(return_value="rollout_ready_for_training") - self.recover_failed_workers = _RemoteMethod(return_value="rollout_recovered") + self.check_and_shutdown_inactive_workers = _RemoteMethod(return_value="rollout_inactive_workers_shutdown") + self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted") self.onload_weights = _RemoteMethod(return_value="weights_loaded") self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded") self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}}) @@ -204,6 +204,7 @@ def build_rollout_controller(rollout_cfg, placement_group): return controller with ( + patch("ray.get", side_effect=lambda obj, timeout=None: obj), patch("xtuner.v1.rl.utils.ray_accelerator_worker.ray.is_initialized", return_value=True), patch( "xtuner.v1.rl.utils.ray_accelerator_worker.ray.available_resources", @@ -217,6 +218,12 @@ def build_rollout_controller(rollout_cfg, placement_group): patch("xtuner.v1.train.rl_trainer.BaseRLTrainer._release_trace_store", return_value=None), patch.object(WorkerConfig, "build", autospec=True, side_effect=build_train_controller), patch.object(RolloutConfig, "build", autospec=True, side_effect=build_rollout_controller), + patch.object( + RolloutConfig, + "get_controller_generate_concurrency", + autospec=True, + side_effect=lambda rollout_cfg, placement_group: rollout_cfg.generate_concurrency_per_instance, + ), ): yield runtime @@ -321,6 +328,7 @@ def _build_colocate_config( auto_resume=auto_resume, checkpoint_interval=1, checkpoint_maxkeep=None, + checkpoint_no_save_replay_buffer=True, hf_interval=-1, seed=42, exp_tracker="jsonl", @@ -361,6 +369,7 @@ def _build_disaggregated_config( auto_resume=auto_resume, checkpoint_interval=1, checkpoint_maxkeep=None, + checkpoint_no_save_replay_buffer=True, hf_interval=-1, seed=42, exp_tracker="jsonl", diff --git a/tests/rl/test_rollout_logic.py b/tests/rl/test_rollout_logic.py index a99271179..c9c45901e 100644 --- a/tests/rl/test_rollout_logic.py +++ b/tests/rl/test_rollout_logic.py @@ -3,7 +3,7 @@ 本文件合并旧的 test_rollout_worker.py 和 test_rollout_utils.py 中不依赖真实模型/后端的测试: - SGLangWorker pause/continue 对 abort flag 和 server request 的控制。 - RolloutWorker abort、abort request timeout 和 in-flight request 取消语义。 -- RolloutHealthChecker 对 inactive/unhealthy worker 的清理逻辑。 +- RolloutHealthManager 对 inactive/unhealthy worker 的生命周期标记逻辑。 - PartialRolloutHandler 拼接 routed_experts 后释放旧 Ray ObjectRef 的逻辑。 旧 test_rollout_utils.py 中的 TestRolloutControllerRecover 需要真实 Ray controller / lmdeploy backend, @@ -19,28 +19,28 @@ import httpx from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status -from xtuner.v1.rl.rollout.controller import RolloutController +from xtuner.v1.rl.rollout.controller import RolloutController, WorkerInfo +from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager from xtuner.v1.rl.rollout.sglang import SGLangWorker -from xtuner.v1.rl.rollout.utils import PartialRolloutHandler, RolloutHealthChecker +from xtuner.v1.rl.rollout.utils import PartialRolloutHandler, WorkerLifecycleState from xtuner.v1.rl.rollout.worker import RolloutWorker from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult -class _FakeRemoteMethod: - def __init__(self, name, call_log): - self.name = name - self.call_log = call_log +class _FakeAsyncRemoteMethod: + def __init__(self, result): + self.result = result + self.calls = [] def remote(self): - self.call_log.append((self.name, "remote")) - return self.name + self.calls.append(()) + async def _result(): + if isinstance(self.result, Exception): + raise self.result + return self.result -class _FakeWorker: - def __init__(self): - self.call_log = [] - self.offload = _FakeRemoteMethod("offload", self.call_log) - self.shutdown = _FakeRemoteMethod("shutdown", self.call_log) + return _result() class _FakeRolloutRouter: @@ -391,58 +391,44 @@ async def run_case(case_name, safe_post_result, safe_handle_response, expected_m await asyncio.gather(*(run_case(*case) for case in cases)) -class TestRolloutHealthChecker(unittest.TestCase): - def _build_checker(self, workers_info): - config = SimpleNamespace(health_check_interval_seconds=10, health_check_failure_threshold=1) - return RolloutHealthChecker(config, workers_info) +class TestRolloutHealthManager(unittest.TestCase): + def _build_manager(self, workers_info, *, failure_threshold=1): + config = SimpleNamespace(health_check_interval_seconds=10, health_check_failure_threshold=failure_threshold) + return RolloutHealthManager(config, workers_info, worker_infos_lock=threading.RLock()) - def test_shutdown_runs_when_offload_fails(self): - # worker 健康检查失败且 offload 也失败时,health checker 应 shutdown 并标记 inactive。 - worker = _FakeWorker() - workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=True)} - checker = self._build_checker(workers_info) + def test_marks_worker_inactive_after_consecutive_health_failures(self): + actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) + worker_info = WorkerInfo(actor=actor, url="http://worker-0") + workers_info = {0: worker_info} + manager = self._build_manager(workers_info, failure_threshold=2) - async def unhealthy_worker(*args, **kwargs): - return False + manager._check_and_deactivate_failed_worker_groups() - def ray_get(ref, timeout=None): - worker.call_log.append((ref, "get")) - if ref == "offload": - raise RuntimeError("offload failed") - return None + self.assertTrue(worker_info.is_active()) + self.assertEqual(actor.check_health.calls, [()]) - with ( - patch("xtuner.v1.rl.rollout.utils.check_worker_health", side_effect=unhealthy_worker), - patch("xtuner.v1.rl.rollout.utils.ray.get", side_effect=ray_get), - ): - checker.run_once() - - self.assertFalse(workers_info[0].is_active) - self.assertEqual( - worker.call_log, - [ - ("offload", "remote"), - ("offload", "get"), - ("shutdown", "remote"), - ("shutdown", "get"), - ], - ) + manager._check_and_deactivate_failed_worker_groups() + + self.assertFalse(worker_info.is_active()) + self.assertEqual(worker_info.lifecycle_state, WorkerLifecycleState.INACTIVE) + self.assertEqual(actor.check_health.calls, [(), ()]) def test_inactive_worker_is_not_cleaned_up_again(self): - # 已 inactive 的 worker 不再重复健康检查、offload 或 shutdown。 - worker = _FakeWorker() - workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=False)} - checker = self._build_checker(workers_info) + # 已 inactive 的 worker 不再重复健康检查。 + actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) + workers_info = { + 0: WorkerInfo( + actor=actor, + url="http://worker-0", + lifecycle_state=WorkerLifecycleState.INACTIVE, + ) + } + manager = self._build_manager(workers_info) - with ( - patch("xtuner.v1.rl.rollout.utils.check_worker_health") as check_worker_health_mock, - patch("xtuner.v1.rl.rollout.utils.ray.get") as ray_get_mock, - ): - checker.run_once() + checked_count = manager._check_and_deactivate_failed_worker_groups() - check_worker_health_mock.assert_not_called() - ray_get_mock.assert_not_called() - self.assertEqual(worker.call_log, []) + self.assertEqual(checked_count, 0) + self.assertEqual(actor.check_health.calls, []) class TestPartialRolloutHandler(unittest.IsolatedAsyncioTestCase): diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index 7542974d9..b245e0ccc 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -1,9 +1,7 @@ import asyncio -import math -import os import threading from dataclasses import dataclass -from typing import Any, Dict, List, Optional, TypeAlias, TypedDict +from typing import Dict, List, TypeAlias, TypedDict from uuid import uuid4 import ray @@ -15,43 +13,72 @@ from xtuner.v1.rl.utils import AutoAcceleratorWorkers from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger +from .health_manager import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthManager from .parser.factory import build_reasoning_parser, build_tool_call_parser from .parser.reasoning_parser import ReasoningParser from .parser.tool_parser import ToolCallParser -from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter +from .utils import SessionRouter, WorkerLifecycleState from .worker import ( ROLLOUT_CONCURRENCY_GROUP_GENERATE, RolloutConfig, RolloutWorker, + get_rollout_worker_base_cls, ) -@dataclass +@dataclass(init=False) class WorkerInfo: - """A data class to hold all state information for a single worker.""" + """Controller-owned state record for one rollout server process.""" actor: RolloutWorker url: str session_url: str | None = None - is_active: bool = True + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE + lifecycle_group_ranks: tuple[int, ...] = () + # True only for server processes that receive rollout requests. Non-entrypoint + # server processes are still tracked for health, recovery, and shutdown. + is_request_entrypoint: bool = True + + def __init__( + self, + actor: RolloutWorker, + url: str, + session_url: str | None = None, + lifecycle_state: WorkerLifecycleState | str | None = None, + lifecycle_group_ranks: tuple[int, ...] = (), + is_request_entrypoint: bool = True, + ): + self.actor = actor + self.url = url + self.session_url = session_url + self.lifecycle_group_ranks = lifecycle_group_ranks + self.is_request_entrypoint = is_request_entrypoint + if lifecycle_state is not None: + self.lifecycle_state = WorkerLifecycleState(lifecycle_state) + else: + self.lifecycle_state = WorkerLifecycleState.ACTIVE + + def is_active(self) -> bool: + return self.lifecycle_state is WorkerLifecycleState.ACTIVE class RolloutWorkerMetadata(TypedDict): """Metadata for rollout workers and their configuration. This data structure encapsulates all necessary information about the rollout worker infrastructure, including - engine topology, server addresses, and worker status. Used for communication between training processes and rollout + engine mesh, server addresses, and worker status. Used for communication between training processes and rollout workers. """ + # TODO(@duanyanhui): combine server_url_dict, worker_server_urls_status, worker_session_url_dict, and worker_session_urls_status into a single dict keyed by rank or URL to avoid potential inconsistencies. # 推理引擎的拓扑结构,每个子列表代表一个推理引擎包含的所有 worker ranks # 例如:[[0, 1, 2, 3], [4, 5, 6, 7]] 表示有 2 个推理引擎,每个引擎包含 4 个 workers # 用于确定分布式推理的并行组划分 engine_rank_mesh_array: List[List[int]] # worker rank 到服务器 URL 的映射字典,用于训练进程与 rollout workers 通信 - # 键:worker 的 rank ID(字符串形式的整数) - # 值:对应的服务器地址列表(通常每个 rank 对应一个 URL) + # 键:worker 的 rank ID + # 值:对应的 request entrypoint 服务器地址 server_url_dict: Dict[int, str] # Rollout 配置对象,包含推理引擎的所有配置参数 @@ -76,8 +103,11 @@ class RolloutWorkerMetadata(TypedDict): # Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller # state; passing a normal Python object would serialize a separate copy into each actor. class RolloutController: - """Controller for managing and coordinating multiple RolloutWorker - actors.""" + """Control-plane entrypoint for rollout traffic and worker startup. + + The controller creates workers, routes generate requests, and broadcasts training lifecycle commands. Health state + transitions and worker recovery belong to RolloutHealthManager. + """ def __init__( self, @@ -94,23 +124,24 @@ def __init__( self.config = infer_config self.num_gpus_per_engine = self.config.num_gpus_per_engine self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") - self.engine_rank_mesh_array: List[List[int]] = [] - self.worker_server_urls_map: dict[int, str] = {} - self.rank2info: dict[int, WorkerInfo] = {} - self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group) - self.num_active_workers = len(self.rank2info) + self.engine_rank_mesh_array: List[List[int]] + self.server_process_rank2info: dict[int, WorkerInfo] + self.engine_rank_mesh_array, self.server_process_rank2info = self._init_workers(placement_group) + # Cache the exact controller concurrency chosen at build time so + # downstream components observe the same limit as the Ray actor. + self._generate_concurrency = self.config.get_controller_generate_concurrency(placement_group) self.worker_info_lock = threading.RLock() # The timeout for the environment to wait for the rollout controller's response. # This should be longer than the controller's internal timeout (`rollout_timeout`) # to account for potential queuing delays and other overheads. self.timeout_multiplier = 2.0 - self.router = SessionRouter(self.rank2info, worker_infos_lock=self.worker_info_lock) - self.health_checker = RolloutHealthChecker( + self.router = SessionRouter(self.server_process_rank2info, worker_infos_lock=self.worker_info_lock) + self.health_manager = RolloutHealthManager( config=self.config, - workers_info=self.rank2info, + workers_info=self.server_process_rank2info, worker_infos_lock=self.worker_info_lock, ) - self.health_checker.start() + self.health_manager.start() self._tool_call_parser, self._reasoning_parser = self._build_output_parsers() def get_rollout_metadata(self) -> RolloutWorkerMetadata: @@ -120,17 +151,29 @@ def get_rollout_metadata(self) -> RolloutWorkerMetadata: dict: A dictionary containing the engine mesh list, server URL dictionary, and the rollout configuration. """ + worker_session_url_dict: dict[int, str] = {} + worker_session_urls_status: dict[str, bool] = {} with self.worker_info_lock: - worker_server_urls_status = {info.url: info.is_active for info in self.rank2info.values()} - worker_session_url_dict = { - rank: info.session_url for rank, info in self.rank2info.items() if info.session_url is not None + request_entrypoint_rank2info = { + rank: info for rank, info in self.server_process_rank2info.items() if info.is_request_entrypoint } - worker_session_urls_status = { - info.session_url: info.is_active for info in self.rank2info.values() if info.session_url is not None + worker_server_urls_map = {rank: info.url for rank, info in request_entrypoint_rank2info.items()} + worker_server_urls_status = {info.url: info.is_active() for info in request_entrypoint_rank2info.values()} + active_session_urls_by_rank = { + rank: info.session_url for rank, info in request_entrypoint_rank2info.items() } + for rank, info in request_entrypoint_rank2info.items(): + if info.session_url is None: + continue + worker_session_url_dict[rank] = info.session_url + worker_session_urls_status[info.session_url] = info.is_active() + self.logger.info(f"Rollout worker server URLs: {worker_server_urls_map}") + self.logger.info(f"Rollout worker session server URLs: {active_session_urls_by_rank}") + + # TODO(@duanyanhui): provide an unified structure that combines server URLs and session URLs rollout_metadata: RolloutWorkerMetadata = { "engine_rank_mesh_array": self.engine_rank_mesh_array, - "server_url_dict": self.worker_server_urls_map, + "server_url_dict": worker_server_urls_map, "rollout_config": self.config, "worker_server_urls_status": worker_server_urls_status, "worker_session_url_dict": worker_session_url_dict, @@ -152,25 +195,8 @@ def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser return tool_call_parser, reasoning_parser - def get_ready_status(self) -> tuple[bool, dict[str, Any]]: - with self.worker_info_lock: - active_workers = sum(1 for info in self.rank2info.values() if info.is_active) - total_workers = len(self.rank2info) - return active_workers > 0, { - "active_workers": active_workers, - "total_workers": total_workers, - } - def get_generate_concurrency(self) -> int: - assert self.config.rollout_max_batch_size_per_instance is not None, ( - "rollout_max_batch_size_per_instance must be set before building AgentLoop." - ) - concurrency_per_worker = math.ceil( - self.config.rollout_max_batch_size_per_instance * self.config.allow_over_concurrency_ratio - ) - with self.worker_info_lock: - active_workers = sum(1 for info in self.rank2info.values() if info.is_active) - return active_workers * concurrency_per_worker + return self._generate_concurrency @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE) async def generate(self, rollout_state: RolloutState) -> RolloutState: @@ -224,15 +250,21 @@ def _apply_output_parsers(self, rollout_state: RolloutState) -> None: def set_enable_partial_rollout(self, enable: bool) -> None: """Propagate enable_partial_rollout flag to all active workers.""" - with self.worker_info_lock: - active_actors = [info.actor for info in self.rank2info.values() if info.is_active] - ray.get([actor.set_enable_partial_rollout.remote(enable) for actor in active_actors]) # type: ignore[attr-defined] + active_workers = self.health_manager.snapshot_active_workers() + ray.get( + [ + worker.actor.set_enable_partial_rollout.remote(enable) # type: ignore[attr-defined] + for worker in active_workers + ] + ) def pause_generation(self): - self.health_checker.pause() - with self.worker_info_lock: - active_workers = [info for info in self.rank2info.values() if info.is_active] - futures = [info.actor.pause_generation.remote() for info in active_workers] # type: ignore[attr-defined] + self.health_manager.pause() + active_workers = self.health_manager.snapshot_active_workers() + futures = [ + worker.actor.pause_generation.remote() # type: ignore[attr-defined] + for worker in active_workers + ] try: results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) except Exception: @@ -240,62 +272,25 @@ def pause_generation(self): f"RolloutController pause_generation failed for {len(active_workers)} active workers." ) raise - succeeded_worker_urls = [info.url for info, result in zip(active_workers, results) if result is not False] - failed_worker_urls = [info.url for info, result in zip(active_workers, results) if result is False] + succeeded_worker_urls = [worker.url for worker, result in zip(active_workers, results) if result is not False] + failed_worker_urls = [worker.url for worker, result in zip(active_workers, results) if result is False] if succeeded_worker_urls: self.logger.info(f"Abort request sent successfully: count={len(succeeded_worker_urls)}") if failed_worker_urls: self.logger.warning(f"Abort request failed: worker_urls={failed_worker_urls}") - def ensure_workers_healthy_before_training(self): - """Ensure rollout workers are healthy before colocated training - onloads.""" - health_checker_was_paused = self.health_checker.is_paused() - if not health_checker_was_paused: - self.health_checker.pause() - try: - with self.worker_info_lock: - workers = {rank: (info.actor, info.url, info.is_active) for rank, info in self.rank2info.items()} - - for rank, (actor, url, was_active) in workers.items(): - try: - is_healthy = ray.get(actor.check_health.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - except Exception as e: - is_healthy = False - self.logger.warning(f"Final health check raised for rollout worker {rank} at {url}: {e}.") - - if not is_healthy: - self.logger.warning(f"Final health check failed for rollout worker {rank} at {url}.") - - with self.worker_info_lock: - info = self.rank2info[rank] - info.is_active = bool(is_healthy) - - if is_healthy and not was_active: - self.logger.info(f"Mark rollout worker {rank} active after final health check: url={url}") - elif not is_healthy and was_active: - self.logger.warning( - f"Mark rollout worker {rank} inactive because final health check failed before training: url={url}" - ) + async def check_and_shutdown_inactive_workers(self): + """Run a fail-fast health barrier and shut down failed groups so + training can reuse shared rollout resources.""" + await asyncio.to_thread(self.health_manager.check_and_shutdown_inactive_workers) - self._recover_failed_workers() - with self.worker_info_lock: - inactive_workers = [ - f"rank={rank}, url={info.url}" for rank, info in self.rank2info.items() if not info.is_active - ] - if inactive_workers: - raise RuntimeError( - "inactive rollout workers before training: " - + ", ".join(inactive_workers) - + ". Refusing to onload training workers because rollout GPU memory may still be held." - ) - finally: - if not health_checker_was_paused: - self.health_checker.resume() + async def restart_inactive_workers(self): + """Restart inactive groups before a sync-step weight update.""" + await asyncio.to_thread(self.health_manager.restart_inactive_workers) def continue_generation(self): - self.health_checker.resume() self._broadcast_to_active_workers("continue_generation") + self.health_manager.resume() def offload(self): self._broadcast_to_active_workers("offload") @@ -311,237 +306,124 @@ def onload_kvcache(self): self._broadcast_to_active_workers("onload_kvcache") def shutdown(self): - """Shuts down all active rollout workers. - - Args: - block (bool): Whether to block until the operation completes. - """ - self.health_checker.stop() - self._broadcast_to_active_workers("shutdown", stop_session_server=True) - - def _recover_failed_workers(self) -> None: - """Recover inactive workers before training while keeping health checks - paused.""" + """Shut down all rollout workers tracked by the controller.""" + self.health_manager.stop() with self.worker_info_lock: - failed_workers = [info for info in self.rank2info.values() if not info.is_active] - - if not failed_workers: - self.logger.info("No failed workers detected during recovery.") - return - - self.logger.warning(f"Detected {len(failed_workers)} failed workers. Initiating recovery process.") - for worker in failed_workers: - if self._restart_failed_workers(worker.actor, expected_url=worker.url): - with self.worker_info_lock: - rank = self._get_rank_by_actor(worker.actor) - if rank is not None: - self.rank2info[rank].is_active = True - - def _restart_failed_workers(self, worker: RolloutWorker, expected_url: str) -> bool: - try: - # 先保证把老的worker关掉 - ray.get(worker.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - # 保证新的worker启动在之前的端口上,否则权重更新会出错 - _, url = ray.get(worker.init.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - assert url == expected_url, f"Worker restarted with unexpected URL: expected {expected_url}, got {url}." - _, session_url = ray.get(worker.get_session_server_info.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - is_healthy = ray.get(worker.check_health.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - - if is_healthy: - self.logger.info(f"Successfully restarted worker {worker} with URL {url}.") - with self.worker_info_lock: - rank = self._get_rank_by_actor(worker) - if rank is not None: - self.rank2info[rank].url = url - self.rank2info[rank].session_url = session_url - self.worker_server_urls_map[rank] = url - return True - else: - self.logger.error(f"Worker {worker} is still unhealthy after restart.") - return False - except AssertionError: - raise - except Exception as e: - self.logger.error(f"Failed to restart worker: {e}") - return False - - def _update_dist_init_addr(self, nodes_per_engine, server_urls_per_engine, dist_init_addrs, tp_size): - """Update the distributed initialization addresses for workers. - - This is used to group workers that belong to the same inference engine. - - Args: - nodes_per_engine (int): The number of nodes per inference engine. - server_urls_per_engine (int): The number of server urls per inference engine. - dist_init_addrs (list): The list of initial addresses. - tp_size (int): The tensor parallel size. - - Returns: - list: The updated list of distributed initialization addresses. - """ - # lmdeploy pytorch ep: server_urls_per_engine > 1 - # sglang cross node engine: nodes_per_engine > 1 - assert server_urls_per_engine == 1 or nodes_per_engine == 1 - if nodes_per_engine > 1: - index = list(range(0, self.num_active_workers + 1, tp_size)) + [self.num_active_workers] - for i in range(1, len(index)): - dist_init_addrs[index[i - 1] : index[i]] = [dist_init_addrs[index[i - 1]]] * (index[i] - index[i - 1]) - if server_urls_per_engine > 1: - activate_servers = len(dist_init_addrs) - for i in range(0, activate_servers, server_urls_per_engine): - dist_init_addrs[i : i + server_urls_per_engine] = [dist_init_addrs[i]] * server_urls_per_engine - return dist_init_addrs + actors = [info.actor for info in self.server_process_rank2info.values()] + ray.get( + [actor.shutdown.remote(stop_session_server=True) for actor in actors], # type: ignore[attr-defined] + timeout=ROLLOUT_RAY_GET_TIMEOUT, + ) def _broadcast_to_active_workers(self, method_name: str, **kwargs): - """Helper function to call a method on all active workers. - - Args: - method_name (str): The name of the method to call. - block (bool): Whether to block until the call completes. - - Returns: - A list of futures if `block` is False, otherwise a list of results. - """ - futures = [] - with self.worker_info_lock: - active_actors = [info.actor for info in self.rank2info.values() if info.is_active] - futures = [getattr(actor, method_name).remote(**kwargs) for actor in active_actors] - results = ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) - return results + workers = self.health_manager.snapshot_active_workers() + futures = [getattr(worker.actor, method_name).remote(**kwargs) for worker in workers] + return ray.get(futures, timeout=ROLLOUT_RAY_GET_TIMEOUT) - def _get_worker_cls(self): - if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": - from .lmdeploy import LMDeployWorker - - worker_cls = LMDeployWorker - elif os.environ.get("XTUNER_USE_VLLM") == "1": - from .vllm import vLLMWorker - - worker_cls = vLLMWorker - elif os.environ.get("XTUNER_USE_SGLANG") == "1": - from .sglang import SGLangWorker - - worker_cls = SGLangWorker - else: - raise NotImplementedError( - "Rollout backend is not supported." - "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM" - " or XTUNER_USE_SGLANG environment variable." - ) + def _build_remote_worker_cls(self, worker_base_cls): assert self.config.rollout_max_batch_size_per_instance is not None, ( "rollout_max_batch_size_per_instance must be set before building RolloutWorker." ) worker_generate_max_concurrency = max( 1000, # Ray async actor default max_concurrency. - math.ceil(self.config.rollout_max_batch_size_per_instance * self.config.allow_over_concurrency_ratio), + self.config.generate_concurrency_per_instance, ) return ray.remote( concurrency_groups={ ROLLOUT_CONCURRENCY_GROUP_GENERATE: worker_generate_max_concurrency, }, - )(worker_cls) - - def _get_rank_by_actor(self, actor: RolloutWorker) -> Optional[int]: - """Get rank by actor object. - - Args: - actor: The RolloutWorker actor object. - - Returns: - The rank of the worker, or None if not found. - """ - for rank, info in self.rank2info.items(): - if info.actor == actor: - return rank - return None - - def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): - """Update the list of active rollout workers and their server URLs. - - When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with - tp_rank=0 in each engine is responsible for receiving input data. Other tp_ranks do not accept input. - Therefore, this function updates active_rollout_workers and worker_server_urls_map to keep only the tp_rank=0 - workers and their corresponding URLs. - """ - if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: - return active_rollout_workers, worker_server_urls_map - else: - active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node - active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] - active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] - return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) + )(worker_base_cls) - def _init_workers(self, placement_group: PlacementGroup): + def _init_workers(self, placement_group: PlacementGroup) -> tuple[List[List[int]], dict[int, WorkerInfo]]: """Initializes and configures the pool of RolloutWorker actors. - This method creates workers from the placement group, configures distributed - inference engines by grouping workers, where each group forms a tensor-parallel - inference engine. It determines the `active_workers` to act as the head of each - engine, constructs the `engine_rank_mesh_array` to define engine topology, - acquires necessary distributed communication ports, and finally launches servers - on the `active_workers` to get their addresses. + This method follows the same high-level flow as the legacy implementation: + create workers, initialize worker-local ports, build engine groups, + select workers that launch rollout servers, launch servers, and + expose request-entrypoint server URLs to rollout traffic. Returns: - Tuple[List, Dict]: A tuple where the first element is - `engine_rank_mesh_array`, a list of lists containing the ranks of workers - in each engine, and the second element is `worker_server_urls_map`, - a dictionary mapping the rank of each active worker to its - corresponding server URL. + A tuple of `engine_rank_mesh_array` and all server-process workers + for lifecycle management. """ - # Create workers from placement group + worker_base_cls = get_rollout_worker_base_cls(self.config) + worker_cls = self._build_remote_worker_cls(worker_base_cls) + + # Create workers from placement group. workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( - self._get_worker_cls(), self.config, placement_group + worker_cls, self.config, placement_group ) - active_servers_count, nodes_per_engine = self.config.get_active_servers_count(len(workers)) - interval = len(workers) // active_servers_count - active_rollout_workers = workers[::interval] - server_urls_per_engine = self.config.server_urls_per_engine - - set_bundle_idxs_objectref = [] - engine_rank_mesh_array = [] - activate_worker_idx = 0 - for active_worker in active_rollout_workers: - head_rank, _ = rank_bundle_idx_list[activate_worker_idx] - engine_workers_meta = rank_bundle_idx_list[head_rank : head_rank + interval] - engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) - set_bundle_idxs_objectref.append(active_worker._set_engine_bundle_idxs.remote(engine_bundle_idxs)) # type: ignore[attr-defined] - engine_rank_mesh_array.append([meta[0] for meta in engine_workers_meta]) - activate_worker_idx += interval - ray.get(set_bundle_idxs_objectref) - # set engine mesh list for each worker - ray.get( - [worker._set_engine_rank_mesh_array.remote(engine_rank_mesh_array) for worker in active_rollout_workers] - ) # type: ignore[attr-defined] - # init dist_init_addr for each worker according to parallel settings - init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] - dist_init_addrs = self._update_dist_init_addr( - nodes_per_engine, server_urls_per_engine, init_dist_init_addrs, self.num_gpus_per_engine + rank_to_actor = {rank: worker for (rank, _), worker in zip(rank_bundle_idx_list, workers)} + + # Reserve worker-local ports for all actors first. build_engine_launch_specs + # uses the returned addresses to bind each ServerProcessSpec to its + # logical engine rendezvous address; only server-process owners call init(). + rank_to_dist_init_addr = { + rank: dist_init_addr + for (rank, _), dist_init_addr in zip( + rank_bundle_idx_list, + ray.get([worker.init_dist_port.remote() for worker in workers]), # type: ignore[attr-defined] + ) + } + + # Build engine groups and server-process specs from the rank/bundle mapping. + engine_launch_specs = worker_base_cls.build_engine_launch_specs( + self.config, + rank_bundle_idx_list, + rank_to_dist_init_addr, ) - # launch rollout servers - init_results = ray.get( - [worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)] + # Keep the public metadata mesh compatible with origin/main. Backends + # may expose a different update-weight mesh than their internal launch + # topology, e.g. LMDeploy EP has one logical engine but one public entry + # per request-serving EP rank. + engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs) + + # Launch every server process described by the backend-specific specs. + server_rank_to_url = dict( + ray.get( + [ + rank_to_actor[server_process.worker_rank].init.remote( # type: ignore[attr-defined] + engine_launch_spec=engine_spec, + ) + for engine_spec in engine_launch_specs + for server_process in engine_spec.server_processes + ] + ) ) - worker_server_urls_map = dict(init_results) # rank -> url - worker_session_url_dict = dict( - ray.get([worker.get_session_server_info.remote() for worker in active_rollout_workers]) + session_url_by_rank = dict( + ray.get( + [ + ( + rank_to_actor[server_process.worker_rank].get_session_server_info.remote() # type: ignore[attr-defined] + ) + for engine_spec in engine_launch_specs + for server_process in engine_spec.server_processes + ] + ) ) - active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( - active_rollout_workers, worker_server_urls_map + + server_process_workers_info: dict[int, WorkerInfo] = {} + for engine_spec in engine_launch_specs: + for server_process in engine_spec.server_processes: + rank = server_process.worker_rank + url = server_rank_to_url[rank] + session_url = session_url_by_rank.get(rank) + if server_process.accepts_rollout_requests and session_url is None: + raise RuntimeError(f"Rollout worker rank={rank} did not return session server URL during init.") + server_process_workers_info[rank] = WorkerInfo( + actor=rank_to_actor[rank], + url=url, + session_url=session_url, + lifecycle_group_ranks=engine_spec.server_worker_ranks, + is_request_entrypoint=server_process.accepts_rollout_requests, + ) + + self.logger.info( + f"Rollout server-process worker URLs: {[info.url for info in server_process_workers_info.values()]}" ) - active_ranks = list(worker_server_urls_map.keys()) - worker_session_url_dict = {rank: worker_session_url_dict[rank] for rank in active_ranks} - workers_info = {} - for i in range(len(active_rollout_workers)): - rank = list(worker_server_urls_map.keys())[i] - url = worker_server_urls_map[rank] - workers_info[rank] = WorkerInfo( - actor=active_rollout_workers[i], - url=url, - session_url=worker_session_url_dict[rank], - ) - self.logger.info(f"Rollout worker server URLs: {[info.url for info in workers_info.values()]}") - self.logger.info(f"Rollout worker session server URLs: {[info.session_url for info in workers_info.values()]}") - return engine_rank_mesh_array, worker_server_urls_map, workers_info + lifecycle_groups = sorted({info.lifecycle_group_ranks for info in server_process_workers_info.values()}) + self.logger.info(f"Rollout worker lifecycle groups: {lifecycle_groups}") + return engine_rank_mesh_array, server_process_workers_info RayRolloutController = ray.remote(RolloutController) diff --git a/xtuner/v1/rl/rollout/health_manager.py b/xtuner/v1/rl/rollout/health_manager.py new file mode 100644 index 000000000..1a9354344 --- /dev/null +++ b/xtuner/v1/rl/rollout/health_manager.py @@ -0,0 +1,615 @@ +import asyncio +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import ray + +from xtuner.v1.utils import get_logger + +from .utils import WorkerLifecycleState + + +if TYPE_CHECKING: + from .controller import WorkerInfo + from .worker import RolloutConfig, RolloutWorker + +ROLLOUT_RAY_GET_TIMEOUT = int(os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", str(5 * 3600))) # default 5 hours +ROLLOUT_RECOVERY_MAX_PARALLEL_GROUPS = 4 +HEALTH_MANAGER_STOP_JOIN_TIMEOUT = 30.0 +logger = get_logger() + +__all__ = [ + "HEALTH_MANAGER_STOP_JOIN_TIMEOUT", + "ROLLOUT_RAY_GET_TIMEOUT", + "RolloutHealthManager", + "WorkerSnapshot", +] + + +@dataclass(frozen=True) +class WorkerSnapshot: + rank: int + actor: "RolloutWorker" + url: str + session_url: str | None + lifecycle_state: WorkerLifecycleState + active: bool + lifecycle_group_ranks: tuple[int, ...] + is_request_entrypoint: bool + + +@dataclass(frozen=True) +class _WorkerGroupSnapshot: + ranks: tuple[int, ...] + workers: tuple[WorkerSnapshot, ...] + + +class RolloutHealthManager: + """Own worker health state and recovery after controller startup. + + RolloutController creates workers, launches them the first time, and routes requests. RolloutHealthManager only + reads that WorkerInfo table, updates lifecycle_state, runs health checks, and restarts failed lifecycle groups. + Worker actors still own backend-specific server start/stop/probe/generate. + """ + + def __init__( + self, + config: "RolloutConfig", + workers_info: dict[int, "WorkerInfo"], + worker_infos_lock: Optional[threading.RLock] = None, + ): + self._workers_info = workers_info + self._worker_infos_lock = worker_infos_lock or threading.RLock() + self._check_interval = config.health_check_interval_seconds + self._check_timeout_seconds = config.health_check_timeout_seconds + self._check_failure_threshold = config.health_check_failure_threshold + self._stop_event: Optional[threading.Event] = None + self._pause_event: Optional[threading.Event] = None + self._thread: Optional[threading.Thread] = None + self._operation_lock = threading.Lock() + self._worker_health_failure_counts: dict[int, int] = {} + self._stopped = False + + def start(self) -> None: + health_thread_alive = self._thread is not None and self._thread.is_alive() + if health_thread_alive: + return + + self._stopped = False + self._stop_event = threading.Event() + self._pause_event = threading.Event() + self._pause_event.set() + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + logger.info("RolloutHealthManager started.") + + def stop(self) -> None: + thread = self._thread + if not thread: + return + + assert self._stop_event is not None + self._stopped = True + self._stop_event.set() + if self._pause_event: + self._pause_event.clear() + + thread.join(timeout=HEALTH_MANAGER_STOP_JOIN_TIMEOUT) + if thread.is_alive(): + logger.warning( + f"RolloutHealthManager stop timed out after {HEALTH_MANAGER_STOP_JOIN_TIMEOUT}s; " + "health thread is still exiting." + ) + return + + self._thread = None + self._stop_event = None + self._pause_event = None + logger.info("RolloutHealthManager stopped.") + + def pause(self) -> None: + if self._pause_event is None: + return + self._pause_event.set() + logger.info("RolloutHealthManager paused.") + + def resume(self) -> None: + if self._pause_event is None: + return + self._pause_event.clear() + logger.info("RolloutHealthManager resumed.") + + def _is_paused(self) -> bool: + return self._pause_event is None or self._pause_event.is_set() + + def _is_stopping(self) -> bool: + """Return whether the health manager is stopping or already stopped.""" + return self._stopped or (self._stop_event is not None and self._stop_event.is_set()) + + @contextmanager + def _background_health_checks_paused(self): + was_paused = self._is_paused() + if not was_paused: + self.pause() + try: + yield + finally: + if not was_paused: + self.resume() + + def snapshot_workers(self) -> dict[int, WorkerSnapshot]: + """Return immutable worker state for callers that only need to query + rollout workers.""" + with self._worker_infos_lock: + return { + rank: WorkerSnapshot( + rank=rank, + actor=info.actor, + url=info.url, + session_url=getattr(info, "session_url", None), + lifecycle_state=info.lifecycle_state, + active=info.is_active(), + lifecycle_group_ranks=tuple(getattr(info, "lifecycle_group_ranks", ()) or ()), + is_request_entrypoint=bool(getattr(info, "is_request_entrypoint", True)), + ) + for rank, info in self._workers_info.items() + } + + def snapshot_active_workers(self) -> list[WorkerSnapshot]: + """Return active worker snapshots.""" + return [worker for worker in self.snapshot_workers().values() if worker.active] + + def restart_inactive_workers(self) -> None: + """Synchronously restart inactive groups before the next sync-step + weight update.""" + with self._background_health_checks_paused(): + with self._operation_lock: + worker_groups = self._snapshot_worker_groups() + failed_groups = [ + group for group in worker_groups.values() if any(not worker.active for worker in group.workers) + ] + if not failed_groups: + logger.info("No failed rollout workers detected during recovery.") + return + + sorted_failed_groups = sorted(failed_groups, key=lambda group: group.ranks) + for group in sorted_failed_groups: + failed_ranks = sorted(worker.rank for worker in group.workers if not worker.active) + logger.warning( + f"Detected failed rollout worker ranks={failed_ranks}; restart_group_ranks={group.ranks}." + ) + self._set_group_lifecycle_state(group.ranks, WorkerLifecycleState.RECOVERING) + + if self._abort_restart_recovery_if_stopping(sorted_failed_groups): + return + + logger.info( + f"Restarting rollout worker groups in parallel: " + f"group_ranks={[group.ranks for group in sorted_failed_groups]}, " + f"max_parallel_groups={ROLLOUT_RECOVERY_MAX_PARALLEL_GROUPS}." + ) + group_recovery_results: dict[tuple[int, ...], bool] = {} + max_workers = min(len(sorted_failed_groups), max(1, ROLLOUT_RECOVERY_MAX_PARALLEL_GROUPS)) + with ThreadPoolExecutor( + max_workers=max_workers, + thread_name_prefix="rollout-recovery", + ) as pool: + future_to_group = { + pool.submit( + self._restart_worker_group, + group, + ): group + for group in sorted_failed_groups + } + for future in as_completed(future_to_group): + group = future_to_group[future] + try: + group_recovery_results[group.ranks] = future.result() + except Exception: + logger.exception(f"Failed to restart rollout worker group ranks={group.ranks}.") + group_recovery_results[group.ranks] = False + + if self._abort_restart_recovery_if_stopping( + sorted_failed_groups, + group_recovery_results=group_recovery_results, + ): + return + + failed_recovery_groups: list[_WorkerGroupSnapshot] = [] + for group in sorted_failed_groups: + is_recovered = group_recovery_results.get(group.ranks, False) + self._set_group_lifecycle_state( + group.ranks, + WorkerLifecycleState.ACTIVE if is_recovered else WorkerLifecycleState.INACTIVE, + ) + if not is_recovered: + failed_recovery_groups.append(group) + inactive_workers = [ + f"rank={worker.rank}, url={worker.url}" + for worker in self.snapshot_workers().values() + if not worker.active + ] + if inactive_workers: + logger.error("inactive rollout workers before sync-step weight update: " + ", ".join(inactive_workers)) + if failed_recovery_groups: + logger.error( + "Failed to restart rollout worker groups; training can continue with remaining active rollout " + "workers and skip inactive groups during rollout-side operations: " + + "; ".join( + f"ranks={group.ranks}, workers=[" + + ", ".join(f"rank={worker.rank}, url={worker.url}" for worker in group.workers) + + "]" + for group in failed_recovery_groups + ) + ) + + def check_and_shutdown_inactive_workers(self) -> None: + """Fail-fast health-check active workers, mark failures inactive, and + shut down every non-active group so shared resources can be reused by + training.""" + with self._background_health_checks_paused(): + self._check_and_deactivate_failed_worker_groups(fail_fast=True) + with self._operation_lock: + worker_groups = self._snapshot_worker_groups() + inactive_groups = [ + group + for group in worker_groups.values() + if any(worker.lifecycle_state is not WorkerLifecycleState.ACTIVE for worker in group.workers) + ] + + if not inactive_groups: + logger.info("No failed rollout workers detected during shutdown barrier.") + return + + failed_shutdown_groups: list[_WorkerGroupSnapshot] = [] + for group in sorted(inactive_groups, key=lambda group: group.ranks): + self._set_group_lifecycle_state(group.ranks, WorkerLifecycleState.INACTIVE) + is_shutdown = self._shutdown_worker_group(group, wait_server_down=True, best_effort=False) + if not is_shutdown: + failed_shutdown_groups.append(group) + logger.error( + "failed to shut down inactive rollout workers before training: " + + ", ".join(f"rank={worker.rank}, url={worker.url}" for worker in group.workers) + ) + if failed_shutdown_groups: + logger.error( + "Failed to shut down inactive rollout worker groups; training can continue with remaining " + "active rollout workers and failed groups stay inactive for rollout-side operations: " + + "; ".join( + f"ranks={group.ranks}, workers=[" + + ", ".join(f"rank={worker.rank}, url={worker.url}" for worker in group.workers) + + "]" + for group in failed_shutdown_groups + ) + ) + + def run_once(self) -> None: + logger.debug("RolloutHealthManager running health checks for all workers.") + checked_active_count = self._check_and_deactivate_failed_worker_groups() + if self.snapshot_active_workers() or self._is_stopping(): + return + + if checked_active_count == 0: + logger.error("No active rollout workers before health check. All rollout workers are inactive.") + else: + logger.error("All rollout workers failed after health check. All rollout workers are inactive.") + # TODO(duanyanhui): Propagate this fatal rollout-dead state to the + # trainer and abort training immediately instead of only logging here. + + def _check_and_deactivate_failed_worker_groups(self, *, fail_fast: bool = False) -> int: + """Health-check active workers and mark any failed lifecycle group + inactive.""" + if self._check_failure_threshold <= 0 and not fail_fast: + logger.debug("Rollout worker periodic health check is disabled.") + return 0 + + with self._operation_lock: + workers_to_check = self.snapshot_active_workers() + + if not workers_to_check: + return 0 + + check_results = self._check_workers_health(workers_to_check, fail_fast=fail_fast) + + failed_groups = { + worker.lifecycle_group_ranks or (worker.rank,) + for worker, is_healthy in zip(workers_to_check, check_results) + if not is_healthy + } + for group_ranks in sorted(failed_groups): + logger.warning(f"Rollout worker group ranks={group_ranks} failed health check. Marking as inactive.") + + if failed_groups: + with self._operation_lock: + if not self._is_stopping(): + active_groups = { + worker.lifecycle_group_ranks or (worker.rank,) for worker in self.snapshot_active_workers() + } + for group_ranks in failed_groups & active_groups: + self._set_group_lifecycle_state(group_ranks, WorkerLifecycleState.INACTIVE) + + return len(workers_to_check) + + def _check_workers_health(self, workers_to_check: list[WorkerSnapshot], *, fail_fast: bool = False) -> list[bool]: + """Run periodic check_health probes concurrently.""" + if self._check_failure_threshold <= 0 and not fail_fast: + return [True for _ in workers_to_check] + + async def check_one_worker(worker: WorkerSnapshot) -> bool: + if worker.actor is None or not worker.active: + logger.warning("Worker has no actor reference or is marked inactive.") + return False + try: + is_healthy = await asyncio.wait_for( + worker.actor.check_health.remote(), # type: ignore[attr-defined] + timeout=self._check_timeout_seconds, + ) + except Exception as e: + logger.error(f"Exception during check_health for worker {worker.rank} at {worker.url}: {e}.") + return False + if not is_healthy: + logger.warning(f"check_health failed for worker {worker.rank} at {worker.url}.") + return bool(is_healthy) + + async def check_workers(workers: list[WorkerSnapshot]) -> list[bool]: + return await asyncio.gather(*(check_one_worker(worker) for worker in workers)) + + check_results = asyncio.run(check_workers(workers_to_check)) + keep_active_by_rank: dict[int, bool] = {} + with self._operation_lock: + for worker, is_healthy in zip(workers_to_check, check_results): + if is_healthy: + self._worker_health_failure_counts.pop(worker.rank, None) + keep_active_by_rank[worker.rank] = True + else: + failure_count = self._worker_health_failure_counts.get(worker.rank, 0) + 1 + self._worker_health_failure_counts[worker.rank] = failure_count + if fail_fast: + logger.warning( + f"Worker {worker.rank} failed explicit health check and will be marked inactive " + f"immediately: failure_count={failure_count}." + ) + keep_active_by_rank[worker.rank] = False + continue + if failure_count >= self._check_failure_threshold: + logger.warning( + f"Worker {worker.rank} reached health check failure threshold: " + f"{failure_count}/{self._check_failure_threshold}." + ) + keep_active_by_rank[worker.rank] = False + else: + logger.warning( + f"Worker {worker.rank} health check failed but remains active: " + f"{failure_count}/{self._check_failure_threshold}." + ) + keep_active_by_rank[worker.rank] = True + + return [keep_active_by_rank[worker.rank] for worker in workers_to_check] + + def _run_loop(self) -> None: + assert self._stop_event is not None and self._pause_event is not None + logger.info("RolloutHealthManager loop started.") + + while not self._stop_event.is_set(): + while self._pause_event.is_set() and not self._stop_event.is_set(): + self._stop_event.wait(timeout=0.5) + + if self._stop_event.is_set(): + break + + if self._stop_event.wait(self._check_interval): + break + + if self._pause_event.is_set() or self._stop_event.is_set(): + continue + + try: + self.run_once() + except RuntimeError: + if self._is_stopping(): + break + logger.exception("RolloutHealthManager run_once failed.") + except Exception: + logger.exception("RolloutHealthManager run_once failed.") + + def _snapshot_worker_groups(self) -> dict[tuple[int, ...], _WorkerGroupSnapshot]: + """Group worker snapshots by lifecycle group so recovery operates on + whole groups.""" + workers_snapshot = self.snapshot_workers() + grouped_workers: dict[tuple[int, ...], list[WorkerSnapshot]] = {} + for worker in workers_snapshot.values(): + group_ranks = worker.lifecycle_group_ranks or (worker.rank,) + grouped_workers.setdefault(group_ranks, []).append(worker) + + return { + group_ranks: _WorkerGroupSnapshot( + ranks=group_ranks, + workers=tuple(sorted(workers, key=lambda worker: worker.rank)), + ) + for group_ranks, workers in grouped_workers.items() + } + + def _set_group_lifecycle_state( + self, + group_ranks: tuple[int, ...], + lifecycle_state: WorkerLifecycleState, + ) -> None: + """Update rollout worker lifecycle for every known rank in one + group.""" + with self._worker_infos_lock: + for rank in group_ranks: + if rank in self._workers_info: + self._workers_info[rank].lifecycle_state = lifecycle_state + if lifecycle_state is WorkerLifecycleState.ACTIVE: + self._worker_health_failure_counts.pop(rank, None) + + def _shutdown_worker_group( + self, + group: _WorkerGroupSnapshot, + *, + wait_server_down: bool, + best_effort: bool, + ) -> bool: + """Shutdown every worker in one group and aggregate per-worker shutdown + results.""" + max_wait_attempts = 60 + retry_interval_seconds = 5.0 + shutdown_succeeded = True + for worker in group.workers: + worker_shutdown_succeeded = True + try: + ray.get(worker.actor.shutdown.remote(), timeout=60) # type: ignore[attr-defined] + except Exception as e: + worker_shutdown_succeeded = False + log = logger.warning if best_effort else logger.error + log(f"Shutdown failed for rollout worker rank={worker.rank}, url={worker.url}: {e}") + + if worker_shutdown_succeeded and wait_server_down: + server_down = False + for attempt in range(1, max_wait_attempts + 1): + try: + is_healthy = ray.get(worker.actor.check_health.remote(), timeout=self._check_timeout_seconds) # type: ignore[attr-defined] + except Exception: + server_down = True + break + if not is_healthy: + server_down = True + break + if attempt < max_wait_attempts: + logger.warning( + f"Rollout worker rank={worker.rank} server still responds after shutdown " + f"attempt={attempt}/{max_wait_attempts}, url={worker.url}." + ) + time.sleep(retry_interval_seconds) + if not server_down: + logger.error( + f"Rollout worker rank={worker.rank} server did not stop after shutdown: url={worker.url}." + ) + worker_shutdown_succeeded = False + + if not worker_shutdown_succeeded: + shutdown_succeeded = False + return best_effort or shutdown_succeeded + + def _abort_restart_recovery_if_stopping( + self, + sorted_failed_groups: list[_WorkerGroupSnapshot], + *, + group_recovery_results: dict[tuple[int, ...], bool] | None = None, + ) -> bool: + if not self._is_stopping(): + return False + + for group in sorted_failed_groups: + is_recovered = False + if group_recovery_results is not None: + is_recovered = group_recovery_results.get(group.ranks, False) + if is_recovered: + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + self._set_group_lifecycle_state(group.ranks, WorkerLifecycleState.INACTIVE) + return True + + def _restart_worker_group( + self, + group: _WorkerGroupSnapshot, + ) -> bool: + """Shutdown, restart with empty-init, and health-check one complete + worker group.""" + if not group.workers or len(group.workers) != len(group.ranks): + logger.error(f"Cannot restart incomplete rollout worker group: ranks={group.ranks}.") + return False + if self._is_stopping(): + return False + + if not self._shutdown_worker_group(group, wait_server_down=True, best_effort=False): + return False + if self._is_stopping(): + return False + + try: + ray.get( + [ + worker.actor.set_skip_load_weights.remote(True) # type: ignore[attr-defined] + for worker in group.workers + ], + timeout=ROLLOUT_RAY_GET_TIMEOUT, + ) + init_results = ray.get( + [ + # init() reuses the immutable launch spec cached on each actor + # during controller startup, including placement bundles and dist addr. + worker.actor.init.remote() # type: ignore[attr-defined] + for worker in group.workers + ], + timeout=ROLLOUT_RAY_GET_TIMEOUT, + ) + if self._is_stopping(): + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + return False + if len(init_results) != len(group.workers): + logger.error( + f"Restarted rollout worker group ranks={group.ranks} returned {len(init_results)} init results, " + f"expected {len(group.workers)}." + ) + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + return False + + for worker, init_result in zip(group.workers, init_results): + init_rank, init_url = init_result + if init_rank != worker.rank or init_url != worker.url: + logger.error( + f"Rollout worker restart returned unexpected endpoint: rank={worker.rank}, " + f"init_rank={init_rank}, expected_url={worker.url}, init_url={init_url}." + ) + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + return False + + health_results = ray.get( + [worker.actor.check_health.remote() for worker in group.workers], # type: ignore[attr-defined] + timeout=self._check_timeout_seconds, + ) + if self._is_stopping(): + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + return False + unhealthy_ranks = [ + worker.rank for worker, is_healthy in zip(group.workers, health_results) if not is_healthy + ] + if unhealthy_ranks: + logger.error( + f"Restarted rollout worker group ranks={group.ranks} has unhealthy ranks={unhealthy_ranks}." + ) + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + return False + + # Newly restarted workers should return to the same offloaded/sleep + # baseline as the other colocated rollout workers before the sync + # path wakes weights/KV back up. + ray.get( + [worker.actor.offload.remote() for worker in group.workers], # type: ignore[attr-defined] + timeout=ROLLOUT_RAY_GET_TIMEOUT, + ) + + logger.info(f"Successfully restarted rollout worker group ranks={group.ranks}.") + return True + except Exception as e: + logger.error(f"Failed to restart rollout worker group ranks={group.ranks}: {e}") + self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) + return False + finally: + try: + ray.get( + [ + worker.actor.restore_skip_load_weights.remote() # type: ignore[attr-defined] + for worker in group.workers + ], + timeout=ROLLOUT_RAY_GET_TIMEOUT, + ) + except Exception: + logger.exception( + f"Failed to restore rollout worker skip_load_weights after restart: group_ranks={group.ranks}." + ) diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index fd24f8f64..7b9c2fa96 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -1,6 +1,5 @@ import os from argparse import Namespace -from itertools import chain from typing import Any, Dict, List import numpy as np @@ -11,7 +10,7 @@ from transformers import AutoTokenizer from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .worker import RolloutConfig, RolloutWorker +from .worker import EngineLaunchSpec, EngineLaunchSpecs, RolloutConfig, RolloutWorker, ServerProcessSpec SHARED_STORE = "shared_store" @@ -80,6 +79,120 @@ def __init__( self.enable_return_routed_experts = self.config.enable_return_routed_experts self.lmdeploy_actor = None + @classmethod + def build_engine_launch_specs( + cls, + config: RolloutConfig, + rank_bundle_idx_list: list[tuple[int, int]], + rank_to_dist_init_addr: dict[int, str] | None = None, + ) -> EngineLaunchSpecs: + """Build LMDeploy server launch layout. + + LMDeploy EP starts one request-serving server per EP rank. + + Example with expert_parallel_size=2: + rank_bundle_idx_list is [(0, 0), (1, 1), (2, 2), (3, 3)]. + rank identifies the rollout worker; bundle idx identifies the Ray + placement-group bundle that owns the GPU resource. + + If rank_to_dist_init_addr is: + {0: "addr0", 1: "addr1", 2: "addr2", 3: "addr3"} + + The launch specs are: + EngineLaunchSpec( + engine_ranks=(0, 1), + server_processes=( + ServerProcessSpec( + worker_rank=0, + placement_group_bundle_idxs=(0,), + dist_init_addr="addr0", + ), + ServerProcessSpec( + worker_rank=1, + placement_group_bundle_idxs=(1,), + dist_init_addr="addr0", + ), + ), + ) + EngineLaunchSpec( + engine_ranks=(2, 3), + server_processes=( + ServerProcessSpec( + worker_rank=2, + placement_group_bundle_idxs=(2,), + dist_init_addr="addr2", + ), + ServerProcessSpec( + worker_rank=3, + placement_group_bundle_idxs=(3,), + dist_init_addr="addr2", + ), + ), + ) + + Each EP rank launches a server process, so server_worker_ranks is the + same as engine_ranks, and every server accepts rollout requests. + """ + if config.expert_parallel_size <= 1: + return RolloutWorker.build_engine_launch_specs( + config, + rank_bundle_idx_list, + rank_to_dist_init_addr, + ) + + ep_size = config.expert_parallel_size + num_workers = len(rank_bundle_idx_list) + if num_workers % ep_size != 0: + raise ValueError(f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}.") + + engine_launch_specs: list[EngineLaunchSpec] = [] + for engine_start in range(0, num_workers, ep_size): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] + engine_ranks = tuple(rank for rank, _ in engine_meta) + engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] + # LMDeploy EP launches one server process for each EP rank. Each + # server owns exactly one placement-group bundle, and every server + # can be used as a rollout request entrypoint. + engine_launch_specs.append( + EngineLaunchSpec( + engine_ranks=engine_ranks, + server_processes=tuple( + ServerProcessSpec( + worker_rank=server_rank, + placement_group_bundle_idxs=(bundle_idx,), + dist_init_addr=engine_dist_init_addr, + ) + for server_rank, bundle_idx in engine_meta + ), + ) + ) + return cls.validate_engine_launch_specs( + tuple(engine_launch_specs), + known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), + ) + + @classmethod + def build_metadata_engine_rank_mesh_array( + cls, + engine_launch_specs: EngineLaunchSpecs, + ) -> list[list[int]]: + """Keep LMDeploy EP metadata compatible with origin/main. + + Pure EP uses one request-serving server per EP rank. The logical engine topology is still stored in + EngineLaunchSpec.engine_ranks for dp_rank and lifecycle operations, but update_weighter expects the public + metadata mesh to contain one single-rank entry per request server. + """ + metadata_engine_rank_mesh_array: list[list[int]] = [] + for engine_spec in engine_launch_specs: + request_entrypoint_servers = engine_spec.request_entrypoint_servers + if len(request_entrypoint_servers) > 1: + metadata_engine_rank_mesh_array.extend( + [server_process.worker_rank] for server_process in request_entrypoint_servers + ) + else: + metadata_engine_rank_mesh_array.append(list(engine_spec.engine_ranks)) + return metadata_engine_rank_mesh_array + def offload(self): """Offloads the model weights and KV cache.""" return self._sleep(level=2) @@ -275,30 +388,9 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: # currently only support ep > 1 and tp == 1 / ep == 1 and tp > 1 assert ep_size == 1 or tp_size == 1 if ep_size > 1: - dp_rank_found = False - # In the case of pure expert parallelism, each worker from all ranks serve url. - # `engine_rank_mesh_array` would miss the ep_size information in inner list, - # Therefore, we need to regroup them into `engine_rank_mesh_array_for_ep`. - # For example, ep_size = 2, work_size = 8: - # engine_rank_mesh_array = [[0],[1],[2],[3],[4],[5],[6],[7]] -> - # engine_rank_mesh_array_for_ep = [[0,1],[2,3],[4,5],[6,7]] - engine_rank_mesh_array_for_ep = [ - list(chain.from_iterable(self.engine_rank_mesh_array[i : i + ep_size])) - for i in range(0, len(self.engine_rank_mesh_array), ep_size) - ] - # dp_rank is the index of self.rank in the inner list of rank mesh array. - # For example, ep_size = 2, work_size = 8: - # engine_rank_mesh_array_for_ep = [[0,1],[2,3],[4,5],[6,7]] - # rank 3 is in [2, 3], dp_rank = [2, 3].index(3) = 1 - for engine_rank_mesh in engine_rank_mesh_array_for_ep: - if self.rank in engine_rank_mesh: - dp_rank = engine_rank_mesh.index(self.rank) - dp_rank_found = True - break - assert dp_rank_found, ( - f"self.rank: {self.rank} should be found in " - f"engine_rank_mesh_array_for_ep: {engine_rank_mesh_array_for_ep}" - ) + engine_launch_spec = self.engine_launch_spec + assert engine_launch_spec is not None + dp_rank = engine_launch_spec.engine_ranks.index(self.rank) backend_config = ( PytorchEngineConfig( diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index 4bd674b89..047fa2d5a 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -11,7 +11,13 @@ from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import XTUNER_DETERMINISTIC -from .worker import RolloutConfig, RolloutWorker +from .worker import ( + EngineLaunchSpec, + EngineLaunchSpecs, + RolloutConfig, + RolloutWorker, + ServerProcessSpec, +) class SGLangWorker(RolloutWorker): @@ -28,7 +34,8 @@ def __init__( from sglang.srt.entrypoints.http_server import launch_server self.server_func = launch_server - self.endpoints["health_generate"] = "health" + self.endpoints["health"] = "health" + self.endpoints["health_generate"] = "health_generate" self.endpoints["generate"] = "generate" self.endpoints["v1/chat/completions"] = "v1/chat/completions" self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path, trust_remote_code=True) @@ -41,6 +48,99 @@ def __init__( self.model_name = self.config.model_name self.enable_return_routed_experts = self.config.enable_return_routed_experts + @classmethod + def build_engine_launch_specs( + cls, + config: RolloutConfig, + rank_bundle_idx_list: list[tuple[int, int]], + rank_to_dist_init_addr: dict[int, str] | None = None, + ) -> EngineLaunchSpecs: + """Build SGLang server launch layout. + + SGLang starts one server per node in a logical engine. Only node 0 is + used as the rollout request entrypoint. + + Example with expert_parallel_size=16 and gpus_per_node=8: + rank_bundle_idx_list is: + [(0, 0), (1, 1), ..., (15, 15)] + + If rank_to_dist_init_addr is: + {0: "addr0", 1: "addr1", ..., 15: "addr15"} + + The launch spec is: + EngineLaunchSpec( + engine_ranks=(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15), + server_processes=( + ServerProcessSpec( + worker_rank=0, + placement_group_bundle_idxs=(0, 1, 2, 3, 4, 5, 6, 7), + dist_init_addr="addr0", + accepts_rollout_requests=True, + node_rank=0, + nnodes=2, + ), + ServerProcessSpec( + worker_rank=8, + placement_group_bundle_idxs=(8, 9, 10, 11, 12, 13, 14, 15), + dist_init_addr="addr0", + accepts_rollout_requests=False, + node_rank=1, + nnodes=2, + ), + ), + ) + + SGLang starts one server per node, so server_worker_ranks is (0, 8). + Only the node-0 server accepts rollout requests. + """ + num_workers = len(rank_bundle_idx_list) + num_gpus_per_engine = cls._get_num_gpus_per_engine(config) + if num_workers % num_gpus_per_engine != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." + ) + if num_gpus_per_engine > config.gpus_per_node and num_gpus_per_engine % config.gpus_per_node != 0: + raise ValueError( + "SGLang cross-node rollout requires num_gpus_per_engine to be divisible by gpus_per_node." + ) + + nnodes = max(1, num_gpus_per_engine // config.gpus_per_node) + engine_launch_specs: list[EngineLaunchSpec] = [] + for engine_start in range(0, num_workers, num_gpus_per_engine): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] + engine_ranks = tuple(rank for rank, _ in engine_meta) + engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) + # SGLang cross-node launch starts one server process per node. The + # first rank of each node owns that node's bundles, while only node + # 0 is exposed as the rollout request entrypoint. + server_ranks = engine_ranks[:: config.gpus_per_node] + engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[server_ranks[0]] + server_processes: list[ServerProcessSpec] = [] + for node_rank, server_rank in enumerate(server_ranks): + node_bundle_start = node_rank * config.gpus_per_node + node_bundle_end = node_bundle_start + config.gpus_per_node + server_processes.append( + ServerProcessSpec( + worker_rank=server_rank, + placement_group_bundle_idxs=engine_bundle_idxs[node_bundle_start:node_bundle_end], + dist_init_addr=engine_dist_init_addr, + accepts_rollout_requests=node_rank == 0, + node_rank=node_rank, + nnodes=nnodes, + ) + ) + engine_launch_specs.append( + EngineLaunchSpec( + engine_ranks=engine_ranks, + server_processes=tuple(server_processes), + ) + ) + return cls.validate_engine_launch_specs( + tuple(engine_launch_specs), + known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), + ) + def _get_request_payload(self, rollout_state: RolloutState) -> dict: sample_params = rollout_state.sample_params payload: dict[str, Any] = {"model": self.model_name} @@ -144,6 +244,17 @@ def _make_request(self, endpoint: str, payload=None): response.raise_for_status() return response.json() + def check_health(self) -> bool: + try: + response = requests.get( + f"{self.server_url}/{self.endpoints['health']}", + timeout=self.config.health_check_timeout_seconds, + ) + return response.status_code == 200 + except requests.RequestException as e: + self.logger.error(f"Health check failed for server {self.server_url}: {e}") + return False + def flush_cache(self): """Flush the cache of the server.""" # TODO: 支持 tp @@ -227,8 +338,13 @@ def _transform_rollout_config_to_server_configs(self): ) tp_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.tensor_parallel_size ep_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.expert_parallel_size - nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node) - node_rank = self.rank // self.config.gpus_per_node if nnodes > 1 else 0 + server_process_spec = self._get_current_server_process_spec() + nnodes = ( + server_process_spec.nnodes + if server_process_spec is not None + else max(1, num_gpus_per_engine // self.config.gpus_per_node) + ) + node_rank = server_process_spec.node_rank if server_process_spec is not None else 0 assigned_gpu_id = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) # SGLang 0.5.10 默认启用的 Piecewise CUDA Graph 在启动 warmup compile 阶段会报错。sglang的文档提到这个功能还是实验功能,可能还不太稳定(https://sgl-project-sglang-93.mintlify.app/optimization/cuda-graph#bug-report)。暂时先通过disable_piecewise_cuda_graph=True关掉改功能 diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index 9fef8306d..c6d32cde3 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -1,8 +1,8 @@ import asyncio -import os import threading import time from collections import OrderedDict +from enum import Enum from itertools import cycle from typing import TYPE_CHECKING, Any, Optional @@ -17,16 +17,40 @@ if TYPE_CHECKING: from .controller import WorkerInfo - from .worker import RolloutConfig, RolloutWorker -ROLLOUT_RAY_GET_TIMEOUT = int(os.getenv("XTUNER_ROLLOUT_RAY_GET_TIMEOUT", str(5 * 3600))) # default 5 hours logger = get_logger() +__all__ = [ + "format_response_body_preview", + "PartialRolloutHandler", + "SessionRouter", + "WorkerLifecycleState", +] + + +def format_response_body_preview(response: Any, limit: int = 512) -> str: + try: + body = response.text + except Exception as e: # pragma: no cover - response.text access normally does not fail + return f"" + if len(body) <= limit: + return repr(body) + return f"{body[:limit]!r}...(truncated, total_len={len(body)})" + + +class WorkerLifecycleState(str, Enum): + # Can serve rollout generation and control requests. + ACTIVE = "active" + # Not serving rollout requests; the rollout server may still hold resources. + INACTIVE = "inactive" + # Temporarily owned by recovery shutdown/init/check_health. + RECOVERING = "recovering" + class SessionRouter: def __init__( self, - worker_infos: dict[int, "WorkerInfo"], # worker: worker_status + worker_infos: dict[int, "WorkerInfo"], worker_infos_lock: Optional[threading.RLock] = None, max_sessions: int = 10000, max_idle_seconds: Optional[float] = 3600.0, @@ -70,12 +94,12 @@ def _choose_next_active_worker(self) -> tuple[int, Any]: rank = next(self._worker_cycler) if self._worker_infos_lock is None: info = self._worker_infos[rank] - if info and info.is_active: + if info and info.is_active() and info.is_request_entrypoint: return rank, info.actor else: with self._worker_infos_lock: info = self._worker_infos[rank] - if info and info.is_active: + if info and info.is_active() and info.is_request_entrypoint: return rank, info.actor return -1, None @@ -90,7 +114,7 @@ async def get_worker(self, session_id: int) -> Optional[Any]: else: with self._worker_infos_lock: info = self._worker_infos.get(worker_rank) - if info and info.is_active: + if info and info.is_active() and info.is_request_entrypoint: self._map[session_id] = (worker_rank, self._now()) return info.actor @@ -102,157 +126,6 @@ async def get_worker(self, session_id: int) -> Optional[Any]: return worker -class RolloutHealthChecker: - def __init__( - self, - config: "RolloutConfig", - workers_info: dict[int, "WorkerInfo"], - worker_infos_lock: Optional[threading.RLock] = None, - ): - self._workers_info = workers_info - self._worker_infos_lock = worker_infos_lock - self._check_interval = config.health_check_interval_seconds - self._check_failure_threshold = config.health_check_failure_threshold - self._stop_event: Optional[threading.Event] = None - self._pause_event: Optional[threading.Event] = None - self._thread: Optional[threading.Thread] = None - - def start(self) -> None: - if self._thread and self._thread.is_alive(): - return - - self._stop_event = threading.Event() - self._pause_event = threading.Event() - self._pause_event.set() # 启动时设置为暂停状态,开始generation后再调用restart方法恢复 - self._thread = threading.Thread(target=self._run_loop, daemon=True) - self._thread.start() - logger.info("RolloutHealthChecker started.") - - def stop(self) -> None: - if not self._thread: - return - - assert self._stop_event is not None - self._stop_event.set() - if self._pause_event: - self._pause_event.clear() - self._thread.join(timeout=5) - self._thread = None - self._stop_event = None - self._pause_event = None - logger.info("RolloutHealthChecker stopped.") - - def pause(self) -> None: - if self._pause_event is None: - return - self._pause_event.set() - logger.info("RolloutHealthChecker paused.") - - def is_paused(self) -> bool: - return self._pause_event is None or self._pause_event.is_set() - - def resume(self) -> None: - if self._pause_event is None: - return - self._pause_event.clear() - logger.info("RolloutHealthChecker restarted.") - - def run_once(self) -> None: - logger.debug("RolloutHealthChecker running health checks for all workers.") - if self._worker_infos_lock is None: - workers_snapshot = { - rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() - } - else: - with self._worker_infos_lock: - workers_snapshot = { - rank: (info.actor, info.url, info.is_active) for rank, info in self._workers_info.items() - } - - workers_to_check = [ - (rank, actor, url, is_active) for rank, (actor, url, is_active) in workers_snapshot.items() if is_active - ] - if not workers_to_check: - return - - tasks = [ - check_worker_health(actor, rank, url, is_active, self._check_failure_threshold) - for rank, actor, url, is_active in workers_to_check - ] - - async def _run_checks() -> list[bool]: - return await asyncio.gather(*tasks) - - check_results = asyncio.run(_run_checks()) - inactive_workers = [] - for (rank, _, _, _), is_healthy in zip(workers_to_check, check_results): - if not is_healthy: - logger.warning(f"Worker {rank} failed health check. Marking as inactive.") - if self._worker_infos_lock is None: - self._workers_info[rank].is_active = False - inactive_worker = self._workers_info[rank].actor - else: - with self._worker_infos_lock: - self._workers_info[rank].is_active = False - inactive_worker = self._workers_info[rank].actor - if inactive_worker is None: - logger.error(f"[RolloutHealthChecker] Worker {rank} has no actor reference. Skipping shutdown.") - continue - inactive_workers.append((rank, inactive_worker)) - else: - logger.debug(f"[RolloutHealthChecker] Worker {rank} passed health check.") - - for rank, inactive_worker in inactive_workers: - try: - ray.get(inactive_worker.offload.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - except Exception as e: - logger.error(f"Exception while offloading worker {rank}: {e}") - - try: - ray.get(inactive_worker.shutdown.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT) # type: ignore[attr-defined] - except Exception as e: - logger.error(f"Exception while shutting down worker {rank}: {e}") - - def _run_loop(self) -> None: - assert self._stop_event is not None and self._pause_event is not None - logger.info("RolloutHealthChecker loop started.") - - while not self._stop_event.is_set(): - while self._pause_event.is_set() and not self._stop_event.is_set(): - self._stop_event.wait(timeout=0.5) - - if self._stop_event.is_set(): - break - - if not self._pause_event.is_set() and not self._stop_event.is_set(): - self.run_once() - - if self._stop_event.wait(self._check_interval): - break - - -async def check_worker_health( - worker: "RolloutWorker", rank: int, url: str, is_active: bool, failure_threshold: int = 3 -) -> bool: - if worker is None or not is_active: - logger.warning("Worker has no actor reference or is marked inactive.") - return False - failing_count = 0 - while failing_count < failure_threshold: - try: - health_status = await worker.check_health.remote() # type: ignore[attr-defined] - if health_status: - return True - failing_count += 1 - logger.warning(f"Health check failed for worker {rank} at {url}. Failure count: {failing_count}") - except Exception as e: - failing_count += 1 - logger.error( - f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" - ) - return False - - async def _resolve_routed_experts(routed_experts: np.ndarray | RayObjectRef) -> np.ndarray: if isinstance(routed_experts, RayObjectRef): routed_experts_value = await routed_experts diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 706be0e71..fd674ac79 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -8,15 +8,16 @@ import time import traceback from abc import abstractmethod +from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, TypeAlias, Union, cast import httpx import ray import requests # type: ignore[import-untyped] from cyclopts import Group, Parameter from packaging.version import Version -from pydantic import BaseModel, ConfigDict, PrivateAttr +from pydantic import BaseModel, ConfigDict from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from typing_extensions import Annotated @@ -38,8 +39,9 @@ from xtuner.v1.utils import get_logger from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult +from .health_manager import ROLLOUT_RAY_GET_TIMEOUT from .session_server import SessionServerActor -from .utils import ROLLOUT_RAY_GET_TIMEOUT, PartialRolloutHandler +from .utils import PartialRolloutHandler if TYPE_CHECKING: @@ -50,6 +52,79 @@ ROLLOUT_CONCURRENCY_GROUP_GENERATE = "generate" +@dataclass(frozen=True) +class ServerProcessSpec: + """How to start one rollout server process.""" + + # Worker rank that owns this server process. + worker_rank: int + # Placement-group bundle indexes assigned to this server process. + placement_group_bundle_idxs: tuple[int, ...] + # Distributed init address used by every server process in the same engine. + # Filled after init_dist_port initializes worker-local ports. + dist_init_addr: str | None = None + # Whether this server is exposed as a rollout request entrypoint. Some + # backends launch extra server processes that must participate in + # lifecycle/health operations but must not be added to worker_server_urls_map + # or receive normal rollout traffic. + accepts_rollout_requests: bool = True + # Node index of this server inside a multi-node logical engine. + node_rank: int = 0 + # Number of nodes used by this logical engine. + nnodes: int = 1 + + +@dataclass(frozen=True) +class EngineLaunchSpec: + """How to launch rollout servers for one logical inference engine.""" + + # All worker ranks that form this logical inference engine. + engine_ranks: tuple[int, ...] + # Server processes required by this engine. + server_processes: tuple[ServerProcessSpec, ...] + + @property + def server_worker_ranks(self) -> tuple[int, ...]: + return tuple(server.worker_rank for server in self.server_processes) + + @property + def request_entrypoint_servers(self) -> tuple[ServerProcessSpec, ...]: + return tuple(server for server in self.server_processes if server.accepts_rollout_requests) + + @property + def request_entrypoint_worker_ranks(self) -> tuple[int, ...]: + return tuple(server.worker_rank for server in self.request_entrypoint_servers) + + @property + def placement_group_bundle_idxs(self) -> tuple[int, ...]: + return tuple( + bundle_idx for server in self.server_processes for bundle_idx in server.placement_group_bundle_idxs + ) + + +EngineLaunchSpecs: TypeAlias = tuple[EngineLaunchSpec, ...] + + +def get_rollout_worker_base_cls(config: "RolloutConfig") -> type["RolloutWorker"]: + if config.rollout_backend == "lmdeploy": + from .lmdeploy import LMDeployWorker + + return LMDeployWorker + elif config.rollout_backend == "vllm": + from .vllm import vLLMWorker + + return vLLMWorker + elif config.rollout_backend == "sglang": + from .sglang import SGLangWorker + + return SGLangWorker + else: + raise NotImplementedError( + f"Rollout backend is not supported: {config.rollout_backend}. " + "Please set XTUNER_USE_LMDEPLOY or XTUNER_USE_VLLM or XTUNER_USE_SGLANG environment variable." + ) + + class RolloutConfig(BaseModel): """Rollout worker configuration for XTuner. @@ -311,6 +386,21 @@ class RolloutConfig(BaseModel): help="Interval in seconds between rollout worker health checks.", ), ] = 30.0 + # LMDeploy /health returns an EngineHealthMonitor snapshot. The monitor's + # backend probe timeout defaults to 10s and its poll interval defaults to + # 12s, so XTuner's HTTP read timeout needs to be longer than 10s to avoid + # turning a slow but informative /health response into a client-side + # timeout. + health_check_timeout_seconds: Annotated[ + float, + Parameter( + group=infer_group, + help=( + "HTTP timeout in seconds for rollout worker health check requests. " + "The default is longer than LMDeploy's 10s backend health probe timeout." + ), + ), + ] = 15.0 health_check_failure_threshold: Annotated[ int, Parameter( @@ -318,7 +408,6 @@ class RolloutConfig(BaseModel): help="Number of consecutive health check failures required before marking a worker inactive.", ), ] = 3 - _logged_server_urls_per_engine: bool = PrivateAttr(default=False) @property def rollout_backend(self) -> str: @@ -335,51 +424,26 @@ def rollout_backend(self) -> str: ) return backend - @property - def server_urls_per_engine(self) -> int: - # server_urls_per_engine is introduced for lmdeploy ep settings - # for now only lmdeploy pytorch backend with ep > 1 requires multiple server urls per engine - if self.rollout_backend == "lmdeploy" and self.expert_parallel_size > 1: - # when expert parallelism is used, lmdeploy requires `expert_parallel_size` server instances per engine - if not self._logged_server_urls_per_engine: - self._logged_server_urls_per_engine = True - get_logger().info( - f"Setting server_urls_per_engine={self.expert_parallel_size} due to expert parallelism in LMDeploy." - ) - return self.expert_parallel_size - else: - return 1 - @property def num_gpus_per_engine(self) -> int: return self.expert_parallel_size if self.expert_parallel_size > 1 else self.tensor_parallel_size - def get_active_servers_count(self, num_rollout_workers: int) -> tuple[int, int]: - """Calculate the number of active servers and nodes per engine.""" - # NOTE: Since different inference engines have different launch methods, - # the number of nodes contained in each engine is not consistent. - # For example, sglang requires starting an inference engine for each node, - # while lmdeploy and vllm do not. Therefore, calculate active servers from the rollout config. - nodes_per_engine = ( - 1 - if self.rollout_cross_node_comm or self.num_gpus_per_engine < self.gpus_per_node - else self.num_gpus_per_engine // self.gpus_per_node - ) - active_servers_count = max( - 1, - int((num_rollout_workers // self.num_gpus_per_engine) * nodes_per_engine * self.server_urls_per_engine), + @property + def generate_concurrency_per_instance(self) -> int: + assert self.rollout_max_batch_size_per_instance is not None, ( + "rollout_max_batch_size_per_instance must be set before computing generate concurrency." ) - return active_servers_count, nodes_per_engine + return math.ceil(self.rollout_max_batch_size_per_instance * self.allow_over_concurrency_ratio) def get_controller_generate_concurrency(self, placement_group: "PlacementGroup") -> int: - active_worker_count, _ = self.get_active_servers_count(len(placement_group.bundle_specs)) - assert self.rollout_max_batch_size_per_instance is not None, ( - "rollout_max_batch_size_per_instance must be set before building RolloutController." + worker_base_cls = get_rollout_worker_base_cls(self) + sorted_bundle_idxs, _, _, _ = AutoAcceleratorWorkers.get_spmd_info(placement_group) + rank_bundle_idx_list = [(rank, bundle_idx) for rank, bundle_idx in enumerate(sorted_bundle_idxs)] + engine_launch_specs = worker_base_cls.build_engine_launch_specs(self, rank_bundle_idx_list) + request_entrypoint_count = sum( + len(engine_spec.request_entrypoint_servers) for engine_spec in engine_launch_specs ) - concurrency_per_worker = math.ceil( - self.rollout_max_batch_size_per_instance * self.allow_over_concurrency_ratio - ) - generate_max_concurrency = active_worker_count * concurrency_per_worker + generate_max_concurrency = request_entrypoint_count * self.generate_concurrency_per_instance return generate_max_concurrency def model_post_init(self, __context: Any) -> None: @@ -484,6 +548,7 @@ def __init__( Defaults to "GPU". """ self.config = config + self._default_skip_load_weights = config.skip_load_weights self.rank = rank self.master_addr = master_addr # ray master self.master_port = master_port @@ -491,12 +556,10 @@ def __init__( self.accelerator = accelerator self.server_func: Callable self.endpoints: dict[str, str] = dict() - self.engine_rank_mesh_array: list[list[int]] + self.engine_rank_mesh_array: list[list[int]] = [] + self.engine_launch_spec: EngineLaunchSpec | None = None # http_concurrency is calculated based on the max batch size per engine and the total number of engines - assert config.rollout_max_batch_size_per_instance, ( - "rollout_max_batch_size_per_instance must be set in RolloutConfig" - ) - http_concurrency = math.ceil(config.rollout_max_batch_size_per_instance * config.allow_over_concurrency_ratio) + http_concurrency = config.generate_concurrency_per_instance limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) self.server_task = None @@ -519,27 +582,207 @@ def __init__( self.partial_rollout_handler = PartialRolloutHandler() self.enable_partial_rollout: bool = False + @staticmethod + def _get_num_gpus_per_engine(config: RolloutConfig) -> int: + return config.num_gpus_per_engine + + @classmethod + def validate_engine_launch_specs( + cls, + engine_launch_specs: EngineLaunchSpecs, + *, + known_worker_ranks: tuple[int, ...] | None = None, + ) -> EngineLaunchSpecs: + """Validate backend launch layout before the controller launches + servers.""" + if not engine_launch_specs: + raise ValueError("engine_launch_specs must define at least one engine.") + + known_worker_rank_set = set(known_worker_ranks) if known_worker_ranks is not None else None + seen_engine_ranks: set[int] = set() + seen_server_ranks: set[int] = set() + seen_bundle_idxs: set[int] = set() + for engine_index, engine_spec in enumerate(engine_launch_specs): + if not engine_spec.engine_ranks: + raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one engine rank.") + engine_rank_set = set(engine_spec.engine_ranks) + if len(engine_rank_set) != len(engine_spec.engine_ranks): + raise ValueError( + f"EngineLaunchSpec[{engine_index}] has duplicate engine ranks: {engine_spec.engine_ranks}." + ) + if known_worker_rank_set is not None: + unknown_engine_ranks = sorted( + rank for rank in engine_spec.engine_ranks if rank not in known_worker_rank_set + ) + if unknown_engine_ranks: + raise ValueError( + f"EngineLaunchSpec[{engine_index}] references unknown engine ranks: {unknown_engine_ranks}." + ) + duplicated_engine_ranks = sorted(rank for rank in engine_spec.engine_ranks if rank in seen_engine_ranks) + if duplicated_engine_ranks: + raise ValueError( + f"EngineLaunchSpec[{engine_index}] engine ranks appear in more than one engine: " + f"{duplicated_engine_ranks}." + ) + seen_engine_ranks.update(engine_spec.engine_ranks) + + if not engine_spec.server_processes: + raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one server process.") + + for server_process in engine_spec.server_processes: + server_rank = server_process.worker_rank + if server_rank not in engine_rank_set: + raise ValueError( + f"EngineLaunchSpec[{engine_index}] server worker_rank={server_rank} " + f"must be part of engine_ranks={engine_spec.engine_ranks}." + ) + if server_rank in seen_server_ranks: + raise ValueError(f"Server worker_rank={server_rank} appears in more than one server process.") + seen_server_ranks.add(server_rank) + + if not server_process.placement_group_bundle_idxs: + raise ValueError(f"Server worker_rank={server_rank} must own at least one placement-group bundle.") + if len(set(server_process.placement_group_bundle_idxs)) != len( + server_process.placement_group_bundle_idxs + ): + raise ValueError( + f"Server worker_rank={server_rank} has duplicate placement-group bundles: " + f"{server_process.placement_group_bundle_idxs}." + ) + duplicated_bundle_idxs = sorted( + bundle_idx + for bundle_idx in server_process.placement_group_bundle_idxs + if bundle_idx in seen_bundle_idxs + ) + if duplicated_bundle_idxs: + raise ValueError( + f"Placement-group bundles are assigned to multiple server processes: {duplicated_bundle_idxs}." + ) + seen_bundle_idxs.update(server_process.placement_group_bundle_idxs) + + if server_process.nnodes < 1: + raise ValueError(f"Server worker_rank={server_rank} must have nnodes >= 1.") + if server_process.node_rank < 0 or server_process.node_rank >= server_process.nnodes: + raise ValueError( + f"Server worker_rank={server_rank} has invalid node_rank={server_process.node_rank} " + f"for nnodes={server_process.nnodes}." + ) + + if not engine_spec.request_entrypoint_servers: + raise ValueError(f"EngineLaunchSpec[{engine_index}] must expose at least one request entrypoint.") + + if known_worker_rank_set is not None: + missing_engine_ranks = sorted(known_worker_rank_set - seen_engine_ranks) + if missing_engine_ranks: + raise ValueError( + f"EngineLaunchSpecs do not cover known worker ranks in engine_ranks: {missing_engine_ranks}." + ) + + return engine_launch_specs + + @classmethod + def build_engine_launch_specs( + cls, + config: RolloutConfig, + rank_bundle_idx_list: list[tuple[int, int]], + rank_to_dist_init_addr: dict[int, str] | None = None, + ) -> EngineLaunchSpecs: + """Build default launch spec: one request-serving server per engine.""" + num_gpus_per_engine = cls._get_num_gpus_per_engine(config) + num_workers = len(rank_bundle_idx_list) + if num_workers % num_gpus_per_engine != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." + ) + + engine_launch_specs: list[EngineLaunchSpec] = [] + for engine_start in range(0, num_workers, num_gpus_per_engine): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] + engine_ranks = tuple(rank for rank, _ in engine_meta) + engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) + engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] + engine_launch_specs.append( + EngineLaunchSpec( + engine_ranks=engine_ranks, + server_processes=( + ServerProcessSpec( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=engine_bundle_idxs, + dist_init_addr=engine_dist_init_addr, + ), + ), + ) + ) + return cls.validate_engine_launch_specs( + tuple(engine_launch_specs), + known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), + ) + + @classmethod + def build_metadata_engine_rank_mesh_array( + cls, + engine_launch_specs: EngineLaunchSpecs, + ) -> list[list[int]]: + """Build the public engine mesh returned in rollout metadata. + + By default, the public metadata mesh matches the logical engine topology. Backends with multiple request + servers per logical engine can override this to preserve their legacy update-weight mesh semantics. + """ + return [list(engine_spec.engine_ranks) for engine_spec in engine_launch_specs] + + def _get_current_server_process_spec( + self, + engine_launch_spec: EngineLaunchSpec | None = None, + ) -> ServerProcessSpec | None: + engine_launch_spec = engine_launch_spec or self.engine_launch_spec + if engine_launch_spec is None: + return None + + for server_process_spec in engine_launch_spec.server_processes: + if server_process_spec.worker_rank == self.rank: + return server_process_spec + raise RuntimeError( + f"Engine launch spec does not include rollout worker rank={self.rank} " + f"in server_worker_ranks={engine_launch_spec.server_worker_ranks}." + ) + def set_enable_partial_rollout(self, enable: bool) -> None: self.enable_partial_rollout = enable - def init(self, dist_init_addr: str | None = None) -> tuple[int, str]: + def init( + self, + *, + engine_launch_spec: EngineLaunchSpec | None = None, + ) -> tuple[int, str]: """Initialize the worker and launch the server. - Args: - dist_init_addr (str): The distributed initialization address. - If not provided, the one generated by `init_dist_port` is used. - Returns: Tuple[int, str]: A tuple containing the worker's rank and its server URL. """ - if dist_init_addr is not None: - self.dist_init_addr = dist_init_addr + if engine_launch_spec is not None: + # Initial controller startup passes the immutable launch spec and caches + # it on the actor. Recovery calls init() without arguments after + # shutdown, intentionally reusing this cached placement/dist layout. + self.engine_launch_spec = engine_launch_spec + server_process_spec = cast( + ServerProcessSpec, + self._get_current_server_process_spec(engine_launch_spec), + ) + self.engine_bundle_idxs = list(server_process_spec.placement_group_bundle_idxs) + if server_process_spec.dist_init_addr is not None: + self.dist_init_addr = server_process_spec.dist_init_addr self.receive_abort_request.clear() self._launch_server() self._start_session_server() return (self.rank, self.server_url) + def set_skip_load_weights(self, skip_load_weights: bool) -> None: + self.config = self.config.model_copy(update={"skip_load_weights": skip_load_weights}) + + def restore_skip_load_weights(self) -> None: + self.config = self.config.model_copy(update={"skip_load_weights": self._default_skip_load_weights}) + def init_dist_port(self) -> str: """Initialize distributed communication ports. @@ -570,6 +813,13 @@ def shutdown(self, *, stop_session_server: bool = False): server_task = self.server_task self._request_server_terminate() ray.cancel(server_task, force=True, recursive=True) + try: + ray.get(server_task, timeout=60) + except ray.exceptions.GetTimeoutError: + self.logger.warning(f"Worker {self.rank} server task did not stop within shutdown timeout.") + raise + except Exception as e: + self.logger.debug(f"Worker {self.rank} server task stopped after shutdown: {e}") self.server_task = None return @@ -665,10 +915,26 @@ def check_health(self) -> bool: "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {self.config.api_key}", } + health_url = f"{self.server_url}/{self.endpoints['health_generate']}" response = requests.get( - f"{self.server_url}/{self.endpoints['health_generate']}", headers=headers, timeout=5.0 + health_url, + headers=headers, + timeout=self.config.health_check_timeout_seconds, ) - return response.status_code == 200 + if response.status_code == 200: + return True + health_message = "" + try: + payload = response.json() + if isinstance(payload, dict) and payload.get("message"): + health_message = f", message={payload['message']!r}" + except ValueError: + pass + self.logger.warning( + f"Health check returned non-200 for server {health_url}: " + f"status_code={response.status_code}{health_message}" + ) + return False except requests.RequestException as e: self.logger.error(f"Health check failed for server {self.server_url}: {e}") return False @@ -1224,20 +1490,6 @@ def _check_infer_engine_version(self, return_token_ids: bool): ) self.check_flag = False - def _set_engine_rank_mesh_array(self, engine_rank_mesh_array: list[list[int]]): - self.engine_rank_mesh_array = engine_rank_mesh_array - - def _set_engine_bundle_idxs(self, engine_bundle_idxs: list[int]): - """Set the bundle indices for the inference engine. - - This is used by some backends (like LMDeploy with Ray executor) to - know which bundles in the placement group belong to this engine. - - Args: - engine_bundle_idxs (list[int]): A list of bundle indices. - """ - self.engine_bundle_idxs = engine_bundle_idxs - @abstractmethod def _get_request_payload(self, rollout_state: RolloutState) -> dict: """Abstract method to create a generation request. diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 142ddbc56..faff64e4a 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -574,7 +574,6 @@ def _init_common(self, cfg: BaseRLTrainerConfig, *, meta_path: str, logger_tag: self._init_train_state(cfg) self._init_train_worker_config(cfg, log_dir) self._init_rollout_config(cfg, log_dir) - self._ensure_rollout_http_concurrency(cfg) self._init_runtime_flags(cfg) self._advantage_estimator = cfg.advantage_estimator_config.build() self._cpu_resource_manager: CPUResourceManager | None = None @@ -660,18 +659,21 @@ def _init_rollout_config(self, cfg: BaseRLTrainerConfig, log_dir: Path) -> None: ) self._rollout_config = cfg.rollout_config - def _ensure_rollout_http_concurrency(self, cfg: BaseRLTrainerConfig) -> None: + def _ensure_rollout_http_concurrency( + self, + cfg: BaseRLTrainerConfig, + rollout_pg, + ) -> None: rollout_max_batch_size = cfg.rollout_config.rollout_max_batch_size_per_instance if rollout_max_batch_size is None or rollout_max_batch_size <= 0: return - if isinstance(cfg, RLDisaggregatedTrainerConfig): - rollout_worker_count = cfg.rollout_resources.num_workers - elif isinstance(cfg, RLColocateTrainerConfig): - rollout_worker_count = cfg.resources.num_workers - else: - rollout_worker_count = 1 - active_rollout_worker_count, _ = cfg.rollout_config.get_active_servers_count(rollout_worker_count) + current_http_concurrency = math.ceil(rollout_max_batch_size * cfg.rollout_config.allow_over_concurrency_ratio) + if current_http_concurrency <= 0: + return + + total_generate_concurrency = cfg.rollout_config.get_controller_generate_concurrency(rollout_pg) + active_rollout_worker_count = total_generate_concurrency // current_http_concurrency if active_rollout_worker_count <= 0: return @@ -690,7 +692,6 @@ def _ensure_rollout_http_concurrency(self, cfg: BaseRLTrainerConfig) -> None: ) required_http_concurrency = math.ceil(scheduled_http_requests / active_rollout_worker_count) - current_http_concurrency = math.ceil(rollout_max_batch_size * cfg.rollout_config.allow_over_concurrency_ratio) if current_http_concurrency >= required_http_concurrency: return @@ -929,7 +930,7 @@ def _train_one_batch( # 共卡训练前切换资源:检查 rollout -> offload rollout -> onload train。 if offload_rollout_before_train: ray.get( - self.rollout_controller.ensure_workers_healthy_before_training.remote(), + self.rollout_controller.check_and_shutdown_inactive_workers.remote(), timeout=RL_TRAINER_RAY_GET_TIMEOUT, ) ray.get(self.rollout_controller.offload.remote(), timeout=RL_TRAINER_RAY_GET_TIMEOUT) @@ -1589,6 +1590,7 @@ def __init__(self, cfg: RLColocateTrainerConfig): self._cpu_resource_manager = CPUResourceManager(self._pg) self._cpu_resource_manager.log_initial_snapshot() set_cpu_resource_manager(self._cpu_resource_manager) + self._ensure_rollout_http_concurrency(cfg, self._pg) if self._debug_rollout: if self._rollout_config.skip_load_weights: @@ -1773,8 +1775,18 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool timer_name = "sync_weight" if should_sync_weights else "switch_to_rollout" with timer(timer_name, step_timer_dict): if should_sync_weights: - bind_train_rollout(train_controller=self.train_controller, rollout_controller=self.rollout_controller) - ray.get(self.rollout_controller.onload_weights.remote(), timeout=RL_TRAINER_RAY_GET_TIMEOUT) + ray.get( + self.rollout_controller.restart_inactive_workers.remote(), + timeout=RL_TRAINER_RAY_GET_TIMEOUT, + ) + bind_train_rollout( + train_controller=self.train_controller, + rollout_controller=self.rollout_controller, + ) + ray.get( + self.rollout_controller.onload_weights.remote(), + timeout=RL_TRAINER_RAY_GET_TIMEOUT, + ) self.train_controller.update_weights() self.logger.info("Rollout workers update weights successfully in colocate mode") self.train_controller.offload(target="model") @@ -1801,6 +1813,7 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): self._cpu_resource_manager = CPUResourceManager([self._train_pg, self._rollout_pg]) self._cpu_resource_manager.log_initial_snapshot() set_cpu_resource_manager(self._cpu_resource_manager) + self._ensure_rollout_http_concurrency(cfg, self._rollout_pg) self.train_controller = self._train_worker_cfg.build(self._train_pg) self.rollout_controller = self._rollout_config.build(self._rollout_pg) if _trainer_config_needs_routed_api_proxy(cfg):