feat: integrate V-JEPA 2.1 ViT-L/16 as image encoder#226
Draft
felixmaximilian wants to merge 8 commits into
Draft
feat: integrate V-JEPA 2.1 ViT-L/16 as image encoder#226felixmaximilian wants to merge 8 commits into
felixmaximilian wants to merge 8 commits into
Conversation
Adds VjepaBackbone (nn.Module), a thin wrapper around the vjepa2.1 ViT-L/16 VisionTransformer that exposes the same (*B,C,H,W) → (*B,N,embed_dim) interface as TimmBackbone. Loads the ema_encoder key from the on-disk checkpoint, cleaning DDP prefixes. Extends SelectiveAdamW to handle img_mod_embed and video_mod_embed parameter name suffixes emitted by the VisionTransformer. Adds "app" to deptry's DEP001 ignore list since vjepa2 is a local sys.path dependency. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds five sibling config files mirroring the dinov3 baseline: - raw_vjepa.yaml: model config swapping TimmBackbone for VjepaBackbone (image_embedding_dim 384→1024 via projections.image) - pretrain_vjepa.yaml: experiment config referencing raw_vjepa + pretrain_vjepa callbacks; sets image_embedding_dim=1024, encoder_embedding_dim=384 - pretrain_vjepa.yaml (callbacks): same as pretrain.yaml but omits the TimmBackbone freeze entry — VjepaBackbone is fully trainable during pre-training - policy_finetune_vjepa.yaml: fine-tuning model config loading from a local checkpoint path (vs. W&B artifact in policy_finetune.yaml); jq mutation and PolicyObjective config identical since encoder_embedding_dim is unchanged - finetune_vjepa.yaml: experiment config referencing policy_finetune_vjepa; uses standard finetune.yaml callbacks (path-based freeze already covers VjepaBackbone inside episode_builder) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replaces the hardcoded sys.path.insert + manual checkpoint loading with
torch.hub.load("facebookresearch/vjepa2", "vjepa2_1_vit_large_384",
pretrained=True), which auto-downloads and caches weights on first use.
Removes _VJEPA2_ROOT, _ensure_vjepa2_on_path, _clean_keys, and the
checkpoint_path constructor arg. Updates raw_vjepa.yaml accordingly.
- VjepaBackbone: drop img_size arg (hub fn hardcodes 384, kwarg caused conflict) - raw_vjepa.yaml: patch_pos_embed embedding_dim image→encoder (mask tokens live in 384-dim rmind space, not 1024-dim vjepa space) - raw_vjepa.yaml: CrossAttentionDecoder dim_model + in_features image→encoder (query/context are 384-dim; out_features stays image_embedding_dim to predict 1024-dim vjepa targets) - decoder.py: CrossAttentionDecoderHead reshape output.reshape(..., d) → reshape(..., -1) so output dim can differ from query dim (was latently broken in dinov3 config where both dims were equal)
Per V-JEPA 2 paper §3.1: the encoder is frozen; only the transformer and prediction heads are trained. Adds ModuleFreezer for VjepaBackbone to pretrain_vjepa callbacks (304M → non-trainable, 21M trainable). For end-to-end training use trainer/callbacks=pretrain_vjepa_unfrozen.
… recipe
- WandbLogger now reads entity/project from Hydra config (${wandb.entity},
${wandb.project}) so personal projects work without env-var overrides
- Removed check-git dependency from `just train` (was blocking runs on
dirty trees during active dev)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove agent-specific .gitignore entries (agents/, scripts/, PROMPT.md, plan-*.md, CLAUDE.md, rmind-vjepa2.1-agent-kit.zip) — these were added during automated planning and shouldn't be in version control - Keep .claude/ and vjepa2_1_vitl_dist_vitG_384.pt ignores (still valid) - Track CLAUDE.md — project-level Claude Code instructions for this repo - Track config/_templates/dataset/yaak/train_debug.yaml and config/datamodule/yaak/train_debug.yaml — 3-drive debug split used in smoke-test runs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Integration of an alternative image encoder: V-JEPA2.1
V-JEPA 2.1 is trained on video with a temporal prediction objective, so its representations capture motion and dynamics whereas DINOv3 is image-only and has no notion of how a scene evolves over time.
Summary
pretrained=True), eliminating the sys.path hack from the earlier draft. Weights are auto-downloaded on first use (~4.8 GB) and cached at
~/.cache/torch/hub/checkpoints/.
finetune_vjepa.yaml (fine-tuning).
unfrozen end-to-end variant is available via trainer/callbacks=pretrain_vjepa_unfrozen.
dim — masked in the dinov3 path where both are 384, but breaks when query dim (384) ≠ target dim (1024).
Architecture
VjepaBackbone (304M params, optionally frozen)
└─ torch.hub vjepa2_1_vit_large_384 → (*B, N_patches=576, 1024)
EpisodeBuilder projections
└─ LayerNorm → ScaleByVectorDim → Linear(1024→384) → (*B, N_patches, 384)
CrossAttentionDecoderHead
└─ predicts future vjepa embeddings (1024-dim targets)
GramAnchoringObjective
└─ loss between predicted (1024) and target (1024) embeddings
Test plan
+trainer.fast_dev_run=1 WANDB_MODE=disabled
https://wandb.ai/yaak/rmind_private/runs/4jxoqeh5?nw=nwusermax_yaak