From dac35adc193cf1d02b211af87b84ae8b40d23c7f Mon Sep 17 00:00:00 2001 From: "yan.yan" Date: Sat, 24 Sep 2022 22:50:42 +0800 Subject: [PATCH] add cuda 11.7, remove cuda 11.1 --- .github/workflows/build.yaml | 8 ++-- .gitignore | 4 +- README.md | 28 ++++++++---- docs/COMMON_PROBLEMS.md | 37 ++++++++++++++++ docs/PERFORMANCE_GUIDE.md | 1 + docs/SPCONV_DEVELOP_PLAN.md | 83 ------------------------------------ pyproject.toml | 2 +- setup.py | 11 +++-- spconv/algo.py | 11 +++-- spconv/benchmark/__init__.py | 14 ++++++ spconv/benchmark/__main__.py | 10 +---- spconv/benchmark/basic.py | 43 +++++++++++-------- spconv/benchmark/core.py | 23 ++++++++++ spconv/core.py | 60 +++++++++++++++++--------- spconv/cppconstants.py | 3 +- spconv/pytorch/ops.py | 22 +++++----- test/benchmark.py | 2 +- 17 files changed, 199 insertions(+), 163 deletions(-) create mode 100644 docs/COMMON_PROBLEMS.md delete mode 100644 docs/SPCONV_DEVELOP_PLAN.md diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index ce8f0e5..dbf58af 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -15,8 +15,8 @@ jobs: runs-on: windows-2019 strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - cuda-version: ['10.2', '11.1', '11.4'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11.0-rc.2'] + cuda-version: ['10.2', '11.3', '11.4', '11.7'] steps: - uses: actions/checkout@master - uses: dorny/paths-filter@v2 @@ -115,8 +115,8 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] # this version is only used for upload. - cuda-version: ['102', '111', '113', '114', ''] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11.0-rc.2'] # this version is only used for upload. + cuda-version: ['102', '113', '114', '117', ''] steps: - uses: actions/checkout@master diff --git a/.gitignore b/.gitignore index fec843e..83d30de 100644 --- a/.gitignore +++ b/.gitignore @@ -114,4 +114,6 @@ wheelhouse_tmp example/libspconv/cumm example/libspconv/spconv/include -example/libspconv/spconv/src \ No newline at end of file +example/libspconv/spconv/src + +third_party/boost \ No newline at end of file diff --git a/README.md b/README.md index 2c322b0..d9f2dea 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ [pypi-ver-cpu]: https://img.shields.io/pypi/v/spconv [pypi-ver-114]: https://img.shields.io/pypi/v/spconv-cu114 [pypi-ver-111]: https://img.shields.io/pypi/v/spconv-cu111 +[pypi-ver-117]: https://img.shields.io/pypi/v/spconv-cu117 + [pypi-ver-113]: https://img.shields.io/pypi/v/spconv-cu113 [pypi-ver-120]: https://img.shields.io/pypi/v/spconv-cu120 [pypi-ver-102]: https://img.shields.io/pypi/v/spconv-cu102 @@ -28,6 +30,8 @@ [pypi-download-113]: https://img.shields.io/pypi/dm/spconv-cu113 [pypi-url-114]: https://pypi.org/project/spconv-cu114/ [pypi-download-114]: https://img.shields.io/pypi/dm/spconv-cu114 +[pypi-url-117]: https://pypi.org/project/spconv-cu117/ +[pypi-download-117]: https://img.shields.io/pypi/dm/spconv-cu117 [pypi-url-120]: https://pypi.org/project/spconv-cu120/ [pypi-download-120]: https://img.shields.io/pypi/dm/spconv-cu120 [pypi-url-cpu]: https://pypi.org/project/spconv/ @@ -41,9 +45,9 @@ | -------------- |:---------------------:| ---------------------:| ---------------------:| | CPU (Linux Only) | [![PyPI Version][pypi-ver-cpu]][pypi-url-cpu] | ```pip install spconv``` | [![pypi monthly download][pypi-download-cpu]][pypi-url-cpu] | | CUDA 10.2 | [![PyPI Version][pypi-ver-102]][pypi-url-102] | ```pip install spconv-cu102```| [![pypi monthly download][pypi-download-102]][pypi-url-102]| -| CUDA 11.1 | [![PyPI Version][pypi-ver-111]][pypi-url-111] | ```pip install spconv-cu111```| [![pypi monthly download][pypi-download-111]][pypi-url-111]| | CUDA 11.3 (Linux Only) | [![PyPI Version][pypi-ver-113]][pypi-url-113] | ```pip install spconv-cu113```| [![pypi monthly download][pypi-download-113]][pypi-url-113]| | CUDA 11.4 | [![PyPI Version][pypi-ver-114]][pypi-url-114] | ```pip install spconv-cu114```| [![pypi monthly download][pypi-download-114]][pypi-url-114]| +| CUDA 11.7 | [![PyPI Version][pypi-ver-117]][pypi-url-117] | ```pip install spconv-cu117```| [![pypi monthly download][pypi-download-117]][pypi-url-117]| ```spconv``` is a project that provide heavily-optimized sparse convolution implementation with tensor core support. check [benchmark](docs/BENCHMARK.md) to see how fast spconv 2.x runs. @@ -52,15 +56,19 @@ Check [spconv 2.x algorithm introduction](docs/spconv2_algo.pdf) to understand sparse convolution algorithm in spconv 2.x! +## WARNING + +Use spconv >= cu114 if possible. cuda 11.4 can compile greatly faster kernel in some situation. + ## NEWS * spconv 2.2: ampere feature support (by [EvernightAurora](https://github.com/EvernightAurora)), pure c++ code generation, nvrtc, drop python 3.6 ## Spconv 2.2 vs Spconv 2.1 -* faster fp16 kernels (~5-30%) in ampere GPUs (tested in RTX 3090) -* greatly faster int8 kernels (~1.2x~2.7x) in ampere GPUs (tested in RTX 3090) -* no python 3.6 support +* faster fp16 conv kernels (~5-30%) in ampere GPUs (tested in RTX 3090) +* greatly faster int8 conv kernels (~1.2x~2.7x) in ampere GPUs (tested in RTX 3090) +* drop python 3.6 support * nvrtc support: kernel in old GPUs will be compiled in runtime. * [libspconv](docs/PURE_CPP_BUILD.md): pure c++ build of all spconv ops. see [example](example/libspconv/run_build.sh) * tf32 kernels, faster fp32 training, disabled by default. set ```import spconv as spconv_core; spconv_core.constants.SPCONV_ALLOW_TF32 = True``` to enable them. @@ -84,6 +92,10 @@ Then see [this](docs/USAGE.md). Don't forget to check [performance guide](docs/PERFORMANCE_GUIDE.md). +### Common Solution for Some Bugs + +see [common problems](docs/COMMON_PROBLEMS.md). + ## Install You need to install python >= 3.7 first to use spconv 2.x. @@ -94,9 +106,9 @@ You need at least CUDA 11.0 to build and run spconv 2.x. We won't offer any supp ### Prebuilt -We offer python 3.7-3.11 and cuda 10.2/11.1/11.3/11.4/12.0 prebuilt binaries for linux (manylinux). +We offer python 3.7-3.11 and cuda 10.2/11.3/11.4/11.7/12.0 prebuilt binaries for linux (manylinux). -We offer python 3.7-3.11 and cuda 10.2/11.1/11.4/12.0 prebuilt binaries for windows 10/11. +We offer python 3.7-3.11 and cuda 10.2/11.4/11.7/12.0 prebuilt binaries for windows 10/11. For Linux users, you need to install pip >= 20.3 first to install prebuilt. @@ -104,12 +116,12 @@ For Linux users, you need to install pip >= 20.3 first to install prebuilt. ```pip install spconv-cu102``` for CUDA 10.2 -```pip install spconv-cu111``` for CUDA 11.1 - ```pip install spconv-cu113``` for CUDA 11.3 (**Linux Only**) ```pip install spconv-cu114``` for CUDA 11.4 +```pip install spconv-cu117``` for CUDA 11.7 + ```pip install spconv-cu120``` for CUDA 12.0 **NOTE** It's safe to have different **minor** cuda version between system and conda (pytorch) in **CUDA >= 11.0** because of [CUDA Minor Version Compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/#minor-version-compatibility). For example, you can use spconv-cu114 with anaconda version of pytorch cuda 11.1 in a OS with CUDA 11.2 installed. diff --git a/docs/COMMON_PROBLEMS.md b/docs/COMMON_PROBLEMS.md new file mode 100644 index 0000000..9a78edf --- /dev/null +++ b/docs/COMMON_PROBLEMS.md @@ -0,0 +1,37 @@ + + +# Common Problems + +## the provided PTX was compiled with an unsupported toolchain + +Update your GPU driver or downgrad your spconv/cumm cuda version. + +## CUDA kernel launch blocks must be positive, but got N= 0 + +Your coordinates generate nothing with some conv params. Modify your conv params to make sure all input points have at least one output point. + +Example: + +Conv Params: +```spatial shape=[8, 200, 200],ksize=[3, 3, 3],stride=[2, 2, 2],padding=[0, 1, 1],dilation=[1, 1, 1]``` +Coordinates: +``` +[[0, 7, 153, 142]] +``` + +The convolution in z axis will drop ALL points in z == 7. change the padding-z to solve this problem. + diff --git a/docs/PERFORMANCE_GUIDE.md b/docs/PERFORMANCE_GUIDE.md index 2acef8c..7869a7d 100644 --- a/docs/PERFORMANCE_GUIDE.md +++ b/docs/PERFORMANCE_GUIDE.md @@ -26,3 +26,4 @@ * spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible. * If you train with float32 and ampere or later GPUs, you can set ```spconv.constants.SPCONV_ALLOW_TF32``` to enable faster fp32 training. See [benchmark](BENCHMARK.md) for more performance details of different algorithms. +* Different CUDA version of spconv may have different performance. Use newest cuda version if possible. For example, spconv-cu117 is faster than spconv-cu114, spconv-cu114 is faster than spconv-cu111. \ No newline at end of file diff --git a/docs/SPCONV_DEVELOP_PLAN.md b/docs/SPCONV_DEVELOP_PLAN.md deleted file mode 100644 index 00df9eb..0000000 --- a/docs/SPCONV_DEVELOP_PLAN.md +++ /dev/null @@ -1,83 +0,0 @@ - - -## Spconv 2.x Develop Plan - -If someone want to contribute to spconv 2.x, feel free to start new discussion in github, or just email to me. - - -### v2.2 Core Features - -- [ ] TF32 support -- [ ] Make ```ConvAlgo.Native``` runable in KRSC layout and only use this layout in future -- [ ] PyTorch Int8 Support - -### v2.3 Core Features - -- [ ] Move most of function in spconv.pytorch.ops to C++ -- [ ] Ampere multi-stage gemm support -- [ ] Optimize CUDA Kernels for small-channel-size layers. - -### v2.4 Core Features - -- [ ] nvrtc support for gemm/conv kernels -- [ ] C++ only spconv -- [ ] TensorRT support - -### Misc Features need contribution - -- [ ] Test spconv 2.x in [torch-points3d](https://github.com/nicolas-chaulet/torch-points3d) and other frameworks -- [ ] Documents in github Page -- [ ] Better tests - - -### Details - -1. TF32 support - -we only need to add tf32 tensor cores to cumm. not hard. - -2. Make ```ConvAlgo.Native``` runable in KRSC layout - -Add stride arg to gemm kernels, use offset + stride to force gemm kernel use KRSC layout as a "KC" matrix. - -3. PyTorch Int8 Support - -... - -4. Move most of function in spconv.pytorch.ops to C++ - -Pure engieering work. - -5. Ampere multi-stage gemm support - -Not easy, we need to use new pattern to write gemm kernels. - -6. Optimize CUDA Kernels for small-channel-size layers - -modify cumm and make it support small kernels. not hard, but need time. - -7. nvrtc support for gemm/conv kernels - -need to rewrite kernel params in cumm. not easy. - -8. C++ only spconv - -actually code generation is easy, we can finish this easily after move ops to c++. - -9. TensorRT support - -The TensorRT support is the last feature in this plan. it needs lots of engieering work and prerequisites, may cost much time. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0cce441..d215ffa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools>=41.0", "wheel", "pccm>=0.2.21", "cumm>=0.2.3"] +requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.3.0"] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 3c4a40e..469036d 100644 --- a/setup.py +++ b/setup.py @@ -163,9 +163,14 @@ def run(self): from spconv.csrc.sparse.convops import GemmTunerSimple, ExternalSpconvMatmul from spconv.csrc.sparse.convops import ConvTunerSimple, ConvGemmOps from spconv.csrc.sparse.inference import InferenceOps - - cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS) - convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS) + all_shuffle = SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS + SHUFFLE_AMPERE_PARAMS + all_imp = (IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + + IMPLGEMM_TURING_PARAMS + IMPLGEMM_AMPERE_PARAMS) + all_shuffle = list(filter(lambda x: not x.is_nvrtc, all_shuffle)) + all_imp = list(filter(lambda x: not x.is_nvrtc, all_imp)) + + cu = GemmMainUnitTest(all_shuffle) + convcu = ConvMainUnitTest(all_imp) convcu.namespace = "cumm.conv.main" cu.namespace = "cumm.gemm.main" diff --git a/spconv/algo.py b/spconv/algo.py index bbf8807..c453a1b 100644 --- a/spconv/algo.py +++ b/spconv/algo.py @@ -40,7 +40,7 @@ from spconv.core import ALL_IMPGEMM_PARAMS, AlgoHint, ConvAlgo, ALL_NATIVE_PARAMS from spconv.core_cc.cumm.conv.main import ConvMainUnitTest from spconv.core_cc.cumm.gemm.main import GemmMainUnitTest -from spconv.cppconstants import COMPILED_CUDA_ARCHS +from spconv.cppconstants import COMPILED_CUDA_GEMM_ARCHS from cumm.tensorview.gemm import NVRTCParams from spconv.tools import CUDAKernelTimer from cumm.gemm.constants import NVRTCConstants, NVRTCMode @@ -337,7 +337,7 @@ def get_all_available( ldb = b.stride[0] ldc = c.stride[0] if desp.supported_ldx(lda, ldb, ldc): - if arch not in COMPILED_CUDA_ARCHS: + if arch not in COMPILED_CUDA_GEMM_ARCHS: desp = desp.copy() desp.is_nvrtc = True if SPCONV_DEBUG_NVRTC_KERNELS: @@ -720,7 +720,7 @@ def get_all_available(self, assert mask_width > 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0 if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: - if arch not in COMPILED_CUDA_ARCHS: + if arch not in COMPILED_CUDA_GEMM_ARCHS: desp = desp.copy() desp.is_nvrtc = True if SPCONV_DEBUG_NVRTC_KERNELS: @@ -822,6 +822,7 @@ def tune_and_cache(self, times: List[float] = [] all_profile_res: List[BestConvAlgoByProfile] = [] + group_by_algo = {} for desp in avail: # for sparse conv, ndim isn't used, so we just provide a constant value. params = ConvParams(NDIM_DONT_CARE, ConvOpTypeCpp(op_type.value)) @@ -865,7 +866,9 @@ def tune_and_cache(self, this_times.append(measure.duration) times.append(np.mean(this_times[1:])) spk_speeds.append(times[-1]) - + if desp.algo not in group_by_algo: + group_by_algo[desp.algo] = 10000.0 + group_by_algo[desp.algo] = min(times[-1], group_by_algo[desp.algo]) all_profile_res.append( BestConvAlgoByProfile(desp, arch, splitk=spk)) if not all_profile_res: diff --git a/spconv/benchmark/__init__.py b/spconv/benchmark/__init__.py index e69de29..960eb82 100644 --- a/spconv/benchmark/__init__.py +++ b/spconv/benchmark/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 Yan Yan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/spconv/benchmark/__main__.py b/spconv/benchmark/__main__.py index 6cac583..0db8a09 100644 --- a/spconv/benchmark/__main__.py +++ b/spconv/benchmark/__main__.py @@ -1,14 +1,6 @@ -from .basic import bench_basic +from .basic import bench_basic, bench_large import fire -def bench_me_basic(dtype_str: str): - from spconv.benchmark.me import bench_me_basic - return bench_me_basic(dtype_str) - -def bench_torchsparse_basic(dtype_str: str): - from spconv.benchmark.thsp import bench_torchsparse_basic - return bench_torchsparse_basic(dtype_str) - if __name__ == "__main__": fire.Fire() diff --git a/spconv/benchmark/basic.py b/spconv/benchmark/basic.py index 52d1c47..9fb95d9 100644 --- a/spconv/benchmark/basic.py +++ b/spconv/benchmark/basic.py @@ -1,4 +1,4 @@ -from spconv.benchmark.core import get_voxel_data +from spconv.benchmark.core import get_voxel_data, get_voxel_data_large import time @@ -12,7 +12,7 @@ from cumm import dtypes import spconv.pytorch as spconv from spconv.test_utils import params_grid - +import spconv as spconv_core class Net(nn.Module): def __init__(self, shape, algo): super().__init__() @@ -150,15 +150,23 @@ def forward(self, features, coors, batch_size, enable_timer: bool = False): dtypes.float16: torch.float16, } -def bench_basic(dtype_str: str): +def bench_basic(dtype_str: str, is_large: bool = False): + assert dtype_str in ["f16", "f32", "tf32"], "only support f16, f32, tf32" + if dtype_str == "tf32": + spconv_core.constants.SPCONV_ALLOW_TF32 = True + dtype_str = "f32" + dtype = dtypes.get_dtype_by_shortcut(dtype_str) if dtype not in _DTYPE_TO_TORCH_DTYPE: raise NotImplementedError("only support bench f32 and f16 for now") torch_dtype = _DTYPE_TO_TORCH_DTYPE[dtype] algos = [spconv.ConvAlgo.Native, spconv.ConvAlgo.MaskImplicitGemm, spconv.ConvAlgo.MaskSplitImplicitGemm] - (voxels, coors, spatial_shape) = get_voxel_data() + if is_large: + (voxels, coors, spatial_shape) = get_voxel_data_large() + else: + (voxels, coors, spatial_shape) = get_voxel_data() + name = "basic-L" if is_large else "basic" device = torch.device("cuda:0") - for algo, in params_grid(algos): voxels_th = torch.from_numpy(voxels).to(device).to(torch_dtype) coors_th = torch.from_numpy(coors).to(device).int() @@ -172,23 +180,22 @@ def bench_basic(dtype_str: str): times = [] with torch.no_grad(): for i in range(100): - torch.cuda.synchronize() - t = time.time() - out_nograd = net(voxels_th, coors_th, 1, False) - timer = out_nograd._timer - torch.cuda.synchronize() - times.append(time.time() - t) - print(f"basic[{dtype_str}|{algo}|forward]", np.mean(times[50:])) + with tv.measure_duration() as measure: + out_nograd = net(voxels_th, coors_th, 1, False) + times.append(measure.duration) + print(f"{name}[{dtype_str}|{algo}|forward]", np.mean(times[50:])) times = [] for i in range(50): out = net(voxels_th, coors_th, 1) - torch.cuda.synchronize() - t = time.time() - out.features.backward(dout_t) - torch.cuda.synchronize() - times.append(time.time() - t) - print(f"basic[{dtype_str}|{algo}|backward]", np.mean(times[25:])) + with tv.measure_duration() as measure: + out.features.backward(dout_t) + times.append(measure.duration) + print(f"{name}[{dtype_str}|{algo}|backward]", np.mean(times[25:])) + + +def bench_large(dtype_str: str): + return bench_basic(dtype_str, True) if __name__ == "__main__": bench_basic("f16") \ No newline at end of file diff --git a/spconv/benchmark/core.py b/spconv/benchmark/core.py index 1555c36..822b89d 100644 --- a/spconv/benchmark/core.py +++ b/spconv/benchmark/core.py @@ -4,6 +4,8 @@ from io import BytesIO import numpy as np from spconv.constants import PACKAGE_ROOT +from spconv.utils import Point2VoxelCPU3d, Point2VoxelGPU3d +from cumm import tensorview as tv RAW_TEST_DATA_PATH = "https://raw.githubusercontent.com/traveller59/spconv/v2.1.10/test/data/test_spconv.pkl" RAW_PC_PATH = "https://raw.githubusercontent.com/traveller59/spconv/v2.1.10/test/data/benchmark-pc.npz" @@ -36,6 +38,27 @@ def get_pc_data(): pc = np.load(ff)["pc"] return pc +def get_voxel_data_large(): + pc = get_pc_data() + gen = Point2VoxelGPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, + 1600000, 1) + + pcs = [pc] + for i in range(7): + pc2 = pc.copy() + pc2[:, 1] += i + 1 + pcs.append(pc2) + + pc = np.concatenate(pcs) + voxels_tv, indices_tv, _ = gen.point_to_voxel_hash(tv.from_numpy(pc).cuda()) + voxels = voxels_tv.cpu().numpy().reshape(-1, 3) + coors = indices_tv.cpu().numpy() + N = coors.shape[0] + # breakpoint() + coors = np.concatenate([np.full([N, 1], 0, coors.dtype), coors], axis=1) + return voxels, coors, gen.grid_size + + if __name__ == "__main__": pc = get_pc_data() print(pc[:10]) \ No newline at end of file diff --git a/spconv/core.py b/spconv/core.py index 6b966b2..712ea4b 100644 --- a/spconv/core.py +++ b/spconv/core.py @@ -634,7 +634,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), NDIM_DONT_CARE, @@ -648,7 +649,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 16)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), NDIM_DONT_CARE, @@ -662,7 +664,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), NDIM_DONT_CARE, @@ -676,7 +679,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 16)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, @@ -690,7 +694,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 16)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), NDIM_DONT_CARE, @@ -704,7 +709,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), NDIM_DONT_CARE, @@ -718,7 +724,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), NDIM_DONT_CARE, @@ -732,7 +739,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), NDIM_DONT_CARE, @@ -746,7 +754,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), NDIM_DONT_CARE, @@ -760,7 +769,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), ] IMPLGEMM_TURING_PARAMS = [ @@ -777,7 +787,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 16)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 64), (32, 32, 64), NDIM_DONT_CARE, @@ -791,7 +802,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 64), NDIM_DONT_CARE, @@ -805,7 +817,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), NDIM_DONT_CARE, @@ -819,7 +832,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 16)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 64), (64, 32, 64), NDIM_DONT_CARE, @@ -833,7 +847,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 64, 32), (64, 32, 32), NDIM_DONT_CARE, @@ -847,7 +862,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 16)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 256, 64), (64, 128, 64), NDIM_DONT_CARE, @@ -861,7 +877,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (256, 128, 64), (128, 64, 64), NDIM_DONT_CARE, @@ -875,7 +892,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 128), (64, 64, 128), NDIM_DONT_CARE, @@ -889,7 +907,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (128, 128, 64), (64, 64, 64), NDIM_DONT_CARE, @@ -903,7 +922,8 @@ class AlgoHint(Enum): TensorOp((16, 8, 32)), mask_sparse=True, increment_k_first=True, - access_per_vector=1), + access_per_vector=1, + is_nvrtc=True), *gen_conv_params(ConvFwdAndBwdInput, (32, 16, 16), (16, 16, 16), diff --git a/spconv/cppconstants.py b/spconv/cppconstants.py index 7c5ae78..2568a47 100644 --- a/spconv/cppconstants.py +++ b/spconv/cppconstants.py @@ -27,4 +27,5 @@ from spconv.core_cc.cumm.common import CompileInfo HAS_BOOST = BoxOps.has_boost() -COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_gemm_cuda_arch()) +COMPILED_CUDA_ARCHS = set(CompileInfo.get_compiled_cuda_arch()) +COMPILED_CUDA_GEMM_ARCHS = set(CompileInfo.get_compiled_gemm_cuda_arch()) diff --git a/spconv/pytorch/ops.py b/spconv/pytorch/ops.py index f5657de..00d3b33 100644 --- a/spconv/pytorch/ops.py +++ b/spconv/pytorch/ops.py @@ -46,6 +46,7 @@ from spconv.constants import FILTER_HWIO, ALL_WEIGHT_IS_KRSC, AllocKeys, SPCONV_USE_DIRECT_TABLE from cumm.gemm import codeops from spconv.tools import CUDAKernelTimer +from spconv import constants DEBUG = False DEBUG_INT64_HASH_K = False @@ -832,7 +833,7 @@ def indice_conv(features: torch.Tensor, indice_pairs_tv, indice_pair_num_tv, arch, num_activate_out, inverse, subm, algo.value, stream, bias_tv, act_alpha, act_beta, act_type, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) out_features = alloc.allocated[AllocKeys.OutFeatures] return out_features if not features.is_cuda: @@ -1013,7 +1014,7 @@ def indice_conv(features: torch.Tensor, beta=0.0, hint=AlgoHint.Fowrard.value, stream=stream, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) # CONV.stream_synchronize(stream) # t = time.time() with timer.record("forward", stream): @@ -1105,7 +1106,7 @@ def indice_conv_backward(features: torch.Tensor, features_tv, filters_tv, out_bp_tv, indice_pairs_tv, indice_pair_num_tv, arch, inverse, subm, algo.value, - stream, use_tf32=SPCONV_ALLOW_TF32) + stream, use_tf32=constants.SPCONV_ALLOW_TF32) din = alloc.allocated[AllocKeys.DIn] df = alloc.allocated[AllocKeys.DFilters] return din, df @@ -1273,7 +1274,7 @@ def indice_conv_backward(features: torch.Tensor, beta=0.0, hint=AlgoHint.BackwardInput.value, stream=stream, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) if is_KC_not_CK: a_wgrad = out_bp_tv b_wgrad = features_tv @@ -1321,7 +1322,7 @@ def indice_conv_backward(features: torch.Tensor, beta=0.0, hint=AlgoHint.BackwardWeight.value, stream=stream, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) # print(tuned_res_wgrad.algo_desp, tuned_res_wgrad.splitk, min_time) # get workspace size for wgrad if is_KC_not_CK: @@ -1467,7 +1468,7 @@ def implicit_gemm(features: torch.Tensor, pair_mask_fwd_splits_tv, mask_argsort_fwd_splits_tv, num_activate_out, mask_tv, arch, is_train, is_subm, stream, timer_cpp, auto_fp32_accum, fp32_accum, bias_tv, act_alpha, act_beta, act_type, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) out_features = alloc.allocated[AllocKeys.OutFeatures] mask_output_fwd = alloc.allocated.get(AllocKeys.MaskOutputFwd, None) if is_train: @@ -1535,7 +1536,8 @@ def implicit_gemm(features: torch.Tensor, mask_filter=masks[0].item(), stream=stream, fp32_accum=fp32_accum, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) + mask_width = tune_res.algo_desp.tile_shape[0] if is_train: mask_output_fwd = torch.empty( @@ -1748,7 +1750,7 @@ def implicit_gemm_backward(features: torch.Tensor, mask_argsort_fwd_splits_tv, mask_argsort_bwd_splits_tv, mask_output_fwd_tv, mask_tv, arch, mask_width, is_subm, stream, timer_cpp, auto_fp32_accum, fp32_accum, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) din = alloc.allocated[AllocKeys.DIn] dfilters = alloc.allocated[AllocKeys.DFilters] return din, dfilters @@ -1825,7 +1827,7 @@ def implicit_gemm_backward(features: torch.Tensor, mask_filter=masks[0].item(), stream=stream, fp32_accum=fp32_accum, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) if wgrad_tune_res is None: wgrad_tune_res, _ = CONV.tune_and_cache( ConvOpType.kBackwardWeight, @@ -1844,7 +1846,7 @@ def implicit_gemm_backward(features: torch.Tensor, mask_output=tv.Tensor(), mask_width=mask_width, stream=stream, - use_tf32=SPCONV_ALLOW_TF32) + use_tf32=constants.SPCONV_ALLOW_TF32) workspace_size = CONV.query_workspace_size(wgrad_tune_res.algo_desp, wgrad_tune_res.splitk, ConvOpType.kBackwardWeight, diff --git a/test/benchmark.py b/test/benchmark.py index 32c2bc1..cacbbe9 100644 --- a/test/benchmark.py +++ b/test/benchmark.py @@ -395,7 +395,7 @@ def main(): # voxels, coors, spatial_shape = waymo_data(num_features=3) with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f: (voxels, coors, spatial_shape) = pickle.load(f) - # voxels, coors, spatial_shape = waymo_data_large_debug() + voxels, coors, spatial_shape = waymo_data_large() # breakpoint() print(spatial_shape)