Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,54 @@ def local_reduce_join(fgraph, node):
return [ret]


@register_specialize
@register_canonicalize
@register_uncanonicalize # Needed for Min which is formed from Neg(Max(Neg))
@node_rewriter([CAReduce])
def local_careduce_join(fgraph, node):
r"""CAReduce(Join(axis, \*tensors), axis=ax) -> Elemwise{scalar.op}(\*[CAReduce(axis=ax, t) for t in tensors])

When the reduction axis includes the join axis (or reduces all elements),
this avoids creating the concatenated intermediate array.

For >2 joined inputs, only scalar ops with variadic support
(Add, Mul) are rewritten, since Elemwise can't combine >2
Comment thread
williambdean marked this conversation as resolved.
Outdated
binary-only ops (e.g. Maximum, Minimum) at once.

"""
[joined_out] = node.inputs
if joined_out.owner is None or not isinstance(joined_out.owner.op, Join):
Comment thread
williambdean marked this conversation as resolved.
Outdated
return None

join_axis_tensor, *joined_inputs = joined_out.owner.inputs

if len(joined_inputs) < 2:
return None

if not isinstance(join_axis_tensor, Constant):
return None

if len(joined_inputs) > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
return None

join_axis = int(join_axis_tensor.data)
if join_axis < 0:
join_axis += joined_out.type.ndim

reduce_op = node.op
if reduce_op.axis is not None and join_axis not in reduce_op.axis:
return None

reduced = [reduce_op.clone(axis=reduce_op.axis)(inp) for inp in joined_inputs]
Comment thread
williambdean marked this conversation as resolved.
Outdated
ret = Elemwise(reduce_op.scalar_op)(*reduced)

if ret.dtype != node.outputs[0].dtype:
return None

copy_stack_trace(node.outputs[0], ret)
return [ret]


@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
Expand Down
95 changes: 92 additions & 3 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytensor.graph.traversal import ancestors
from pytensor.printing import debugprint, pprint
from pytensor.scalar import PolyGamma, Psi, TriGamma
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -3500,12 +3500,12 @@ def test_type(self):
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)

# This case could be rewritten
# Join axis is included in reduction axes
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
assert isinstance(topo[-1].op, Elemwise)

A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
Expand Down Expand Up @@ -3559,6 +3559,95 @@ def test_non_ds_inputs(self):
expected_out = add(exp(x), log(x))
assert equal_computations([rewritten_out], [expected_out])

def test_careduce_join_sum_2(self):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the kind of approach to writing rewrite tests we're trying to settle on: #2103

"""Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs"""
x, y = vectors("xy")
xv, yv = (
np.random.rand(100).astype(config.floatX),
np.random.rand(200).astype(config.floatX),
)
out = pt_sum(pt.concatenate([x, y]), axis=None)
f = function([x, y], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv])))
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_sum_3(self):
"""Sum(concat(a, b, c), axis=None) -> Add(Sum(a), Sum(b), Sum(c)) with 3 inputs (variadic combine)"""
x, y, z = vectors("xyz")
xv = np.random.rand(100).astype(config.floatX)
yv = np.random.rand(150).astype(config.floatX)
zv = np.random.rand(200).astype(config.floatX)
out = pt_sum(pt.concatenate([x, y, z]), axis=None)
f = function([x, y, z], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv])))
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_max_2(self):
"""Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)"""
x, y = vectors("xy")
xv, yv = (
np.random.rand(100).astype(config.floatX),
np.random.rand(200).astype(config.floatX),
)
out = pt_max(pt.concatenate([x, y]), axis=None)
f = function([x, y], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv])))
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_sum_specific_axis(self):
"""Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0))

join_axis=0 is included in the reduction axes, so the rewrite applies.
"""
x, y = matrices("xy")
xv = np.array([[1, 2], [3, 4]], dtype=config.floatX)
yv = np.array([[5, 6]], dtype=config.floatX)
out = pt_sum(pt.concatenate([x, y], axis=0), axis=0)
f = function([x, y], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0))
topo = f.maker.fgraph.toposort()
assert not any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_max_3_not_applied(self):
"""Max(concat(a, b, c), axis=None) should NOT trigger (binary-only combine can't take >2)

Elemwise{maximum} is a binary op with nin=2, so it can't combine
three individual Max results at once like Add/Mul can.
"""
x, y, z = vectors("xyz")
out = pt_max(pt.concatenate([x, y, z]), axis=None)
f = function([x, y, z], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_not_applied_axis_excludes_join(self):
"""Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)"""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for follow up we can still optimize, it's still generally better to reduce before joining even if the join is still needed. Just need to change the axis then. My comment here is to make the docstring not so authoritative that sounds like this would be a problem. Mention it as not currently supported instead

x, y = matrices("xy")
out = pt_sum(pt.concatenate([x, y], axis=0), axis=1)
f = function([x, y], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_not_applied_empty_axis(self):
"""Sum(concat(a, b), axis=[]) should NOT trigger (empty reduction)"""
x, y = vectors("xy")
out = pt_sum(pt.concatenate([x, y], axis=0), axis=[])
f = function([x, y], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_not_applied_dynamic_axis(self):
"""Non-constant join axis should NOT trigger"""
axis = iscalar("axis")
x, y = vectors("xy")
out = pt_sum(join(axis, x, y), axis=None)
f = function([axis, x, y], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)


def test_local_useless_adds():
default_mode = get_default_mode()
Expand Down
Loading