Skip to content

feat: integrate V-JEPA 2.1 ViT-L/16 as image encoder#226

Draft
felixmaximilian wants to merge 8 commits into
mainfrom
feat/vjepa2.1-encoder
Draft

feat: integrate V-JEPA 2.1 ViT-L/16 as image encoder#226
felixmaximilian wants to merge 8 commits into
mainfrom
feat/vjepa2.1-encoder

Conversation

@felixmaximilian

@felixmaximilian felixmaximilian commented May 13, 2026

Copy link
Copy Markdown
Contributor

Integration of an alternative image encoder: V-JEPA2.1
V-JEPA 2.1 is trained on video with a temporal prediction objective, so its representations capture motion and dynamics whereas DINOv3 is image-only and has no notion of how a scene evolves over time.

Summary

  • Adds VjepaBackbone — wraps the V-JEPA 2.1 ViT-L/16 encoder loaded via torch.hub.load('facebookresearch/vjepa2', 'vjepa2_1_vit_large_384',
    pretrained=True), eliminating the sys.path hack from the earlier draft. Weights are auto-downloaded on first use (~4.8 GB) and cached at
    ~/.cache/torch/hub/checkpoints/.
  • New Hydra configs mirror the dinov3 pattern: config/model/yaak/control_transformer/raw_vjepa.yaml (pre-training), policy_finetune_vjepa.yaml +
    finetune_vjepa.yaml (fine-tuning).
  • Pre-training freezes VjepaBackbone by default (mirrors V-JEPA 2 §3.1 — freeze the video encoder, learn a new action-conditioned predictor). An
    unfrozen end-to-end variant is available via trainer/callbacks=pretrain_vjepa_unfrozen.
  • SelectiveAdamW extended with vjepa-specific parameter group detection.
  • Fixed a latent CrossAttentionDecoderHead reshape bug (reshape(b, t, sq, d) → reshape(b, t, sq, -1)) where d was assumed equal to the output
    dim — masked in the dinov3 path where both are 384, but breaks when query dim (384) ≠ target dim (1024).
  • WandbLogger now reads entity/project from Hydra config so personal projects work without env-var overrides.

Architecture

VjepaBackbone (304M params, optionally frozen)
└─ torch.hub vjepa2_1_vit_large_384 → (*B, N_patches=576, 1024)
EpisodeBuilder projections
└─ LayerNorm → ScaleByVectorDim → Linear(1024→384) → (*B, N_patches, 384)
CrossAttentionDecoderHead
└─ predicts future vjepa embeddings (1024-dim targets)
GramAnchoringObjective
└─ loss between predicted (1024) and target (1024) embeddings

Test plan

  • Config loads: just train-debug trainer/callbacks=pretrain_vjepa datamodule.train.batch_size=2 datamodule.val.batch_size=2
    +trainer.fast_dev_run=1 WANDB_MODE=disabled
  • End-to-end (unfrozen): same with trainer/callbacks=pretrain_vjepa_unfrozen
  • Fine-tuning config loads: just train-debug experiment=yaak/control_transformer/finetune_vjepa ...
  • Verify VjepaBackbone params have requires_grad=False when pretrain_vjepa callbacks active

https://wandb.ai/yaak/rmind_private/runs/4jxoqeh5?nw=nwusermax_yaak

felixmaximilian and others added 8 commits May 12, 2026 23:14
Adds VjepaBackbone (nn.Module), a thin wrapper around the vjepa2.1 ViT-L/16
VisionTransformer that exposes the same (*B,C,H,W) → (*B,N,embed_dim) interface
as TimmBackbone. Loads the ema_encoder key from the on-disk checkpoint, cleaning
DDP prefixes. Extends SelectiveAdamW to handle img_mod_embed and video_mod_embed
parameter name suffixes emitted by the VisionTransformer. Adds "app" to
deptry's DEP001 ignore list since vjepa2 is a local sys.path dependency.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds five sibling config files mirroring the dinov3 baseline:
- raw_vjepa.yaml: model config swapping TimmBackbone for VjepaBackbone
  (image_embedding_dim 384→1024 via projections.image)
- pretrain_vjepa.yaml: experiment config referencing raw_vjepa + pretrain_vjepa
  callbacks; sets image_embedding_dim=1024, encoder_embedding_dim=384
- pretrain_vjepa.yaml (callbacks): same as pretrain.yaml but omits the
  TimmBackbone freeze entry — VjepaBackbone is fully trainable during pre-training
- policy_finetune_vjepa.yaml: fine-tuning model config loading from a local
  checkpoint path (vs. W&B artifact in policy_finetune.yaml); jq mutation
  and PolicyObjective config identical since encoder_embedding_dim is unchanged
- finetune_vjepa.yaml: experiment config referencing policy_finetune_vjepa;
  uses standard finetune.yaml callbacks (path-based freeze already covers
  VjepaBackbone inside episode_builder)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replaces the hardcoded sys.path.insert + manual checkpoint loading with
torch.hub.load("facebookresearch/vjepa2", "vjepa2_1_vit_large_384",
pretrained=True), which auto-downloads and caches weights on first use.
Removes _VJEPA2_ROOT, _ensure_vjepa2_on_path, _clean_keys, and the
checkpoint_path constructor arg. Updates raw_vjepa.yaml accordingly.
- VjepaBackbone: drop img_size arg (hub fn hardcodes 384, kwarg caused conflict)
- raw_vjepa.yaml: patch_pos_embed embedding_dim image→encoder (mask tokens
  live in 384-dim rmind space, not 1024-dim vjepa space)
- raw_vjepa.yaml: CrossAttentionDecoder dim_model + in_features image→encoder
  (query/context are 384-dim; out_features stays image_embedding_dim to
  predict 1024-dim vjepa targets)
- decoder.py: CrossAttentionDecoderHead reshape output.reshape(..., d) →
  reshape(..., -1) so output dim can differ from query dim (was latently
  broken in dinov3 config where both dims were equal)
Per V-JEPA 2 paper §3.1: the encoder is frozen; only the transformer and
prediction heads are trained. Adds ModuleFreezer for VjepaBackbone to
pretrain_vjepa callbacks (304M → non-trainable, 21M trainable).

For end-to-end training use trainer/callbacks=pretrain_vjepa_unfrozen.
… recipe

- WandbLogger now reads entity/project from Hydra config (${wandb.entity},
  ${wandb.project}) so personal projects work without env-var overrides
- Removed check-git dependency from `just train` (was blocking runs on
  dirty trees during active dev)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove agent-specific .gitignore entries (agents/, scripts/, PROMPT.md,
  plan-*.md, CLAUDE.md, rmind-vjepa2.1-agent-kit.zip) — these were added
  during automated planning and shouldn't be in version control
- Keep .claude/ and vjepa2_1_vitl_dist_vitG_384.pt ignores (still valid)
- Track CLAUDE.md — project-level Claude Code instructions for this repo
- Track config/_templates/dataset/yaak/train_debug.yaml and
  config/datamodule/yaak/train_debug.yaml — 3-drive debug split used in
  smoke-test runs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@felixmaximilian felixmaximilian marked this pull request as draft May 22, 2026 18:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant