Skip to content

Make Zobrist hashing Apple Metal (jax-metal) compatible#1317

Open
gweber wants to merge 3 commits into
sotetsuk:mainfrom
gweber:metal-compat
Open

Make Zobrist hashing Apple Metal (jax-metal) compatible#1317
gweber wants to merge 3 commits into
sotetsuk:mainfrom
gweber:metal-compat

Conversation

@gweber

@gweber gweber commented Jun 10, 2026

Copy link
Copy Markdown

Problem

chess and go fail to run on the Apple Metal backend (jax-metal):

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: .../games/chess.py:408:13:
  error: failed to legalize operation 'mhlo.reduce'
    hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,))

The Metal XLA backend cannot legalize a lax.reduce whose reducer is bitwise_xor
(reduce-sum/max are fine; the custom bitwise-xor reduction is not). This affects the
Zobrist hashing in both chess (2 sites) and go (1 site) — the only lax.reduce
bitwise-xor uses in the codebase.

Fix

Add pgx._src.utils.xor_reduce, which computes the identical XOR-reduction via per-bit
parity built from sum + shifts (all Metal-supported), and use it in chess/go Zobrist
hashing. Numerically identical on every backend (verified bit-for-bit vs the original
lax.reduce on random uint32 arrays); no behaviour change on CPU/GPU/TPU.

Result

chess and go init/step and full (mctx Gumbel) self-play now run on Apple Silicon GPUs.
Existing chess legality was cross-checked against python-chess across 200+ positions
(castling/en-passant/promotions) with the change applied.

lax.reduce with bitwise_xor fails to legalize on the Metal XLA backend
(UNIMPLEMENTED: mhlo.reduce). Replace the XOR-reduction in chess and go
Zobrist hashing with an equivalent bit-parity computation built from sum/
shifts (utils.xor_reduce), numerically identical on every backend. This
unblocks chess and go self-play on Apple Silicon GPUs.
Keep the per-bit-parity path (which makes the reduction legalize on jax-metal) only on
the Metal backend, and use the native bitwise_xor reduction on CPU/CUDA/TPU. The bit-parity
fallback expands every value into its bits, which is ~15x slower on CPU and ~100x on CUDA;
since this runs once per env step for Zobrist hashing it measurably slowed chess/go on the
non-Metal backends. Backend is resolved at trace time, so jitted code pays no runtime cost.
Hashes are bit-identical on all backends.
@gweber

gweber commented Jun 12, 2026

Copy link
Copy Markdown
Author

Updated so the per-bit-parity path is used only on Metal, with the native bitwise_xor reduction on CPU/CUDA/TPU (backend resolved at trace time, no runtime cost).

The bit-parity fallback expands every value into its bits, which measurably slows Zobrist hashing on the non-Metal backends since it runs once per env step. Full env.step benchmarks (B=2048):

env backend before after speedup
go_19x19 CPU 69.0 ms 38.2 ms 1.81×
chess CPU 41.0 ms 36.9 ms 1.11×

Hashes are bit-identical on all backends and the go/chess counting + scoring tests still pass. This keeps Metal compatibility without regressing the backends most people run on.

Harden the backend check from a default_backend() string match to a device SIGNAL (platform
+ device_kind contains 'metal'/'apple'), so a CUDA device reporting platform 'gpu' is never
misread as Metal, and a Metal device reporting platform 'gpu' is still caught by its Apple
device_kind. Add a PGX_XOR_REDUCE=native|parity escape hatch; fall back to the parity path
if devices can't be introspected. No behavior change on CPU/CUDA/Metal (verified
native==parity==ground-truth).
gweber added a commit to gweber/pgx that referenced this pull request Jun 12, 2026
…ty split)

Adopt the same implementation as PR sotetsuk#1317 (metal-compat b4d97c6) so mushin and the upstream
branch share ONE xor_reduce and won't conflict when metal-compat lands. Functionally
identical to the prior mushin version (native on CUDA/CPU, parity only on Metal, signal-based
detection + PGX_XOR_REDUCE override); only factors the Metal fallback into a helper.
gweber added a commit to gweber/pgx that referenced this pull request Jun 12, 2026
…es + docs

Brings the Mac lineage (CPU-perf pass +55-70% Go, tiered Bloom-PSK [Bloom pre-filter + exact
recent window], narrower int8/int16 dtypes, segment_sum chain accumulation, chain-stats cache,
chess king-danger prune / occupancy-bitboard / float16 obs, backgammon legal-mask opts, fork
documentation, benchmark harness) together with this branch's Layer-Go (board-wide legal mask,
hash-based superko) + backend-aware xor_reduce.

Only conflict was utils.py: kept the canonical native-dispatch xor_reduce (PR sotetsuk#1317:
_xor_reduce_bitparity + _use_native_xor + PGX_XOR_REDUCE override) AND added the Mac's
bloom_insert/bloom_query helpers (used by the tiered Go PSK). Dropped the Mac's superseded
bit-parity-only xor_reduce. Verified: go_19x19/go_9x9/chess vmapped steps run.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant