Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,28 +240,11 @@ jobs:
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env:
PYTHON_VERSION: 3.12
- name: Download previous benchmark data
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ./cache
key: ${{ runner.os }}-benchmark
- name: Run benchmarks
shell: micromamba-shell {0}
run: |
export PYTENSOR_FLAGS=warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest --runslow --benchmark-only --benchmark-json output.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@52576c92bccf6ac60c8223ec7eb2565637cae9ba # v1.22.1
with:
name: Python Benchmark with pytest-benchmark
tool: "pytest"
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
alert-threshold: "200%"
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: false
fail-on-alert: false
auto-push: false
python -m pytest --runslow --benchmark-only

all-checks:
if: ${{ always() }}
Expand Down
106 changes: 92 additions & 14 deletions pytensor/tensor/rewriting/subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Eye,
Join,
MakeVector,
Nonzero,
alloc,
arange,
as_tensor,
Expand Down Expand Up @@ -244,6 +245,77 @@ def _index_provably_not_larger(idx, val_static_dim, fgraph=None) -> bool:
return bool(np.prod(idx_static_shape) < val_static_dim)


def _constants_jointly_unique(consts) -> bool:
"""Whether stacked constant indices have no duplicate coordinate tuples.

The stacked ``np.unique`` can be expensive on large indices, so the result
is cached on the first constant's tag. Uniqueness is a property of the whole
group, and a constant may belong to several groups (constants are shared
across the graph), so the cache is keyed by the group's identities rather
than a single flag.
"""
key = tuple(id(c) for c in consts)
cache = getattr(consts[0].tag, "jointly_unique_indices", None)
if cache is None:
cache = consts[0].tag.jointly_unique_indices = {}
if key not in cache:
datas = [np.asarray(c.data) for c in consts]
# A coordinate axis that mixes positive and negative values may alias
# (``0`` and ``-dim`` are the same position), so distinctness of the raw
# values no longer proves distinctness of the coordinates.
if any((data >= 0).any() and (data < 0).any() for data in datas):
cache[key] = False
else:
coords = np.broadcast_arrays(*datas)
stacked = np.stack([coord.ravel() for coord in coords])
cache[key] = bool(np.unique(stacked, axis=1).shape[1] == stacked.shape[1])
return bool(cache[key])


def _indices_provably_not_larger(idxs_and_dims, fgraph) -> bool:
"""Whether advanced-indexing some consecutive axes selects no more elements
than those axes already hold, so lifting a Subtensor through the indexing
can't increase computation.

``idxs_and_dims`` pairs each advanced index (``ndim > 0``) with the static
size of the axis it indexes.
"""
if not idxs_and_dims:
return True

idxs = [idx for idx, _ in idxs_and_dims]
dims = [dim for _, dim in idxs_and_dims]
idx_shapes = [idx.type.shape for idx in idxs]

# With static shapes the result size is known exactly, so just compare it
# against the number of elements the indexed axes hold.
if all(d is not None for d in dims) and all(
None not in shape for shape in idx_shapes
):
return bool(np.prod(np.broadcast_shapes(*idx_shapes)) <= np.prod(dims))

# Otherwise fall back to proving the indices are duplicate-free, which on its
# own bounds the result by the axes' size, even when the sizes are unknown:
# - each index repeats no position on its own axis, or
if all(_index_provably_not_larger(idx, dim, fgraph) for idx, dim in idxs_and_dims):
return True
if len(idxs) > 1:
# - the indices are all the coordinates of one Nonzero, distinct by
# construction (e.g. symbolic tril_indices), or
owners = {idx.owner for idx in idxs}
if (
len(owners) == 1
and (owner := next(iter(owners))) is not None
and isinstance(owner.op, Nonzero)
and set(idxs) == set(owner.outputs)
):
return True
# - the constant coordinate tuples have no duplicates.
if all(isinstance(idx, Constant) for idx in idxs):
return _constants_jointly_unique(idxs)
return False


@register_canonicalize
@register_stabilize
@register_specialize
Expand Down Expand Up @@ -345,17 +417,19 @@ def local_subtensor_of_batch_dims(fgraph, node):
if _non_consecutive_adv_indexing(idx_tuple):
return None

# Skip when lifting would expand a gather past a non-broadcast input's size.
# Skip when indexing each input would select more elements than it holds,
# making the lifted Elemwise do more work. The advanced indices are weighed
# together, over the consecutive axes they jointly index.
for inp in elem.owner.inputs:
for axis, idx in enumerate(idx_tuple):
if axis >= inp.type.ndim:
break
if not isinstance(idx, TensorVariable) or idx.type.ndim == 0:
continue
if inp.type.broadcastable[axis]:
continue
if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph):
return None
adv_indices = [
(idx, inp.type.shape[axis])
for axis, idx in enumerate(idx_tuple[: inp.type.ndim])
if isinstance(idx, TensorVariable)
and idx.type.ndim > 0
and not inp.type.broadcastable[axis]
]
if not _indices_provably_not_larger(adv_indices, fgraph):
return None

batch_ndim = (
elem.owner.op.batch_ndim(elem.owner)
Expand Down Expand Up @@ -742,11 +816,15 @@ def lift_subtensor_through_alloc(fgraph, node):

# Indices on Alloc-added dims don't reach val; the rest line up with val's dims.
val_indexer = indices[n_added_dims:]
dangerous_index_reaches_val = any(
not val.type.broadcastable[axis]
# Per-axis check; doesn't account for net effect across all axes.
and not _index_provably_not_larger(idx, val.type.shape[axis], fgraph)
val_adv_indices = [
(idx, val.type.shape[axis])
for axis, idx in enumerate(val_indexer)
if isinstance(idx, TensorVariable)
and idx.type.ndim > 0
and not val.type.broadcastable[axis]
]
dangerous_index_reaches_val = not _indices_provably_not_larger(
val_adv_indices, fgraph
)

# On broadcast val dims the index is neutralized (advanced indices dropped,
Expand Down
21 changes: 12 additions & 9 deletions tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2753,8 +2753,9 @@ def test_cholesky_unconstrain_grad(exp_before_materialize):

packed = pt.vector("packed")
if exp_before_materialize:
# We test the same optimized result regardless of whether
# the diagonals are updated before or after materialization
# Same ``L`` two ways: exponentiate the diagonal in the packed vector
# before scattering, or (else branch) scatter first and exponentiate the
# matrix diagonal. Equivalent, but optimize to different graphs under BlasOpt.
packed_diag_indices = pt.arange(n + 1).cumsum()[1:] - 1
log_diag = packed[packed_diag_indices]
packed_update = packed[packed_diag_indices].set(pt.exp(log_diag))
Expand All @@ -2778,7 +2779,6 @@ def test_cholesky_unconstrain_grad(exp_before_materialize):

mode = get_default_mode().excluding("fuse_indexed_into_elemwise")
f = function([packed], [loss, grad], mode=mode)
f.dprint(print_shape=True)

idx_types = (
Subtensor,
Expand All @@ -2790,13 +2790,16 @@ def test_cholesky_unconstrain_grad(exp_before_materialize):
ExtractDiag,
)
n_idx = sum(1 for n in f.maker.fgraph.toposort() if isinstance(n.op, idx_types))
# The ``BlasOpt`` rewrites lower ``L @ L.T`` to ``Gemm``; the gradient then
# fuses the diagonal-gradient term into a ``Gemm`` operand, materializing one
# extra set-subtensor. A linker that cannot use them lists ``BlasOpt`` in
# ``incompatible_rewrites`` (e.g. the numba linker), keeping the plain ``Dot``
# lowering with that term as a vector. Both lowerings are correct.
# Post-materialization, the log-det gradient is a diagonal matrix added to
# ``L@L.T`` at the matrix level; BlasOpt's GemmOptimizer fuses ``add(dot, C)``
# into one Gemm, materializing that diagonal as one extra set-subtensor (7 ops).
# Pre-materialization keeps the term in the packed vector's index space, so
# there's no matrix-level add to fuse (6 ops). Without Gemm (BlasOpt in the
# linker's incompatible_rewrites, e.g. numba) the term never reaches the matrix
# and both collapse to 6. All lowerings are correct.
blas_rewrites_run = "BlasOpt" not in f.maker.mode.linker.incompatible_rewrites
assert n_idx == (7 if blas_rewrites_run else 6)
expected_n_idx = 7 if (blas_rewrites_run and not exp_before_materialize) else 6
assert n_idx == expected_n_idx

x = np.array([1.0, 0.5, 2.0, 0.3, 0.1, 1.5])
# Expected values were computed once by running ``f(x)``.
Expand Down
45 changes: 45 additions & 0 deletions tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,36 @@ def test_elemwise_adv_index_assumed_unique_lifts(self):
)
result.assert_graph(x[idx] + y[idx])

def test_elemwise_jointly_unique_adv_indices_lift(self):
"""A group of adv indices that each repeat but pair up to distinct
coordinates (tril_indices) can't select more elements than the indexed
axes hold, so it lifts."""
# Symbolic indices: the outputs of a single Nonzero.
n = pt.scalar("n", dtype="int64")
x = pt.matrix("x")
rows, cols = pt.tril_indices(n)
# ``local_add_canonizer`` simplifies the ``n - 0`` inside ``tril_indices``,
# which then lets the duplicate ``arange`` merge -- noise for this test.
result = RewriteTester(
[n, x], [pt.exp(x)[rows, cols]], exclude=("local_add_canonizer",)
)
result.assert_graph(pt.exp(x[rows, cols]))
result.assert_eval(3, np.arange(9.0).reshape(3, 3))

# Constant indices, static array shape: proved through the exact size.
x = pt.matrix("x", shape=(5, 5))
rows, cols = (pt.constant(i) for i in np.tril_indices(5))
result = RewriteTester([x], [pt.exp(x)[rows, cols]])
result.assert_graph(pt.exp(x[rows, cols]))
result.assert_eval(np.arange(25.0).reshape(5, 5))

# Constant indices, unknown array shape: proved through joint uniqueness.
x = pt.matrix("x")
rows, cols = (pt.constant(i) for i in np.tril_indices(5))
result = RewriteTester([x], [pt.exp(x)[rows, cols]])
result.assert_graph(pt.exp(x[rows, cols]))
result.assert_eval(np.arange(25.0).reshape(5, 5))

def test_blockwise(self):
class CoreTestOp(Op):
itypes = [dvector, dvector]
Expand Down Expand Up @@ -755,6 +785,21 @@ def test_const_idx_with_duplicates_bails(self):
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [out], strict_dtype=False)

def test_jointly_unique_adv_indices_lift(self):
"""Indices that each repeat but pair up to distinct coordinates
(tril_indices) don't enlarge val, so the read lifts through Alloc."""
val = pt.matrix("val", shape=(5, 5))
rows, cols = (pt.constant(i) for i in np.tril_indices(5))

result = RewriteTester(
[val],
[pt.alloc(val, 5, 5)[rows, cols]],
include=("ShapeOpt", "canonicalize", "specialize"),
exclude=("local_replace_AdvancedSubtensor",),
)
result.assert_graph(val[rows, cols], strict_dtype=False)
result.assert_eval(np.arange(25.0).reshape(5, 5))

def test_negative_step_idx_to_slice(self):
"""Negative-step constant arange ``[7, 5, 3, 1]`` rewrites to ``x[7::-2]``."""
x = pt.vector("x", shape=(10,))
Expand Down
Loading