forked from ROCm/aiter
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
128 lines (109 loc) · 4.04 KB
/
setup.py
File metadata and controls
128 lines (109 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
import os
import sys
import shutil
from setuptools import setup
# !!!!!!!!!!!!!!!! never import aiter
# from aiter.jit import core
this_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, f"{this_dir}/aiter/")
from jit import core
from jit.utils.cpp_extension import (
BuildExtension,
IS_HIP_EXTENSION,
)
ck_dir = os.environ.get("CK_DIR", f"{this_dir}/3rdparty/composable_kernel")
PACKAGE_NAME = "aiter"
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
if BUILD_TARGET == "auto":
if IS_HIP_EXTENSION:
IS_ROCM = True
else:
IS_ROCM = False
else:
if BUILD_TARGET == "cuda":
IS_ROCM = False
elif BUILD_TARGET == "rocm":
IS_ROCM = True
FORCE_CXX11_ABI = False
if IS_ROCM:
assert os.path.exists(
ck_dir
), 'CK is needed by aiter, please make sure clone by "git clone --recursive https://github.com/ROCm/aiter.git" or "git submodule sync ; git submodule update --init --recursive"'
if int(os.environ.get("PREBUILD_KERNELS", 0)) == 1:
exclude_ops = ["libmha_fwd", "libmha_bwd"]
all_opts_args_build = core.get_args_of_build("all", exclue=exclude_ops)
# remove pybind, because there are already duplicates in rocm_opt
new_list = [el for el in all_opts_args_build["srcs"] if "pybind.cu" not in el]
all_opts_args_build["srcs"] = new_list
core.build_module(
md_name="aiter_",
srcs=all_opts_args_build["srcs"] + [f"{this_dir}/csrc"],
flags_extra_cc=all_opts_args_build["flags_extra_cc"]
+ ["-DPREBUILD_KERNELS"],
flags_extra_hip=all_opts_args_build["flags_extra_hip"]
+ ["-DPREBUILD_KERNELS"],
blob_gen_cmd=all_opts_args_build["blob_gen_cmd"],
extra_include=all_opts_args_build["extra_include"],
extra_ldflags=None,
verbose=False,
is_python_module=True,
is_standalone=False,
torch_exclude=False,
)
else:
raise NotImplementedError("Only ROCM is supported")
if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"):
shutil.rmtree("aiter_meta")
## link "3rdparty", "hsa", "csrc" into "aiter_meta"
shutil.copytree("3rdparty", "aiter_meta/3rdparty")
shutil.copytree("hsa", "aiter_meta/hsa")
shutil.copytree("csrc", "aiter_meta/csrc")
class NinjaBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
# calculate the maximum allowed NUM_JOBS based on cores
max_num_jobs_cores = max(1, os.cpu_count() * 0.8)
if int(os.environ.get("MAX_JOBS", "1")) < max_num_jobs_cores:
import psutil
# calculate the maximum allowed NUM_JOBS based on free memory
free_memory_gb = psutil.virtual_memory().available / (
1024**3
) # free memory in GB
# each JOB peak memory cost is ~8-9GB when threads = 4
max_num_jobs_memory = int(free_memory_gb / 9)
# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
max_jobs = int(max(1, min(max_num_jobs_cores, max_num_jobs_memory)))
os.environ["MAX_JOBS"] = str(max_jobs)
super().__init__(*args, **kwargs)
setup(
name=PACKAGE_NAME,
use_scm_version=True,
packages=["aiter_meta", "aiter"],
include_package_data=True,
package_data={
"": ["*"],
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
# ext_modules=ext_modules,
cmdclass={"build_ext": NinjaBuildExtension},
python_requires=">=3.8",
install_requires=[
"pybind11",
# "ninja",
"pandas",
"einops",
],
setup_requires=[
"packaging",
"psutil",
"ninja",
"setuptools_scm",
],
)
if os.path.exists("aiter_meta") and os.path.isdir("aiter_meta"):
shutil.rmtree("aiter_meta")