From 340a43456ec627cf610edaf57e7a63ce2e5176b9 Mon Sep 17 00:00:00 2001
From: Clement Chan <iclementine@outlook.com>
Date: Mon, 30 Dec 2024 11:26:24 +0800
Subject: [PATCH 1/3] add a cpp extension

---
 setup.py                          | 63 +++++++++++++++++++++++++++++++
 src/flag_gems/utils/type_utils.py |  3 ++
 2 files changed, 66 insertions(+)

diff --git a/setup.py b/setup.py
index e87ebb877..1c140ea11 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,16 @@
+import glob
+import os
 from typing import List, Optional, Sequence
 
+import torch
 from pip._internal.metadata import get_default_environment
 from setuptools import find_packages, setup
+from torch.utils.cpp_extension import (
+    CUDA_HOME,
+    BuildExtension,
+    CppExtension,
+    CUDAExtension,
+)
 
 # ----------------------------- check triton -----------------------------
 # NOTE: this is used to check whether pytorch-triton or triton is installed. Since
@@ -62,6 +71,58 @@ def _is_package_installed(package_name: str) -> bool:
     or "triton"
 )
 
+
+# --------------------------- flagems c extension ------------------------
+library_name = "flag_gems"
+
+
+def get_extensions():
+    debug_mode = os.getenv("DEBUG", "0") == "1"
+    use_cuda = os.getenv("USE_CUDA", "1") == "1"
+    if debug_mode:
+        print("Compiling in debug mode")
+
+    use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None
+    extension = CUDAExtension if use_cuda else CppExtension
+
+    extra_link_args = []
+    extra_compile_args = {
+        "cxx": [
+            "-O3" if not debug_mode else "-O0",
+            "-fdiagnostics-color=always",
+        ],
+        "nvcc": [
+            "-O3" if not debug_mode else "-O0",
+        ],
+    }
+    if debug_mode:
+        extra_compile_args["cxx"].append("-g")
+        extra_compile_args["nvcc"].append("-g")
+        extra_link_args.extend(["-O0", "-g"])
+
+    this_dir = os.path.dirname(os.path.curdir)
+    src_dir = os.path.join(this_dir, "src")
+    extensions_dir = os.path.join(src_dir, library_name, "csrc")
+    sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
+
+    # extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
+    # cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
+
+    # if use_cuda:
+    #     sources += cuda_sources
+
+    ext_modules = [
+        extension(
+            f"{library_name}._C",
+            sources,
+            extra_compile_args=extra_compile_args,
+            extra_link_args=extra_link_args,
+        )
+    ]
+
+    return ext_modules
+
+
 # ----------------------------- Setup -----------------------------
 setup(
     name="flag_gems",
@@ -104,5 +165,7 @@ def _is_package_installed(package_name: str) -> bool:
     package_data={
         "flag_gems.runtime": ["**/*.yaml"],
     },
+    ext_modules=get_extensions(),
     setup_requires=["setuptools"],
+    cmdclass={"build_ext": BuildExtension},
 )
diff --git a/src/flag_gems/utils/type_utils.py b/src/flag_gems/utils/type_utils.py
index baff2bd0e..de44adde3 100644
--- a/src/flag_gems/utils/type_utils.py
+++ b/src/flag_gems/utils/type_utils.py
@@ -1,7 +1,10 @@
+import functools
+
 import torch
 from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes
 
 
+@functools.lru_cache(maxsize=None)
 def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND):
     computation_dtype, result_dtype = elementwise_dtypes(
         *args,

From 3905d0ef7d49ae9a4193714efb2f3852ee45547f Mon Sep 17 00:00:00 2001
From: Clement Chan <iclementine@outlook.com>
Date: Mon, 30 Dec 2024 14:41:06 +0800
Subject: [PATCH 2/3] only use cpp extension

---
 setup.py | 23 ++---------------------
 1 file changed, 2 insertions(+), 21 deletions(-)

diff --git a/setup.py b/setup.py
index 1c140ea11..31f7cd9a5 100644
--- a/setup.py
+++ b/setup.py
@@ -2,15 +2,9 @@
 import os
 from typing import List, Optional, Sequence
 
-import torch
 from pip._internal.metadata import get_default_environment
 from setuptools import find_packages, setup
-from torch.utils.cpp_extension import (
-    CUDA_HOME,
-    BuildExtension,
-    CppExtension,
-    CUDAExtension,
-)
+from torch.utils.cpp_extension import BuildExtension, CppExtension
 
 # ----------------------------- check triton -----------------------------
 # NOTE: this is used to check whether pytorch-triton or triton is installed. Since
@@ -78,26 +72,19 @@ def _is_package_installed(package_name: str) -> bool:
 
 def get_extensions():
     debug_mode = os.getenv("DEBUG", "0") == "1"
-    use_cuda = os.getenv("USE_CUDA", "1") == "1"
     if debug_mode:
         print("Compiling in debug mode")
 
-    use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None
-    extension = CUDAExtension if use_cuda else CppExtension
-
+    extension = CppExtension
     extra_link_args = []
     extra_compile_args = {
         "cxx": [
             "-O3" if not debug_mode else "-O0",
             "-fdiagnostics-color=always",
         ],
-        "nvcc": [
-            "-O3" if not debug_mode else "-O0",
-        ],
     }
     if debug_mode:
         extra_compile_args["cxx"].append("-g")
-        extra_compile_args["nvcc"].append("-g")
         extra_link_args.extend(["-O0", "-g"])
 
     this_dir = os.path.dirname(os.path.curdir)
@@ -105,12 +92,6 @@ def get_extensions():
     extensions_dir = os.path.join(src_dir, library_name, "csrc")
     sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
 
-    # extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
-    # cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu")))
-
-    # if use_cuda:
-    #     sources += cuda_sources
-
     ext_modules = [
         extension(
             f"{library_name}._C",

From ffbbd8245e60c4d7eae35ff5ae1e301f093bf6ce Mon Sep 17 00:00:00 2001
From: Bowen12992 <zhangbluestars@gmail.com>
Date: Mon, 30 Dec 2024 10:08:26 +0000
Subject: [PATCH 3/3] nothing

---
 setup.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/setup.py b/setup.py
index 31f7cd9a5..20fd7cb5a 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@
 # ----------------------------- check triton -----------------------------
 # NOTE: this is used to check whether pytorch-triton or triton is installed. Since
 # the name for the package to be import is the name, but the names in package manager
-# are different. So we check it in this way
+# are different. So we check it in this way:
 # 1. If the triton that is installed via pytorch-triton, then it is the version that is
 # dependended by pytorch. Upgrading it may break torch. Be aware of the risk!
 # 2. If the triton is installed via torch, then maybe you are aware that you are using