Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
# Provide rules context so 'norm' is translated to mesh axes during maybe_restore
with nn_partitioning.axis_rules(mt_config.logical_axis_rules):
trainer = MaxTextPeftTrainer(model, optimizer, tunix_config)
if mt_config.lora.lora_restore_path:
trainer = lora_utils.restore_lora_from_path(trainer, mt_config)
if mt_config.lora.lora_restore_path and trainer.train_steps == 0:
lora_utils.restore_lora_from_path(trainer.model, mt_config)
trainer.with_training_hooks(training_hooks)
trainer.with_data_hooks(data_hooks)
trainer = use_maxtext_loss_function(trainer, mt_config)
Expand Down
82 changes: 56 additions & 26 deletions src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import os
import re
from typing import Any, Optional
from typing import Optional

from flax import nnx, linen as nn
from flax.linen import partitioning as nn_partitioning
Expand Down Expand Up @@ -465,10 +465,9 @@ def _build_lora_provider(mt_config: pyconfig.HyperParameters) -> qwix.LoraProvid
return qwix.LoraProvider(**lora_kwargs)


def _prepare_dummy_inputs() -> tuple[jnp.ndarray, jnp.ndarray]:
def _prepare_dummy_inputs(dummy_bs: int = 1) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Builds dummy decoder inputs used to materialize LoRA parameters."""
# Keep LoRA warmup as small as possible to minimize compile/memory overhead.
dummy_bs = 1
seq_len = 1
decoder_input_tokens = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32)
decoder_positions = jnp.zeros((dummy_bs, seq_len), dtype=jnp.int32)
Expand All @@ -483,30 +482,50 @@ def is_lora_enabled(model: nnx.Module) -> bool:
return False


def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperParameters):
def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperParameters) -> None:
"""Validates that LoRA is active or that target modules were matched."""

if is_lora_enabled(lora_model):
wrapped_modules = set()
for path, value in nnx.iter_graph(lora_model):
if isinstance(value, nnx.LoRAParam):
if len(path) > 1:
parent_path = "/".join(str(p) for p in path[:-1])
wrapped_modules.add(parent_path)

if wrapped_modules:
wrapped_modules = sorted(list(wrapped_modules))
max_logging.log(
f"LoRA configured: module_path='{_get_lora_module_path(mt_config)}' successfully matched "
f"{len(wrapped_modules)} target submodules."
)
preview_limit = 20
preview_modules = wrapped_modules[:preview_limit]
max_logging.log(f"Sample matched submodules ({len(preview_modules)} of {len(wrapped_modules)}): {preview_modules}")
else:
max_logging.log("LoRA is enabled. (Detailed submodules match report skipped due to mock model or empty state)")
return

lora_module_path = _get_lora_module_path(mt_config)
compiled_module_path = re.compile(lora_module_path)
matched_module_paths = []
sample_module_paths = []

matched_module_paths = []
for path, _ in nnx.iter_modules(lora_model):
module_path = "/".join(str(p) for p in path)
if len(sample_module_paths) < 100:
sample_module_paths.append(module_path)
if compiled_module_path.search(module_path):
if module_path and compiled_module_path.search(module_path):
matched_module_paths.append(module_path)

if not matched_module_paths:
max_logging.log(
f"LoRA module_path='{lora_module_path}' did not match any weights. " f"Sample module paths: {sample_module_paths}"
)
max_logging.log(f"Error: LoRA module_path='{lora_module_path}' did not match any weights.")
raise ValueError("LoRA enabled but no LoRA parameters found in decoder/model state.")

# Simplify matched paths by replacing numeric layer indices with "*" to avoid redundant output
simplified_matches = sorted(
{"/".join("*" if p.isdigit() else p for p in path.split("/")) for path in matched_module_paths}
)
max_logging.log(f"LoRA target verification: successfully matched {len(matched_module_paths)} modules.")
max_logging.log(f"Matched submodule patterns: {simplified_matches}")

raise ValueError(
"LoRA module path matched target modules, but nnx.LoRAParam is still "
"missing. For Tunix PeftTrainer, LoRA params must be materialized before "
Expand All @@ -533,8 +552,12 @@ def apply_lora_to_model(

lora_provider = _build_lora_provider(mt_config)

dp_size = 1
if mesh is not None and "data" in mesh.shape:
dp_size = mesh.shape["data"]

model_rngs = getattr(model.decoder, "rngs", None)
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs()
decoder_input_tokens, decoder_positions = _prepare_dummy_inputs(dummy_bs=dp_size)

lora_model = qwix.apply_lora_to_model(
model,
Expand Down Expand Up @@ -582,26 +605,33 @@ def _safe_reshard(var, sharding_spec):
return lora_model


def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any:
"""Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run."""
lora_restore_path = mt_config.lora.lora_restore_path
def restore_lora_from_path(model: nnx.Module, mt_config: pyconfig.HyperParameters) -> nnx.Module:
"""Restores LoRA parameter weights from an external Orbax checkpoint.

train_steps = getattr(trainer, "train_steps", 0)
if train_steps > 0:
max_logging.log(
f"PeftTrainer restored current run at step {train_steps}; " f"ignoring lora_restore_path '{lora_restore_path}'."
)
return trainer
This function performs the restore in-place on the model's parameters and
returns the model with the restored weights applied.

Args:
model: The JAX/Flax NNX model (nnx.Module).
mt_config: The HyperParameters config containing the lora configuration.

Returns:
The model with the restored LoRA weights applied in-place.

Raises:
ValueError: If LoRA is not enabled on the model, but a restore path is set.
"""
lora_restore_path = mt_config.lora.lora_restore_path

if not is_lora_enabled(trainer.model):
if not is_lora_enabled(model):
lora_module_path = _get_lora_module_path(mt_config)
if not mt_config.lora.enable_lora:
raise ValueError(
"lora_restore_path is set but LoRA is not enabled on the model. "
f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules."
)

abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam)
abstract_lora_params = nnx.state(model, nnx.LoRAParam)

target_for_restore = jax.tree.map(
lambda v: {"value": v.value},
Expand Down Expand Up @@ -657,9 +687,9 @@ def _map_to_state(path, variable):
is_leaf=lambda n: isinstance(n, nnx.Variable),
)

nnx.update(trainer.model, abstract_lora_params)
nnx.update(model, abstract_lora_params)
max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.")
return trainer
return model


# NNX-shaped LoRA helpers.
Expand Down
39 changes: 25 additions & 14 deletions tests/post_training/unit/lora_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,41 @@ def test_prepare_dummy_inputs(self):
self.assertEqual(tokens.shape, (1, 1))
self.assertEqual(positions.shape, (1, 1))

def test_verify_lora_parameters_enabled(self):
"""Test verification of LoRA parameters when enabled."""
def test_verify_lora_parameters_success(self):
"""Test verification of LoRA parameters with matches and enabled LoRA."""
mock_model = mock.MagicMock()
mock_config = mock.MagicMock(spec=pyconfig.HyperParameters)
mock_config.lora = mock.MagicMock()
mock_config.lora.lora_module_path = ".*mlp/wi_0.*"

mock_param = nnx.LoRAParam(0.0)
mock_graph_entries = [
(("decoder", "layers", 0, "mlp", "wi_0", "lora_a"), mock_param),
]

# Note: we use our local is_lora_enabled now
with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=True):
# Should not raise
with (
mock.patch("maxtext.utils.lora_utils.nnx.iter_graph", return_value=mock_graph_entries),
mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=True),
mock.patch("maxtext.utils.max_logging.log") as mock_log,
):
lora_utils._verify_lora_parameters(mock_model, mock_config)

# Should log the successful match pattern summary
log_calls = [call[0][0] for call in mock_log.call_args_list]
self.assertTrue(any("successfully matched" in msg for msg in log_calls))
self.assertTrue(any("Sample matched submodules" in msg for msg in log_calls))

def test_verify_lora_parameters_not_enabled_no_match(self):
"""Test verification fails when LoRA parameters are expected but not found."""
"""Test verification fails with ValueError when no modules match at all."""
mock_model = mock.MagicMock()
mock_config = mock.MagicMock(spec=pyconfig.HyperParameters)
mock_config.lora = mock.MagicMock()
mock_config.model_name = "llama"
mock_config.lora.lora_module_path = "non_existent"

with mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=False):
mock_model.iter_modules.return_value = []
with (
mock.patch("flax.nnx.iter_modules", return_value=[]),
mock.patch("maxtext.utils.lora_utils.is_lora_enabled", return_value=False),
):
with self.assertRaisesRegex(ValueError, "no LoRA parameters found"):
lora_utils._verify_lora_parameters(mock_model, mock_config)

Expand Down Expand Up @@ -252,15 +267,11 @@ def test_restore_lora_from_path(self):
model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN)
model = lora_utils.apply_lora_to_model(model, None, cfg)

trainer = mock.MagicMock()
trainer.model = model
trainer.train_steps = 0

restored_state = nnx.state(model, nnx.LoRAParam)

with mock.patch("orbax.checkpoint.PyTreeCheckpointer.restore", return_value=restored_state) as mock_restore:
with mock.patch("flax.nnx.update") as mock_update:
lora_utils.restore_lora_from_path(trainer, cfg)
lora_utils.restore_lora_from_path(model, cfg)
mock_restore.assert_called_once()
args, kwargs = mock_restore.call_args
self.assertEqual(args[0], "some/path")
Expand Down
8 changes: 3 additions & 5 deletions tests/utils/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ def main(config, test_args): # pylint: disable=W0621
if config.lora.enable_lora:
model = lora_utils.apply_lora_to_model(model, mesh, config)
if config.lora.lora_restore_path:
mock_trainer = type("MockTrainer", (), {"model": model, "train_steps": 0})
lora_utils.restore_lora_from_path(mock_trainer, config)
lora_utils.restore_lora_from_path(model, config)
state = None
else:
model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
Expand Down Expand Up @@ -432,7 +431,7 @@ def main(config, test_args): # pylint: disable=W0621
max_logging.log(f"Loading HF model with dtype: {torch_dtype} (derived from config.dtype: {config.dtype})")

hf_model = AutoModelForCausalLM.from_pretrained(
test_args.hf_model_path, dtype=torch_dtype, token=hf_token, trust_remote_code=test_args.trust_remote_code
test_args.hf_model_path, torch_dtype=torch_dtype, token=hf_token, trust_remote_code=test_args.trust_remote_code
)
hf_lora_path = config.hf_lora_adapter_path
if hf_lora_path:
Expand Down Expand Up @@ -469,8 +468,7 @@ def main(config, test_args): # pylint: disable=W0621
if config.lora.enable_lora:
maxtext_model = lora_utils.apply_lora_to_model(maxtext_model, mesh, config)
if config.lora.lora_restore_path:
mock_trainer = type("MockTrainer", (), {"model": maxtext_model, "train_steps": 0})
lora_utils.restore_lora_from_path(mock_trainer, config)
lora_utils.restore_lora_from_path(maxtext_model, config)
maxtext_state = None
else:
maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
Expand Down
Loading