-
Notifications
You must be signed in to change notification settings - Fork 194
Add python mode #2219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jessegrabowski
wants to merge
10
commits into
pymc-devs:main
Choose a base branch
from
jessegrabowski:python-mode
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+685
−156
Open
Add python mode #2219
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
71d722b
Add overridable per-node thunk hook to VMLinker
jessegrabowski bbdd349
Add pure-Python backend (PythonLinker + python_funcify)
jessegrabowski 4b66251
Add tests for Python backend
jessegrabowski ece9045
"py" -> PythonLinker, "perform" -> PerformLinker
jessegrabowski 644eb6b
Add Python backend Blockwise and Cholesky dispatches
jessegrabowski df4f69c
Add Python backend SolveTriangular dispatch
jessegrabowski bc08d09
Add Python backend QR dispatch and fix QR raw output shapes
jessegrabowski c2fa96d
Route VMLinker py thunks through the python_funcify dispatch
jessegrabowski 72bbd87
Add whole-graph JIT linker (pyjit), keep py the robust VM
jessegrabowski 29b972b
Default Op.perform to python_funcify; make Cholesky perform-less
jessegrabowski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from pytensor.link.python.linker import PythonLinker |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # isort: off | ||
| from pytensor.link.python.dispatch.basic import python_funcify | ||
|
|
||
| # Load dispatch specializations | ||
| import pytensor.link.python.dispatch.blockwise | ||
| import pytensor.link.python.dispatch.linalg | ||
|
|
||
| # isort: on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from functools import singledispatch | ||
|
|
||
|
|
||
| @singledispatch | ||
| def python_funcify(op, node=None, **kwargs): | ||
| """Return a fast pure-Python implementation of ``op`` as a callable. | ||
|
|
||
| The callable takes the node's inputs positionally and returns its output (a | ||
| single value, or a tuple for multi-output nodes). Register a specialization | ||
| to override an `Op`'s ``perform`` with a faster numpy/scipy path on the | ||
| Python backend. | ||
|
|
||
| Unregistered ops raise `NotImplementedError`, signalling the linker to fall | ||
| back to ``perform`` (via ``Op.make_thunk(impl="py")``). | ||
| """ | ||
| raise NotImplementedError( | ||
| f"No python_funcify implementation registered for {type(op).__name__}" | ||
| ) | ||
|
|
||
|
|
||
| def make_node_thunk_with_python_dispatch( | ||
| node, storage_map, compute_map, *, fallback, impl | ||
| ): | ||
| """Build a per-node thunk, preferring a registered `python_funcify` impl. | ||
|
|
||
| When `python_funcify` has a specialization for ``node.op``, its callable is | ||
| wrapped into a thunk that reads inputs from and writes outputs to | ||
| ``storage_map``. Otherwise ``fallback`` (``Op.make_thunk``) is used, which | ||
| covers ``perform`` ops and lazy ops like ``IfElse`` unchanged. | ||
| """ | ||
| try: | ||
| fn = python_funcify(node.op, node=node) | ||
| except NotImplementedError: | ||
| return fallback(node, storage_map, compute_map, impl) | ||
|
|
||
| return _wrap_callable_as_thunk(fn, node, storage_map, compute_map) | ||
|
|
||
|
|
||
| def _wrap_callable_as_thunk(fn, node, storage_map, compute_map): | ||
| input_storage = [storage_map[v] for v in node.inputs] | ||
| output_compute = [compute_map[v] for v in node.outputs] | ||
|
|
||
| if len(node.outputs) == 1: | ||
| [out_storage] = (storage_map[v] for v in node.outputs) | ||
|
|
||
| def thunk(fn=fn, inputs=input_storage, out=out_storage, cm=output_compute): | ||
| out[0] = fn(*(inp[0] for inp in inputs)) | ||
| cm[0][0] = True | ||
| else: | ||
| output_storage = [storage_map[v] for v in node.outputs] | ||
|
|
||
| def thunk(fn=fn, inputs=input_storage, outs=output_storage, cm=output_compute): | ||
| for storage, value in zip(outs, fn(*(inp[0] for inp in inputs))): | ||
| storage[0] = value | ||
| for entry in cm: | ||
| entry[0] = True | ||
|
|
||
| thunk.lazy = False | ||
| return thunk |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| import numpy as np | ||
|
|
||
| from pytensor.link.python.dispatch.basic import python_funcify | ||
| from pytensor.tensor.blockwise import Blockwise | ||
|
|
||
|
|
||
| @python_funcify.register(Blockwise) | ||
| def python_funcify_Blockwise(op, node=None, **kwargs): | ||
| core_node = op._create_dummy_core_node( | ||
| node.inputs, propagate_unbatched_core_inputs=True | ||
| ) | ||
| # Raises NotImplementedError when the core Op has no dispatch, which makes the | ||
| # whole Blockwise fall back to its (vectorized) perform. | ||
| core_fn = python_funcify(op.core_op, node=core_node) | ||
|
|
||
| out_dtypes = [out.type.dtype for out in node.outputs] | ||
| return np.vectorize(core_fn, signature=op.signature, otypes=out_dtypes) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # isort: off | ||
| import pytensor.link.python.dispatch.linalg.decomposition | ||
| import pytensor.link.python.dispatch.linalg.solvers | ||
| # isort: on |
4 changes: 4 additions & 0 deletions
4
pytensor/link/python/dispatch/linalg/decomposition/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # isort: off | ||
| import pytensor.link.python.dispatch.linalg.decomposition.cholesky | ||
| import pytensor.link.python.dispatch.linalg.decomposition.qr | ||
| # isort: on |
32 changes: 32 additions & 0 deletions
32
pytensor/link/python/dispatch/linalg/decomposition/cholesky.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| import numpy as np | ||
| from scipy.linalg import get_lapack_funcs | ||
|
|
||
| from pytensor.link.python.dispatch.basic import python_funcify | ||
| from pytensor.tensor.linalg.decomposition.cholesky import Cholesky | ||
|
|
||
|
|
||
| @python_funcify.register(Cholesky) | ||
| def python_funcify_Cholesky(op, node=None, **kwargs): | ||
| lower = op.lower | ||
| overwrite_a = op.overwrite_a | ||
| (potrf,) = get_lapack_funcs(("potrf",), dtype=node.inputs[0].type.dtype) | ||
|
|
||
| def cholesky(x): | ||
| if x.size == 0: | ||
| return np.empty_like(x) | ||
|
|
||
| # potrf only honors overwrite_a for F-contiguous input; transpose a | ||
| # C-contiguous array to benefit from it. | ||
| c_contiguous_input = overwrite_a and x.flags["C_CONTIGUOUS"] | ||
| if c_contiguous_input: | ||
| x = x.T | ||
| factor, info = potrf(x, lower=not lower, overwrite_a=True, clean=True) | ||
| factor = factor.T | ||
| else: | ||
| factor, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True) | ||
|
|
||
| if info != 0: | ||
| factor[...] = np.nan | ||
| return factor | ||
|
|
||
| return cholesky |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import numpy as np | ||
| from scipy.linalg import get_lapack_funcs | ||
|
|
||
| from pytensor.link.python.dispatch.basic import python_funcify | ||
| from pytensor.tensor.linalg.decomposition.qr import QR | ||
|
|
||
|
|
||
| @python_funcify.register(QR) | ||
| def python_funcify_QR(op, node=None, **kwargs): | ||
| mode = op.mode | ||
| pivoting = op.pivoting | ||
| overwrite_a = op.overwrite_a | ||
| call_and_get_lwork = op._call_and_get_lwork | ||
|
|
||
| def qr(x): | ||
| M, N = x.shape | ||
|
|
||
| if pivoting: | ||
| (geqp3,) = get_lapack_funcs(("geqp3",), (x,)) | ||
| factor, jpvt, tau, *_ = call_and_get_lwork( | ||
| geqp3, x, lwork=-1, overwrite_a=overwrite_a | ||
| ) | ||
| jpvt -= 1 # geqp3 returns 1-based indices | ||
| else: | ||
| (geqrf,) = get_lapack_funcs(("geqrf",), (x,)) | ||
| factor, tau, *_ = call_and_get_lwork( | ||
| geqrf, x, lwork=-1, overwrite_a=overwrite_a | ||
| ) | ||
|
|
||
| if mode not in ("economic", "raw") or M < N: | ||
| R = np.triu(factor) | ||
| else: | ||
| R = np.triu(factor[:N, :]) | ||
|
|
||
| if mode == "r": | ||
| return (R, jpvt) if pivoting else R | ||
| if mode == "raw": | ||
| return (factor, tau, R, jpvt) if pivoting else (factor, tau, R) | ||
|
|
||
| (orgqr,) = get_lapack_funcs(("orgqr",), (factor,)) | ||
| if M < N: | ||
| Q, *_ = call_and_get_lwork( | ||
| orgqr, factor[:, :M], tau, lwork=-1, overwrite_a=1 | ||
| ) | ||
| elif mode == "economic": | ||
| Q, *_ = call_and_get_lwork(orgqr, factor, tau, lwork=-1, overwrite_a=1) | ||
| else: | ||
| square = np.empty((M, M), dtype=factor.dtype.char) | ||
| square[:, :N] = factor | ||
| Q, *_ = call_and_get_lwork(orgqr, square, tau, lwork=-1, overwrite_a=1) | ||
|
|
||
| return (Q, R, jpvt) if pivoting else (Q, R) | ||
|
|
||
| return qr |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # isort: off | ||
| import pytensor.link.python.dispatch.linalg.solvers.triangular | ||
| # isort: on |
43 changes: 43 additions & 0 deletions
43
pytensor/link/python/dispatch/linalg/solvers/triangular.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import numpy as np | ||
| from scipy.linalg import get_lapack_funcs | ||
|
|
||
| from pytensor.link.python.dispatch.basic import python_funcify | ||
| from pytensor.tensor.linalg.solvers.triangular import SolveTriangular | ||
|
|
||
|
|
||
| @python_funcify.register(SolveTriangular) | ||
| def python_funcify_SolveTriangular(op, node=None, **kwargs): | ||
| lower = op.lower | ||
| unit_diagonal = op.unit_diagonal | ||
| overwrite_b = op.overwrite_b | ||
| (trtrs,) = get_lapack_funcs(("trtrs",), dtype=node.outputs[0].type.dtype) | ||
|
|
||
| def solve_triangular(A, b): | ||
| if b.size == 0: | ||
| return np.empty_like(b) | ||
|
|
||
| if A.flags["F_CONTIGUOUS"]: | ||
| x, info = trtrs( | ||
| A, | ||
| b, | ||
| overwrite_b=overwrite_b, | ||
| lower=lower, | ||
| trans=0, | ||
| unitdiag=unit_diagonal, | ||
| ) | ||
| else: | ||
| # trtrs expects Fortran ordering, so solve the transposed system. | ||
| x, info = trtrs( | ||
| A.T, | ||
| b, | ||
| overwrite_b=overwrite_b, | ||
| lower=not lower, | ||
| trans=1, | ||
| unitdiag=unit_diagonal, | ||
| ) | ||
|
|
||
| if info != 0: | ||
| x[...] = np.nan | ||
| return x | ||
|
|
||
| return solve_triangular |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from pytensor.link.vm import VMLinker | ||
|
|
||
|
|
||
| class PythonLinker(VMLinker): | ||
| """A pure-Python `VMLinker` that runs each node through the `python_funcify` registry. | ||
|
|
||
| Per node, a registered `python_funcify` implementation (a fast numpy/scipy | ||
| callable) is wrapped into a thunk; unregistered ops fall back to their | ||
| ``perform`` method via ``Op.make_thunk(impl="py")``. Lazy ops such as | ||
| ``IfElse`` fall through to their own thunks, so the VM still short-circuits | ||
| them. Fusion is excluded because fused ``Composite`` loops run slower than | ||
| vectorized numpy on this backend. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| allow_gc=None, | ||
| use_cloop=False, | ||
| callback=None, | ||
| callback_input=None, | ||
| lazy=None, | ||
| schedule=None, | ||
| c_thunks=None, | ||
| allow_partial_eval=None, | ||
| ): | ||
| # The Python backend never emits C: per-node Python thunks, Python VM. | ||
| super().__init__( | ||
| allow_gc=allow_gc, | ||
| use_cloop=False, | ||
| callback=callback, | ||
| callback_input=callback_input, | ||
| lazy=lazy, | ||
| schedule=schedule, | ||
| c_thunks=False, | ||
| allow_partial_eval=allow_partial_eval, | ||
| ) | ||
| # ``c_thunks=False`` already gives ("minimum_compile", "py_only") / | ||
| # ("cxx_only",); add fusion for the numpy backend. | ||
| self.incompatible_rewrites = ("cxx_only", "fusion") | ||
|
|
||
| def _make_node_thunk(self, node, storage_map, compute_map, impl): | ||
| from pytensor.link.python.dispatch.basic import ( | ||
| make_node_thunk_with_python_dispatch, | ||
| ) | ||
|
|
||
| return make_node_thunk_with_python_dispatch( | ||
| node, | ||
| storage_map, | ||
| compute_map, | ||
| fallback=super()._make_node_thunk, | ||
| impl=impl, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.