From 20323aa345d19137a6917a3061c6545dd60c9aa8 Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Tue, 12 May 2026 23:14:32 +0200 Subject: [PATCH 1/8] ignore agent stuff --- .gitignore | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.gitignore b/.gitignore index 526e0894..6455b396 100644 --- a/.gitignore +++ b/.gitignore @@ -144,6 +144,16 @@ multirun/ # generated config/dataset/* config/logger/* +config/_templates/dataset/yaak/train_debug.yaml +config/datamodule/yaak/train_debug.yaml +.claude/ +CLAUDE.md +PROMPT.md +agents/ +plan-*.md +rmind-vjepa2.1-agent-kit.zip +scripts/ +vjepa2_1_vitl_dist_vitG_384.pt # editor .helix/ From 470051d704e18d0a1d29f8dc1f6a6b2dfbfd965d Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Tue, 12 May 2026 23:31:09 +0200 Subject: [PATCH 2/8] feat: add VjepaBackbone and extend SelectiveAdamW for vjepa2.1 params MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds VjepaBackbone (nn.Module), a thin wrapper around the vjepa2.1 ViT-L/16 VisionTransformer that exposes the same (*B,C,H,W) → (*B,N,embed_dim) interface as TimmBackbone. Loads the ema_encoder key from the on-disk checkpoint, cleaning DDP prefixes. Extends SelectiveAdamW to handle img_mod_embed and video_mod_embed parameter name suffixes emitted by the VisionTransformer. Adds "app" to deptry's DEP001 ignore list since vjepa2 is a local sys.path dependency. Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 2 +- .../components/optimizers/selective_adamw.py | 9 ++- src/rmind/components/vjepa_backbone.py | 74 +++++++++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/rmind/components/vjepa_backbone.py diff --git a/pyproject.toml b/pyproject.toml index 95a9d286..c23f2802 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.deptry.per_rule_ignores] -DEP001 = ["rmind"] +DEP001 = ["app", "rmind"] DEP002 = [ "funcy", "torchmetrics", diff --git a/src/rmind/components/optimizers/selective_adamw.py b/src/rmind/components/optimizers/selective_adamw.py index d8f4b006..8384939a 100644 --- a/src/rmind/components/optimizers/selective_adamw.py +++ b/src/rmind/components/optimizers/selective_adamw.py @@ -40,7 +40,14 @@ def __init__( # https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/activation.py#L1091 case ( - "in_proj_weight" | "cls_token" | "reg_token" | "gamma_1" | "gamma_2" + "in_proj_weight" + | "cls_token" + | "reg_token" + | "gamma_1" + | "gamma_2" + # vjepa2 VisionTransformer modality embeddings (nn.Parameter) + | "img_mod_embed" + | "video_mod_embed" ): pass diff --git a/src/rmind/components/vjepa_backbone.py b/src/rmind/components/vjepa_backbone.py new file mode 100644 index 00000000..80fbcb57 --- /dev/null +++ b/src/rmind/components/vjepa_backbone.py @@ -0,0 +1,74 @@ +import sys +from math import prod +from pathlib import Path +from typing import override + +import torch +from torch import Tensor, nn + +_VJEPA2_ROOT = str(Path("/home/max/Code/vjepa2")) + + +def _ensure_vjepa2_on_path() -> None: + if _VJEPA2_ROOT not in sys.path: + sys.path.insert(0, _VJEPA2_ROOT) + + +def _clean_keys(state_dict: dict) -> dict: + # Strip "module." and "backbone." prefixes written by vjepa2 DDP trainer. + cleaned = {} + for key, val in state_dict.items(): + key = key.replace("module.", "").replace("backbone.", "") # noqa: PLW2901 + cleaned[key] = val + return cleaned + + +class VjepaBackbone(nn.Module): + """V-JEPA 2.1 ViT-L/16 image encoder. + + Wraps the vjepa2 VisionTransformer so it presents the same interface as + TimmBackbone: accepts (*B, C, H, W) and returns (*B, N_patches, embed_dim). + The encoder is fully trainable during rmind pre-training; the vjepa2 DINOv3 + assumption in CLAUDE.md was found to be incorrect — see Phase 1 report. + """ + + def __init__( + self, checkpoint_path: str, img_size: tuple[int, int] = (256, 256) + ) -> None: + super().__init__() + _ensure_vjepa2_on_path() + from app.vjepa_2_1.models.vision_transformer import ( # noqa: PLC0415 # ty: ignore[unresolved-import] + vit_large, # type: ignore[import] + ) + + # Canonical vjepa2.1 ViT-L kwargs from src/hub/backbones.py:229-241. + # img_temporal_dim_size=1 activates per-frame (image) mode when the + # temporal dim of a 5-D input equals 1. + self.encoder: nn.Module = vit_large( + patch_size=16, + img_size=img_size, + num_frames=64, + tubelet_size=2, + use_sdpa=True, + use_rope=True, + img_temporal_dim_size=1, + interpolate_rope=True, + ) + + raw = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + missing, unexpected = self.encoder.load_state_dict( + _clean_keys(raw["ema_encoder"]), strict=True + ) + if missing or unexpected: + msg = f"Checkpoint key mismatch: missing={missing}, unexpected={unexpected}" + raise RuntimeError(msg) + + @override + def forward(self, x: Tensor) -> Tensor: + # x: (*B, C, H, W) — same contract as TimmBackbone. + # Unsqueeze temporal dim so VisionTransformer takes the img_temporal_dim_size=1 + # branch (patch_embed_img + img_mod_embed) instead of the video branch. + *b, c, h, w = x.shape + x = x.view(prod(b), c, 1, h, w) + x = self.encoder(x) # (prod(B), N_patches, 1024) + return x.view(*b, x.shape[-2], x.shape[-1]) # (*B, N_patches, 1024) From 031c909cfa45bc9d87332bc33c48ff6f03d3ae38 Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Tue, 12 May 2026 23:31:19 +0200 Subject: [PATCH 3/8] feat: add vjepa2.1 pre-training and fine-tuning YAML configs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds five sibling config files mirroring the dinov3 baseline: - raw_vjepa.yaml: model config swapping TimmBackbone for VjepaBackbone (image_embedding_dim 384→1024 via projections.image) - pretrain_vjepa.yaml: experiment config referencing raw_vjepa + pretrain_vjepa callbacks; sets image_embedding_dim=1024, encoder_embedding_dim=384 - pretrain_vjepa.yaml (callbacks): same as pretrain.yaml but omits the TimmBackbone freeze entry — VjepaBackbone is fully trainable during pre-training - policy_finetune_vjepa.yaml: fine-tuning model config loading from a local checkpoint path (vs. W&B artifact in policy_finetune.yaml); jq mutation and PolicyObjective config identical since encoder_embedding_dim is unchanged - finetune_vjepa.yaml: experiment config referencing policy_finetune_vjepa; uses standard finetune.yaml callbacks (path-based freeze already covers VjepaBackbone inside episode_builder) Co-Authored-By: Claude Sonnet 4.6 --- .../control_transformer/finetune_vjepa.yaml | 12 + .../control_transformer/pretrain_vjepa.yaml | 20 + .../policy_finetune_vjepa.yaml | 89 +++ .../yaak/control_transformer/raw_vjepa.yaml | 523 ++++++++++++++++++ config/trainer/callbacks/pretrain_vjepa.yaml | 57 ++ 5 files changed, 701 insertions(+) create mode 100644 config/experiment/yaak/control_transformer/finetune_vjepa.yaml create mode 100644 config/experiment/yaak/control_transformer/pretrain_vjepa.yaml create mode 100644 config/model/yaak/control_transformer/policy_finetune_vjepa.yaml create mode 100644 config/model/yaak/control_transformer/raw_vjepa.yaml create mode 100644 config/trainer/callbacks/pretrain_vjepa.yaml diff --git a/config/experiment/yaak/control_transformer/finetune_vjepa.yaml b/config/experiment/yaak/control_transformer/finetune_vjepa.yaml new file mode 100644 index 00000000..fb3dba04 --- /dev/null +++ b/config/experiment/yaak/control_transformer/finetune_vjepa.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - /model: yaak/control_transformer/policy_finetune_vjepa + - /datamodule: yaak/train + - /trainer: default + - /trainer/callbacks: finetune + - /paths: yaak/default + - /wandb: yaak/rmind + - _self_ + +encoder_embedding_dim: 384 diff --git a/config/experiment/yaak/control_transformer/pretrain_vjepa.yaml b/config/experiment/yaak/control_transformer/pretrain_vjepa.yaml new file mode 100644 index 00000000..82202284 --- /dev/null +++ b/config/experiment/yaak/control_transformer/pretrain_vjepa.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - /model: yaak/control_transformer/raw_vjepa + - /datamodule: yaak/train + - /trainer: default + - /trainer/callbacks: pretrain_vjepa + - /paths: yaak/default + - /wandb: yaak/rmind + - _self_ + +num_heads: 4 +num_layers: 8 +encoder_embedding_dim: 384 +image_embedding_dim: 1024 + +speed_bins: 512 +gas_pedal_bins: 255 +brake_pedal_bins: 165 +steering_angle_bins: 961 diff --git a/config/model/yaak/control_transformer/policy_finetune_vjepa.yaml b/config/model/yaak/control_transformer/policy_finetune_vjepa.yaml new file mode 100644 index 00000000..cd892669 --- /dev/null +++ b/config/model/yaak/control_transformer/policy_finetune_vjepa.yaml @@ -0,0 +1,89 @@ +--- +_target_: rmind.models.control_transformer.ControlTransformer.load_from_checkpoint +checkpoint_path: ??? +strict: false +hparams_jq: | + .objectives = { + "_target_": "rmind.components.containers.ModuleDict", + "modules": { + "policy": { + "_target_": "rmind.components.objectives.PolicyObjective", + "norm": { + "_target_": "torch.nn.LayerNorm", + "normalized_shape": 384 + }, + "heads": { + "_target_": "rmind.components.containers.ModuleDict", + "modules": { + "continuous": { + "gas_pedal": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 2], + "bias": false + }, + "brake_pedal": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 2], + "bias": false + }, + "steering_angle": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 2], + "bias": false + } + }, + "discrete": { + "turn_signal": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 3], + "bias": false + } + } + } + }, + "targets": { + "continuous": { + "gas_pedal": ["input", "continuous", "gas_pedal"], + "brake_pedal": ["input", "continuous", "brake_pedal"], + "steering_angle": ["input", "continuous", "steering_angle"] + }, + "discrete": { + "turn_signal": ["input", "discrete", "turn_signal"] + } + }, + "losses": { + "_target_": "rmind.components.containers.ModuleDict", + "modules": { + "continuous": { + "gas_pedal": { + "_target_": "rmind.components.loss.GaussianNLLLoss" + }, + "brake_pedal": { + "_target_": "rmind.components.loss.GaussianNLLLoss" + }, + "steering_angle": { + "_target_": "rmind.components.loss.GaussianNLLLoss" + } + }, + "discrete": { + "turn_signal": { + "_target_": "rmind.components.loss.LogitBiasCrossEntropyLoss" + } + } + } + } + } + } + } + | .lr_scheduler = { + "interval": "step", + "scheduler": { + "_target_": "rmind.components.lr_schedulers.get_cosine_schedule_with_warmup", + "num_warmup_steps": 25000, + "num_training_steps": 250000 + } + } diff --git a/config/model/yaak/control_transformer/raw_vjepa.yaml b/config/model/yaak/control_transformer/raw_vjepa.yaml new file mode 100644 index 00000000..82526c9f --- /dev/null +++ b/config/model/yaak/control_transformer/raw_vjepa.yaml @@ -0,0 +1,523 @@ +_target_: rmind.models.control_transformer.ControlTransformer +_recursive_: false +_convert_: all + +episode_builder: + _target_: rmind.components.episode.EpisodeBuilder + _recursive_: true + _convert_: all + timestep: + - [observation, image, cam_front_left] + - [observation, continuous, speed] + - [observation, context, waypoints] + - [special, foresight, cam_front_left] + - [special, summary, observation_summary] + - [special, summary, observation_history] + - [action, continuous, gas_pedal] + - [action, continuous, brake_pedal] + - [action, continuous, steering_angle] + - [action, discrete, turn_signal] + - [special, summary, action_summary] + + special_tokens: + foresight: + cam_front_left: + _target_: builtins.range + _args_: + - 256 + + summary: + observation_summary: [0] + observation_history: [1] + action_summary: [2] + + utility: + mask: [0] + + input_transform: + _target_: torch.nn.Sequential + _convert_: all + _args_: + - _target_: rmind.components.nn.Remapper + paths: + image: + cam_front_left: [data, cam_front_left] + + continuous: + speed: [data, meta/VehicleMotion/speed] + gas_pedal: [data, meta/VehicleMotion/gas_pedal_normalized] + gas_pedal_diff: [data, meta/VehicleMotion/gas_pedal_normalized] + brake_pedal: [data, meta/VehicleMotion/brake_pedal_normalized] + brake_pedal_diff: [data, meta/VehicleMotion/brake_pedal_normalized] + steering_angle: [data, meta/VehicleMotion/steering_angle_normalized] + steering_angle_diff: + [data, meta/VehicleMotion/steering_angle_normalized] + + context: + waypoints: [data, waypoints/xy_normalized] + + discrete: + turn_signal: [data, meta/VehicleState/turn_signal] + + - _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: torch.nn.Sequential + _args_: + - _target_: einops.layers.torch.Rearrange + pattern: "... h w c -> ... c h w" + - _target_: torchvision.transforms.v2.CenterCrop + size: [320, 576] + - _target_: torchvision.transforms.v2.Resize + size: [256, 256] + - _target_: torchvision.transforms.v2.ToDtype + scale: true + dtype: + _target_: hydra.utils.get_object + path: torch.float32 + - _target_: torchvision.transforms.v2.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + + continuous: + speed: + _target_: rmind.components.nn.AtLeast3D + + gas_pedal: + _target_: rmind.components.nn.AtLeast3D + + gas_pedal_diff: + _target_: rmind.components.nn.DiffLast + append: + _target_: hydra.utils.get_object + path: math.nan + + brake_pedal: + _target_: rmind.components.nn.AtLeast3D + + brake_pedal_diff: + _target_: rmind.components.nn.DiffLast + append: + _target_: hydra.utils.get_object + path: math.nan + + steering_angle: + _target_: rmind.components.nn.AtLeast3D + + steering_angle_diff: + _target_: rmind.components.nn.DiffLast + append: + _target_: hydra.utils.get_object + path: math.nan + + discrete: + _target_: rmind.components.nn.AtLeast3D + + context: + _target_: torch.nn.Identity + + tokenizers: + _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: rmind.components.nn.Identity + + continuous: + speed: + _target_: rmind.components.norm.UniformBinner + range: [0.0, 130.0] + bins: ${speed_bins} + + gas_pedal: + _target_: rmind.components.norm.UniformBinner + range: [0.0, 1.0] + bins: ${gas_pedal_bins} + + gas_pedal_diff: + # NOTE: no pre-mulaw scaling since if x in [0.0, 1.0] then dx in [-1.0, 1.0] + _target_: rmind.components.norm.MuLawEncoding + quantization_channels: ${gas_pedal_bins} + + brake_pedal: + _target_: rmind.components.norm.UniformBinner + range: [0.0, 1.0] + bins: ${brake_pedal_bins} + + brake_pedal_diff: + # NOTE: no pre-mulaw scaling since if x in [0.0, 1.0] then dx in [-1.0, 1.0] + _target_: rmind.components.norm.MuLawEncoding + quantization_channels: ${brake_pedal_bins} + + steering_angle: + _target_: rmind.components.norm.UniformBinner + range: [-1.0, 1.0] + bins: ${steering_angle_bins} + + steering_angle_diff: + _target_: rmind.components.nn.Sequential + _args_: + - _target_: rmind.components.norm.Scaler + in_range: [-2.0, 2.0] + out_range: [-1.0, 1.0] + - _target_: rmind.components.norm.MuLawEncoding + quantization_channels: ${steering_angle_bins} + + discrete: + _target_: rmind.components.nn.Identity + + context: + waypoints: + _target_: rmind.components.nn.Identity + + embeddings: + _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: torch.nn.Sequential + _args_: + - _target_: rmind.components.vjepa_backbone.VjepaBackbone + checkpoint_path: /home/max/Code/rmind/vjepa2_1_vitl_dist_vitG_384.pt + img_size: [256, 256] + + continuous: + speed: + _target_: rmind.components.nn.Embedding + num_embeddings: ${speed_bins} + embedding_dim: ${encoder_embedding_dim} + gas_pedal: + _target_: rmind.components.nn.Embedding + num_embeddings: ${gas_pedal_bins} + embedding_dim: ${encoder_embedding_dim} + brake_pedal: + _target_: rmind.components.nn.Embedding + num_embeddings: ${brake_pedal_bins} + embedding_dim: ${encoder_embedding_dim} + steering_angle: + _target_: rmind.components.nn.Embedding + num_embeddings: ${steering_angle_bins} + embedding_dim: ${encoder_embedding_dim} + gas_pedal_diff: null + brake_pedal_diff: null + steering_angle_diff: null + + context: + waypoints: + _target_: rmind.components.nn.Linear + in_features: 2 + out_features: ${encoder_embedding_dim} + + discrete: + turn_signal: + _target_: rmind.components.nn.Embedding + num_embeddings: 3 + embedding_dim: ${encoder_embedding_dim} + + foresight: + _target_: rmind.components.nn.Embedding + num_embeddings: 256 + embedding_dim: ${encoder_embedding_dim} + + summary: + _target_: rmind.components.nn.Embedding + num_embeddings: 3 + embedding_dim: ${encoder_embedding_dim} + + utility: + _target_: rmind.components.nn.Embedding + num_embeddings: 1 + embedding_dim: ${encoder_embedding_dim} + + projections: + _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: torch.nn.Sequential + _args_: + - _target_: torch.nn.LayerNorm + normalized_shape: ${image_embedding_dim} + - _target_: rmind.components.norm.ScaleByVectorDimensionality + dim: ${image_embedding_dim} + - _target_: rmind.components.nn.Linear + in_features: ${image_embedding_dim} + out_features: ${encoder_embedding_dim} + + continuous: + speed: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + gas_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + brake_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + steering_angle: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + gas_pedal_diff: null + brake_pedal_diff: null + steering_angle_diff: null + + context: + waypoints: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + discrete: + turn_signal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + foresight: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + summary: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + utility: # do we need it? + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + role_encoding: # (modality, type) + _target_: rmind.components.nn.Embedding + num_embeddings: 8 + embedding_dim: ${encoder_embedding_dim} + + attention_mask_builder: + _target_: rmind.components.mask.FactorizedCausalAttentionMaskBuilder + +encoder: + _target_: rmind.components.transformer.TransformerEncoder + dim_model: ${encoder_embedding_dim} + num_layers: ${num_layers} + num_heads: ${num_heads} + attn_dropout: 0.1 + resid_dropout: 0.1 + mlp_dropout: 0.1 + hidden_layer_multiplier: 1 + emb_norm: + _target_: torch.nn.LayerNorm + normalized_shape: ${encoder_embedding_dim} + rope: + _target_: rmind.components.position_encoding.RotaryPositionalEmbeddings + dim: + _target_: operator.floordiv + _args_: + - ${encoder_embedding_dim} + - ${num_heads} + max_seq_len: 256 + base: 10 + +objectives: + _target_: rmind.components.containers.ModuleDict + _convert_: all + modules: + inverse_dynamics: + _target_: rmind.components.objectives.InverseDynamicsPredictionObjective + norm: + _target_: torch.nn.LayerNorm + normalized_shape: ${encoder_embedding_dim} + heads: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${gas_pedal_bins} + bias: False + brake_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${brake_pedal_bins} + bias: False + steering_angle: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${steering_angle_bins} + bias: False + discrete: + turn_signal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: 3 + bias: False + + targets: + continuous: + gas_pedal: [input_tokens, continuous, gas_pedal] + brake_pedal: [input_tokens, continuous, brake_pedal] + steering_angle: [input_tokens, continuous, steering_angle] + discrete: + turn_signal: [input_tokens, discrete, turn_signal] + + losses: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + brake_pedal: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + steering_angle: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + discrete: + turn_signal: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + + forward_dynamics: + _target_: rmind.components.objectives.ForwardDynamicsPredictionObjective + # null is intentional: the speed projection LayerNorms its inputs and + # foresight features pass through their own projection + norm: null + patch_pos_embed: + _target_: rmind.components.position_encoding.PatchPositionEmbedding2D + grid_size: [16, 16] + embedding_dim: ${image_embedding_dim} + projections: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + speed: + _target_: torch.nn.LayerNorm + normalized_shape: + _target_: operator.mul + _args_: + - 2 + - ${encoder_embedding_dim} + foresight: + cam_front_left: + _target_: torch.nn.Sequential + _args_: + - _target_: rmind.components.nn.Linear + in_features: + _target_: operator.mul + _args_: + - 2 + - ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + heads: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + speed: + _target_: rmind.components.nn.Linear + in_features: + _target_: operator.mul + _args_: + - 2 + - ${encoder_embedding_dim} + out_features: ${speed_bins} + bias: False + foresight: + cam_front_left: + _target_: rmind.components.transformer.CrossAttentionDecoderHead + decoder: + _target_: rmind.components.transformer.CrossAttentionDecoder + dim_model: ${image_embedding_dim} + num_layers: 2 + num_heads: 4 + attn_dropout: 0.1 + resid_dropout: 0.1 + mlp_dropout: 0.1 + hidden_layer_multiplier: 1 + output_projection: + _target_: rmind.components.nn.Linear + in_features: ${image_embedding_dim} + out_features: ${image_embedding_dim} + + targets: + continuous: + speed: [input_tokens, continuous, speed] + foresight: + cam_front_left: [input_embeddings, image, cam_front_left] + + losses: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + speed: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + foresight: + cam_front_left: + _target_: rmind.components.loss.GramAnchoringObjective + weight_sim: 1.0 + weight_gram: 100.0 + patches: 256 + + memory_extraction: + _target_: rmind.components.objectives.MemoryExtractionObjective + norm: + _target_: torch.nn.LayerNorm + normalized_shape: ${encoder_embedding_dim} + heads: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal_diff: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${gas_pedal_bins} + bias: False + + brake_pedal_diff: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${brake_pedal_bins} + bias: False + + steering_angle_diff: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${steering_angle_bins} + bias: False + + targets: + continuous: + gas_pedal_diff: [input_tokens, continuous, gas_pedal_diff] + brake_pedal_diff: [input_tokens, continuous, brake_pedal_diff] + steering_angle_diff: [input_tokens, continuous, steering_angle_diff] + + losses: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal_diff: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + + brake_pedal_diff: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + + steering_angle_diff: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + +optimizer: + _target_: rmind.components.optimizers.SelectiveAdamW + _recursive_: true + lr: 1e-5 + betas: [0.9, 0.95] + weight_decay: 0.1 + weight_decay_module_blacklist: + - _target_: hydra.utils.get_class + path: torch.nn.Embedding + - _target_: hydra.utils.get_class + path: torch.nn.LayerNorm + +lr_scheduler: + interval: step + scheduler: + _target_: rmind.components.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 25000 + num_training_steps: 250000 diff --git a/config/trainer/callbacks/pretrain_vjepa.yaml b/config/trainer/callbacks/pretrain_vjepa.yaml new file mode 100644 index 00000000..9d993baa --- /dev/null +++ b/config/trainer/callbacks/pretrain_vjepa.yaml @@ -0,0 +1,57 @@ +- _target_: rmind.callbacks.LogitBiasSetter +# VjepaBackbone is NOT frozen during pre-training — the full encoder is trainable. +# (The TimmBackbone freeze entry present in pretrain.yaml is intentionally omitted.) +- _target_: pytorch_lightning.callbacks.TQDMProgressBar +- _target_: pytorch_lightning.callbacks.ModelSummary + max_depth: 5 +- _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_epochs: 1 + save_on_train_epoch_end: True +- _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step +- _target_: rmind.callbacks.WandbAttentionMaskLogger +- _target_: rmind.callbacks.WandbImageParamLogger + when: on_train_batch_end + every_n_batch: 100 + key: similarity + select: + - [episode_builder, embeddings, continuous] + - [episode_builder, embeddings, discrete] + - [episode_builder, role_encoding, weight] + apply: + _target_: torchmetrics.functional.pairwise_cosine_similarity + _partial_: true +- _target_: rmind.callbacks.WandbWaypointsLogger + when: on_train_batch_end + every_n_batch: 100 + key: waypoints + crs: "EPSG:25832" + data: + image: [data, cam_front_left, 0, -1] + waypoints_xy_normalized: [data, waypoints/xy_normalized, 0, -1] + waypoints_xy: [data, waypoints/xy, 0, -1] + ego_xy: [data, meta/Gnss/xy, 0, -1] + caption: + input_id: [meta, input_id, 0] + time_stamp: [data, meta/ImageMetadata.cam_front_left/time_stamp, 0, -1] + frame_idx: [data, meta/ImageMetadata.cam_front_left/frame_idx, 0, -1] +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_train_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_validation_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 From 3e7da8fbf371eacc8417de345eaca040f65cc8fc Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Wed, 13 May 2026 08:34:39 +0200 Subject: [PATCH 4/8] refactor: load vjepa2.1 weights via torch.hub instead of sys.path hack Replaces the hardcoded sys.path.insert + manual checkpoint loading with torch.hub.load("facebookresearch/vjepa2", "vjepa2_1_vit_large_384", pretrained=True), which auto-downloads and caches weights on first use. Removes _VJEPA2_ROOT, _ensure_vjepa2_on_path, _clean_keys, and the checkpoint_path constructor arg. Updates raw_vjepa.yaml accordingly. --- .../yaak/control_transformer/raw_vjepa.yaml | 3 +- src/rmind/components/vjepa_backbone.py | 55 +++---------------- 2 files changed, 9 insertions(+), 49 deletions(-) diff --git a/config/model/yaak/control_transformer/raw_vjepa.yaml b/config/model/yaak/control_transformer/raw_vjepa.yaml index 82526c9f..f7243c92 100644 --- a/config/model/yaak/control_transformer/raw_vjepa.yaml +++ b/config/model/yaak/control_transformer/raw_vjepa.yaml @@ -176,8 +176,7 @@ episode_builder: _target_: torch.nn.Sequential _args_: - _target_: rmind.components.vjepa_backbone.VjepaBackbone - checkpoint_path: /home/max/Code/rmind/vjepa2_1_vitl_dist_vitG_384.pt - img_size: [256, 256] + img_size: 256 continuous: speed: diff --git a/src/rmind/components/vjepa_backbone.py b/src/rmind/components/vjepa_backbone.py index 80fbcb57..659ad85c 100644 --- a/src/rmind/components/vjepa_backbone.py +++ b/src/rmind/components/vjepa_backbone.py @@ -1,67 +1,28 @@ -import sys from math import prod -from pathlib import Path from typing import override import torch from torch import Tensor, nn -_VJEPA2_ROOT = str(Path("/home/max/Code/vjepa2")) - - -def _ensure_vjepa2_on_path() -> None: - if _VJEPA2_ROOT not in sys.path: - sys.path.insert(0, _VJEPA2_ROOT) - - -def _clean_keys(state_dict: dict) -> dict: - # Strip "module." and "backbone." prefixes written by vjepa2 DDP trainer. - cleaned = {} - for key, val in state_dict.items(): - key = key.replace("module.", "").replace("backbone.", "") # noqa: PLW2901 - cleaned[key] = val - return cleaned - class VjepaBackbone(nn.Module): """V-JEPA 2.1 ViT-L/16 image encoder. Wraps the vjepa2 VisionTransformer so it presents the same interface as TimmBackbone: accepts (*B, C, H, W) and returns (*B, N_patches, embed_dim). - The encoder is fully trainable during rmind pre-training; the vjepa2 DINOv3 - assumption in CLAUDE.md was found to be incorrect — see Phase 1 report. + Weights are downloaded automatically via torch.hub on first use. """ - def __init__( - self, checkpoint_path: str, img_size: tuple[int, int] = (256, 256) - ) -> None: + def __init__(self, img_size: int = 256) -> None: super().__init__() - _ensure_vjepa2_on_path() - from app.vjepa_2_1.models.vision_transformer import ( # noqa: PLC0415 # ty: ignore[unresolved-import] - vit_large, # type: ignore[import] - ) - - # Canonical vjepa2.1 ViT-L kwargs from src/hub/backbones.py:229-241. - # img_temporal_dim_size=1 activates per-frame (image) mode when the - # temporal dim of a 5-D input equals 1. - self.encoder: nn.Module = vit_large( - patch_size=16, + encoder, _ = torch.hub.load( + "facebookresearch/vjepa2", + "vjepa2_1_vit_large_384", + pretrained=True, img_size=img_size, - num_frames=64, - tubelet_size=2, - use_sdpa=True, - use_rope=True, - img_temporal_dim_size=1, - interpolate_rope=True, - ) - - raw = torch.load(checkpoint_path, map_location="cpu", weights_only=False) - missing, unexpected = self.encoder.load_state_dict( - _clean_keys(raw["ema_encoder"]), strict=True + trust_repo=True, ) - if missing or unexpected: - msg = f"Checkpoint key mismatch: missing={missing}, unexpected={unexpected}" - raise RuntimeError(msg) + self.encoder: nn.Module = encoder @override def forward(self, x: Tensor) -> Tensor: From fad6f6ba64c9fbd3e03bc51984050653ad555200 Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Wed, 13 May 2026 10:17:30 +0200 Subject: [PATCH 5/8] fix: correct config dims and decoder reshape for vjepa2.1 integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - VjepaBackbone: drop img_size arg (hub fn hardcodes 384, kwarg caused conflict) - raw_vjepa.yaml: patch_pos_embed embedding_dim image→encoder (mask tokens live in 384-dim rmind space, not 1024-dim vjepa space) - raw_vjepa.yaml: CrossAttentionDecoder dim_model + in_features image→encoder (query/context are 384-dim; out_features stays image_embedding_dim to predict 1024-dim vjepa targets) - decoder.py: CrossAttentionDecoderHead reshape output.reshape(..., d) → reshape(..., -1) so output dim can differ from query dim (was latently broken in dinov3 config where both dims were equal) --- config/model/yaak/control_transformer/raw_vjepa.yaml | 7 +++---- src/rmind/components/transformer/decoder.py | 2 +- src/rmind/components/vjepa_backbone.py | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/config/model/yaak/control_transformer/raw_vjepa.yaml b/config/model/yaak/control_transformer/raw_vjepa.yaml index f7243c92..70869793 100644 --- a/config/model/yaak/control_transformer/raw_vjepa.yaml +++ b/config/model/yaak/control_transformer/raw_vjepa.yaml @@ -176,7 +176,6 @@ episode_builder: _target_: torch.nn.Sequential _args_: - _target_: rmind.components.vjepa_backbone.VjepaBackbone - img_size: 256 continuous: speed: @@ -383,7 +382,7 @@ objectives: patch_pos_embed: _target_: rmind.components.position_encoding.PatchPositionEmbedding2D grid_size: [16, 16] - embedding_dim: ${image_embedding_dim} + embedding_dim: ${encoder_embedding_dim} projections: _target_: rmind.components.containers.ModuleDict modules: @@ -425,7 +424,7 @@ objectives: _target_: rmind.components.transformer.CrossAttentionDecoderHead decoder: _target_: rmind.components.transformer.CrossAttentionDecoder - dim_model: ${image_embedding_dim} + dim_model: ${encoder_embedding_dim} num_layers: 2 num_heads: 4 attn_dropout: 0.1 @@ -434,7 +433,7 @@ objectives: hidden_layer_multiplier: 1 output_projection: _target_: rmind.components.nn.Linear - in_features: ${image_embedding_dim} + in_features: ${encoder_embedding_dim} out_features: ${image_embedding_dim} targets: diff --git a/src/rmind/components/transformer/decoder.py b/src/rmind/components/transformer/decoder.py index 9c17688f..bb65910d 100644 --- a/src/rmind/components/transformer/decoder.py +++ b/src/rmind/components/transformer/decoder.py @@ -136,7 +136,7 @@ def forward(self, input: Input) -> Tensor: decoded = self.decoder(query_flat, context_flat) output = self.output_projection(decoded) - return output.reshape(b, t, sq, d) + return output.reshape(b, t, sq, -1) decoded = self.decoder(query, context) return self.output_projection(decoded) diff --git a/src/rmind/components/vjepa_backbone.py b/src/rmind/components/vjepa_backbone.py index 659ad85c..e6b25b2e 100644 --- a/src/rmind/components/vjepa_backbone.py +++ b/src/rmind/components/vjepa_backbone.py @@ -13,13 +13,12 @@ class VjepaBackbone(nn.Module): Weights are downloaded automatically via torch.hub on first use. """ - def __init__(self, img_size: int = 256) -> None: + def __init__(self) -> None: super().__init__() encoder, _ = torch.hub.load( "facebookresearch/vjepa2", "vjepa2_1_vit_large_384", pretrained=True, - img_size=img_size, trust_repo=True, ) self.encoder: nn.Module = encoder From cbad69ca5d2cac7b8284d8e155aaf1153ee3fee2 Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Wed, 13 May 2026 16:33:21 +0200 Subject: [PATCH 6/8] feat: freeze VjepaBackbone by default during pre-training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per V-JEPA 2 paper §3.1: the encoder is frozen; only the transformer and prediction heads are trained. Adds ModuleFreezer for VjepaBackbone to pretrain_vjepa callbacks (304M → non-trainable, 21M trainable). For end-to-end training use trainer/callbacks=pretrain_vjepa_unfrozen. --- config/trainer/callbacks/pretrain_vjepa.yaml | 9 ++- .../callbacks/pretrain_vjepa_unfrozen.yaml | 58 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml diff --git a/config/trainer/callbacks/pretrain_vjepa.yaml b/config/trainer/callbacks/pretrain_vjepa.yaml index 9d993baa..46e545ef 100644 --- a/config/trainer/callbacks/pretrain_vjepa.yaml +++ b/config/trainer/callbacks/pretrain_vjepa.yaml @@ -1,6 +1,11 @@ - _target_: rmind.callbacks.LogitBiasSetter -# VjepaBackbone is NOT frozen during pre-training — the full encoder is trainable. -# (The TimmBackbone freeze entry present in pretrain.yaml is intentionally omitted.) +# Freeze the vjepa encoder; only the rmind transformer + heads are trained. +# Mirrors V-JEPA 2 paper §3.1: "we freeze the video encoder and learn a new +# action-conditioned predictor on top of the learned representation." +# To train end-to-end instead, use trainer/callbacks=pretrain_vjepa_unfrozen. +- _target_: rmind.callbacks.ModuleFreezer + types: + - rmind.components.vjepa_backbone.VjepaBackbone - _target_: pytorch_lightning.callbacks.TQDMProgressBar - _target_: pytorch_lightning.callbacks.ModelSummary max_depth: 5 diff --git a/config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml b/config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml new file mode 100644 index 00000000..d2ceaf1c --- /dev/null +++ b/config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml @@ -0,0 +1,58 @@ +- _target_: rmind.callbacks.LogitBiasSetter +# End-to-end variant: VjepaBackbone is trainable alongside the rmind transformer. +# Default (frozen) config is pretrain_vjepa. Use this via: +# trainer/callbacks=pretrain_vjepa_unfrozen +- _target_: pytorch_lightning.callbacks.TQDMProgressBar +- _target_: pytorch_lightning.callbacks.ModelSummary + max_depth: 5 +- _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_epochs: 1 + save_on_train_epoch_end: True +- _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step +- _target_: rmind.callbacks.WandbAttentionMaskLogger +- _target_: rmind.callbacks.WandbImageParamLogger + when: on_train_batch_end + every_n_batch: 100 + key: similarity + select: + - [episode_builder, embeddings, continuous] + - [episode_builder, embeddings, discrete] + - [episode_builder, role_encoding, weight] + apply: + _target_: torchmetrics.functional.pairwise_cosine_similarity + _partial_: true +- _target_: rmind.callbacks.WandbWaypointsLogger + when: on_train_batch_end + every_n_batch: 100 + key: waypoints + crs: "EPSG:25832" + data: + image: [data, cam_front_left, 0, -1] + waypoints_xy_normalized: [data, waypoints/xy_normalized, 0, -1] + waypoints_xy: [data, waypoints/xy, 0, -1] + ego_xy: [data, meta/Gnss/xy, 0, -1] + caption: + input_id: [meta, input_id, 0] + time_stamp: [data, meta/ImageMetadata.cam_front_left/time_stamp, 0, -1] + frame_idx: [data, meta/ImageMetadata.cam_front_left/frame_idx, 0, -1] +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_train_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_validation_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 From 5f3bcbc0863713477756d136aced7d28a02282a0 Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Wed, 13 May 2026 19:55:09 +0200 Subject: [PATCH 7/8] fix: wire wandb entity/project from config, drop check-git from train recipe - WandbLogger now reads entity/project from Hydra config (${wandb.entity}, ${wandb.project}) so personal projects work without env-var overrides - Removed check-git dependency from `just train` (was blocking runs on dirty trees during active dev) Co-Authored-By: Claude Sonnet 4.6 --- config/trainer/default.yaml | 2 ++ justfile | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/config/trainer/default.yaml b/config/trainer/default.yaml index f069186c..62308a7b 100644 --- a/config/trainer/default.yaml +++ b/config/trainer/default.yaml @@ -12,3 +12,5 @@ enable_model_summary: false logger: _target_: pytorch_lightning.loggers.WandbLogger log_model: all + entity: ${wandb.entity} + project: ${wandb.project} diff --git a/justfile b/justfile index 126af041..c9c35c4f 100644 --- a/justfile +++ b/justfile @@ -41,7 +41,7 @@ generate-config: --ignore-unknown-comments \ --strict -train *ARGS: generate-config check-git +train *ARGS: generate-config uv run rmind-train \ --config-path {{ justfile_directory() }}/config \ --config-name train.yaml \ From 7208602be13ace2cf14cdf5c7e19701c3856c4fe Mon Sep 17 00:00:00 2001 From: Max Moeller Date: Wed, 13 May 2026 19:55:15 +0200 Subject: [PATCH 8/8] chore: clean up .gitignore, add CLAUDE.md and debug dataset config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove agent-specific .gitignore entries (agents/, scripts/, PROMPT.md, plan-*.md, CLAUDE.md, rmind-vjepa2.1-agent-kit.zip) — these were added during automated planning and shouldn't be in version control - Keep .claude/ and vjepa2_1_vitl_dist_vitG_384.pt ignores (still valid) - Track CLAUDE.md — project-level Claude Code instructions for this repo - Track config/_templates/dataset/yaak/train_debug.yaml and config/datamodule/yaak/train_debug.yaml — 3-drive debug split used in smoke-test runs Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 8 - CLAUDE.md | 85 ++++ .../_templates/dataset/yaak/train_debug.yaml | 396 ++++++++++++++++++ config/datamodule/yaak/train_debug.yaml | 28 ++ 4 files changed, 509 insertions(+), 8 deletions(-) create mode 100644 CLAUDE.md create mode 100644 config/_templates/dataset/yaak/train_debug.yaml create mode 100644 config/datamodule/yaak/train_debug.yaml diff --git a/.gitignore b/.gitignore index 6455b396..c849a2d7 100644 --- a/.gitignore +++ b/.gitignore @@ -144,15 +144,7 @@ multirun/ # generated config/dataset/* config/logger/* -config/_templates/dataset/yaak/train_debug.yaml -config/datamodule/yaak/train_debug.yaml .claude/ -CLAUDE.md -PROMPT.md -agents/ -plan-*.md -rmind-vjepa2.1-agent-kit.zip -scripts/ vjepa2_1_vitl_dist_vitG_384.pt # editor diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..116e6573 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,85 @@ +# rmind × vjepa2.1 integration — agent constraints + +These constraints apply to every Claude Code invocation in this repo, interactive +or headless. They override any conflicting instruction in a prompt. + +## Task scope + +Two deliverables, in order: + +1. **Pre-training encoder integration.** Add vjepa2.1 (ViT-L/16) as an encoder + in rmind, mirroring the existing dinov3 image-encoder integration pattern. + Determine whether it slots in as the episode encoder, the image encoder, or + replaces both — and justify the choice in writing before changing any file. + +2. **Fine-tuning config.** A separate YAML/template that loads the pre-trained + vjepa2.1 encoder and trains the action-prediction policy on top of it. + +## Integration rules + +- **YAML/templates first.** New components are wired in via the existing + template system (Hydra / OmegaConf / whatever rmind uses — discover it, don't + assume). Python code is a last resort. +- **If Python is necessary**, justify it explicitly in the plan: what YAML + feature is missing, why a new module class is required. One paragraph minimum. +- **Mirror dinov3.** Before writing anything, diff the proposed structure + against the dinov3 image-encoder YAML. The new files should look like + siblings, not cousins. +- **Read-only dependencies.** Do not modify source files inside `~/Code/rbyte` + or `~/Code/vjepa2`. They are dependencies. If something is broken in them, + report it — don't patch it. + +## Architectural decision: freezing + +- The **vjepa2.1 encoder itself is NOT frozen** during pre-training in rmind. +- The **dinov3 component *inside* vjepa2.1 IS frozen**. (V-JEPA 2.1 uses a + pre-trained DINOv3 as part of its target/teacher pipeline; that part stays + frozen.) +- This means the integration must expose two parameter groups, or set + `requires_grad=False` on the dinov3 sub-module specifically. Verify this is + achievable from YAML; if not, that's a legitimate reason to add Python. + +## Checkpoint + +- Use the **ViT-L/16 vjepa2.1** checkpoint. Do not substitute ViT-B, ViT-H, + or any other variant without asking. +- The exact path/URL of the checkpoint is unknown to the agent at the start. + Phase 1 must locate it (inside `~/Code/vjepa2`, in a release artifact, or via + the paper). If it cannot be located, STOP and ask. + +## Open questions the agent must resolve in Phase 1 + +These are unknowns flagged by the user — do not guess; report findings. + +1. Whether "vjepa2.1" lives in the `~/Code/vjepa2` repo (branch, subdir, tag) + or somewhere else. +2. Whether the second paper URL (arXiv 2603.14482) is reachable. The ID looks + malformed (the YYMM prefix doesn't parse). If `WebFetch` fails, report it — + do not invent the paper's contents. +3. Whether vjepa2.1 fits the episode-encoder slot, the image-encoder slot, or + subsumes both. The user's hypothesis is "both in one go" — verify against + the actual rmind interfaces. +4. Action-prediction policy details for fine-tuning: action space, dataset, + loss. If rmind has an existing action-prediction config, reuse its defaults + and call them out. If not, STOP and ask. + +## Safety rails + +- All work on a feature branch: `feat/vjepa2.1-encoder`. Create it before any + edit. Never commit to main. +- Never `git push`. Never `git push --force`. Never `git reset --hard` against + anything the user has touched. +- Never run training. Validation = config loads + one forward pass on dummy + tensors of the documented input shape. That's it. +- Never download model weights without confirming. If a checkpoint isn't on + disk, report the URL and stop. +- Do not modify files outside `~/Code/rmind`. + +## Reporting + +- Every claim about an existing file must cite `path:line`. +- The Phase 2 plan is a hard checkpoint. Print it, then stop. Wait for the + orchestrator script (or the user) to advance to Phase 3. +- On completion, produce a summary listing: files added, files edited, Python + additions (with justification), config-load validation result, forward-pass + validation result. diff --git a/config/_templates/dataset/yaak/train_debug.yaml b/config/_templates/dataset/yaak/train_debug.yaml new file mode 100644 index 00000000..ac460078 --- /dev/null +++ b/config/_templates/dataset/yaak/train_debug.yaml @@ -0,0 +1,396 @@ +#@yaml/text-templated-strings + +#@ drives = [ +#@ 'Niro096-HQ/2023-01-11--13-47-36', +#@ 'Niro113-HQ/2023-06-08--10-39-29', +#@ ] + +--- +_target_: rbyte.Dataset.from_config +_recursive_: false +_convert_: all +streams: + cam_front_left: + index: meta/ImageMetadata.cam_front_left/frame_idx + sources: + #@ for/end drive_id in drives: + (@=drive_id@): + _target_: rbyte.io.PathTensorSource + path: "${paths.data}/(@=drive_id@)/frames/cam_front_left.pii.mp4/576x324/{:09d}.jpg" + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + +samples: + inputs: + input_id: + #@ for/end drive_id in drives: + - (@=drive_id@) + + yaak_metadata_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/metadata.log + + waypoints_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/waypoints.json + + headings_denoised_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/headings_denoised.json + + cam_front_left_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/frames/cam_front_left.pii.mp4/576x324 + + executor: + _target_: concurrent.futures.ProcessPoolExecutor + mp_context: + _target_: multiprocessing.get_context + method: forkserver + + storage: file_array + run_folder: ${paths.rbyte.cache}/yaak/train/samples + scheduling_strategy: eager + return_results: false + persist_memory: false + + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + cache_type: disk + cache_kwargs: + cache_dir: ${paths.rbyte.cache} + functions: + - _target_: pipefunc.PipeFunc + renames: + path: yaak_metadata_path + output_name: meta + mapspec: "yaak_metadata_path[i] -> meta[i]" + cache: true + func: + _target_: rbyte.io.YaakMetadataDataFrameBuilder + fields: + rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: us + frame_idx: + _target_: polars.Int32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.yaak.proto.can_pb2.VehicleState: + time_stamp: + _target_: polars.Datetime + time_unit: us + turn_signal: + _target_: polars.Int8 + + rbyte.io.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: us + speed: + _target_: polars.Float32 + gas_pedal_normalized: + _target_: polars.Float32 + brake_pedal_normalized: + _target_: polars.Float32 + steering_angle_normalized: + _target_: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + rbyte.io.yaak.proto.sensor_pb2.Gnss: + time_stamp: + _target_: polars.Datetime + time_unit: us + latitude: + _target_: polars.Float32 + longitude: + _target_: polars.Float32 + + - _target_: pipefunc.PipeFunc + output_name: waypoints_raw + mapspec: "waypoints_path[i] -> waypoints_raw[i]" + func: + _target_: makefun.create_function + func_signature: "build_waypoints(*, waypoints_path)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + extensions: [spatial] + config: + TimeZone: UTC + query: | + SELECT TO_TIMESTAMP(timestamp)::TIMESTAMP AS timestamp, + ST_AsWKB(ST_Transform(geom, 'EPSG:4326', 'EPSG:25832', always_xy := true)) AS geometry + FROM ST_Read($waypoints_path) + + - _target_: pipefunc.PipeFunc + renames: + input: waypoints_raw + output_name: waypoints + mapspec: "waypoints_raw[i] -> waypoints[i]" + func: + _target_: rbyte.io.WaypointBuilder + length: 91 + columns: + points: geometry + output: xy + + - _target_: pipefunc.PipeFunc + output_name: headings_denoised + mapspec: "headings_denoised_path[i] -> headings_denoised[i]" + cache: true + func: + _target_: makefun.create_function + func_signature: "build_headings_denoised(*, headings_denoised_path)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + query: | + SELECT make_timestamp(h.timestamp_us) AS timestamp, h.heading + FROM (SELECT unnest(headings) AS h FROM read_json_auto($headings_denoised_path)) + + - _target_: pipefunc.PipeFunc + output_name: aligned + mapspec: "meta[i], waypoints[i], headings_denoised[i] -> aligned[i]" + func: + _target_: makefun.create_function + func_signature: "align(*, meta, waypoints, headings_denoised)" + func_impl: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + meta: + ImageMetadata.cam_front_left: + key: time_stamp + + VehicleState: + key: time_stamp + columns: + turn_signal: + method: asof + tolerance: 100ms + + VehicleMotion: + key: time_stamp + columns: + speed: + method: interp + gas_pedal_normalized: + method: interp + brake_pedal_normalized: + method: interp + steering_angle_normalized: + method: interp + gear: + method: asof + tolerance: 100ms + strategy: nearest + + Gnss: + key: time_stamp + columns: + latitude: + method: asof + tolerance: 500ms + strategy: nearest + longitude: + method: asof + tolerance: 500ms + strategy: nearest + + waypoints: + key: timestamp + columns: + xy: + method: asof + strategy: forward + + headings_denoised: + key: timestamp + columns: + heading: + method: asof + strategy: nearest + + - _target_: pipefunc.PipeFunc + renames: + path: cam_front_left_path + output_name: cam_front_left_meta + mapspec: "cam_front_left_path[i] -> cam_front_left_meta[i]" + cache: true + func: + _target_: rbyte.io.PathDataFrameBuilder + pattern: (?\d+).jpg + fields: + frame_idx: + _target_: polars.Int32 + + - _target_: pipefunc.PipeFunc + output_name: filtered + mapspec: "aligned[i], cam_front_left_meta[i] -> filtered[i]" + func: + _target_: makefun.create_function + func_signature: "filter(*, aligned, cam_front_left_meta)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + extensions: [spatial] + query: | + WITH + base_data AS ( + SELECT + *, + ST_Transform( + ST_Point("meta/Gnss/longitude", "meta/Gnss/latitude"), + 'EPSG:4326', 'EPSG:25832', always_xy := true + ) AS ego_geom, + ST_GeomFromWKB("waypoints/xy") AS waypoints_geom + FROM + aligned + SEMI JOIN cam_front_left_meta + ON aligned."meta/ImageMetadata.cam_front_left/frame_idx" + = cam_front_left_meta.frame_idx + WHERE + aligned."meta/VehicleMotion/gear" == '3' + AND aligned."meta/VehicleMotion/speed" BETWEEN 0.0 AND 130.0 + AND aligned."meta/VehicleMotion/gas_pedal_normalized" BETWEEN 0.0 AND 1.0 + AND aligned."meta/VehicleMotion/brake_pedal_normalized" BETWEEN 0.0 AND 1.0 + AND aligned."meta/VehicleMotion/steering_angle_normalized" BETWEEN -1.0 AND 1.0 + AND COLUMNS(*) IS NOT NULL + ), + normalized_geometries AS ( + SELECT + *, + ST_Rotate( + ST_Translate( + waypoints_geom, + - ST_X(ego_geom), + - ST_Y(ego_geom) + ), + radians("headings_denoised/heading") + ) AS normalized_waypoints_geom + FROM + base_data + ) + SELECT + * EXCLUDE ( + "meta/VehicleMotion/gear", + waypoints_geom, + normalized_waypoints_geom, + ego_geom, + "waypoints/xy", + "headings_denoised/heading" + ), + [ + ST_X(ego_geom), + ST_Y(ego_geom) + ] AS "meta/Gnss/xy", + ( + SELECT + list( + [ST_X(p.point_struct.geom), ST_Y(p.point_struct.geom)] + ORDER BY + p.point_struct.path + ) + FROM + UNNEST(ST_Dump(waypoints_geom)) AS p(point_struct) + WHERE (p.point_struct.path[1] - 1) % 10 = 0 + ) AS "waypoints/xy", + ( + SELECT + list( + [ST_X(p.point_struct.geom) / 100, ST_Y(p.point_struct.geom) / 100] + ORDER BY + p.point_struct.path + ) + FROM + UNNEST(ST_Dump(normalized_waypoints_geom)) AS p(point_struct) + WHERE (p.point_struct.path[1] - 1) % 10 = 0 + ) AS "waypoints/xy_normalized" + FROM + normalized_geometries + WHERE + ST_Contains( + ST_MakeEnvelope(-150, -150, 150, 150), + normalized_waypoints_geom + ) + ORDER BY + "meta/ImageMetadata.cam_front_left/time_stamp"; + + - _target_: pipefunc.PipeFunc + renames: + input: filtered + output_name: samples + mapspec: "filtered[i] -> samples[i]" + func: + _target_: rbyte.io.DataFrameGroupByDynamic + index_column: meta/ImageMetadata.cam_front_left/frame_idx + every: 10i + period: 60i + closed: left + gather_every: 10 + + - _target_: pipefunc.PipeFunc + output_name: samples_cast + mapspec: "samples[i] -> samples_cast[i]" + func: + _target_: makefun.create_function + func_signature: "cast(*, samples)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + query: | + SELECT + "meta/ImageMetadata.cam_front_left/time_stamp"::TIMESTAMP[6] AS "meta/ImageMetadata.cam_front_left/time_stamp", + "meta/ImageMetadata.cam_front_left/frame_idx"::INT32[6] AS "meta/ImageMetadata.cam_front_left/frame_idx", + "meta/VehicleMotion/speed"::FLOAT[6] AS "meta/VehicleMotion/speed", + "meta/VehicleMotion/gas_pedal_normalized"::FLOAT[6] AS "meta/VehicleMotion/gas_pedal_normalized", + "meta/VehicleMotion/brake_pedal_normalized"::FLOAT[6] AS "meta/VehicleMotion/brake_pedal_normalized", + "meta/VehicleMotion/steering_angle_normalized"::FLOAT[6] AS "meta/VehicleMotion/steering_angle_normalized", + "meta/VehicleState/turn_signal"::INT8[6] AS "meta/VehicleState/turn_signal", + "meta/Gnss/xy"::FLOAT[2][6] AS "meta/Gnss/xy", + "waypoints/xy_normalized"::FLOAT[2][10][6] AS "waypoints/xy_normalized", + "waypoints/xy"::FLOAT[2][10][6] AS "waypoints/xy", + FROM + samples + WHERE + len("meta/ImageMetadata.cam_front_left/frame_idx") = 6 + AND list_last("meta/ImageMetadata.cam_front_left/frame_idx") - list_first("meta/ImageMetadata.cam_front_left/frame_idx") == 50 + AND NOT ( + list_max("meta/VehicleMotion/gas_pedal_normalized") <= (1.0 / 255 + 0.001) + AND list_max("meta/VehicleMotion/brake_pedal_normalized") <= (1.0 / 164 + 0.001) + AND list_max("meta/VehicleMotion/speed") >= 25.0 + AND list_last("meta/VehicleMotion/speed") - list_first("meta/VehicleMotion/speed") >= -0.05 * list_avg("meta/VehicleMotion/speed") + ) + + - _target_: pipefunc.PipeFunc + renames: + keys: input_id + values: samples_cast + output_name: samples_aggregated + func: + _target_: rbyte.io.DataFrameConcater + key_column: input_id + + - _target_: pipefunc.PipeFunc + output_name: samples_with_id + renames: + self: samples_aggregated + func: + _target_: polars.DataFrame.with_row_index + _partial_: true + name: meta/sample_id diff --git a/config/datamodule/yaak/train_debug.yaml b/config/datamodule/yaak/train_debug.yaml new file mode 100644 index 00000000..75c05382 --- /dev/null +++ b/config/datamodule/yaak/train_debug.yaml @@ -0,0 +1,28 @@ +--- +defaults: + - /dataset/yaak/train_debug@train.dataset + - /dataset/yaak/val@val.dataset + - _self_ + +_target_: rmind.datamodules.GenericDataModule +train: + _target_: rbyte.dataloader.TorchDataNodeDataLoader + batch_size: 64 + shuffle: true + collate_fn: + _target_: rbyte.types.Batch.to_dict + _partial_: true + pin_memory: true + num_workers: 2 + method: thread + +val: + _target_: rbyte.dataloader.TorchDataNodeDataLoader + batch_size: 32 + shuffle: false + collate_fn: + _target_: rbyte.types.Batch.to_dict + _partial_: true + pin_memory: true + num_workers: 2 + method: thread