Skip to content

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Dec 1, 2025

Description

This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.

For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ptrendx
Copy link
Member Author

ptrendx commented Dec 1, 2025

/te-ci pytorch

Comment on lines +644 to +645
def fast_set_attr(self, name: str, value: Any) -> None:
self.__dict__[name] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:

  • dict read: 9 ns
  • dict write: 13 ns
  • dict in: 9 ns
  • dict.get: 14 ns
  • Function call: 9 ns
  • Class attr read: 3 ns
  • Class attr write: 5 ns
  • Class custom getattr: 101 ns
  • Class custom setattr: 134 ns
Benchmarking script

I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.

import contextlib
import time

class Timer:
    """Measure time interval."""

    def __init__(self) -> None:
        self._start = None
        self._end = None

    def time(self) -> float:
	"""CPU time interval in seconds."""
        return self._end - self._start

    @contextlib.contextmanager
    def context(self):
        """Context manager to capture time interval."""
	self._start = time.perf_counter()
        yield
        self._end = time.perf_counter()

def main() -> None:

    # Options
    iters = 1024 * 1024

    # Timer
    timer = Timer()

    # Dummy data
    str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
    str_list = [str_list[i % len(str_list)] for i in range(iters)]
    str_dict = {s: len(s) for s in str_list}
    class PlainClass:
        def __init__(self) -> None:
            self.attr = 1
    class CustomGetattrSetattrClass:
        def __init__(self) -> None:
            self.attr = 1
        def __getattribute__(self, name):
            return super().__getattribute__(name)
	def __setattr__(self, name, val):
            super().__setattr__(name, val)

    # Timer overhead
    with timer.context():
        pass
    print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")

    # Range loop
    with timer.context():
        for _ in range(iters):
            pass
    print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")

    # List loop
    with timer.context():
        for _ in str_list:
            pass
    print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, j in enumerate(range(iters)):
            pass
    print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, s in enumerate(str_list):
            pass
    print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # List reads
    with timer.context():
        for i in range(iters):
            str_list[i]
    print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict reads
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]]
    print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict get
    with timer.context():
        for i in range(iters):
            str_dict.get(str_list[i], None)
    print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")

    # Dict writes
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]] = i
    print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")

    # Dict membership
    with timer.context():
        for i in range(iters):
            str_list[i] in str_dict
    print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    def func() -> None:
        pass
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    func = lambda: None
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr read
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr write
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            getattr(myobj, "attr", None)
    print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            setattr(myobj, "attr", i)
    print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom getattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom setattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")

if __name__ == "__main__":
    main()

How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.

Copy link
Collaborator

@timmoon10 timmoon10 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we then make sure that this does not resurface in the future?

Comment on lines 1076 to 1078
# with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if _nvtx_enabled():
torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvtx_range_push/nvte_range_pop does the same thing and is slightly cleaner:

def nvtx_range_push(msg: str) -> None:

def nvtx_range_pop(msg: Optional[str] = None) -> None:

Suggested change
# with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if _nvtx_enabled():
torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward")
nvtx_range_push(self.__class__.__name__ + " forward")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will change it :-).

Comment on lines 1028 to +1040
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sucks that we are replacing a context with manual pre- and post-function calls. However, I've found that an empty contextmanager takes ~900 ns, so cutting that overhead is worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is quite expensive unfortunately.

@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from 5eefe3e to 1c7d896 Compare December 2, 2025 22:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants