Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BiasGelu fusion and related ops #23518

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
@@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float or half tensors.</dd>
</dl>

4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
@@ -912,11 +912,11 @@ Do not modify directly.*
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|DynamicTimeWarping|*in* input:**F**<br> *out* output:**I**|1+|**F** = tensor(float)<br/> **I** = tensor(int32)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

using namespace ONNX_NAMESPACE;

9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
@@ -25,10 +25,14 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);

class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
@@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
@@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
@@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/elementwise.h
Original file line number Diff line number Diff line change
@@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp<ElementwiseParams<T>> {
}

ELEMENTWISE_FWD_DECL(FastGeLU, float);
ELEMENTWISE_FWD_DECL(FastGeLU, double);
ELEMENTWISE_FWD_DECL(FastGeLU, half);
ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16);

ELEMENTWISE_FWD_DECL(GeLU, float);
ELEMENTWISE_FWD_DECL(GeLU, double);
ELEMENTWISE_FWD_DECL(GeLU, half);
ELEMENTWISE_FWD_DECL(GeLU, BFloat16);

Original file line number Diff line number Diff line change
@@ -4,5 +4,6 @@
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"

ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16);
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"

ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16);
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
@@ -11,10 +11,13 @@ namespace contrib {
namespace rocm {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
@@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
@@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
@@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
@@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(0, "X", "input tensor", "T")
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
.Output(0, "Y", "output tensor", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// fastgelu(x) =
13 changes: 11 additions & 2 deletions onnxruntime/core/optimizer/bias_gelu_fusion.cc
Original file line number Diff line number Diff line change
@@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}

const Node& next_node = (*next_node_itr);
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||

bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain);
if (!(is_onnx_gelu ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
@@ -72,14 +75,20 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

bool is_approximate = is_fast_gelu;
if (is_onnx_gelu) {
const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate");
is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh");
}

if (graph.NodeProducesGraphOutput(node)) {
continue;
}

Node& add_node = node;
Node& gelu_node = const_cast<Node&>(next_node);
std::string op_type = "BiasGelu";
if (is_fast_gelu) op_type = "FastGelu";
if (is_approximate) op_type = "FastGelu";

Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
op_type,
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/gelu.cc
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

template <typename T>
@@ -80,6 +81,7 @@ namespace contrib::cuda {

REGISTER_CONTRIB_KERNEL_TYPED(float)
REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(BFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(double)

#undef REGISTER_CONTRIB_KERNEL_TYPED
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ Status LaunchGeluKernel(

SPECIALIZED_GELU_IMPL(float);
SPECIALIZED_GELU_IMPL(half);
SPECIALIZED_GELU_IMPL(BFloat16);
SPECIALIZED_GELU_IMPL(double);

#undef SPECIALIZED_GELU_IMPL
40 changes: 39 additions & 1 deletion onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
@@ -389,7 +389,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
int min_cuda_architecture = 800;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
@@ -440,5 +440,43 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
}
#endif

// CUDA and ROCm only for double type.
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_Double) {
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;

std::vector<double> X = {
0.8, -0.5, 0.0, 1.0,
0.5, 0.2, 0.3, -0.6};

std::vector<double> B = {
-0.5, 0.6, 1.2, 2.1};

std::vector<double> Y = {
0.185371, 0.053983, 1.061703, 3.097373,
0.000000, 0.630432, 1.399572, 1.399572};

std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;

tester.AddInput<double>("X", input_dims, X);
tester.AddInput<double>("bias", bias_dims, B);
tester.AddOutput<double>("Y", output_dims, Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

} // namespace test
} // namespace onnxruntime
40 changes: 40 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
@@ -4781,6 +4781,46 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}

TEST_F(GraphTransformationTests, BiasOnnxGeluTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}

TEST_F(GraphTransformationTests, BiasOnnxFastGeluTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_fast_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 0);
}

// BiasGelu allows input switching based on input dimensions.
// This test validates the input edges are plugged correct in the optimized graph.
TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
Binary file modified onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx
Binary file not shown.
37 changes: 37 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py
Original file line number Diff line number Diff line change
@@ -29,3 +29,40 @@

model = helper.make_model(graph)
onnx.save(model, r"bias_gelu_fusion.onnx")

graph = helper.make_graph(
[
helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"),
helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu"),
],
"Gelu_Add_Fusion", # name
[ # inputs
helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]),
],
[ # outputs
helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
],
)

model = helper.make_model(graph)
onnx.save(model, r"bias_onnx_gelu_fusion.onnx")


graph = helper.make_graph(
[
helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"),
helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu", approximate="tanh"),
],
"Gelu_Add_Fusion", # name
[ # inputs
helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]),
],
[ # outputs
helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
],
)

model = helper.make_model(graph)
onnx.save(model, r"bias_onnx_fast_gelu_fusion.onnx")
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

:½

X
Badd0_outadd0"Add
1
add0_outYgelu"Gelu*
approximate"tanh Gelu_Add_FusionZ#
X

batch
seqlen
€Z
B

€b#
Y

batch
seqlen
€B
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

:¥

X
Badd0_outadd0"Add

add0_outYgelu"GeluGelu_Add_FusionZ#
X

batch
seqlen
€Z
B

€b#
Y

batch
seqlen
€B

Unchanged files with check annotations Beta

}
private:
std::codecvt_utf8<wchar_t> converter_;

Check warning on line 224 in onnxruntime/core/providers/cpu/text/string_normalizer.cc

GitHub Actions / Vcpkg

'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]

Check warning on line 224 in onnxruntime/core/providers/cpu/text/string_normalizer.cc

GitHub Actions / Vcpkg

'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]
};
// We need to specialize for MS as there is