Skip to content

Commit 2a034d5

Browse files
authored
Cherry-pick: Reduce Python and Nuget GPU package size (microsoft#26002) (microsoft#26087)
Reduce Python and Nuget GPU package size (microsoft#26002) [CUDA] Add build flag onnxruntime_USE_FPA_INTB_GEMM (microsoft#25802)
1 parent d1cfdf0 commit 2a034d5

30 files changed

+101
-34
lines changed

cmake/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)
9898

9999
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
100100
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
101-
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
101+
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
102+
cmake_dependent_option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" ON "onnxruntime_USE_CUDA" OFF)
102103

103104
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
104105
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
@@ -694,6 +695,7 @@ if (onnxruntime_USE_CUDA)
694695
set(onnxruntime_USE_FLASH_ATTENTION OFF)
695696
set(onnxruntime_USE_LEAN_ATTENTION OFF)
696697
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
698+
set(onnxruntime_USE_FPA_INTB_GEMM OFF)
697699
endif()
698700

699701
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
@@ -706,6 +708,11 @@ if (onnxruntime_USE_CUDA)
706708
set(onnxruntime_USE_FLASH_ATTENTION OFF)
707709
endif()
708710

711+
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
712+
message( STATUS "FpA IntB Gemm unsupported for CUDA compiler version < 12.0")
713+
set(onnxruntime_USE_FPA_INTB_GEMM OFF)
714+
endif()
715+
709716
if (WIN32)
710717
message( STATUS "Lean Attention unsupported in Windows")
711718
set(onnxruntime_USE_LEAN_ATTENTION OFF)
@@ -734,6 +741,11 @@ if (onnxruntime_USE_CUDA)
734741
message( STATUS "Enable memory efficient attention for CUDA EP")
735742
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
736743
endif()
744+
745+
if (onnxruntime_USE_FPA_INTB_GEMM)
746+
message( STATUS "Enable FpA IntB Gemm for CUDA EP")
747+
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FPA_INTB_GEMM=1)
748+
endif()
737749
endif()
738750

739751
if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA))

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightO
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t,
2424
} // namespace cutlass_kernels
2525
} // namespace kernels
2626
} // namespace onnxruntime::llm
27+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<half, cutlass::uint4b_t, cutlass::Weight
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<half, uint8_t, cutlass::WeightOnlyQuantO
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scaleonly.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
#if USE_FPA_INTB_GEMM
1717
#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h"
1818

1919
namespace onnxruntime::llm {
@@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<half, uint8_t, cutlass::WeightOnlyQuantO
2323
} // namespace cutlass_kernels
2424
} // namespace kernels
2525
} // namespace onnxruntime::llm
26+
#endif

onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
#if USE_FPA_INTB_GEMM
23
#ifndef EXCLUDE_SM_90
34
#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
45

@@ -515,3 +516,4 @@ __nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::
515516
} // namespace kernels
516517
} // namespace onnxruntime::llm
517518
#endif // EXCLUDE_SM_90
519+
#endif // USE_FPA_INTB_GEMM

0 commit comments

Comments
 (0)