Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
eeb9f8d
cleanup
mayank31398 Mar 9, 2025
493a090
cleanup
mayank31398 Mar 9, 2025
dbfb464
cleanup
mayank31398 Mar 9, 2025
f36505d
cleanup
mayank31398 Mar 9, 2025
d8a4554
cleanup
mayank31398 Mar 9, 2025
6e2c333
cleanup
mayank31398 Mar 9, 2025
a7f196d
cleanup
mayank31398 Mar 9, 2025
ec679b6
cleanup
mayank31398 Mar 10, 2025
99b393a
cleanup
mayank31398 Mar 10, 2025
30d3282
cleanup
mayank31398 Mar 10, 2025
43c4757
cleanup
mayank31398 Mar 10, 2025
d9c2edf
cleanup
mayank31398 Mar 10, 2025
ec182dc
cleanup
mayank31398 Mar 10, 2025
cc29efa
cleanup
mayank31398 Mar 10, 2025
33ae092
cleanup
mayank31398 Mar 10, 2025
4a3afd3
cleanup
mayank31398 Mar 10, 2025
0a599d5
cleanup
mayank31398 Mar 10, 2025
9ad8c07
cleanup
mayank31398 Mar 10, 2025
6598c86
cleanup
mayank31398 Mar 11, 2025
321c6ed
cleanup
mayank31398 Mar 11, 2025
30e6ab0
cleanup
mayank31398 Mar 11, 2025
a9b6d9b
cleanup
mayank31398 Mar 11, 2025
3f7f621
cleanup
mayank31398 Mar 11, 2025
d4db2cb
cleanup
mayank31398 Mar 11, 2025
5469340
cleanup
mayank31398 Mar 11, 2025
af57a4e
cleanup
mayank31398 Mar 11, 2025
b19f8c4
cleanup
mayank31398 Mar 11, 2025
80ed796
cleanup
mayank31398 Mar 11, 2025
a5f46dc
cleanup
mayank31398 Mar 11, 2025
c13bdd0
cleanup
mayank31398 Mar 11, 2025
9b1e315
cleanup
mayank31398 Mar 12, 2025
981f5a8
cleanup
mayank31398 Mar 12, 2025
807ee28
cleanup
mayank31398 Mar 12, 2025
b6acce7
cleanup
mayank31398 Mar 12, 2025
7a887bb
cleanup
mayank31398 Mar 12, 2025
f8926de
cleanup
mayank31398 Mar 12, 2025
fdf58ba
cleanup
mayank31398 Mar 13, 2025
3c38c2b
cleanup
mayank31398 Mar 13, 2025
0bbf2d1
cleanup
mayank31398 Mar 13, 2025
0daac79
cleanup
mayank31398 Mar 13, 2025
0a5fd5a
cleanup
mayank31398 Mar 13, 2025
862d66c
cleanup
mayank31398 Mar 13, 2025
65caadf
cleanup
mayank31398 Mar 13, 2025
ee6b70a
cleanup
mayank31398 Mar 13, 2025
b071432
cleanup
mayank31398 Mar 13, 2025
6ea8471
cleanup
mayank31398 Mar 13, 2025
72df3aa
cleanup
mayank31398 Mar 13, 2025
4042721
cleanup
mayank31398 Mar 13, 2025
a674828
cleanup
mayank31398 Mar 13, 2025
c9011f5
cleanup
mayank31398 Mar 13, 2025
f447331
cleanup
mayank31398 Mar 13, 2025
3bc245b
cleanup
mayank31398 Mar 13, 2025
f5cd1a4
cleanup
mayank31398 Mar 13, 2025
fa34ba0
cleanup
mayank31398 Mar 13, 2025
0283885
cleanup
mayank31398 Mar 13, 2025
e582e55
cleanup
mayank31398 Mar 13, 2025
5d142bb
cleanup
mayank31398 Mar 13, 2025
58e9391
cleanup
mayank31398 Mar 13, 2025
9cdbb5e
cleanup
mayank31398 Mar 13, 2025
4ce0719
cleanup
mayank31398 Mar 13, 2025
005f934
cleanup
mayank31398 Mar 13, 2025
833f07b
cleanup
mayank31398 Mar 13, 2025
a2245df
cleanup
mayank31398 Mar 13, 2025
1d3233c
cleanup
mayank31398 Mar 13, 2025
bb017fc
cleanup
mayank31398 Mar 14, 2025
4e7b529
cleanup
mayank31398 Mar 14, 2025
3958676
cleanup
mayank31398 Mar 14, 2025
3523e74
add inductor pass
mayank31398 Sep 12, 2025
923dd5f
add inductor pass
mayank31398 Sep 12, 2025
2fc036c
add inductor pass
mayank31398 Sep 12, 2025
f2f90a0
add inductor pass
mayank31398 Sep 12, 2025
fb473a1
add inductor pass
mayank31398 Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fma/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from enum import Enum

import torch

from .math import get_powers_of_2
Expand All @@ -27,3 +29,7 @@
torch.int64: 8,
torch.uint64: 8,
}


class Kernel(Enum):
fused_residual_add_rmsnorm = "fused_residual_add_rmsnorm"
33 changes: 33 additions & 0 deletions fma/inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,42 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

from functools import partial

import torch
from torch._inductor.fx_passes.joint_graph import patterns
from torch._inductor.pattern_matcher import fwd_only, joint_fwd_bwd, register_replacement

from .constants import Kernel
from .kernel_backend import KernelBackend
from .ops import fused_residual_add_rmsnorm


def init_inductor(cache_size_limit: int) -> None:
torch._dynamo.config.cache_size_limit = cache_size_limit
torch._dynamo.config.accumulated_cache_size_limit = 1024


_REPLACEMENT_PATTERNS = {
Kernel.fused_residual_add_rmsnorm: (
partial(fused_residual_add_rmsnorm, kernel_backend=KernelBackend.torch),
partial(fused_residual_add_rmsnorm, kernel_backend=KernelBackend.triton),
None,
)
}


def enable_kernels(kernels: list[Kernel]) -> None:
for kernel in kernels:
search_function, replacement_function, example_inputs_function = _REPLACEMENT_PATTERNS[kernel]

for trace_function in [joint_fwd_bwd, fwd_only]:
register_replacement(
search_fn=search_function,
replace_fn=replacement_function,
example_inputs=example_inputs_function(),
trace_fn=trace_function,
pass_dicts=patterns,
)
1 change: 0 additions & 1 deletion fma/kernel_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from contextlib import contextmanager
from enum import Enum

from .cutotune import CutoTuneParameter
Expand Down