diff --git a/include/infiniop.h b/include/infiniop.h index d51b8d92e..ce7239729 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -14,6 +14,7 @@ #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" +#include "infiniop/ops/sigmoid.h" #include "infiniop/ops/sub.h" #include "infiniop/ops/swiglu.h" #include "infiniop/tensor_descriptor.h" diff --git a/include/infiniop/ops/sigmoid.h b/include/infiniop/ops/sigmoid.h new file mode 100644 index 000000000..4fa0f6604 --- /dev/null +++ b/include/infiniop/ops/sigmoid.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_SIGMOID_API_H__ +#define __INFINIOP_SIGMOID_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSigmoidDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSigmoidDescriptor(infiniopHandle_t handle, + infiniopSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopSigmoid(infiniopSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc); + +#endif diff --git a/scripts/python_test.py b/scripts/python_test.py index eb2d4319e..9086b5708 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -22,6 +22,7 @@ def run_tests(args): "rearrange.py", "rms_norm.py", "rope.py", + "sigmoid.py", "sub.py", "swiglu.py", ]: diff --git a/src/infiniop-test/include/ops.hpp b/src/infiniop-test/include/ops.hpp index 3820f7cfd..d543b4cb3 100644 --- a/src/infiniop-test/include/ops.hpp +++ b/src/infiniop-test/include/ops.hpp @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add) DECLARE_INFINIOP_TEST(causal_softmax) DECLARE_INFINIOP_TEST(rearrange) DECLARE_INFINIOP_TEST(sub) +DECLARE_INFINIOP_TEST(sigmoid) #define REGISTER_INFINIOP_TEST(name) \ { \ @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub) REGISTER_INFINIOP_TEST(causal_softmax) \ REGISTER_INFINIOP_TEST(rearrange) \ REGISTER_INFINIOP_TEST(sub) \ + REGISTER_INFINIOP_TEST(sigmoid) \ } namespace infiniop_test { diff --git a/src/infiniop-test/src/ops/sigmoid.cpp b/src/infiniop-test/src/ops/sigmoid.cpp new file mode 100644 index 000000000..bb3a0f70a --- /dev/null +++ b/src/infiniop-test/src/ops/sigmoid.cpp @@ -0,0 +1,103 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::sigmoid { +struct Test::Attributes { + std::shared_ptr x; + std::shared_ptr y; + std::shared_ptr ans; +}; + +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + if (tensors.find("x") == tensors.end() + || tensors.find("y") == tensors.end() + || tensors.find("ans") == tensors.end()) { + throw std::runtime_error("Invalid Test"); + } + + test->_attributes->x = tensors["x"]; + test->_attributes->y = tensors["y"]; + test->_attributes->ans = tensors["ans"]; + + return test; +} + +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) { + infiniopSigmoidDescriptor_t op_desc; + auto x = _attributes->x->to(device, device_id); + auto y = _attributes->y->to(device, device_id); + CHECK_OR(infiniopCreateSigmoidDescriptor(handle, &op_desc, + y->desc(), + x->desc()), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor.")); + size_t workspace_size; + CHECK_OR(infiniopGetSigmoidWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size.")); + void *workspace; + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace.")); + CHECK_OR(infiniopSigmoid(op_desc, workspace, workspace_size, + y->data(), + x->data(), + nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution.")); + + try { + allClose(y, _attributes->ans, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + double elapsed_time = 0.; + + elapsed_time = benchmark( + [=]() { + infiniopSigmoid( + op_desc, workspace, workspace_size, + y->data(), + x->data(), + nullptr); + }, + warm_ups, iterations); + + infiniopDestroySigmoidDescriptor(op_desc); + infinirtFree(workspace); + return TEST_PASSED(elapsed_time); +} + +std::vector Test::attribute_names() { + return {}; +} + +std::vector Test::tensor_names() { + return {"x", "y", "ans"}; +} + +std::vector Test::output_names() { + return {"y"}; +} + +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- x: " << _attributes->x->info() << std::endl; + oss << "- y: " << _attributes->y->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +Test::~Test() { + delete _attributes; +} + +} // namespace infiniop_test::sigmoid diff --git a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc new file mode 100644 index 000000000..c335bba60 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc @@ -0,0 +1,51 @@ +#include "sigmoid_cpu.h" + +namespace op::sigmoid::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::sigmoid::cpu diff --git a/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h new file mode 100644 index 000000000..6ab7eaeb9 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h @@ -0,0 +1,19 @@ +#ifndef __SIGMOID_CPU_H__ +#define __SIGMOID_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(sigmoid, cpu, cpu) + +namespace op::sigmoid::cpu { +typedef struct SigmoidOp { +public: + static constexpr size_t num_inputs = 1; + template + T operator()(const T &x) const { + return T(1) / (T(1) + std::exp(-x)); + } +} SigmoidOp; +} // namespace op::sigmoid::cpu + +#endif // __SIGMOID_CPU_H__ diff --git a/src/infiniop/ops/sigmoid/cuda/kernel.cuh b/src/infiniop/ops/sigmoid/cuda/kernel.cuh new file mode 100644 index 000000000..1ea7c2a02 --- /dev/null +++ b/src/infiniop/ops/sigmoid/cuda/kernel.cuh @@ -0,0 +1,34 @@ +#ifndef __SIDMOID_CUDA_H__ +#define __SIDMOID_CUDA_H__ + +#include "../../../elementwise/cuda/elementwise_cuda.cuh" +#include +#include + +namespace op::sigmoid::cuda { +typedef struct SigmoidOp { +public: + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + // sigmoid(x) = 1 / (1 + exp(-x)) + if constexpr (std::is_same_v) { + half2 denominator = __hadd2(make_half2(1, 1), h2exp(__hneg2(x))); + return h2rcp(denominator); + } else if constexpr (std::is_same_v) { + half denominator = __hadd(__float2half(1.0f), hexp(__hneg(x))); + return hrcp(denominator); + } else if constexpr (std::is_same_v) { + __nv_bfloat16 denominator = __float2bfloat16(__fadd_rn(1.0f, __expf(__bfloat162float(-x)))); + return __float2bfloat16(1.0f) / denominator; + } else if constexpr (std::is_same_v) { + float denominator = __fadd_rn(1.0f, __expf(-x)); + return __frcp_rn(denominator); + } else { // double + return 1.0 / (1.0 + exp(-x)); + } + } +} SigmoidOp; +} // namespace op::sigmoid::cuda + +#endif // __SIDMOID_CUDA_H__ diff --git a/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu new file mode 100644 index 000000000..f5dfe9fdd --- /dev/null +++ b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu @@ -0,0 +1,58 @@ +#include "../cuda/kernel.cuh" +#include "sigmoid_nvidia.cuh" + +namespace op::sigmoid::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::SigmoidOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::SigmoidOp, __nv_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::SigmoidOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::SigmoidOp, double>(_info, workspace, output, inputs, stream); + + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::sigmoid::nvidia diff --git a/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh new file mode 100644 index 000000000..53dd0a1fa --- /dev/null +++ b/src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __SIGMOID_CUDA_API_H__ +#define __SIGMOID_CUDA_API_H__ + +#include "../../../elementwise/cuda/elementwise_cuda_api.cuh" + +ELEMENTWISE_DESCRIPTOR(sigmoid, nvidia, cuda) + +#endif // __SIGMOID_CUDA_API_H__ diff --git a/src/infiniop/ops/sigmoid/operator.cc b/src/infiniop/ops/sigmoid/operator.cc new file mode 100644 index 000000000..3f2f95067 --- /dev/null +++ b/src/infiniop/ops/sigmoid/operator.cc @@ -0,0 +1,115 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/sigmoid.h" + +#ifdef ENABLE_CPU_API +#include "cpu/sigmoid_cpu.h" +#endif +#ifdef ENABLE_NVIDIA_API +#include "nvidia/sigmoid_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateSigmoidDescriptor( + infiniopHandle_t handle, + infiniopSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::sigmoid::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopSigmoid( + infiniopSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop-test/README.md b/test/infiniop-test/README.md index 85e889e42..40dc7e36d 100644 --- a/test/infiniop-test/README.md +++ b/test/infiniop-test/README.md @@ -17,7 +17,7 @@ xmake build infiniop-test 在`/test/infiniop-test/`目录执行矩阵乘测例生成脚本,执行结束以后会在`/test/infiniop-test/`目录生成`gemm.gguf`测例文件。 ```bash -cd /test/infiniop-test/ +cd ./test/infiniop-test/ python -m test_generate.testcases.gemm ``` diff --git a/test/infiniop-test/test_generate/testcases/sigmoid.py b/test/infiniop-test/test_generate/testcases/sigmoid.py new file mode 100644 index 000000000..f622a4d6e --- /dev/null +++ b/test/infiniop-test/test_generate/testcases/sigmoid.py @@ -0,0 +1,136 @@ +import numpy as np +from numpy.lib.stride_tricks import as_strided +import gguf +from typing import List + +from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor + + +def sigmoid( + x: np.ndarray, +): + return 1 / (1 + np.exp(-x)) + + +def random_tensor(shape, dtype): + rate = 1e-3 + var = 0.5 * rate + return rate * np.random.rand(*shape).astype(dtype) - var + + +def process_tensors(a, b, stride_a=None, stride_b=None): + def normalize_stride(tensor, stride): + if stride: + slices = tuple(slice(0, 1) if s == 0 else slice(None) for s in stride) + return tensor[slices] + else: + return tensor + + a_unique = normalize_stride(a, stride_a) + b_unique = normalize_stride(b, stride_b) + return a_unique, b_unique + + +def process_tensor(a, stride_a=None): + def normalize_stride(tensor, stride): + if stride: + slices = tuple(slice(0, 1) if s == 0 else slice(None) for s in stride) + return tensor[slices] + else: + return tensor + + a_unique = normalize_stride(a, stride_a) + return a_unique + + +class SigmoidTestCase(InfiniopTestCase): + def __init__( + self, + x: np.ndarray, + shape_x: List[int] | None, + stride_x: List[int] | None, + y: np.ndarray, + shape_y: List[int] | None, + stride_y: List[int] | None, + ): + super().__init__("sigmoid") + self.x = x + self.shape_x = shape_x + self.stride_x = stride_x + + self.y = y + self.shape_y = shape_y + self.stride_y = stride_y + + def write_test(self, test_writer: "InfiniopTestWriter"): + super().write_test(test_writer) + + if self.shape_x is not None: + test_writer.add_array(test_writer.gguf_key("x.shape"), self.shape_x) + if self.shape_y is not None: + test_writer.add_array(test_writer.gguf_key("y.shape"), self.shape_y) + + if self.stride_x is not None: + test_writer.add_array(test_writer.gguf_key("x.strides"), gguf_strides(*self.stride_x)) + + test_writer.add_array( + test_writer.gguf_key("y.strides"), + gguf_strides(*self.stride_y if self.stride_y is not None else contiguous_gguf_strides(self.shape_y)) + ) + + test_writer.add_tensor( + test_writer.gguf_key("x"), self.x, raw_dtype=np_dtype_to_ggml(self.x.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype) + ) + + input_x = self.x.astype(np.float64) + if (self.stride_x is not None) and (0 in self.stride_x): + typesize = np.dtype(input_x.dtype).itemsize + new_strides_bytes = tuple(x * typesize for x in self.stride_x) + input_x = as_strided(x=input_x, shape=self.shape_x, strides=new_strides_bytes) + + ans = sigmoid(input_x) + + test_writer.add_tensor( + test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64 + ) + + +if __name__ == '__main__': + test_writer = InfiniopTestWriter("sigmoid.gguf") + + test_cases = [] + _TEST_CASES_ = [ + # shape, x_stride, y_stride + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4), (0, 1), None), + ((13, 4, 4), None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), None), + ((16, 5632), None, None), + ((16, 5632), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)), + ] + _TENSOR_DTYPES_ = [np.float16, np.float32] + + for dtype in _TENSOR_DTYPES_: + for shape, stride_x, stride_y in _TEST_CASES_: + x = np.random.rand(*shape).astype(dtype) + y = np.empty(tuple(0 for _ in shape), dtype=dtype) + + x = process_zero_stride_tensor(x, stride_x) + test_case = SigmoidTestCase(x=x, + shape_x=shape, + stride_x=stride_x, + y=y, + shape_y=shape, + stride_y=stride_y) + + test_cases.append(test_case) + + test_writer.add_tests(test_cases) + test_writer.save() diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e92e77105..62febf5c2 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -489,3 +489,35 @@ def conv_(lib): lib.infiniopDestroyConvDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def sigmoid_(lib): + lib.infiniopCreateSigmoidDescriptor.restype = c_int32 + lib.infiniopCreateSigmoidDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetSigmoidWorkspaceSize.restype = c_int32 + lib.infiniopGetSigmoidWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSigmoid.restype = c_int32 + lib.infiniopSigmoid.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySigmoidDescriptor.restype = c_int32 + lib.infiniopDestroySigmoidDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/sigmoid.py b/test/infiniop/sigmoid.py new file mode 100644 index 000000000..b8f896fa1 --- /dev/null +++ b/test/infiniop/sigmoid.py @@ -0,0 +1,171 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, x_stride, y_stride + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4), (0, 1), (0, 1)), + ((13, 4, 4), None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (4, 0, 1)), + ((16, 5632), None, None), + ((16, 5632), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)), + ((4, 4, 56320), None, None), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + +# Inplace options applied for each test case in _TEST_CASES_ +_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_X, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +def torch_sigmoid(y, x): + torch.sigmoid(x, out=y) + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float16, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + if inplace == Inplace.INPLACE_X: + if x_stride != y_stride: + return + y = x + else: + y = TestTensor(shape, y_stride, dtype, device, mode="ones") + + if y.is_broadcast(): + return + + print( + f"Testing Sigmoid on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} " + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + torch_sigmoid(y.torch_tensor(), x.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateSigmoidDescriptor( + handle, + ctypes.byref(descriptor), + y.descriptor, + x.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [x, y]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetSigmoidWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, y.device) + + def lib_sigmoid(): + check_error( + LIBINFINIOP.infiniopSigmoid( + descriptor, + workspace.data(), + workspace.size(), + y.data(), + x.data(), + None, + ) + ) + + lib_sigmoid() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_sigmoid(y.torch_tensor(), x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_sigmoid(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroySigmoidDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m")