Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse Composite Ops with CAReduce Ops in Numba/C backend #224

Open
ricardoV94 opened this issue Feb 20, 2023 · 0 comments
Open

Fuse Composite Ops with CAReduce Ops in Numba/C backend #224

ricardoV94 opened this issue Feb 20, 2023 · 0 comments
Labels
help wanted Extra attention is needed numba performance

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 20, 2023

Goal

We should implement a new non-user facing Op that is introduced during specialization to fuse Composite and CAReduce Ops.

Terminology

  • Scalar Op: Op that performs a basic operation on scalar inputs (add, exp, sin, ...)
  • Elemwise Op: Op that extends a Scalar Op into tensor inputs (similar to numpy ufunc). This is what users usually build graphs in terms of when they write tensor.add, tensor.exp, ...)
  • Composite Op: A scalar Op that performs multiple pure scalar operations on a set of scalar inputs. It "fuses them" in a single pass. Composite Ops can also be turned into Elemwise Ops, which leads us to...
  • FusionRewrite: Rewrite during the specialization phase that replaces multiple Elemwise operations by a single Elemwise Composite, so as to avoid iterating over each intermediate output.
  • CAReduce: An Op that performs a Commutative Associative Reduce operation on tensor inputs. It has a core binary Scalar Op such as Add or And. It is called sequentially on the inputs along certain axis, until those are reduced.

Description

Users don't usually work with Composites directly. Instead we have a FusionRewrite introduced during the specialization phase that replaces sequences of simple Elemwise Ops by a single large Elemwise Composite:

class FusionOptimizer(GraphRewriter):

You can see it here:

import pytensor
import pytensor.tensor as pt

xs = pt.vector("xs")
out = pt.exp(pt.sin(pt.cos(pt.log(xs))))
print("Before:")
pytensor.dprint(out)

f_out = pytensor.function([xs], out)
print("\nAfter:")
pytensor.dprint(f_out)
Before:
Elemwise{exp,no_inplace} [id A]
 |Elemwise{sin,no_inplace} [id B]
   |Elemwise{cos,no_inplace} [id C]
     |Elemwise{log,no_inplace} [id D]
       |xs [id E]

After:
Elemwise{Composite} [id A] 0
 |xs [id B]

Inner graphs:
Elemwise{Composite} [id A]
 >exp [id C]
 > |sin [id D]
 >   |cos [id E]
 >     |log [id F]
 >       |<float64> [id G]

We would like to follow up the FusionRewrite with a new rewrite that fuses Composites with CAReduce operations. So if we have a graph like this:

xs = pt.vector("xs")
elemwise_out = pt.exp(pt.sin(pt.cos(pt.log(xs))))
reduction_out = pt.sum(elemwise_out)

We want to perform a single pass through xs. The evaluation pseudo-code in Python would look something like:

reduction_out = 0  # Initial value of the Add CAReduce
for x in xs.ravel():
  reduction_out += exp(sin(cos(log(x))))

This requires either extending the CAReduce (or creating a new Op) that represents that new type of composite operation. Ideally it would support Composites with multiple reduced and non-reduced outputs and multiple types of reduction (say And in one output and Add in another. I am not sure about axis... So we might only apply it to the None case for now, where all axis are reduced.

As we want to discontinue the C backend, I think we should focus on implementing this for the Numba backend alone (and Python, for debugging purposes). The JAX backend probably does optimizations of this kind directly and we are not working at such a low level there anyway, so we wouldn't do anything about JAX either.

Steps

  1. Implement new Op that performs Elemwise + partial CAReduce operations on certain outputs. This Op must contain all the meta-information needed to create the write evaluation function. This Op need not have gradients nor C or JAX implementation.
  2. Implement Python perform method (for debugging purposes, arguably lower priority, but it has proved useful to debug and understand other complex Ops like Scan)
  3. Implement Numba perform method (some foundational work was done on A more low level implementation of vectorize in numba #92, see below)
  4. Implement Numba-only rewrite that replaces Composite + CAReduce by new Op

Relevant links

@ricardoV94 ricardoV94 added numba help wanted Extra attention is needed performance labels Feb 20, 2023
@ricardoV94 ricardoV94 pinned this issue Feb 20, 2023
@ricardoV94 ricardoV94 changed the title Fuse Composite Ops with CAReduce Ops in Numba backend Fuse Composite Ops with CAReduce Ops in Numba/C backend Jun 30, 2023
@ricardoV94 ricardoV94 unpinned this issue Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed numba performance
Projects
None yet
Development

No branches or pull requests

1 participant