Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op #731

Merged
merged 4 commits into from
Jun 13, 2024

Conversation

Dhruvanshu-Joshi
Copy link
Member

Description

MaxandArgmax Op calculates both maximum and argmax together. With this PR, we aim to have seperate ops for the two operations.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@Dhruvanshu-Joshi Dhruvanshu-Joshi force-pushed the remove_MaxArgmax branch 2 times, most recently from 1d9a484 to a278272 Compare May 13, 2024 17:59
@Dhruvanshu-Joshi Dhruvanshu-Joshi marked this pull request as ready for review May 13, 2024 18:01
@Dhruvanshu-Joshi
Copy link
Member Author

The tests failing are because of uint64 data type which is highlighted in #770 . So for this to be ready, should I just remove the test for uint64 for now and open another issue to add support back for this test once #770 is solved?

@ricardoV94
Copy link
Member

The tests failing are because of uint64 data type which is highlighted in #770 . So for this to be ready, should I just remove the test for uint64 for now and open another issue to add support back for this test once #770 is solved?

You can mark the test with pytest.mark.xfail. There are a couple of examples in the codebase

Copy link

codecov bot commented May 22, 2024

Codecov Report

Attention: Patch coverage is 81.69014% with 13 lines in your changes missing coverage. Please review.

Project coverage is 80.87%. Comparing base (15b90be) to head (d953a0d).
Report is 227 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/math.py 78.57% 6 Missing and 6 partials ⚠️
pytensor/link/jax/dispatch/nlinalg.py 92.30% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #731      +/-   ##
==========================================
+ Coverage   80.83%   80.87%   +0.03%     
==========================================
  Files         162      163       +1     
  Lines       46862    46847      -15     
  Branches    11465    11463       -2     
==========================================
+ Hits        37881    37887       +6     
- Misses       6733     6747      +14     
+ Partials     2248     2213      -35     
Files with missing lines Coverage Δ
pytensor/ifelse.py 51.70% <ø> (ø)
pytensor/link/numba/dispatch/elemwise.py 91.85% <100.00%> (+3.13%) ⬆️
pytensor/tensor/rewriting/uncanonicalize.py 96.63% <100.00%> (+0.42%) ⬆️
pytensor/link/jax/dispatch/nlinalg.py 90.47% <92.30%> (+0.73%) ⬆️
pytensor/tensor/math.py 90.43% <78.57%> (+0.77%) ⬆️

... and 37 files with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there, just some cleanup needed

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small notes. Reverting the white space changes is useful, because git diff will then only show the files that have meaningful changes from this PR

Comment on lines 110 to 116
def __getattr__(name):
if name == "MaxandArgmax":
warnings.warn(
"The class `MaxandArgmax` has been deprecated. "
"Call `Max` and `Argmax` seperately as an alternative.",
FutureWarning,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to raise some sort of error because MaxAndArgmax no longer exist. It can't simply be a warning. I think it's better to just remove it

Copy link
Member Author

@Dhruvanshu-Joshi Dhruvanshu-Joshi Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise an AttributeError here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever error would be raised if this helper was not here, AttrubetError sounds about right but confirm with trying without the special code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helper also needs to raise the standard error for things other than MaxandArgmax

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that AttributeError is raised

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how we implement it in pymc for reference: https://github.com/pymc-devs/pymc/blob/f44071bdda363f743548187d3e124a027adfdb77/pymc/distributions/transforms.py#L54-L67

Although according to the PEP it should probably have a !r in the standard AttributeError message: https://peps.python.org/pep-0562/#rationale

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our case, we won't have to give a warning right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes in our case we just raise a more informative AttributeError instead of the default one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your last change you made this useless, it's the same as if we didn't implement __getattr__ at all. The point of adding it is to have a custom message for the deprecated Op, but have the default behavior for everything else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rectified this

@ricardoV94
Copy link
Member

Almost there, looks great. Just 3 unresolved comments above

@@ -1404,6 +1422,11 @@ def test_bool(self):
assert np.all(i)


def test_MaxAndArgmax_deprecated():
with pytest.raises(AttributeError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the specific message we expect for trying to access the MaxAndArgmax

Suggested change
with pytest.raises(AttributeError):
with pytest.raises(AttributeError, match=...):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@ricardoV94
Copy link
Member

Some tests failing, it seems there are still some internal imports of MaxAndArgmax

@Dhruvanshu-Joshi
Copy link
Member Author

Some tests failing, it seems there are still some internal imports of MaxAndArgmax

My bad. I rectified this. JAX tests are skipped locally so could not test them anywhere other than the CI. Hopefully, they'll pass now.

@ricardoV94 ricardoV94 merged commit e6e6d69 into pymc-devs:main Jun 13, 2024
55 of 56 checks passed
@ricardoV94
Copy link
Member

Thanks @Dhruvanshu-Joshi !

@ricardoV94 ricardoV94 mentioned this pull request Jul 8, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove MaxAndArgmax Op
2 participants