Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2dde384
cleanup
mayank31398 Oct 21, 2025
ed43ec2
cleanup
mayank31398 Oct 21, 2025
fb3fb84
cleanup
mayank31398 Oct 21, 2025
7bcb57a
cleanup
mayank31398 Oct 21, 2025
722f554
cleanup
mayank31398 Oct 21, 2025
54730bb
cleanup
mayank31398 Oct 21, 2025
9f3c8fd
cleanup
mayank31398 Oct 21, 2025
316d290
cleanup
mayank31398 Oct 21, 2025
2de9a80
cleanup
mayank31398 Oct 21, 2025
68e2774
cleanup
mayank31398 Oct 21, 2025
3b01156
cleanup
mayank31398 Oct 21, 2025
69cfbcf
cleanup
mayank31398 Oct 21, 2025
47147b0
cleanup
mayank31398 Oct 21, 2025
5004ce8
cleanup
mayank31398 Oct 21, 2025
28038a2
cleanup
mayank31398 Oct 21, 2025
8f94ef6
cleanup
mayank31398 Oct 21, 2025
c1fb0d0
cleanup
mayank31398 Oct 21, 2025
34ff646
cleanup
mayank31398 Oct 21, 2025
0c6eec3
cleanup
mayank31398 Oct 21, 2025
3420f3b
cleanup
mayank31398 Oct 21, 2025
9f33a05
cleanup
mayank31398 Oct 21, 2025
8a5f372
cleanup
mayank31398 Oct 21, 2025
4a59708
cleanup
mayank31398 Oct 21, 2025
6c7edac
cleanup
mayank31398 Oct 21, 2025
18a2390
cleanup
mayank31398 Oct 21, 2025
ec3c61a
cleanup
mayank31398 Oct 21, 2025
90051d3
cleanup
mayank31398 Oct 21, 2025
f08bacc
cleanup
mayank31398 Oct 21, 2025
47f6fb3
cleanup
mayank31398 Oct 21, 2025
bd9c8e4
cleanup
mayank31398 Oct 21, 2025
1ea250e
cleanup
mayank31398 Oct 21, 2025
5f79f22
cleanup
mayank31398 Oct 21, 2025
b06b247
cleanup
mayank31398 Oct 21, 2025
a468f16
cleanup
mayank31398 Oct 21, 2025
e930594
cleanup
mayank31398 Oct 21, 2025
59b5c27
cleanup
mayank31398 Oct 21, 2025
51fc1df
cleanup
mayank31398 Oct 21, 2025
7cfa190
cleanup
mayank31398 Oct 21, 2025
e505a16
cleanup
mayank31398 Oct 21, 2025
1d13d77
cleanup
mayank31398 Oct 21, 2025
8c5eb62
cleanup
mayank31398 Oct 21, 2025
792e524
cleanup
mayank31398 Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_cartesian_product_cutotune_configs,
get_cutotune_cache,
)
from .enums import KernelBackend
from .enums import KernelBackend, force_kernel_backend
from .functional import (
bmm,
continuous_count,
Expand Down
42 changes: 36 additions & 6 deletions fma/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,50 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

from contextlib import contextmanager
from enum import Enum

from .cutotune import CutoTuneParameter
import torch


_IS_ROCM_AVAILABLE = torch.version.hip is not None
_FORCED_KERNEL_BACKEND = None


@contextmanager
def force_kernel_backend(kernel_backend: KernelBackend):
global _FORCED_KERNEL_BACKEND

original_value = _FORCED_KERNEL_BACKEND
_FORCED_KERNEL_BACKEND = kernel_backend

yield

_FORCED_KERNEL_BACKEND = original_value


class KernelBackend(Enum):
cuda = "cuda"
torch = "torch"
rocm = "rocm"
tpu = "pallas"
# for triton compatible accelerators
triton = "triton"
torch = "torch"

@staticmethod
def get_kernel_backend_from_device(x: torch.Tensor) -> KernelBackend:
global _FORCED_KERNEL_BACKEND

def is_cuda_kernel_backend_allowed(kernel_backend: KernelBackend) -> bool:
return isinstance(kernel_backend, CutoTuneParameter) or kernel_backend in [None, KernelBackend.cuda]
if _FORCED_KERNEL_BACKEND is not None:
return _FORCED_KERNEL_BACKEND

device_type = x.device.type

def is_triton_kernel_backend_allowed(kernel_backend: KernelBackend) -> bool:
return isinstance(kernel_backend, CutoTuneParameter) or kernel_backend in [None, KernelBackend.triton]
if device_type == "cuda":
return KernelBackend.rocm if _IS_ROCM_AVAILABLE else KernelBackend.cuda
elif device_type == "xla":
return KernelBackend.tpu
else:
return KernelBackend.triton
9 changes: 3 additions & 6 deletions fma/functional/bmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from ...cutotune import CutoTuneParameter
from ...enums import KernelBackend
from .triton_implementation import bmm_triton

Expand All @@ -17,8 +16,6 @@ def bmm(
is_B_transposed: bool = False,
alpha: float = 1,
beta: float = 1,
*,
kernel_backend: KernelBackend | CutoTuneParameter = KernelBackend.triton,
) -> torch.Tensor:
"""computes `alpha` * (`A` @ `B`) + `beta` * `C`

Expand All @@ -30,8 +27,6 @@ def bmm(
is_B_transposed (bool, optional): whether B has shape N x K. Defaults to False.
alpha (float, optional): alpha. Defaults to 1.
beta (float, optional): beta. Defaults to 1.
kernel_backend (KernelBackend | CutoTuneParameter, optional): kernel backend to prioritize.
Defaults to KernelBackend.triton.

Raises:
ValueError: if unexpected `kernel_backend` is passed
Expand All @@ -56,6 +51,8 @@ def bmm(
assert C is not None
assert C.size() == (L, M, N)

kernel_backend = KernelBackend.get_kernel_backend_from_device(A)

if kernel_backend == KernelBackend.torch:
if is_A_transposed:
A = A.transpose(1, 2)
Expand All @@ -69,7 +66,7 @@ def bmm(
D = alpha * D
else:
D = torch.baddbmm(C, A, B, alpha=alpha, beta=beta)
elif kernel_backend == KernelBackend.triton:
elif kernel_backend in [KernelBackend.cuda, KernelBackend.triton]:
D = torch.empty(L, M, N, dtype=A.dtype, device=A.device)

bmm_triton(
Expand Down
10 changes: 4 additions & 6 deletions fma/functional/continuous_count/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@


@torch.no_grad()
def continuous_count(
x: torch.Tensor, size: int, *, kernel_backend: KernelBackend = KernelBackend.cuda
) -> torch.Tensor:
def continuous_count(x: torch.Tensor, size: int) -> torch.Tensor:
"""counts the number of occurances of the values [0, 1, ..., `size`) in the input tensor (`size` is excluded).
NOTE: the user is responsible for ensuring that the values lie in the valid range, any values outside this
range are ignored and not counted.

Args:
x (torch.Tensor): input tensor
size (int): values [0, 1, ..., `size`) are counted (`size` is excluded)
kernel_backend (KernelBackend, optional): kernel backend to prioritize.
Defaults to KernelBackend.cuda.

Returns:
torch.Tensor: output tensor
Expand All @@ -32,9 +28,11 @@ def continuous_count(
assert x.dim() == 1, "x should be 1-dimensional"
assert x.dtype in [torch.int32, torch.long]

kernel_backend = KernelBackend.get_kernel_backend_from_device(x)

if kernel_backend == KernelBackend.torch:
output = x.bincount(minlength=size).to(torch.uint32)
elif kernel_backend == KernelBackend.cuda:
elif kernel_backend in [KernelBackend.cuda, KernelBackend.triton]:
output = torch.empty(size, dtype=torch.uint32, device=x.device)
continuous_count_cuda(x=x, output=output, E=size, THREAD_BLOCK_CLUSTER_SIZE=1, BLOCK_SIZE=1024)
else:
Expand Down
12 changes: 4 additions & 8 deletions fma/functional/cross_entropy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor | None]:


def cross_entropy(
x: torch.Tensor,
labels: torch.Tensor,
reduction: str = "mean",
logits_multiplier: float | None = None,
*,
kernel_backend: KernelBackend = KernelBackend.triton,
x: torch.Tensor, labels: torch.Tensor, reduction: str = "mean", logits_multiplier: float | None = None
) -> torch.Tensor:
"""compute cross entropy loss

Expand All @@ -50,8 +45,6 @@ def cross_entropy(
reduction (str, optional): reduction should be either sum or mean. Defaults to "mean".
logits_multiplier (float | None, optional): logits multiplier pre-multiplies logits, None implies 1.
Defaults to None.
kernel_backend (KernelBackend, optional): kernel backend to prioritize.
Defaults to KernelBackend.triton.

Returns:
torch.Tensor: loss
Expand All @@ -64,6 +57,8 @@ def cross_entropy(
labels.size(0) == get_num_elements_and_hidden_size(x)[0]
), "x and labels have different number of elements along batch dimension"

kernel_backend = KernelBackend.get_kernel_backend_from_device(x)

if kernel_backend == KernelBackend.torch:
x = x.float()

Expand All @@ -72,6 +67,7 @@ def cross_entropy(

x = F.cross_entropy(x, labels, reduction=reduction)
else:
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]
x = _CrossEntropy.apply(x, labels, reduction, logits_multiplier)

return x
12 changes: 4 additions & 8 deletions fma/functional/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def fused_linear_cross_entropy(
labels: torch.Tensor,
reduction: str = "mean",
logits_multiplier: float | None = None,
kernel_backend: KernelBackend | CutoTuneParameter = KernelBackend.triton,
) -> torch.Tensor:
"""compute cross entropy loss without materializing the full output logits matrix

Expand All @@ -105,16 +104,13 @@ def fused_linear_cross_entropy(
assert x.size(0) == labels.size(0), "x and labels have different number of elements along dim 0"
assert x.size(-1) == weight.size(-1)

kernel_backend = KernelBackend.get_kernel_backend_from_device(x)

if kernel_backend == KernelBackend.torch:
x = F.linear(x, weight)
x = cross_entropy(
x=x,
labels=labels,
reduction=reduction,
logits_multiplier=logits_multiplier,
kernel_backend=kernel_backend,
)
x = cross_entropy(x=x, labels=labels, reduction=reduction, logits_multiplier=logits_multiplier)
else:
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]
x = _FusedLinearCrossEntropy.apply(x, weight, labels, reduction, logits_multiplier)

return x
5 changes: 3 additions & 2 deletions fma/functional/fused_residual_add_rmsnorm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def fused_residual_add_rmsnorm(
multiplier: float | None = None,
memory_efficient: bool = False,
deterministic: bool = False,
kernel_backend: KernelBackend | CutoTuneParameter = KernelBackend.triton,
) -> tuple[torch.Tensor, torch.Tensor]:
"""fused residual add RMSNorm computation

Expand All @@ -132,6 +131,8 @@ def fused_residual_add_rmsnorm(
assert weight.size(-1) == x.size(-1), "hidden size for x and weight tensor is different"
assert weight.type() == x.type(), "tensors weight and y should have same dtype"

kernel_backend = KernelBackend.get_kernel_backend_from_device(x)

if kernel_backend == KernelBackend.torch:
if multiplier not in [None, 1]:
x = x * multiplier
Expand All @@ -142,7 +143,7 @@ def fused_residual_add_rmsnorm(

x = F.rms_norm(x, normalized_shape=(x.size(-1),), weight=weight, eps=eps)
else:
assert kernel_backend == KernelBackend.triton
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]
increment_counter(fused_residual_add_rmsnorm)

is_flat = x.dim() == 1
Expand Down
4 changes: 4 additions & 0 deletions fma/functional/grouped_gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from ...enums import KernelBackend
from .cuda_implementation import grouped_gemm_cuda


Expand All @@ -24,6 +25,9 @@ def grouped_gemm(
assert beta == 0
assert C is None

kernel_backend = KernelBackend.get_kernel_backend_from_device(A)
assert kernel_backend == KernelBackend.cuda

output = torch.empty(*output_shape, device=A.device, dtype=A.dtype)

grouped_gemm_cuda(
Expand Down
7 changes: 4 additions & 3 deletions fma/functional/gru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from ...cutotune import CutoTuneParameter
from ...enums import KernelBackend
from ...torch_math import clip_gradients, sigmoid, tanh
from ...utils import empty_like_contiguous, zeros_like_contiguous
Expand Down Expand Up @@ -138,8 +137,6 @@ def gru(
gradient_clipping: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | int | None = None,
*,
kernel_backend: KernelBackend | CutoTuneParameter = KernelBackend.triton,
) -> tuple[torch.Tensor, torch.Tensor]:
"""computes multihead RNN: tanh(`input_state` @ `weight` + `input`)

Expand All @@ -164,6 +161,8 @@ def gru(
N, H = input.size()[-2:]
assert weight.size() == (N, H, H)

kernel_backend = KernelBackend.get_kernel_backend_from_device(input)

if gradient_clipping is not None and gradient_clipping < 0:
gradient_clipping = -gradient_clipping

Expand Down Expand Up @@ -252,6 +251,8 @@ def gru(
output[offset_unfinished] = new_state
input_state[unfinished] = new_state
else:
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]

output = _GRU.apply(
input,
weight,
Expand Down
7 changes: 0 additions & 7 deletions fma/functional/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import torch

from ..cutotune import CutoTuneParameter
from ..enums import KernelBackend
from .fused_residual_add_rmsnorm import fused_residual_add_rmsnorm


Expand All @@ -15,8 +13,6 @@ def rmsnorm(
eps: float | None,
memory_efficient: bool = False,
deterministic: bool = False,
*,
kernel_backend: KernelBackend | CutoTuneParameter = KernelBackend.triton,
) -> torch.Tensor:
"""RMSNorm computation

Expand All @@ -27,8 +23,6 @@ def rmsnorm(
memory_efficient (bool, optional): memory efficient = False caches RMSNorm's denominator in the forward.
Defaults to False.
deterministic (bool, optional): whether to use deterministic backward. Defaults to False.
kernel_backend (KernelBackend | CutoTuneParameter, optional): kernel backend to prioritize.
Defaults to KernelBackend.triton.

Returns:
torch.Tensor: output tensor
Expand All @@ -42,7 +36,6 @@ def rmsnorm(
multiplier=None,
memory_efficient=memory_efficient,
deterministic=deterministic,
kernel_backend=kernel_backend,
)

return x
6 changes: 3 additions & 3 deletions fma/functional/rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def rnn(
gradient_clipping: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | int | None = None,
*,
kernel_backend: KernelBackend = KernelBackend.triton,
) -> tuple[torch.Tensor, torch.Tensor]:
"""computes multihead RNN recurrent update over the sequence length: tanh(`input_state` @ `weight` + `input`)

Expand All @@ -93,7 +91,6 @@ def rnn(
implies no clipping. Defaults to None.
cu_seqlens (torch.Tensor | None, optional): cumulative sequence length (must contain 0 as first element). Defaults to None.
max_seqlen (torch.Tensor | int | None, optional): max sequence length in the batch. Defaults to None.
kernel_backend (KernelBackend, optional): kernel backend to prioritize. Defaults to KernelBackend.triton.

Returns:
tuple[torch.Tensor, torch.Tensor]: output tensor of shape (B, S, N, H) and output state tensor of shape (B, N, H)
Expand All @@ -105,6 +102,8 @@ def rnn(
N, H = input.size()[-2:]
assert weight.size() == (N, H, H)

kernel_backend = KernelBackend.get_kernel_backend_from_device(input)

if gradient_clipping is not None and gradient_clipping < 0:
gradient_clipping = -gradient_clipping

Expand Down Expand Up @@ -171,6 +170,7 @@ def rnn(
output[offset_unfinished] = new_state
input_state[unfinished] = new_state
else:
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]
output = _RNN.apply(input, weight, input_state, gradient_clipping, cu_seqlens, max_seqlen)

output_state = output[:, -1] if cu_seqlens is None else output[cu_seqlens[1:] - 1]
Expand Down
Loading