Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions configs/recipes/gemma4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Configs for Google's Gemma 4 model family. See the [Hugging Face announcement](h
- [google/gemma-4-E2B-it](https://huggingface.co/google/gemma-4-E2B-it) (~5B) — **LoRA config available**
- [google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it) (~8B) — **FFT + LoRA configs available**
- Larger (image + text, 256K context)
- [google/gemma-4-26B-A4B-it](https://huggingface.co/google/gemma-4-26B-A4B-it) (MoE, 27B)
- [google/gemma-4-26B-A4B-it](https://huggingface.co/google/gemma-4-26B-A4B-it) (MoE, 27B) — **LoRA config available**
- [google/gemma-4-31B-it](https://huggingface.co/google/gemma-4-31B-it) (dense, 31B)

Gemma 4 requires accepting the model license on Hugging Face before downloading.
Expand Down Expand Up @@ -41,12 +41,13 @@ oumi launch up -c oumi://configs/recipes/gemma4/sft/e4b_full/gcp_job.yaml --clus

### LoRA Training

LoRA is scoped to the language-model layers only. Gemma 4's vision/audio towers
use `Gemma4ClippableLinear` wrappers that PEFT cannot adapt, and they share
projection names (`q_proj`, `v_proj`, ...) with the text model. The recipes target
the plain projection names and set `lora_exclude_modules: [".*vision_tower.*",
".*audio_tower.*"]`, which oumi passes to PEFT's `exclude_modules` to keep LoRA off
the towers.
LoRA is scoped to the language-model layers only. Gemma 4's non-text towers use
`Gemma4ClippableLinear` wrappers that PEFT cannot adapt, and they share projection
names (`q_proj`, `v_proj`, ...) with the text model. The recipes target the plain
projection names and set `lora_exclude_modules` to keep LoRA off the towers: the
Efficient (text+image+audio) models exclude `[".*vision_tower.*", ".*audio_tower.*"]`,
and the Larger (image+text) models exclude `[".*vision_tower.*", ".*multi_modal_projector.*"]`.
oumi passes this list to PEFT's `exclude_modules`.

To launch Gemma 4 E4B LoRA training locally (fits a single A100/H100):

Expand All @@ -65,3 +66,9 @@ To launch Gemma 4 E4B LoRA training on a remote GCP A100 cluster:
```shell
oumi launch up -c oumi://configs/recipes/gemma4/sft/e4b_lora/gcp_job.yaml --cluster gemma4-e4b-lora
```

To launch Gemma 4 27B (MoE) LoRA training on a remote GCP 8x A100 cluster:

```shell
oumi launch up -c oumi://configs/recipes/gemma4/sft/27b_lora/gcp_job.yaml --cluster gemma4-27b-lora
```
47 changes: 47 additions & 0 deletions configs/recipes/gemma4/sft/27b_lora/gcp_job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Job config to LoRA tune Gemma 4 27B (Mixture-of-Experts).
#
# Requirements:
# - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup
# - Gemma license acceptance required: https://huggingface.co/google/gemma-4-26B-A4B-it
# - Log into WandB (`wandb login`) or disable `enable_wandb`
#
# Usage:
# oumi launch up -c oumi://configs/recipes/gemma4/sft/27b_lora/gcp_job.yaml --cluster gemma4-27b-lora
#
# See Also:
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html
# - Config class: oumi.core.configs.JobConfig
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/job_config.py
# - Other job configs: configs/**/*job.yaml

name: gemma4-27b-lora

resources:
cloud: gcp
accelerators: "A100:8"
use_spot: false
disk_size: 500 # Disk size in GBs

working_dir: .

file_mounts:
~/.netrc: ~/.netrc # WandB credentials

envs:
WANDB_PROJECT: oumi-train
OUMI_RUN_NAME: gemma4-27b.lora

setup: |
set -e
pip install uv && uv pip install --system oumi[gpu] hf_transfer

run: |
set -e # Exit if any command failed.
source ./configs/examples/misc/sky_init.sh

set -x
oumi distributed torchrun -m oumi train \
-c oumi://configs/recipes/gemma4/sft/27b_lora/train.yaml \
--training.run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}"

echo "Node ${SKYPILOT_NODE_RANK} is all done!"
98 changes: 98 additions & 0 deletions configs/recipes/gemma4/sft/27b_lora/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# LoRA SFT config for Gemma 4 27B Instruct (Mixture-of-Experts).
#
# Model highlights:
# - 26.5B total parameters (~4B active per token, "A4B"), MoE multimodal model
# from Google (Gemma 4 Larger series)
# - 256K context length; image + text inputs
# - LoRA is scoped to the text transformer only. The vision tower and the
# multimodal projector use `Gemma4ClippableLinear` wrappers that PEFT cannot
# adapt, and share projection names with the text model — so they are excluded
# via `lora_exclude_modules` below.
# NOTE: on this MoE the standard gate_proj/up_proj/down_proj names do not
# match the (fused) expert MLP modules, so LoRA adapts only the attention
# projections (~9.3M params). Adapting the experts needs their module names.
#
# Requirements:
# - transformers >= 5.5.4
# - peft (installed automatically with oumi)
# - Gemma license acceptance required: https://huggingface.co/google/gemma-4-26B-A4B-it
# - Log into WandB (`wandb login`) or disable `enable_wandb`
#
# Usage:
# oumi distributed torchrun -m oumi train -c oumi://configs/recipes/gemma4/sft/27b_lora/train.yaml
#
# See Also:
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
# - Config class: oumi.core.configs.TrainingConfig
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
# - Other training configs: configs/**/*train.yaml

model:
model_name: "google/gemma-4-26B-A4B-it"
model_max_length: 8192
torch_dtype_str: "bfloat16"
attn_implementation: "sdpa"
trust_remote_code: false
enable_liger_kernel: false # Disabled (may conflict with Gemma output format).

data:
train:
datasets:
- dataset_name: "yahma/alpaca-cleaned" # 51,760 examples

training:
use_peft: true
trainer_type: "TRL_SFT"
save_final_model: true
num_train_epochs: 1
per_device_train_batch_size: 1
# NOTE: gradient accumulation inflates reported loss for Gemma 4 (~4x) due
# to a missing `accepts_loss_kwargs = False` on Gemma4ForConditionalGeneration.
# Fixed on transformers main but not yet released as of 5.5.4. See
# huggingface/transformers#40564 (same bug existed in Gemma 3).
gradient_accumulation_steps: 8
max_grad_norm: 1.0

enable_gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
compile: false

optimizer: "adamw_torch_fused"
learning_rate: 2.0e-04
lr_scheduler_type: "cosine"
warmup_ratio: 0.05
weight_decay: 0.01

dataloader_num_workers: "auto"
dataloader_prefetch_factor: 8
logging_steps: 5
empty_device_cache_steps: 50
output_dir: "output/gemma4_27b.lora"
include_performance_metrics: true
enable_wandb: true

peft:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- "q_proj"
- "k_proj"
- "v_proj"
- "o_proj"
- "gate_proj"
- "up_proj"
- "down_proj"
# Keep LoRA on the text transformer only; exclude the vision tower and the
# multimodal projector (image+text model, no audio tower).
lora_exclude_modules:
- ".*vision_tower.*"
- ".*multi_modal_projector.*"

fsdp:
enable_fsdp: true
sharding_strategy: "FULL_SHARD"
forward_prefetch: true
auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
transformer_layer_cls: "Gemma4TextDecoderLayer"