diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index c52bd8192d..c3abdf60f4 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -97,6 +97,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils.globals import HF_IDS +from maxtext.utils.lora_utils import sync_lora_metadata flags.DEFINE_bool( @@ -451,6 +452,9 @@ def main(argv: Sequence[str]) -> None: if not load_parameters_path and not lora_restore_path: raise ValueError("Either load_parameters_path or lora_restore_path must be specified.") + if lora_restore_path: + sync_lora_metadata(config) + # Load Maxtext checkpoint using Orbax (now smart enough to load both if present) max_logging.log("\nLoading Orbax checkpoint(s)...") start = time.time() diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index c778f92bf9..6f3cda25ea 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -1142,6 +1142,12 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total)) save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save) + custom_metadata = None + if config and hasattr(config, "lora") and config.lora: + lora_rank = getattr(config.lora, "lora_rank", 0) + if lora_rank > 0 and hasattr(config.lora, "model_dump"): + custom_metadata = {"lora": config.lora.model_dump()} + match (checkpoint_manager, config, data_iterator): case (checkpoint_manager, _, _) if isinstance( checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) @@ -1149,4 +1155,6 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= replicator_error_handler(config) return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force) case _: - return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force) + return checkpoint_manager.save( + step, args=Composite(**save_args_composite), force=force, custom_metadata=custom_metadata + ) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 6b4410f209..c16cb6cd47 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -21,6 +21,7 @@ import re from typing import Any, Optional +from etils import epath from flax import nnx, linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -515,6 +516,52 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar ) +def sync_lora_metadata(config: pyconfig.HyperParameters) -> None: + """Syncs LoRA parameters (rank, alpha) from the checkpoint sidecar metadata if present. + + If configuration values are set to non-default values (i.e. rank > 0 or alpha > 0.0) + and differ from the checkpoint metadata values, we raise a ValueError to fail the run. + If they are at default values, we sync them from the checkpoint. + """ + lora_restore_path = config.lora.lora_restore_path + if not lora_restore_path: + return + + lora_meta = None + checkpoint_dir = epath.Path(lora_restore_path) + try: + ckptr = ocp.StandardCheckpointer() + metadata = ckptr.metadata(checkpoint_dir) + custom_metadata = metadata.custom_metadata or {} + lora_meta = custom_metadata.get("lora") + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Warning: Failed to load LoRA metadata: {e}") + + if lora_meta: + meta_rank = lora_meta.get("lora_rank", config.lora.lora_rank) + meta_alpha = lora_meta.get("lora_alpha", config.lora.lora_alpha) + + # Check lora_rank + if config.lora.lora_rank not in (0, meta_rank): + raise ValueError( + f"Configured lora_rank ({config.lora.lora_rank}) does not match " + f"checkpoint metadata lora_rank ({meta_rank}) at {checkpoint_dir}." + ) + # Check lora_alpha + if config.lora.lora_alpha not in (0.0, meta_alpha): + raise ValueError( + f"Configured lora_alpha ({config.lora.lora_alpha}) does not match " + f"checkpoint metadata lora_alpha ({meta_alpha}) at {checkpoint_dir}." + ) + + config.lora.lora_rank = meta_rank + config.lora.lora_alpha = meta_alpha + max_logging.log( + f"Synced LoRA parameters from Orbax metadata at {checkpoint_dir}: " + f"rank={config.lora.lora_rank}, alpha={config.lora.lora_alpha}" + ) + + def apply_lora_to_model( model: nnx.Module, mesh: Optional[jax.sharding.Mesh], @@ -585,6 +632,8 @@ def _safe_reshard(var, sharding_spec): def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" lora_restore_path = mt_config.lora.lora_restore_path + if not lora_restore_path: + return trainer train_steps = getattr(trainer, "train_steps", 0) if train_steps > 0: @@ -601,6 +650,8 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." ) + sync_lora_metadata(mt_config) + abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) target_for_restore = jax.tree.map( diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py index b0f229875d..4921d7a1a4 100644 --- a/tests/post_training/unit/lora_utils_test.py +++ b/tests/post_training/unit/lora_utils_test.py @@ -15,9 +15,13 @@ """Tests for Qwix LoRA utils in lora_utils.py""" import re import sys +import tempfile import unittest from unittest import mock + +from etils import epath import jax +import jax.numpy as jnp import optax import pytest from flax import nnx @@ -26,6 +30,7 @@ pytestmark = [pytest.mark.post_training] # Now safe to do top-level imports +from maxtext.common import checkpointing from tunix.sft import peft_trainer from maxtext.utils import lora_utils from maxtext.utils import model_creation_utils @@ -59,11 +64,12 @@ def _make_config(**overrides): """Return a MaxTextConfig object suitable for unit tests.""" + config_dict = _BASE_CONFIG.copy() + config_dict.update(overrides) # Use initialize_pydantic to get nested models as objects (attribute access) return pyconfig.initialize_pydantic( [sys.argv[0], get_test_config_path()], - **_BASE_CONFIG, - **overrides, + **config_dict, ) @@ -121,7 +127,12 @@ def test_build_lora_provider(self): with mock.patch("qwix.LoraProvider") as mock_provider: lora_utils._build_lora_provider(mock_config) mock_provider.assert_called_once_with( - module_path="custom/path", rank=8, alpha=16.0, dropout=0.0, weight_qtype="int8", tile_size=32 + module_path="custom/path", + rank=8, + alpha=16.0, + dropout=0.0, + weight_qtype="int8", + tile_size=32, ) def test_prepare_dummy_inputs(self): @@ -173,7 +184,13 @@ def test_apply_lora_to_model_adapters_loaded(self): # If we skip Qwix, it should stay False. self.assertFalse(lora_utils.is_lora_enabled(result)) - def _run_apply_lora_test(self, scan_layers: bool, weight_qtype=None, tile_size=None, mock_multihost: bool = False): + def _run_apply_lora_test( + self, + scan_layers: bool, + weight_qtype=None, + tile_size=None, + mock_multihost: bool = False, + ): """Helper to run LoRA application test with/without scanned layers and optional QLoRA.""" # Passing nested dict as 'lora' kwarg to _make_config cfg = _make_config( @@ -246,7 +263,12 @@ def test_apply_lora_multihost_mock(self): def test_restore_lora_from_path(self): """Test restoration of LoRA parameters from a path.""" cfg = _make_config( - lora={"enable_lora": True, "lora_restore_path": "some/path", "lora_rank": 4, "lora_alpha": 8.0}, + lora={ + "enable_lora": True, + "lora_restore_path": "some/path", + "lora_rank": 4, + "lora_alpha": 8.0, + }, scan_layers=False, ) model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) @@ -271,6 +293,135 @@ def test_restore_lora_from_path(self): self.assertTrue(kwargs["args"].partial_restore) mock_update.assert_called_once() + def test_sync_lora_metadata_default_syncs(self): + """Test that default lora rank/alpha are successfully synced from checkpoint metadata.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 0, + "lora_alpha": 0.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + lora_utils.sync_lora_metadata(cfg) + self.assertEqual(cfg.lora.lora_rank, 32) + self.assertEqual(cfg.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_matching_passes(self): + """Test that matching non-default parameters pass without errors.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 32, + "lora_alpha": 64.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + # Should not raise ValueError + lora_utils.sync_lora_metadata(cfg) + self.assertEqual(cfg.lora.lora_rank, 32) + self.assertEqual(cfg.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_rank_mismatch_fails(self): + """Test that configured rank mismatching checkpoint metadata rank raises ValueError.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 8, + "lora_alpha": 64.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + with self.assertRaisesRegex(ValueError, "Configured lora_rank .* does not match"): + lora_utils.sync_lora_metadata(cfg) + + def test_sync_lora_metadata_alpha_mismatch_fails(self): + """Test that configured alpha mismatching checkpoint metadata alpha raises ValueError.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 32, + "lora_alpha": 16.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + with self.assertRaisesRegex(ValueError, "Configured lora_alpha .* does not match"): + lora_utils.sync_lora_metadata(cfg) + + def test_save_checkpoint_passes_metadata(self): + """Test that save_checkpoint correctly generates and passes custom lora metadata to CheckpointManager.""" + cfg = _make_config( + lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0}, + enable_checkpointing=True, + ) + mock_manager = mock.MagicMock() + mock_state = mock.MagicMock() + + with mock.patch("jax.block_until_ready"): + checkpointing.save_checkpoint(mock_manager, step=10, state=mock_state, config=cfg) + mock_manager.save.assert_called_once() + _, kwargs = mock_manager.save.call_args + self.assertIn("custom_metadata", kwargs) + self.assertEqual(kwargs["custom_metadata"], {"lora": cfg.lora.model_dump()}) + + def test_save_and_restore_metadata_integration(self): + """Integration test checking that Orbax CheckpointManager writes and reads custom LoRA metadata.""" + + cfg_save = _make_config( + lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0}, + enable_checkpointing=True, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + manager = checkpointing.create_orbax_checkpoint_manager( + tmpdir, + enable_checkpointing=True, + use_async=False, + save_interval_steps=1, + use_ocdbt=False, + use_zarr3=False, + ) + + # Use save_checkpoint wrapper with a simple state + dummy_state = {"weight": jnp.array([1.0, 2.0])} + checkpointing.save_checkpoint(manager, step=0, state=dummy_state, config=cfg_save) + manager.wait_until_finished() + + # Now verify that the saved checkpoint contains metadata on disk + checkpoint_dir = epath.Path(tmpdir) / "0" + self.assertTrue((checkpoint_dir / "_CHECKPOINT_METADATA").exists()) + + # Restore using sync_lora_metadata on a config with default rank/alpha + cfg_restore = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": str(checkpoint_dir), + "lora_rank": 0, + "lora_alpha": 0.0, + } + ) + lora_utils.sync_lora_metadata(cfg_restore) + + # Verify values were successfully synced back + self.assertEqual(cfg_restore.lora.lora_rank, 8) + self.assertEqual(cfg_restore.lora.lora_alpha, 16.0) + def test_gemma4_lora_path_matching(self): """Test that the Gemma4 LoRA regex correctly matches all expected parameter paths.""" mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) @@ -309,10 +460,6 @@ def test_gemma4_lora_path_matching(self): "decoder/layers_remainder/layers/0/mlp/shared_experts/wi_0/kernel", "decoder/layers_remainder/layers/0/mlp/shared_experts/wi_1/kernel", "decoder/layers_remainder/layers/0/mlp/shared_experts/wo/kernel", - # No scanned_blocks/layers_remainder prefix (e.g. fallback or direct structure) - "decoder/layers/0/self_attention/query/kernel", - "decoder/layers/0/mlp/wi_0/kernel", - "decoder/layers/layers/0/mlp/shared_experts/wi_0/kernel", ] for path in matching_paths: diff --git a/tests/unit/hf_checkpoint_conversion_test.py b/tests/unit/hf_checkpoint_conversion_test.py index 6451a50fd5..a920e0a762 100644 --- a/tests/unit/hf_checkpoint_conversion_test.py +++ b/tests/unit/hf_checkpoint_conversion_test.py @@ -367,6 +367,7 @@ def test_get_maxtext_model_info(self): "hidden_size_per_layer_input=128", "vocab_size_per_layer_input=256", "vocab_size=256", + "skip_jax_distributed_system=True", ], override_model_config=True, ) @@ -417,7 +418,7 @@ def test_recursive_update(self): @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer") @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path") @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices") - def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls): + def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, _mock_path, mock_checkpointer_cls): # Mock jax devices mock_jax_devices.return_value = [MagicMock()]