Refine rollout worker health check and recovery lifecycle#1877
Refine rollout worker health check and recovery lifecycle#1877YanhuiDua wants to merge 16 commits into
Conversation
9cb357d to
a766fdb
Compare
|
@claude review |
|
Claude: ## Summary Well-structured refactor that introduces IssuesCritical
Warning
Nit
ProduceBatchResult impactNot affected — no changes to batch accounting, status conversion, or reward/filter paths. RoutedExperts impactNot affected — no changes to routed-experts plumbing, object ref ownership, or extra_fields handling. VerdictREQUEST_CHANGES — the |
|
@claude review |
| } | ||
| for future in as_completed(future_to_group): | ||
| group = future_to_group[future] | ||
| try: | ||
| group_recovery_results[group.ranks] = future.result() | ||
| except Exception: | ||
| logger.exception(f"Failed to restart rollout worker group ranks={group.ranks}.") | ||
| group_recovery_results[group.ranks] = False | ||
|
|
||
| if self._is_stopping(): | ||
| for group in sorted_failed_groups: |
There was a problem hiding this comment.
Claude: Critical: run_once() still only triggers recover_failed_workers() when all workers are down after the health check. If a single lifecycle group fails, it's marked INACTIVE by _check_active_workers_and_mark_failed_groups() but no recovery or cleanup is attempted — the failed group's GPU memory remains allocated indefinitely until the next training barrier (ensure_workers_healthy_before_training).
The old RolloutHealthChecker.run_once() at least offloaded and shut down failed workers immediately. The new code should either:
- Attempt recovery or at least offload/shutdown for failed groups during generation, or
- Document this as an intentional design decision with a comment explaining why partial-group recovery during generation is unsafe.
Without either, a long-running generation phase can leak GPU memory from dead workers.
ProduceBatchResult impact: Not directly affected, but reduced worker pool capacity can affect group_gen_* timing if generation load concentrates on fewer workers.
| assert self._stop_event is not None | ||
| self._stopped = True | ||
| self._stop_event.set() | ||
| if self._pause_event: |
There was a problem hiding this comment.
Claude: Nit: When join() times out, the method returns early without clearing self._thread, self._stop_event, or self._pause_event. This means a subsequent start() call will see self._thread as non-None and is_alive() still True, so it short-circuits without creating a new thread. The stale references could also confuse is_paused() or other state queries.
Consider at least logging that the manager is in a degraded state, or setting self._stopped = True so _is_stopping() guards subsequent operations.
|
|
||
| if failed_groups and not self._is_stopping(): | ||
| with self._operation_lock: | ||
| current_workers_snapshot = self.snapshot_workers() | ||
| active_groups = { | ||
| worker.lifecycle_group_ranks or (worker.rank,) | ||
| for worker in current_workers_snapshot.values() | ||
| if worker.active | ||
| } | ||
| failed_groups = failed_groups & active_groups | ||
| for group_ranks in failed_groups: | ||
| self._set_group_lifecycle_state(group_ranks, WorkerLifecycleState.INACTIVE) | ||
|
|
||
| return len(workers_to_check) | ||
|
|
||
| def _check_workers_health(self, workers_to_check: list[WorkerSnapshot]) -> list[bool]: | ||
| """Run periodic check_health probes concurrently.""" | ||
| if self._check_failure_threshold <= 0: |
There was a problem hiding this comment.
Claude: Warning: Periodic health checks now use only check_health (basic HTTP liveness) for all workers, including request-entrypoint workers. The previous implementation (in the initial patch 2 version) used check_health_generate for request entrypoints, which validates the inference engine is actually functional.
This means a worker whose HTTP server is alive but whose inference engine is broken (e.g., CUDA OOM, stuck model load) will pass periodic health checks indefinitely. The generate-readiness verification now only happens during recovery restart, not during ongoing monitoring.
If this is intentional (to avoid the cost of generate probes every check interval), consider documenting why — otherwise this is a regression in failure detection coverage.
|
Claude: ## Re-review Summary (after patches 4-5) Good progress since the first review. The Remaining IssuesCritical
Warning
Nit
Previous Issues — Fixed
ProduceBatchResult impact: not affected — no changes to batch accounting, status conversion, or reward/filter paths. RoutedExperts impact: not affected — no changes to routed-experts plumbing, object ref ownership, or extra_fields handling. VerdictREQUEST_CHANGES — the |
be11ead to
50f26e4
Compare
d31b0d9 to
bf1b669
Compare
|
对于实验中偶发出现的health checker失败的现象,进行如下修复:
对于实验中出现的503 service unaviable error的错误,目前并不清楚错误原因,lmdeploy中会提供error message,xtuner新增打印error message,后续再观察下错误信息,预计大概率是因为timeout |
152f497 to
1170ec0
Compare
| self.engine_rank_mesh_array, | ||
| self.worker_server_urls_map, | ||
| self.rank2info, | ||
| ) = self._init_workers(placement_group) |
There was a problem hiding this comment.
self.server_process_rank2info 作为 _init_workers 返回值和其他保持一致
| # may expose a different update-weight mesh than their internal launch | ||
| # topology, e.g. LMDeploy EP has one logical engine but one public entry | ||
| # per request-serving EP rank. | ||
| engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs) |
There was a problem hiding this comment.
rename to entrypoint_ranks_array
| self.engine_rank_mesh_array: List[List[int]] = [] | ||
| self.worker_server_urls_map: dict[int, str] = {} | ||
| # Active rollout workers keep the origin/main meaning: workers that accept rollout requests. | ||
| self.rank2info: dict[int, WorkerInfo] = {} |
There was a problem hiding this comment.
rename to entrypoint_server_rank2info, so it's easy to see its difference with self.server_process_rank2info
| inactive_workers = [ | ||
| f"rank={rank}, url={info.url}" for rank, info in self.rank2info.items() if not info.is_active | ||
| ] | ||
| await asyncio.to_thread(self.health_manager.ensure_workers_healthy_before_training) |
There was a problem hiding this comment.
复用了周期健康检查逻辑,而 _check_workers_health() 在失败次数没到阈值时会把失败 worker 继续视为 active:health_manager.py:333。默认阈值是 3,所以训练前第一次 final health check 失败时,recover_failed_workers() 会认为没有 failed group,controller 也看不到 inactive worker,然后 rl_trainer.py:931 这条共卡切换路径继续往下走。旧实现训练前检查是一次失败就标 inactive 并恢复。
|
@claude review |
4ce2380 to
a9da73f
Compare
|
@claude review |
| thread.join(timeout=HEALTH_MANAGER_STOP_JOIN_TIMEOUT) | ||
| if thread.is_alive(): | ||
| logger.warning( | ||
| f"RolloutHealthManager stop timed out after {HEALTH_MANAGER_STOP_JOIN_TIMEOUT}s; " | ||
| "health thread is still exiting." | ||
| ) | ||
| return | ||
|
|
||
| self._thread = None |
There was a problem hiding this comment.
Claude: Nit: stop() returns early on timeout without clearing stale references.
When thread.is_alive() is true after the timeout, this return leaves self._thread, self._stop_event, and self._pause_event set to their current values. If start() is subsequently called, the is_alive() check at line 82 handles the dead-thread case correctly. However, calling stop() a second time would re-join the same (now potentially dead) thread — harmless but confusing.
Consider clearing the references even on timeout, or setting self._thread = None so subsequent calls are no-ops:
| thread.join(timeout=HEALTH_MANAGER_STOP_JOIN_TIMEOUT) | |
| if thread.is_alive(): | |
| logger.warning( | |
| f"RolloutHealthManager stop timed out after {HEALTH_MANAGER_STOP_JOIN_TIMEOUT}s; " | |
| "health thread is still exiting." | |
| ) | |
| return | |
| self._thread = None | |
| thread.join(timeout=HEALTH_MANAGER_STOP_JOIN_TIMEOUT) | |
| if thread.is_alive(): | |
| logger.warning( | |
| f"RolloutHealthManager stop timed out after {HEALTH_MANAGER_STOP_JOIN_TIMEOUT}s; " | |
| "health thread is still exiting." | |
| ) | |
| self._thread = None | |
| self._stop_event = None | |
| self._pause_event = None |
| health_results = ray.get( | ||
| [worker.actor.check_health.remote() for worker in group.workers], # type: ignore[attr-defined] | ||
| timeout=ROLLOUT_RAY_GET_TIMEOUT, | ||
| ) | ||
| if self._is_stopping(): | ||
| self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) | ||
| return False | ||
| unhealthy_ranks = [ | ||
| worker.rank for worker, is_healthy in zip(group.workers, health_results) if not is_healthy | ||
| ] | ||
| if unhealthy_ranks: | ||
| logger.error( | ||
| f"Restarted rollout worker group ranks={group.ranks} has unhealthy ranks={unhealthy_ranks}." | ||
| ) | ||
| self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) | ||
| return False |
There was a problem hiding this comment.
Claude: Warning: Post-restart workers are no longer verified for generate-readiness.
check_health_generate() and _check_worker_group_generate_ready_after_restart() were both removed in this PR. After restart, only check_health() (HTTP liveness via the health endpoint) verifies the worker before marking it ACTIVE.
While check_health() does hit the backend's health/generate endpoint, it doesn't verify end-to-end inference capability (model loaded, CUDA contexts warmed). The removed check_health_generate() sent an actual mini-inference request (e.g., max_tokens=1 chat completion for LMDeploy) to confirm the engine is fully functional.
Given the PR's mention of "503 service unavailable" errors being observed in experiments, a newly-restarted server that passes the liveness check but hasn't finished internal initialization could produce these exact errors.
If removing the generate-readiness probe is intentional (to avoid the timeout issues mentioned in the PR comments), consider:
- Adding it back as a non-blocking post-restart validation that logs but doesn't block ACTIVE status
- Or documenting that the first few requests after recovery may fail and will be retried by the client
总览PR 的方向是对的:把 rollout server 启动布局和 health recovery 从 flowchart LR
C[RolloutController] --> W[WorkerInfo mutable table]
R[SessionRouter] --> W
H[RolloutHealthManager] --> W
M[get_rollout_metadata] --> W
C --> T[EngineLaunchSpec]
T --> L[server launch]
T --> E[request entrypoint]
T --> G[lifecycle group]
T --> U[update-weight metadata]
主要问题
建议改进核心是把“拓扑”“worker 状态表”“session 路由策略”“健康检查/生命周期恢复”拆成四个更清晰的 Module。外部 Interface 仍保持 flowchart LR
C[RolloutController] --> F[get_rollout_topology_builder]
F --> B[BackendTopologyBuilder Adapter]
B --> T[RolloutTopologyPlan]
C --> W[RolloutWorker actors]
C --> R[RolloutWorkerRegistry]
T --> R
Router[SessionRouter] --> R
H[RolloutHealthManager] --> R
H --> Loop[background health loop]
Trainer[Trainer / UpdateWeighter] --> M[TrainingMetadata]
R --> M
1. RolloutTopologyPlan负责隐藏后端拓扑差异:
worker_cls = get_rollout_worker_base_cls(config)
topology_builder = get_rollout_topology_builder(config)保留 关键 Interface: class RolloutTopologyPlan:
def server_slots(self) -> tuple[ServerLaunchSlot, ...]: ...
def slot_for_rank(self, rank: int) -> ServerLaunchSlot: ...
def lifecycle_group_for_rank(self, rank: int) -> LifecycleGroup: ...
def training_engine_mesh(self) -> list[list[int]]: ...调用方使用方式: topology = topology_builder.build_topology(
config=config,
rank_bundle_idx_list=rank_bundle_idx_list,
rank_to_dist_init_addr=rank_to_dist_init_addr,
)
for slot in topology.server_slots():
actor = workers[slot.worker_rank]
server_url = actor.init.remote(slot)
registry.register_started_server(
rank=slot.worker_rank,
actor=actor,
server_url=server_url,
session_url=actor.get_session_server_info.remote(),
)这里 2. BackendTopologyBuilder Adapter每个后端只实现拓扑翻译,不再把这类逻辑挂在 flowchart LR
F[get_rollout_topology_builder] --> D[DefaultTopologyBuilder]
F --> LM[LMDeployTopologyBuilder]
F --> SG[SGLangTopologyBuilder]
D --> P[RolloutTopologyPlan]
LM --> P
SG --> P
LMDeploy 关键规则:
SGLang 关键规则:
关键 Interface: class BackendTopologyBuilder(Protocol):
def build_topology(
self,
*,
config: RolloutConfig,
rank_bundle_idx_list: list[tuple[int, int]],
rank_to_dist_init_addr: dict[int, str],
) -> RolloutTopologyPlan: ...
def get_rollout_topology_builder(config: RolloutConfig) -> BackendTopologyBuilder:
if config.rollout_backend == "lmdeploy":
return LMDeployTopologyBuilder()
if config.rollout_backend == "sglang":
return SGLangTopologyBuilder()
if config.rollout_backend == "vllm":
return DefaultTopologyBuilder()
raise NotImplementedError(config.rollout_backend)LMDeploy EP 关键伪代码: for engine_meta in chunks(rank_bundle_idx_list, config.expert_parallel_size):
engine_ranks = tuple(rank for rank, _ in engine_meta)
dist_init_addr = rank_to_dist_init_addr[engine_ranks[0]]
for server_rank, bundle_idx in engine_meta:
server_slots[server_rank] = ServerLaunchSlot(
worker_rank=server_rank,
placement_group_bundle_idxs=(bundle_idx,),
dist_init_addr=dist_init_addr,
accepts_rollout_requests=True,
)
training_engine_mesh.append([server_rank])
lifecycle_groups.append(LifecycleGroup(ranks=engine_ranks))SGLang 跨节点关键伪代码: for engine_meta in chunks(rank_bundle_idx_list, config.num_gpus_per_engine):
engine_ranks = tuple(rank for rank, _ in engine_meta)
engine_bundles = tuple(bundle_idx for _, bundle_idx in engine_meta)
server_ranks = engine_ranks[:: config.gpus_per_node]
for node_rank, server_rank in enumerate(server_ranks):
server_slots[server_rank] = ServerLaunchSlot(
worker_rank=server_rank,
placement_group_bundle_idxs=node_bundles(engine_bundles, node_rank),
dist_init_addr=rank_to_dist_init_addr[server_ranks[0]],
accepts_rollout_requests=node_rank == 0,
)
lifecycle_groups.append(LifecycleGroup(ranks=tuple(server_ranks)))
training_engine_mesh.append(list(engine_ranks))这样后端差异集中在 Adapter 内部, 3. RolloutWorkerRegistry负责隐藏 worker 状态表:
关键 Interface: class RolloutWorkerRegistry:
def register_started_server(
self,
*,
rank: int,
actor: RolloutWorker,
server_url: str,
session_url: str | None,
) -> None: ...
def active_entrypoints(self) -> tuple[WorkerSnapshot, ...]: ...
def active_entrypoint_by_rank(self, rank: int) -> WorkerSnapshot | None: ...
def snapshot_active_workers(self) -> tuple[WorkerSnapshot, ...]: ...
def mark_unhealthy_ranks(self, ranks: set[int]) -> tuple[LifecycleGroup, ...]: ...
def recovery_plan(self) -> tuple[tuple[LifecycleGroup, tuple[WorkerSnapshot, ...]], ...]: ...
def complete_recovery(self, group: LifecycleGroup, *, recovered: bool) -> None: ...
def training_metadata_snapshot(self) -> TrainingMetadata: ...调用方使用方式: # Router 不读 worker dict,只问 registry 当前可用 entrypoint。
candidates = registry.active_entrypoints()
# Health manager 不改 WorkerInfo 字段,只提交 probe 结果。
registry.mark_unhealthy_ranks(failed_ranks)
# Trainer 不消费多张 status dict,只拿一致快照。
metadata = registry.training_metadata_snapshot()
train_controller.update_rollout_info(metadata)4. SessionRouter
关键 Interface 和使用方式: class SessionRouter:
def __init__(self, registry: RolloutWorkerRegistry) -> None:
self._registry = registry
self._session_to_rank: dict[int, int] = {}
self._round_robin_cursor = 0
async def get_worker(self, session_id: int) -> RolloutWorker | None:
sticky_rank = self._session_to_rank.get(session_id)
if sticky_rank is not None:
sticky_worker = self._registry.active_entrypoint_by_rank(sticky_rank)
if sticky_worker is not None:
return sticky_worker.actor
candidates = self._registry.active_entrypoints()
if not candidates:
self._session_to_rank.pop(session_id, None)
return None
worker = candidates[self._round_robin_cursor % len(candidates)]
self._round_robin_cursor += 1
self._session_to_rank[session_id] = worker.rank
return worker.actor
async def generate(self, rollout_state: RolloutState) -> RolloutState:
actor = await self._router.get_worker(rollout_state.session_id)
if actor is None:
rollout_state.status = Status.FAILED
rollout_state.error_msg = "No active rollout worker available."
return rollout_state
return await actor.generate.remote(rollout_state=rollout_state)这样 registry 仍是状态表 Module,router 才是路由策略 Module。Controller 不知道 sticky / round-robin,也不知道 entrypoint 过滤细节。 5. RolloutHealthManager
关键 Interface: class RolloutHealthManager:
def start_background_checks(self) -> None: ...
def stop_background_checks(self) -> None: ...
def pause_background_checks(self) -> None: ...
def resume_background_checks(self) -> None: ...
def run_periodic_health_check(self) -> None: ...
def shutdown_failed_groups_before_train(self) -> None: ...
def restart_inactive_groups(self) -> None: ...Controller 在所有 server 注册到 registry 后启动后台检查: for slot in topology.server_slots():
...
registry.register_started_server(...)
health_manager.start_background_checks()后台检查关键伪代码: def _run_health_loop(self) -> None:
while not self._stop_event.wait(self._health_check_interval_seconds):
if self._pause_event.is_set():
continue
self.run_periodic_health_check()
def run_periodic_health_check(self) -> None:
failed_ranks = set()
for worker in self._registry.snapshot_active_workers():
if not worker.actor.check_health.remote():
failed_ranks.add(worker.rank)
self._registry.mark_unhealthy_ranks(failed_ranks)显式 phase-switch / recovery 期间,health manager 暂停后台 health loop:
恢复关键伪代码: def restart_inactive_groups(self) -> None:
self.pause_background_checks()
try:
for group, workers in self._registry.recovery_plan():
recovered = self._restart_group(workers)
self._registry.complete_recovery(group, recovered=recovered)
finally:
self.resume_background_checks()当前伪代码仍表达的是 server process 级恢复:复用已有 Ray actor,调用 结论这个 PR 不建议继续往现有类里加字段和 helper。应先把 topology、worker registry、session router、health manager 这几个 Module 做深,让调用方只依赖小 Interface,内部再兼容后端拓扑和 legacy metadata。这样改动更内聚,也更容易写 Good Tests。 |
概述
这个 PR 主要做两部分重构:
EngineLaunchSpec/ServerProcessSpec。RolloutController、RolloutHealthManager、RolloutWorker的职责边界,并且支持ep group中一个worker失败后,将会重启所有的worker,另外, 当RolloutHealthManager检测到所有worker失败后,会立即重启所有的worker为什么 Health Manager 重构依赖 Server Launch Spec 重构
Rollout health recovery 不能只知道“哪些 worker 还活着”,还必须知道每个 rollout server 是怎么启动出来的,以及失败后应该按什么粒度恢复。
在 LMDeploy EP、SGLang 跨节点等场景下,一个 logical engine 和 server process 不是简单的一一对应关系:
因此,Health Manager 需要依赖 Server Launch Spec 提供的结构化信息:
engine_ranks:一个 logical engine 由哪些 worker rank 组成。server_processes:这个 engine 实际启动了哪些 rollout server process。server_worker_ranks:哪些 worker rank 拥有 server process,需要参与生命周期管理。accepts_rollout_requests:哪些 server 是 request entrypoint,可以接收 generate 请求。dist_init_addr/placement_group_bundle_idxs:worker recovery 时复用原始启动布局,避免重启后 server 地址或资源绑定发生变化。所以第一个 commit 先把 server launch layout 显式化;第二个 commit 才能基于这些结构化信息,把 health check、状态流转、group recovery 和 request routing 的职责从 controller 中拆出来。
主要改动
Server Launch Spec 重构
ServerProcessSpec和EngineLaunchSpec,显式描述每个推理 engine 应该启动哪些 server process。RolloutController._init_workers,先构造 launch spec,再根据 spec 启动 server。build_engine_launch_specs中。init时缓存 launch spec,后续 recovery 重启时复用原始 placement / dist-init 布局。Rollout Health Manager 重构
RolloutHealthManager,统一负责 worker 健康状态检查、状态流转和失败恢复。RolloutController只保留 worker 创建、请求路由、训练生命周期控制等职责。RolloutWorker继续负责后端相关的 server 启停、health probe 和 generate 行为。WorkerLifecycleState,明确区分:ACTIVEINACTIVERECOVERINGSessionRouter和 generate readiness check 只面向真正接收请求的 server。说明:该PR包含了PR1873的功能:#1873
测试
已验证通过moe模型EP/TP正常训练,并且当rank2的worker挂掉后,EP/TP 整组可正常恢复;当所有worker挂掉后,所有worker立即重启。遗留的一个问题:当所有worker挂掉后,会丢掉一些数据,但这个涉及到数据流的改动,本PR先不做该问题的修复