Skip to content

refactor: streamline LoRA parameter restoration for JAX/NNX models#4067

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/lora-nnx-restore-refactor
Open

refactor: streamline LoRA parameter restoration for JAX/NNX models#4067
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/lora-nnx-restore-refactor

Conversation

@RexBearIU

Copy link
Copy Markdown
Collaborator

Description

Refactors the LoRA parameter restoration utility (restore_lora_from_path) to natively accept and merge parameters of JAX/NNX model wrapper classes.
Standardizing on the NNX state tree representation removes complex state-container unpackings, enhancing trainer safety and maintainability.

Key Changes:

  • LoRA Utilities (src/maxtext/utils/lora_utils.py): Overhauled restore_lora_from_path to directly query and merge parameters of NNX wrapper
    classes.
  • SFT Trainer (src/maxtext/trainers/post_train/sft/train_sft.py): Simplified parameter setup sequence to cleanly map restored LoRA states
    directly onto the model.

Tests

  • Verified via unit testing for LoRA restore functions:
python3 -m pytest tests/post_training/unit/lora_utils_test.py 

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@RexBearIU RexBearIU changed the title feat(lora): Refactor restore_lora_from_path to accept JAX/NNX models … refactor: streamline LoRA parameter restoration for JAX/NNX models Jun 4, 2026
@RexBearIU RexBearIU force-pushed the jackyf/lora-nnx-restore-refactor branch 2 times, most recently from 3360657 to 0639ae2 Compare June 4, 2026 14:42
@codecov

codecov Bot commented Jun 4, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 70.58824% with 10 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/lora_utils.py 75.00% 5 Missing and 3 partials ⚠️
src/maxtext/trainers/post_train/sft/train_sft.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@RexBearIU RexBearIU force-pushed the jackyf/lora-nnx-restore-refactor branch 3 times, most recently from 4d51a1c to 019d93d Compare June 8, 2026 08:35
@RexBearIU RexBearIU force-pushed the jackyf/lora-nnx-restore-refactor branch 2 times, most recently from a0f3ab9 to 2b0c32a Compare June 25, 2026 10:28
@RexBearIU RexBearIU force-pushed the jackyf/lora-nnx-restore-refactor branch from 2b0c32a to 936b8b7 Compare June 25, 2026 14:53
@RexBearIU RexBearIU force-pushed the jackyf/lora-nnx-restore-refactor branch 4 times, most recently from 7b8865f to 2316d38 Compare June 30, 2026 10:33
@RexBearIU RexBearIU force-pushed the jackyf/lora-nnx-restore-refactor branch from 2316d38 to 303587b Compare June 30, 2026 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants