You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We should implement a new non-user facing Op that is introduced during specialization to fuse Composite and CAReduce Ops.
Terminology
Scalar Op: Op that performs a basic operation on scalar inputs (add, exp, sin, ...)
Elemwise Op: Op that extends a Scalar Op into tensor inputs (similar to numpy ufunc). This is what users usually build graphs in terms of when they write tensor.add, tensor.exp, ...)
Composite Op: A scalar Op that performs multiple pure scalar operations on a set of scalar inputs. It "fuses them" in a single pass. Composite Ops can also be turned into Elemwise Ops, which leads us to...
FusionRewrite: Rewrite during the specialization phase that replaces multiple Elemwise operations by a single Elemwise Composite, so as to avoid iterating over each intermediate output.
CAReduce: An Op that performs a Commutative Associative Reduce operation on tensor inputs. It has a core binary ScalarOp such as Add or And. It is called sequentially on the inputs along certain axis, until those are reduced.
Description
Users don't usually work with Composites directly. Instead we have a FusionRewrite introduced during the specialization phase that replaces sequences of simple Elemwise Ops by a single large Elemwise Composite:
We want to perform a single pass through xs. The evaluation pseudo-code in Python would look something like:
reduction_out=0# Initial value of the Add CAReduceforxinxs.ravel():
reduction_out+=exp(sin(cos(log(x))))
This requires either extending the CAReduce (or creating a new Op) that represents that new type of composite operation. Ideally it would support Composites with multiple reduced and non-reduced outputs and multiple types of reduction (say And in one output and Add in another. I am not sure about axis... So we might only apply it to the None case for now, where all axis are reduced.
As we want to discontinue the C backend, I think we should focus on implementing this for the Numba backend alone (and Python, for debugging purposes). The JAX backend probably does optimizations of this kind directly and we are not working at such a low level there anyway, so we wouldn't do anything about JAX either.
Steps
Implement new Op that performs Elemwise + partial CAReduce operations on certain outputs. This Op must contain all the meta-information needed to create the write evaluation function. This Op need not have gradients nor C or JAX implementation.
Implement Python perform method (for debugging purposes, arguably lower priority, but it has proved useful to debug and understand other complex Ops like Scan)
Optimize Sums of MakeVectors and Joins #59 is ever more relevant here, as it would increase the number of Elemwise+Reduction graphs that could be optimized, by removing unnecessary shape-related operations and increasing the chance that they happen immediately below a Composite operation.
ricardoV94
changed the title
Fuse Composite Ops with CAReduce Ops in Numba backend
Fuse Composite Ops with CAReduce Ops in Numba/C backend
Jun 30, 2023
Goal
We should implement a new non-user facing Op that is introduced during specialization to fuse Composite and CAReduce Ops.
Terminology
tensor.add
,tensor.exp
, ...)Op
that performs a Commutative Associative Reduce operation on tensor inputs. It has a core binaryScalar
Op
such asAdd
orAnd
. It is called sequentially on the inputs along certain axis, until those are reduced.Description
Users don't usually work with Composites directly. Instead we have a FusionRewrite introduced during the specialization phase that replaces sequences of simple Elemwise Ops by a single large Elemwise Composite:
pytensor/pytensor/tensor/rewriting/elemwise.py
Line 600 in 4730d0c
You can see it here:
We would like to follow up the FusionRewrite with a new rewrite that fuses Composites with CAReduce operations. So if we have a graph like this:
We want to perform a single pass through
xs
. The evaluation pseudo-code in Python would look something like:This requires either extending the CAReduce (or creating a new Op) that represents that new type of composite operation. Ideally it would support Composites with multiple reduced and non-reduced outputs and multiple types of reduction (say
And
in one output andAdd
in another. I am not sure about axis... So we might only apply it to theNone
case for now, where all axis are reduced.As we want to discontinue the C backend, I think we should focus on implementing this for the Numba backend alone (and Python, for debugging purposes). The JAX backend probably does optimizations of this kind directly and we are not working at such a low level there anyway, so we wouldn't do anything about JAX either.
Steps
Relevant links
ScalarOp
ofCAReduce
. This however won't work for multiple mixed outputs.Elemwise
graphs that have multiple outputs and clients #121 extended the FusionRewrite to handle multiple output CompositesSum
s ofMakeVector
s andJoin
s #59 is ever more relevant here, as it would increase the number of Elemwise+Reduction graphs that could be optimized, by removing unnecessary shape-related operations and increasing the chance that they happen immediately below a Composite operation.The text was updated successfully, but these errors were encountered: