Skip to content

feat(engine): Engine training API#2556

Draft
HuiyingLi wants to merge 12 commits into
mainfrom
huiyingl/feat/tinker-datum
Draft

feat(engine): Engine training API#2556
HuiyingLi wants to merge 12 commits into
mainfrom
huiyingl/feat/tinker-datum

Conversation

@HuiyingLi

Copy link
Copy Markdown
Contributor

Summary

A tinker-like, in-process training API for Automodel — a small Engine over a
typed Datum/ModelOutput contract that lets post-training/RL frameworks drive
forward / forward_backward / optimizer steps without hand-wiring the microbatch
lifecycle, MoE aux-loss scaling, CP/THD batch shaping, or gradient clipping.

Purely additive. New files + a couple of __init__ exports; recipes/ and
existing components/ are untouched. Nothing imports the Engine yet — it is
opt-in for recipes and external consumers.

What's added

  • components/datasets/datum.pyDatum (per-example input: input_ids +
    loss_inputs), PackedBatch (pass-through door for already-packing
    frameworks), and collate_datums that delegates to the existing canonical
    collaters
    (default_collater / packed_sequence_thd_collater) — no forked
    packing.
  • components/training/model_output.pyModelOutput + pure per-token
    extraction helpers (selected_token_logprobs / compute_entropy /
    split_per_datum).
  • nemo_automodel/engine.pyEngine: config- or injection-construction
    (reuses the recipe's build_model + optimizer build), forward /
    forward_backward (Datum / dict / PackedBatch inputs, multi-microbatch grad
    accumulation), optimizer_step / optim_step(lr=), save_state /
    load_state, export_weights, to(), train/eval modes, and generic
    pre-construction hooks.
  • nemo_automodel/loss_fns.pyLossFn contract +
    cross_entropy/importance_sampling/ppo; the Engine owns RL normalization
    (global token denom, matching MaskedCrossEntropy's non-PP formula).

Design

  • The Engine is mechanism, not policy: it has no actor/critic/reference
    knowledge (output extraction is duck-typed on what the model emits). Roles are
    assembled by the caller via hooks + whether an optimizer is provided —
    mirroring how Megatron-core/verl keep roles out of the core engine.
  • Components stay independent (import-linter contract kept): Datum/collation
    live in datasets, output extraction in training; the Engine sits at
    top-level so it may orchestrate both.

Validation

  • 66 CPU unit tests (contract, loss math, engine forward/backward, roles,
    checkpoint round-trip, verl-shaped integration).
  • GPU: real GPT2 forward_backward + multi-microbatch grad accumulation.
  • End-to-end 8-GPU GRPO via a verl adapter built on this Engine: rollout →
    compute_old_log_prob → GRPO advantage → actor update → vLLM weight sync, two
    training steps with full metrics.

Deferred (documented, raise clearly)

Pipeline parallelism in forward/forward_backward; HF-named export(path=);
Engine.from_model(name); the CPU-side service/client layer.

Note: branch is based on an earlier main; the diff above is the merge-base
(additive) view. Happy to rebase before un-drafting.

🤖 Generated with Claude Code

HuiyingLi and others added 10 commits June 14, 2026 08:07
Phase 1 of the tinker-like training API: a typed boundary between
post-training/RL algorithm code and the training-loop internals.

Input half (components.datasets.datum) — a data concern, placed there so it can
reuse the canonical collaters instead of forking padding/packing:
- Datum: single-example input contract (input_ids + loss_inputs dict), with
  Datum.to_features() emitting the per-example dict the dataset pipeline
  produces (labels masked with -100 where weights == 0).
- collate_datums: delegates token fields to the existing default_collater
  (padded) / packed_sequence_thd_collater (THD) so the canonical schema
  (attention_mask / qkv_format / seq_lens) is reused unchanged; additionally
  batches float per-token side-inputs (weights/logprobs/advantages) that those
  collaters cannot carry.

Output half (components.training.model_output) — a forward concern (touches
logits), kept out of datasets:
- ModelOutput: per-datum output contract (loss, logprobs, entropy, values,
  metrics), list fields aligned to input order.
- selected_token_logprobs / compute_entropy / split_per_datum: pure per-token
  extraction helpers RL frameworks otherwise reimplement.

No forked collation; import-linter "components must not import each other" kept.
Additive: no training-loop or recipe changes. 19 unit tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…hed)

A small in-process Engine — the home for the tinker-like training API on top of
the Datum/ModelOutput contract. PURELY ADDITIVE: new top-level module
nemo_automodel/engine.py; recipes/llm/train_ft.py and all existing components
are unchanged. Nothing imports the Engine yet — it's opt-in for recipes/RL
frameworks.

Surface (non-PP path):
- Engine.Config + two construction paths: build from config (reuses the LLM
  recipe's build_model + optimizer .build/shard — identical objects), or inject
  already-built model_parts/optimizers (tests, custom builders).
- forward_backward(batch, loss_fn): accepts a dict, list of microbatch dicts, or
  list[Datum]; returns ModelOutput for Datum input, else {"loss","metrics"}.
  Owns the microbatch lifecycle (prepare_* hooks, MoE aux-loss scaling,
  CP/THD via make_cp_batch_and_ctx, final-backward) so callers don't.
- forward(datums): forward-only, returns per-datum logprobs/entropy.
- optimizer_step()/optim_step(lr=)/zero_grad()/lr_scheduler_step() (list-aware).
- train_mode()/eval_mode(), to(device) offload, export_weights() streaming.

Reuses existing primitives throughout (no forked training math). Pipeline
parallelism in forward/forward_backward is a documented follow-up and raises
clearly. 14 CPU unit tests via the injection path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…tion

Phase 3: align the loss path to the tinker doc's LossFn contract and move
loss normalization into the Engine (also what verl needs).

- nemo_automodel/loss_fns.py: LossFn = Callable[[ModelOutput, Sequence[Datum]],
  Sequence[Tensor]] + built-ins cross_entropy / importance_sampling / ppo and a
  BUILTIN_LOSSES registry. Losses return the un-reduced per-token signal; they
  carry no distributed/scaling knowledge.
- Engine.forward_backward Datum mode: loss_fn is a built-in name (default
  "cross_entropy") or a LossFn callable. The Engine owns the doc's
  normalization — applies loss_inputs["weights"], divides by the GLOBAL token
  (or sample) count all-reduced across data ranks, runs the microbatch
  lifecycle + backward, and returns a ModelOutput. Matches the recipe's non-PP
  loss formula (weighted-sum / global-token-count, cf. MaskedCrossEntropy), so
  DP scaling stays identical to the proven path.
- Dict mode (legacy SFT via calculate_loss) is unchanged.
- _build_model_output shared by forward() and the loss path; supports
  list[Datum] (one microbatch) and list[list[Datum]] (grad accumulation).

Additive; recipe/components untouched; import-linter kept. 19 new tests
(pure loss math + engine grad-flow + normalization).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…nd-to-end

Adds the second input door so already-packing frameworks (verl) don't pay an
un-pack/re-pack tax, and proves both doors against verl's data patterns.

- PackedBatch (components/datasets/datum.py): model_inputs (model-ready) +
  seq_lens + optional flat targets. The caller owns packing/CP and
  normalization; the Engine owns forward + per-datum ModelOutput extraction
  (split_per_datum over seq_lens) + microbatch lifecycle + backward.
- Engine.forward_backward dispatches PackedBatch -> _forward_backward_packed,
  which accepts a scalar-returning loss closure loss_fn(ModelOutput) ->
  Tensor | (Tensor, metrics) (verl's pattern), backward()s it, and surfaces the
  closure's metrics.
- End-to-end CPU proof (test_engine_verl_integration.py):
  * Datum door: datums_from_verl (the adapter that lives in verl's
    AutomodelEngine) un-packs a jagged micro-batch -> list[Datum] -> engine
    forward_backward(loss_fn="ppo") with engine-owned normalization.
  * Pass-through door: PackedBatch + a verl-style closure
    loss(model_output, data, dp_group) -> (loss, metrics), caller-normalized.

Additive; recipe/components untouched; import-linter kept. 3 integration tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…isable_adapters

Phase 4: express actor / critic / reference as Engine roles (the doc's
output_head/trainable/lora + disable_adapters), additively.

- Config.trainable (default True): False drops the optimizer/scheduler, freezes
  params, and sets eval() — the reference / reward / frozen-critic case.
- Config.output_head ("lm" | "value"): "value" makes forward()/the loss path
  emit per-token ModelOutput.values from a model that exposes `.values`
  (critic); "lm" keeps logprobs/entropy.
- Config.hooks: list[model -> model | None] applied to each model part after
  construction (freeze modules, install a value head, patch). NOTE: these run
  post-construction; true pre-DDP/FSDP-wrap hooks need a build_model extension
  (documented follow-up).
- Engine.forward(disable_adapters=True) + Engine.disable_adapter() context:
  base-model (LoRA-off) forward for reference logprobs without a second engine;
  delegates to the model's own disable_adapter ctx, no-op when absent.

_build_model_output now branches on output_head (logits->logprobs vs .values).
Additive; recipe/components untouched; import-linter kept. 10 tests incl. an
actor/critic/ref deliverable test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The Engine should be mechanism, not policy. "actor/critic/reference" are
algorithm-layer roles — Megatron-core has no such concept, and verl encodes
them in its OWN adapter (a value-head model + an engine subclass + forward-only
ref), not as core-engine config. Phase 4's output_head/trainable leaked that
policy into the Engine; this removes it.

- Remove Config.output_head and Config.trainable.
- _build_model_output now DUCK-TYPES on what the model emits (.values -> values,
  .logits -> logprobs/entropy, both -> both) instead of branching on a role
  enum. Custom extraction = subclass the method (the verl pattern).
- Reference role = Engine with no optimizer + a caller-frozen model (no flag).
- Critic = a model that emits .values + a custom value loss. Both assembled by
  the CALLER.
- Keep hooks (generic model surgery — the seam a value head/freeze uses) and
  forward(disable_adapters=) / disable_adapter() (generic PEFT, used by SFT too).

Tests rewritten to assemble actor/critic/ref caller-side with one identical
Engine class. Additive; recipe untouched; import-linter kept.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Closes one of the two verl-integration gaps: training-resume checkpoints.

- CheckpointHandle(path, wait()): wait() is a no-op for sync saves, flushes the
  async checkpointer otherwise.
- Engine.save_state(path, *, user_state=None, async_save=False) -> handle:
  saves model + optimizer (+scheduler) via Automodel's Checkpointer, plus
  per-rank user_state (saved on the rank that provides it). Durable on return
  unless async_save. Collective.
- Engine.load_state(path) -> user_state | None: restores model+optimizer state,
  returns this rank's user_state.
- Config.checkpoint: optional CheckpointingConfig/dict; default is a DCP sharded
  resume config (torch_save, no HF consolidation) — checkpoint = resume,
  distinct from export() (HF weights).
- Reuses the recipe's exact Checkpointer API + path convention
  (save_model->path/model, save_optimizer->path/optim); adds _tp_rank/_pp_rank
  for the Checkpointer build.

Tests (single-process gloo PG on CPU): full model+optimizer DCP resume
round-trip, user_state round-trip, CheckpointHandle, rank helpers.
Additive; recipe/components untouched; import-linter kept.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…t outputs

HuggingFace model outputs subclass OrderedDict, so `getattr(out, "values")`
returns the dict's `.values` method, not None — the duck-typed extraction
mistook it for a value head and crashed (`'builtin_function_or_method' has no
attribute 'dim'`). Only treat `.values`/`.logits` as model outputs when they are
torch.Tensor. Caught by a GPU smoke running a real GPT2 through the Engine.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…cum)

verl hands forward_backward_batch the whole batch and micro-batches internally,
so the single-microbatch pass-through fired the grad-accumulation lifecycle
once per call — wrong for accumulation. Extend the pass-through door to accept a
list of PackedBatch (microbatches) and run the prepare_for_grad_accumulation /
final-backward / after-first-microbatch lifecycle across them, mirroring the
Datum path's list[list[Datum]] support.

- loss_fn may be one callable (applied to every microbatch) or a list with one
  per microbatch (each closing over its own microbatch data — the verl shape).
- Per-datum logprobs/entropy/values are returned concatenated in microbatch
  order; the caller restores original order via its own indices.
- forward-only with no loss_fn returns loss=None (regression fix).
- Metrics from per-mb closures are averaged across microbatches.

Validated on GPU with a real GPT2: grads accumulate across microbatches before a
single optimizer step. Additive; recipe untouched; import-linter kept.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…rence)

The PackedBatch pass-through assumed a loss_fn always returns a scalar loss.
RL inference (verl compute_old_log_prob / ref logprobs) passes a closure that
only captures per-datum outputs and returns None — guard the loss accumulation
so forward-only inference works without a loss. Validated by an 8-GPU GRPO run
(rollout -> old_log_prob -> advantage -> actor update -> weight sync) driving
the Engine through the verl automodel adapter.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 14, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi HuiyingLi changed the title feat(engine): tinker-like Engine + Datum/ModelOutput training API feat(engine): Engine training API Jun 14, 2026
HuiyingLi and others added 2 commits June 15, 2026 07:36
…RL losses

- Move nemo_automodel/engine.py -> components/training/engine.py, next to its
  collaborators (model_output.py, step_scheduler.py, utils.py). Root no longer
  holds substantive logic. Update all importers (4 unit tests + the docstring).
- Remove RL-specific losses (ppo, importance_sampling) from loss_fns.py and
  BUILTIN_LOSSES. The Engine stays role-agnostic: RL objectives read advantages /
  behavior-policy logprobs and are the consumer's concern (verl supplies its own
  ppo_loss via the PackedBatch closure). Keep cross_entropy + the LossFn contract.
- Update unit tests: registry assertion, and exercise RL losses as caller-supplied
  LossFns through both the Datum door and the PackedBatch pass-through.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
No real consumer uses an Automodel-side loss registry: verl, slime and nemo-rl
each bring their own loss (the loss IS the RL algorithm), and verl's megatron/
fsdp backends plug their own loss into the framework's forward-backward the same
way our automodel adapter does. The named registry only duplicated those and
risked drifting from Automodel's real loss library (components/loss/).

- Delete nemo_automodel/loss_fns.py (BUILTIN_LOSSES, ppo, importance_sampling,
  cross_entropy registry). The original engine design doc's contract was already
  `loss_fn: Callable | None` with no registry; this restores it.
- Fold the LossFn type alias into engine.py and keep a single module-level
  cross_entropy as the default Datum-mode loss (used when loss_fn is None).
- Drop the string-name resolution; Datum-mode loss_fn is now a Callable | None.
- Delete test_loss_fns.py (the registry it covered is gone; the default
  cross_entropy path stays covered by test_engine.py).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant