Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Iluvatar] fix kron scatter randperm quantile #500

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/flag_gems/ops/randperm.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def radix_sortbykey_scatter_kernel(
+ ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
+ bin_id,
partial_counter,
cache_modifier=".wt",
cache_modifier=".cg",
)
bin_offset = p * (bins + 1) + bin_id
prefix_offsets = tl.load(
Expand All @@ -242,7 +242,7 @@ def radix_sortbykey_scatter_kernel(
+ ((portion_id * passes + p) * max_tiles_per_portion + pid0) * bins
+ bin_id,
global_counter,
cache_modifier=".wt",
cache_modifier=".cg",
)
inc_bucket_offset = prefix_offsets.to(tl.int64) + inc_sum.to(tl.int64)
if last_block and portion_id < num_portions - 1:
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate_scatter_kernel(

code.writeline("def heur_block(args):")
with code.indent():
code.writeline("if(flag_gems.vendor_name=='metax'):")
code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):")
with code.indent():
code.writeline("return 256")
code.writeline("return 128")
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/runtime/backend/_iluvatar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
vendor_name="iluvatar", device_name="cuda", device_query_cmd="ixsmi"
)

CUSTOMIZED_UNUSED_OPS = ("scatter", "quantile", "randperm", "mv")
CUSTOMIZED_UNUSED_OPS = ()

__all__ = ["*"]
9 changes: 7 additions & 2 deletions src/flag_gems/runtime/backend/_iluvatar/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .bmm import bmm
from .div import div_mode, floor_divide, remainder, true_divide
from .div import div_mode, div_mode_
from .mm import mm

__all__ = ["bmm", "mm", "div_mode", "floor_divide", "remainder", "true_divide"]
__all__ = [
"bmm",
"mm",
"div_mode",
"div_mode_",
]
46 changes: 45 additions & 1 deletion src/flag_gems/runtime/backend/_iluvatar/ops/div.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def true_divide(A, B):
return torch.tensor(A / B)


def true_divide_(A, B):
logging.debug("GEMS TRUE_DIVIDE_")
if isinstance(B, torch.Tensor):
return true_div_func(A, B, out0=A)
else:
return true_div_func_tensor_scalar(A, B, out0=A)


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def trunc_div_func(x, y):
Expand All @@ -62,7 +70,7 @@ def trunc_div_func_scalar_tensor(x, y):


def trunc_divide(A, B):
logging.debug("GEMS TRUNC_DIVIDE iluvatar")
logging.debug("GEMS TRUNC_DIVIDE")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return trunc_div_func(A, B)
elif isinstance(A, torch.Tensor):
Expand All @@ -74,6 +82,14 @@ def trunc_divide(A, B):
return torch.tensor(A / B)


def trunc_divide_(A, B):
logging.debug("GEMS TRUNC_DIVIDE_")
if isinstance(B, torch.Tensor):
return trunc_div_func(A, B, out0=A)
else:
return trunc_div_func_tensor_scalar(A, B, out0=A)


@triton.jit
def _int_floordiv(x, y):
# TODO: request Triton to add an integer remainder builtin
Expand Down Expand Up @@ -167,6 +183,14 @@ def floor_divide(A, B):
return torch.tensor(A // B)


def floor_divide_(A, B):
logging.debug("GEMS FLOOR_DIVIDE_")
if isinstance(B, torch.Tensor):
return floor_div_func(A, B, out0=A)
else:
return floor_div_func_tensor_scalar(A, B, out0=A)


def div_mode(A, B, rounding_mode=None):
if rounding_mode is None:
return true_divide(A, B)
Expand All @@ -179,6 +203,18 @@ def div_mode(A, B, rounding_mode=None):
raise ValueError(msg)


def div_mode_(A, B, rounding_mode=None):
if rounding_mode is None:
return true_divide_(A, B)
elif rounding_mode == "trunc":
return trunc_divide_(A, B)
elif rounding_mode == "floor":
return floor_divide_(A, B)
else:
msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
raise ValueError(msg)


@triton.jit
def _remainder(x, y):
r = x % y
Expand Down Expand Up @@ -216,3 +252,11 @@ def remainder(A, B):
else:
# Both scalar
return torch.tensor(A % B)


def remainder_(A, B):
logging.debug("GEMS REMAINDER_")
if isinstance(B, torch.Tensor):
return rem_tt(A, B, out0=A)
else:
return rem_ts(A, B, out0=A)
18 changes: 18 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3415,3 +3415,21 @@ batch_norm:
- 8
- 16
- 32
kron:
- gen: true
param_map:
META:
BLOCK_M: block_m
BLOCK_N: block_n
num_warps: warps
block_m:
- 1
- 2
- 4
- 8
block_n:
- 1024
- 2048
warps:
- 4
- 8
2 changes: 1 addition & 1 deletion tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def test_accuracy_trunc_div_(shape, dtype):

inp1 = torch.randn(shape, dtype=dtype, device="cpu").to(flag_gems.device)
inp2 = torch.randn(shape, dtype=dtype, device="cpu").to(flag_gems.device)
upcast = True if flag_gems.vendor_name not in ["kunlunxin"] else False
upcast = True if flag_gems.vendor_name not in ["kunlunxin", "iluvatar"] else False
ref_inp1 = to_reference(inp1, upcast)
ref_inp2 = to_reference(inp2, upcast)

Expand Down
Loading