Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linspace op #478

Merged
merged 6 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions benchmark/test_tensor_constructor_perf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import random

import pytest
import torch
Expand Down Expand Up @@ -56,6 +57,18 @@ def arange_input_fn(shape, dtype, device):
},


def linspace_input_fn(shape, dtype, device):
limit = torch.finfo(dtype).max - 1
num = int(min(limit, math.prod(shape)))
yield {
"start": 0,
"end": num,
"steps": random.randint(1, num),
"dtype": dtype,
"device": device,
},


# Define operations and their corresponding input functions
tensor_constructor_operations = [
# generic tensor constructor
Expand All @@ -75,6 +88,8 @@ def arange_input_fn(shape, dtype, device):
("full_like", torch.full_like, full_like_input_fn),
# arange
("arange", torch.arange, arange_input_fn),
# linspace
("linspace", torch.linspace, linspace_input_fn),
]


Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("zeros_like", zeros_like, Autograd.disable),
("ones_like", ones_like, Autograd.disable),
("full_like", full_like, Autograd.disable),
("linspace", linspace, Autograd.disable),
("resolve_neg", resolve_neg, Autograd.disable),
("resolve_conj", resolve_conj, Autograd.disable),
("normal.Tensor_float", normal_tensor_float, Autograd.disable),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .isnan import isnan
from .layernorm import layer_norm
from .le import le, le_scalar
from .linspace import linspace
from .log_sigmoid import log_sigmoid
from .log_softmax import log_softmax
from .logical_and import logical_and
Expand Down Expand Up @@ -181,6 +182,7 @@
"zeros",
"ones",
"full",
"linspace",
"native_dropout",
"erf",
"embedding",
Expand Down
61 changes: 61 additions & 0 deletions src/flag_gems/ops/linspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry
from ..utils import triton_lang_extension as tle


@libentry()
@triton.jit
def linspace_kernel(
out_ptr,
out_stride0,
start,
mid,
end,
step_size,
steps,
BLOCK_SIZE: tl.constexpr,
):
pid = tle.program_id(0)
idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = idx < steps
fw_mask = idx < mid
fw_values = start + (step_size * idx)
bd_values = end - step_size * (steps - idx - 1)

out_val = tl.where(fw_mask, fw_values, bd_values)
tl.store(out_ptr + idx * out_stride0, out_val, mask=mask)


def linspace(
start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None
) -> torch.Tensor:
logging.debug("GEMS LINSPACE")
assert steps >= 1, "steps must be >= 1"

out = torch.empty(
steps,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
)
if steps == 1:
return torch.fill(out, start)
else:
if isinstance(start, torch.Tensor):
start = start.item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

picking the item from tensor and then passing it to kernel function might cause unnecessary costing of time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to write 4 kernels for the case where start and end are tensors?

if isinstance(end, torch.Tensor):
end = end.item()
mid = steps // 2
step_size = (float(end) - float(start)) / (steps - 1)
BLOCK_SIZE = 128
grid = (triton.cdiv(steps, BLOCK_SIZE),)
linspace_kernel[grid](
out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=BLOCK_SIZE
)
return out
35 changes: 35 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,41 @@ def test_arange(start, step, end, dtype, device, pin_memory):
gems_assert_equal(res_out, ref_out)


@pytest.mark.linspace
@pytest.mark.parametrize("start", [0, 2, 4])
@pytest.mark.parametrize("end", [256, 2048, 4096])
@pytest.mark.parametrize("steps", [1, 256, 512])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES + [None])
@pytest.mark.parametrize("device", [device, None])
@pytest.mark.parametrize("pin_memory", [False, None])
def test_linspace(start, end, steps, dtype, device, pin_memory):
if TO_CPU:
return
ref_out = torch.linspace(
start,
end,
steps,
dtype=dtype,
layout=None,
device=device,
pin_memory=pin_memory,
)
with flag_gems.use_gems():
res_out = torch.linspace(
start,
end,
steps,
dtype=dtype,
layout=None,
device=device,
pin_memory=pin_memory,
)
if dtype in [torch.float16, torch.bfloat16]:
gems_assert_close(res_out, ref_out, dtype=dtype)
else:
gems_assert_equal(res_out, ref_out)


@pytest.mark.skipif(flag_gems.device == "musa", reason="AssertionError")
@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX")
@pytest.mark.isin
Expand Down