diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index b75232a68d..2f4e1dd00c 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -58,7 +58,7 @@ from pytensor.tensor.random.rewriting import local_subtensor_rv_lift from pytensor.tensor.rewriting.basic import register_canonicalize from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp -from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.rewriting.shape import ShapeOptimizer from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -149,6 +149,17 @@ def remove_DiracDelta(fgraph, node): position=0, ) +# Attach ShapeFeature after `pre_lower_xtensor` (registered at position 0.1 in +# `pymc/dims/__init__.py`) so it never tracks shapes through xtensor wrappers, +# which would leak references to the underlying `RandomVariable`s into shape +# sub-graphs that no `lower_xtensor` round-trip can cancel. +logprob_rewrites_db.register( + "ShapeOpt", + ShapeOptimizer(), + "basic", + position=0.11, +) + # Introduce sigmoid. We do it before canonicalization so that useless mul are removed next logprob_rewrites_db.register( "local_exp_over_1_plus_exp", @@ -245,15 +256,11 @@ def construct_ir_fgraph( ------- A `FunctionGraph` of the measurable IR. """ - # We add `ShapeFeature` because it will get rid of references to the old - # `RandomVariable`s that have been lifted; otherwise, it will be difficult - # to give good warnings when an unaccounted for `RandomVariable` is encountered fgraph = FunctionGraph( outputs=list(rv_values.keys()), clone=True, copy_orphans=False, copy_inputs=False, - features=[ShapeFeature()], ) # Replace valued RVs by ValuedVar Ops so that rewrites are aware of conditioning points