Skip to content

[Backgammon] legal_action_mask: scatter once instead of a (26, 156) intermediate#1323

Open
gweber wants to merge 2 commits into
sotetsuk:mainfrom
gweber:backgammon-legal-mask-scatter
Open

[Backgammon] legal_action_mask: scatter once instead of a (26, 156) intermediate#1323
gweber wants to merge 2 commits into
sotetsuk:mainfrom
gweber:backgammon-legal-mask-scatter

Conversation

@gweber

@gweber gweber commented Jun 10, 2026

Copy link
Copy Markdown

What

_legal_action_mask_for_valid_single_dice computed the per-die legal mask like this:

def _is_legal(idx):
    action = idx * 6 + die
    m = jnp.zeros(26 * 6, dtype=jnp.bool_)      # a full 156-wide zero vector ...
    return m.at[action].set(_is_action_legal(board, action))   # ... to set a single bit
legal = jax.vmap(_is_legal)(jnp.arange(26)).any(axis=0)

i.e. it materialized a (26, 26*6) intermediate (one zero vector per src, each with a single bit set) and OR-reduced it. This computes the 26 src legalities as a (26,) vector and scatters them into the mask in one shot:

actions = jnp.arange(26, dtype=jnp.int32) * 6 + die
legal = jax.vmap(lambda a: _is_action_legal(board, a))(actions)
return jnp.zeros(26 * 6, dtype=jnp.bool_).at[actions].set(legal)

Behaviour & perf

Pure refactor, no behavioural change. Verified outputs match the previous implementation across 400 boards × 6 dice and the full _legal_action_mask over 4 dice-sets (0 mismatches). The full _legal_action_mask — called on every step (_init, _update_by_action, _change_turn) — is ~2× faster (0.61 → 0.31 ms, B=256, CPU). tests/test_backgammon.py: 20 passed (the one test_api failure is pre-existing and reproduces unchanged on main — a JAX-version issue unrelated to this change).

gweber added 2 commits June 10, 2026 20:04
…ntermediate

_legal_action_mask_for_valid_single_dice built 26 full (26*6,) zero vectors
— one per src — set a single bit in each, then OR-reduced them. Compute the
26 src legalities as a (26,) vector and scatter them in one shot.

Behaviour identical; ~2x faster. Verified: outputs match the previous
implementation across 400 boards x 6 dice and the full mask over 4 dice-sets
(0 mismatches); ~1.99x faster full _legal_action_mask (B=256, CPU).
…e per-action vmap

_is_to_off_legal recomputed _rear_distance(board) and _is_all_on_home_board(board)
for every candidate action, though both depend only on the board. Compute them
once per board in the legal-mask leaf and thread them in. _is_action_legal keeps
its 2-arg form (computes them itself when omitted), so external callers and tests
are unaffected.

~1.12x faster full _legal_action_mask on top of the previous scatter change.
Identical output across 300 boards x 5 dice-sets; tests/test_backgammon.py passes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant