Skip to content

feat(lora): save/restore LoRA config in checkpoint metadata#4269

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/lora-ckpt-metadata
Open

feat(lora): save/restore LoRA config in checkpoint metadata#4269
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/lora-ckpt-metadata

Conversation

@RexBearIU

@RexBearIU RexBearIU commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR implements native serialization of LoRA configuration parameters (lora_rank, lora_alpha) in standard Orbax _CHECKPOINT_METADATA files, and automatically restores them during checkpoint-to-Hugging Face conversion.

Why is this change being made?

Previously, users had to manually supply matching lora.lora_rank and lora.lora_alpha parameters when converting MaxText checkpoints to Hugging Face format. Storing them in Orbax metadata makes the conversion seamless and error-free (resolves @igorts-git's request in #3970).

Key Implementation Details

  • Serialization: In save_checkpoint (checkpointing.py), we save the active config.lora block under the "lora" key in Orbax's custom_metadata when a LoRA rank is specified.
  • Restoration: In main (to_huggingface.py), sync_lora_metadata reads the custom metadata from lora_restore_path via ocp.StandardCheckpointer and overrides active config parameters during conversion.
  • Fail-Fast Safety: Scoped strictly to the conversion path to ensure SFT training paths remain strict and fail fast on any configuration mismatches.
  • Test Import Refactoring: Refactored hf_checkpoint_conversion_test.py to move dynamically loaded inline imports to global top-level imports and completely removed json import since JSON string is written directly.

BUGS: #3970

Tests

We have verified the implementation with complete suite-level and individual unit-tests:

  1. Added/Updated Unit Tests:
    • SyncLoRAMetadataTest in tests/unit/hf_checkpoint_conversion_test.py to verify the auto-resolving mechanism during Hugging Face conversion.
  2. Command to run:
    python tests/unit/hf_checkpoint_conversion_test.py
    All tests pass successfully.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.00000% with 9 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/lora_utils.py 77.77% 4 Missing and 2 partials ⚠️
...rc/maxtext/checkpoint_conversion/to_huggingface.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shralex shralex left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.

@RexBearIU RexBearIU changed the title feat(lora): serialize and load lora_config.json sidecar metadata feat(lora): save and auto-restore LoRA rank/alpha using native Orbax custom_metadata Jun 25, 2026
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from 187905b to cd17578 Compare June 25, 2026 15:13
@RexBearIU

RexBearIU commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

Hi @shralex, thank you for the feedback!

I have fully addressed your comments with the following changes:

  1. Checkpoint Restore Auto-Sync: Implemented automatic LoRA rank and alpha syncing from the Orbax native _CHECKPOINT_METADATA file's custom_metadata on the training/SFT restore path (restore_lora_from_path in lora_utils.py). Now, training/SFT runs resuming or restoring from a LoRA checkpoint will automatically detect, sync, and apply the correct LoRA rank and alpha parameters from the saved checkpoint metadata.
  2. Unified Native Orbax Metadata: Switched from creating and loading a custom lora_config.json to using Orbax's native custom_metadata dictionary inside _CHECKPOINT_METADATA. This conforms perfectly to standard checkpointing conventions without introducing any custom, out-of-band config files.
  3. Path Resilience: Enhanced metadata resolution to support paths pointing to either the step directory directly (e.g., .../checkpoints/1000/) or to any nested parameter subfolders (e.g., .../checkpoints/1000/items/), resolving parent paths gracefully.
  4. Expanded Unit Tests & Linting: Added and modified tests (SyncLoRAMetadataTest and SyncLoRAMetadataTrainingTest in both test suites) covering both conversion and training/SFT-side auto-restore flows. Verified everything compiles, passes all pre-commit formatting/styling, and is 100% green!

Please let me know if you would like any other enhancements!

Comment thread src/maxtext/checkpoint_conversion/to_huggingface.py Outdated
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from cd17578 to 1b15640 Compare June 25, 2026 16:02
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from 1b15640 to ae44adc Compare June 25, 2026 16:11
Comment thread tests/unit/hf_checkpoint_conversion_test.py Outdated
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch 3 times, most recently from 69c78a7 to a701719 Compare June 26, 2026 02:50
@RexBearIU RexBearIU changed the title feat(lora): save and auto-restore LoRA rank/alpha using native Orbax custom_metadata feat(lora): save/restore LoRA config in checkpoint metadata Jun 26, 2026
@xibinliu xibinliu force-pushed the jackyf/lora-ckpt-metadata branch from a701719 to 07c5e19 Compare June 26, 2026 16:42
@xibinliu

Copy link
Copy Markdown
Collaborator

Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.

added the logic to re-use the metadata for checkpoint restore.

@xibinliu xibinliu force-pushed the jackyf/lora-ckpt-metadata branch 3 times, most recently from 5940e65 to 9bc253e Compare June 26, 2026 23:29
Comment thread src/maxtext/utils/lora_utils.py Outdated
)
return trainer

sync_lora_metadata(mt_config)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this down to after we verified that lora is enabled

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I have moved sync_lora_metadata(config) down in to_huggingface.py so that it is called after we verify that LoRA is indeed enabled in the model configuration.


def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any:
"""Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run."""
lora_restore_path = mt_config.lora.lora_restore_path

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a check here:

if not lora_restore_path:
return trainer # No restore requested; exit cleanly without error

(otherwise we're relying on the callers to always call this function when this path is set)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Added the guard check at the beginning of restore_lora_from_path so it returns the trainer early and exits cleanly if lora_restore_path is not set.

@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from 9bc253e to 0f6248b Compare June 29, 2026 09:46

@shralex shralex left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version reverts Xibin's previous version where sync_lora_metadata was in lora_utils. We should move it back there and use it not just on checkpoint conversion but also before model creation.

Comment thread src/maxtext/common/checkpointing.py Outdated
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)

custom_metadata = None
if config and config.lora.lora_rank > 0:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets check that config contains "lora" before accessing config.lora.lora_rank:

if config and hasattr(config, "lora") and config.lora:
lora_rank = getattr(config.lora, "lora_rank", 0)
if lora_rank > 0 and hasattr(config.lora, "model_dump"):
custom_metadata = {"lora": config.lora.model_dump()}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Added checks to ensure config has the lora attribute and is not None before attempting to access lora_rank or model_dump.

Comment thread src/maxtext/common/checkpointing.py Outdated
replicator_error_handler(config)
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
return checkpoint_manager.save(
step, args=Composite(state=checkpoint_args), force=force, custom_metadata=custom_metadata

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EmergencyCheckpointManager and EmergencyReplicatorCheckpointManager do not accept a custom metadata argument. Lets leave this argument out here, and open a bug to add this support

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Omitted passing the custom_metadata argument when calling .save() on EmergencyCheckpointManager or EmergencyReplicatorCheckpointManager.

@RexBearIU RexBearIU Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a bug b/529671188 for Orbax team to add support on EmergencyCheckpointManager or EmergencyReplicatorCheckpointManager

@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch 3 times, most recently from d55b90d to ffe10de Compare June 30, 2026 08:33
@RexBearIU

RexBearIU commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

This version reverts Xibin's previous version where sync_lora_metadata was in lora_utils. We should move it back there and use it not just on checkpoint conversion but also before model creation.

Done. Moved sync_lora_metadata back to lora_utils.py (running during checkpoint restore) with a clean, formatting-free diff.

@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from ffe10de to 2649217 Compare June 30, 2026 10:07
Co-authored-by: Xibin Liu <xibin@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants