diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b7e27eb6a..9927994ae2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -240,28 +240,11 @@ jobs: python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"' env: PYTHON_VERSION: 3.12 - - name: Download previous benchmark data - uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: ./cache - key: ${{ runner.os }}-benchmark - name: Run benchmarks shell: micromamba-shell {0} run: | export PYTENSOR_FLAGS=warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe - python -m pytest --runslow --benchmark-only --benchmark-json output.json - - name: Store benchmark result - uses: benchmark-action/github-action-benchmark@52576c92bccf6ac60c8223ec7eb2565637cae9ba # v1.22.1 - with: - name: Python Benchmark with pytest-benchmark - tool: "pytest" - output-file-path: output.json - external-data-json-path: ./cache/benchmark-data.json - alert-threshold: "200%" - github-token: ${{ secrets.GITHUB_TOKEN }} - comment-on-alert: false - fail-on-alert: false - auto-push: false + python -m pytest --runslow --benchmark-only all-checks: if: ${{ always() }} diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index d3a4bf5bf6..0f242577a9 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -27,6 +27,7 @@ Eye, Join, MakeVector, + Nonzero, alloc, arange, as_tensor, @@ -244,6 +245,77 @@ def _index_provably_not_larger(idx, val_static_dim, fgraph=None) -> bool: return bool(np.prod(idx_static_shape) < val_static_dim) +def _constants_jointly_unique(consts) -> bool: + """Whether stacked constant indices have no duplicate coordinate tuples. + + The stacked ``np.unique`` can be expensive on large indices, so the result + is cached on the first constant's tag. Uniqueness is a property of the whole + group, and a constant may belong to several groups (constants are shared + across the graph), so the cache is keyed by the group's identities rather + than a single flag. + """ + key = tuple(id(c) for c in consts) + cache = getattr(consts[0].tag, "jointly_unique_indices", None) + if cache is None: + cache = consts[0].tag.jointly_unique_indices = {} + if key not in cache: + datas = [np.asarray(c.data) for c in consts] + # A coordinate axis that mixes positive and negative values may alias + # (``0`` and ``-dim`` are the same position), so distinctness of the raw + # values no longer proves distinctness of the coordinates. + if any((data >= 0).any() and (data < 0).any() for data in datas): + cache[key] = False + else: + coords = np.broadcast_arrays(*datas) + stacked = np.stack([coord.ravel() for coord in coords]) + cache[key] = bool(np.unique(stacked, axis=1).shape[1] == stacked.shape[1]) + return bool(cache[key]) + + +def _indices_provably_not_larger(idxs_and_dims, fgraph) -> bool: + """Whether advanced-indexing some consecutive axes selects no more elements + than those axes already hold, so lifting a Subtensor through the indexing + can't increase computation. + + ``idxs_and_dims`` pairs each advanced index (``ndim > 0``) with the static + size of the axis it indexes. + """ + if not idxs_and_dims: + return True + + idxs = [idx for idx, _ in idxs_and_dims] + dims = [dim for _, dim in idxs_and_dims] + idx_shapes = [idx.type.shape for idx in idxs] + + # With static shapes the result size is known exactly, so just compare it + # against the number of elements the indexed axes hold. + if all(d is not None for d in dims) and all( + None not in shape for shape in idx_shapes + ): + return bool(np.prod(np.broadcast_shapes(*idx_shapes)) <= np.prod(dims)) + + # Otherwise fall back to proving the indices are duplicate-free, which on its + # own bounds the result by the axes' size, even when the sizes are unknown: + # - each index repeats no position on its own axis, or + if all(_index_provably_not_larger(idx, dim, fgraph) for idx, dim in idxs_and_dims): + return True + if len(idxs) > 1: + # - the indices are all the coordinates of one Nonzero, distinct by + # construction (e.g. symbolic tril_indices), or + owners = {idx.owner for idx in idxs} + if ( + len(owners) == 1 + and (owner := next(iter(owners))) is not None + and isinstance(owner.op, Nonzero) + and set(idxs) == set(owner.outputs) + ): + return True + # - the constant coordinate tuples have no duplicates. + if all(isinstance(idx, Constant) for idx in idxs): + return _constants_jointly_unique(idxs) + return False + + @register_canonicalize @register_stabilize @register_specialize @@ -345,17 +417,19 @@ def local_subtensor_of_batch_dims(fgraph, node): if _non_consecutive_adv_indexing(idx_tuple): return None - # Skip when lifting would expand a gather past a non-broadcast input's size. + # Skip when indexing each input would select more elements than it holds, + # making the lifted Elemwise do more work. The advanced indices are weighed + # together, over the consecutive axes they jointly index. for inp in elem.owner.inputs: - for axis, idx in enumerate(idx_tuple): - if axis >= inp.type.ndim: - break - if not isinstance(idx, TensorVariable) or idx.type.ndim == 0: - continue - if inp.type.broadcastable[axis]: - continue - if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph): - return None + adv_indices = [ + (idx, inp.type.shape[axis]) + for axis, idx in enumerate(idx_tuple[: inp.type.ndim]) + if isinstance(idx, TensorVariable) + and idx.type.ndim > 0 + and not inp.type.broadcastable[axis] + ] + if not _indices_provably_not_larger(adv_indices, fgraph): + return None batch_ndim = ( elem.owner.op.batch_ndim(elem.owner) @@ -742,11 +816,15 @@ def lift_subtensor_through_alloc(fgraph, node): # Indices on Alloc-added dims don't reach val; the rest line up with val's dims. val_indexer = indices[n_added_dims:] - dangerous_index_reaches_val = any( - not val.type.broadcastable[axis] - # Per-axis check; doesn't account for net effect across all axes. - and not _index_provably_not_larger(idx, val.type.shape[axis], fgraph) + val_adv_indices = [ + (idx, val.type.shape[axis]) for axis, idx in enumerate(val_indexer) + if isinstance(idx, TensorVariable) + and idx.type.ndim > 0 + and not val.type.broadcastable[axis] + ] + dangerous_index_reaches_val = not _indices_provably_not_larger( + val_adv_indices, fgraph ) # On broadcast val dims the index is neutralized (advanced indices dropped, diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 1c1cbb7bb0..3de57d4e05 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -2753,8 +2753,9 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): packed = pt.vector("packed") if exp_before_materialize: - # We test the same optimized result regardless of whether - # the diagonals are updated before or after materialization + # Same ``L`` two ways: exponentiate the diagonal in the packed vector + # before scattering, or (else branch) scatter first and exponentiate the + # matrix diagonal. Equivalent, but optimize to different graphs under BlasOpt. packed_diag_indices = pt.arange(n + 1).cumsum()[1:] - 1 log_diag = packed[packed_diag_indices] packed_update = packed[packed_diag_indices].set(pt.exp(log_diag)) @@ -2778,7 +2779,6 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): mode = get_default_mode().excluding("fuse_indexed_into_elemwise") f = function([packed], [loss, grad], mode=mode) - f.dprint(print_shape=True) idx_types = ( Subtensor, @@ -2790,13 +2790,16 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): ExtractDiag, ) n_idx = sum(1 for n in f.maker.fgraph.toposort() if isinstance(n.op, idx_types)) - # The ``BlasOpt`` rewrites lower ``L @ L.T`` to ``Gemm``; the gradient then - # fuses the diagonal-gradient term into a ``Gemm`` operand, materializing one - # extra set-subtensor. A linker that cannot use them lists ``BlasOpt`` in - # ``incompatible_rewrites`` (e.g. the numba linker), keeping the plain ``Dot`` - # lowering with that term as a vector. Both lowerings are correct. + # Post-materialization, the log-det gradient is a diagonal matrix added to + # ``L@L.T`` at the matrix level; BlasOpt's GemmOptimizer fuses ``add(dot, C)`` + # into one Gemm, materializing that diagonal as one extra set-subtensor (7 ops). + # Pre-materialization keeps the term in the packed vector's index space, so + # there's no matrix-level add to fuse (6 ops). Without Gemm (BlasOpt in the + # linker's incompatible_rewrites, e.g. numba) the term never reaches the matrix + # and both collapse to 6. All lowerings are correct. blas_rewrites_run = "BlasOpt" not in f.maker.mode.linker.incompatible_rewrites - assert n_idx == (7 if blas_rewrites_run else 6) + expected_n_idx = 7 if (blas_rewrites_run and not exp_before_materialize) else 6 + assert n_idx == expected_n_idx x = np.array([1.0, 0.5, 2.0, 0.3, 0.1, 1.5]) # Expected values were computed once by running ``f(x)``. diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index fabe875e38..a4210059ac 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -234,6 +234,36 @@ def test_elemwise_adv_index_assumed_unique_lifts(self): ) result.assert_graph(x[idx] + y[idx]) + def test_elemwise_jointly_unique_adv_indices_lift(self): + """A group of adv indices that each repeat but pair up to distinct + coordinates (tril_indices) can't select more elements than the indexed + axes hold, so it lifts.""" + # Symbolic indices: the outputs of a single Nonzero. + n = pt.scalar("n", dtype="int64") + x = pt.matrix("x") + rows, cols = pt.tril_indices(n) + # ``local_add_canonizer`` simplifies the ``n - 0`` inside ``tril_indices``, + # which then lets the duplicate ``arange`` merge -- noise for this test. + result = RewriteTester( + [n, x], [pt.exp(x)[rows, cols]], exclude=("local_add_canonizer",) + ) + result.assert_graph(pt.exp(x[rows, cols])) + result.assert_eval(3, np.arange(9.0).reshape(3, 3)) + + # Constant indices, static array shape: proved through the exact size. + x = pt.matrix("x", shape=(5, 5)) + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + result = RewriteTester([x], [pt.exp(x)[rows, cols]]) + result.assert_graph(pt.exp(x[rows, cols])) + result.assert_eval(np.arange(25.0).reshape(5, 5)) + + # Constant indices, unknown array shape: proved through joint uniqueness. + x = pt.matrix("x") + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + result = RewriteTester([x], [pt.exp(x)[rows, cols]]) + result.assert_graph(pt.exp(x[rows, cols])) + result.assert_eval(np.arange(25.0).reshape(5, 5)) + def test_blockwise(self): class CoreTestOp(Op): itypes = [dvector, dvector] @@ -755,6 +785,21 @@ def test_const_idx_with_duplicates_bails(self): rewritten = rewrite_graph(out, **self.rewrite_kw) assert_equal_computations([rewritten], [out], strict_dtype=False) + def test_jointly_unique_adv_indices_lift(self): + """Indices that each repeat but pair up to distinct coordinates + (tril_indices) don't enlarge val, so the read lifts through Alloc.""" + val = pt.matrix("val", shape=(5, 5)) + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + + result = RewriteTester( + [val], + [pt.alloc(val, 5, 5)[rows, cols]], + include=("ShapeOpt", "canonicalize", "specialize"), + exclude=("local_replace_AdvancedSubtensor",), + ) + result.assert_graph(val[rows, cols], strict_dtype=False) + result.assert_eval(np.arange(25.0).reshape(5, 5)) + def test_negative_step_idx_to_slice(self): """Negative-step constant arange ``[7, 5, 3, 1]`` rewrites to ``x[7::-2]``.""" x = pt.vector("x", shape=(10,))