-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathsetup.py
More file actions
72 lines (66 loc) · 2.38 KB
/
setup.py
File metadata and controls
72 lines (66 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import subprocess
import setuptools
import importlib
import importlib.resources
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
torch.utils.cpp_extension.COMMON_NVCC_FLAGS = []
if __name__ == '__main__':
nvshmem_inc = os.getenv('NVSHMEM_INC', "/usr/include/nvshmem_12")
nvshmem_lib = os.getenv('NVSHMEM_LIB', "/usr/lib/x86_64-linux-gnu/nvshmem/12")
nvshmem_host_lib = 'libnvshmem_host.so'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
nvcc_flags = ['-O3', '-Xcompiler', '-O3', '-rdc=true']
sources = ['csrc/torch_interface.cpp',
"csrc/exchange.cu",
"csrc/all_reduce_ring_standard.cu",
"csrc/all_reduce_ring_simple.cu",
"csrc/all_reduce_tree.cu",
"csrc/all_reduce_oneshot.cu",
"csrc/all_reduce_twoshot.cu",
"csrc/all_reduce_double_ring.cu",
"csrc/custom_all_reduce.cu",
]
include_dirs = ['csrc/', nvshmem_inc]
library_dirs = [nvshmem_lib]
nvcc_dlink = ['-dlink', f'-L{nvshmem_lib}', '-lnvshmem_device']
extra_link_args = [f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_lib}']
# Put them together
extra_compile_args = {
'cxx': cxx_flags,
'nvcc': nvcc_flags,
'nvcc_dlink': nvcc_dlink,
}
# Summary
print(f'Build summary:')
print(f' > Sources: {sources}')
print(f' > Includes: {include_dirs}')
print(f' > Libraries: {library_dirs}')
print(f' > Compilation flags: {extra_compile_args}')
print(f' > Link flags: {extra_link_args}')
print(f' > NVSHMEM lib: {nvshmem_lib}')
print(f' > NVSHMEM include: {nvshmem_inc}')
print()
setuptools.setup(
name='penny',
version='0.0.1',
packages=setuptools.find_packages(
include=['penny']
),
ext_modules=[
CUDAExtension(
name='penny_cpp',
include_dirs=include_dirs,
library_dirs=library_dirs,
sources=sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=["cuda"]
)
],
cmdclass={
'build_ext': BuildExtension
}
)