Skip to content

Commit e4605f7

Browse files
Merge pull request #293 from YdrMaster/distinct-cuda
issue291 合并 cuda 代码
2 parents 5025ebe + eac2b0c commit e4605f7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+462
-576
lines changed

src/infiniop/devices/cuda/cuda_kernel_common.cuh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#define INFINIOP_CUDA_KERNEL __global__ void
55
#endif
66

7+
#include <cuda_bf16.h>
8+
#include <cuda_fp16.h>
9+
710
// Posible maximum number of threads per block for CUDA architectures
811
// Used for picking correct kernel launch configuration
912
#define CUDA_BLOCK_SIZE_4096 4096
@@ -12,8 +15,10 @@
1215

1316
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
1417

15-
namespace device::cuda {
18+
using cuda_bfloat16 = nv_bfloat16;
19+
using cuda_bfloat162 = nv_bfloat162;
1620

21+
namespace device::cuda {
1722
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
1823
__forceinline__ __device__ __host__ size_t
1924
indexToReducedOffset(
@@ -45,8 +50,6 @@ indexToOffset(
4550
}
4651
} // namespace device::cuda
4752

48-
#ifdef ENABLE_NVIDIA_API
49-
#include <cuda_fp16.h>
5053
__forceinline__ __device__ float
5154
exp_(const float val) {
5255
return expf(val);
@@ -73,4 +76,3 @@ __forceinline__ __device__ __nv_bfloat16
7376
exp_(const __nv_bfloat16 x) {
7477
return hexp(x);
7578
}
76-
#endif

src/infiniop/devices/maca/maca_kernel_common.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
#define INFINIOP_MACA_KERNEL __global__ void
2+
23
// Posible maximum number of threads per block for MACA architectures
34
// Used for picking correct kernel launch configuration
45
#define MACA_BLOCK_SIZE_1024 1024
56
#define MACA_BLOCK_SIZE_512 512
67

78
#define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess)
89

10+
using cuda_bfloat16 = hpcc_bfloat16;
11+
using cuda_bfloat162 = hpcc_bfloat162;
12+
913
namespace device::maca {
1014

1115
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
@@ -39,16 +43,14 @@ indexToOffset(
3943
}
4044
} // namespace device::maca
4145

42-
#ifdef ENABLE_MACA_API
43-
#include <maca_fp16.h>
4446
__forceinline__ __device__ float
4547
exp_(const float val) {
4648
return expf(val);
4749
}
4850

4951
__forceinline__ __device__ long double
5052
exp_(const long double val) {
51-
return expl(val);
53+
return exp(val);
5254
}
5355

5456
__forceinline__ __device__ double
@@ -61,8 +63,7 @@ exp_(const __half x) {
6163
return hexp(x);
6264
}
6365

64-
__forceinline__ __device__ __hpcc_bfloat16;
65-
exp_(const __hpcc_bfloat16; x) {
66+
__forceinline__ __device__ __hpcc_bfloat16
67+
exp_(const __hpcc_bfloat16 x) {
6668
return hexp(x);
6769
}
68-
#endif

src/infiniop/elementwise/elementwise.h

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,45 @@
1212
#include <numeric>
1313
#include <vector>
1414

15-
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
16-
\
17-
namespace op::OP::NAMESPACE { \
18-
class Descriptor final : public InfiniopDescriptor { \
19-
infiniDtype_t _dtype; \
20-
op::elementwise::ElementwiseInfo _info; \
21-
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
22-
size_t _workspace_size; \
23-
\
24-
Descriptor( \
25-
infiniDtype_t dtype, \
26-
op::elementwise::ElementwiseInfo info, \
27-
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
28-
size_t workspace_size, \
29-
infiniDevice_t device_type, \
30-
int device_id) \
31-
: InfiniopDescriptor{device_type, device_id}, \
32-
_dtype(dtype), \
33-
_info(std::move(info)), \
34-
_device_info(std::move(device_info)), \
35-
_workspace_size(workspace_size) {} \
36-
\
37-
public: \
38-
~Descriptor(); \
39-
\
40-
size_t workspaceSize() const { return _workspace_size; } \
41-
\
42-
static infiniStatus_t create( \
43-
infiniopHandle_t handle, \
44-
Descriptor **desc_ptr, \
45-
infiniopTensorDescriptor_t output_desc, \
46-
std::vector<infiniopTensorDescriptor_t> input_descs); \
47-
\
48-
infiniStatus_t calculate( \
49-
void *workspace, size_t workspace_size, \
50-
void *output, \
51-
std::vector<const void *> inputs, \
52-
void *stream) const; \
53-
}; \
15+
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE, KERNEL_COMMON) \
16+
\
17+
namespace op::OP::NAMESPACE { \
18+
class Descriptor final : public InfiniopDescriptor { \
19+
infiniDtype_t _dtype; \
20+
op::elementwise::ElementwiseInfo _info; \
21+
std::unique_ptr<op::elementwise::KERNEL_COMMON::DeviceImpl> _device_info; \
22+
size_t _workspace_size; \
23+
\
24+
Descriptor( \
25+
infiniDtype_t dtype, \
26+
op::elementwise::ElementwiseInfo info, \
27+
op::elementwise::KERNEL_COMMON::DeviceImpl *device_info, \
28+
size_t workspace_size, \
29+
infiniDevice_t device_type, \
30+
int device_id) \
31+
: InfiniopDescriptor{device_type, device_id}, \
32+
_dtype(dtype), \
33+
_info(std::move(info)), \
34+
_device_info(std::move(device_info)), \
35+
_workspace_size(workspace_size) {} \
36+
\
37+
public: \
38+
~Descriptor(); \
39+
\
40+
size_t workspaceSize() const { return _workspace_size; } \
41+
\
42+
static infiniStatus_t create( \
43+
infiniopHandle_t handle, \
44+
Descriptor **desc_ptr, \
45+
infiniopTensorDescriptor_t output_desc, \
46+
std::vector<infiniopTensorDescriptor_t> input_descs); \
47+
\
48+
infiniStatus_t calculate( \
49+
void *workspace, size_t workspace_size, \
50+
void *output, \
51+
std::vector<const void *> inputs, \
52+
void *stream) const; \
53+
}; \
5454
}
5555

5656
namespace op::elementwise {

src/infiniop/ops/add/cpu/add_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include "../../../elementwise/cpu/elementwise_cpu.h"
55

6-
ELEMENTWISE_DESCRIPTOR(add, cpu)
6+
ELEMENTWISE_DESCRIPTOR(add, cpu, cpu)
77

88
namespace op::add::cpu {
99
typedef struct AddOp {

src/infiniop/ops/add/cuda/add_cuda.cu renamed to src/infiniop/ops/add/nvidia/add_nvidia.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include "add_cuda.cuh"
2-
#include "add_cuda_internal.cuh"
1+
#include "../cuda/kernel.cuh"
2+
#include "add_nvidia.cuh"
33

4-
namespace op::add::cuda {
4+
namespace op::add::nvidia {
55

66
Descriptor::~Descriptor() = default;
77

@@ -43,17 +43,17 @@ infiniStatus_t Descriptor::calculate(
4343

4444
switch (_dtype) {
4545
case INFINI_DTYPE_F16:
46-
return _device_info->calculate<256, AddOp, half>(_info, workspace, output, inputs, stream);
46+
return _device_info->calculate<256, cuda::AddOp, half>(_info, workspace, output, inputs, stream);
4747
case INFINI_DTYPE_BF16:
48-
return _device_info->calculate<256, AddOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
48+
return _device_info->calculate<256, cuda::AddOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
4949
case INFINI_DTYPE_F32:
50-
return _device_info->calculate<256, AddOp, float>(_info, workspace, output, inputs, stream);
50+
return _device_info->calculate<256, cuda::AddOp, float>(_info, workspace, output, inputs, stream);
5151
case INFINI_DTYPE_F64:
52-
return _device_info->calculate<256, AddOp, double>(_info, workspace, output, inputs, stream);
52+
return _device_info->calculate<256, cuda::AddOp, double>(_info, workspace, output, inputs, stream);
5353
default:
5454
return INFINI_STATUS_BAD_TENSOR_DTYPE;
5555
}
5656

5757
return INFINI_STATUS_SUCCESS;
5858
}
59-
} // namespace op::add::cuda
59+
} // namespace op::add::nvidia

src/infiniop/ops/add/cuda/add_cuda.cuh renamed to src/infiniop/ops/add/nvidia/add_nvidia.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
55

6-
ELEMENTWISE_DESCRIPTOR(add, cuda)
6+
ELEMENTWISE_DESCRIPTOR(add, nvidia, cuda)
77

88
#endif // __ADD_CUDA_API_H__

src/infiniop/ops/add/operator.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "cpu/add_cpu.h"
77
#endif
88
#ifdef ENABLE_NVIDIA_API
9-
#include "cuda/add_cuda.cuh"
9+
#include "nvidia/add_nvidia.cuh"
1010
#endif
1111

1212
__C infiniStatus_t infiniopCreateAddDescriptor(
@@ -31,7 +31,7 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
3131
CREATE(INFINI_DEVICE_CPU, cpu);
3232
#endif
3333
#ifdef ENABLE_NVIDIA_API
34-
CREATE(INFINI_DEVICE_NVIDIA, cuda);
34+
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
3535
#endif
3636

3737
default:
@@ -46,14 +46,14 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
4646
#define GET(CASE, NAMESPACE) \
4747
case CASE: \
4848
*size = reinterpret_cast<op::add::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
49-
return INFINI_STATUS_SUCCESS;
49+
return INFINI_STATUS_SUCCESS
5050

5151
switch (desc->device_type) {
5252
#ifdef ENABLE_CPU_API
53-
GET(INFINI_DEVICE_CPU, cpu)
53+
GET(INFINI_DEVICE_CPU, cpu);
5454
#endif
5555
#ifdef ENABLE_NVIDIA_API
56-
GET(INFINI_DEVICE_NVIDIA, cuda)
56+
GET(INFINI_DEVICE_NVIDIA, nvidia);
5757
#endif
5858
default:
5959
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -83,7 +83,7 @@ __C infiniStatus_t infiniopAdd(
8383
CALCULATE(INFINI_DEVICE_CPU, cpu);
8484
#endif
8585
#ifdef ENABLE_NVIDIA_API
86-
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
86+
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
8787
#endif
8888

8989
default:
@@ -99,15 +99,15 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
9999
#define DELETE(CASE, NAMESPACE) \
100100
case CASE: \
101101
delete reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc); \
102-
return INFINI_STATUS_SUCCESS;
102+
return INFINI_STATUS_SUCCESS
103103

104104
switch (desc->device_type) {
105105

106106
#ifdef ENABLE_CPU_API
107107
DELETE(INFINI_DEVICE_CPU, cpu);
108108
#endif
109109
#ifdef ENABLE_NVIDIA_API
110-
DELETE(INFINI_DEVICE_NVIDIA, cuda);
110+
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
111111
#endif
112112

113113
default:

src/infiniop/ops/causal_softmax/cuda/causal_softmax_cuda.cuh

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/infiniop/ops/causal_softmax/cuda/causal_softmax_kernel.cuh renamed to src/infiniop/ops/causal_softmax/cuda/kernel.cuh

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
1+
#ifndef __CAUSAL_SOFTMAX_KERNEL_CUH__
22
#define __CAUSAL_SOFTMAX_KERNEL_CUH__
33

4-
#include "../../../devices/cuda/cuda_kernel_common.cuh"
5-
#include "../../../reduce/cuda/reduce.cuh"
6-
74
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
8-
INFINIOP_CUDA_KERNEL causalSoftmax(
5+
__device__ void causalSoftmaxKernel(
96
Tdata *y_, const Tdata *x_,
107
size_t batch, size_t height, size_t width,
118
ptrdiff_t y_stride_b, ptrdiff_t y_stride_h,
@@ -32,11 +29,11 @@ INFINIOP_CUDA_KERNEL causalSoftmax(
3229
// 2 | * * * ... * * * |
3330
// height: 3 col_id->
3431
if (width + blockIdx.x >= threadIdx.x + height) {
35-
#ifdef ENABLE_NVIDIA_API
36-
y[col] = exp_(x[col] - max_);
37-
#else
38-
y[col] = exp(x[col] - max_);
39-
#endif
32+
if constexpr (std::is_same_v<Tdata, half> || std::is_same_v<Tdata, cuda_bfloat16>) {
33+
y[col] = hexp(x[col] - max_);
34+
} else {
35+
y[col] = exp(x[col] - max_);
36+
}
4037
} else {
4138
y[col] = Tdata(0);
4239
}

0 commit comments

Comments
 (0)