Skip to content
136 changes: 98 additions & 38 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from typing import cast

from pytensor.compile.maker import function
from pytensor.compile.mode import get_mode
from pytensor.compile.rebuild import rebuild_collect_shared
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import DisconnectedType, disconnected_type, grad, pushforward
from pytensor.graph.basic import (
Expand All @@ -23,7 +21,7 @@
from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.replace import graph_replace, rebuild_mutable
from pytensor.graph.traversal import graph_inputs
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.shape import Shape_i
Expand Down Expand Up @@ -101,23 +99,17 @@ def construct_nominal_fgraph(
)
)

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + implicit_shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(_clone_d, update_d, update_expr, new_shared_inputs),
) = new
# Rebuild ``outputs`` rooted at the dummy (nominal) inputs. ``rebuild_mutable``
# reconstructs every node, so this works whether the inner graph is mutable or
# still references immutable ``FrozenApply`` nodes (when an inner-graph rewrite
# assembled this op from a frozen inner graph) -- the rewrite need not unfreeze
# just to have construction re-freeze. Shared inputs were gathered above and
# are part of ``replacements``.
local_inputs = dummy_inputs + dummy_implicit_shared_inputs
local_outputs = rebuild_mutable(outputs, replacements)

assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not new_shared_inputs

fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

Expand All @@ -135,7 +127,9 @@ def construct_nominal_fgraph(
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, implicit_shared_inputs, update_d, update_expr
# Inner graphs never carry shared-variable updates (asserted previously via
# rebuild_collect_shared); the update maps are always empty.
return fgraph, implicit_shared_inputs, {}, {}


class OpFromGraph(Op, HasInnerGraph):
Expand Down Expand Up @@ -338,17 +332,36 @@ def __init__(

self.is_inline = inline

self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inner_fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)
self._frozen_fgraph = self.fgraph.freeze()
# Keep only the immutable (frozen) inner graph as ``op.fgraph``; the
# mutable copy is transient, so the canonical inner graph can never be
# mutated in place. Freeze with the default (no dedup): inner graphs may
# carry inplace ops whose destroyed buffers must stay distinct, and
# structural folding would alias them. See ``FunctionGraph.freeze``.
self.fgraph = inner_fgraph.freeze()

if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)

# `compile_kwargs` used to control how the inner graph was compiled.
# That is now the job of the `optimize_inner_graphs` rewrite (which
# inherits the outer compilation), so they are deprecated AND ignored:
# the inner function is compiled with default settings (see `fn`).
# `on_unused_input` is exempt: tolerating unused inputs is now the
# default behavior, so passing it is a harmless no-op (not warned).
deprecated_kwargs = {k for k in kwargs if k != "on_unused_input"}
if deprecated_kwargs:
warnings.warn(
"Passing `compile_kwargs` to `OpFromGraph` is deprecated and "
"now ignored: the inner graph inherits the outer compilation. "
f"Ignored: {sorted(deprecated_kwargs)}.",
FutureWarning,
)
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand Down Expand Up @@ -415,7 +428,7 @@ def _freeze_override(self, override, make_dummy_args):
if override is None:
return None
if isinstance(override, OpFromGraph):
return override._frozen_fgraph
return override.fgraph

all_inputs, callable_args = make_dummy_args()

Expand Down Expand Up @@ -477,7 +490,7 @@ def __eq__(self, other):
if type(self) is not type(other):
return False
if (
self._frozen_fgraph != other._frozen_fgraph
self.fgraph != other.fgraph
or self.is_inline != other.is_inline
or self.destroy_map != other.destroy_map
or len(self.shared_inputs) != len(other.shared_inputs)
Expand All @@ -501,7 +514,7 @@ def __eq__(self, other):
)

def __hash__(self):
return hash((type(self), self._frozen_fgraph, self.is_inline))
return hash((type(self), self.fgraph, self.is_inline))

def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
Expand Down Expand Up @@ -560,8 +573,13 @@ def _build_and_cache_lop_op(
except KeyError:
pass

inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
# Differentiate a thawed copy of the inner graph so ``grad`` walks
# mutable ``Apply`` nodes rather than the immutable ``FrozenApply`` nodes
# of ``self.fgraph`` (whose tuple inputs/outputs break Ops that
# concatenate them, e.g. ``Blockwise.pullback``).
unfrozen_fgraph = self.fgraph.unfreeze()
inner_inputs = list(unfrozen_fgraph.inputs)
inner_outputs = list(unfrozen_fgraph.outputs)
nin = len(inner_inputs)
nout = len(inner_outputs)
pullback_overrides = self.pullback_overrides
Expand Down Expand Up @@ -684,8 +702,10 @@ def _build_and_cache_rop_op(self):
if self._rop_op_cache is not None:
return self._rop_op_cache

inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
# Thaw the inner graph before differentiating (see ``_build_and_cache_lop_op``).
unfrozen_fgraph = self.fgraph.unfreeze()
inner_inputs = list(unfrozen_fgraph.inputs)
inner_outputs = list(unfrozen_fgraph.outputs)
nout = len(inner_outputs)
pushforward_overrides = self.pushforward_overrides

Expand Down Expand Up @@ -816,9 +836,7 @@ def make_node(self, *inputs):

# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, copy_inputs_over=True
)
new_inner_outputs = rebuild_mutable(self.inner_outputs, replace)

# It's possible that the new shared variable inputs aren't actually
# shared variables. When they aren't we need to add them as new
Expand Down Expand Up @@ -919,25 +937,67 @@ def fn(self):
if getattr(self, "_fn", None) is not None:
return self._fn

kwargs = self.kwargs.copy()
mode = get_mode(kwargs.pop("mode", None)).excluding("symbolic_op_recognition")
self._fn = function(self.inner_inputs, self.inner_outputs, mode=mode, **kwargs)
# `compile_kwargs` are deprecated and ignored; compile the (already
# optimized) inner graph with default settings. Inner graphs commonly
# have unused inputs (e.g. rng, size), which are tolerated by default.
# They may also carry inplace ops baked in by ``optimize_inner_graphs``
# (or built inplace, e.g. ``FusedElemwise``); those only ever destroy
# internal buffers (inputs stay protected), so we accept them.
mode = get_mode(None).excluding("symbolic_op_recognition")
# The canonical inner graph is frozen (immutable); compile an
# ``unfreeze()``d mutable copy. ``function`` re-clones it while wrapping
# the inputs/outputs as In/Out specs; that extra clone is one-time since
# ``_fn`` is cached.
unfrozen_fgraph = self.fgraph.unfreeze()
self._fn = function(
unfrozen_fgraph.inputs,
unfrozen_fgraph.outputs,
mode=mode,
on_unused_input="ignore",
accept_inplace=True,
)
self._fn.trust_input = True

return self._fn

@property
def inner_inputs(self):
return self.fgraph.inputs
# A list (not the frozen tuple) so callers that concatenate inner
# inputs/outputs keep list semantics. Read-only views of the immutable
# graph; manipulating them requires a fresh/unfrozen graph.
return list(self.fgraph.inputs)

@property
def inner_outputs(self):
return self.fgraph.outputs
return list(self.fgraph.outputs)

def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
return res
# The inner graph is immutable (a frozen ``FunctionGraph``), so there is
# nothing to deep-clone -- mirror ``Composite.clone``.
return self

def clone_with_inner_graph(self, inner_fgraph: FunctionGraph) -> OpFromGraph:
"""Return a copy of this op whose inner graph is ``inner_fgraph``.

Used by the ``optimize_inner_graphs`` rewrite to build a new op
carrying an already-optimized inner graph without ever mutating
``self``. The subclass and all properties/overrides are preserved
(via ``copy``); only the inner graph and the state derived from it are
rebuilt.
"""
new = copy(self)
new_fgraph, new.shared_inputs, _, _ = construct_nominal_fgraph(
list(inner_fgraph.inputs), list(inner_fgraph.outputs)
)
new.fgraph = new_fgraph.freeze()
new.input_types = [inp.type for inp in new.fgraph.inputs]
new.output_types = [out.type for out in new.fgraph.outputs]
# Drop caches tied to the previous inner graph.
new._lop_op_cache = {}
new._rop_op_cache = None
new._frozen_lop = None
new._frozen_rop = None
return new

def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
Expand Down
Loading
Loading