Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.rewriting.shape import ShapeOptimizer
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -149,6 +149,17 @@ def remove_DiracDelta(fgraph, node):
position=0,
)

# Attach ShapeFeature after `pre_lower_xtensor` (registered at position 0.1 in
# `pymc/dims/__init__.py`) so it never tracks shapes through xtensor wrappers,
# which would leak references to the underlying `RandomVariable`s into shape
# sub-graphs that no `lower_xtensor` round-trip can cancel.
logprob_rewrites_db.register(
"ShapeOpt",
ShapeOptimizer(),
"basic",
position=0.11,
)

# Introduce sigmoid. We do it before canonicalization so that useless mul are removed next
logprob_rewrites_db.register(
"local_exp_over_1_plus_exp",
Expand Down Expand Up @@ -245,15 +256,11 @@ def construct_ir_fgraph(
-------
A `FunctionGraph` of the measurable IR.
"""
# We add `ShapeFeature` because it will get rid of references to the old
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
# to give good warnings when an unaccounted for `RandomVariable` is encountered
fgraph = FunctionGraph(
outputs=list(rv_values.keys()),
clone=True,
copy_orphans=False,
copy_inputs=False,
features=[ShapeFeature()],
)

# Replace valued RVs by ValuedVar Ops so that rewrites are aware of conditioning points
Expand Down
Loading