-
Notifications
You must be signed in to change notification settings - Fork 150
Rewrite concatenate([x, x]) as repeat(x, 2) #1714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Let's try with |
|
Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528 |
| new_s = rewrite_graph(s) | ||
| assert equal_computations([new_s], [join(1, mat, mat, mat)]) | ||
| assert new_s.dtype == s.dtype | ||
| # Compare to the expected form (without rewriting expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: these kind of comments don't make sense in isolation. It's the sort of stuff LLMs do when you prompt to change something they did before.
Similar to th4 comments above with "now" and other qualifiers that don't make sense in the long term
| assert f.maker.fgraph.outputs[0].dtype == config.floatX | ||
|
|
||
| # test we don't apply when their is 2 inputs | ||
| # test that join with 2 identical inputs now gets optimized to tile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the other previous test change it so it isn't the same input repeated
Add optimization for Join → Repeat when concatenating identical tensors
Description
This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.
Example:
join(0, x, x, x) → repeat(x, 3, axis=0)
Key additions:
Related Issue
Checklist
Type of change