Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,20 @@ def async_save_dcp(
# temporary .incomplete directory and commit it only after every rank's
# async_save future has completed.
incomplete_dir = weights_dir.with_name(f"{weights_dir.name}.incomplete")
if weights_dir.exists():
# Check existence only on rank 0 and broadcast the result so all ranks
# raise (or continue) together. Without this, NFS cache inconsistencies
# could cause some ranks to raise while others proceed to the barrier,
# resulting in a deadlock.
dir_exists = torch.tensor(int(weights_dir.exists() if dist.get_rank() == 0 else 0), dtype=torch.int32)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we only check if the directory exists on rank0?

@VincentCheungKokomo VincentCheungKokomo May 30, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already only calls exists() on rank 0 (via weights_dir.exists() if dist.get_rank() == 0 else 0) — other ranks simply use 0.
The next line dist.broadcast(dir_exists, src=0, group=async_checkpoint_pg) then syncs rank 0's result to all ranks. This way all
ranks get a consistent decision and either all raise or all proceed together, avoiding a deadlock where NFS cache inconsistencies
cause some ranks to raise while others continue to the barrier.

dist.broadcast(dir_exists, src=0, group=async_checkpoint_pg)
if dir_exists.item():
raise FileExistsError(f"Checkpoint directory already exists: {weights_dir}")
if dist.get_rank() == 0:
if incomplete_dir.exists():
shutil.rmtree(incomplete_dir)
incomplete_dir.mkdir(parents=True, exist_ok=True)
# Ensure rank 0 finishes rmtree+mkdir before any rank proceeds to write.
dist.barrier(group=async_checkpoint_pg)

# XtunerCacheWriter.stage() creates its staging cache directly in POSIX
# shared memory (/dev/shm). PyTorch's ForkingPickler detects
Expand Down Expand Up @@ -440,7 +448,20 @@ def commit_async_save() -> None:
dcp_future.result()
break
except BaseException as exc:
if attempt == max_daemon_init_attempts or not self._is_async_checkpoint_daemon_init_error(exc):
# Use all_reduce(MAX) so all ranks agree on whether this is
# a retryable daemon-init error. Without this, ranks that
# see a non-retryable error would raise (skipping the
# barrier) while other ranks wait at the barrier, causing a
# deadlock.
is_retryable = attempt < max_daemon_init_attempts and self._is_async_checkpoint_daemon_init_error(
exc
)
# 0 = retryable, 1 = fatal; MAX means any fatal rank wins.
decision = torch.tensor(0 if is_retryable else 1, dtype=torch.int32)
dist.all_reduce(decision, op=dist.ReduceOp.MAX, group=async_checkpoint_pg)
is_fatal = bool(decision.item())

if is_fatal:
elapsed = time.time() - t0
logger.error(f"[DCP async_save for {weights_dir}] failed after {elapsed:.2f}s: {exc}")
logger.error(traceback.format_exc())
Expand Down Expand Up @@ -492,6 +513,51 @@ def _build_async_storage_writer(self, weights_dir: Path, *, save_optimizer: bool
storage_writer.state_dict_cache = self._async_state_dict_cache
return storage_writer

@classmethod
def warmup_async_save_dcp(cls, work_dir: Path) -> None:
"""Warm up async DCP save infrastructure with a tiny dummy state dict.

This triggers the full async save path — including daemon subprocess
spawn and its internal init_process_group — so that errors like port
conflicts (EADDRINUSE) surface before any real training begins.

Args:
work_dir (Path): Working directory for temporary preflight files.
"""
preflight_dir = work_dir / ".preflight_dcp"
weights_dir = preflight_dir / "weights"

if dist.get_rank() == 0:
if preflight_dir.exists():
shutil.rmtree(preflight_dir)
weights_dir.mkdir(parents=True, exist_ok=True)
dist.barrier()

dummy_state_dict = {"_preflight": torch.zeros(1)}

try:
async_save_kwargs: dict[str, Any] = {}
state_dict_saver = importlib.import_module("torch.distributed.checkpoint.state_dict_saver")
async_checkpointer_type = getattr(state_dict_saver, "AsyncCheckpointerType", None)
if async_checkpointer_type is not None:
async_save_kwargs["async_checkpointer_type"] = async_checkpointer_type.PROCESS

future = cast(Any, dcp.async_save)(
dummy_state_dict,
checkpoint_id=weights_dir,
**async_save_kwargs,
)
future.result(timeout=300)
except Exception as e:
raise RuntimeError(
f"DCP warmup save failed. This usually indicates a port conflict "
f"or process group initialization issue. Error: {e}"
) from e
finally:
if dist.get_rank() == 0 and preflight_dir.exists():
shutil.rmtree(preflight_dir, ignore_errors=True)
dist.barrier()

def destroy_async_checkpoint_pg(self) -> None:
"""Destroy the dedicated gloo process group used for async
checkpoint."""
Expand Down
65 changes: 58 additions & 7 deletions xtuner/v1/patch/xtuner_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StreamTransformExtension,
)
from torch.distributed.checkpoint.filesystem import FileSystem
from torch.distributed.checkpoint.staging import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.staging import _copy_state_dict
from torch.distributed.checkpoint.storage import (
WriteResult,
)
Expand All @@ -25,6 +25,60 @@
logger = logging.getLogger(__name__)


def _create_coalesced_shm_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Create a CPU state dict backed by coalesced shared-memory buffers.

Instead of creating one shared-memory file per tensor (which leads to
thousands of fds and triggers ``received 0 items of ancdata`` when the
daemon subprocess tries to receive them all), this function groups tensors
by dtype, allocates a single large shared-memory tensor per dtype, and
returns views into that buffer.

Args:
state_dict (dict[str, Any]): The source state dict (tensors can be on
any device).

Returns:
dict[str, Any]: A new state dict with the same keys, where every tensor
is a view into a dtype-coalesced shared-memory buffer.
"""
# Collect tensor metadata grouped by dtype
dtype_groups: dict[torch.dtype, list[tuple[str, torch.Size]]] = {}
for key, val in state_dict.items():
if isinstance(val, torch.Tensor) and val.numel() > 0:
dtype_groups.setdefault(val.dtype, []).append((key, val.size()))

# Allocate one coalesced buffer per dtype in shared memory
dtype_buffers: dict[torch.dtype, torch.Tensor] = {}
dtype_offsets: dict[torch.dtype, int] = {}
for dtype, items in dtype_groups.items():
total_numel = sum(size.numel() for _, size in items)
buf = torch.empty(total_numel, dtype=dtype)
buf.share_memory_()
dtype_buffers[dtype] = buf
dtype_offsets[dtype] = 0

# Build the output state dict with views into coalesced buffers
result: dict[str, Any] = {}
for key, val in state_dict.items():
if isinstance(val, torch.Tensor) and val.numel() > 0:
dtype = val.dtype
offset = dtype_offsets[dtype]
numel = val.numel()
view = dtype_buffers[dtype][offset : offset + numel].view(val.size())
dtype_offsets[dtype] = offset + numel
result[key] = view
elif isinstance(val, torch.Tensor):
# Zero-numel tensors: just create a shared empty tensor
t = torch.zeros_like(val, device="cpu")
t.share_memory_()
result[key] = t
else:
result[key] = val

return result


# PyTorch 2.7+ introduced _extensions parameter for FileSystemWriter
_TORCH_DCP_FSWRITER_HAS_EXTENSIONS = version.parse(torch.__version__) >= version.parse("2.7.0")

Expand Down Expand Up @@ -194,16 +248,13 @@ def stage(self, state_dict: dict[str, Any]) -> dict[str, Any]:
self.per_thread_copy_ahead = 0

if not self.cache_staged_state_dict:
staged_state_dict = _create_cpu_state_dict(state_dict, share_memory=True)
staged_state_dict = _create_coalesced_shm_state_dict(state_dict)
return _copy_state_dict(state_dict, staged_state_dict, type_check=self.type_check)

if self.state_dict_cache is None:
if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0:
logger.info("[DCP async_save] creating shared-memory staged cache")
self.state_dict_cache = _create_cpu_state_dict(
state_dict,
share_memory=True,
)
logger.info("[DCP async_save] creating shared-memory staged cache (coalesced)")
self.state_dict_cache = _create_coalesced_shm_state_dict(state_dict)

return _copy_state_dict(state_dict, self.state_dict_cache, type_check=self.type_check)

Expand Down
21 changes: 20 additions & 1 deletion xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,8 @@ def __init__(

self._metrics_recorder = self._maybe_init_model_metrics_recorder(internal_metrics_cfg)

self._preflight_async_checkpoint()

@classmethod
def from_config(cls, config: TrainerConfig) -> Self:
"""Create a Trainer instance from a TrainerConfig.
Expand Down Expand Up @@ -865,6 +867,7 @@ def fit(self):
ckpt_saved = self._maybe_save(is_snapshot=False)
if not ckpt_saved:
_ = self._maybe_save(is_snapshot=True)
self._check_async_save_health()

time_before_get_data = time.time()

Expand Down Expand Up @@ -1177,6 +1180,22 @@ def _maybe_check_health(self):
raise RuntimeError("Health check failed, exit training")
log_rank0.info(f"Health check passed at step {self.cur_step}")

def _preflight_async_checkpoint(self) -> None:
"""Warm up async DCP save to surface daemon init errors early."""
if not self._async_checkpoint:
return
log_rank0.info("Preflight: warming up async DCP save infrastructure...")
TrainEngine.warmup_async_save_dcp(work_dir=self.work_dir)
log_rank0.info("Preflight: async DCP save infrastructure verified OK.")

def _check_async_save_health(self) -> None:
"""Non-blocking check: if any pending async save has failed, raise immediately."""
if self._pending_checkpoint is not None and self._pending_checkpoint.done():
exc = self._pending_checkpoint.exception()
if exc is not None:
self._pending_checkpoint = None
raise RuntimeError(f"Async DCP checkpoint failed in background: {exc}") from exc

def _wait_for_pending_checkpoint(self, timeout: int = 3000) -> None:
if self._pending_checkpoint is None:
return
Expand Down Expand Up @@ -1226,7 +1245,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:

# Save model and optimizer
future: Future | None = None
if self._async_checkpoint and not is_snapshot:
if self._async_checkpoint:
future = self._engine.async_save_dcp(weights_dir=weights_path)
else:
self._engine.save_dcp(weights_dir=weights_path)
Expand Down
Loading