Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.globals import HF_IDS
from maxtext.utils.lora_utils import sync_lora_metadata


flags.DEFINE_bool(
Expand Down Expand Up @@ -451,6 +452,9 @@ def main(argv: Sequence[str]) -> None:
if not load_parameters_path and not lora_restore_path:
raise ValueError("Either load_parameters_path or lora_restore_path must be specified.")

if lora_restore_path:
sync_lora_metadata(config)

# Load Maxtext checkpoint using Orbax (now smart enough to load both if present)
max_logging.log("\nLoading Orbax checkpoint(s)...")
start = time.time()
Expand Down
10 changes: 9 additions & 1 deletion src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,11 +1142,19 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)

custom_metadata = None
if config and hasattr(config, "lora") and config.lora:
lora_rank = getattr(config.lora, "lora_rank", 0)
if lora_rank > 0 and hasattr(config.lora, "model_dump"):
custom_metadata = {"lora": config.lora.model_dump()}

match (checkpoint_manager, config, data_iterator):
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
replicator_error_handler(config)
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
case _:
return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force)
return checkpoint_manager.save(
step, args=Composite(**save_args_composite), force=force, custom_metadata=custom_metadata
)
51 changes: 51 additions & 0 deletions src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import re
from typing import Any, Optional

from etils import epath
from flax import nnx, linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
Expand Down Expand Up @@ -515,6 +516,52 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
)


def sync_lora_metadata(config: pyconfig.HyperParameters) -> None:
"""Syncs LoRA parameters (rank, alpha) from the checkpoint sidecar metadata if present.

If configuration values are set to non-default values (i.e. rank > 0 or alpha > 0.0)
and differ from the checkpoint metadata values, we raise a ValueError to fail the run.
If they are at default values, we sync them from the checkpoint.
"""
lora_restore_path = config.lora.lora_restore_path
if not lora_restore_path:
return

lora_meta = None
checkpoint_dir = epath.Path(lora_restore_path)
try:
ckptr = ocp.StandardCheckpointer()
metadata = ckptr.metadata(checkpoint_dir)
custom_metadata = metadata.custom_metadata or {}
lora_meta = custom_metadata.get("lora")
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Warning: Failed to load LoRA metadata: {e}")

if lora_meta:
meta_rank = lora_meta.get("lora_rank", config.lora.lora_rank)
meta_alpha = lora_meta.get("lora_alpha", config.lora.lora_alpha)

# Check lora_rank
if config.lora.lora_rank not in (0, meta_rank):
raise ValueError(
f"Configured lora_rank ({config.lora.lora_rank}) does not match "
f"checkpoint metadata lora_rank ({meta_rank}) at {checkpoint_dir}."
)
# Check lora_alpha
if config.lora.lora_alpha not in (0.0, meta_alpha):
raise ValueError(
f"Configured lora_alpha ({config.lora.lora_alpha}) does not match "
f"checkpoint metadata lora_alpha ({meta_alpha}) at {checkpoint_dir}."
)

config.lora.lora_rank = meta_rank
config.lora.lora_alpha = meta_alpha
max_logging.log(
f"Synced LoRA parameters from Orbax metadata at {checkpoint_dir}: "
f"rank={config.lora.lora_rank}, alpha={config.lora.lora_alpha}"
)


def apply_lora_to_model(
model: nnx.Module,
mesh: Optional[jax.sharding.Mesh],
Expand Down Expand Up @@ -585,6 +632,8 @@ def _safe_reshard(var, sharding_spec):
def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any:
"""Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run."""
lora_restore_path = mt_config.lora.lora_restore_path
if not lora_restore_path:
return trainer

train_steps = getattr(trainer, "train_steps", 0)
if train_steps > 0:
Expand All @@ -601,6 +650,8 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) ->
f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules."
)

sync_lora_metadata(mt_config)

abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam)

target_for_restore = jax.tree.map(
Expand Down
165 changes: 156 additions & 9 deletions tests/post_training/unit/lora_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
"""Tests for Qwix LoRA utils in lora_utils.py"""
import re
import sys
import tempfile
import unittest
from unittest import mock

from etils import epath
import jax
import jax.numpy as jnp
import optax
import pytest
from flax import nnx
Expand All @@ -26,6 +30,7 @@
pytestmark = [pytest.mark.post_training]

# Now safe to do top-level imports
from maxtext.common import checkpointing
from tunix.sft import peft_trainer
from maxtext.utils import lora_utils
from maxtext.utils import model_creation_utils
Expand Down Expand Up @@ -59,11 +64,12 @@

def _make_config(**overrides):
"""Return a MaxTextConfig object suitable for unit tests."""
config_dict = _BASE_CONFIG.copy()
config_dict.update(overrides)
# Use initialize_pydantic to get nested models as objects (attribute access)
return pyconfig.initialize_pydantic(
[sys.argv[0], get_test_config_path()],
**_BASE_CONFIG,
**overrides,
**config_dict,
)


Expand Down Expand Up @@ -121,7 +127,12 @@ def test_build_lora_provider(self):
with mock.patch("qwix.LoraProvider") as mock_provider:
lora_utils._build_lora_provider(mock_config)
mock_provider.assert_called_once_with(
module_path="custom/path", rank=8, alpha=16.0, dropout=0.0, weight_qtype="int8", tile_size=32
module_path="custom/path",
rank=8,
alpha=16.0,
dropout=0.0,
weight_qtype="int8",
tile_size=32,
)

def test_prepare_dummy_inputs(self):
Expand Down Expand Up @@ -173,7 +184,13 @@ def test_apply_lora_to_model_adapters_loaded(self):
# If we skip Qwix, it should stay False.
self.assertFalse(lora_utils.is_lora_enabled(result))

def _run_apply_lora_test(self, scan_layers: bool, weight_qtype=None, tile_size=None, mock_multihost: bool = False):
def _run_apply_lora_test(
self,
scan_layers: bool,
weight_qtype=None,
tile_size=None,
mock_multihost: bool = False,
):
"""Helper to run LoRA application test with/without scanned layers and optional QLoRA."""
# Passing nested dict as 'lora' kwarg to _make_config
cfg = _make_config(
Expand Down Expand Up @@ -246,7 +263,12 @@ def test_apply_lora_multihost_mock(self):
def test_restore_lora_from_path(self):
"""Test restoration of LoRA parameters from a path."""
cfg = _make_config(
lora={"enable_lora": True, "lora_restore_path": "some/path", "lora_rank": 4, "lora_alpha": 8.0},
lora={
"enable_lora": True,
"lora_restore_path": "some/path",
"lora_rank": 4,
"lora_alpha": 8.0,
},
scan_layers=False,
)
model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN)
Expand All @@ -271,6 +293,135 @@ def test_restore_lora_from_path(self):
self.assertTrue(kwargs["args"].partial_restore)
mock_update.assert_called_once()

def test_sync_lora_metadata_default_syncs(self):
"""Test that default lora rank/alpha are successfully synced from checkpoint metadata."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 0,
"lora_alpha": 0.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
lora_utils.sync_lora_metadata(cfg)
self.assertEqual(cfg.lora.lora_rank, 32)
self.assertEqual(cfg.lora.lora_alpha, 64.0)

def test_sync_lora_metadata_matching_passes(self):
"""Test that matching non-default parameters pass without errors."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 32,
"lora_alpha": 64.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
# Should not raise ValueError
lora_utils.sync_lora_metadata(cfg)
self.assertEqual(cfg.lora.lora_rank, 32)
self.assertEqual(cfg.lora.lora_alpha, 64.0)

def test_sync_lora_metadata_rank_mismatch_fails(self):
"""Test that configured rank mismatching checkpoint metadata rank raises ValueError."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 8,
"lora_alpha": 64.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
with self.assertRaisesRegex(ValueError, "Configured lora_rank .* does not match"):
lora_utils.sync_lora_metadata(cfg)

def test_sync_lora_metadata_alpha_mismatch_fails(self):
"""Test that configured alpha mismatching checkpoint metadata alpha raises ValueError."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 32,
"lora_alpha": 16.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
with self.assertRaisesRegex(ValueError, "Configured lora_alpha .* does not match"):
lora_utils.sync_lora_metadata(cfg)

def test_save_checkpoint_passes_metadata(self):
"""Test that save_checkpoint correctly generates and passes custom lora metadata to CheckpointManager."""
cfg = _make_config(
lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0},
enable_checkpointing=True,
)
mock_manager = mock.MagicMock()
mock_state = mock.MagicMock()

with mock.patch("jax.block_until_ready"):
checkpointing.save_checkpoint(mock_manager, step=10, state=mock_state, config=cfg)
mock_manager.save.assert_called_once()
_, kwargs = mock_manager.save.call_args
self.assertIn("custom_metadata", kwargs)
self.assertEqual(kwargs["custom_metadata"], {"lora": cfg.lora.model_dump()})

def test_save_and_restore_metadata_integration(self):
"""Integration test checking that Orbax CheckpointManager writes and reads custom LoRA metadata."""

cfg_save = _make_config(
lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0},
enable_checkpointing=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
manager = checkpointing.create_orbax_checkpoint_manager(
tmpdir,
enable_checkpointing=True,
use_async=False,
save_interval_steps=1,
use_ocdbt=False,
use_zarr3=False,
)

# Use save_checkpoint wrapper with a simple state
dummy_state = {"weight": jnp.array([1.0, 2.0])}
checkpointing.save_checkpoint(manager, step=0, state=dummy_state, config=cfg_save)
manager.wait_until_finished()

# Now verify that the saved checkpoint contains metadata on disk
checkpoint_dir = epath.Path(tmpdir) / "0"
self.assertTrue((checkpoint_dir / "_CHECKPOINT_METADATA").exists())

# Restore using sync_lora_metadata on a config with default rank/alpha
cfg_restore = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": str(checkpoint_dir),
"lora_rank": 0,
"lora_alpha": 0.0,
}
)
lora_utils.sync_lora_metadata(cfg_restore)

# Verify values were successfully synced back
self.assertEqual(cfg_restore.lora.lora_rank, 8)
self.assertEqual(cfg_restore.lora.lora_alpha, 16.0)

def test_gemma4_lora_path_matching(self):
"""Test that the Gemma4 LoRA regex correctly matches all expected parameter paths."""
mock_config = mock.MagicMock(spec=pyconfig.HyperParameters)
Expand Down Expand Up @@ -309,10 +460,6 @@ def test_gemma4_lora_path_matching(self):
"decoder/layers_remainder/layers/0/mlp/shared_experts/wi_0/kernel",
"decoder/layers_remainder/layers/0/mlp/shared_experts/wi_1/kernel",
"decoder/layers_remainder/layers/0/mlp/shared_experts/wo/kernel",
# No scanned_blocks/layers_remainder prefix (e.g. fallback or direct structure)
"decoder/layers/0/self_attention/query/kernel",
"decoder/layers/0/mlp/wi_0/kernel",
"decoder/layers/layers/0/mlp/shared_experts/wi_0/kernel",
]

for path in matching_paths:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/hf_checkpoint_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def test_get_maxtext_model_info(self):
"hidden_size_per_layer_input=128",
"vocab_size_per_layer_input=256",
"vocab_size=256",
"skip_jax_distributed_system=True",
],
override_model_config=True,
)
Expand Down Expand Up @@ -417,7 +418,7 @@ def test_recursive_update(self):
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer")
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path")
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices")
def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls):
def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, _mock_path, mock_checkpointer_cls):

# Mock jax devices
mock_jax_devices.return_value = [MagicMock()]
Expand Down
Loading