Skip to content

feat: AR rollout loss for ForwardDynamicsPredictionObjective#228

Draft
felixmaximilian wants to merge 1 commit into
feat/vjepa2.1-encoderfrom
feat/ar-rollout-loss
Draft

feat: AR rollout loss for ForwardDynamicsPredictionObjective#228
felixmaximilian wants to merge 1 commit into
feat/vjepa2.1-encoderfrom
feat/ar-rollout-loss

Conversation

@felixmaximilian

Copy link
Copy Markdown
Contributor

Summary

  • Adds a 1-step autoregressive rollout loss to ForwardDynamicsPredictionObjective, mirroring the TF+AR training objective from V-JEPA 2/2.1.
  • feedback_projection: Linear(image_embedding_dim → encoder_embedding_dim) projects each predicted patch independently (preserving spatial structure), concatenated with action_summary as one extra context token before being fed back to the foresight CrossAttentionDecoderHead.
  • TF and AR losses logged separately to wandb (foresight/cam_front_left/tf, .../ar) with no changes to control_transformer.py.
  • Fully configurable via YAML: ar_steps: 0 (default) or feedback_projection: null disables the loss entirely — no structural change to the loss dict.
  • Adds notebooks/ar_sanity_checks.ipynb with rollout error curve (check 1) and PCA feature distribution overlap (check 3).

Test plan

  • Verify config loads with ar_steps: 1 and ar_steps: 0
  • Confirm wandb logs show separate tf/ar metrics when enabled
  • Run notebook against a trained checkpoint to inspect rollout error curve and feature distribution
  • Confirm loss sums correctly through TensorDict.sum(reduce=True) in both modes

🤖 Generated with Claude Code

- notebooks/ar_sanity_checks.ipynb: rollout error curve (check 1) and
  PCA feature distribution overlap (check 3), plus cosine similarity.
  Fully type-annotated; cast objective and annotate metrics to satisfy ty.
- forward_dynamics.py: add type guards at top of _ar_losses so ty can
  narrow Module|None to Module.
- .typos.toml: whitelist arange (numpy array-range) and nd abbreviation,
  both incorrectly flagged by the default typos dictionary.
- pyproject.toml: add scikit-learn to train optional deps and DEP002
  ignore list; map sklearn module name for deptry; uv.lock updated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant