forked from coreylammie/MemTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
65 lines (59 loc) · 1.99 KB
/
setup.py
File metadata and controls
65 lines (59 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from setuptools import setup, find_packages
import torch
version = '1.1.0'
CUDA = False
def create_version_py(version, CUDA):
file = open('memtorch/version.py', 'w')
if CUDA:
version_string = '__version__ = \'{}\''.format(version)
else:
version_string = '__version__ = \'{}-cpu\''.format(version)
file.write(version_string)
file.close()
create_version_py(version, CUDA)
if CUDA:
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
ext_modules = [
CUDAExtension('cuda_quantization', [
'memtorch/cu/quantize/quant_cuda.cpp',
'memtorch/cu/quantize/quant.cu'
], extra_include_paths='memtorch/cu/quantize'),
CppExtension('quantization', [
'memtorch/cpp/quantize/quant.cpp'
])]
name = 'memtorch'
else:
from torch.utils.cpp_extension import BuildExtension, CppExtension
ext_modules = [
CppExtension('quantization', [
'memtorch/cpp/quantize/quant.cpp'
])]
name = 'memtorch-cpu'
if __name__ == '__main__':
setup(name=name,
version=version,
description='A Simulation Framework for Memristive Deep Learning Systems',
long_description='A Simulation Framework for Memristive Deep Learning Systems which integrates directly with the well-known PyTorch Machine Learning (ML) library',
url='https://github.com/coreylammie/MemTorch',
license='GPL-3.0',
author='Corey Lammie',
author_email='coreylammie@jcu.edu.au',
ext_modules=ext_modules,
cmdclass={
'build_ext': BuildExtension
},
packages=find_packages(),
install_requires=[
'numpy',
'pandas',
'scipy',
'sklearn',
'torch>=1.2.0',
'torchvision',
'matplotlib',
'seaborn',
'ipython'
],
include_package_data=CUDA,
python_requires='>=3.6'
)