diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 4e185e93a1..9002caa16c 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -262,8 +262,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None): # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): trainer = MaxTextPeftTrainer(model, optimizer, tunix_config) - if mt_config.lora.lora_restore_path: - trainer = lora_utils.restore_lora_from_path(trainer, mt_config) + if mt_config.lora.lora_restore_path and trainer.train_steps == 0: + lora_utils.restore_lora_from_path(trainer.model, mt_config) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 6b4410f209..57fe38dc2a 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -19,7 +19,7 @@ import json import os import re -from typing import Any, Optional +from typing import Optional from flax import nnx, linen as nn from flax.linen import partitioning as nn_partitioning @@ -465,10 +465,9 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid return qwix.LoraProvider(**lora_kwargs) -def _prepare_dummy_inputs() -> tuple[jnp.ndarray, jnp.ndarray]: +def _prepare_dummy_inputs(dummy_bs: int = 1) -> tuple[jnp.ndarray, jnp.ndarray]: """Builds dummy decoder inputs used to materialize LoRA parameters.""" # Keep LoRA warmup as small as possible to minimize compile/memory overhead. - dummy_bs = 1 seq_len = 1 decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) decoder_positions = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32) @@ -483,30 +482,50 @@ def is_lora_enabled(model: nnx.Module) -> bool: return False -def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperParameters): +def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperParameters) -> None: """Validates that LoRA is active or that target modules were matched.""" if is_lora_enabled(lora_model): + wrapped_modules = set() + for path, value in nnx.iter_graph(lora_model): + if isinstance(value, nnx.LoRAParam): + if len(path) > 1: + parent_path = "/".join(str(p) for p in path[:-1]) + wrapped_modules.add(parent_path) + + if wrapped_modules: + wrapped_modules = sorted(list(wrapped_modules)) + max_logging.log( + f"LoRA configured: module_path='{_get_lora_module_path(mt_config)}' successfully matched " + f"{len(wrapped_modules)} target submodules." + ) + preview_limit = 20 + preview_modules = wrapped_modules[:preview_limit] + max_logging.log(f"Sample matched submodules ({len(preview_modules)} of {len(wrapped_modules)}): {preview_modules}") + else: + max_logging.log("LoRA is enabled. (Detailed submodules match report skipped due to mock model or empty state)") return lora_module_path = _get_lora_module_path(mt_config) compiled_module_path = re.compile(lora_module_path) - matched_module_paths = [] - sample_module_paths = [] + matched_module_paths = [] for path, _ in nnx.iter_modules(lora_model): module_path = "/".join(str(p) for p in path) - if len(sample_module_paths) < 100: - sample_module_paths.append(module_path) - if compiled_module_path.search(module_path): + if module_path and compiled_module_path.search(module_path): matched_module_paths.append(module_path) if not matched_module_paths: - max_logging.log( - f"LoRA module_path='{lora_module_path}' did not match any weights. " f"Sample module paths: {sample_module_paths}" - ) + max_logging.log(f"Error: LoRA module_path='{lora_module_path}' did not match any weights.") raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.") + # Simplify matched paths by replacing numeric layer indices with "*" to avoid redundant output + simplified_matches = sorted( + {"/".join("*" if p.isdigit() else p for p in path.split("/")) for path in matched_module_paths} + ) + max_logging.log(f"LoRA target verification: successfully matched {len(matched_module_paths)} modules.") + max_logging.log(f"Matched submodule patterns: {simplified_matches}") + raise ValueError( "LoRA module path matched target modules, but nnx.LoRAParam is still " "missing. For Tunix PeftTrainer, LoRA params must be materialized before " @@ -533,8 +552,12 @@ def apply_lora_to_model( lora_provider = _build_lora_provider(mt_config) + dp_size = 1 + if mesh is not None and "data" in mesh.shape: + dp_size = mesh.shape["data"] + model_rngs = getattr(model.decoder, "rngs", None) - decoder_input_tokens, decoder_positions = _prepare_dummy_inputs() + decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(dummy_bs=dp_size) lora_model = qwix.apply_lora_to_model( model, @@ -582,18 +605,25 @@ def _safe_reshard(var, sharding_spec): return lora_model -def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: - """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" - lora_restore_path = mt_config.lora.lora_restore_path +def restore_lora_from_path(model: nnx.Module, mt_config: pyconfig.HyperParameters) -> nnx.Module: + """Restores LoRA parameter weights from an external Orbax checkpoint. - train_steps = getattr(trainer, "train_steps", 0) - if train_steps > 0: - max_logging.log( - f"PeftTrainer restored current run at step {train_steps}; " f"ignoring lora_restore_path '{lora_restore_path}'." - ) - return trainer + This function performs the restore in-place on the model's parameters and + returns the model with the restored weights applied. + + Args: + model: The JAX/Flax NNX model (nnx.Module). + mt_config: The HyperParameters config containing the lora configuration. + + Returns: + The model with the restored LoRA weights applied in-place. + + Raises: + ValueError: If LoRA is not enabled on the model, but a restore path is set. + """ + lora_restore_path = mt_config.lora.lora_restore_path - if not is_lora_enabled(trainer.model): + if not is_lora_enabled(model): lora_module_path = _get_lora_module_path(mt_config) if not mt_config.lora.enable_lora: raise ValueError( @@ -601,7 +631,7 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." ) - abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) + abstract_lora_params = nnx.state(model, nnx.LoRAParam) target_for_restore = jax.tree.map( lambda v: {"value": v.value}, @@ -657,9 +687,9 @@ def _map_to_state(path, variable): is_leaf=lambda n: isinstance(n, nnx.Variable), ) - nnx.update(trainer.model, abstract_lora_params) + nnx.update(model, abstract_lora_params) max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") - return trainer + return model # NNX-shaped LoRA helpers. diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py index b0f229875d..2b61019bc8 100644 --- a/tests/post_training/unit/lora_utils_test.py +++ b/tests/post_training/unit/lora_utils_test.py @@ -130,26 +130,41 @@ def test_prepare_dummy_inputs(self): self.assertEqual(tokens.shape, (1, 1)) self.assertEqual(positions.shape, (1, 1)) - def test_verify_lora_parameters_enabled(self): - """Test verification of LoRA parameters when enabled.""" + def test_verify_lora_parameters_success(self): + """Test verification of LoRA parameters with matches and enabled LoRA.""" mock_model = mock.MagicMock() mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) + mock_config.lora = mock.MagicMock() + mock_config.lora.lora_module_path = ".*mlp/wi_0.*" + + mock_param = nnx.LoRAParam(0.0) + mock_graph_entries = [ + (("decoder", "layers", 0, "mlp", "wi_0", "lora_a"), mock_param), + ] - # Note: we use our local is_lora_enabled now - with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=True): - # Should not raise + with ( + mock.patch("maxtext.utils.lora_utils.nnx.iter_graph", return_value=mock_graph_entries), + mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=True), + mock.patch("maxtext.utils.max_logging.log") as mock_log, + ): lora_utils._verify_lora_parameters(mock_model, mock_config) + # Should log the successful match pattern summary + log_calls = [call[0][0] for call in mock_log.call_args_list] + self.assertTrue(any("successfully matched" in msg for msg in log_calls)) + self.assertTrue(any("Sample matched submodules" in msg for msg in log_calls)) + def test_verify_lora_parameters_not_enabled_no_match(self): - """Test verification fails when LoRA parameters are expected but not found.""" + """Test verification fails with ValueError when no modules match at all.""" mock_model = mock.MagicMock() mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) mock_config.lora = mock.MagicMock() - mock_config.model_name = "llama" mock_config.lora.lora_module_path = "non_existent" - with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=False): - mock_model.iter_modules.return_value = [] + with ( + mock.patch("flax.nnx.iter_modules", return_value=[]), + mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=False), + ): with self.assertRaisesRegex(ValueError, "no LoRA parameters found"): lora_utils._verify_lora_parameters(mock_model, mock_config) @@ -252,15 +267,11 @@ def test_restore_lora_from_path(self): model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) model = lora_utils.apply_lora_to_model(model, None, cfg) - trainer = mock.MagicMock() - trainer.model = model - trainer.train_steps = 0 - restored_state = nnx.state(model, nnx.LoRAParam) with mock.patch("orbax.checkpoint.PyTreeCheckpointer.restore", return_value=restored_state) as mock_restore: with mock.patch("flax.nnx.update") as mock_update: - lora_utils.restore_lora_from_path(trainer, cfg) + lora_utils.restore_lora_from_path(model, cfg) mock_restore.assert_called_once() args, kwargs = mock_restore.call_args self.assertEqual(args[0], "some/path") diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index 92adb4e921..bdbee338ee 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -261,8 +261,7 @@ def main(config, test_args): # pylint: disable=W0621 if config.lora.enable_lora: model = lora_utils.apply_lora_to_model(model, mesh, config) if config.lora.lora_restore_path: - mock_trainer = type("MockTrainer", (), {"model": model, "train_steps": 0}) - lora_utils.restore_lora_from_path(mock_trainer, config) + lora_utils.restore_lora_from_path(model, config) state = None else: model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) @@ -432,7 +431,7 @@ def main(config, test_args): # pylint: disable=W0621 max_logging.log(f"Loading HF model with dtype: {torch_dtype} (derived from config.dtype: {config.dtype})") hf_model = AutoModelForCausalLM.from_pretrained( - test_args.hf_model_path, dtype=torch_dtype, token=hf_token, trust_remote_code=test_args.trust_remote_code + test_args.hf_model_path, torch_dtype=torch_dtype, token=hf_token, trust_remote_code=test_args.trust_remote_code ) hf_lora_path = config.hf_lora_adapter_path if hf_lora_path: @@ -469,8 +468,7 @@ def main(config, test_args): # pylint: disable=W0621 if config.lora.enable_lora: maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config) if config.lora.lora_restore_path: - mock_trainer = type("MockTrainer", (), {"model": maxtext_model, "train_steps": 0}) - lora_utils.restore_lora_from_path(mock_trainer, config) + lora_utils.restore_lora_from_path(maxtext_model, config) maxtext_state = None else: maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)