Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions tests/rl/test_rl_colocate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,18 @@ def _make_trainer(self, agent_loop_manager, *, total_train_steps: int = 1, sync_
)

trainer.rollout_controller = SimpleNamespace(
ensure_workers_healthy_before_training=SimpleNamespace(
remote=MagicMock(return_value="rollout_ready_for_training")
check_and_shutdown_inactive_workers=SimpleNamespace(
remote=MagicMock(return_value="rollout_inactive_workers_shutdown")
),
offload=SimpleNamespace(remote=MagicMock(return_value="rollout_offloaded")),
restart_inactive_workers=SimpleNamespace(remote=MagicMock(return_value="rollout_restarted")),
onload_weights=SimpleNamespace(remote=MagicMock(return_value="weights_loaded")),
onload_kvcache=SimpleNamespace(remote=MagicMock(return_value="kvcache_loaded")),
)
trainer.train_controller = SimpleNamespace(
onload=MagicMock(return_value="train_onloaded"),
offload=MagicMock(return_value="train_offloaded"),
update_weights=MagicMock(return_value="weights_updated"),
fit=MagicMock(
return_value=[
{
Expand Down Expand Up @@ -220,6 +225,32 @@ async def _produce_empty(batch_size, train_step, **kwargs):
trainer.train_controller.fit.assert_not_called()
self.assertEqual(trainer._cur_step, 0)

def test_fit_does_not_onload_train_when_rollout_training_barrier_fails(self):
# 验证共卡训练进入训练前必须先通过 rollout phase-switch barrier;
# 失败时不能 onload 训练。
async def _produce_batch(batch_size, train_step, *, model_step):
return ProduceBatchResult(
rollout_states=[[SimpleNamespace(message_uid=train_step, uid=train_step)]]
)

trainer = self._make_trainer(SimpleNamespace(produce_batch=_produce_batch))
trainer.rollout_controller.check_and_shutdown_inactive_workers.remote.side_effect = RuntimeError(
"inactive rollout workers after recovery"
)

with (
patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run),
patch("xtuner.v1.train.rl_trainer.ray.get", side_effect=lambda obj, timeout=None: obj),
):
with self.assertRaisesRegex(RuntimeError, "inactive rollout workers"):
trainer.fit()

trainer.rollout_controller.check_and_shutdown_inactive_workers.remote.assert_called_once_with()
trainer.rollout_controller.offload.remote.assert_not_called()
trainer.train_controller.onload.assert_not_called()
trainer.train_controller.fit.assert_not_called()
self.assertEqual(trainer._cur_step, 0)

def test_fit_uses_sync_interval_and_passes_rollout_model_step(self):
# 验证 rollout 看到的是按 sync interval 推进后的 model_step。
produce_calls = []
Expand Down
4 changes: 2 additions & 2 deletions tests/rl/test_rl_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self):
self.pause_generation = _RemoteMethod(async_result=True)
self.continue_generation = _RemoteMethod(async_result=True)
self.offload = _RemoteMethod(return_value="rollout_offloaded")
self.ensure_workers_healthy_before_training = _RemoteMethod(return_value="rollout_ready_for_training")
self.recover_failed_workers = _RemoteMethod(return_value="rollout_recovered")
self.check_and_shutdown_inactive_workers = _RemoteMethod(return_value="rollout_inactive_workers_shutdown")
self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted")
self.onload_weights = _RemoteMethod(return_value="weights_loaded")
self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded")
self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}})
Expand Down
102 changes: 44 additions & 58 deletions tests/rl/test_rollout_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
本文件合并旧的 test_rollout_worker.py 和 test_rollout_utils.py 中不依赖真实模型/后端的测试:
- SGLangWorker pause/continue 对 abort flag 和 server request 的控制。
- RolloutWorker abort、abort request timeout 和 in-flight request 取消语义。
- RolloutHealthChecker 对 inactive/unhealthy worker 的清理逻辑
- RolloutHealthManager 对 inactive/unhealthy worker 的生命周期标记逻辑
- PartialRolloutHandler 拼接 routed_experts 后释放旧 Ray ObjectRef 的逻辑。

旧 test_rollout_utils.py 中的 TestRolloutControllerRecover 需要真实 Ray controller / lmdeploy backend,
Expand All @@ -19,28 +19,28 @@
import httpx

from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status
from xtuner.v1.rl.rollout.controller import RolloutController
from xtuner.v1.rl.rollout.controller import RolloutController, WorkerInfo
from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager
from xtuner.v1.rl.rollout.sglang import SGLangWorker
from xtuner.v1.rl.rollout.utils import PartialRolloutHandler, RolloutHealthChecker
from xtuner.v1.rl.rollout.utils import PartialRolloutHandler, WorkerLifecycleState
from xtuner.v1.rl.rollout.worker import RolloutWorker
from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult


class _FakeRemoteMethod:
def __init__(self, name, call_log):
self.name = name
self.call_log = call_log
class _FakeAsyncRemoteMethod:
def __init__(self, result):
self.result = result
self.calls = []

def remote(self):
self.call_log.append((self.name, "remote"))
return self.name
self.calls.append(())

async def _result():
if isinstance(self.result, Exception):
raise self.result
return self.result

class _FakeWorker:
def __init__(self):
self.call_log = []
self.offload = _FakeRemoteMethod("offload", self.call_log)
self.shutdown = _FakeRemoteMethod("shutdown", self.call_log)
return _result()


class _FakeRolloutRouter:
Expand Down Expand Up @@ -391,58 +391,44 @@ async def run_case(case_name, safe_post_result, safe_handle_response, expected_m
await asyncio.gather(*(run_case(*case) for case in cases))


class TestRolloutHealthChecker(unittest.TestCase):
def _build_checker(self, workers_info):
config = SimpleNamespace(health_check_interval_seconds=10, health_check_failure_threshold=1)
return RolloutHealthChecker(config, workers_info)
class TestRolloutHealthManager(unittest.TestCase):
def _build_manager(self, workers_info, *, failure_threshold=1):
config = SimpleNamespace(health_check_interval_seconds=10, health_check_failure_threshold=failure_threshold)
return RolloutHealthManager(config, workers_info, worker_infos_lock=threading.RLock())

def test_shutdown_runs_when_offload_fails(self):
# worker 健康检查失败且 offload 也失败时,health checker 应 shutdown 并标记 inactive。
worker = _FakeWorker()
workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=True)}
checker = self._build_checker(workers_info)
def test_marks_worker_inactive_after_consecutive_health_failures(self):
actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False))
worker_info = WorkerInfo(actor=actor, url="http://worker-0")
workers_info = {0: worker_info}
manager = self._build_manager(workers_info, failure_threshold=2)

async def unhealthy_worker(*args, **kwargs):
return False
manager._check_and_deactivate_failed_worker_groups()

def ray_get(ref, timeout=None):
worker.call_log.append((ref, "get"))
if ref == "offload":
raise RuntimeError("offload failed")
return None
self.assertTrue(worker_info.is_active())
self.assertEqual(actor.check_health.calls, [()])

with (
patch("xtuner.v1.rl.rollout.utils.check_worker_health", side_effect=unhealthy_worker),
patch("xtuner.v1.rl.rollout.utils.ray.get", side_effect=ray_get),
):
checker.run_once()

self.assertFalse(workers_info[0].is_active)
self.assertEqual(
worker.call_log,
[
("offload", "remote"),
("offload", "get"),
("shutdown", "remote"),
("shutdown", "get"),
],
)
manager._check_and_deactivate_failed_worker_groups()

self.assertFalse(worker_info.is_active())
self.assertEqual(worker_info.lifecycle_state, WorkerLifecycleState.INACTIVE)
self.assertEqual(actor.check_health.calls, [(), ()])

def test_inactive_worker_is_not_cleaned_up_again(self):
# 已 inactive 的 worker 不再重复健康检查、offload 或 shutdown。
worker = _FakeWorker()
workers_info = {0: SimpleNamespace(actor=worker, url="http://worker-0", is_active=False)}
checker = self._build_checker(workers_info)
# 已 inactive 的 worker 不再重复健康检查。
actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True))
workers_info = {
0: WorkerInfo(
actor=actor,
url="http://worker-0",
lifecycle_state=WorkerLifecycleState.INACTIVE,
)
}
manager = self._build_manager(workers_info)

with (
patch("xtuner.v1.rl.rollout.utils.check_worker_health") as check_worker_health_mock,
patch("xtuner.v1.rl.rollout.utils.ray.get") as ray_get_mock,
):
checker.run_once()
checked_count = manager._check_and_deactivate_failed_worker_groups()

check_worker_health_mock.assert_not_called()
ray_get_mock.assert_not_called()
self.assertEqual(worker.call_log, [])
self.assertEqual(checked_count, 0)
self.assertEqual(actor.check_health.calls, [])


class TestPartialRolloutHandler(unittest.IsolatedAsyncioTestCase):
Expand Down
Loading
Loading