Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.python.linker import PythonLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker

Expand All @@ -40,7 +41,8 @@
# Mode, it will be used as the key to retrieve the real linker in this
# dictionary
predefined_linkers = {
"py": PerformLinker(), # Use allow_gc PyTensor flag
"py": PythonLinker(), # Pure-Python backend with the python_funcify dispatch
"perform": PerformLinker(), # Per-node reference: runs every Op's perform method
"c": CLinker(), # Don't support gc. so don't check allow_gc
"c|py": OpWiseCLinker(), # Use allow_gc PyTensor flag
"c|py_nogc": OpWiseCLinker(allow_gc=False),
Expand Down Expand Up @@ -476,6 +478,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
RewriteDatabaseQuery(include=["fast_run", "mlx"]),
)

PYTHON = Mode(
PythonLinker(),
RewriteDatabaseQuery(include=["fast_run"]),
)

FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
Expand All @@ -495,6 +502,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
"MLX": MLX,
"PYTHON": PYTHON,
}

_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytensor.link.python.linker import PythonLinker
8 changes: 8 additions & 0 deletions pytensor/link/python/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# isort: off
from pytensor.link.python.dispatch.basic import python_funcify

# Load dispatch specializations
import pytensor.link.python.dispatch.blockwise
import pytensor.link.python.dispatch.linalg

# isort: on
59 changes: 59 additions & 0 deletions pytensor/link/python/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from functools import singledispatch


@singledispatch
def python_funcify(op, node=None, **kwargs):
"""Return a fast pure-Python implementation of ``op`` as a callable.

The callable takes the node's inputs positionally and returns its output (a
single value, or a tuple for multi-output nodes). Register a specialization
to override an `Op`'s ``perform`` with a faster numpy/scipy path on the
Python backend.

Unregistered ops raise `NotImplementedError`, signalling the linker to fall
back to ``perform`` (via ``Op.make_thunk(impl="py")``).
"""
raise NotImplementedError(
f"No python_funcify implementation registered for {type(op).__name__}"
)


def make_node_thunk_with_python_dispatch(
node, storage_map, compute_map, *, fallback, impl
):
"""Build a per-node thunk, preferring a registered `python_funcify` impl.

When `python_funcify` has a specialization for ``node.op``, its callable is
wrapped into a thunk that reads inputs from and writes outputs to
``storage_map``. Otherwise ``fallback`` (``Op.make_thunk``) is used, which
covers ``perform`` ops and lazy ops like ``IfElse`` unchanged.
"""
try:
fn = python_funcify(node.op, node=node)
except NotImplementedError:
return fallback(node, storage_map, compute_map, impl)

return _wrap_callable_as_thunk(fn, node, storage_map, compute_map)


def _wrap_callable_as_thunk(fn, node, storage_map, compute_map):
input_storage = [storage_map[v] for v in node.inputs]
output_compute = [compute_map[v] for v in node.outputs]

if len(node.outputs) == 1:
[out_storage] = (storage_map[v] for v in node.outputs)

def thunk(fn=fn, inputs=input_storage, out=out_storage, cm=output_compute):
out[0] = fn(*(inp[0] for inp in inputs))
cm[0][0] = True
else:
output_storage = [storage_map[v] for v in node.outputs]

def thunk(fn=fn, inputs=input_storage, outs=output_storage, cm=output_compute):
for storage, value in zip(outs, fn(*(inp[0] for inp in inputs))):
storage[0] = value
for entry in cm:
entry[0] = True

thunk.lazy = False
return thunk
17 changes: 17 additions & 0 deletions pytensor/link/python/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.blockwise import Blockwise


@python_funcify.register(Blockwise)
def python_funcify_Blockwise(op, node=None, **kwargs):
core_node = op._create_dummy_core_node(
node.inputs, propagate_unbatched_core_inputs=True
)
# Raises NotImplementedError when the core Op has no dispatch, which makes the
# whole Blockwise fall back to its (vectorized) perform.
core_fn = python_funcify(op.core_op, node=core_node)

out_dtypes = [out.type.dtype for out in node.outputs]
return np.vectorize(core_fn, signature=op.signature, otypes=out_dtypes)
4 changes: 4 additions & 0 deletions pytensor/link/python/dispatch/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# isort: off
import pytensor.link.python.dispatch.linalg.decomposition
import pytensor.link.python.dispatch.linalg.solvers
# isort: on
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# isort: off
import pytensor.link.python.dispatch.linalg.decomposition.cholesky
import pytensor.link.python.dispatch.linalg.decomposition.qr
# isort: on
32 changes: 32 additions & 0 deletions pytensor/link/python/dispatch/linalg/decomposition/cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky


@python_funcify.register(Cholesky)
def python_funcify_Cholesky(op, node=None, **kwargs):
lower = op.lower
overwrite_a = op.overwrite_a
(potrf,) = get_lapack_funcs(("potrf",), dtype=node.inputs[0].type.dtype)

def cholesky(x):
if x.size == 0:
return np.empty_like(x)

# potrf only honors overwrite_a for F-contiguous input; transpose a
# C-contiguous array to benefit from it.
c_contiguous_input = overwrite_a and x.flags["C_CONTIGUOUS"]
if c_contiguous_input:
x = x.T
factor, info = potrf(x, lower=not lower, overwrite_a=True, clean=True)
factor = factor.T
else:
factor, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)

if info != 0:
factor[...] = np.nan
return factor

return cholesky
54 changes: 54 additions & 0 deletions pytensor/link/python/dispatch/linalg/decomposition/qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.linalg.decomposition.qr import QR


@python_funcify.register(QR)
def python_funcify_QR(op, node=None, **kwargs):
mode = op.mode
pivoting = op.pivoting
overwrite_a = op.overwrite_a
call_and_get_lwork = op._call_and_get_lwork

def qr(x):
M, N = x.shape

if pivoting:
(geqp3,) = get_lapack_funcs(("geqp3",), (x,))
factor, jpvt, tau, *_ = call_and_get_lwork(
geqp3, x, lwork=-1, overwrite_a=overwrite_a
)
jpvt -= 1 # geqp3 returns 1-based indices
else:
(geqrf,) = get_lapack_funcs(("geqrf",), (x,))
factor, tau, *_ = call_and_get_lwork(
geqrf, x, lwork=-1, overwrite_a=overwrite_a
)

if mode not in ("economic", "raw") or M < N:
R = np.triu(factor)
else:
R = np.triu(factor[:N, :])

if mode == "r":
return (R, jpvt) if pivoting else R
if mode == "raw":
return (factor, tau, R, jpvt) if pivoting else (factor, tau, R)

(orgqr,) = get_lapack_funcs(("orgqr",), (factor,))
if M < N:
Q, *_ = call_and_get_lwork(
orgqr, factor[:, :M], tau, lwork=-1, overwrite_a=1
)
elif mode == "economic":
Q, *_ = call_and_get_lwork(orgqr, factor, tau, lwork=-1, overwrite_a=1)
else:
square = np.empty((M, M), dtype=factor.dtype.char)
square[:, :N] = factor
Q, *_ = call_and_get_lwork(orgqr, square, tau, lwork=-1, overwrite_a=1)

return (Q, R, jpvt) if pivoting else (Q, R)

return qr
3 changes: 3 additions & 0 deletions pytensor/link/python/dispatch/linalg/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# isort: off
import pytensor.link.python.dispatch.linalg.solvers.triangular
# isort: on
43 changes: 43 additions & 0 deletions pytensor/link/python/dispatch/linalg/solvers/triangular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
from scipy.linalg import get_lapack_funcs

from pytensor.link.python.dispatch.basic import python_funcify
from pytensor.tensor.linalg.solvers.triangular import SolveTriangular


@python_funcify.register(SolveTriangular)
def python_funcify_SolveTriangular(op, node=None, **kwargs):
lower = op.lower
unit_diagonal = op.unit_diagonal
overwrite_b = op.overwrite_b
(trtrs,) = get_lapack_funcs(("trtrs",), dtype=node.outputs[0].type.dtype)

def solve_triangular(A, b):
if b.size == 0:
return np.empty_like(b)

if A.flags["F_CONTIGUOUS"]:
x, info = trtrs(
A,
b,
overwrite_b=overwrite_b,
lower=lower,
trans=0,
unitdiag=unit_diagonal,
)
else:
# trtrs expects Fortran ordering, so solve the transposed system.
x, info = trtrs(
A.T,
b,
overwrite_b=overwrite_b,
lower=not lower,
trans=1,
unitdiag=unit_diagonal,
)

if info != 0:
x[...] = np.nan
return x

return solve_triangular
52 changes: 52 additions & 0 deletions pytensor/link/python/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pytensor.link.vm import VMLinker


class PythonLinker(VMLinker):
"""A pure-Python `VMLinker` that runs each node through the `python_funcify` registry.

Per node, a registered `python_funcify` implementation (a fast numpy/scipy
callable) is wrapped into a thunk; unregistered ops fall back to their
``perform`` method via ``Op.make_thunk(impl="py")``. Lazy ops such as
``IfElse`` fall through to their own thunks, so the VM still short-circuits
them. Fusion is excluded because fused ``Composite`` loops run slower than
vectorized numpy on this backend.
"""

def __init__(
self,
allow_gc=None,
use_cloop=False,
callback=None,
callback_input=None,
lazy=None,
schedule=None,
c_thunks=None,
allow_partial_eval=None,
):
# The Python backend never emits C: per-node Python thunks, Python VM.
super().__init__(
allow_gc=allow_gc,
use_cloop=False,
callback=callback,
callback_input=callback_input,
lazy=lazy,
schedule=schedule,
c_thunks=False,
allow_partial_eval=allow_partial_eval,
)
# ``c_thunks=False`` already gives ("minimum_compile", "py_only") /
# ("cxx_only",); add fusion for the numpy backend.
self.incompatible_rewrites = ("cxx_only", "fusion")

def _make_node_thunk(self, node, storage_map, compute_map, impl):
Comment thread
jessegrabowski marked this conversation as resolved.
Outdated
from pytensor.link.python.dispatch.basic import (
make_node_thunk_with_python_dispatch,
)

return make_node_thunk_with_python_dispatch(
node,
storage_map,
compute_map,
fallback=super()._make_node_thunk,
impl=impl,
)
15 changes: 11 additions & 4 deletions pytensor/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,16 @@ def make_vm(
)
return vm

def _make_node_thunk(self, node, storage_map, compute_map, impl):
"""Create the thunk for a single node.

Subclasses override this to intercept thunk creation (e.g. to consult a
dispatch registry) before falling back to ``Op.make_thunk``.
"""
# no-recycling is done at each VM.__call__, so there is no need to cause
# duplicate C code by passing no_recycling here.
return node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)

def make_all(
self,
profiler=None,
Expand Down Expand Up @@ -1230,11 +1240,8 @@ def make_all(
for node in order:
try:
thunk_start = time.perf_counter()
# no-recycling is done at each VM.__call__ So there is
# no need to cause duplicate c code by passing
# no_recycling here.
thunks.append(
node.op.make_thunk(node, storage_map, compute_map, [], impl=impl)
self._make_node_thunk(node, storage_map, compute_map, impl)
)
linker_make_thunk_time[node] = time.perf_counter() - thunk_start
if not hasattr(thunks[-1], "lazy"):
Expand Down
Loading
Loading