Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
#include "infiniop/ops/sigmoid.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h"
Expand Down
24 changes: 24 additions & 0 deletions include/infiniop/ops/sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef __INFINIOP_SIGMOID_API_H__
#define __INFINIOP_SIGMOID_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopSigmoidDescriptor_t;

__C __export infiniStatus_t infiniopCreateSigmoidDescriptor(infiniopHandle_t handle,
infiniopSigmoidDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x);

__C __export infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopSigmoid(infiniopSigmoidDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);

__C __export infiniStatus_t infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc);

#endif
1 change: 1 addition & 0 deletions scripts/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def run_tests(args):
"rearrange.py",
"rms_norm.py",
"rope.py",
"sigmoid.py",
"sub.py",
"swiglu.py",
]:
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop-test/include/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST(causal_softmax)
DECLARE_INFINIOP_TEST(rearrange)
DECLARE_INFINIOP_TEST(sub)
DECLARE_INFINIOP_TEST(sigmoid)

#define REGISTER_INFINIOP_TEST(name) \
{ \
Expand Down Expand Up @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub)
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(sigmoid) \
}

namespace infiniop_test {
Expand Down
103 changes: 103 additions & 0 deletions src/infiniop-test/src/ops/sigmoid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>

namespace infiniop_test::sigmoid {
struct Test::Attributes {
std::shared_ptr<Tensor> x;
std::shared_ptr<Tensor> y;
std::shared_ptr<Tensor> ans;
};

std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();
if (tensors.find("x") == tensors.end()
|| tensors.find("y") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test");
}

test->_attributes->x = tensors["x"];
test->_attributes->y = tensors["y"];
test->_attributes->ans = tensors["ans"];

return test;
}

std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
infiniopSigmoidDescriptor_t op_desc;
auto x = _attributes->x->to(device, device_id);
auto y = _attributes->y->to(device, device_id);
CHECK_OR(infiniopCreateSigmoidDescriptor(handle, &op_desc,
y->desc(),
x->desc()),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
size_t workspace_size;
CHECK_OR(infiniopGetSigmoidWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
void *workspace;
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
CHECK_OR(infiniopSigmoid(op_desc, workspace, workspace_size,
y->data(),
x->data(),
nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));

try {
allClose(y, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}

double elapsed_time = 0.;

elapsed_time = benchmark(
[=]() {
infiniopSigmoid(
op_desc, workspace, workspace_size,
y->data(),
x->data(),
nullptr);
},
warm_ups, iterations);

infiniopDestroySigmoidDescriptor(op_desc);
infinirtFree(workspace);
return TEST_PASSED(elapsed_time);
}

std::vector<std::string> Test::attribute_names() {
return {};
}

std::vector<std::string> Test::tensor_names() {
return {"x", "y", "ans"};
}

std::vector<std::string> Test::output_names() {
return {"y"};
}

std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- x: " << _attributes->x->info() << std::endl;
oss << "- y: " << _attributes->y->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}

Test::~Test() {
delete _attributes;
}

} // namespace infiniop_test::sigmoid
51 changes: 51 additions & 0 deletions src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "sigmoid_cpu.h"

namespace op::sigmoid::cpu {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(y_shape, x_shape);

// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<SigmoidOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<SigmoidOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<SigmoidOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<SigmoidOp, bf16_t>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::sigmoid::cpu
19 changes: 19 additions & 0 deletions src/infiniop/ops/sigmoid/cpu/sigmoid_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef __SIGMOID_CPU_H__
#define __SIGMOID_CPU_H__

#include "../../../elementwise/cpu/elementwise_cpu.h"

ELEMENTWISE_DESCRIPTOR(sigmoid, cpu, cpu)

namespace op::sigmoid::cpu {
typedef struct SigmoidOp {
public:
static constexpr size_t num_inputs = 1;
template <typename T>
T operator()(const T &x) const {
return T(1) / (T(1) + std::exp(-x));
}
} SigmoidOp;
} // namespace op::sigmoid::cpu

#endif // __SIGMOID_CPU_H__
34 changes: 34 additions & 0 deletions src/infiniop/ops/sigmoid/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __SIDMOID_CUDA_H__
#define __SIDMOID_CUDA_H__

#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace op::sigmoid::cuda {
typedef struct SigmoidOp {
public:
static constexpr size_t num_inputs = 1;
template <typename T>
__device__ __forceinline__ T operator()(const T &x) const {
// sigmoid(x) = 1 / (1 + exp(-x))
if constexpr (std::is_same_v<T, half2>) {
half2 denominator = __hadd2(make_half2(1, 1), h2exp(__hneg2(x)));
return h2rcp(denominator);
} else if constexpr (std::is_same_v<T, half>) {
half denominator = __hadd(__float2half(1.0f), hexp(__hneg(x)));
return hrcp(denominator);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
__nv_bfloat16 denominator = __float2bfloat16(__fadd_rn(1.0f, __expf(__bfloat162float(-x))));
return __float2bfloat16(1.0f) / denominator;
} else if constexpr (std::is_same_v<T, float>) {
float denominator = __fadd_rn(1.0f, __expf(-x));
return __frcp_rn(denominator);
} else { // double
return 1.0 / (1.0 + exp(-x));
}
}
} SigmoidOp;
} // namespace op::sigmoid::cuda

#endif // __SIDMOID_CUDA_H__
58 changes: 58 additions & 0 deletions src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "../cuda/kernel.cuh"
#include "sigmoid_nvidia.cuh"

namespace op::sigmoid::nvidia {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);

CHECK_SAME_SHAPE(y_shape, x_shape);

// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SigmoidOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SigmoidOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SigmoidOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SigmoidOp, double>(_info, workspace, output, inputs, stream);

default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::sigmoid::nvidia
8 changes: 8 additions & 0 deletions src/infiniop/ops/sigmoid/nvidia/sigmoid_nvidia.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SIGMOID_CUDA_API_H__
#define __SIGMOID_CUDA_API_H__

#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"

ELEMENTWISE_DESCRIPTOR(sigmoid, nvidia, cuda)

#endif // __SIGMOID_CUDA_API_H__
Loading