Skip to content

Subtensor lift: reason about advanced indices jointly when gating#2251

Open
ricardoV94 wants to merge 3 commits into
pymc-devs:mainfrom
ricardoV94:proper_idx_gating
Open

Subtensor lift: reason about advanced indices jointly when gating#2251
ricardoV94 wants to merge 3 commits into
pymc-devs:mainfrom
ricardoV94:proper_idx_gating

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Jun 21, 2026

Copy link
Copy Markdown
Member

We had heuristics that work per-axis, but we should reason about the advanced indices together since they interact. E.g., tril_indices, would fail the per-axis checks, as each index (row, col) is larger than the respective axis and contains duplicates. But jointly they index less than the full array and don't have duplicates, so we want to lift it through Elemwise and the like.

Shows up in pymc-devs/pymc#8297

@ricardoV94 ricardoV94 force-pushed the proper_idx_gating branch 3 times, most recently from 483c1b9 to f63da68 Compare June 25, 2026 12:38
@ricardoV94 ricardoV94 marked this pull request as ready for review June 25, 2026 13:06
@ricardoV94 ricardoV94 force-pushed the proper_idx_gating branch 2 times, most recently from fe2c69f to f7ceb7d Compare June 27, 2026 10:56
and idx.type.ndim > 0
and not inp.type.broadcastable[axis]
]
if not _indices_provably_not_larger(adv_indices, fgraph):

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this take the variable for the static shape

continue
if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph):
return None
adv_indices = [

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be called adv_indices and input_static_dim

Generalize the duplicate-free reasoning used to gate the advanced-index
write rewrites from a per-axis check to a joint one over the whole advanced
index group, sharing the logic across subtensor.py and subtensor_lift.py.

- _index_provably_unique: per-axis uniqueness, now also proving single-signed
  aranges and views (Reshape/DimShuffle) that preserve the value multiset.
- _indices_jointly_unique: distinct joint coordinate tuples via all-axes
  uniqueness, a single Nonzero (e.g. tril_indices), or jointly-unique
  constants where no single axis is unique on its own.
- _indices_provably_not_larger: bound a gather by the indexed axes' size
  using static shapes, per-axis bounds, or joint uniqueness.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant