From 340a43456ec627cf610edaf57e7a63ce2e5176b9 Mon Sep 17 00:00:00 2001 From: Clement Chan <iclementine@outlook.com> Date: Mon, 30 Dec 2024 11:26:24 +0800 Subject: [PATCH 1/3] add a cpp extension --- setup.py | 63 +++++++++++++++++++++++++++++++ src/flag_gems/utils/type_utils.py | 3 ++ 2 files changed, 66 insertions(+) diff --git a/setup.py b/setup.py index e87ebb877..1c140ea11 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,16 @@ +import glob +import os from typing import List, Optional, Sequence +import torch from pip._internal.metadata import get_default_environment from setuptools import find_packages, setup +from torch.utils.cpp_extension import ( + CUDA_HOME, + BuildExtension, + CppExtension, + CUDAExtension, +) # ----------------------------- check triton ----------------------------- # NOTE: this is used to check whether pytorch-triton or triton is installed. Since @@ -62,6 +71,58 @@ def _is_package_installed(package_name: str) -> bool: or "triton" ) + +# --------------------------- flagems c extension ------------------------ +library_name = "flag_gems" + + +def get_extensions(): + debug_mode = os.getenv("DEBUG", "0") == "1" + use_cuda = os.getenv("USE_CUDA", "1") == "1" + if debug_mode: + print("Compiling in debug mode") + + use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None + extension = CUDAExtension if use_cuda else CppExtension + + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + ], + } + if debug_mode: + extra_compile_args["cxx"].append("-g") + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + this_dir = os.path.dirname(os.path.curdir) + src_dir = os.path.join(this_dir, "src") + extensions_dir = os.path.join(src_dir, library_name, "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + + # extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + # cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + + # if use_cuda: + # sources += cuda_sources + + ext_modules = [ + extension( + f"{library_name}._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + + return ext_modules + + # ----------------------------- Setup ----------------------------- setup( name="flag_gems", @@ -104,5 +165,7 @@ def _is_package_installed(package_name: str) -> bool: package_data={ "flag_gems.runtime": ["**/*.yaml"], }, + ext_modules=get_extensions(), setup_requires=["setuptools"], + cmdclass={"build_ext": BuildExtension}, ) diff --git a/src/flag_gems/utils/type_utils.py b/src/flag_gems/utils/type_utils.py index baff2bd0e..de44adde3 100644 --- a/src/flag_gems/utils/type_utils.py +++ b/src/flag_gems/utils/type_utils.py @@ -1,7 +1,10 @@ +import functools + import torch from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes +@functools.lru_cache(maxsize=None) def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND): computation_dtype, result_dtype = elementwise_dtypes( *args, From 3905d0ef7d49ae9a4193714efb2f3852ee45547f Mon Sep 17 00:00:00 2001 From: Clement Chan <iclementine@outlook.com> Date: Mon, 30 Dec 2024 14:41:06 +0800 Subject: [PATCH 2/3] only use cpp extension --- setup.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/setup.py b/setup.py index 1c140ea11..31f7cd9a5 100644 --- a/setup.py +++ b/setup.py @@ -2,15 +2,9 @@ import os from typing import List, Optional, Sequence -import torch from pip._internal.metadata import get_default_environment from setuptools import find_packages, setup -from torch.utils.cpp_extension import ( - CUDA_HOME, - BuildExtension, - CppExtension, - CUDAExtension, -) +from torch.utils.cpp_extension import BuildExtension, CppExtension # ----------------------------- check triton ----------------------------- # NOTE: this is used to check whether pytorch-triton or triton is installed. Since @@ -78,26 +72,19 @@ def _is_package_installed(package_name: str) -> bool: def get_extensions(): debug_mode = os.getenv("DEBUG", "0") == "1" - use_cuda = os.getenv("USE_CUDA", "1") == "1" if debug_mode: print("Compiling in debug mode") - use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None - extension = CUDAExtension if use_cuda else CppExtension - + extension = CppExtension extra_link_args = [] extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", ], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - ], } if debug_mode: extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) this_dir = os.path.dirname(os.path.curdir) @@ -105,12 +92,6 @@ def get_extensions(): extensions_dir = os.path.join(src_dir, library_name, "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) - # extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - # cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) - - # if use_cuda: - # sources += cuda_sources - ext_modules = [ extension( f"{library_name}._C", From ffbbd8245e60c4d7eae35ff5ae1e301f093bf6ce Mon Sep 17 00:00:00 2001 From: Bowen12992 <zhangbluestars@gmail.com> Date: Mon, 30 Dec 2024 10:08:26 +0000 Subject: [PATCH 3/3] nothing --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 31f7cd9a5..20fd7cb5a 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ # ----------------------------- check triton ----------------------------- # NOTE: this is used to check whether pytorch-triton or triton is installed. Since # the name for the package to be import is the name, but the names in package manager -# are different. So we check it in this way +# are different. So we check it in this way: # 1. If the triton that is installed via pytorch-triton, then it is the version that is # dependended by pytorch. Upgrading it may break torch. Be aware of the risk! # 2. If the triton is installed via torch, then maybe you are aware that you are using