[Backgammon] legal_action_mask: scatter once instead of a (26, 156) intermediate#1323
Open
gweber wants to merge 2 commits into
Open
[Backgammon] legal_action_mask: scatter once instead of a (26, 156) intermediate#1323gweber wants to merge 2 commits into
gweber wants to merge 2 commits into
Conversation
…ntermediate _legal_action_mask_for_valid_single_dice built 26 full (26*6,) zero vectors — one per src — set a single bit in each, then OR-reduced them. Compute the 26 src legalities as a (26,) vector and scatter them in one shot. Behaviour identical; ~2x faster. Verified: outputs match the previous implementation across 400 boards x 6 dice and the full mask over 4 dice-sets (0 mismatches); ~1.99x faster full _legal_action_mask (B=256, CPU).
…e per-action vmap _is_to_off_legal recomputed _rear_distance(board) and _is_all_on_home_board(board) for every candidate action, though both depend only on the board. Compute them once per board in the legal-mask leaf and thread them in. _is_action_legal keeps its 2-arg form (computes them itself when omitted), so external callers and tests are unaffected. ~1.12x faster full _legal_action_mask on top of the previous scatter change. Identical output across 300 boards x 5 dice-sets; tests/test_backgammon.py passes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
_legal_action_mask_for_valid_single_dicecomputed the per-die legal mask like this:i.e. it materialized a
(26, 26*6)intermediate (one zero vector per src, each with a single bit set) and OR-reduced it. This computes the 26 src legalities as a(26,)vector and scatters them into the mask in one shot:Behaviour & perf
Pure refactor, no behavioural change. Verified outputs match the previous implementation across 400 boards × 6 dice and the full
_legal_action_maskover 4 dice-sets (0 mismatches). The full_legal_action_mask— called on every step (_init,_update_by_action,_change_turn) — is ~2× faster (0.61 → 0.31 ms, B=256, CPU).tests/test_backgammon.py: 20 passed (the onetest_apifailure is pre-existing and reproduces unchanged onmain— a JAX-version issue unrelated to this change).