diff --git a/.gitignore b/.gitignore index 526e0894..c849a2d7 100644 --- a/.gitignore +++ b/.gitignore @@ -144,6 +144,8 @@ multirun/ # generated config/dataset/* config/logger/* +.claude/ +vjepa2_1_vitl_dist_vitG_384.pt # editor .helix/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..116e6573 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,85 @@ +# rmind × vjepa2.1 integration — agent constraints + +These constraints apply to every Claude Code invocation in this repo, interactive +or headless. They override any conflicting instruction in a prompt. + +## Task scope + +Two deliverables, in order: + +1. **Pre-training encoder integration.** Add vjepa2.1 (ViT-L/16) as an encoder + in rmind, mirroring the existing dinov3 image-encoder integration pattern. + Determine whether it slots in as the episode encoder, the image encoder, or + replaces both — and justify the choice in writing before changing any file. + +2. **Fine-tuning config.** A separate YAML/template that loads the pre-trained + vjepa2.1 encoder and trains the action-prediction policy on top of it. + +## Integration rules + +- **YAML/templates first.** New components are wired in via the existing + template system (Hydra / OmegaConf / whatever rmind uses — discover it, don't + assume). Python code is a last resort. +- **If Python is necessary**, justify it explicitly in the plan: what YAML + feature is missing, why a new module class is required. One paragraph minimum. +- **Mirror dinov3.** Before writing anything, diff the proposed structure + against the dinov3 image-encoder YAML. The new files should look like + siblings, not cousins. +- **Read-only dependencies.** Do not modify source files inside `~/Code/rbyte` + or `~/Code/vjepa2`. They are dependencies. If something is broken in them, + report it — don't patch it. + +## Architectural decision: freezing + +- The **vjepa2.1 encoder itself is NOT frozen** during pre-training in rmind. +- The **dinov3 component *inside* vjepa2.1 IS frozen**. (V-JEPA 2.1 uses a + pre-trained DINOv3 as part of its target/teacher pipeline; that part stays + frozen.) +- This means the integration must expose two parameter groups, or set + `requires_grad=False` on the dinov3 sub-module specifically. Verify this is + achievable from YAML; if not, that's a legitimate reason to add Python. + +## Checkpoint + +- Use the **ViT-L/16 vjepa2.1** checkpoint. Do not substitute ViT-B, ViT-H, + or any other variant without asking. +- The exact path/URL of the checkpoint is unknown to the agent at the start. + Phase 1 must locate it (inside `~/Code/vjepa2`, in a release artifact, or via + the paper). If it cannot be located, STOP and ask. + +## Open questions the agent must resolve in Phase 1 + +These are unknowns flagged by the user — do not guess; report findings. + +1. Whether "vjepa2.1" lives in the `~/Code/vjepa2` repo (branch, subdir, tag) + or somewhere else. +2. Whether the second paper URL (arXiv 2603.14482) is reachable. The ID looks + malformed (the YYMM prefix doesn't parse). If `WebFetch` fails, report it — + do not invent the paper's contents. +3. Whether vjepa2.1 fits the episode-encoder slot, the image-encoder slot, or + subsumes both. The user's hypothesis is "both in one go" — verify against + the actual rmind interfaces. +4. Action-prediction policy details for fine-tuning: action space, dataset, + loss. If rmind has an existing action-prediction config, reuse its defaults + and call them out. If not, STOP and ask. + +## Safety rails + +- All work on a feature branch: `feat/vjepa2.1-encoder`. Create it before any + edit. Never commit to main. +- Never `git push`. Never `git push --force`. Never `git reset --hard` against + anything the user has touched. +- Never run training. Validation = config loads + one forward pass on dummy + tensors of the documented input shape. That's it. +- Never download model weights without confirming. If a checkpoint isn't on + disk, report the URL and stop. +- Do not modify files outside `~/Code/rmind`. + +## Reporting + +- Every claim about an existing file must cite `path:line`. +- The Phase 2 plan is a hard checkpoint. Print it, then stop. Wait for the + orchestrator script (or the user) to advance to Phase 3. +- On completion, produce a summary listing: files added, files edited, Python + additions (with justification), config-load validation result, forward-pass + validation result. diff --git a/config/_templates/dataset/yaak/train_debug.yaml b/config/_templates/dataset/yaak/train_debug.yaml new file mode 100644 index 00000000..ac460078 --- /dev/null +++ b/config/_templates/dataset/yaak/train_debug.yaml @@ -0,0 +1,396 @@ +#@yaml/text-templated-strings + +#@ drives = [ +#@ 'Niro096-HQ/2023-01-11--13-47-36', +#@ 'Niro113-HQ/2023-06-08--10-39-29', +#@ ] + +--- +_target_: rbyte.Dataset.from_config +_recursive_: false +_convert_: all +streams: + cam_front_left: + index: meta/ImageMetadata.cam_front_left/frame_idx + sources: + #@ for/end drive_id in drives: + (@=drive_id@): + _target_: rbyte.io.PathTensorSource + path: "${paths.data}/(@=drive_id@)/frames/cam_front_left.pii.mp4/576x324/{:09d}.jpg" + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + +samples: + inputs: + input_id: + #@ for/end drive_id in drives: + - (@=drive_id@) + + yaak_metadata_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/metadata.log + + waypoints_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/waypoints.json + + headings_denoised_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/headings_denoised.json + + cam_front_left_path: + #@ for/end drive_id in drives: + - ${paths.data}/(@=drive_id@)/frames/cam_front_left.pii.mp4/576x324 + + executor: + _target_: concurrent.futures.ProcessPoolExecutor + mp_context: + _target_: multiprocessing.get_context + method: forkserver + + storage: file_array + run_folder: ${paths.rbyte.cache}/yaak/train/samples + scheduling_strategy: eager + return_results: false + persist_memory: false + + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + cache_type: disk + cache_kwargs: + cache_dir: ${paths.rbyte.cache} + functions: + - _target_: pipefunc.PipeFunc + renames: + path: yaak_metadata_path + output_name: meta + mapspec: "yaak_metadata_path[i] -> meta[i]" + cache: true + func: + _target_: rbyte.io.YaakMetadataDataFrameBuilder + fields: + rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: us + frame_idx: + _target_: polars.Int32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.yaak.proto.can_pb2.VehicleState: + time_stamp: + _target_: polars.Datetime + time_unit: us + turn_signal: + _target_: polars.Int8 + + rbyte.io.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: us + speed: + _target_: polars.Float32 + gas_pedal_normalized: + _target_: polars.Float32 + brake_pedal_normalized: + _target_: polars.Float32 + steering_angle_normalized: + _target_: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + rbyte.io.yaak.proto.sensor_pb2.Gnss: + time_stamp: + _target_: polars.Datetime + time_unit: us + latitude: + _target_: polars.Float32 + longitude: + _target_: polars.Float32 + + - _target_: pipefunc.PipeFunc + output_name: waypoints_raw + mapspec: "waypoints_path[i] -> waypoints_raw[i]" + func: + _target_: makefun.create_function + func_signature: "build_waypoints(*, waypoints_path)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + extensions: [spatial] + config: + TimeZone: UTC + query: | + SELECT TO_TIMESTAMP(timestamp)::TIMESTAMP AS timestamp, + ST_AsWKB(ST_Transform(geom, 'EPSG:4326', 'EPSG:25832', always_xy := true)) AS geometry + FROM ST_Read($waypoints_path) + + - _target_: pipefunc.PipeFunc + renames: + input: waypoints_raw + output_name: waypoints + mapspec: "waypoints_raw[i] -> waypoints[i]" + func: + _target_: rbyte.io.WaypointBuilder + length: 91 + columns: + points: geometry + output: xy + + - _target_: pipefunc.PipeFunc + output_name: headings_denoised + mapspec: "headings_denoised_path[i] -> headings_denoised[i]" + cache: true + func: + _target_: makefun.create_function + func_signature: "build_headings_denoised(*, headings_denoised_path)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + query: | + SELECT make_timestamp(h.timestamp_us) AS timestamp, h.heading + FROM (SELECT unnest(headings) AS h FROM read_json_auto($headings_denoised_path)) + + - _target_: pipefunc.PipeFunc + output_name: aligned + mapspec: "meta[i], waypoints[i], headings_denoised[i] -> aligned[i]" + func: + _target_: makefun.create_function + func_signature: "align(*, meta, waypoints, headings_denoised)" + func_impl: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + meta: + ImageMetadata.cam_front_left: + key: time_stamp + + VehicleState: + key: time_stamp + columns: + turn_signal: + method: asof + tolerance: 100ms + + VehicleMotion: + key: time_stamp + columns: + speed: + method: interp + gas_pedal_normalized: + method: interp + brake_pedal_normalized: + method: interp + steering_angle_normalized: + method: interp + gear: + method: asof + tolerance: 100ms + strategy: nearest + + Gnss: + key: time_stamp + columns: + latitude: + method: asof + tolerance: 500ms + strategy: nearest + longitude: + method: asof + tolerance: 500ms + strategy: nearest + + waypoints: + key: timestamp + columns: + xy: + method: asof + strategy: forward + + headings_denoised: + key: timestamp + columns: + heading: + method: asof + strategy: nearest + + - _target_: pipefunc.PipeFunc + renames: + path: cam_front_left_path + output_name: cam_front_left_meta + mapspec: "cam_front_left_path[i] -> cam_front_left_meta[i]" + cache: true + func: + _target_: rbyte.io.PathDataFrameBuilder + pattern: (?\d+).jpg + fields: + frame_idx: + _target_: polars.Int32 + + - _target_: pipefunc.PipeFunc + output_name: filtered + mapspec: "aligned[i], cam_front_left_meta[i] -> filtered[i]" + func: + _target_: makefun.create_function + func_signature: "filter(*, aligned, cam_front_left_meta)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + extensions: [spatial] + query: | + WITH + base_data AS ( + SELECT + *, + ST_Transform( + ST_Point("meta/Gnss/longitude", "meta/Gnss/latitude"), + 'EPSG:4326', 'EPSG:25832', always_xy := true + ) AS ego_geom, + ST_GeomFromWKB("waypoints/xy") AS waypoints_geom + FROM + aligned + SEMI JOIN cam_front_left_meta + ON aligned."meta/ImageMetadata.cam_front_left/frame_idx" + = cam_front_left_meta.frame_idx + WHERE + aligned."meta/VehicleMotion/gear" == '3' + AND aligned."meta/VehicleMotion/speed" BETWEEN 0.0 AND 130.0 + AND aligned."meta/VehicleMotion/gas_pedal_normalized" BETWEEN 0.0 AND 1.0 + AND aligned."meta/VehicleMotion/brake_pedal_normalized" BETWEEN 0.0 AND 1.0 + AND aligned."meta/VehicleMotion/steering_angle_normalized" BETWEEN -1.0 AND 1.0 + AND COLUMNS(*) IS NOT NULL + ), + normalized_geometries AS ( + SELECT + *, + ST_Rotate( + ST_Translate( + waypoints_geom, + - ST_X(ego_geom), + - ST_Y(ego_geom) + ), + radians("headings_denoised/heading") + ) AS normalized_waypoints_geom + FROM + base_data + ) + SELECT + * EXCLUDE ( + "meta/VehicleMotion/gear", + waypoints_geom, + normalized_waypoints_geom, + ego_geom, + "waypoints/xy", + "headings_denoised/heading" + ), + [ + ST_X(ego_geom), + ST_Y(ego_geom) + ] AS "meta/Gnss/xy", + ( + SELECT + list( + [ST_X(p.point_struct.geom), ST_Y(p.point_struct.geom)] + ORDER BY + p.point_struct.path + ) + FROM + UNNEST(ST_Dump(waypoints_geom)) AS p(point_struct) + WHERE (p.point_struct.path[1] - 1) % 10 = 0 + ) AS "waypoints/xy", + ( + SELECT + list( + [ST_X(p.point_struct.geom) / 100, ST_Y(p.point_struct.geom) / 100] + ORDER BY + p.point_struct.path + ) + FROM + UNNEST(ST_Dump(normalized_waypoints_geom)) AS p(point_struct) + WHERE (p.point_struct.path[1] - 1) % 10 = 0 + ) AS "waypoints/xy_normalized" + FROM + normalized_geometries + WHERE + ST_Contains( + ST_MakeEnvelope(-150, -150, 150, 150), + normalized_waypoints_geom + ) + ORDER BY + "meta/ImageMetadata.cam_front_left/time_stamp"; + + - _target_: pipefunc.PipeFunc + renames: + input: filtered + output_name: samples + mapspec: "filtered[i] -> samples[i]" + func: + _target_: rbyte.io.DataFrameGroupByDynamic + index_column: meta/ImageMetadata.cam_front_left/frame_idx + every: 10i + period: 60i + closed: left + gather_every: 10 + + - _target_: pipefunc.PipeFunc + output_name: samples_cast + mapspec: "samples[i] -> samples_cast[i]" + func: + _target_: makefun.create_function + func_signature: "cast(*, samples)" + func_impl: + _target_: rbyte.io.DuckDBDataFrameQuery + query: | + SELECT + "meta/ImageMetadata.cam_front_left/time_stamp"::TIMESTAMP[6] AS "meta/ImageMetadata.cam_front_left/time_stamp", + "meta/ImageMetadata.cam_front_left/frame_idx"::INT32[6] AS "meta/ImageMetadata.cam_front_left/frame_idx", + "meta/VehicleMotion/speed"::FLOAT[6] AS "meta/VehicleMotion/speed", + "meta/VehicleMotion/gas_pedal_normalized"::FLOAT[6] AS "meta/VehicleMotion/gas_pedal_normalized", + "meta/VehicleMotion/brake_pedal_normalized"::FLOAT[6] AS "meta/VehicleMotion/brake_pedal_normalized", + "meta/VehicleMotion/steering_angle_normalized"::FLOAT[6] AS "meta/VehicleMotion/steering_angle_normalized", + "meta/VehicleState/turn_signal"::INT8[6] AS "meta/VehicleState/turn_signal", + "meta/Gnss/xy"::FLOAT[2][6] AS "meta/Gnss/xy", + "waypoints/xy_normalized"::FLOAT[2][10][6] AS "waypoints/xy_normalized", + "waypoints/xy"::FLOAT[2][10][6] AS "waypoints/xy", + FROM + samples + WHERE + len("meta/ImageMetadata.cam_front_left/frame_idx") = 6 + AND list_last("meta/ImageMetadata.cam_front_left/frame_idx") - list_first("meta/ImageMetadata.cam_front_left/frame_idx") == 50 + AND NOT ( + list_max("meta/VehicleMotion/gas_pedal_normalized") <= (1.0 / 255 + 0.001) + AND list_max("meta/VehicleMotion/brake_pedal_normalized") <= (1.0 / 164 + 0.001) + AND list_max("meta/VehicleMotion/speed") >= 25.0 + AND list_last("meta/VehicleMotion/speed") - list_first("meta/VehicleMotion/speed") >= -0.05 * list_avg("meta/VehicleMotion/speed") + ) + + - _target_: pipefunc.PipeFunc + renames: + keys: input_id + values: samples_cast + output_name: samples_aggregated + func: + _target_: rbyte.io.DataFrameConcater + key_column: input_id + + - _target_: pipefunc.PipeFunc + output_name: samples_with_id + renames: + self: samples_aggregated + func: + _target_: polars.DataFrame.with_row_index + _partial_: true + name: meta/sample_id diff --git a/config/datamodule/yaak/train_debug.yaml b/config/datamodule/yaak/train_debug.yaml new file mode 100644 index 00000000..75c05382 --- /dev/null +++ b/config/datamodule/yaak/train_debug.yaml @@ -0,0 +1,28 @@ +--- +defaults: + - /dataset/yaak/train_debug@train.dataset + - /dataset/yaak/val@val.dataset + - _self_ + +_target_: rmind.datamodules.GenericDataModule +train: + _target_: rbyte.dataloader.TorchDataNodeDataLoader + batch_size: 64 + shuffle: true + collate_fn: + _target_: rbyte.types.Batch.to_dict + _partial_: true + pin_memory: true + num_workers: 2 + method: thread + +val: + _target_: rbyte.dataloader.TorchDataNodeDataLoader + batch_size: 32 + shuffle: false + collate_fn: + _target_: rbyte.types.Batch.to_dict + _partial_: true + pin_memory: true + num_workers: 2 + method: thread diff --git a/config/experiment/yaak/control_transformer/finetune_vjepa.yaml b/config/experiment/yaak/control_transformer/finetune_vjepa.yaml new file mode 100644 index 00000000..fb3dba04 --- /dev/null +++ b/config/experiment/yaak/control_transformer/finetune_vjepa.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - /model: yaak/control_transformer/policy_finetune_vjepa + - /datamodule: yaak/train + - /trainer: default + - /trainer/callbacks: finetune + - /paths: yaak/default + - /wandb: yaak/rmind + - _self_ + +encoder_embedding_dim: 384 diff --git a/config/experiment/yaak/control_transformer/pretrain_vjepa.yaml b/config/experiment/yaak/control_transformer/pretrain_vjepa.yaml new file mode 100644 index 00000000..82202284 --- /dev/null +++ b/config/experiment/yaak/control_transformer/pretrain_vjepa.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - /model: yaak/control_transformer/raw_vjepa + - /datamodule: yaak/train + - /trainer: default + - /trainer/callbacks: pretrain_vjepa + - /paths: yaak/default + - /wandb: yaak/rmind + - _self_ + +num_heads: 4 +num_layers: 8 +encoder_embedding_dim: 384 +image_embedding_dim: 1024 + +speed_bins: 512 +gas_pedal_bins: 255 +brake_pedal_bins: 165 +steering_angle_bins: 961 diff --git a/config/model/yaak/control_transformer/policy_finetune_vjepa.yaml b/config/model/yaak/control_transformer/policy_finetune_vjepa.yaml new file mode 100644 index 00000000..cd892669 --- /dev/null +++ b/config/model/yaak/control_transformer/policy_finetune_vjepa.yaml @@ -0,0 +1,89 @@ +--- +_target_: rmind.models.control_transformer.ControlTransformer.load_from_checkpoint +checkpoint_path: ??? +strict: false +hparams_jq: | + .objectives = { + "_target_": "rmind.components.containers.ModuleDict", + "modules": { + "policy": { + "_target_": "rmind.components.objectives.PolicyObjective", + "norm": { + "_target_": "torch.nn.LayerNorm", + "normalized_shape": 384 + }, + "heads": { + "_target_": "rmind.components.containers.ModuleDict", + "modules": { + "continuous": { + "gas_pedal": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 2], + "bias": false + }, + "brake_pedal": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 2], + "bias": false + }, + "steering_angle": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 2], + "bias": false + } + }, + "discrete": { + "turn_signal": { + "_target_": "torchvision.ops.MLP", + "in_channels": 1152, + "hidden_channels": [384, 3], + "bias": false + } + } + } + }, + "targets": { + "continuous": { + "gas_pedal": ["input", "continuous", "gas_pedal"], + "brake_pedal": ["input", "continuous", "brake_pedal"], + "steering_angle": ["input", "continuous", "steering_angle"] + }, + "discrete": { + "turn_signal": ["input", "discrete", "turn_signal"] + } + }, + "losses": { + "_target_": "rmind.components.containers.ModuleDict", + "modules": { + "continuous": { + "gas_pedal": { + "_target_": "rmind.components.loss.GaussianNLLLoss" + }, + "brake_pedal": { + "_target_": "rmind.components.loss.GaussianNLLLoss" + }, + "steering_angle": { + "_target_": "rmind.components.loss.GaussianNLLLoss" + } + }, + "discrete": { + "turn_signal": { + "_target_": "rmind.components.loss.LogitBiasCrossEntropyLoss" + } + } + } + } + } + } + } + | .lr_scheduler = { + "interval": "step", + "scheduler": { + "_target_": "rmind.components.lr_schedulers.get_cosine_schedule_with_warmup", + "num_warmup_steps": 25000, + "num_training_steps": 250000 + } + } diff --git a/config/model/yaak/control_transformer/raw_vjepa.yaml b/config/model/yaak/control_transformer/raw_vjepa.yaml new file mode 100644 index 00000000..70869793 --- /dev/null +++ b/config/model/yaak/control_transformer/raw_vjepa.yaml @@ -0,0 +1,521 @@ +_target_: rmind.models.control_transformer.ControlTransformer +_recursive_: false +_convert_: all + +episode_builder: + _target_: rmind.components.episode.EpisodeBuilder + _recursive_: true + _convert_: all + timestep: + - [observation, image, cam_front_left] + - [observation, continuous, speed] + - [observation, context, waypoints] + - [special, foresight, cam_front_left] + - [special, summary, observation_summary] + - [special, summary, observation_history] + - [action, continuous, gas_pedal] + - [action, continuous, brake_pedal] + - [action, continuous, steering_angle] + - [action, discrete, turn_signal] + - [special, summary, action_summary] + + special_tokens: + foresight: + cam_front_left: + _target_: builtins.range + _args_: + - 256 + + summary: + observation_summary: [0] + observation_history: [1] + action_summary: [2] + + utility: + mask: [0] + + input_transform: + _target_: torch.nn.Sequential + _convert_: all + _args_: + - _target_: rmind.components.nn.Remapper + paths: + image: + cam_front_left: [data, cam_front_left] + + continuous: + speed: [data, meta/VehicleMotion/speed] + gas_pedal: [data, meta/VehicleMotion/gas_pedal_normalized] + gas_pedal_diff: [data, meta/VehicleMotion/gas_pedal_normalized] + brake_pedal: [data, meta/VehicleMotion/brake_pedal_normalized] + brake_pedal_diff: [data, meta/VehicleMotion/brake_pedal_normalized] + steering_angle: [data, meta/VehicleMotion/steering_angle_normalized] + steering_angle_diff: + [data, meta/VehicleMotion/steering_angle_normalized] + + context: + waypoints: [data, waypoints/xy_normalized] + + discrete: + turn_signal: [data, meta/VehicleState/turn_signal] + + - _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: torch.nn.Sequential + _args_: + - _target_: einops.layers.torch.Rearrange + pattern: "... h w c -> ... c h w" + - _target_: torchvision.transforms.v2.CenterCrop + size: [320, 576] + - _target_: torchvision.transforms.v2.Resize + size: [256, 256] + - _target_: torchvision.transforms.v2.ToDtype + scale: true + dtype: + _target_: hydra.utils.get_object + path: torch.float32 + - _target_: torchvision.transforms.v2.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + + continuous: + speed: + _target_: rmind.components.nn.AtLeast3D + + gas_pedal: + _target_: rmind.components.nn.AtLeast3D + + gas_pedal_diff: + _target_: rmind.components.nn.DiffLast + append: + _target_: hydra.utils.get_object + path: math.nan + + brake_pedal: + _target_: rmind.components.nn.AtLeast3D + + brake_pedal_diff: + _target_: rmind.components.nn.DiffLast + append: + _target_: hydra.utils.get_object + path: math.nan + + steering_angle: + _target_: rmind.components.nn.AtLeast3D + + steering_angle_diff: + _target_: rmind.components.nn.DiffLast + append: + _target_: hydra.utils.get_object + path: math.nan + + discrete: + _target_: rmind.components.nn.AtLeast3D + + context: + _target_: torch.nn.Identity + + tokenizers: + _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: rmind.components.nn.Identity + + continuous: + speed: + _target_: rmind.components.norm.UniformBinner + range: [0.0, 130.0] + bins: ${speed_bins} + + gas_pedal: + _target_: rmind.components.norm.UniformBinner + range: [0.0, 1.0] + bins: ${gas_pedal_bins} + + gas_pedal_diff: + # NOTE: no pre-mulaw scaling since if x in [0.0, 1.0] then dx in [-1.0, 1.0] + _target_: rmind.components.norm.MuLawEncoding + quantization_channels: ${gas_pedal_bins} + + brake_pedal: + _target_: rmind.components.norm.UniformBinner + range: [0.0, 1.0] + bins: ${brake_pedal_bins} + + brake_pedal_diff: + # NOTE: no pre-mulaw scaling since if x in [0.0, 1.0] then dx in [-1.0, 1.0] + _target_: rmind.components.norm.MuLawEncoding + quantization_channels: ${brake_pedal_bins} + + steering_angle: + _target_: rmind.components.norm.UniformBinner + range: [-1.0, 1.0] + bins: ${steering_angle_bins} + + steering_angle_diff: + _target_: rmind.components.nn.Sequential + _args_: + - _target_: rmind.components.norm.Scaler + in_range: [-2.0, 2.0] + out_range: [-1.0, 1.0] + - _target_: rmind.components.norm.MuLawEncoding + quantization_channels: ${steering_angle_bins} + + discrete: + _target_: rmind.components.nn.Identity + + context: + waypoints: + _target_: rmind.components.nn.Identity + + embeddings: + _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: torch.nn.Sequential + _args_: + - _target_: rmind.components.vjepa_backbone.VjepaBackbone + + continuous: + speed: + _target_: rmind.components.nn.Embedding + num_embeddings: ${speed_bins} + embedding_dim: ${encoder_embedding_dim} + gas_pedal: + _target_: rmind.components.nn.Embedding + num_embeddings: ${gas_pedal_bins} + embedding_dim: ${encoder_embedding_dim} + brake_pedal: + _target_: rmind.components.nn.Embedding + num_embeddings: ${brake_pedal_bins} + embedding_dim: ${encoder_embedding_dim} + steering_angle: + _target_: rmind.components.nn.Embedding + num_embeddings: ${steering_angle_bins} + embedding_dim: ${encoder_embedding_dim} + gas_pedal_diff: null + brake_pedal_diff: null + steering_angle_diff: null + + context: + waypoints: + _target_: rmind.components.nn.Linear + in_features: 2 + out_features: ${encoder_embedding_dim} + + discrete: + turn_signal: + _target_: rmind.components.nn.Embedding + num_embeddings: 3 + embedding_dim: ${encoder_embedding_dim} + + foresight: + _target_: rmind.components.nn.Embedding + num_embeddings: 256 + embedding_dim: ${encoder_embedding_dim} + + summary: + _target_: rmind.components.nn.Embedding + num_embeddings: 3 + embedding_dim: ${encoder_embedding_dim} + + utility: + _target_: rmind.components.nn.Embedding + num_embeddings: 1 + embedding_dim: ${encoder_embedding_dim} + + projections: + _target_: rmind.components.containers.ModuleDict + modules: + image: + _target_: torch.nn.Sequential + _args_: + - _target_: torch.nn.LayerNorm + normalized_shape: ${image_embedding_dim} + - _target_: rmind.components.norm.ScaleByVectorDimensionality + dim: ${image_embedding_dim} + - _target_: rmind.components.nn.Linear + in_features: ${image_embedding_dim} + out_features: ${encoder_embedding_dim} + + continuous: + speed: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + gas_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + brake_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + steering_angle: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + gas_pedal_diff: null + brake_pedal_diff: null + steering_angle_diff: null + + context: + waypoints: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + discrete: + turn_signal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + foresight: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + summary: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + utility: # do we need it? + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + role_encoding: # (modality, type) + _target_: rmind.components.nn.Embedding + num_embeddings: 8 + embedding_dim: ${encoder_embedding_dim} + + attention_mask_builder: + _target_: rmind.components.mask.FactorizedCausalAttentionMaskBuilder + +encoder: + _target_: rmind.components.transformer.TransformerEncoder + dim_model: ${encoder_embedding_dim} + num_layers: ${num_layers} + num_heads: ${num_heads} + attn_dropout: 0.1 + resid_dropout: 0.1 + mlp_dropout: 0.1 + hidden_layer_multiplier: 1 + emb_norm: + _target_: torch.nn.LayerNorm + normalized_shape: ${encoder_embedding_dim} + rope: + _target_: rmind.components.position_encoding.RotaryPositionalEmbeddings + dim: + _target_: operator.floordiv + _args_: + - ${encoder_embedding_dim} + - ${num_heads} + max_seq_len: 256 + base: 10 + +objectives: + _target_: rmind.components.containers.ModuleDict + _convert_: all + modules: + inverse_dynamics: + _target_: rmind.components.objectives.InverseDynamicsPredictionObjective + norm: + _target_: torch.nn.LayerNorm + normalized_shape: ${encoder_embedding_dim} + heads: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${gas_pedal_bins} + bias: False + brake_pedal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${brake_pedal_bins} + bias: False + steering_angle: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${steering_angle_bins} + bias: False + discrete: + turn_signal: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: 3 + bias: False + + targets: + continuous: + gas_pedal: [input_tokens, continuous, gas_pedal] + brake_pedal: [input_tokens, continuous, brake_pedal] + steering_angle: [input_tokens, continuous, steering_angle] + discrete: + turn_signal: [input_tokens, discrete, turn_signal] + + losses: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + brake_pedal: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + steering_angle: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + discrete: + turn_signal: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + + forward_dynamics: + _target_: rmind.components.objectives.ForwardDynamicsPredictionObjective + # null is intentional: the speed projection LayerNorms its inputs and + # foresight features pass through their own projection + norm: null + patch_pos_embed: + _target_: rmind.components.position_encoding.PatchPositionEmbedding2D + grid_size: [16, 16] + embedding_dim: ${encoder_embedding_dim} + projections: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + speed: + _target_: torch.nn.LayerNorm + normalized_shape: + _target_: operator.mul + _args_: + - 2 + - ${encoder_embedding_dim} + foresight: + cam_front_left: + _target_: torch.nn.Sequential + _args_: + - _target_: rmind.components.nn.Linear + in_features: + _target_: operator.mul + _args_: + - 2 + - ${encoder_embedding_dim} + out_features: ${encoder_embedding_dim} + + heads: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + speed: + _target_: rmind.components.nn.Linear + in_features: + _target_: operator.mul + _args_: + - 2 + - ${encoder_embedding_dim} + out_features: ${speed_bins} + bias: False + foresight: + cam_front_left: + _target_: rmind.components.transformer.CrossAttentionDecoderHead + decoder: + _target_: rmind.components.transformer.CrossAttentionDecoder + dim_model: ${encoder_embedding_dim} + num_layers: 2 + num_heads: 4 + attn_dropout: 0.1 + resid_dropout: 0.1 + mlp_dropout: 0.1 + hidden_layer_multiplier: 1 + output_projection: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${image_embedding_dim} + + targets: + continuous: + speed: [input_tokens, continuous, speed] + foresight: + cam_front_left: [input_embeddings, image, cam_front_left] + + losses: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + speed: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + foresight: + cam_front_left: + _target_: rmind.components.loss.GramAnchoringObjective + weight_sim: 1.0 + weight_gram: 100.0 + patches: 256 + + memory_extraction: + _target_: rmind.components.objectives.MemoryExtractionObjective + norm: + _target_: torch.nn.LayerNorm + normalized_shape: ${encoder_embedding_dim} + heads: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal_diff: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${gas_pedal_bins} + bias: False + + brake_pedal_diff: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${brake_pedal_bins} + bias: False + + steering_angle_diff: + _target_: rmind.components.nn.Linear + in_features: ${encoder_embedding_dim} + out_features: ${steering_angle_bins} + bias: False + + targets: + continuous: + gas_pedal_diff: [input_tokens, continuous, gas_pedal_diff] + brake_pedal_diff: [input_tokens, continuous, brake_pedal_diff] + steering_angle_diff: [input_tokens, continuous, steering_angle_diff] + + losses: + _target_: rmind.components.containers.ModuleDict + modules: + continuous: + gas_pedal_diff: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + + brake_pedal_diff: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + + steering_angle_diff: + _target_: rmind.components.loss.LogitBiasCrossEntropyLoss + +optimizer: + _target_: rmind.components.optimizers.SelectiveAdamW + _recursive_: true + lr: 1e-5 + betas: [0.9, 0.95] + weight_decay: 0.1 + weight_decay_module_blacklist: + - _target_: hydra.utils.get_class + path: torch.nn.Embedding + - _target_: hydra.utils.get_class + path: torch.nn.LayerNorm + +lr_scheduler: + interval: step + scheduler: + _target_: rmind.components.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 25000 + num_training_steps: 250000 diff --git a/config/trainer/callbacks/pretrain_vjepa.yaml b/config/trainer/callbacks/pretrain_vjepa.yaml new file mode 100644 index 00000000..46e545ef --- /dev/null +++ b/config/trainer/callbacks/pretrain_vjepa.yaml @@ -0,0 +1,62 @@ +- _target_: rmind.callbacks.LogitBiasSetter +# Freeze the vjepa encoder; only the rmind transformer + heads are trained. +# Mirrors V-JEPA 2 paper §3.1: "we freeze the video encoder and learn a new +# action-conditioned predictor on top of the learned representation." +# To train end-to-end instead, use trainer/callbacks=pretrain_vjepa_unfrozen. +- _target_: rmind.callbacks.ModuleFreezer + types: + - rmind.components.vjepa_backbone.VjepaBackbone +- _target_: pytorch_lightning.callbacks.TQDMProgressBar +- _target_: pytorch_lightning.callbacks.ModelSummary + max_depth: 5 +- _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_epochs: 1 + save_on_train_epoch_end: True +- _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step +- _target_: rmind.callbacks.WandbAttentionMaskLogger +- _target_: rmind.callbacks.WandbImageParamLogger + when: on_train_batch_end + every_n_batch: 100 + key: similarity + select: + - [episode_builder, embeddings, continuous] + - [episode_builder, embeddings, discrete] + - [episode_builder, role_encoding, weight] + apply: + _target_: torchmetrics.functional.pairwise_cosine_similarity + _partial_: true +- _target_: rmind.callbacks.WandbWaypointsLogger + when: on_train_batch_end + every_n_batch: 100 + key: waypoints + crs: "EPSG:25832" + data: + image: [data, cam_front_left, 0, -1] + waypoints_xy_normalized: [data, waypoints/xy_normalized, 0, -1] + waypoints_xy: [data, waypoints/xy, 0, -1] + ego_xy: [data, meta/Gnss/xy, 0, -1] + caption: + input_id: [meta, input_id, 0] + time_stamp: [data, meta/ImageMetadata.cam_front_left/time_stamp, 0, -1] + frame_idx: [data, meta/ImageMetadata.cam_front_left/frame_idx, 0, -1] +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_train_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_validation_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 diff --git a/config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml b/config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml new file mode 100644 index 00000000..d2ceaf1c --- /dev/null +++ b/config/trainer/callbacks/pretrain_vjepa_unfrozen.yaml @@ -0,0 +1,58 @@ +- _target_: rmind.callbacks.LogitBiasSetter +# End-to-end variant: VjepaBackbone is trainable alongside the rmind transformer. +# Default (frozen) config is pretrain_vjepa. Use this via: +# trainer/callbacks=pretrain_vjepa_unfrozen +- _target_: pytorch_lightning.callbacks.TQDMProgressBar +- _target_: pytorch_lightning.callbacks.ModelSummary + max_depth: 5 +- _target_: pytorch_lightning.callbacks.ModelCheckpoint + every_n_epochs: 1 + save_on_train_epoch_end: True +- _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step +- _target_: rmind.callbacks.WandbAttentionMaskLogger +- _target_: rmind.callbacks.WandbImageParamLogger + when: on_train_batch_end + every_n_batch: 100 + key: similarity + select: + - [episode_builder, embeddings, continuous] + - [episode_builder, embeddings, discrete] + - [episode_builder, role_encoding, weight] + apply: + _target_: torchmetrics.functional.pairwise_cosine_similarity + _partial_: true +- _target_: rmind.callbacks.WandbWaypointsLogger + when: on_train_batch_end + every_n_batch: 100 + key: waypoints + crs: "EPSG:25832" + data: + image: [data, cam_front_left, 0, -1] + waypoints_xy_normalized: [data, waypoints/xy_normalized, 0, -1] + waypoints_xy: [data, waypoints/xy, 0, -1] + ego_xy: [data, meta/Gnss/xy, 0, -1] + caption: + input_id: [meta, input_id, 0] + time_stamp: [data, meta/ImageMetadata.cam_front_left/time_stamp, 0, -1] + frame_idx: [data, meta/ImageMetadata.cam_front_left/frame_idx, 0, -1] +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_train_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 +- _target_: rmind.callbacks.WandbPatchSimilarityLogger + when: on_validation_batch_end + key: patch_similarities + image_sources: + cam_front_left: [foresight, cam_front_left] + embeddings_predict: [forward_dynamics, _artifacts, last_embeddings] + embeddings_target: [forward_dynamics, _artifacts, last_targets] + sample_id_path: [data, meta/sample_id] + every_n_sample: 1000 + patch_grid_size: 16 diff --git a/config/trainer/default.yaml b/config/trainer/default.yaml index f069186c..62308a7b 100644 --- a/config/trainer/default.yaml +++ b/config/trainer/default.yaml @@ -12,3 +12,5 @@ enable_model_summary: false logger: _target_: pytorch_lightning.loggers.WandbLogger log_model: all + entity: ${wandb.entity} + project: ${wandb.project} diff --git a/justfile b/justfile index 126af041..c9c35c4f 100644 --- a/justfile +++ b/justfile @@ -41,7 +41,7 @@ generate-config: --ignore-unknown-comments \ --strict -train *ARGS: generate-config check-git +train *ARGS: generate-config uv run rmind-train \ --config-path {{ justfile_directory() }}/config \ --config-name train.yaml \ diff --git a/pyproject.toml b/pyproject.toml index 95a9d286..c23f2802 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.deptry.per_rule_ignores] -DEP001 = ["rmind"] +DEP001 = ["app", "rmind"] DEP002 = [ "funcy", "torchmetrics", diff --git a/src/rmind/components/optimizers/selective_adamw.py b/src/rmind/components/optimizers/selective_adamw.py index d8f4b006..8384939a 100644 --- a/src/rmind/components/optimizers/selective_adamw.py +++ b/src/rmind/components/optimizers/selective_adamw.py @@ -40,7 +40,14 @@ def __init__( # https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/activation.py#L1091 case ( - "in_proj_weight" | "cls_token" | "reg_token" | "gamma_1" | "gamma_2" + "in_proj_weight" + | "cls_token" + | "reg_token" + | "gamma_1" + | "gamma_2" + # vjepa2 VisionTransformer modality embeddings (nn.Parameter) + | "img_mod_embed" + | "video_mod_embed" ): pass diff --git a/src/rmind/components/transformer/decoder.py b/src/rmind/components/transformer/decoder.py index 9c17688f..bb65910d 100644 --- a/src/rmind/components/transformer/decoder.py +++ b/src/rmind/components/transformer/decoder.py @@ -136,7 +136,7 @@ def forward(self, input: Input) -> Tensor: decoded = self.decoder(query_flat, context_flat) output = self.output_projection(decoded) - return output.reshape(b, t, sq, d) + return output.reshape(b, t, sq, -1) decoded = self.decoder(query, context) return self.output_projection(decoded) diff --git a/src/rmind/components/vjepa_backbone.py b/src/rmind/components/vjepa_backbone.py new file mode 100644 index 00000000..e6b25b2e --- /dev/null +++ b/src/rmind/components/vjepa_backbone.py @@ -0,0 +1,34 @@ +from math import prod +from typing import override + +import torch +from torch import Tensor, nn + + +class VjepaBackbone(nn.Module): + """V-JEPA 2.1 ViT-L/16 image encoder. + + Wraps the vjepa2 VisionTransformer so it presents the same interface as + TimmBackbone: accepts (*B, C, H, W) and returns (*B, N_patches, embed_dim). + Weights are downloaded automatically via torch.hub on first use. + """ + + def __init__(self) -> None: + super().__init__() + encoder, _ = torch.hub.load( + "facebookresearch/vjepa2", + "vjepa2_1_vit_large_384", + pretrained=True, + trust_repo=True, + ) + self.encoder: nn.Module = encoder + + @override + def forward(self, x: Tensor) -> Tensor: + # x: (*B, C, H, W) — same contract as TimmBackbone. + # Unsqueeze temporal dim so VisionTransformer takes the img_temporal_dim_size=1 + # branch (patch_embed_img + img_mod_embed) instead of the video branch. + *b, c, h, w = x.shape + x = x.view(prod(b), c, 1, h, w) + x = self.encoder(x) # (prod(B), N_patches, 1024) + return x.view(*b, x.shape[-2], x.shape[-1]) # (*B, N_patches, 1024)