diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index b17e318570..a97f02bdb6 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -349,6 +349,15 @@ def _reshape(inputs, out_shape, out_sharding): # custom_vjp + lax.scan boundary, which fails for tied embeddings. graphdef, head_params, other_params, rest = nnx.split(model, _is_output_head_param_path, nnx.Param, ...) + # all gather only the embedding table + head_params = all_gather_over_fsdp( + head_params, + nnx.get_partition_spec(head_params), + model.mesh, + config.logical_axis_rules, + config.shard_mode, + ) + def _logits_for_chunk(chunk_head_params, chunk_other_params, chunk_rest, hidden_chunk): local_model = nnx.merge(graphdef, chunk_head_params, chunk_other_params, chunk_rest, copy=True) chunk_logits = local_model.logits_from_hidden_states_for_vocab_tiling(hidden_chunk, deterministic, model_mode)