feat(engine): Engine training API#2556
Draft
HuiyingLi wants to merge 12 commits into
Draft
Conversation
Phase 1 of the tinker-like training API: a typed boundary between post-training/RL algorithm code and the training-loop internals. Input half (components.datasets.datum) — a data concern, placed there so it can reuse the canonical collaters instead of forking padding/packing: - Datum: single-example input contract (input_ids + loss_inputs dict), with Datum.to_features() emitting the per-example dict the dataset pipeline produces (labels masked with -100 where weights == 0). - collate_datums: delegates token fields to the existing default_collater (padded) / packed_sequence_thd_collater (THD) so the canonical schema (attention_mask / qkv_format / seq_lens) is reused unchanged; additionally batches float per-token side-inputs (weights/logprobs/advantages) that those collaters cannot carry. Output half (components.training.model_output) — a forward concern (touches logits), kept out of datasets: - ModelOutput: per-datum output contract (loss, logprobs, entropy, values, metrics), list fields aligned to input order. - selected_token_logprobs / compute_entropy / split_per_datum: pure per-token extraction helpers RL frameworks otherwise reimplement. No forked collation; import-linter "components must not import each other" kept. Additive: no training-loop or recipe changes. 19 unit tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…hed)
A small in-process Engine — the home for the tinker-like training API on top of
the Datum/ModelOutput contract. PURELY ADDITIVE: new top-level module
nemo_automodel/engine.py; recipes/llm/train_ft.py and all existing components
are unchanged. Nothing imports the Engine yet — it's opt-in for recipes/RL
frameworks.
Surface (non-PP path):
- Engine.Config + two construction paths: build from config (reuses the LLM
recipe's build_model + optimizer .build/shard — identical objects), or inject
already-built model_parts/optimizers (tests, custom builders).
- forward_backward(batch, loss_fn): accepts a dict, list of microbatch dicts, or
list[Datum]; returns ModelOutput for Datum input, else {"loss","metrics"}.
Owns the microbatch lifecycle (prepare_* hooks, MoE aux-loss scaling,
CP/THD via make_cp_batch_and_ctx, final-backward) so callers don't.
- forward(datums): forward-only, returns per-datum logprobs/entropy.
- optimizer_step()/optim_step(lr=)/zero_grad()/lr_scheduler_step() (list-aware).
- train_mode()/eval_mode(), to(device) offload, export_weights() streaming.
Reuses existing primitives throughout (no forked training math). Pipeline
parallelism in forward/forward_backward is a documented follow-up and raises
clearly. 14 CPU unit tests via the injection path.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…tion Phase 3: align the loss path to the tinker doc's LossFn contract and move loss normalization into the Engine (also what verl needs). - nemo_automodel/loss_fns.py: LossFn = Callable[[ModelOutput, Sequence[Datum]], Sequence[Tensor]] + built-ins cross_entropy / importance_sampling / ppo and a BUILTIN_LOSSES registry. Losses return the un-reduced per-token signal; they carry no distributed/scaling knowledge. - Engine.forward_backward Datum mode: loss_fn is a built-in name (default "cross_entropy") or a LossFn callable. The Engine owns the doc's normalization — applies loss_inputs["weights"], divides by the GLOBAL token (or sample) count all-reduced across data ranks, runs the microbatch lifecycle + backward, and returns a ModelOutput. Matches the recipe's non-PP loss formula (weighted-sum / global-token-count, cf. MaskedCrossEntropy), so DP scaling stays identical to the proven path. - Dict mode (legacy SFT via calculate_loss) is unchanged. - _build_model_output shared by forward() and the loss path; supports list[Datum] (one microbatch) and list[list[Datum]] (grad accumulation). Additive; recipe/components untouched; import-linter kept. 19 new tests (pure loss math + engine grad-flow + normalization). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…nd-to-end
Adds the second input door so already-packing frameworks (verl) don't pay an
un-pack/re-pack tax, and proves both doors against verl's data patterns.
- PackedBatch (components/datasets/datum.py): model_inputs (model-ready) +
seq_lens + optional flat targets. The caller owns packing/CP and
normalization; the Engine owns forward + per-datum ModelOutput extraction
(split_per_datum over seq_lens) + microbatch lifecycle + backward.
- Engine.forward_backward dispatches PackedBatch -> _forward_backward_packed,
which accepts a scalar-returning loss closure loss_fn(ModelOutput) ->
Tensor | (Tensor, metrics) (verl's pattern), backward()s it, and surfaces the
closure's metrics.
- End-to-end CPU proof (test_engine_verl_integration.py):
* Datum door: datums_from_verl (the adapter that lives in verl's
AutomodelEngine) un-packs a jagged micro-batch -> list[Datum] -> engine
forward_backward(loss_fn="ppo") with engine-owned normalization.
* Pass-through door: PackedBatch + a verl-style closure
loss(model_output, data, dp_group) -> (loss, metrics), caller-normalized.
Additive; recipe/components untouched; import-linter kept. 3 integration tests.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…isable_adapters
Phase 4: express actor / critic / reference as Engine roles (the doc's
output_head/trainable/lora + disable_adapters), additively.
- Config.trainable (default True): False drops the optimizer/scheduler, freezes
params, and sets eval() — the reference / reward / frozen-critic case.
- Config.output_head ("lm" | "value"): "value" makes forward()/the loss path
emit per-token ModelOutput.values from a model that exposes `.values`
(critic); "lm" keeps logprobs/entropy.
- Config.hooks: list[model -> model | None] applied to each model part after
construction (freeze modules, install a value head, patch). NOTE: these run
post-construction; true pre-DDP/FSDP-wrap hooks need a build_model extension
(documented follow-up).
- Engine.forward(disable_adapters=True) + Engine.disable_adapter() context:
base-model (LoRA-off) forward for reference logprobs without a second engine;
delegates to the model's own disable_adapter ctx, no-op when absent.
_build_model_output now branches on output_head (logits->logprobs vs .values).
Additive; recipe/components untouched; import-linter kept. 10 tests incl. an
actor/critic/ref deliverable test.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The Engine should be mechanism, not policy. "actor/critic/reference" are algorithm-layer roles — Megatron-core has no such concept, and verl encodes them in its OWN adapter (a value-head model + an engine subclass + forward-only ref), not as core-engine config. Phase 4's output_head/trainable leaked that policy into the Engine; this removes it. - Remove Config.output_head and Config.trainable. - _build_model_output now DUCK-TYPES on what the model emits (.values -> values, .logits -> logprobs/entropy, both -> both) instead of branching on a role enum. Custom extraction = subclass the method (the verl pattern). - Reference role = Engine with no optimizer + a caller-frozen model (no flag). - Critic = a model that emits .values + a custom value loss. Both assembled by the CALLER. - Keep hooks (generic model surgery — the seam a value head/freeze uses) and forward(disable_adapters=) / disable_adapter() (generic PEFT, used by SFT too). Tests rewritten to assemble actor/critic/ref caller-side with one identical Engine class. Additive; recipe untouched; import-linter kept. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Closes one of the two verl-integration gaps: training-resume checkpoints. - CheckpointHandle(path, wait()): wait() is a no-op for sync saves, flushes the async checkpointer otherwise. - Engine.save_state(path, *, user_state=None, async_save=False) -> handle: saves model + optimizer (+scheduler) via Automodel's Checkpointer, plus per-rank user_state (saved on the rank that provides it). Durable on return unless async_save. Collective. - Engine.load_state(path) -> user_state | None: restores model+optimizer state, returns this rank's user_state. - Config.checkpoint: optional CheckpointingConfig/dict; default is a DCP sharded resume config (torch_save, no HF consolidation) — checkpoint = resume, distinct from export() (HF weights). - Reuses the recipe's exact Checkpointer API + path convention (save_model->path/model, save_optimizer->path/optim); adds _tp_rank/_pp_rank for the Checkpointer build. Tests (single-process gloo PG on CPU): full model+optimizer DCP resume round-trip, user_state round-trip, CheckpointHandle, rank helpers. Additive; recipe/components untouched; import-linter kept. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…t outputs HuggingFace model outputs subclass OrderedDict, so `getattr(out, "values")` returns the dict's `.values` method, not None — the duck-typed extraction mistook it for a value head and crashed (`'builtin_function_or_method' has no attribute 'dim'`). Only treat `.values`/`.logits` as model outputs when they are torch.Tensor. Caught by a GPU smoke running a real GPT2 through the Engine. Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…cum) verl hands forward_backward_batch the whole batch and micro-batches internally, so the single-microbatch pass-through fired the grad-accumulation lifecycle once per call — wrong for accumulation. Extend the pass-through door to accept a list of PackedBatch (microbatches) and run the prepare_for_grad_accumulation / final-backward / after-first-microbatch lifecycle across them, mirroring the Datum path's list[list[Datum]] support. - loss_fn may be one callable (applied to every microbatch) or a list with one per microbatch (each closing over its own microbatch data — the verl shape). - Per-datum logprobs/entropy/values are returned concatenated in microbatch order; the caller restores original order via its own indices. - forward-only with no loss_fn returns loss=None (regression fix). - Metrics from per-mb closures are averaged across microbatches. Validated on GPU with a real GPT2: grads accumulate across microbatches before a single optimizer step. Additive; recipe untouched; import-linter kept. Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…rence) The PackedBatch pass-through assumed a loss_fn always returns a scalar loss. RL inference (verl compute_old_log_prob / ref logprobs) passes a closure that only captures per-datum outputs and returns None — guard the loss accumulation so forward-only inference works without a loss. Validated by an 8-GPU GRPO run (rollout -> old_log_prob -> advantage -> actor update -> weight sync) driving the Engine through the verl automodel adapter. Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…RL losses - Move nemo_automodel/engine.py -> components/training/engine.py, next to its collaborators (model_output.py, step_scheduler.py, utils.py). Root no longer holds substantive logic. Update all importers (4 unit tests + the docstring). - Remove RL-specific losses (ppo, importance_sampling) from loss_fns.py and BUILTIN_LOSSES. The Engine stays role-agnostic: RL objectives read advantages / behavior-policy logprobs and are the consumer's concern (verl supplies its own ppo_loss via the PackedBatch closure). Keep cross_entropy + the LossFn contract. - Update unit tests: registry assertion, and exercise RL losses as caller-supplied LossFns through both the Datum door and the PackedBatch pass-through. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
No real consumer uses an Automodel-side loss registry: verl, slime and nemo-rl each bring their own loss (the loss IS the RL algorithm), and verl's megatron/ fsdp backends plug their own loss into the framework's forward-backward the same way our automodel adapter does. The named registry only duplicated those and risked drifting from Automodel's real loss library (components/loss/). - Delete nemo_automodel/loss_fns.py (BUILTIN_LOSSES, ppo, importance_sampling, cross_entropy registry). The original engine design doc's contract was already `loss_fn: Callable | None` with no registry; this restores it. - Fold the LossFn type alias into engine.py and keep a single module-level cross_entropy as the default Datum-mode loss (used when loss_fn is None). - Drop the string-name resolution; Datum-mode loss_fn is now a Callable | None. - Delete test_loss_fns.py (the registry it covered is gone; the default cross_entropy path stays covered by test_engine.py). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
A tinker-like, in-process training API for Automodel — a small
Engineover atyped
Datum/ModelOutputcontract that lets post-training/RL frameworks driveforward / forward_backward / optimizer steps without hand-wiring the microbatch
lifecycle, MoE aux-loss scaling, CP/THD batch shaping, or gradient clipping.
Purely additive. New files + a couple of
__init__exports;recipes/andexisting
components/are untouched. Nothing imports the Engine yet — it isopt-in for recipes and external consumers.
What's added
components/datasets/datum.py—Datum(per-example input:input_ids+loss_inputs),PackedBatch(pass-through door for already-packingframeworks), and
collate_datumsthat delegates to the existing canonicalcollaters (
default_collater/packed_sequence_thd_collater) — no forkedpacking.
components/training/model_output.py—ModelOutput+ pure per-tokenextraction helpers (
selected_token_logprobs/compute_entropy/split_per_datum).nemo_automodel/engine.py—Engine: config- or injection-construction(reuses the recipe's
build_model+ optimizer build),forward/forward_backward(Datum / dict / PackedBatch inputs, multi-microbatch gradaccumulation),
optimizer_step/optim_step(lr=),save_state/load_state,export_weights,to(), train/eval modes, and genericpre-construction
hooks.nemo_automodel/loss_fns.py—LossFncontract +cross_entropy/importance_sampling/ppo; the Engine owns RL normalization(global token denom, matching
MaskedCrossEntropy's non-PP formula).Design
knowledge (output extraction is duck-typed on what the model emits). Roles are
assembled by the caller via
hooks+ whether an optimizer is provided —mirroring how Megatron-core/verl keep roles out of the core engine.
Datum/collationlive in
datasets, output extraction intraining; the Engine sits attop-level so it may orchestrate both.
Validation
checkpoint round-trip, verl-shaped integration).
compute_old_log_prob→ GRPO advantage → actor update → vLLM weight sync, twotraining steps with full metrics.
Deferred (documented, raise clearly)
Pipeline parallelism in
forward/forward_backward; HF-namedexport(path=);Engine.from_model(name); the CPU-side service/client layer.🤖 Generated with Claude Code