Skip to content

Conversation

@tchan102
Copy link

@tchan102 tchan102 commented Nov 4, 2025

Add optimization for Join → Repeat when concatenating identical tensors

Description

This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.

Example:
join(0, x, x, x) → repeat(x, 3, axis=0)

Key additions:

  • Implemented new rewrite function local_join_to_repeat registered under both @register_canonicalize and @register_specialize.
  • Added corresponding test test_local_join_to_repeat to verify correctness, performance, and behavior for vectors and matrices.

Related Issue

Checklist

Type of change

  • [ x] New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added graph rewriting enhancement New feature or request labels Nov 4, 2025
@ricardoV94
Copy link
Member

Let's try with @register_canonicalize only

@ricardoV94
Copy link
Member

Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528

@tchan102 tchan102 requested a review from ricardoV94 November 9, 2025 00:47
@tchan102 tchan102 requested a review from ricardoV94 November 11, 2025 20:51
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# Compare to the expected form (without rewriting expected)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: these kind of comments don't make sense in isolation. It's the sort of stuff LLMs do when you prompt to change something they did before.

Similar to th4 comments above with "now" and other qualifiers that don't make sense in the long term

assert f.maker.fgraph.outputs[0].dtype == config.floatX

# test we don't apply when their is 2 inputs
# test that join with 2 identical inputs now gets optimized to tile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the other previous test change it so it isn't the same input repeated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite concatenate([x, x]) as repeat(x, 2)

2 participants