diff --git a/fma/constants.py b/fma/constants.py index d69f8cf44..4e5975dbe 100644 --- a/fma/constants.py +++ b/fma/constants.py @@ -2,6 +2,8 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from enum import Enum + import torch from .math import get_powers_of_2 @@ -27,3 +29,7 @@ torch.int64: 8, torch.uint64: 8, } + + +class Kernel(Enum): + fused_residual_add_rmsnorm = "fused_residual_add_rmsnorm" diff --git a/fma/inductor.py b/fma/inductor.py index f7aa1b239..e31aa1244 100644 --- a/fma/inductor.py +++ b/fma/inductor.py @@ -2,9 +2,42 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** +from __future__ import annotations + +from functools import partial + import torch +from torch._inductor.fx_passes.joint_graph import patterns +from torch._inductor.pattern_matcher import fwd_only, joint_fwd_bwd, register_replacement + +from .constants import Kernel +from .kernel_backend import KernelBackend +from .ops import fused_residual_add_rmsnorm def init_inductor(cache_size_limit: int) -> None: torch._dynamo.config.cache_size_limit = cache_size_limit torch._dynamo.config.accumulated_cache_size_limit = 1024 + + +_REPLACEMENT_PATTERNS = { + Kernel.fused_residual_add_rmsnorm: ( + partial(fused_residual_add_rmsnorm, kernel_backend=KernelBackend.torch), + partial(fused_residual_add_rmsnorm, kernel_backend=KernelBackend.triton), + None, + ) +} + + +def enable_kernels(kernels: list[Kernel]) -> None: + for kernel in kernels: + search_function, replacement_function, example_inputs_function = _REPLACEMENT_PATTERNS[kernel] + + for trace_function in [joint_fwd_bwd, fwd_only]: + register_replacement( + search_fn=search_function, + replace_fn=replacement_function, + example_inputs=example_inputs_function(), + trace_fn=trace_function, + pass_dicts=patterns, + ) diff --git a/fma/kernel_backend.py b/fma/kernel_backend.py index 62d7d0bfe..22408dbdb 100644 --- a/fma/kernel_backend.py +++ b/fma/kernel_backend.py @@ -2,7 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from contextlib import contextmanager from enum import Enum from .cutotune import CutoTuneParameter