-
Notifications
You must be signed in to change notification settings - Fork 194
Optimize CAReduce of Join by pushing reduction through concatenation #2130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
900489a
23220c4
dcfe621
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ | |
| from pytensor.graph.traversal import ancestors | ||
| from pytensor.printing import debugprint, pprint | ||
| from pytensor.scalar import PolyGamma, Psi, TriGamma | ||
| from pytensor.tensor.basic import Alloc, constant, join, second, switch | ||
| from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch | ||
| from pytensor.tensor.blas import Dot22, Gemv | ||
| from pytensor.tensor.blas_c import CGemv | ||
| from pytensor.tensor.blockwise import Blockwise | ||
|
|
@@ -3500,12 +3500,12 @@ def test_type(self): | |
| topo = f.maker.fgraph.toposort() | ||
| assert not isinstance(topo[-1].op, Elemwise) | ||
|
|
||
| # This case could be rewritten | ||
| # Join axis is included in reduction axes | ||
| A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) | ||
| f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) | ||
| np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not isinstance(topo[-1].op, Elemwise) | ||
| assert isinstance(topo[-1].op, Elemwise) | ||
|
|
||
| A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) | ||
| f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) | ||
|
|
@@ -3559,6 +3559,95 @@ def test_non_ds_inputs(self): | |
| expected_out = add(exp(x), log(x)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| def test_careduce_join_sum_2(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check the kind of approach to writing rewrite tests we're trying to settle on: #2103 |
||
| """Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs""" | ||
| x, y = vectors("xy") | ||
| xv, yv = ( | ||
| np.random.rand(100).astype(config.floatX), | ||
| np.random.rand(200).astype(config.floatX), | ||
| ) | ||
| out = pt_sum(pt.concatenate([x, y]), axis=None) | ||
| f = function([x, y], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]))) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_sum_3(self): | ||
| """Sum(concat(a, b, c), axis=None) -> Add(Sum(a), Sum(b), Sum(c)) with 3 inputs (variadic combine)""" | ||
| x, y, z = vectors("xyz") | ||
| xv = np.random.rand(100).astype(config.floatX) | ||
| yv = np.random.rand(150).astype(config.floatX) | ||
| zv = np.random.rand(200).astype(config.floatX) | ||
| out = pt_sum(pt.concatenate([x, y, z]), axis=None) | ||
| f = function([x, y, z], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv]))) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_max_2(self): | ||
| """Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)""" | ||
| x, y = vectors("xy") | ||
| xv, yv = ( | ||
| np.random.rand(100).astype(config.floatX), | ||
| np.random.rand(200).astype(config.floatX), | ||
| ) | ||
| out = pt_max(pt.concatenate([x, y]), axis=None) | ||
| f = function([x, y], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv]))) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_sum_specific_axis(self): | ||
| """Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0)) | ||
|
|
||
| join_axis=0 is included in the reduction axes, so the rewrite applies. | ||
| """ | ||
| x, y = matrices("xy") | ||
| xv = np.array([[1, 2], [3, 4]], dtype=config.floatX) | ||
| yv = np.array([[5, 6]], dtype=config.floatX) | ||
| out = pt_sum(pt.concatenate([x, y], axis=0), axis=0) | ||
| f = function([x, y], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0)) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_max_3_not_applied(self): | ||
| """Max(concat(a, b, c), axis=None) should NOT trigger (binary-only combine can't take >2) | ||
|
|
||
| Elemwise{maximum} is a binary op with nin=2, so it can't combine | ||
| three individual Max results at once like Add/Mul can. | ||
| """ | ||
| x, y, z = vectors("xyz") | ||
| out = pt_max(pt.concatenate([x, y, z]), axis=None) | ||
| f = function([x, y, z], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_not_applied_axis_excludes_join(self): | ||
| """Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for follow up we can still optimize, it's still generally better to reduce before joining even if the join is still needed. Just need to change the axis then. My comment here is to make the docstring not so authoritative that sounds like this would be a problem. Mention it as not currently supported instead |
||
| x, y = matrices("xy") | ||
| out = pt_sum(pt.concatenate([x, y], axis=0), axis=1) | ||
| f = function([x, y], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_not_applied_empty_axis(self): | ||
| """Sum(concat(a, b), axis=[]) should NOT trigger (empty reduction)""" | ||
| x, y = vectors("xy") | ||
| out = pt_sum(pt.concatenate([x, y], axis=0), axis=[]) | ||
| f = function([x, y], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_not_applied_dynamic_axis(self): | ||
| """Non-constant join axis should NOT trigger""" | ||
| axis = iscalar("axis") | ||
| x, y = vectors("xy") | ||
| out = pt_sum(join(axis, x, y), axis=None) | ||
| f = function([axis, x, y], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
|
|
||
| def test_local_useless_adds(): | ||
| default_mode = get_default_mode() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.