-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsetup.py
160 lines (127 loc) · 5.26 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
import os
import subprocess
import sys
import zipfile
import tempfile
from datetime import datetime
from setuptools import find_packages, setup
from distutils import core
from distutils.core import Distribution
from distutils.errors import DistutilsArgError
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
"""
To make a wheel package:
$ python setup.py sdist bdist_wheel -d $YOUR_TARGET
To make a wheel package with specific python version and specific platform:
$ python setup.py sdist bdist_wheel -d $YOUR_TARGET --python-tag py39 --plat-name=linux_x86_64
"""
package_name = os.getenv("QUARK_WHEEL_NAME", "amd-quark")
_version_txt = open("quark/version.txt", "r").read().strip()
is_nightly = os.getenv('QUARK_NIGHTLY', 'false').lower() == 'true'
is_release = os.getenv('QUARK_RELEASE', 'false').lower() == 'true'
class CustomBdistWheel(_bdist_wheel):
def run(self):
super().run()
dist_dir = os.path.abspath(self.dist_dir)
wheel_files = [f for f in os.listdir(dist_dir) if f.endswith('.whl')]
wheel_file = os.path.join(dist_dir, wheel_files[0])
with tempfile.TemporaryDirectory() as temp_dir:
with zipfile.ZipFile(wheel_file, 'r') as zf:
zf.extractall(temp_dir)
dist_info = os.path.join(temp_dir, f"{self.wheel_dist_name}.dist-info")
record_path = os.path.join(dist_info, "RECORD")
with open(record_path, 'a') as record_file:
library_dir = "quark/onnx/operators/custom_ops/lib"
library_files = [
"libcustom_ops.so",
"libcustom_ops_gpu.so",
"custom_ops.dll",
"custom_ops.pyd"
]
for library_file in library_files:
library_path = os.path.join(library_dir, library_file)
record_file.write(f"{library_path},,\n")
with zipfile.ZipFile(wheel_file, 'a') as zf:
arcname = os.path.join(f"{self.wheel_dist_name}.dist-info", "RECORD")
zf.write(record_path, arcname)
def os_path_join(*args, **kwargs):
p = os.path.join(*args, **kwargs)
p = os.path.normpath(p)
return p
def os_path_exists(path):
path = os.path.normpath(path)
return os.path.exists(path)
def os_path_dirname(path):
path = os.path.normpath(path)
res_path = os.path.dirname(path)
return os.path.normpath(res_path)
def os_path_abspath(path):
path = os.path.normpath(path)
res_path = os.path.abspath(path)
return os.path.normpath(res_path)
def read_requirements():
with open('requirements.txt', 'r') as f:
requirements = f.read().splitlines()
return requirements
def build_config_setup():
cmdclass={"bdist_wheel": CustomBdistWheel}
return cmdclass
def get_git_hash():
try:
return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
except subprocess.CalledProcessError:
return "unknown"
def get_version(is_nightly=False, is_release=False):
"""Return the version of the Quark package
Nightly builds will have a `.devYYYYMMDD+<git_hash>` suffix and release builds will not have any suffix.
Non-nightly and non-release builds will have a `+<git-hash>` suffix for debugging purposes.
Args:
is_nightly (bool, optional): Whether the build is a nightly build. Defaults to False.
is_release (bool, optional): Whether the build is a release build. Defaults to False.
Returns:
str: The version of the Quark package
"""
assert not (is_nightly and is_release), "Quark build cannot be both nightly and release at the same time!"
global _version_txt
if is_release:
return f"{_version_txt}"
if is_nightly:
dev_suffix = f".dev{datetime.now().strftime('%Y%m%d')}"
return f"{_version_txt}{dev_suffix}"
git_hash = get_git_hash()
return f"{_version_txt}+{git_hash}"
if __name__ == '__main__':
dist = Distribution()
dist.script_name = sys.argv[0]
dist.script_args = sys.argv[1:]
try:
is_valid_args = dist.parse_command_line()
except DistutilsArgError as msg:
raise SystemExit(f"{core.gen_usage(dist.script_name)}\nerror:{msg}")
if not is_valid_args:
sys.exit()
cmdclass = build_config_setup()
install_requires = read_requirements()
cwd = os_path_dirname(os_path_abspath(__file__))
sha = get_git_hash()
version_path = os_path_join(cwd, "quark", "version.py")
with open(version_path, "w") as f:
f.write(f"__version__ = '{_version_txt}'\n")
f.write(f"git_version = '{sha}'\n")
f.write(f"is_release = {is_release}\n")
setup(name=package_name,
version=get_version(is_nightly, is_release),
description="The deep learning model compression toolkit.",
author="Advanced Micro Devices, Inc.",
author_email='[email protected]',
license="MIT",
packages=find_packages(include=['quark', 'quark.*']), # Only include folder 'quark'
include_package_data=True,
cmdclass=cmdclass,
install_requires=install_requires,
python_requires='>=3.9.0,<3.13',
)