-
Notifications
You must be signed in to change notification settings - Fork 570
[pyTorch] CPU performance optimizations #2439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/te-ci pytorch |
| def fast_set_attr(self, name: str, value: Any) -> None: | ||
| self.__dict__[name] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:
dictread: 9 nsdictwrite: 13 nsdictin: 9 nsdict.get: 14 ns- Function call: 9 ns
- Class attr read: 3 ns
- Class attr write: 5 ns
- Class custom getattr: 101 ns
- Class custom setattr: 134 ns
Benchmarking script
I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.
import contextlib
import time
class Timer:
"""Measure time interval."""
def __init__(self) -> None:
self._start = None
self._end = None
def time(self) -> float:
"""CPU time interval in seconds."""
return self._end - self._start
@contextlib.contextmanager
def context(self):
"""Context manager to capture time interval."""
self._start = time.perf_counter()
yield
self._end = time.perf_counter()
def main() -> None:
# Options
iters = 1024 * 1024
# Timer
timer = Timer()
# Dummy data
str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
str_list = [str_list[i % len(str_list)] for i in range(iters)]
str_dict = {s: len(s) for s in str_list}
class PlainClass:
def __init__(self) -> None:
self.attr = 1
class CustomGetattrSetattrClass:
def __init__(self) -> None:
self.attr = 1
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name, val):
super().__setattr__(name, val)
# Timer overhead
with timer.context():
pass
print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")
# Range loop
with timer.context():
for _ in range(iters):
pass
print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")
# List loop
with timer.context():
for _ in str_list:
pass
print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")
# Empty range+enumerate loop
with timer.context():
for i, j in enumerate(range(iters)):
pass
print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")
# Empty range+enumerate loop
with timer.context():
for i, s in enumerate(str_list):
pass
print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")
# List reads
with timer.context():
for i in range(iters):
str_list[i]
print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")
# Dict reads
with timer.context():
for i in range(iters):
str_dict[str_list[i]]
print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")
# Dict get
with timer.context():
for i in range(iters):
str_dict.get(str_list[i], None)
print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")
# Dict writes
with timer.context():
for i in range(iters):
str_dict[str_list[i]] = i
print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")
# Dict membership
with timer.context():
for i in range(iters):
str_list[i] in str_dict
print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")
# Function call
def func() -> None:
pass
with timer.context():
for _ in range(iters):
func()
print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")
# Function call
func = lambda: None
with timer.context():
for _ in range(iters):
func()
print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")
# Class attr read
myobj = PlainClass()
with timer.context():
for _ in range(iters):
_ = myobj.attr
print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")
# Class attr write
myobj = PlainClass()
with timer.context():
for i in range(iters):
myobj.attr = i
print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")
# getattr
myobj = PlainClass()
with timer.context():
for _ in range(iters):
getattr(myobj, "attr", None)
print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")
# getattr
myobj = PlainClass()
with timer.context():
for i in range(iters):
setattr(myobj, "attr", i)
print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")
# Class custom getattr
myobj = CustomGetattrSetattrClass()
with timer.context():
for _ in range(iters):
_ = myobj.attr
print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")
# Class custom setattr
myobj = CustomGetattrSetattrClass()
with timer.context():
for i in range(iters):
myobj.attr = i
print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")
if __name__ == "__main__":
main()How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would we then make sure that this does not resurface in the future?
| # with get_nvtx_range_context(self.__class__.__name__ + " forward"): | ||
| if _nvtx_enabled(): | ||
| torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvtx_range_push/nvte_range_pop does the same thing and is slightly cleaner:
| def nvtx_range_push(msg: str) -> None: |
| def nvtx_range_pop(msg: Optional[str] = None) -> None: |
| # with get_nvtx_range_context(self.__class__.__name__ + " forward"): | |
| if _nvtx_enabled(): | |
| torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward") | |
| nvtx_range_push(self.__class__.__name__ + " forward") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will change it :-).
| def prepare_forward( | ||
| self, | ||
| inp: torch.Tensor, | ||
| num_gemms: int = 1, | ||
| allow_non_contiguous: bool = False, | ||
| allow_different_data_and_param_types: bool = False, | ||
| ) -> Generator[torch.Tensor, None, None]: | ||
| """Checks and prep for FWD. | ||
| The context manager is needed because there isn't a way for a module to know | ||
| if it's the last FP8 module in the forward autocast. It is useful | ||
| to setup the forward aggregated amax reduction for every module | ||
| just in case. The autocast exit will pick up the most recent one. | ||
| """ | ||
| self.allow_different_data_and_param_types = allow_different_data_and_param_types | ||
| self.forwarded_at_least_once = True | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sucks that we are replacing a context with manual pre- and post-function calls. However, I've found that an empty contextmanager takes ~900 ns, so cutting that overhead is worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is quite expensive unfortunately.
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
5eefe3e to
1c7d896
Compare
Signed-off-by: Przemek Tredak <[email protected]>
Description
This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.
For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: