From cccfc67ca8d05849c358d755615ee4e2152e97e8 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 29 Jun 2026 17:05:44 -0700 Subject: [PATCH] Add support for On-The-Fly Dynamic SafeTensors loading. PiperOrigin-RevId: 940112733 --- .../checkpoint_conversion/to_maxtext.py | 3 +- .../utils/load_dynamic.py | 371 ++++++++++++++++++ .../utils/tensor_handling.py | 197 ++++++++++ .../checkpoint_conversion/utils/utils.py | 135 ++++++- src/maxtext/common/checkpointing.py | 19 +- src/maxtext/configs/types.py | 7 +- src/maxtext/utils/maxtext_utils.py | 1 + tests/unit/checkpointing_test.py | 309 +++++++++++++++ tests/unit/configs_value_test.py | 40 +- 9 files changed, 1061 insertions(+), 21 deletions(-) create mode 100644 src/maxtext/checkpoint_conversion/utils/load_dynamic.py create mode 100644 src/maxtext/checkpoint_conversion/utils/tensor_handling.py create mode 100644 tests/unit/checkpointing_test.py diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 4245201b4e..2553469359 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -67,7 +67,8 @@ from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys +from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns +from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models diff --git a/src/maxtext/checkpoint_conversion/utils/load_dynamic.py b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py new file mode 100644 index 0000000000..2032cf1c41 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py @@ -0,0 +1,371 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format. + +This module allows loading HuggingFace checkpoints (in Safetensors format) +directly during MaxText training or evaluation runs, performing on-the-fly sharded +restore and CPU/TPU transformations. This avoids offline pre-conversion steps +and prevents host OOM. + +Usage: + To load Hugging Face checkpoints directly, configure the following flags: + 1. `source_checkpoint_layout`: Set to `"safetensors_dynamic"`. + 2. `load_parameters_path`: Set to the source path of the Hugging Face checkpoint. + +Examples: + A. Load from a Google Cloud Storage (GCS) directory containing `.safetensors`: + ``` + python3 maxtext/trainers/pre_train/train.py \ + maxtext/configs/base.yml \ + run_name=my_run \ + model_name=llama3.1-8b \ + source_checkpoint_layout="safetensors_dynamic" \ + load_parameters_path="gs://my-bucket/path/to/safetensors_directory/" + ``` + + B. Load directly from the Hugging Face Hub (automatically cached to GCS): + ``` + python3 maxtext/trainers/pre_train/train.py \ + maxtext/configs/base.yml \ + run_name=my_run \ + model_name=llama3.1-8b \ + source_checkpoint_layout="safetensors_dynamic" \ + load_parameters_path="hf://meta-llama/Meta-Llama-3-8B" \ + hf_access_token="" \ + base_output_directory="gs://my-bucket/output/" + ``` + + C. Load from Hugging Face Hub using automatic model_name resolution: + ``` + python3 maxtext/trainers/pre_train/train.py \ + maxtext/configs/base.yml \ + run_name=my_run \ + model_name=llama3.1-8b \ + source_checkpoint_layout="safetensors_dynamic" \ + load_parameters_path="" \ + hf_access_token="" \ + base_output_directory="gs://my-bucket/output/" + ``` + +Note: + - Hugging Face weights from HF Hub are cached to `base_output_directory`. + - When loading from Hugging Face Hub, `base_output_directory` must start with + "gs://" and `hf_access_token` is required if downloading gated models. +""" + +import concurrent.futures +import multiprocessing +import os +import random +import time + +from flax import nnx +import flax.traverse_util +from google.cloud import storage +import huggingface_hub +import jax +from maxtext.checkpoint_conversion.utils import hf_model_configs +from maxtext.checkpoint_conversion.utils import param_mapping +from maxtext.checkpoint_conversion.utils import tensor_handling +from maxtext.utils import gcs_utils +from maxtext.utils import globals as maxtext_globals +from maxtext.utils import max_logging +from orbax.checkpoint import v1 as ocp_v1 +from orbax.checkpoint._src.arrays import sharding as sharding_utils + + +HF_MODEL_CONFIGS = hf_model_configs.HF_MODEL_CONFIGS +get_hf_loading_function = tensor_handling.get_hf_loading_function + + +def build_gcs_cache_worker(fpath, gcs_cache_dir, hf_access_token): + """Caches a file from Hugging Face to a GCS bucket cache directory. + + Args: + fpath: The full remote file path on the Hugging Face virtual file system + (e.g., "meta-llama/Meta-Llama-3-8B/model-00001-of-00004.safetensors"). + gcs_cache_dir: The destination directory in GCS. + hf_access_token: The access token for Hugging Face. + """ + fs = huggingface_hub.HfFileSystem(token=hf_access_token) + time.sleep(random.uniform(0.0, 5.0)) + + bucket_name, blob_prefix = gcs_utils.parse_gcs_bucket_and_prefix(gcs_cache_dir) + blob_name = os.path.join(blob_prefix, os.path.basename(fpath)) + + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + if blob.exists(): + max_logging.log(f"[Worker] Cache hit for {os.path.basename(fpath)}.") + return + + t0 = time.time() + max_retries = 5 + for attempt in range(max_retries): + try: + with fs.open(fpath, "rb") as remote_f: + blob.chunk_size = 1024 * 1024 * 32 # 32MB chunks + blob.upload_from_file(remote_f, client=storage_client) + print( + f"[Worker] Cached {os.path.basename(fpath)} in" f" {time.time() - t0:.1f}s", + flush=True, + ) + break + except Exception as e: # pylint: disable=broad-exception-caught + if attempt < max_retries - 1: + max_logging.log( + f"Error fetching {fpath} to GCS: {e}. Retrying in 15 seconds..." f" (Attempt {attempt+1}/{max_retries})" + ) + time.sleep(15) + else: + max_logging.log(f"Failed to fetch {fpath} to GCS after {max_retries} attempts.") + raise + + +def get_hf_config_and_mappings(maxtext_config): + """Gets HF config and parameter mapping based on the MaxText config.""" + model_key = maxtext_config.model_name + if "-Instruct" in model_key: + model_key = model_key.replace("-Instruct", "") + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() + + param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key]( + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers + ) + hook_fn_map_mt = param_mapping.HOOK_FNS[model_key]( + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers, saving_to_hf=False + ) + return param_map_mt_to_hf, hook_fn_map_mt + + +def load_sharded_hf_state(path): + """Loads HF state with maximal sharding across TPU mesh to avoid host OOM. + + Args: + path: A directory path (either local or GCS starting with gs://) containing + the .safetensors files (e.g., "gs://my-bucket/hf_cache/model_id" or + "/path/to/safetensors_directory/"). If a Hugging Face Hub ID was used, + it should already be cached/downloaded to GCS before calling this + function. + + Returns: + The loaded Hugging Face state dictionary mapping parameter names to + JAX arrays. + """ + t0 = time.time() + context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) + with context: + metadata = ocp_v1.pytree_metadata(path) + simple_abstract_state = metadata.metadata + + # Distributed Sharded Download: Tell JAX to shard the HF Safetensors download + # across the entire TPU mesh to avoid Host OOM. + current_global_devices = jax.devices() + shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_global_devices) + + def combine_sharding(sds, single_sharding): + return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding) + + sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) + + max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...") + hf_state = ocp_v1.load_pytree(path, sharded_abstract_state) + max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s") + return hf_state + + +def transform_hf_state_to_mt_state(hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config): + """Transforms HF state into MaxText state by applying param mappings and mathematical hooks.""" + t0 = time.time() + + def tensor_getter(key): + return hf_state.pop(key) + + flat_target = flax.traverse_util.flatten_dict(target_tree, sep=".") + flat_restored = flat_target.copy() + + mapped_count = 0 + keys_missed = [] + max_logging.log("Starting fast in-memory Distributed Transformations...") + + for mt_key, hf_source in param_map_mt_to_hf.items(): + mt_name = mt_key.replace("params-", "").replace("-", ".") + + # Determine the correct key in flat_target + check_name = mt_name + if check_name not in flat_target: + if f"params.{mt_name}" in flat_target: + check_name = f"params.{mt_name}" + elif mt_key.replace("-", ".") in flat_target: + check_name = mt_key.replace("-", ".") + + if check_name not in flat_target: + keys_missed.append(mt_name) + continue + + target_leaf = flat_target[check_name] + hook_fn = hook_fn_map_mt.get(mt_key) + + load_fn = get_hf_loading_function( + hf_source, + tensor_getter, + hook_fn, + target_leaf, + maxtext_config, + ) + + # Execute transformation and assign to flat_restored + t_layer = time.time() + flat_restored[check_name] = load_fn() + + max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s") + mapped_count += 1 + + if mapped_count == 0: + max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}") + max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}") + + max_logging.log(f"Successfully mapped {mapped_count} parameters.") + restored_params = flax.traverse_util.unflatten_dict(flat_restored, sep=".") + + if "params" in restored_params: + restored_params = restored_params["params"] + + max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s") + + return {"params": restored_params} + + +def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config): + """Main entry point to dynamically build and load safetensors into MaxText format. + + Splits execution into: + 1. Deriving Mappings + 2. Loading Sharded arrays directly to TPUs + 3. Processing the transformations natively on TPUs + """ + if maxtext_config is None: + raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.") + + model_name = maxtext_config.model_name + if "-Instruct" in model_name: + model_name = model_name.replace("-Instruct", "") + + if not path: + if model_name not in maxtext_globals.HF_IDS: + raise ValueError(f"Unsupported model name for automatic HF repo resolution: {model_name}.") + path = maxtext_globals.HF_IDS[model_name] + + if path.startswith("hf://"): + path = path[5:] + + if not path.startswith("gs://") and not os.path.isdir(path): + fs = huggingface_hub.HfFileSystem(token=maxtext_config.hf_access_token) + repo_id = path + + files = fs.glob(f"{repo_id}/*.safetensors") + + host_id = jax.process_index() + + if hasattr(maxtext_config, "base_output_directory") and maxtext_config.base_output_directory.startswith("gs://"): + gcs_cache_dir = f"{maxtext_config.base_output_directory}/hf_cache/{repo_id.replace('/', '_')}" + path = gcs_cache_dir + + # Only Host 0 downloads to the shared GCS cache + if host_id == 0: + max_logging.log("Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS" f" Cache: {gcs_cache_dir}") + t_gcs_start = time.time() + + # List existing blobs to avoid spawning processes for already cached + # files + storage_client = storage.Client() + bucket_name = gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[0] + blob_prefix = ( + gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[1] + if "/" in gcs_cache_dir.replace("gs://", "") + else "" + ) + + existing_blobs = {blob.name for blob in storage_client.list_blobs(bucket_name, prefix=blob_prefix)} + + files_to_download = [] + for fpath in files: + expected_blob_name = os.path.join(blob_prefix, os.path.basename(fpath)) + if expected_blob_name not in existing_blobs: + files_to_download.append(fpath) + + if files_to_download: + with concurrent.futures.ProcessPoolExecutor( + max_workers=32, mp_context=multiprocessing.get_context("spawn") + ) as executor: + futures = [ + executor.submit( + build_gcs_cache_worker, + fpath, + gcs_cache_dir, + maxtext_config.hf_access_token, + ) + for fpath in files_to_download + ] + + while futures: + done, futures = concurrent.futures.wait(futures, timeout=10) + + # Raise any exceptions if a worker failed + for f in done: + f.result() + + t_gcs_end = time.time() + max_logging.log( + f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s." + f" Downloaded {len(files_to_download)} missing files." + ) + + # Global barrier: all hosts wait for Host 0 to finish downloading to the + # shared GCS bucket + max_logging.log(f"Host {host_id} waiting for GCS cache at {gcs_cache_dir} to be" " populated by Host 0...") + jax.experimental.multihost_utils.sync_global_devices("dynamic_hf_download_complete") + max_logging.log(f"Host {host_id} detected GCS cache is ready!") + + else: + raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.") + + t_total = time.time() + param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config) + max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s") + + target_tree = ( + abstract_unboxed_pre_state.to_pure_dict() + if isinstance(abstract_unboxed_pre_state, nnx.State) + else abstract_unboxed_pre_state.params + ) + + t1 = time.time() + hf_state = load_sharded_hf_state(path) + max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s") + + t2 = time.time() + # Transform Hugging Face weight tensors on-the-fly into MaxText format + # in-memory. This is done in-memory on each host, sharded across the mesh. + restored_params = transform_hf_state_to_mt_state( + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config + ) + max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s") + max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s") + + return None, restored_params diff --git a/src/maxtext/checkpoint_conversion/utils/tensor_handling.py b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py new file mode 100644 index 0000000000..508697624b --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py @@ -0,0 +1,197 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor handling utility functions for checkpoint conversion.""" + +from functools import partial +from typing import Any, Callable, List +import jax +import jax.numpy as np + + +def apply_hook_fns(weight, target_shape, hook_fns): + """Apply hook functions, essential for to_maxtext and to_huggingface""" + # If hook is unsepecified, use identity + if hook_fns is None: + return weight + if not isinstance(hook_fns, list): + hook_fns = [hook_fns] + # Apply a list of hooks, be careful of order + for hook_fn in hook_fns: + weight = hook_fn(weight, target_shape) + return weight + + +def _binary_chunked_stack(tensors: List[np.ndarray], axis: int) -> np.ndarray: + """Stacks JAX arrays along axis by binary division to limit memory usage from JAX compiler.""" + if not tensors: + raise ValueError("Cannot stack empty list of tensors.") + if len(tensors) == 1: + return np.expand_dims(tensors[0], axis=axis) + if len(tensors) == 2: + return np.stack(tensors, axis=axis) + + mid = len(tensors) // 2 + left = _binary_chunked_stack(tensors[:mid], axis=axis) + right = _binary_chunked_stack(tensors[mid:], axis=axis) + return np.concatenate([left, right], axis=axis) + + +def _build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_leaf: Any, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers) directly in place on device.""" + if hasattr(target_leaf, "sharding"): + target_shape = target_leaf.shape + target_sharding = target_leaf.sharding + target_dtype = target_leaf.dtype + else: + target_shape = target_leaf + target_sharding = None + target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 + + mt_slice_shape = target_shape[2:] + + # Pre-derive the compatible sharding specs to avoid rank mismatches + if target_sharding is not None and hasattr(target_sharding, "spec"): + # Target shape is (experts, layers, ...) -> slice is from index 2 onwards + spec_list = list(target_sharding.spec)[2:] + slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) + # Stacking layer shards + layer_spec_list = list(target_sharding.spec)[1:] + layer_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*layer_spec_list)) + else: + slice_sharding = target_sharding + layer_sharding = target_sharding + + all_expert_tensors = [] + # Outer loop iterates through experts + for layer_keys_for_expert in hf_source_keys: + layer_tensors_for_expert = [] + # Inner loop iterates through layers for the current expert + for hf_key_single in layer_keys_for_expert: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + + if target_sharding is not None: + processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) + layer_tensors_for_expert.append(processed_hf_tensor) + + expert_tensor = _binary_chunked_stack(layer_tensors_for_expert, axis=0) + if target_sharding is not None: + expert_tensor = jax.device_put(expert_tensor, layer_sharding) + all_expert_tensors.append(expert_tensor) + + stacked_array = _binary_chunked_stack(all_expert_tensors, axis=0).astype(target_dtype) + if target_sharding is not None: + stacked_array = jax.device_put(stacked_array, target_sharding) + return stacked_array + + +def _build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_leaf: Any, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along a single axis directly in place on device.""" + if hasattr(target_leaf, "sharding"): + target_shape = target_leaf.shape + target_sharding = target_leaf.sharding + target_dtype = target_leaf.dtype + else: + target_shape = target_leaf + target_sharding = None + target_dtype = target_leaf.dtype if hasattr(target_leaf, "dtype") else np.float32 + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + if target_sharding is not None and hasattr(target_sharding, "spec"): + spec_list = list(target_sharding.spec) + del spec_list[axis_to_stack] + slice_sharding = jax.sharding.NamedSharding(target_sharding.mesh, jax.sharding.PartitionSpec(*spec_list)) + else: + slice_sharding = target_sharding + + tensors_to_stack = [] + for hf_key_single in hf_source_keys: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + + if target_sharding is not None: + processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding) + tensors_to_stack.append(processed_hf_tensor) + + stacked_array = _binary_chunked_stack(tensors_to_stack, axis=axis_to_stack).astype(target_dtype) + if target_sharding is not None: + stacked_array = jax.device_put(stacked_array, target_sharding) + return stacked_array + + +def get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_leaf, config): + """Determine the loading function for HF keys.""" + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, leaf, hook): + if hasattr(leaf, "sharding"): + array = apply_hook_fns(getter(key), leaf.shape, hook) + return jax.device_put(array, device=leaf.sharding) + else: + return apply_hook_fns(getter(key), leaf, hook) + + return partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_leaf, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + return partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_leaf, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + return partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_leaf, + config, + ) diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 625197504b..0ee520a9aa 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -23,7 +23,8 @@ import time import json from concurrent.futures import ThreadPoolExecutor -from typing import Any +from functools import partial +from typing import Any, Callable, List from tqdm import tqdm import resource import numpy as np @@ -1246,3 +1247,135 @@ def save_weights_to_checkpoint( checkpoint_manager.wait_until_finished() max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min") + + +def _build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). + + This function handles the complex case for scanned MoE layers, producing a tensor + with the shape (num_experts, num_layers, ...). + + Args: + hf_source_keys: A nested (2D) list of Hugging Face parameter names. + Outer list iterates experts, inner list iterates layers. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + all_expert_tensors = [] + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:] + mt_slice_shape = target_shape[2:] + + # Outer loop iterates through experts + for layer_keys_for_expert in hf_source_keys: + layer_tensors_for_expert = [] + # Inner loop iterates through layers for the current expert + for hf_key_single in layer_keys_for_expert: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + layer_tensors_for_expert.append(processed_hf_tensor) + all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) + return np.stack(all_expert_tensors, axis=0) + + +def _build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> np.ndarray: + """Builds a MaxText tensor by stacking HF weights along a single axis. + + This function handles both standard scanned layers (e.g., attention) and + unscanned MoE layers (which are stacked along the expert axis). + + Args: + hf_source_keys: A 1D list of Hugging Face parameter names. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + tensors_to_stack = [] + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + for hf_key_single in hf_source_keys: + hf_tensor_numpy = tensor_getter_fn(hf_key_single) + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + tensors_to_stack.append(processed_hf_tensor) + + # Stack all processed tensors along the determined axis. + return np.stack(tensors_to_stack, axis=axis_to_stack) + + +def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): + """Determine the loading function for HF keys. + HF keys can take four forms: + Case 1: Unscanned (single string) + Case 2: Scanned (list of strings) + Case 3: Unscanned with expert stacking (list of strings) + Case 4: Scanned with expert stacking (nested list of strings) + """ + load_fn = None + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, shape, hook): + return apply_hook_fns(getter(key), shape, hook) + + load_fn = partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_shape_or_shapes, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + load_fn = partial( + _build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + load_fn = partial( + _build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + return load_fn diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 2f0d8c3a49..4e29d29581 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -34,6 +34,8 @@ from maxtext.utils import max_logging from maxtext.utils import gcs_utils from maxtext.utils import elastic_utils +from maxtext.checkpoint_conversion.utils.load_dynamic import load_safetensors_dynamic_state + import numpy as np import orbax.checkpoint as ocp from orbax.checkpoint import v1 as ocp_v1 @@ -764,6 +766,7 @@ def load_state_if_possible( checkpoint_conversion_fn=None, source_checkpoint_layout="orbax", expansion_factor_real_data: int = -1, + maxtext_config: Any | None = None, ): """Loads TrainState as possible from the inputs. @@ -897,7 +900,12 @@ def map_to_pspec(data): _assert_no_shaped_dtype_struct(restored) return (restored, None) - if load_parameters_from_path != "": + if source_checkpoint_layout == "safetensors_dynamic": + path = load_parameters_from_path or load_full_state_from_path + max_logging.log(f"Dynamic On-the-Fly Formatting: Loading SafeTensors from {path}") + + return load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config) + elif load_parameters_from_path != "": if isinstance(abstract_unboxed_pre_state, nnx.State): _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) else: @@ -953,13 +961,18 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute- def load_params_from_path( - load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True + load_parameters_from_path, + abstract_unboxed_params, + checkpoint_storage_concurrent_gb, + use_ocdbt=True, + use_zarr3=True, ): """Load decode params from checkpoint at specified path.""" assert load_parameters_from_path, "load_parameters_from_path is not defined." max_logging.log(f"restoring params from {load_parameters_from_path}") - # NNX target: the on-disk checkpoint is in Linen layout; reshape it into the NNX params state. + # NNX target: the on-disk checkpoint is in Linen layout; reshape it into the + # NNX params state. if isinstance(abstract_unboxed_params, nnx.State): return _load_linen_params_into_nnx( load_parameters_from_path, diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 0f28aefa81..b0f407eff3 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -348,7 +348,7 @@ class Checkpointing(BaseModel): save_quantized_params_path: PathStr = Field("", description="Path to save params quantized on the fly.") enable_orbax_v1: bool = Field(False, description="Bool flag for enabling Orbax v1.") checkpoint_conversion_fn: None | str = Field(None, description="Function for processing loaded checkpoint dict.") - source_checkpoint_layout: Literal["orbax", "safetensors"] = Field( + source_checkpoint_layout: Literal["orbax", "safetensors", "safetensors_dynamic"] = Field( "orbax", description="The layout of the source checkpoint to load." ) save_checkpoint_on_completion: bool = Field( @@ -3041,6 +3041,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # I. RUN ALL CROSS-FIELD VALIDATIONS if self.load_parameters_path and self.load_full_state_path: raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.") + if self.source_checkpoint_layout == "safetensors_dynamic" and self.enable_single_controller: + raise ValueError( + "`source_checkpoint_layout='safetensors_dynamic'` is not supported" + " on the Pathways backend (`enable_single_controller=True`)." + ) if self.elastic_enabled and not self.enable_single_controller: raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).") if self.colocated_python_data_input and not self.enable_single_controller: diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index f81528daf6..3451ec824d 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1491,6 +1491,7 @@ def setup_initial_state( checkpoint_conversion_fn=config.checkpoint_conversion_fn, source_checkpoint_layout=config.source_checkpoint_layout, expansion_factor_real_data=config.expansion_factor_real_data, + maxtext_config=config, ) if restored: diff --git a/tests/unit/checkpointing_test.py b/tests/unit/checkpointing_test.py new file mode 100644 index 0000000000..4b60e199e3 --- /dev/null +++ b/tests/unit/checkpointing_test.py @@ -0,0 +1,309 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the checkpointing components in checkpoint_conversion.""" + +from unittest import mock +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +from flax.training import train_state +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from maxtext.checkpoint_conversion.utils import load_dynamic +from maxtext.checkpoint_conversion.utils.tensor_handling import ( + _binary_chunked_stack, + get_hf_loading_function, +) +from maxtext.common import checkpointing +import numpy as np +import optax +import os +import safetensors.numpy + + +class BinaryChunkedStackTest(parameterized.TestCase): + """Tests for the `_binary_chunked_stack` function.""" + + def test_binary_chunked_stack(self): + # Test stacking 1, 2, 3, 5, 8, and 12 tensors + shapes = [(1,), (2, 3), (4, 5, 6)] + for shape in shapes: + for num_tensors in [1, 2, 3, 5, 8, 12]: + key = jax.random.PRNGKey(0) + tensors = [jax.random.normal(jax.random.fold_in(key, i), shape) for i in range(num_tensors)] + + # Test along various axes + for axis in range(-len(shape) - 1, len(shape) + 1): + expected = jnp.stack(tensors, axis=axis) + actual = _binary_chunked_stack(tensors, axis) + np.testing.assert_allclose(actual, expected) + + +class TensorHandlingTest(parameterized.TestCase): + """Tests for the tensor handling loader functions.""" + + def setUp(self): + super().setUp() + self.mesh = Mesh(np.array(jax.devices()[:1]), axis_names=("x",)) + self.sharding_rank4 = NamedSharding(self.mesh, PartitionSpec("x", None, None, None)) + self.sharding_rank3 = NamedSharding(self.mesh, PartitionSpec("x", None, None)) + + def test_get_hf_loading_function_case_2_3_single_axis(self): + # Tests Case 2/3 and lines 179 and gets loader for single axis stacked + class MockConfig: + + def __init__(self): + self.scan_layers = True + self.param_scan_axis = 0 + + config = MockConfig() + + target_leaf = jax.ShapeDtypeStruct( + shape=(2, 4, 4), + dtype=np.float32, + sharding=self.sharding_rank3, + ) + + hf_keys = ["layer_0.weight", "layer_1.weight"] + + tensors = { + "layer_0.weight": np.ones((4, 4), dtype=np.float32) * 10, + "layer_1.weight": np.ones((4, 4), dtype=np.float32) * 20, + } + + def getter_fn(key): + return tensors[key] + + hook_fn = None + + loader_fn = get_hf_loading_function(hf_keys, getter_fn, hook_fn, target_leaf, config) + + result = loader_fn() + + self.assertEqual(result.shape, (2, 4, 4)) + np.testing.assert_allclose(result[0], tensors["layer_0.weight"]) + np.testing.assert_allclose(result[1], tensors["layer_1.weight"]) + + def test_get_hf_loading_function_case_4_multi_axis(self): + # Tests Case 4, line 190, 73, and gets loader for multi-axis stacked + class MockConfig: + + def __init__(self): + self.scan_layers = True + self.param_scan_axis = 0 + + config = MockConfig() + + target_leaf = jax.ShapeDtypeStruct( + shape=(2, 2, 4, 4), + dtype=np.float32, + sharding=self.sharding_rank4, + ) + + hf_keys = [ + ["expert_0.layer_0.weight", "expert_0.layer_1.weight"], + ["expert_1.layer_0.weight", "expert_1.layer_1.weight"], + ] + + tensors = { + "expert_0.layer_0.weight": np.ones((4, 4), dtype=np.float32) * 11, + "expert_0.layer_1.weight": np.ones((4, 4), dtype=np.float32) * 12, + "expert_1.layer_0.weight": np.ones((4, 4), dtype=np.float32) * 21, + "expert_1.layer_1.weight": np.ones((4, 4), dtype=np.float32) * 22, + } + + def getter_fn(key): + return tensors[key] + + hook_fn = None + + loader_fn = get_hf_loading_function(hf_keys, getter_fn, hook_fn, target_leaf, config) + + result = loader_fn() + + self.assertEqual(result.shape, (2, 2, 4, 4)) + np.testing.assert_allclose(result[0, 0], tensors["expert_0.layer_0.weight"]) + np.testing.assert_allclose(result[0, 1], tensors["expert_0.layer_1.weight"]) + np.testing.assert_allclose(result[1, 0], tensors["expert_1.layer_0.weight"]) + np.testing.assert_allclose(result[1, 1], tensors["expert_1.layer_1.weight"]) + + +class LoadDynamicTest(parameterized.TestCase): + """Tests for cache downloads and dynamic loading of safetensors.""" + + @mock.patch("huggingface_hub.HfFileSystem") + @mock.patch("google.cloud.storage.Client") + def test_build_gcs_cache_worker_cache_hit(self, mock_storage_client, mock_hf_fs): + mock_client_instance = mock_storage_client.return_value + mock_bucket = mock_client_instance.bucket.return_value + mock_blob = mock_bucket.blob.return_value + mock_blob.exists.return_value = True + + load_dynamic.build_gcs_cache_worker("some_repo/model.safetensors", "gs://my-bucket/cache", "token") + mock_blob.exists.assert_called_once() + mock_blob.upload_from_file.assert_not_called() + + @mock.patch("huggingface_hub.HfFileSystem") + @mock.patch("google.cloud.storage.Client") + def test_build_gcs_cache_worker_cache_miss_success(self, mock_storage_client, mock_hf_fs): + mock_fs_instance = mock_hf_fs.return_value + mock_remote_file = mock.MagicMock() + mock_fs_instance.open.return_value.__enter__.return_value = mock_remote_file + + mock_client_instance = mock_storage_client.return_value + mock_bucket = mock_client_instance.bucket.return_value + mock_blob = mock_bucket.blob.return_value + mock_blob.exists.return_value = False + + load_dynamic.build_gcs_cache_worker("some_repo/model.safetensors", "gs://my-bucket/cache", "token") + mock_blob.exists.assert_called_once() + mock_blob.upload_from_file.assert_called_once_with(mock_remote_file, client=mock_client_instance) + + @mock.patch("huggingface_hub.HfFileSystem") + @mock.patch("google.cloud.storage.Client") + def test_build_gcs_cache_worker_retry_and_fail(self, mock_storage_client, mock_hf_fs): + mock_fs_instance = mock_hf_fs.return_value + mock_fs_instance.open.side_effect = Exception("Download failed") + + mock_client_instance = mock_storage_client.return_value + mock_bucket = mock_client_instance.bucket.return_value + mock_blob = mock_bucket.blob.return_value + mock_blob.exists.return_value = False + + with mock.patch("time.sleep"): + with self.assertRaises(Exception): + load_dynamic.build_gcs_cache_worker("some_repo/model.safetensors", "gs://my-bucket/cache", "token") + + @mock.patch.object(load_dynamic.huggingface_hub, "HfFileSystem") + @mock.patch.object(load_dynamic.storage, "Client") + @mock.patch.object(load_dynamic, "load_sharded_hf_state") + @mock.patch.object(load_dynamic, "transform_hf_state_to_mt_state") + @mock.patch("jax.process_index", return_value=0) + @mock.patch("jax.experimental.multihost_utils.sync_global_devices") + def test_load_safetensors_dynamic_from_hf_hub( + self, + mock_sync, + mock_process_index, + mock_transform, + mock_load_sharded, + mock_storage_client, + mock_hf_fs, + ): + mock_fs_instance = mock_hf_fs.return_value + mock_fs_instance.glob.return_value = ["repo/meta-llama/model.safetensors"] + + mock_client_instance = mock_storage_client.return_value + mock_blob = mock.MagicMock() + mock_blob.name = "hf_cache/repo_meta-llama/model.safetensors" + mock_client_instance.list_blobs.return_value = [mock_blob] + + mock_load_sharded.return_value = {} + mock_transform.return_value = {"params": {}} + + class MockConfig: + + def __init__(self): + self.model_name = "llama3.1-8b" + self.base_output_directory = "gs://dummy-bucket" + self.scan_layers = True + self.param_scan_axis = 0 + self.hf_access_token = "dummy_token" + + config = MockConfig() + + class DummyAbstractState: + + def __init__(self): + self.params = {} + + abstract_state = DummyAbstractState() + + path = "repo/meta-llama" + dummy_ret_val, loaded_vars = load_dynamic.load_safetensors_dynamic_state(path, abstract_state, config) + + self.assertIsNone(dummy_ret_val) + self.assertEqual(loaded_vars, {"params": {}}) + mock_hf_fs.assert_called_once_with(token="dummy_token") + mock_sync.assert_called_once_with("dynamic_hf_download_complete") + + +class SourceCheckpointLoadingTest(parameterized.TestCase): + """Tests for the `load_state_if_possible` function with safetensors_dynamic layout.""" + + def setUp(self): + super().setUp() + self.mesh = Mesh(np.array(jax.devices()[:1]), axis_names=("x",)) + self.sharding = NamedSharding(self.mesh, PartitionSpec()) + + self.tmp_dir = epath.Path(self.create_tempdir().full_path) + self.safetensors_ckpt_dir = self.tmp_dir / "hf_safetensors" + self.safetensors_ckpt_dir.mkdir(parents=True, exist_ok=True) + self.safetensors_ckpt_path = self.safetensors_ckpt_dir / "model.safetensors" + + def test_load_safetensors_dynamic_single_key(self): + if os.getenv("JAX_PLATFORMS") == "proxy": + self.skipTest("SafetensorsLayout is not supported on Pathways backend.") + # Save a single key (embedding weight) to a safetensors file + dummy_weight = np.arange(1024, dtype=np.float32).reshape(256, 4) + safetensors.numpy.save_file({"model.embed_tokens.weight": dummy_weight}, str(self.safetensors_ckpt_path)) + + # Setup mock config + class MockConfig: + + def __init__(self): + self.model_name = "llama3.1-8b" + self.base_output_directory = "gs://dummy-bucket" + self.scan_layers = True + self.param_scan_axis = 0 + self.hf_access_token = None + + config = MockConfig() + + # Target abstract state matching llama2 embeddings shape + target_state = { + "params": { + "token_embedder": { + "embedding": jax.ShapeDtypeStruct(shape=(256, 4), dtype=np.float32, sharding=self.sharding) + } + } + } + abstract_state = train_state.TrainState.create( + apply_fn=lambda x: x, params=target_state["params"], tx=optax.identity() + ) + + # Load using checkpointing framework dynamically + loaded_data, loaded_vars = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path=str(self.safetensors_ckpt_dir), + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=1, + abstract_unboxed_pre_state=abstract_state, + enable_orbax_v1=True, + source_checkpoint_layout="safetensors_dynamic", + maxtext_config=config, + ) + + self.assertIsNone(loaded_data) + self.assertIsNotNone(loaded_vars) + + # Assert values match + loaded_weight = loaded_vars["params"]["token_embedder"]["embedding"] + np.testing.assert_allclose(loaded_weight, dummy_weight) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/configs_value_test.py b/tests/unit/configs_value_test.py index d28001f7c8..0450e6671b 100644 --- a/tests/unit/configs_value_test.py +++ b/tests/unit/configs_value_test.py @@ -15,21 +15,19 @@ """Tests for the new pydantic-based configuration system.""" import os -import unittest -from unittest.mock import patch, MagicMock - -import pydantic +import unittest.mock +from absl.testing import absltest from maxtext.configs import pyconfig -from maxtext.configs.pyconfig import initialize_pydantic from maxtext.configs import types -from maxtext.utils.globals import MAXTEXT_REPO_ROOT +from maxtext.utils import globals as maxtext_globals +import pydantic -# Path to the base.yml config. This assumes that `pytest` is run from the project root. -_BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml") +# Path to the base.yml config. +_BASE_CONFIG_PATH = os.path.join(maxtext_globals.MAXTEXT_CONFIGS_DIR, "base.yml") -class ConfigTest(unittest.TestCase): +class ConfigTest(absltest.TestCase): """Tests for the new pydantic-based configuration system.""" def test_basic_config_loading(self): @@ -77,8 +75,8 @@ def test_derived_values(self): "gradient_accumulation_steps=2", ] # Mock jax.devices() to be deterministic - mock_devices = [MagicMock(slice_index=0) for _ in range(8)] - with patch("jax.devices", return_value=mock_devices): + mock_devices = [unittest.mock.MagicMock(slice_index=0) for _ in range(8)] + with unittest.mock.patch("jax.devices", return_value=mock_devices): config = pyconfig.initialize(argv) # global_parameter_scale=4 -> emb_scale=1, num_head_scale=1, mlp_dim_scale=1, layer_scale=0 @@ -98,14 +96,14 @@ def test_validation_error(self): with self.assertRaises(pydantic.ValidationError): pyconfig.initialize(argv) - @patch.dict(os.environ, {pyconfig.yaml_key_to_env_key("steps"): "123"}) + @unittest.mock.patch.dict(os.environ, {pyconfig.yaml_key_to_env_key("steps"): "123"}) def test_env_override(self): """Tests that environment variables override YAML values.""" argv = ["", _BASE_CONFIG_PATH, "run_name=test"] config = pyconfig.initialize(argv) self.assertEqual(config.steps, 123) - @patch.dict(os.environ, {pyconfig.yaml_key_to_env_key("steps"): "123"}) + @unittest.mock.patch.dict(os.environ, {pyconfig.yaml_key_to_env_key("steps"): "123"}) def test_cli_overrides_env_is_disallowed(self): """Tests that CLI arguments overriding environment variables fails.""" argv = ["", _BASE_CONFIG_PATH, "run_name=test", "steps=456"] @@ -129,7 +127,7 @@ def test_llama3_tokenizer_correction(self): def test_initialize_pydantic_bad_keys(self): """Test that `pydantic.ValidationError` is raised on keys not in MaxTextConfig""" with self.assertRaises(ValueError): - initialize_pydantic( + pyconfig.initialize_pydantic( [ "", _BASE_CONFIG_PATH, @@ -172,6 +170,18 @@ def test_gmm_v2_requires_tokamax_gmm(self): with self.assertRaises(pydantic.ValidationError): pyconfig.initialize(argv) + def test_safetensors_dynamic_disallows_single_controller(self): + """Tests that source_checkpoint_layout=safetensors_dynamic disallows enable_single_controller=True.""" + argv = [ + "", + _BASE_CONFIG_PATH, + "run_name=test", + "source_checkpoint_layout=safetensors_dynamic", + "enable_single_controller=true", + ] + with self.assertRaises(pydantic.ValidationError): + pyconfig.initialize(argv) + if __name__ == "__main__": - unittest.main() + absltest.main()