From 82fd7a8b824cc8afd2788d86b9ce648b301ae476 Mon Sep 17 00:00:00 2001 From: "yan.yan" Date: Wed, 10 Nov 2021 22:27:00 +0800 Subject: [PATCH] v2.1.5: add profile tool and python 3.6 for linux --- .github/workflows/build.yaml | 2 +- .github/workflows/stale.yaml | 5 +- CHANGELOG.md | 9 + README.md | 35 +- docs/BENCHMARK.md | 48 ++ docs/PERFORMANCE_GUIDE.md | 9 +- example/mnist_sparse.py | 116 ++- example/voxel_gen.py | 68 +- format_all.sh | 6 +- pyproject.toml | 2 +- scripts/dev_subm.py | 111 ++- scripts/sort_bench.py | 22 +- setup.py | 8 +- spconv/__init__.py | 8 +- spconv/algo.py | 17 +- spconv/build.py | 17 +- spconv/constants.py | 12 +- spconv/core.py | 684 +++++++++++++----- spconv/core_cc/__init__.pyi | 14 - spconv/core_cc/csrc/__init__.pyi | 14 - spconv/core_cc/csrc/sparse/__init__.pyi | 14 - spconv/core_cc/csrc/sparse/all/ops1d.pyi | 23 +- .../csrc/sparse/all/ops1d/__init__.pyi | 49 -- .../core_cc/csrc/sparse/all/ops1d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops2d.pyi | 23 +- .../csrc/sparse/all/ops2d/__init__.pyi | 49 -- .../core_cc/csrc/sparse/all/ops2d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops3d.pyi | 23 +- .../csrc/sparse/all/ops3d/__init__.pyi | 49 -- .../core_cc/csrc/sparse/all/ops3d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops4d.pyi | 23 +- .../csrc/sparse/all/ops4d/__init__.pyi | 49 -- .../core_cc/csrc/sparse/all/ops4d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi | 11 +- .../csrc/sparse/all/ops_cpu1d/__init__.pyi | 74 -- .../csrc/sparse/all/ops_cpu1d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi | 11 +- .../csrc/sparse/all/ops_cpu2d/__init__.pyi | 74 -- .../csrc/sparse/all/ops_cpu2d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi | 11 +- .../csrc/sparse/all/ops_cpu3d/__init__.pyi | 74 -- .../csrc/sparse/all/ops_cpu3d/p2v_c.pyi | 11 - spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi | 11 +- .../csrc/sparse/all/ops_cpu4d/__init__.pyi | 74 -- .../csrc/sparse/all/ops_cpu4d/p2v_c.pyi | 11 - spconv/core_cc/cumm/__init__.pyi | 14 - spconv/core_cc/cumm/conv/main.pyi | 5 +- spconv/core_cc/cumm/gemm/__init__.pyi | 14 - spconv/core_cc/cumm/gemm/main.pyi | 9 +- spconv/core_cc/cumm/tools/__init__.pyi | 0 spconv/core_cc/cumm/tools/cuda.pyi | 56 ++ spconv/cppconstants.py | 8 +- spconv/csrc/__init__.py | 7 +- spconv/csrc/sparse/__init__.py | 7 +- spconv/csrc/sparse/all.py | 92 ++- spconv/csrc/sparse/cpu_core.py | 29 + spconv/csrc/sparse/devleop/sort_bench.py | 7 +- spconv/csrc/sparse/gather.py | 45 +- spconv/csrc/sparse/indices.py | 278 +++---- spconv/csrc/sparse/maxpool.py | 87 ++- spconv/csrc/sparse/pointops.py | 86 ++- spconv/pytorch/__init__.py | 11 +- spconv/pytorch/constants.py | 10 +- spconv/pytorch/conv.py | 285 ++++---- spconv/pytorch/core.py | 20 +- spconv/pytorch/cppcore.py | 22 +- spconv/pytorch/functional.py | 138 ++-- spconv/pytorch/modules.py | 13 +- spconv/pytorch/ops.py | 446 ++++++------ spconv/pytorch/pool.py | 151 ++-- spconv/pytorch/spatial.py | 6 +- spconv/pytorch/tables.py | 26 +- spconv/pytorch/utils.py | 30 +- spconv/test_utils.py | 6 +- spconv/tools.py | 78 ++ spconv/utils/__init__.py | 35 +- test/aaa.py | 112 --- test/benchmark.py | 98 ++- test/test_conv.py | 63 +- version.txt | 2 +- 80 files changed, 2254 insertions(+), 1979 deletions(-) create mode 100644 docs/BENCHMARK.md delete mode 100644 spconv/core_cc/csrc/sparse/all/ops1d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops1d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops2d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops2d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops3d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops3d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops4d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops4d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu1d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu1d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu2d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu2d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu3d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu3d/p2v_c.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu4d/__init__.pyi delete mode 100644 spconv/core_cc/csrc/sparse/all/ops_cpu4d/p2v_c.pyi create mode 100644 spconv/core_cc/cumm/tools/__init__.pyi create mode 100644 spconv/core_cc/cumm/tools/cuda.pyi create mode 100644 spconv/csrc/sparse/cpu_core.py create mode 100644 spconv/tools.py delete mode 100644 test/aaa.py diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 340c587..2e05f2b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -89,7 +89,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] # this version is only used for upload. + python-version: ['3.6', '3.7', '3.8', '3.9', '3.10'] # this version is only used for upload. cuda-version: ['102', '111', '113', '114', ''] steps: diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml index ba5051a..1ae3445 100644 --- a/.github/workflows/stale.yaml +++ b/.github/workflows/stale.yaml @@ -14,5 +14,6 @@ jobs: steps: - uses: actions/stale@v4 with: - stale-issue-message: 'Close stale issues due to inactivity.' - stale-pr-message: 'Close stale PRs due to inactivity.' + stale-issue-message: 'Mark stale issues due to inactivity.' + stale-pr-message: 'Mark stale PRs due to inactivity.' + operations-per-run: 300 diff --git a/CHANGELOG.md b/CHANGELOG.md index e825ddd..7912f04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## [2.1.5] - 2021-11-10 +### Added +- Add cuda profile tool +- Add python 36 support +### Changed +- Format all code +### Removed +- remove a unnecessary device sync and slightly improve performance. + ## [2.1.0] - 2021-10-31 ### Addad * add implicit gemm algorithm for all kind of convolution with kernel volume <= 32. this algorithm is very fast with float16. diff --git a/README.md b/README.md index 4c55a84..4e91bab 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. --> - -[pypi-download]: https://img.shields.io/pypi/dm/spconv-cu114 -[pypi-url]: https://pypi.org/project/spconv-cu114/ -[pypi-image]: https://badge.fury.io/py/spconv-cu114.svg +[pypi-ver-cpu]: https://img.shields.io/pypi/v/spconv +[pypi-ver-114]: https://img.shields.io/pypi/v/spconv-cu114 +[pypi-ver-111]: https://img.shields.io/pypi/v/spconv-cu111 +[pypi-ver-113]: https://img.shields.io/pypi/v/spconv-cu113 +[pypi-ver-102]: https://img.shields.io/pypi/v/spconv-cu102 + +[pypi-url-111]: https://pypi.org/project/spconv-cu111/ +[pypi-download-111]: https://img.shields.io/pypi/dm/spconv-cu111 +[pypi-url-113]: https://pypi.org/project/spconv-cu113/ +[pypi-download-113]: https://img.shields.io/pypi/dm/spconv-cu113 +[pypi-url-102]: https://pypi.org/project/spconv-cu102/ +[pypi-download-102]: https://img.shields.io/pypi/dm/spconv-cu102 +[pypi-url-114]: https://pypi.org/project/spconv-cu114/ +[pypi-download-114]: https://img.shields.io/pypi/dm/spconv-cu114 +[pypi-url-cpu]: https://pypi.org/project/spconv/ +[pypi-download-cpu]: https://img.shields.io/pypi/dm/spconv # SpConv: Spatially Sparse Convolution Library -[![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild) [![PyPI Version][pypi-image]][pypi-url] [![pypi monthly download][pypi-download]][pypi-url] +[![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild) + +| | PyPi Version | Downloads | +| -------------- |:---------------------:| ---------------------:| +| CPU (Linux Only) | [![PyPI Version][pypi-ver-cpu]][pypi-url-cpu] | [![pypi monthly download][pypi-download-cpu]][pypi-url-cpu] | +| CUDA 10.2 | [![PyPI Version][pypi-ver-102]][pypi-url-102] | [![pypi monthly download][pypi-download-102]][pypi-url-102] | +| CUDA 11.1 | [![PyPI Version][pypi-ver-111]][pypi-url-111] | [![pypi monthly download][pypi-download-111]][pypi-url-111]| +| CUDA 11.3 (Linux Only) | [![PyPI Version][pypi-ver-113]][pypi-url-113] |[![pypi monthly download][pypi-download-113]][pypi-url-113]| +| CUDA 11.4 | [![PyPI Version][pypi-ver-114]][pypi-url-114] | [![pypi monthly download][pypi-download-114]][pypi-url-114]| -```spconv``` is a project that provide heavily-optimized sparse convolution implementation with tensor core support. +```spconv``` is a project that provide heavily-optimized sparse convolution implementation with tensor core support. check [benchmark](docs/BENCHMARK.md) to see how fast spconv 2.x runs. [Spconv 1.x code](https://github.com/traveller59/spconv/tree/v1.2.1). We won't provide any support for spconv 1.x since it's deprecated. use spconv 2.x if possible. @@ -99,7 +119,10 @@ The c++ code will be built automatically when you change c++ code in project. For NVIDIA Embedded Platforms, you need to specify cuda arch before build: ```export CUMM_CUDA_ARCH_LIST="7.2"``` for xavier. +You need to remove ```cumm``` in ```requires``` section in pyproject.toml after install editable ```cumm``` and before install spconv due to pyproject limit (can't find editable installed ```cumm```). + #### Linux + 0. uninstall spconv and cumm installed by pip 1. install build-essential, install CUDA 2. ```git clone https://github.com/FindDefinition/cumm```, ```cd ./cumm```, ```pip install -e .``` diff --git a/docs/BENCHMARK.md b/docs/BENCHMARK.md new file mode 100644 index 0000000..5bd53f9 --- /dev/null +++ b/docs/BENCHMARK.md @@ -0,0 +1,48 @@ + + +## Simple Benchmark + +### Network Benchmark without batchnorm (F32/F16) in RTX 3080 Laptop GPU + +Network Code: test/benchmark.py + +| F32/F16 | Spconv 1.x F32 (1080Ti) | Native| Implicit Gemm | Implicit Gemm Split Mask | +| -------------- |:---------------------:|---------------------:|---------------------:| ---------------------:| +| Forward | 43ms | 21.7ms/13.7ms | 23.5ms/11.2ms | 22ms/12.2ms | +| Backward | 80ms | 41.9ms/25.2ms | 51.0ms/13.8ms | 41.1ms/12.2ms | + +### Network Gemm Kernel Benchmark FP16 in RTX 3080 Laptop GPU + +Network Code: test/benchmark.py + +The network/input/profile code is same as above table. + +This table only profile **fp16 gemm kernels** without output tensor create/clear overhead. this table show the performance upper bound of our algorithm. + +| F16 | Native| Implicit Gemm | Implicit Gemm Split Mask | +| -------------- |:---------------------:|---------------------:| ---------------------:| +| Forward | 8.0ms | 4.3ms | 4.0ms | + +We can see that the implicit gemm is very fast, gemm only use 4.3ms/11.2ms in network forward. we can achieve better performance in TensorRT + Pure C++. + +**NOTE** +When you want to benchmark network in your laptop, don't forget to close all apps except terminals! Other apps will consume GPU resource and make kernels run slower. + + +## Comparsion with [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) and [torchsparse](https://github.com/mit-han-lab/torchsparse) + +TODO \ No newline at end of file diff --git a/docs/PERFORMANCE_GUIDE.md b/docs/PERFORMANCE_GUIDE.md index a42edb9..cdb0e41 100644 --- a/docs/PERFORMANCE_GUIDE.md +++ b/docs/PERFORMANCE_GUIDE.md @@ -25,12 +25,7 @@ * make sure your channel size is multiple of 8 when using fp16. multiple of 32 is better. * spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible. -Network Benchmark without batchnorm (F32/F16) in RTX 3080 Laptop GPU - -| F32/F16 | Spconv 1.x | Native| Implicit Gemm | Implicit Gemm Split Mask | -| -------------- |:---------------------:|---------------------:|---------------------:| ---------------------:| -| Forward | 43ms | 29ms/23ms | 30ms/15ms | 30ms/19ms | -| Backward | 80ms | 47ms/32ms | 56ms/15ms | 45ms/14ms | +See [benchmark](BENCHMARK.md) for more performance details of different algorithms. ## Algorithm Overview @@ -57,4 +52,4 @@ In my test, ```Implicit Gemm``` is almost 2x faster than ```Native```. TODO -In my test, ```Implicit Gemm Split Mask``` is slightly faster than ```Implicit Gemm```, but the indice generation is greatly slower, so currently we use ```Implicit Gemm``` by default. \ No newline at end of file +In my test, ```Implicit Gemm Split Mask``` is slightly faster than ```Implicit Gemm```, but the indice generation is slower, so currently we use ```Implicit Gemm``` by default. \ No newline at end of file diff --git a/example/mnist_sparse.py b/example/mnist_sparse.py index 5c9a20b..21c2593 100644 --- a/example/mnist_sparse.py +++ b/example/mnist_sparse.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,11 +22,12 @@ from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR import contextlib -import torch.cuda.amp +import torch.cuda.amp + @contextlib.contextmanager def identity_ctx(): - yield + yield class Net(nn.Module): @@ -39,14 +40,13 @@ def __init__(self): spconv.SubMConv2d(32, 64, 3, 1), nn.ReLU(), spconv.SparseMaxPool2d(2, 2), - spconv.ToDense(), + spconv.ToDense(), ) self.fc1 = nn.Linear(14 * 14 * 64, 128) self.fc2 = nn.Linear(128, 10) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) - def forward(self, x: torch.Tensor): # x: [N, 28, 28, 1], must be NHWC tensor x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1)) @@ -116,40 +116,72 @@ def test(args, model, device, test_loader): with amp_ctx: output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss( + output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax( + dim=1, + keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', + parser.add_argument('--batch-size', + type=int, + default=64, + metavar='N', help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + parser.add_argument('--test-batch-size', + type=int, + default=1000, + metavar='N', help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=14, metavar='N', + parser.add_argument('--epochs', + type=int, + default=14, + metavar='N', help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + parser.add_argument('--lr', + type=float, + default=1.0, + metavar='LR', help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + parser.add_argument('--gamma', + type=float, + default=0.7, + metavar='M', help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, + parser.add_argument('--no-cuda', + action='store_true', + default=False, help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', + parser.add_argument('--seed', + type=int, + default=1, + metavar='S', help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - - parser.add_argument('--save-model', action='store_true', default=False, + parser.add_argument( + '--log-interval', + type=int, + default=10, + metavar='N', + help='how many batches to wait before logging training status') + + parser.add_argument('--save-model', + action='store_true', + default=False, help='For Saving the current Model') - parser.add_argument('--fp16', action='store_true', default=False, + parser.add_argument('--fp16', + action='store_true', + default=False, help='For mixed precision training') args = parser.parse_args() @@ -161,20 +193,30 @@ def main(): kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - # here we remove norm to get sparse tensor with lots of zeros - # transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, **kwargs) + datasets.MNIST( + '../data', + train=True, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + # here we remove norm to get sparse tensor with lots of zeros + # transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.batch_size, + shuffle=True, + **kwargs) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - # here we remove norm to get sparse tensor with lots of zeros - # transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.test_batch_size, shuffle=True, **kwargs) + datasets.MNIST( + '../data', + train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + # here we remove norm to get sparse tensor with lots of zeros + # transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.test_batch_size, + shuffle=True, + **kwargs) model = Net().to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) diff --git a/example/voxel_gen.py b/example/voxel_gen.py index c4c85c7..25fe672 100644 --- a/example/voxel_gen.py +++ b/example/voxel_gen.py @@ -1,32 +1,32 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import numpy as np -from cumm import tensorview as tv +from cumm import tensorview as tv from spconv.utils import Point2VoxelCPU3d from spconv.pytorch.utils import PointToVoxel -import torch +import torch + def main(): # voxel gen source code: spconv/csrc/sparse/pointops.py - gen = Point2VoxelCPU3d( - vsize_xyz=[0.1, 0.1, 0.1], - coors_range_xyz=[-80, -80, -2, 80, 80, 6], - num_point_features=3, - max_num_voxels=5000, - max_num_points_per_voxel=5) + gen = Point2VoxelCPU3d(vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -2, 80, 80, 6], + num_point_features=3, + max_num_voxels=5000, + max_num_points_per_voxel=5) pc = np.random.uniform(-10, 10, size=[1000, 3]) pc_tv = tv.from_numpy(pc) @@ -39,20 +39,23 @@ def main(): print("------Raw Voxels-------") print(voxels_np[0]) # run voxel gen and FILL MEAN VALUE to voxel remain - voxels_tv, indices_tv, num_p_in_vx_tv = gen.point_to_voxel_empty_mean(pc_tv) + voxels_tv, indices_tv, num_p_in_vx_tv = gen.point_to_voxel_empty_mean( + pc_tv) voxels_np = voxels_tv.numpy_view() indices_np = indices_tv.numpy_view() num_p_in_vx_np = num_p_in_vx_tv.numpy_view() print("------Voxels with mean filled-------") print(voxels_np[0]) + def main_point_with_features(): # voxel gen source code: spconv/csrc/sparse/pointops.py gen = Point2VoxelCPU3d( - vsize_xyz=[0.1, 0.1, 0.1], - coors_range_xyz=[-80, -80, -2, 80, 80, 6], - num_point_features=4, # here num_point_features must equal to pc.shape[1] - max_num_voxels=5000, + vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -2, 80, 80, 6], + num_point_features= + 4, # here num_point_features must equal to pc.shape[1] + max_num_voxels=5000, max_num_points_per_voxel=5) pc = np.random.uniform(-10, 10, size=[1000, 3]) @@ -68,21 +71,22 @@ def main_point_with_features(): print("------Raw Voxels-------") print(voxels_np[0]) # run voxel gen and FILL MEAN VALUE to voxel remain - voxels_tv, indices_tv, num_p_in_vx_tv = gen.point_to_voxel_empty_mean(pc_tv) + voxels_tv, indices_tv, num_p_in_vx_tv = gen.point_to_voxel_empty_mean( + pc_tv) voxels_np = voxels_tv.numpy_view() indices_np = indices_tv.numpy_view() num_p_in_vx_np = num_p_in_vx_tv.numpy_view() print("------Voxels with mean filled-------") print(voxels_np[0]) + def main_pytorch_voxel_gen(): # voxel gen source code: spconv/csrc/sparse/pointops.py - gen = PointToVoxel( - vsize_xyz=[0.1, 0.1, 0.1], - coors_range_xyz=[-80, -80, -2, 80, 80, 6], - num_point_features=3, - max_num_voxels=5000, - max_num_points_per_voxel=5) + gen = PointToVoxel(vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -2, 80, 80, 6], + num_point_features=3, + max_num_voxels=5000, + max_num_points_per_voxel=5) pc = np.random.uniform(-10, 10, size=[1000, 3]) pc_th = torch.from_numpy(pc) @@ -100,16 +104,16 @@ def main_pytorch_voxel_gen(): print("------Voxels with mean filled-------") print(voxels_np[0]) + def main_pytorch_voxel_gen_cuda(): # voxel gen source code: spconv/csrc/sparse/pointops.py device = torch.device("cuda:0") - gen = PointToVoxel( - vsize_xyz=[0.1, 0.1, 0.1], - coors_range_xyz=[-80, -80, -2, 80, 80, 6], - num_point_features=3, - max_num_voxels=5000, - max_num_points_per_voxel=5, - device=device) + gen = PointToVoxel(vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -2, 80, 80, 6], + num_point_features=3, + max_num_voxels=5000, + max_num_points_per_voxel=5, + device=device) pc = np.random.uniform(-10, 10, size=[1000, 3]).astype(np.float32) pc_th = torch.from_numpy(pc).to(device) @@ -133,4 +137,4 @@ def main_pytorch_voxel_gen_cuda(): main_point_with_features() main_pytorch_voxel_gen() if torch.cuda.is_available(): - main_pytorch_voxel_gen_cuda() \ No newline at end of file + main_pytorch_voxel_gen_cuda() diff --git a/format_all.sh b/format_all.sh index 480536f..7a5935e 100644 --- a/format_all.sh +++ b/format_all.sh @@ -1,5 +1 @@ -isort -rc --atomic ./spconv && \ -isort -rc --atomic ./test && \ -yapf -i --recursive -vv ./spconv ./test -find ./src -regex '.*\.\(cpp\|hpp\|cc\|cxx\|cu\|cuh\|h\)' | xargs clang-format -i -find ./include -regex '.*\.\(cpp\|hpp\|cc\|cxx\|cu\|cuh\|h\)' | xargs clang-format -i \ No newline at end of file +yapf -i --recursive -vv ./spconv ./test ./example ./scripts diff --git a/pyproject.toml b/pyproject.toml index 5cb3e1c..8c27c2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools>=41.0", "wheel", "pccm>=0.2.21", "cumm>=0.2.1"] +requires = ["setuptools>=41.0", "wheel", "pccm>=0.2.21", "cumm>=0.2.2"] build-backend = "setuptools.build_meta" diff --git a/scripts/dev_subm.py b/scripts/dev_subm.py index 4cb69df..9d98313 100644 --- a/scripts/dev_subm.py +++ b/scripts/dev_subm.py @@ -19,20 +19,21 @@ from cumm.conv.main import ConvMainUnitTest, gen_gemm_kernels from cumm.conv.params import ConvProblem from cumm.gemm import kernel -import os +import os from spconv.core_cc.csrc.sparse.all import SpconvOps from cumm.gemm.codeops import div_up from spconv.constants import PACKAGE_ROOT from spconv.core import ConvAlgo -from spconv.pytorch import ops +from spconv.pytorch import ops from spconv.algo import CONV, BestConvAlgoByProfile from spconv.pytorch.cppcore import torch_tensor_to_tv + def reduce_mask_count(mask: np.ndarray, width: int): mask_length_32 = (div_up(mask.shape[0], width)) * width if mask.shape[0] < mask_length_32: - mask_pad = np.zeros((mask_length_32,), dtype=mask.dtype) + mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype) mask_pad[:mask.shape[0]] = mask mask = mask_pad mask = mask.reshape(-1, width) @@ -40,16 +41,18 @@ def reduce_mask_count(mask: np.ndarray, width: int): maskr_tv = tv.from_numpy(maskr) return SpconvOps.count_bits(maskr_tv).numpy().sum() * width + def reduce_mask_count_x(mask: np.ndarray, width: int): mask_length_32 = (div_up(mask.shape[0], width)) * width if mask.shape[0] < mask_length_32: - mask_pad = np.zeros((mask_length_32,), dtype=mask.dtype) + mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype) mask_pad[:mask.shape[0]] = mask mask = mask_pad mask = mask.reshape(-1, width) maskr = np.bitwise_or.reduce(mask, axis=1) return maskr + def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): limit_input_n = 16384 limit_input_n = None @@ -88,8 +91,9 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): stride = [1] * ndim dilation = [1] * ndim out_padding = [0] * ndim - out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs(indices_th, 1, spatial_shape, ConvAlgo.Native, - ksize, stride, padding, dilation, out_padding, subm) + out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs( + indices_th, 1, spatial_shape, ConvAlgo.Native, ksize, stride, padding, + dilation, out_padding, subm) indice_num_per_loc_np = indice_num_per_loc.cpu().numpy() indice_pairs_np = pair_ref.cpu().numpy() algo = ConvAlgo.MaskSplitImplicitGemm @@ -98,8 +102,9 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): else: num_split = 2 for i in range(5): - res = ops.get_indice_pairs_implicit_gemm(indices_th, 1, spatial_shape, algo, - ksize, stride, padding, dilation, out_padding, subm) + res = ops.get_indice_pairs_implicit_gemm(indices_th, 1, spatial_shape, + algo, ksize, stride, padding, + dilation, out_padding, subm) out_inds = res[0] num_inds_per_loc = res[1] pair_fwd = res[2] @@ -115,23 +120,38 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): mask_argsort_fwd_splits = res[6] mask_argsort_bwd_splits = res[7] masks = res[8] - pair_mask_fwd_splits_tv = [ops.torch_tensor_to_tv(t, dtype=tv.uint32) for t in pair_mask_fwd_splits] - valid_location_bitcount = [SpconvOps.count_bits(t) for t in pair_mask_fwd_splits_tv] - valid_location_count = sum([t.cpu().numpy().sum() for t in valid_location_bitcount]) + pair_mask_fwd_splits_tv = [ + ops.torch_tensor_to_tv(t, dtype=tv.uint32) + for t in pair_mask_fwd_splits + ] + valid_location_bitcount = [ + SpconvOps.count_bits(t) for t in pair_mask_fwd_splits_tv + ] + valid_location_count = sum( + [t.cpu().numpy().sum() for t in valid_location_bitcount]) reduce_length = 32 - split_mask_valid_count = sum([reduce_mask_count(t.cpu().numpy(), reduce_length) for t in pair_mask_fwd_splits_tv]) + split_mask_valid_count = sum([ + reduce_mask_count(t.cpu().numpy(), reduce_length) + for t in pair_mask_fwd_splits_tv + ]) if subm: - print("SUBM", valid_location_count, split_mask_valid_count, pair_fwd.numel()) + print("SUBM", valid_location_count, split_mask_valid_count, + pair_fwd.numel()) else: - print("REGULAR", valid_location_count, split_mask_valid_count, pair_fwd.numel()) - # return + print("REGULAR", valid_location_count, split_mask_valid_count, + pair_fwd.numel()) + # return if run_conv: C = 64 K = 64 desps = CONV.desps - mask_output_fwd = torch.zeros([2, div_up(out_inds.shape[0], 32)], dtype=torch.int32, device=indices_th.device) - mask_output_bwd = torch.zeros([2, div_up(indices.dim(0), 32)], dtype=torch.int32, device=indices_th.device) + mask_output_fwd = torch.zeros([2, div_up(out_inds.shape[0], 32)], + dtype=torch.int32, + device=indices_th.device) + mask_output_bwd = torch.zeros([2, div_up(indices.dim(0), 32)], + dtype=torch.int32, + device=indices_th.device) for desp in desps: if desp.algo != GemmAlgo.Simt.value: @@ -140,17 +160,22 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): # continue # if desp.tile_shape ! if desp.dtype_a == dtypes.int8.tv_dtype: - inp = np.random.randint(-1, 1, size=[voxels_np.shape[0], C]).astype(np.int8) - weight = np.random.randint(-1, 1, size=[K, *ksize, C]).astype(np.int8) - output = np.random.randint(-1, 1, size=[out_inds.shape[0], K]).astype( - dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) + inp = np.random.randint(-1, 1, size=[voxels_np.shape[0], + C]).astype(np.int8) + weight = np.random.randint(-1, 1, size=[K, *ksize, + C]).astype(np.int8) + output = np.random.randint(-1, 1, size=[ + out_inds.shape[0], K + ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) else: - inp = np.random.uniform(-1, 1, size=[voxels_np.shape[0], C]).astype( - dtypes.get_npdtype_from_tvdtype(desp.dtype_input)) + inp = np.random.uniform(-1, 1, size=[ + voxels_np.shape[0], C + ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_input)) weight = np.random.uniform(-1, 1, size=[K, *ksize, C]).astype( dtypes.get_npdtype_from_tvdtype(desp.dtype_weight)) - output = np.random.uniform(-1, 1, size=[out_inds.shape[0], K]).astype( - dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) + output = np.random.uniform(-1, 1, size=[ + out_inds.shape[0], K + ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) weight_ref = weight.transpose(1, 2, 3, 0, 4) weight_ref = np.ascontiguousarray(weight_ref).reshape(-1, K, C) if desp.op_type == ConvOpType.kBackwardInput.value: @@ -211,19 +236,19 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): ) else: if desp.op_type == ConvOpType.kForward.value: - indice_pairs = pair_fwd # inp -> out + indice_pairs = pair_fwd # inp -> out mask_ops = pair_mask_fwd_splits mask_argsorts = mask_argsort_fwd_splits mask_output = mask_output_fwd elif desp.op_type == ConvOpType.kBackwardInput.value: - indice_pairs = pair_bwd # out -> inp + indice_pairs = pair_bwd # out -> inp mask_ops = pair_mask_bwd_splits mask_argsorts = mask_argsort_bwd_splits mask_output = mask_output_bwd print([bin(x.item()) for x in masks]) else: - indice_pairs = pair_fwd # inp -> out + indice_pairs = pair_fwd # inp -> out mask_ops = pair_mask_fwd_splits mask_argsorts = mask_argsort_fwd_splits mask_output = mask_output_fwd @@ -255,7 +280,7 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): ) torch.cuda.synchronize() - duration = time.time() - t + duration = time.time() - t if desp.op_type == ConvOpType.kForward.value: output_ref = np.zeros_like(output, dtype=np.float32) # ref algorithm @@ -270,7 +295,9 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): c_inds = indice_pairs_np[1][filter_offset][:nhot] # print(a_inds_cpu[:10]) a = inp[a_inds] - cc = a.astype(np.float32) @ weight_ref[filter_offset].T.astype(np.float32) + cc = a.astype( + np.float32) @ weight_ref[filter_offset].T.astype( + np.float32) output_ref[c_inds] += cc output_cpu = output_tv.cpu().numpy().astype(np.float32) @@ -294,12 +321,18 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): # print(a_inds_cpu[:10]) a = output[a_inds] # NK @ KC - cc = a.astype(np.float32) @ weight_ref[filter_offset].astype(np.float32) + cc = a.astype( + np.float32) @ weight_ref[filter_offset].astype( + np.float32) dinput_ref[c_inds] += cc din_cpu = inp_tv.cpu().numpy() - print("ERROR", np.linalg.norm(din_cpu.reshape(-1) - dinput_ref.reshape(-1))) + print( + "ERROR", + np.linalg.norm( + din_cpu.reshape(-1) - dinput_ref.reshape(-1))) else: - dw_ref = np.zeros_like(weight_ref, dtype=np.float32) # KV, K, C + dw_ref = np.zeros_like(weight_ref, + dtype=np.float32) # KV, K, C for filter_offset in range(kv): if subm and filter_offset > kv // 2: nhot = indice_num_per_loc_np[kv - 1 - filter_offset] @@ -310,16 +343,20 @@ def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True): o_inds = indice_pairs_np[1][filter_offset][:nhot] i_inds = indice_pairs_np[0][filter_offset][:nhot] # print(a_inds_cpu[:10]) - out_gather = output[o_inds] # [N, K] - inp_gather = inp[i_inds] # [N, C] + out_gather = output[o_inds] # [N, K] + inp_gather = inp[i_inds] # [N, C] # KN @ NC - dw_res = out_gather.astype(np.float32).T @ inp_gather.astype(np.float32) + dw_res = out_gather.astype( + np.float32).T @ inp_gather.astype(np.float32) dw_ref[filter_offset] = dw_res # print(indice_pairs_np_test[0]) dw_ref_kcrs = dw_ref.transpose(1, 0, 2) dw_cpu = weight_tv.cpu().numpy().reshape(K, np.prod(ksize), C) - print("ERROR", np.linalg.norm(dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1))) + print( + "ERROR", + np.linalg.norm( + dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1))) if __name__ == "__main__": diff --git a/scripts/sort_bench.py b/scripts/sort_bench.py index 0c51dbd..fae436b 100644 --- a/scripts/sort_bench.py +++ b/scripts/sort_bench.py @@ -1,31 +1,32 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -from cumm import tensorview as tv +import numpy as np +from cumm import tensorview as tv from spconv.core_cc.csrc.sparse.all import SpconvOps -import pickle +import pickle import torch -from spconv.pytorch.cppcore import torch_tensor_to_tv +from spconv.pytorch.cppcore import torch_tensor_to_tv + def main(): with open("/home/yy/asd.pkl", "rb") as f: a_th = pickle.load(f) mask_argsort = torch.empty((1, a_th.shape[1]), - dtype=torch.int32, - device=a_th.device) + dtype=torch.int32, + device=a_th.device) a = a_th.cpu().numpy()[0] a_tv = torch_tensor_to_tv(a_th) @@ -34,5 +35,6 @@ def main(): a_tv_1 = a_tv.clone() SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0]) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/setup.py b/setup.py index c6075df..fcbaef8 100644 --- a/setup.py +++ b/setup.py @@ -38,9 +38,9 @@ cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102 RELEASE_NAME += "-cu{}".format(cuda_ver) - deps = ["cumm-cu{}".format(cuda_ver)] + deps = ["cumm-cu{}>=0.2.2".format(cuda_ver)] else: - deps = ["cumm"] + deps = ["cumm>=0.2.2"] @@ -48,11 +48,11 @@ URL = 'https://github.com/traveller59/spconv' EMAIL = 'yanyan.sub@outlook.com' AUTHOR = 'Yan Yan' -REQUIRES_PYTHON = '>=3.7' +REQUIRES_PYTHON = '>=3.6' VERSION = None # What packages are required for this module to be executed? -REQUIRED = ["pccm>=0.2.19", "pybind11>=2.6.0", "fire", "numpy", *deps] +REQUIRED = ["pccm>=0.2.21", "pybind11>=2.6.0", "fire", "numpy", *deps] # What packages are optional? EXTRAS = { diff --git a/spconv/__init__.py b/spconv/__init__.py index 569255e..f7c47b0 100644 --- a/spconv/__init__.py +++ b/spconv/__init__.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,4 +16,4 @@ from .core import ConvAlgo, AlgoHint from . import constants -from .__version__ import __version__ \ No newline at end of file +from .__version__ import __version__ diff --git a/spconv/algo.py b/spconv/algo.py index bc316b1..ed34973 100644 --- a/spconv/algo.py +++ b/spconv/algo.py @@ -24,9 +24,10 @@ from typing import Optional import time from threading import Lock -import torch +import contextlib import numpy as np from spconv.core import ConvAlgo, AlgoHint +from spconv.tools import CUDAKernelTimer ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp() ALL_CONV_ALGO_DESPS = ConvMainUnitTest.get_all_conv_algo_desp() @@ -403,7 +404,8 @@ def run_with_tuned_result( alpha: float = 1.0, beta: float = 0.0, gather_data: tv.Tensor = tv.Tensor(), - workspace: tv.Tensor = tv.Tensor()): + workspace: tv.Tensor = tv.Tensor(), + timer: CUDAKernelTimer = CUDAKernelTimer(False)): m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a, trans_b, trans_c, shuffle_type.value, @@ -446,6 +448,9 @@ def run_with_tuned_result( # stream=stream) # GemmMainUnitTest.stream_synchronize(stream) # gather = time.time() - tt + if timer.enable: + assert timer._timer is not None + params.timer = timer._timer GemmMainUnitTest.matmul2(params) # GemmMainUnitTest.stream_synchronize(stream) @@ -678,7 +683,8 @@ def run_with_tuned_result(self, beta: float = 0.0, stream: int = 0, workspace: tv.Tensor = tv.Tensor(), - verbose: bool = False): + verbose: bool = False, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): channel_k = output.dim(1) channel_c = inp.dim(1) # GemmMainUnitTest.stream_synchronize(stream) @@ -709,9 +715,11 @@ def run_with_tuned_result(self, params.mask_filter = mask_filter params.mask_output = mask_output params.reverse_mask = reverse_mask + if timer.enable: + assert timer._timer is not None + params.timer = timer._timer # torch.cuda.synchronize() # t = time.time() - params.workspace = workspace ConvMainUnitTest.implicit_gemm2(params) # torch.cuda.synchronize() @@ -724,6 +732,7 @@ def run_with_tuned_result(self, def stream_synchronize(self, stream: int): return GemmMainUnitTest.stream_synchronize(stream) + GEMM = SimpleGemm(ALL_ALGO_DESPS) CONV = SimpleConv(ALL_CONV_ALGO_DESPS) diff --git a/spconv/build.py b/spconv/build.py index fab1f56..5436384 100644 --- a/spconv/build.py +++ b/spconv/build.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,7 +19,8 @@ from ccimport.compat import InWindows from .constants import PACKAGE_NAME, PACKAGE_ROOT, DISABLE_JIT -if project_is_installed(PACKAGE_NAME) and project_is_editable(PACKAGE_NAME) and not DISABLE_JIT: +if project_is_installed(PACKAGE_NAME) and project_is_editable( + PACKAGE_NAME) and not DISABLE_JIT: from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS from spconv.core import IMPLGEMM_SIMT_PARAMS, IMPLGEMM_VOLTA_PARAMS, IMPLGEMM_TURING_PARAMS @@ -27,11 +28,13 @@ from cumm.conv.main import ConvMainUnitTest from spconv.csrc.sparse.all import SpconvOps - cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS) + cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + + SHUFFLE_TURING_PARAMS) cu.namespace = "cumm.gemm.main" - convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + IMPLGEMM_TURING_PARAMS) + convcu = ConvMainUnitTest(IMPLGEMM_SIMT_PARAMS + IMPLGEMM_VOLTA_PARAMS + + IMPLGEMM_TURING_PARAMS) convcu.namespace = "cumm.conv.main" - objects_folder = None + objects_folder = None if InWindows: # windows have command line limit, so we use objects_folder to reduce command size. objects_folder = "objects" diff --git a/spconv/constants.py b/spconv/constants.py index f923b6c..b7e8c56 100644 --- a/spconv/constants.py +++ b/spconv/constants.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,10 +20,10 @@ PACKAGE_NAME = "spconv" PACKAGE_ROOT = Path(__file__).parent.resolve() -EDITABLE_INSTALLED = project_is_installed(PACKAGE_NAME) and project_is_editable(PACKAGE_NAME) - +EDITABLE_INSTALLED = project_is_installed( + PACKAGE_NAME) and project_is_editable(PACKAGE_NAME) _filter_hwio_env = os.getenv("SPCONV_FILTER_HWIO", "0") FILTER_HWIO = _filter_hwio_env == "1" DISABLE_JIT = os.getenv("SPCONV_DISABLE_JIT", "0") == "1" -NDIM_DONT_CARE = 3 \ No newline at end of file +NDIM_DONT_CARE = 3 diff --git a/spconv/core.py b/spconv/core.py index 72a735c..d64b452 100644 --- a/spconv/core.py +++ b/spconv/core.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,6 +21,7 @@ ConvLayoutType, ConvMode, ConvOpType) from spconv.constants import NDIM_DONT_CARE + class ConvAlgo(Enum): Native = "Native" MaskImplicitGemm = "MaskImplicitGemm" @@ -33,90 +34,70 @@ class AlgoHint(Enum): BackwardInput = 0b010 BackwardWeight = 0b100 + # we can't add more kernels here because build in github action is very slow. # TODO two step build: build gemm kernels first, then bind for every python SHUFFLE_SIMT_PARAMS: List[GemmAlgoParams] = [ - *gen_shuffle_params( - (64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", - 2, kernel.GemmAlgo.SimtDP4A, None), - *gen_shuffle_params( - (128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", - 2, kernel.GemmAlgo.SimtDP4A, None), - *gen_shuffle_params( - (128, 128, 32), - (32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, - kernel.GemmAlgo.SimtDP4A, None), + *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", + 2, kernel.GemmAlgo.SimtDP4A, None), + *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", + 2, kernel.GemmAlgo.SimtDP4A, None), + *gen_shuffle_params((128, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], + "", 2, kernel.GemmAlgo.SimtDP4A, None), *gen_shuffle_params( (128, 128, 32), (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.SimtDP4A, None), - *gen_shuffle_params( - (64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", - 2, kernel.GemmAlgo.SimtDP4A, None), - - *gen_shuffle_params( - (64, 256, 8), - (32, 64, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", + 2, kernel.GemmAlgo.SimtDP4A, None), + *gen_shuffle_params((64, 256, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), # *gen_shuffle_params( # (64, 256, 8), # (64, 32, 8), ["f32,f32,f32,f32,f32"], 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (32, 128, 16), - (32, 32, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (32, 512, 8), - (32, 64, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((32, 128, 16), (32, 32, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((32, 512, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), # *gen_shuffle_params( # (128, 128, 8), # (64, 32, 8), ["f32,f32,f32,f32,f32"], 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (128, 128, 8), - (32, 64, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (64, 128, 8), - (32, 64, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((64, 128, 8), (32, 64, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), # *gen_shuffle_params( # (64, 128, 8), # (64, 32, 8), ["f32,f32,f32,f32,f32"], 2, kernel.GemmAlgo.Simt, None), # *gen_shuffle_params( # (128, 64, 8), # (32, 64, 8), ["f32,f32,f32,f32,f32"], 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (128, 64, 8), - (64, 32, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (64, 64, 8), - (32, 32, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (32, 64, 16), - (32, 32, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (64, 32, 16), - (32, 32, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (32, 32, 32), - (32, 32, 8), ["f32,f32,f32,f32,f32"], "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((128, 64, 8), (64, 32, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((64, 64, 8), (32, 32, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((32, 64, 16), (32, 32, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((64, 32, 16), (32, 32, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f32,f32,f32,f32,f32"], + "f32,f32,f32,f32,f32", 2, kernel.GemmAlgo.Simt, None), # fall back kernels if mat is misaligned for half # TODO use access-per-vector kernel instead of simt kernel for fallback - *gen_shuffle_params( - (128, 128, 8), - (32, 64, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (32, 64, 32), - (32, 32, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (32, 32, 32), - (32, 32, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((128, 128, 8), (32, 64, 8), ["f16,f16,f16,f16,f16"], + "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((32, 64, 32), (32, 32, 8), ["f16,f16,f16,f16,f16"], + "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((32, 32, 32), (32, 32, 8), ["f16,f16,f16,f16,f16"], + "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), # *gen_shuffle_params( # (64, 64, 16), # (32, 32, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (64, 128, 16), - (32, 64, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), - *gen_shuffle_params( - (64, 64, 8), - (32, 32, 8), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((64, 128, 16), (32, 64, 8), ["f16,f16,f16,f16,f16"], + "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), + *gen_shuffle_params((64, 64, 8), (32, 32, 8), ["f16,f16,f16,f16,f16"], + "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Simt, None), ] SHUFFLE_VOLTA_PARAMS: List[GemmAlgoParams] = [ @@ -145,7 +126,7 @@ class AlgoHint(Enum): (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Volta, TensorOpParams((8, 8, 4))), ] -SHUFFLE_VOLTA_PARAMS = [] +# SHUFFLE_VOLTA_PARAMS = [] SHUFFLE_TURING_PARAMS: List[GemmAlgoParams] = [ *gen_shuffle_params( (64, 64, 32), @@ -183,133 +164,500 @@ class AlgoHint(Enum): (64, 128, 32), (32, 64, 32), ["f16,f16,f16,f16,f16"], "f16,f16,f16,f32,f32", 2, kernel.GemmAlgo.Turing, TensorOpParams((16, 8, 8))), - *gen_shuffle_params( - (64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", - 2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), + *gen_shuffle_params((64, 64, 32), (32, 32, 32), ["s8,s8,s32,s32,s32"], "", + 2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), *gen_shuffle_params( (128, 128, 32), - (32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, - kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), + (32, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, + TensorOpParams((8, 8, 16))), # *gen_shuffle_params( # (128, 128, 32), # (64, 32, 32), ["s8,s8,s8,s32,s32", "s8,s8,s32,s32,s32"], "", 2, # kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), *gen_shuffle_params( (128, 256, 32), - (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, - kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), + (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, + TensorOpParams((8, 8, 16))), *gen_shuffle_params( (256, 128, 32), - (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, - kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), - *gen_shuffle_params( - (128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", - 2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), - *gen_shuffle_params( - (64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", - 2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), + (64, 64, 32), ["s8,s8,s32,s32,s32"], "", 2, kernel.GemmAlgo.Turing, + TensorOpParams((8, 8, 16))), + *gen_shuffle_params((128, 64, 32), (64, 32, 32), ["s8,s8,s32,s32,s32"], "", + 2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), + *gen_shuffle_params((64, 128, 32), (32, 64, 32), ["s8,s8,s32,s32,s32"], "", + 2, kernel.GemmAlgo.Turing, TensorOpParams((8, 8, 16))), ] - # SHUFFLE_TURING_PARAMS = [] IMPLGEMM_SIMT_PARAMS = [ - *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 16), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 16), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 32, 32), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - - - *gen_conv_params(ConvFwdAndBwdInput, (64, 256, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 8), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - - *gen_conv_params(ConvBwdWeight, (32, 128, 16), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 16), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 8), (32, 64, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 16), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 32, 32), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 256, 8), (32, 64, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 8), (32, 64, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 8), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 32, 16), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (32, 128, 16), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), # *gen_conv_params(ConvBwdWeight, (32, 256, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], # NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (32, 64, 16), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (32, 32, 32), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - - - *gen_conv_params(ConvBwdWeight, (64, 256, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (64, 128, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (64, 64, 8), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (64, 32, 16), (32, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - - *gen_conv_params(ConvBwdWeight, (128, 128, 8), (32, 64, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (128, 64, 8), (64, 32, 8), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f32,f32,f32,f32,f32"], - NHWC, NHWC, NHWC, GemmAlgo.Simt, None, mask_sparse=True, increment_k_first=True, access_per_vector=1), - + *gen_conv_params(ConvBwdWeight, (32, 64, 16), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (32, 32, 32), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 256, 8), (32, 64, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 128, 8), (32, 64, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 64, 8), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 32, 16), (32, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (128, 128, 8), (32, 64, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (128, 64, 8), (64, 32, 8), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f32,f32,f32,f32,f32"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Simt, + None, + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), ] IMPLGEMM_VOLTA_PARAMS = [ - *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=0), - *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - - *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=0), - - *gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Volta, TensorOpParams((8, 8, 4)), mask_sparse=True, increment_k_first=True, access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=0), + *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=0), + *gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Volta, + TensorOpParams((8, 8, 4)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), ] IMPLGEMM_TURING_PARAMS = [ - *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=0), - *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 64), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 64), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - - *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, ["f16,f16,f16,f16,f16"], - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - - *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=0), - - *gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), - *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", - NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=0), + *gen_conv_params(ConvFwdAndBwdInput, (32, 64, 32), (32, 32, 16), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 256, 32), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 32, 64), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (32, 128, 64), (32, 64, 64), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 32), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 128, 64), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvFwdAndBwdInput, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, ["f16,f16,f16,f16,f16"], + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, + "f16,f16,f16,f32,f32", + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=0), + *gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, + "f16,f16,f16,f32,f32", + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), + *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), + NDIM_DONT_CARE, + ConvIterAlgo.Optimized, + 2, + "f16,f16,f16,f32,f32", + NHWC, + NHWC, + NHWC, + GemmAlgo.Turing, + TensorOpParams((16, 8, 8)), + mask_sparse=True, + increment_k_first=True, + access_per_vector=1), # *gen_conv_params(ConvBwdWeight, (32, 64, 32), (32, 32, 16), NDIM_DONT_CARE, ConvIterAlgo.Optimized, 2, "f16,f16,f16,f32,f32", # NHWC, NHWC, NHWC, GemmAlgo.Turing, TensorOpParams((16, 8, 8)), mask_sparse=True, increment_k_first=True, access_per_vector=1), # gen_conv_params(ConvFwdAndBwdInput, ) -] \ No newline at end of file +] diff --git a/spconv/core_cc/__init__.pyi b/spconv/core_cc/__init__.pyi index b8bf5f6..e69de29 100644 --- a/spconv/core_cc/__init__.pyi +++ b/spconv/core_cc/__init__.pyi @@ -1,14 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/spconv/core_cc/csrc/__init__.pyi b/spconv/core_cc/csrc/__init__.pyi index b8bf5f6..e69de29 100644 --- a/spconv/core_cc/csrc/__init__.pyi +++ b/spconv/core_cc/csrc/__init__.pyi @@ -1,14 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/spconv/core_cc/csrc/sparse/__init__.pyi b/spconv/core_cc/csrc/sparse/__init__.pyi index b8bf5f6..e69de29 100644 --- a/spconv/core_cc/csrc/sparse/__init__.pyi +++ b/spconv/core_cc/csrc/sparse/__init__.pyi @@ -1,14 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/spconv/core_cc/csrc/sparse/all/ops1d.pyi b/spconv/core_cc/csrc/sparse/all/ops1d.pyi index b28aa24..03e57ca 100644 --- a/spconv/core_cc/csrc/sparse/all/ops1d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops1d.pyi @@ -19,10 +19,31 @@ class Point2Voxel: max_num_points_per_voxel: """ ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: clear_voxels: + empty_mean: + stream_int: + """ + ... + @staticmethod + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + points: + voxels: + indices: + num_per_voxel: + hashdata: + point_indice_data: + vsize: + grid_size: + grid_stride: + coors_range: + clear_voxels: + empty_mean: + stream_int: """ ... diff --git a/spconv/core_cc/csrc/sparse/all/ops1d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops1d/__init__.pyi deleted file mode 100644 index 03e57ca..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops1d/__init__.pyi +++ /dev/null @@ -1,49 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2Voxel: - hashdata: Tensor - point_indice_data: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - empty_mean: - stream_int: - """ - ... - @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - hashdata: - point_indice_data: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - empty_mean: - stream_int: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops1d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops1d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops1d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops2d.pyi b/spconv/core_cc/csrc/sparse/all/ops2d.pyi index b28aa24..03e57ca 100644 --- a/spconv/core_cc/csrc/sparse/all/ops2d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops2d.pyi @@ -19,10 +19,31 @@ class Point2Voxel: max_num_points_per_voxel: """ ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: clear_voxels: + empty_mean: + stream_int: + """ + ... + @staticmethod + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + points: + voxels: + indices: + num_per_voxel: + hashdata: + point_indice_data: + vsize: + grid_size: + grid_stride: + coors_range: + clear_voxels: + empty_mean: + stream_int: """ ... diff --git a/spconv/core_cc/csrc/sparse/all/ops2d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops2d/__init__.pyi deleted file mode 100644 index 03e57ca..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops2d/__init__.pyi +++ /dev/null @@ -1,49 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2Voxel: - hashdata: Tensor - point_indice_data: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - empty_mean: - stream_int: - """ - ... - @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - hashdata: - point_indice_data: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - empty_mean: - stream_int: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops2d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops2d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops2d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops3d.pyi b/spconv/core_cc/csrc/sparse/all/ops3d.pyi index b28aa24..03e57ca 100644 --- a/spconv/core_cc/csrc/sparse/all/ops3d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops3d.pyi @@ -19,10 +19,31 @@ class Point2Voxel: max_num_points_per_voxel: """ ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: clear_voxels: + empty_mean: + stream_int: + """ + ... + @staticmethod + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + points: + voxels: + indices: + num_per_voxel: + hashdata: + point_indice_data: + vsize: + grid_size: + grid_stride: + coors_range: + clear_voxels: + empty_mean: + stream_int: """ ... diff --git a/spconv/core_cc/csrc/sparse/all/ops3d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops3d/__init__.pyi deleted file mode 100644 index 03e57ca..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops3d/__init__.pyi +++ /dev/null @@ -1,49 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2Voxel: - hashdata: Tensor - point_indice_data: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - empty_mean: - stream_int: - """ - ... - @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - hashdata: - point_indice_data: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - empty_mean: - stream_int: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops3d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops3d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops3d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops4d.pyi b/spconv/core_cc/csrc/sparse/all/ops4d.pyi index b28aa24..03e57ca 100644 --- a/spconv/core_cc/csrc/sparse/all/ops4d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops4d.pyi @@ -19,10 +19,31 @@ class Point2Voxel: max_num_points_per_voxel: """ ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: clear_voxels: + empty_mean: + stream_int: + """ + ... + @staticmethod + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + points: + voxels: + indices: + num_per_voxel: + hashdata: + point_indice_data: + vsize: + grid_size: + grid_stride: + coors_range: + clear_voxels: + empty_mean: + stream_int: """ ... diff --git a/spconv/core_cc/csrc/sparse/all/ops4d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops4d/__init__.pyi deleted file mode 100644 index 03e57ca..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops4d/__init__.pyi +++ /dev/null @@ -1,49 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2Voxel: - hashdata: Tensor - point_indice_data: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - def point_to_voxel_hash(self, points: Tensor, clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - empty_mean: - stream_int: - """ - ... - @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - hashdata: - point_indice_data: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - empty_mean: - stream_int: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops4d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops4d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops4d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi index 19e2842..d44e436 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi @@ -9,14 +9,11 @@ class Point2VoxelCPU: @property def grid_size(self) -> List[int]: ... @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> Tuple[List[float], List[int], List[int], List[float]]: + def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: """ Args: vsize_xyz: coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: """ ... def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: @@ -30,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,7 +35,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: @@ -47,7 +43,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -55,7 +51,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu1d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu1d/__init__.pyi deleted file mode 100644 index d44e436..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu1d/__init__.pyi +++ /dev/null @@ -1,74 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2VoxelCPU: - densehashdata: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - def point_to_voxel(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... - def point_to_voxel_empty_mean(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu1d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu1d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu1d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi index 19e2842..d44e436 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi @@ -9,14 +9,11 @@ class Point2VoxelCPU: @property def grid_size(self) -> List[int]: ... @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> Tuple[List[float], List[int], List[int], List[float]]: + def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: """ Args: vsize_xyz: coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: """ ... def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: @@ -30,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,7 +35,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: @@ -47,7 +43,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -55,7 +51,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu2d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu2d/__init__.pyi deleted file mode 100644 index d44e436..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu2d/__init__.pyi +++ /dev/null @@ -1,74 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2VoxelCPU: - densehashdata: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - def point_to_voxel(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... - def point_to_voxel_empty_mean(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu2d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu2d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu2d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi index 19e2842..d44e436 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi @@ -9,14 +9,11 @@ class Point2VoxelCPU: @property def grid_size(self) -> List[int]: ... @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> Tuple[List[float], List[int], List[int], List[float]]: + def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: """ Args: vsize_xyz: coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: """ ... def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: @@ -30,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,7 +35,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: @@ -47,7 +43,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -55,7 +51,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu3d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu3d/__init__.pyi deleted file mode 100644 index d44e436..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu3d/__init__.pyi +++ /dev/null @@ -1,74 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2VoxelCPU: - densehashdata: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - def point_to_voxel(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... - def point_to_voxel_empty_mean(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu3d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu3d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu3d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi index 19e2842..d44e436 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi @@ -9,14 +9,11 @@ class Point2VoxelCPU: @property def grid_size(self) -> List[int]: ... @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> Tuple[List[float], List[int], List[int], List[float]]: + def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: """ Args: vsize_xyz: coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: """ ... def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: @@ -30,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,7 +35,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: @@ -47,7 +43,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, mean_per_voxel: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -55,7 +51,6 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: - mean_per_voxel: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu4d/__init__.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu4d/__init__.pyi deleted file mode 100644 index d44e436..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu4d/__init__.pyi +++ /dev/null @@ -1,74 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -from cumm.tensorview import Tensor -class Point2VoxelCPU: - densehashdata: Tensor - voxels: Tensor - indices: Tensor - num_per_voxel: Tensor - @property - def grid_size(self) -> List[int]: ... - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... - def __init__(self, vsize_xyz: List[float], coors_range_xyz: List[float], num_point_features: int, max_num_voxels: int, max_num_points_per_voxel: int) -> None: - """ - Args: - vsize_xyz: - coors_range_xyz: - num_point_features: - max_num_voxels: - max_num_points_per_voxel: - """ - ... - @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - voxels: - indices: - num_per_voxel: - densehashdata: - vsize: - grid_size: - grid_stride: - coors_range: - clear_voxels: - """ - ... - def point_to_voxel(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... - def point_to_voxel_empty_mean(self, points: Tensor, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - points: - clear_voxels: - """ - ... diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu4d/p2v_c.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu4d/p2v_c.pyi deleted file mode 100644 index 2e75535..0000000 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu4d/p2v_c.pyi +++ /dev/null @@ -1,11 +0,0 @@ -from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pccm.stubs import EnumValue, EnumClassValue -class Point2VoxelCommon: - @staticmethod - def calc_meta_data(vsize_xyz: List[float], coors_range_xyz: List[float]) -> Tuple[List[float], List[int], List[int], List[float]]: - """ - Args: - vsize_xyz: - coors_range_xyz: - """ - ... diff --git a/spconv/core_cc/cumm/__init__.pyi b/spconv/core_cc/cumm/__init__.pyi index b8bf5f6..e69de29 100644 --- a/spconv/core_cc/cumm/__init__.pyi +++ b/spconv/core_cc/cumm/__init__.pyi @@ -1,14 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/spconv/core_cc/cumm/conv/main.pyi b/spconv/core_cc/cumm/conv/main.pyi index aa1ce42..58c1431 100644 --- a/spconv/core_cc/cumm/conv/main.pyi +++ b/spconv/core_cc/cumm/conv/main.pyi @@ -2,6 +2,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Ty from pccm.stubs import EnumValue, EnumClassValue from ...cumm.gemm.main import GemmAlgoDesp from cumm.tensorview import Tensor +from cumm.tensorview import CUDAKernelTimer class ConvAlgoDesp(GemmAlgoDesp): ndim: int op_type: int @@ -86,17 +87,19 @@ class ConvParams: mask_filter: int reverse_mask: bool verbose: bool + timer: CUDAKernelTimer workspace: Tensor = Tensor() mask: Tensor = Tensor() mask_argsort: Tensor = Tensor() indices: Tensor = Tensor() mask_output: Tensor = Tensor() stream: int - def __init__(self, ndim: int, op_type: int) -> None: + def __init__(self, ndim: int, op_type: int, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None: """ Args: ndim: op_type: + timer: """ ... class ConvMainUnitTest: diff --git a/spconv/core_cc/cumm/gemm/__init__.pyi b/spconv/core_cc/cumm/gemm/__init__.pyi index b8bf5f6..e69de29 100644 --- a/spconv/core_cc/cumm/gemm/__init__.pyi +++ b/spconv/core_cc/cumm/gemm/__init__.pyi @@ -1,14 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/spconv/core_cc/cumm/gemm/main.pyi b/spconv/core_cc/cumm/gemm/main.pyi index 01ae484..17001b7 100644 --- a/spconv/core_cc/cumm/gemm/main.pyi +++ b/spconv/core_cc/cumm/gemm/main.pyi @@ -1,6 +1,7 @@ from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from pccm.stubs import EnumValue, EnumClassValue from cumm.tensorview import Tensor +from cumm.tensorview import CUDAKernelTimer class GemmAlgoDesp: dtype_a: int dtype_b: int @@ -102,7 +103,13 @@ class GemmParams: alpha: float beta: float stream: int - def __init__(self) -> None: ... + timer: CUDAKernelTimer + def __init__(self, timer: CUDAKernelTimer = CUDAKernelTimer(False)) -> None: + """ + Args: + timer: + """ + ... def check_valid(self) -> None: ... @property def a(self) -> Tensor: ... diff --git a/spconv/core_cc/cumm/tools/__init__.pyi b/spconv/core_cc/cumm/tools/__init__.pyi new file mode 100644 index 0000000..e69de29 diff --git a/spconv/core_cc/cumm/tools/cuda.pyi b/spconv/core_cc/cumm/tools/cuda.pyi new file mode 100644 index 0000000..aa51098 --- /dev/null +++ b/spconv/core_cc/cumm/tools/cuda.pyi @@ -0,0 +1,56 @@ +from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from pccm.stubs import EnumValue, EnumClassValue +class CUDAEvent: + def __init__(self, name: str) -> None: + """ + Args: + name: + """ + ... + def record(self, stream: int = 0) -> None: + """ + Args: + stream: + """ + ... + def sync(self) -> None: ... + @staticmethod + def duration(start: "CUDAEvent", stop: "CUDAEvent") -> float: + """ + Args: + start: + stop: + """ + ... +class CUDAKernelTimer: + enable: bool + def __init__(self, enable: bool = True) -> None: + """ + Args: + enable: + """ + ... + def push(self, name: str) -> None: + """ + Args: + name: + """ + ... + def pop(self) -> None: ... + def record(self, name: str, stream: int = 0) -> None: + """ + Args: + name: + stream: + """ + ... + def insert_pair(self, name: str, start: str, stop: str) -> None: + """ + Args: + name: + start: + stop: + """ + ... + def get_all_pair_duration(self) -> Dict[str, float]: ... + def sync(self) -> None: ... diff --git a/spconv/cppconstants.py b/spconv/cppconstants.py index de319a9..f955c96 100644 --- a/spconv/cppconstants.py +++ b/spconv/cppconstants.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,4 +17,4 @@ if hasattr(_ext, "cumm"): CPU_ONLY_BUILD = False else: - CPU_ONLY_BUILD = True + CPU_ONLY_BUILD = True diff --git a/spconv/csrc/__init__.py b/spconv/csrc/__init__.py index b8bf5f6..84d35de 100644 --- a/spconv/csrc/__init__.py +++ b/spconv/csrc/__init__.py @@ -1,14 +1,13 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/spconv/csrc/sparse/__init__.py b/spconv/csrc/sparse/__init__.py index b8bf5f6..84d35de 100644 --- a/spconv/csrc/sparse/__init__.py +++ b/spconv/csrc/sparse/__init__.py @@ -1,14 +1,13 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/spconv/csrc/sparse/all.py b/spconv/csrc/sparse/all.py index aacf19e..1642d7a 100644 --- a/spconv/csrc/sparse/all.py +++ b/spconv/csrc/sparse/all.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,13 +17,14 @@ from cumm.conv.params import ConvProblem from cumm import dtypes from cumm.constants import CUMM_CPU_ONLY_BUILD -import pccm +import pccm from ccimport import compat from .pointops import Point2Voxel, Point2VoxelCPU from .indices import SparseConvIndicesKernel, CudaCommonKernel, SparseConvIndicesCPU from .maxpool import IndiceMaxPool, IndiceMaxPoolCPU from .gather import GatherCPU + class CustomThrustLib(pccm.Class): def __init__(self): super().__init__() @@ -32,12 +33,15 @@ def __init__(self): if compat.InLinux: self.build_meta.add_cflags("nvcc", "-Xcompiler", "-fno-gnu-unique") + class ThrustCustomAllocatorV2(pccm.Class, pccm.pybind.PybindClassMixin): def __init__(self): super().__init__() self.add_dependency(TensorView) self.add_include("functional", "memory") - self.add_pybind_member("alloc_func", "std::function", pyanno="Callable[[int], int]") + self.add_pybind_member("alloc_func", + "std::function", + pyanno="Callable[[int], int]") self.add_typedef("value_type", "char") @pccm.member_function @@ -54,14 +58,15 @@ def allocate(self): TV_THROW_RT_ERR("set alloc function first."); }} """) - return code + return code @pccm.member_function def deallocate(self): code = pccm.FunctionCode() code.arg("ptr", "char *") code.arg("num_bytes", "size_t") - return code + return code + class SpconvOps(pccm.Class): def __init__(self): @@ -69,28 +74,38 @@ def __init__(self): self.add_dependency(ThrustCustomAllocatorV2) self.ndims = [1, 2, 3, 4] for ndim in self.ndims: - p2v = Point2Voxel(dtypes.float32, ndim) + p2v = Point2Voxel(dtypes.float32, ndim) p2v_cpu = Point2VoxelCPU(dtypes.float32, ndim) - self.add_param_class(f"ops_cpu{ndim}d", p2v_cpu, f"Point2Voxel{ndim}DCPU") + self.add_param_class(f"ops_cpu{ndim}d", p2v_cpu, + f"Point2Voxel{ndim}DCPU") problem = ConvProblem(ndim, ConvOpType.kForward, NHWC, NHWC, NHWC) indices = SparseConvIndicesKernel(problem, dtypes.int32) indices_cpu = SparseConvIndicesCPU(problem, dtypes.int32) - self.add_param_class(f"ops_cpu{ndim}d", indices_cpu, f"SpconvIndicesCPU{ndim}D") + self.add_param_class(f"ops_cpu{ndim}d", indices_cpu, + f"SpconvIndicesCPU{ndim}D") # self.add_param_class("ops", indices, "SpconvIndices") if not CUMM_CPU_ONLY_BUILD: self.add_param_class(f"ops{ndim}d", p2v, f"Point2Voxel{ndim}D") - cuda_funcs = [self.generate_subm_conv_inds, - self.generate_conv_inds_stage1, self.generate_conv_inds_stage1_5, self.generate_conv_inds_stage2, self.sort_1d_by_key, - self.generate_conv_inds_mask_stage1, self.generate_conv_inds_mask_stage2] - self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d", indices, f"SpconvIndices{ndim}D") + cuda_funcs = [ + self.generate_subm_conv_inds, + self.generate_conv_inds_stage1, + self.generate_conv_inds_stage1_5, + self.generate_conv_inds_stage2, self.sort_1d_by_key, + self.generate_conv_inds_mask_stage1, + self.generate_conv_inds_mask_stage2 + ] + self.add_impl_only_param_class(cuda_funcs, f"ops{ndim}d", + indices, + f"SpconvIndices{ndim}D") @pccm.pybind.mark @pccm.cuda.static_function def generate_conv_inds_stage1(self): code = pccm.FunctionCode() code.arg("indices", "tv::Tensor") - code.arg("indice_pairs, indice_pairs_uniq, indice_num_per_loc", "tv::Tensor") + code.arg("indice_pairs, indice_pairs_uniq, indice_num_per_loc", + "tv::Tensor") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"std::vector") code.arg("ksize, stride, padding, dilation", f"std::vector") @@ -127,7 +142,7 @@ def generate_conv_inds_stage1(self): """) code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") - return code# .ret("int") + return code # .ret("int") @pccm.pybind.mark @pccm.cuda.static_function @@ -201,7 +216,8 @@ def generate_conv_inds_mask_stage1(self): return code.make_invalid() code.arg("indices", "tv::Tensor") - code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc", "tv::Tensor") + code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc", + "tv::Tensor") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"std::vector") code.arg("ksize, stride, padding, dilation", f"std::vector") @@ -236,7 +252,7 @@ def generate_conv_inds_mask_stage1(self): """) code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") - return code# .ret("int") + return code # .ret("int") @pccm.pybind.mark @pccm.cuda.static_function @@ -245,7 +261,9 @@ def generate_conv_inds_mask_stage2(self): if CUMM_CPU_ONLY_BUILD: return code.make_invalid() code.arg("indices, hashdata", "tv::Tensor") - code.arg("indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", "tv::Tensor") + code.arg( + "indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", + "tv::Tensor") code.arg("mask_fwd, mask_bwd", "tv::Tensor") code.arg("num_out_act", "int") code.arg("batch_size", "int") @@ -294,7 +312,8 @@ def generate_subm_conv_inds(self): code.arg("batch_size", "int") code.arg("input_dims", f"std::vector") code.arg("ksize, dilation", f"std::vector") - code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", "cumm.tensorview.Tensor = Tensor()") + code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", + "cumm.tensorview.Tensor = Tensor()") code.arg("backward", "bool", "false") code.arg("stream_int", f"std::uintptr_t", "0", pyanno="int = 0") code.raw(f""" @@ -529,7 +548,10 @@ def sort_1d_by_key(self): if CUMM_CPU_ONLY_BUILD: return code.make_invalid() code.arg("data", "tv::Tensor") - code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()") + code.arg("indices", + "tv::Tensor", + "tv::Tensor()", + pyanno="cumm.tensorview.Tensor = Tensor()") code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.code_after_include = f""" template struct SmallOrEqualTo {{ @@ -575,7 +597,10 @@ def sort_1d_by_key_allocator(self): code.arg("data", "tv::Tensor") code.arg("alloc_func", "std::function") - code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()") + code.arg("indices", + "tv::Tensor", + "tv::Tensor()", + pyanno="cumm.tensorview.Tensor = Tensor()") code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.code_after_include = f""" template struct SmallOrEqualTo {{ @@ -613,7 +638,6 @@ def sort_1d_by_key_allocator(self): """) return code.ret("tv::Tensor") - @pccm.pybind.mark @pccm.cuda.static_function def sort_1d_by_key_split(self): @@ -623,7 +647,10 @@ def sort_1d_by_key_split(self): code.arg("data", "tv::Tensor") code.arg("mask", "tv::Tensor") - code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()") + code.arg("indices", + "tv::Tensor", + "tv::Tensor()", + pyanno="cumm.tensorview.Tensor = Tensor()") code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("mask_output", "bool", "false") @@ -678,7 +705,10 @@ def sort_1d_by_key_split_allocator(self): code.arg("mask", "tv::Tensor") - code.arg("indices", "tv::Tensor", "tv::Tensor()", pyanno="cumm.tensorview.Tensor = Tensor()") + code.arg("indices", + "tv::Tensor", + "tv::Tensor()", + pyanno="cumm.tensorview.Tensor = Tensor()") code.arg("stream", "std::uintptr_t", "0", pyanno="int") code.arg("mask_output", "bool", "false") @@ -821,9 +851,10 @@ def calc_point2voxel_meta_data(self): }} """) code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") - return code.ret("std::tuple, std::vector, std::vector, std::vector>") - - + return code.ret( + "std::tuple, std::vector, std::vector, std::vector>" + ) + @pccm.pybind.mark @pccm.static_function def point2voxel_cpu(self): @@ -876,7 +907,8 @@ def point2voxel_cpu(self): def point2voxel_cuda(self): code = pccm.FunctionCode() code.arg("points", "tv::Tensor") - code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", "tv::Tensor") + code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", + "tv::Tensor") code.arg("vsize", f"std::vector") code.arg("grid_size, grid_stride", f"std::vector") code.arg("coors_range", f"std::vector") @@ -914,4 +946,4 @@ def point2voxel_cuda(self): }} """) code.raw(f"""TV_THROW_RT_ERR("unknown ndim", ndim);""") - return code.ret("std::tuple") \ No newline at end of file + return code.ret("std::tuple") diff --git a/spconv/csrc/sparse/cpu_core.py b/spconv/csrc/sparse/cpu_core.py new file mode 100644 index 0000000..45926cb --- /dev/null +++ b/spconv/csrc/sparse/cpu_core.py @@ -0,0 +1,29 @@ +# Copyright 2021 Yan Yan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pccm +from ccimport import compat +from cumm.common import TensorView + + +class OMPLib(pccm.Class): + def __init__(self): + super().__init__() + self.add_dependency(TensorView) + self.add_include("tensorview/parallel/all.h") + if compat.InWindows: + self.build_meta.add_cflags("cl", "/openmp") + else: + self.build_meta.add_cflags("g++", "-fopenmp") + self.build_meta.add_cflags("clang++", "-fopenmp") diff --git a/spconv/csrc/sparse/devleop/sort_bench.py b/spconv/csrc/sparse/devleop/sort_bench.py index 68f3e6f..4d49fa5 100644 --- a/spconv/csrc/sparse/devleop/sort_bench.py +++ b/spconv/csrc/sparse/devleop/sort_bench.py @@ -1,5 +1,6 @@ -import torch -import time +import torch +import time + def main(): @@ -34,4 +35,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/spconv/csrc/sparse/gather.py b/spconv/csrc/sparse/gather.py index 66ec034..0b374db 100644 --- a/spconv/csrc/sparse/gather.py +++ b/spconv/csrc/sparse/gather.py @@ -1,26 +1,32 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pccm +import pccm from cumm.common import TensorView -from typing import List +from cumm.constants import CUMM_CPU_ONLY_BUILD +from spconv.csrc.sparse.cpu_core import OMPLib +from typing import List + class GatherCPU(pccm.Class): def __init__(self): super().__init__() + if CUMM_CPU_ONLY_BUILD: + self.add_dependency(OMPLib) self.add_dependency(TensorView) - + self.add_include("tensorview/parallel/all.h") + @pccm.static_function def gather(self): code = pccm.FunctionCode() @@ -35,15 +41,16 @@ def gather(self): int channel = in.dim(1); tv::dispatch(out.dtype(), [&](auto I){{ auto indices_data = inds.data_ptr(); - using T = TV_DECLTYPE(I); T *buffer_data = out.data_ptr(); const T *features_data = in.data_ptr(); - for (int i = 0; i < nhot; ++i) {{ - std::memcpy(buffer_data + i * channel, - features_data + indices_data[i] * channel, - sizeof(T) * channel); - }} + tv::kernel_1d(out.device(), nhot, [&](int begin, int end, int step){{ + for (int i = begin; i < end; i += step) {{ + std::memcpy(buffer_data + i * channel, + features_data + indices_data[i] * channel, + sizeof(T) * channel); + }} + }}); }}); """) return code @@ -65,13 +72,15 @@ def scatter_add(self): T *features_data = out.data_ptr(); const T *buf = in.data_ptr(); T *out_ptr = out.data_ptr(); - for (int i = 0; i < nhot; ++i) {{ - buf = buffer_data + i * channel; - out_ptr = features_data + indices_data[i] * channel; - for (int j = 0; j < channel; ++j) {{ - out_ptr[j] = out_ptr[j] + buf[j]; + tv::kernel_1d(out.device(), nhot, [&](int begin, int end, int step){{ + for (int i = begin; i < end; i += step) {{ + buf = buffer_data + i * channel; + out_ptr = features_data + indices_data[i] * channel; + for (int j = 0; j < channel; ++j) {{ + out_ptr[j] = out_ptr[j] + buf[j]; + }} }} - }} + }}); }}); """) return code diff --git a/spconv/csrc/sparse/indices.py b/spconv/csrc/sparse/indices.py index 224c07b..cf5c260 100644 --- a/spconv/csrc/sparse/indices.py +++ b/spconv/csrc/sparse/indices.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,13 +16,14 @@ from cumm.conv.bases import ConvEnum from cumm.gemm.core.metaarray import MetaArray, seq from cumm import dtypes -import pccm +import pccm from cumm.gemm.layout import TensorGeneric, to_stride from cumm.common import TensorView, TensorViewHashKernel, TensorViewKernel, ThrustLib from cumm.gemm import codeops -from typing import List +from typing import List from cumm.conv.params import ConvProblem -import numpy as np +import numpy as np + class CudaCommonKernel(pccm.ParameterizedClass): # we need to use PClass instead of Class @@ -31,8 +32,8 @@ class CudaCommonKernel(pccm.ParameterizedClass): def arange_kernel(self): code = pccm.FunctionCode() code.targ("T") - code.arg("data", f"T*") - code.arg("size", f"int") + code.arg("data", f"T*") + code.arg("size", f"int") code.raw(f""" for (int i : tv::KernelLoopX(size)) {{ data[i] = T(i); @@ -44,9 +45,9 @@ def arange_kernel(self): def fill_kernel(self): code = pccm.FunctionCode() code.targ("T") - code.arg("data", f"T*") + code.arg("data", f"T*") code.arg("val", f"T") - code.arg("size", f"int") + code.arg("size", f"int") code.raw(f""" for (int i : tv::KernelLoopX(size)) {{ data[i] = T(val); @@ -66,7 +67,7 @@ def __init__(self, problem: ConvProblem): self.add_param_class("lociter", layout_npq, "LayoutNPQ") self.add_param_class("lociter_rs", layout_rs, "LayoutRS") - self.ndim = problem.ndim + self.ndim = problem.ndim self.add_member("problem_", f"ConvProblem") self.add_member("count_", f"tv::array") self.add_member("layout_npq", f"LayoutNPQ") @@ -82,13 +83,15 @@ def ctor(self): pqs = codeops.unpack("problem.output_dims", range(self.ndim)) rss = codeops.unpack("problem.ksize", range(self.ndim)) - code.ctor_init("layout_npq", f"LayoutNPQ::from_shape({{problem.N, {pqs}}})") + code.ctor_init("layout_npq", + f"LayoutNPQ::from_shape({{problem.N, {pqs}}})") code.ctor_init("layout_rs", f"LayoutRS::from_shape({{{rss}}})") - - return code - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - name="operator++") + return code + + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + name="operator++") def increment(self): code = pccm.FunctionCode() for i in range(self.ndim - 1, -1, -1): @@ -110,8 +113,9 @@ def set_filter_offset(self): """) return code - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - const=True) + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + const=True) def nhw_to_npq(self): code = pccm.FunctionCode() code.arg("nhw_offset", "const int*") @@ -128,8 +132,9 @@ def nhw_to_npq(self): """) return code.ret(f"tv::array") - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - const=True) + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + const=True) def npq_to_nhw(self): code = pccm.FunctionCode() code.arg("npq_offset", "const int*") @@ -144,9 +149,9 @@ def npq_to_nhw(self): """) return code.ret(f"tv::array") - - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - const=True) + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + const=True) def query_npq(self): code = pccm.FunctionCode() code.arg("nhw_offset", "const int*") @@ -156,22 +161,27 @@ def query_npq(self): auto npq_no_stride = nhw_to_npq(nhw_offset); npq_offset[0] = npq_no_stride[0]; """) - hw_valid = [] # type: List[str] - stride_valid = [] # type: List[str] + hw_valid = [] # type: List[str] + stride_valid = [] # type: List[str] for i in range(self.ndim): - code.raw(f"npq_offset[{i + 1}] = npq_no_stride[{i + 1}] / problem_.stride[{i}];") - hw_valid.append((f"npq_offset[{i + 1}] >= 0 && " - f"npq_offset[{i + 1}] < problem_.output_dims[{i}]")) - stride_valid.append(f"!(npq_no_stride[{i + 1}] % problem_.stride[{i}])") + code.raw( + f"npq_offset[{i + 1}] = npq_no_stride[{i + 1}] / problem_.stride[{i}];" + ) + hw_valid.append( + (f"npq_offset[{i + 1}] >= 0 && " + f"npq_offset[{i + 1}] < problem_.output_dims[{i}]")) + stride_valid.append( + f"!(npq_no_stride[{i + 1}] % problem_.stride[{i}])") code.raw(f""" return npq_no_stride[0] < problem_.N && {' && '.join(hw_valid)} && {' && '.join(stride_valid)}; """) - return code + return code - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - const=True) + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + const=True) def query_npq_no_stride(self): code = pccm.FunctionCode() code.arg("nhw_offset", "const int*") @@ -180,18 +190,20 @@ def query_npq_no_stride(self): code.raw(f""" npq_offset = nhw_to_npq(nhw_offset); """) - hw_valid = [] # type: List[str] + hw_valid = [] # type: List[str] for i in range(self.ndim): - hw_valid.append((f"npq_offset[{i + 1}] >= 0 && " - f"npq_offset[{i + 1}] < problem_.output_dims[{i}]")) + hw_valid.append( + (f"npq_offset[{i + 1}] >= 0 && " + f"npq_offset[{i + 1}] < problem_.output_dims[{i}]")) code.raw(f""" return npq_offset[0] < problem_.N && {' && '.join(hw_valid)}; """) - return code + return code - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - const=True) + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + const=True) def query_nhw(self): code = pccm.FunctionCode() code.arg("npq_offset", "const int*") @@ -200,18 +212,20 @@ def query_nhw(self): code.raw(f""" nhw_offset = npq_to_nhw(npq_offset); """) - hw_valid = [] # type: List[str] + hw_valid = [] # type: List[str] for i in range(self.ndim): - hw_valid.append((f"nhw_offset[{i + 1}] >= 0 && " - f"nhw_offset[{i + 1}] < problem_.input_dims[{i}]")) + hw_valid.append( + (f"nhw_offset[{i + 1}] >= 0 && " + f"nhw_offset[{i + 1}] < problem_.input_dims[{i}]")) code.raw(f""" return nhw_offset[0] < problem_.N && {' && '.join(hw_valid)}; """) - return code + return code - @pccm.member_function(header_only=True, attrs=["TV_HOST_DEVICE_INLINE"], - const=True) + @pccm.member_function(header_only=True, + attrs=["TV_HOST_DEVICE_INLINE"], + const=True) def query_nhw_out(self): code = pccm.FunctionCode() code.arg("npq_offset", "const int*") @@ -220,41 +234,45 @@ def query_nhw_out(self): code.raw(f""" nhw_offset = npq_to_nhw(npq_offset); """) - hw_valid = [] # type: List[str] + hw_valid = [] # type: List[str] for i in range(self.ndim): - hw_valid.append((f"nhw_offset[{i + 1}] >= 0 && " - f"nhw_offset[{i + 1}] < problem_.output_dims[{i}]")) + hw_valid.append( + (f"nhw_offset[{i + 1}] >= 0 && " + f"nhw_offset[{i + 1}] < problem_.output_dims[{i}]")) code.raw(f""" return nhw_offset[0] < problem_.N && {' && '.join(hw_valid)}; """) - return code + return code + class SparseConvIndicesKernel(pccm.ParameterizedClass): def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType): super().__init__() - self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel, ThrustLib) + self.add_dependency(TensorView, TensorViewKernel, TensorViewHashKernel, + ThrustLib) self.loc_iter = ConvOutLocIter(problem) self.add_param_class("spinds", self.loc_iter, "ConvLocIter") - self.add_param_class("spinds", problem, "ConvProblem") - self.add_param_class("cudakers", CudaCommonKernel()) + self.add_param_class("spinds", problem, "ConvProblem") + self.add_param_class("cudakers", CudaCommonKernel()) - self.ndim = problem.ndim + self.ndim = problem.ndim self.dtype_indices = dtype_indices self.dtype_indices_uniq = dtype_indices assert dtype_indices == dtypes.int32 or dtype_indices == dtypes.int64 - @pccm.cuda.cuda_global_function def calc_conv_indices_stage1(self): code = pccm.FunctionCode() - code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] + code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] - code.arg("indices_in", f"const int*") # [N, ndim + 1] - code.arg("indice_pairs", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("indice_pairs_for_uniq", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("indice_num_per_loc", f"int*") # [kernelProd] + code.arg("indices_in", f"const int*") # [N, ndim + 1] + code.arg("indice_pairs", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("indice_pairs_for_uniq", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("num_indices_in", "int") code.arg("indices_pair_size", "int") @@ -288,17 +306,18 @@ def calc_conv_indices_stage1(self): """) return code - @pccm.cuda.cuda_global_function def build_conv_hash_table(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indices_out", f"int*") # [N, ndim + 1] - code.arg("indice_pairs_for_uniq", f"const {self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("indices_out", f"int*") # [N, ndim + 1] + code.arg("indice_pairs_for_uniq", + f"const {self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("layout_npq", f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize] + code.arg("layout_npq", + f"spinds::LayoutNPQ") # [2, kernelProd, MaxSize] code.arg("num_indices", "int") @@ -315,8 +334,8 @@ def build_conv_hash_table(self): def calc_conv_indices_stage2(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indice_pairs_out_part", f"int*") # [2, kernelProd, MaxSize] + code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("indice_pairs_out_part", f"int*") # [2, kernelProd, MaxSize] code.arg("num_indices_in", "int") code.arg("indices_pair_size", "int") # TODO use block instead of filter_offset? @@ -338,12 +357,14 @@ def calc_conv_indices_stage2(self): @pccm.cuda.cuda_global_function def calc_conv_indices_stage1_mask(self): code = pccm.FunctionCode() - code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] + code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] - code.arg("indices_in", f"const int*") # [N, ndim + 1] - code.arg("indice_pairs_bwd", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("indice_pairs_for_uniq", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("indice_num_per_loc", f"int*") # [kernelProd] + code.arg("indices_in", f"const int*") # [N, ndim + 1] + code.arg("indice_pairs_bwd", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("indice_pairs_for_uniq", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("num_indices_in", "int") @@ -381,11 +402,13 @@ def calc_conv_indices_stage1_mask(self): def calc_conv_indices_stage2_mask(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indice_pairs_fwd", f"int*") # [kernelProd, MaxSize], inp -> out - code.arg("indice_pairs_bwd", f"int*") # [kernelProd, MaxSize], out -> inp - code.arg("mask_fwd", f"uint32_t*") # [kernelProd] - code.arg("mask_bwd", f"uint32_t*") # [kernelProd] + code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("indice_pairs_fwd", + f"int*") # [kernelProd, MaxSize], inp -> out + code.arg("indice_pairs_bwd", + f"int*") # [kernelProd, MaxSize], out -> inp + code.arg("mask_fwd", f"uint32_t*") # [kernelProd] + code.arg("mask_bwd", f"uint32_t*") # [kernelProd] code.arg("num_indices_in", "int") code.arg("num_indices_out", "int") @@ -418,8 +441,9 @@ def calc_conv_indices_stage2_mask(self): @pccm.cuda.cuda_global_function def calc_conv_indices_stage2_mask_output(self): code = pccm.FunctionCode() - code.arg("indice_pairs_bwd", f"int*") # [kernelProd, MaxSize], out -> inp - code.arg("mask_bwd", f"uint32_t*") # [kernelProd] + code.arg("indice_pairs_bwd", + f"int*") # [kernelProd, MaxSize], out -> inp + code.arg("mask_bwd", f"uint32_t*") # [kernelProd] code.arg("num_indices_in", "int") code.arg("kv", "int") @@ -441,10 +465,12 @@ def calc_conv_indices_stage2_mask_output(self): def calc_conv_indices_stage2_inference_mask(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indice_pairs_fwd", f"int*") # [kernelProd, MaxSize], inp -> out - code.arg("indice_pairs_bwd", f"int*") # [kernelProd, MaxSize], out -> inp - code.arg("mask_fwd", f"uint32_t*") # [kernelProd] + code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("indice_pairs_fwd", + f"int*") # [kernelProd, MaxSize], inp -> out + code.arg("indice_pairs_bwd", + f"int*") # [kernelProd, MaxSize], out -> inp + code.arg("mask_fwd", f"uint32_t*") # [kernelProd] code.arg("num_indices_in", "int") code.arg("num_indices_out", "int") @@ -469,16 +495,15 @@ def calc_conv_indices_stage2_inference_mask(self): """) return code - @pccm.cuda.cuda_global_function def build_subm_conv_hash_table(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indices_in", f"const int*") # [N, ndim + 1] + code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("indices_in", f"const int*") # [N, ndim + 1] - code.arg("layout_npq", f"spinds::LayoutNPQ") + code.arg("layout_npq", f"spinds::LayoutNPQ") code.arg("num_indices", "int") @@ -493,8 +518,8 @@ def build_subm_conv_hash_table(self): @pccm.cuda.cuda_global_function def clean_indices_uniq(self): code = pccm.FunctionCode() - code.arg("indice_pairs_for_uniq", f"{self.dtype_indices}*") - code.arg("size", f"{self.dtype_indices}") + code.arg("indice_pairs_for_uniq", f"{self.dtype_indices}*") + code.arg("size", f"{self.dtype_indices}") code.raw(f""" for ({self.dtype_indices} i : tv::KernelLoopX<{self.dtype_indices}>(size)) {{ indice_pairs_for_uniq[i] = std::numeric_limits<{self.dtype_indices}>::max(); @@ -506,12 +531,13 @@ def clean_indices_uniq(self): def calc_subm_conv_indices(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] - code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] + code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indices_in", f"const int*") # [N, ndim + 1] - code.arg("indice_pairs", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("indice_num_per_loc", f"int*") # [kernelProd] + code.arg("indices_in", f"const int*") # [N, ndim + 1] + code.arg("indice_pairs", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("indice_num_per_loc", f"int*") # [kernelProd] code.arg("num_indices_in", "int") code.arg("indices_pair_size", "int") @@ -552,12 +578,13 @@ def calc_subm_conv_indices(self): def calc_subm_conv_indices_mask(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] - code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] + code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indices_in", f"const int*") # [N, ndim + 1] - code.arg("indice_pairs", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("mask", f"uint32_t*") # [kernelProd] + code.arg("indices_in", f"const int*") # [N, ndim + 1] + code.arg("indice_pairs", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("mask", f"uint32_t*") # [kernelProd] code.arg("num_indices", "int") code.arg("indices_pair_size", "int") @@ -609,13 +636,14 @@ def calc_subm_conv_indices_mask(self): def calc_subm_conv_indices_split_mask(self): code = pccm.FunctionCode() code.targ("TTable") - code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] - code.arg("table", f"TTable") # [N, ndim + 1] + code.arg("loc_iter", f"ConvLocIter") # [N, ndim + 1] + code.arg("table", f"TTable") # [N, ndim + 1] - code.arg("indices_in", f"const int*") # [N, ndim + 1] - code.arg("indice_pairs", f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] - code.arg("mask1", f"uint32_t*") # [kernelProd] - code.arg("mask2", f"uint32_t*") # [kernelProd] + code.arg("indices_in", f"const int*") # [N, ndim + 1] + code.arg("indice_pairs", + f"{self.dtype_indices}*") # [2, kernelProd, MaxSize] + code.arg("mask1", f"uint32_t*") # [kernelProd] + code.arg("mask2", f"uint32_t*") # [kernelProd] code.arg("num_indices", "int") code.arg("indices_pair_size", "int") @@ -665,10 +693,12 @@ def calc_subm_conv_indices_split_mask(self): def generate_conv_inds_stage1(self): code = pccm.FunctionCode() code.arg("indices", "tv::Tensor") - code.arg("indice_pairs, indice_pairs_uniq, indice_num_per_loc", "tv::Tensor") + code.arg("indice_pairs, indice_pairs_uniq, indice_num_per_loc", + "tv::Tensor") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"tv::array") - code.arg("ksize, stride, padding, dilation", f"tv::array") + code.arg("ksize, stride, padding, dilation", + f"tv::array") code.arg("transposed", f"bool", "false") code.arg("stream_int", f"std::uintptr_t", "0") @@ -706,9 +736,7 @@ def generate_conv_inds_stage1(self): // auto num_out_act = new_end - ptr_tr - 1; // return num_out_act; """) - return code# .ret("int") - - + return code # .ret("int") @pccm.cuda.static_function def generate_conv_inds_stage1_5(self): @@ -726,7 +754,6 @@ def generate_conv_inds_stage1_5(self): """) return code.ret("int") - @pccm.cuda.static_function def generate_conv_inds_stage2(self): code = pccm.FunctionCode() @@ -735,7 +762,8 @@ def generate_conv_inds_stage2(self): code.arg("num_out_act", "int") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"tv::array") - code.arg("ksize, stride, padding, dilation", f"tv::array") + code.arg("ksize, stride, padding, dilation", + f"tv::array") code.arg("transposed", f"bool", "false") code.arg("stream_int", f"std::uintptr_t", "0") code.raw(f""" @@ -783,10 +811,12 @@ def generate_conv_inds_stage2(self): def generate_conv_inds_mask_stage1(self): code = pccm.FunctionCode() code.arg("indices", "tv::Tensor") - code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc", "tv::Tensor") + code.arg("indice_pairs_bwd, indice_pairs_uniq, indice_num_per_loc", + "tv::Tensor") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"tv::array") - code.arg("ksize, stride, padding, dilation", f"tv::array") + code.arg("ksize, stride, padding, dilation", + f"tv::array") code.arg("transposed", f"bool", "false") code.arg("stream_int", f"std::uintptr_t", "0") @@ -817,21 +847,23 @@ def generate_conv_inds_mask_stage1(self): indice_pairs_bwd.data_ptr<{self.dtype_indices}>(), indice_pairs_uniq.data_ptr<{self.dtype_indices}>(), indice_num_per_loc.data_ptr(), indices.dim(0), kv, transposed); - auto timer = tv::CudaContextTimer<>(); """) - return code# .ret("int") + return code # .ret("int") @pccm.cuda.static_function def generate_conv_inds_stage2_mask(self): code = pccm.FunctionCode() code.arg("indices, hashdata", "tv::Tensor") - code.arg("indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", "tv::Tensor") + code.arg( + "indice_pairs_fwd, indice_pairs_bwd, indice_pairs_uniq, out_inds", + "tv::Tensor") code.arg("mask_fwd, mask_bwd", "tv::Tensor") code.arg("num_out_act", "int") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"tv::array") - code.arg("ksize, stride, padding, dilation", f"tv::array") + code.arg("ksize, stride, padding, dilation", + f"tv::array") code.arg("transposed", f"bool", "false") code.arg("stream_int", f"std::uintptr_t", "0") code.raw(f""" @@ -903,7 +935,6 @@ def generate_conv_inds_stage2_mask(self): """) return code.ret("int") - @pccm.cuda.static_function def generate_subm_conv_inds(self): code = pccm.FunctionCode() @@ -912,7 +943,8 @@ def generate_subm_conv_inds(self): code.arg("batch_size", "int") code.arg("input_dims", f"tv::array") code.arg("ksize, dilation", f"tv::array") - code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", "cumm.tensorview.Tensor = Tensor()") + code.arg("indice_pair_mask", "tv::Tensor", "tv::Tensor()", + "cumm.tensorview.Tensor = Tensor()") code.arg("backward", "bool", "false") code.arg("stream_int", f"std::uintptr_t", "0") @@ -993,6 +1025,7 @@ def generate_subm_conv_inds(self): return code.ret("int") + class SparseConvIndicesCPU(pccm.ParameterizedClass): def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType): super().__init__() @@ -1000,9 +1033,9 @@ def __init__(self, problem: ConvProblem, dtype_indices: dtypes.DType): self.add_include("unordered_map") self.loc_iter = ConvOutLocIter(problem) self.add_param_class("spinds", self.loc_iter, "ConvLocIter") - self.add_param_class("spinds", problem, "ConvProblem") + self.add_param_class("spinds", problem, "ConvProblem") - self.ndim = problem.ndim + self.ndim = problem.ndim self.dtype_indices = dtype_indices self.dtype_indices_uniq = dtype_indices @@ -1016,7 +1049,7 @@ def generate_subm_conv_inds(self): code.arg("batch_size", "int") code.arg("input_dims", f"tv::array") code.arg("ksize, dilation", f"tv::array") - + code.raw(f""" tv::array stride, padding; for (int i = 0; i < {self.ndim}; ++i){{ @@ -1079,7 +1112,8 @@ def generate_conv_inds(self): code.arg("indice_pairs, out_inds, indice_num_per_loc", "tv::Tensor") code.arg("batch_size", "int") code.arg("output_dims, input_dims", f"tv::array") - code.arg("ksize, stride, padding, dilation", f"tv::array") + code.arg("ksize, stride, padding, dilation", + f"tv::array") code.arg("transposed", f"bool", "false") code.raw(f""" int kv = tv::arrayops::prod(ksize); diff --git a/spconv/csrc/sparse/maxpool.py b/spconv/csrc/sparse/maxpool.py index a06ed38..8c2666f 100644 --- a/spconv/csrc/sparse/maxpool.py +++ b/spconv/csrc/sparse/maxpool.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,15 +16,18 @@ from cumm.conv.bases import ConvEnum from cumm.gemm.core.metaarray import MetaArray, seq from cumm import dtypes -import pccm +import pccm from cumm.gemm.layout import TensorGeneric, to_stride from cumm.common import TensorView, TensorViewHashKernel, TensorViewKernel, ThrustLib, GemmBasic from cumm.gemm import codeops -from typing import List +from typing import List from cumm.conv.params import ConvProblem from cumm.gemm.mask_iters import MaskTileIterator, MaskTileIteratorParams -import numpy as np +import numpy as np from cumm.gemm import (thread_map) +from spconv.csrc.sparse.cpu_core import OMPLib +from cumm.constants import CUMM_CPU_ONLY_BUILD + class IndiceMaxPool(pccm.Class): # TODO optimize this function @@ -32,13 +35,13 @@ def __init__(self): super().__init__() self.add_include("limits") self.add_dependency(TensorViewKernel, TensorView, GemmBasic) - + @pccm.cuda.cuda_global_function def forward_kernel(self): code = pccm.FunctionCode() code.targ("T") - code.arg("out_features", f"T*") + code.arg("out_features", f"T*") code.arg("in_features", f"const T*") code.arg("out_indices", "const int*") code.arg("in_indices", "const int*") @@ -67,7 +70,7 @@ def forward_implicit_gemm_kernel(self): code = pccm.FunctionCode() code.targ("T") - code.arg("out_features", f"T*") + code.arg("out_features", f"T*") code.arg("in_features", f"const T*") code.arg("indices", "const int*") code.arg("num_features", "int") @@ -104,9 +107,9 @@ def forward_implicit_gemm_kernel(self): def backward_kernel(self): code = pccm.FunctionCode() code.targ("T") - code.arg("out_features", f"const T*") + code.arg("out_features", f"const T*") code.arg("in_features", f"const T*") - code.arg("dout_features", f"const T*") + code.arg("dout_features", f"const T*") code.arg("din_features", f"T*") code.arg("out_indices", "const int*") code.arg("in_indices", "const int*") @@ -137,9 +140,9 @@ def backward_implicit_gemm_kernel(self): code = pccm.FunctionCode() code.targ("T") - code.arg("out_features", f"const T*") + code.arg("out_features", f"const T*") code.arg("in_features", f"const T*") - code.arg("dout_features", f"const T*") + code.arg("dout_features", f"const T*") code.arg("din_features", f"T*") code.arg("indices_bwd", "const int*") code.arg("num_features", "int") @@ -351,6 +354,9 @@ class IndiceMaxPoolCPU(pccm.Class): def __init__(self): super().__init__() self.add_dependency(TensorView) + if CUMM_CPU_ONLY_BUILD: + self.add_dependency(OMPLib) + self.add_include("tensorview/parallel/all.h") @pccm.static_function def forward(self): @@ -371,20 +377,21 @@ def forward(self): auto in_indices = in_inds.data_ptr(); auto out_indices = out_inds.data_ptr(); - - for (int i = 0; i < nhot; ++i) {{ - int in_idx = in_indices[i]; - int out_idx = out_indices[i]; - auto in_ptr = in_features + in_idx * num_features; - auto out_ptr = out_features + out_idx * num_features; - for (int j = 0; j < num_features; ++j) {{ - auto in = in_ptr[j]; - auto out = out_ptr[j]; - if (in > out){{ - out_ptr[j] = in; + tv::kernel_1d(out.device(), nhot, [&](int begin, int end, int step){{ + for (int i = begin; i < end; i += step) {{ + int in_idx = in_indices[i]; + int out_idx = out_indices[i]; + auto in_ptr = in_features + in_idx * num_features; + auto out_ptr = out_features + out_idx * num_features; + for (int j = 0; j < num_features; ++j) {{ + auto in = in_ptr[j]; + auto out = out_ptr[j]; + if (in > out){{ + out_ptr[j] = in; + }} }} }} - }} + }}); }}); """) return code @@ -412,22 +419,24 @@ def backward(self): auto in_indices = in_inds.data_ptr(); auto out_indices = out_inds.data_ptr(); - - for (int i = 0; i < nhot; ++i) {{ - int in_idx_offset = in_indices[i] * num_features; - int out_idx_offset = out_indices[i] * num_features; - auto in_ptr = in_features + in_idx_offset; - auto out_ptr = out_features + out_idx_offset; - auto din_ptr = din_features + in_idx_offset; - auto dout_ptr = dout_features + out_idx_offset; - for (int j = 0; j < num_features; ++j) {{ - auto in = in_ptr[j]; - auto out = out_ptr[j]; - if (in == out){{ - din_ptr[j] = din_ptr[j] + dout_ptr[j]; + tv::kernel_1d(out.device(), nhot, [&](int begin, int end, int step){{ + for (int i = begin; i < end; i += step) {{ + int in_idx_offset = in_indices[i] * num_features; + int out_idx_offset = out_indices[i] * num_features; + auto in_ptr = in_features + in_idx_offset; + auto out_ptr = out_features + out_idx_offset; + auto din_ptr = din_features + in_idx_offset; + auto dout_ptr = dout_features + out_idx_offset; + for (int j = 0; j < num_features; ++j) {{ + auto in = in_ptr[j]; + auto out = out_ptr[j]; + if (in == out){{ + din_ptr[j] = din_ptr[j] + dout_ptr[j]; + }} }} }} - }} + }}); + }}); """) return code diff --git a/spconv/csrc/sparse/pointops.py b/spconv/csrc/sparse/pointops.py index 6ff18cc..8ff8651 100644 --- a/spconv/csrc/sparse/pointops.py +++ b/spconv/csrc/sparse/pointops.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,27 +15,27 @@ import contextlib from cumm.gemm.core.metaarray import MetaArray, seq from cumm import dtypes -import pccm +import pccm from cumm.gemm.layout import TensorGeneric, to_stride from cumm.common import TensorView, TensorViewHashKernel from cumm.gemm import codeops -from typing import List +from typing import List from cumm.conv.params import ConvProblem -import numpy as np +import numpy as np + class Point2VoxelCommon(pccm.ParameterizedClass): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): super().__init__() self.add_dependency(TensorView) - self.dtype = dtype - self.ndim = ndim + self.dtype = dtype + self.ndim = ndim self.zyx = zyx ret_str = f"std::array" retf_str = f"std::array" retf2_str = f"std::array" self.calc_meta_ret = f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>" - @pccm.pybind.mark @pccm.static_function def calc_meta_data(self): code = pccm.FunctionCode() @@ -80,7 +80,8 @@ def calc_meta_data(self): retf_str = f"std::array" retf2_str = f"std::array" - return code.ret(f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>") + return code.ret( + f"std::tuple<{retf_str}, {ret_str}, {ret_str}, {retf2_str}>") @pccm.static_function def array2tvarray(self): @@ -112,16 +113,21 @@ def tvarray2array(self): """) return code.ret("std::array") + class Point2VoxelKernel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): """this class don't support multi-thread. create p2v for every thread. """ - def __init__(self, dtype: dtypes.DType, ndim: int, layout: TensorGeneric, zyx: bool = True): + def __init__(self, + dtype: dtypes.DType, + ndim: int, + layout: TensorGeneric, + zyx: bool = True): super().__init__() self.add_dependency(TensorView, TensorViewHashKernel) self.add_param_class("layout_ns", layout, "Layout") - self.dtype = dtype - self.ndim = ndim + self.dtype = dtype + self.ndim = ndim self.zyx = zyx @pccm.cuda.cuda_global_function @@ -142,7 +148,7 @@ def build_hash_table(self): point_xyz = f"{self.ndim - 1} - j" if not self.zyx: point_xyz = f"j" - # if zyx, the coors_range and grid_bound is zyx too, + # if zyx, the coors_range and grid_bound is zyx too, # generated indices is zyx. code.raw(f""" for (int i : tv::KernelLoopX(num_points)){{ @@ -166,7 +172,7 @@ def build_hash_table(self): }} }} """) - return code + return code @pccm.cuda.cuda_global_function def assign_table(self): @@ -190,7 +196,7 @@ def assign_table(self): }} }} """) - return code + return code @pccm.cuda.cuda_global_function def generate_voxel(self): @@ -231,7 +237,7 @@ def generate_voxel(self): }} }} """) - return code + return code @pccm.cuda.cuda_global_function def voxel_empty_fill_mean(self): @@ -263,7 +269,7 @@ def voxel_empty_fill_mean(self): }} }} """) - return code + return code @pccm.cuda.cuda_global_function def limit_num_per_voxel_value(self): @@ -276,7 +282,8 @@ def limit_num_per_voxel_value(self): num_per_voxel[i] = count; }} """) - return code + return code + class Point2Voxel(pccm.ParameterizedClass, pccm.pybind.PybindClassMixin): def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): @@ -286,14 +293,23 @@ def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon") layout = TensorGeneric(ndim, True) self.add_param_class("layout_ns", layout, "Layout") - self.dtype = dtype - self.ndim = ndim + self.dtype = dtype + self.ndim = ndim self.zyx = zyx - cuda_funcs = [self.point_to_voxel_hash, self.point_to_voxel_hash_static] - self.add_impl_only_param_class(cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx)) - - self.add_pybind_member("hashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") - self.add_pybind_member("point_indice_data", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") + cuda_funcs = [ + self.point_to_voxel_hash, self.point_to_voxel_hash_static + ] + self.add_impl_only_param_class( + cuda_funcs, "kernel", Point2VoxelKernel(dtype, ndim, layout, zyx)) + + self.add_pybind_member("hashdata", + "tv::Tensor", + readwrite=False, + pyanno="cumm.tensorview.Tensor") + self.add_pybind_member("point_indice_data", + "tv::Tensor", + readwrite=False, + pyanno="cumm.tensorview.Tensor") self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False) @@ -357,7 +373,7 @@ def ctor(self): hashdata = tv::zeros({{1}}, tv::custom128, 0); point_indice_data = tv::zeros({{1}}, tv::int64, 0); """) - return code + return code @pccm.pybind.mark @pccm.cuda.member_function @@ -439,13 +455,13 @@ def point_to_voxel_hash(self): """) return code.ret("std::tuple") - @pccm.pybind.mark @pccm.cuda.static_function def point_to_voxel_hash_static(self): code = pccm.FunctionCode() code.arg("points", "tv::Tensor") - code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", "tv::Tensor") + code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", + "tv::Tensor") code.arg("vsize", f"std::array") code.arg("grid_size, grid_stride", f"std::array") code.arg("coors_range", f"std::array") @@ -527,13 +543,16 @@ def __init__(self, dtype: dtypes.DType, ndim: int, zyx: bool = True): self.add_dependency(TensorView) layout = TensorGeneric(ndim, True) self.add_param_class("layout_ns", layout, "Layout") - self.dtype = dtype - self.ndim = ndim + self.dtype = dtype + self.ndim = ndim self.zyx = zyx self.p2v_c = Point2VoxelCommon(dtype, ndim, zyx) self.add_param_class("p2v_c", self.p2v_c, "Point2VoxelCommon") - self.add_pybind_member("densehashdata", "tv::Tensor", readwrite=False, pyanno="cumm.tensorview.Tensor") + self.add_pybind_member("densehashdata", + "tv::Tensor", + readwrite=False, + pyanno="cumm.tensorview.Tensor") self.add_pybind_member("voxels", "tv::Tensor", readwrite=False) self.add_pybind_member("indices", "tv::Tensor", readwrite=False) @@ -568,7 +587,6 @@ def calc_meta_data(self): """) return code.ret(self.p2v_c.calc_meta_ret) - @pccm.pybind.mark @pccm.constructor def ctor(self): @@ -613,7 +631,7 @@ def ctor(self): densehashdata_ptr[i] = -1; }} """) - return code + return code def point_to_voxel_static_template(self, mean: bool = False): code = pccm.FunctionCode() diff --git a/spconv/pytorch/__init__.py b/spconv/pytorch/__init__.py index d31dae9..8ec49a3 100644 --- a/spconv/pytorch/__init__.py +++ b/spconv/pytorch/__init__.py @@ -4,13 +4,14 @@ import numpy as np import torch -from spconv.pytorch import ops -from spconv.pytorch.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d, - SparseConvTranspose3d, SparseInverseConv2d, - SparseInverseConv3d, SubMConv2d, SubMConv3d) +from spconv.pytorch import ops, functional +from spconv.pytorch.conv import (SparseConv2d, SparseConv3d, + SparseConvTranspose2d, SparseConvTranspose3d, + SparseInverseConv2d, SparseInverseConv3d, + SubMConv2d, SubMConv3d) from spconv.pytorch.core import SparseConvTensor from spconv.pytorch.identity import Identity -from spconv.pytorch.modules import SparseModule, SparseSequential +from spconv.pytorch.modules import SparseModule, SparseSequential, assign_name_for_sparse_modules from spconv.pytorch.ops import ConvAlgo from spconv.pytorch.pool import SparseMaxPool2d, SparseMaxPool3d from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable diff --git a/spconv/pytorch/constants.py b/spconv/pytorch/constants.py index a0fcbfb..1de0083 100644 --- a/spconv/pytorch/constants.py +++ b/spconv/pytorch/constants.py @@ -1,18 +1,18 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +import torch try: remove_plus = torch.__version__.find("+") remove_dotdev = torch.__version__.find(".dev") @@ -26,4 +26,4 @@ PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split("."))) except: # for unknown errors, just set a version - PYTORCH_VERSION = [1, 8, 0] \ No newline at end of file + PYTORCH_VERSION = [1, 8, 0] diff --git a/spconv/pytorch/conv.py b/spconv/pytorch/conv.py index f946df8..882eef5 100644 --- a/spconv/pytorch/conv.py +++ b/spconv/pytorch/conv.py @@ -24,12 +24,13 @@ from spconv import pytorch as spconv from spconv.core import ConvAlgo -import spconv.pytorch.functional as Fsp +from spconv.pytorch import functional as Fsp from spconv.pytorch import ops from spconv.cppconstants import CPU_ONLY_BUILD from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData from spconv.pytorch.modules import SparseModule from spconv.constants import FILTER_HWIO +from spconv.utils import nullcontext def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo): @@ -205,6 +206,7 @@ def forward(self, input: SparseConvTensor): self.dilation) else: out_spatial_shape = spatial_shape + # print(self._sparse_unique_name, spatial_shape, out_spatial_shape) # input.update_grid(out_spatial_shape) # t = time.time() out_tensor = input.shadow_copy() @@ -247,158 +249,165 @@ def forward(self, input: SparseConvTensor): out_tensor = out_tensor.replace_feature(features) return out_tensor indice_dict = input.indice_dict.copy() - + algo = self.algo - if self.indice_key is not None : + if self.indice_key is not None: datas = input.find_indice_pair(self.indice_key) if datas is not None: msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key." assert algo == datas.algo, msg # algo = datas.algo - if algo == ConvAlgo.Native: - datas = input.find_indice_pair(self.indice_key) - if datas is not None: - assert isinstance(datas, IndiceData) - if self.inverse: - assert datas is not None and self.indice_key is not None - assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." - - outids = datas.indices - indice_pairs = datas.indice_pairs - indice_pair_num = datas.indice_pair_num - out_spatial_shape = datas.out_spatial_shape - assert indice_pair_num.shape[0] == np.prod( - self.kernel_size - ), "inverse conv must have same kernel size as its couple conv" - else: - if self.indice_key is not None and datas is not None: - outids = datas.out_indices + profile_ctx = nullcontext() + if input._timer is not None and self._sparse_unique_name: + profile_ctx = input._timer.namespace(self._sparse_unique_name) + with profile_ctx: + if algo == ConvAlgo.Native: + datas = input.find_indice_pair(self.indice_key) + if datas is not None: + assert isinstance(datas, IndiceData) + if self.inverse: + assert datas is not None and self.indice_key is not None + assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." + + outids = datas.indices indice_pairs = datas.indice_pairs indice_pair_num = datas.indice_pair_num + out_spatial_shape = datas.out_spatial_shape + assert indice_pair_num.shape[0] == np.prod( + self.kernel_size + ), "inverse conv must have same kernel size as its couple conv" else: - if input.benchmark: - torch.cuda.synchronize() - t = time.time() - outids, indice_pairs, indice_pair_num = ops.get_indice_pairs( - indices, batch_size, spatial_shape, algo, - self.kernel_size, self.stride, self.padding, - self.dilation, self.output_padding, self.subm, - self.transposed) - if input.benchmark: - torch.cuda.synchronize() - interval = time.time() - t - out_tensor.benchmark_record[ - self.name]["indice_gen_time"].append(interval) - - indice_data = IndiceData(outids, - indices, - indice_pairs, - indice_pair_num, - spatial_shape, - is_subm=self.subm, - algo=algo) - if self.indice_key is not None: - msg = f"your indice key {self.indice_key} already exists in this sparse tensor." - assert self.indice_key not in indice_dict, msg - indice_dict[self.indice_key] = indice_data - if input.benchmark: - torch.cuda.synchronize() - t = time.time() - indice_pairs_calc = indice_pairs - if indice_pairs.device != features.device: - indice_pairs_calc = indice_pairs.to(features.device) - if self.subm: - out_features = Fsp.indice_subm_conv(features, self.weight, - indice_pairs_calc, - indice_pair_num, - outids.shape[0], algo) - else: - if self.inverse: - out_features = Fsp.indice_inverse_conv( + if self.indice_key is not None and datas is not None: + outids = datas.out_indices + indice_pairs = datas.indice_pairs + indice_pair_num = datas.indice_pair_num + else: + if input.benchmark: + torch.cuda.synchronize() + t = time.time() + outids, indice_pairs, indice_pair_num = ops.get_indice_pairs( + indices, batch_size, spatial_shape, algo, + self.kernel_size, self.stride, self.padding, + self.dilation, self.output_padding, self.subm, + self.transposed) + if input.benchmark: + torch.cuda.synchronize() + interval = time.time() - t + out_tensor.benchmark_record[ + self.name]["indice_gen_time"].append(interval) + + indice_data = IndiceData(outids, + indices, + indice_pairs, + indice_pair_num, + spatial_shape, + is_subm=self.subm, + algo=algo) + if self.indice_key is not None: + msg = f"your indice key {self.indice_key} already exists in this sparse tensor." + assert self.indice_key not in indice_dict, msg + indice_dict[self.indice_key] = indice_data + if input.benchmark: + torch.cuda.synchronize() + t = time.time() + indice_pairs_calc = indice_pairs + if indice_pairs.device != features.device: + indice_pairs_calc = indice_pairs.to(features.device) + if self.subm: + out_features = Fsp.indice_subm_conv( features, self.weight, indice_pairs_calc, - indice_pair_num, outids.shape[0], algo) + indice_pair_num, outids.shape[0], algo, input._timer) else: - out_features = Fsp.indice_conv(features, self.weight, - indice_pairs_calc, - indice_pair_num, - outids.shape[0], algo) - - else: - datas = input.find_indice_pair(self.indice_key) - if datas is not None: - assert isinstance(datas, ImplicitGemmIndiceData) - if self.inverse: - assert datas is not None and self.indice_key is not None - assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." - outids = datas.indices - pair_fwd = datas.pair_bwd - pair_bwd = datas.pair_fwd - pair_mask_fwd_splits = datas.pair_mask_bwd_splits - pair_mask_bwd_splits = datas.pair_mask_fwd_splits - mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits - mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits - masks = datas.masks + if self.inverse: + out_features = Fsp.indice_inverse_conv( + features, self.weight, indice_pairs_calc, + indice_pair_num, outids.shape[0], algo) + else: + out_features = Fsp.indice_conv(features, self.weight, + indice_pairs_calc, + indice_pair_num, + outids.shape[0], algo, + input._timer) else: - if self.indice_key is not None and datas is not None: - outids = datas.out_indices - pair_fwd = datas.pair_fwd - pair_bwd = datas.pair_bwd - pair_mask_fwd_splits = datas.pair_mask_fwd_splits - pair_mask_bwd_splits = datas.pair_mask_bwd_splits - mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits - mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits + datas = input.find_indice_pair(self.indice_key) + if datas is not None: + assert isinstance(datas, ImplicitGemmIndiceData) + if self.inverse: + assert datas is not None and self.indice_key is not None + assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops." + outids = datas.indices + pair_fwd = datas.pair_bwd + pair_bwd = datas.pair_fwd + pair_mask_fwd_splits = datas.pair_mask_bwd_splits + pair_mask_bwd_splits = datas.pair_mask_fwd_splits + mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits + mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits masks = datas.masks + else: - res = ops.get_indice_pairs_implicit_gemm( - indices, - batch_size, - spatial_shape, - algo, - ksize=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - out_padding=self.output_padding, - subm=self.subm, - transpose=self.transposed, - is_train=self.training, - alloc=input.thrust_allocator) - outids = res[0] - num_inds_per_loc = res[1] - pair_fwd = res[2] - pair_bwd = res[3] - pair_mask_fwd_splits = res[4] - pair_mask_bwd_splits = res[5] - mask_argsort_fwd_splits = res[6] - mask_argsort_bwd_splits = res[7] - masks = res[8] - if self.indice_key is not None: - indice_data = ImplicitGemmIndiceData( - outids, - indices, - pair_fwd, - pair_bwd, - pair_mask_fwd_splits=pair_mask_fwd_splits, - pair_mask_bwd_splits=pair_mask_bwd_splits, - mask_argsort_fwd_splits=mask_argsort_fwd_splits, - mask_argsort_bwd_splits=mask_argsort_bwd_splits, - masks=masks, - is_subm=self.subm, - out_spatial_shape=out_spatial_shape, - algo=algo) - msg = f"your indice key {self.indice_key} already exists in this sparse tensor." - assert self.indice_key not in indice_dict, msg - indice_dict[self.indice_key] = indice_data - if input.benchmark: - torch.cuda.synchronize() - t = time.time() - num_activate_out = outids.shape[0] - out_features = Fsp.implicit_gemm( - features, self.weight, pair_fwd, pair_bwd, - pair_mask_fwd_splits, pair_mask_bwd_splits, - mask_argsort_fwd_splits, mask_argsort_bwd_splits, - num_activate_out, masks, self.training, self.subm) + if self.indice_key is not None and datas is not None: + outids = datas.out_indices + pair_fwd = datas.pair_fwd + pair_bwd = datas.pair_bwd + pair_mask_fwd_splits = datas.pair_mask_fwd_splits + pair_mask_bwd_splits = datas.pair_mask_bwd_splits + mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits + mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits + masks = datas.masks + else: + with input._timer.namespace("gen_pairs"): + res = ops.get_indice_pairs_implicit_gemm( + indices, + batch_size, + spatial_shape, + algo, + ksize=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + out_padding=self.output_padding, + subm=self.subm, + transpose=self.transposed, + is_train=self.training, + alloc=input.thrust_allocator, + timer=input._timer) + outids = res[0] + num_inds_per_loc = res[1] + pair_fwd = res[2] + pair_bwd = res[3] + pair_mask_fwd_splits = res[4] + pair_mask_bwd_splits = res[5] + mask_argsort_fwd_splits = res[6] + mask_argsort_bwd_splits = res[7] + masks = res[8] + if self.indice_key is not None: + indice_data = ImplicitGemmIndiceData( + outids, + indices, + pair_fwd, + pair_bwd, + pair_mask_fwd_splits=pair_mask_fwd_splits, + pair_mask_bwd_splits=pair_mask_bwd_splits, + mask_argsort_fwd_splits=mask_argsort_fwd_splits, + mask_argsort_bwd_splits=mask_argsort_bwd_splits, + masks=masks, + is_subm=self.subm, + out_spatial_shape=out_spatial_shape, + algo=algo) + msg = f"your indice key {self.indice_key} already exists in this sparse tensor." + assert self.indice_key not in indice_dict, msg + indice_dict[self.indice_key] = indice_data + if input.benchmark: + torch.cuda.synchronize() + t = time.time() + num_activate_out = outids.shape[0] + out_features = Fsp.implicit_gemm( + features, self.weight, pair_fwd, pair_bwd, + pair_mask_fwd_splits, pair_mask_bwd_splits, + mask_argsort_fwd_splits, mask_argsort_bwd_splits, + num_activate_out, masks, self.training, self.subm, + input._timer) if self.bias is not None: out_features += self.bias if input.benchmark: diff --git a/spconv/pytorch/core.py b/spconv/pytorch/core.py index ca7d53e..d6972d0 100644 --- a/spconv/pytorch/core.py +++ b/spconv/pytorch/core.py @@ -19,6 +19,7 @@ from spconv.core import ConvAlgo from spconv.pytorch.constants import PYTORCH_VERSION from spconv.pytorch.ops import ThrustSortAllocator +from spconv.tools import CUDAKernelTimer if PYTORCH_VERSION >= [1, 8, 0]: try: @@ -51,13 +52,14 @@ def __init__(self, out_indices, indices, indice_pairs, indice_pair_num, class ImplicitGemmIndiceData(object): - def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor, pair_fwd: torch.Tensor, - pair_bwd: torch.Tensor, + def __init__(self, out_indices: torch.Tensor, indices: torch.Tensor, + pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, pair_mask_fwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor], - masks: List[np.ndarray], out_spatial_shape, is_subm: bool, algo: ConvAlgo): + masks: List[np.ndarray], out_spatial_shape, is_subm: bool, + algo: ConvAlgo): self.out_indices = out_indices self.indices = indices self.pair_fwd = pair_fwd @@ -99,7 +101,8 @@ def __init__(self, voxel_num: Optional[torch.Tensor] = None, indice_dict: Optional[dict] = None, benchmark: bool = False, - permanent_thrust_allocator: bool = False): + permanent_thrust_allocator: bool = False, + enable_timer: bool = False): """ Args: features: [num_points, num_features] feature tensor @@ -130,9 +133,10 @@ def __init__(self, self.voxel_num = voxel_num # for tensorrt self.benchmark = benchmark self.benchmark_record = {} - self.thrust_allocator: Optional[ThrustSortAllocator] = None + self.thrust_allocator: Optional[ThrustSortAllocator] = None if permanent_thrust_allocator: self.thrust_allocator = ThrustSortAllocator(features.device) + self._timer = CUDAKernelTimer(enable_timer) def replace_feature(self, feature): """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) @@ -144,7 +148,7 @@ def replace_feature(self, feature): new_spt.benchmark = self.benchmark new_spt.benchmark_record = self.benchmark_record new_spt.thrust_allocator = self.thrust_allocator - + new_spt._timer = self._timer return new_spt @property @@ -174,7 +178,8 @@ def from_dense(cls, x: torch.Tensor): def spatial_size(self): return np.prod(self.spatial_shape) - def find_indice_pair(self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]: + def find_indice_pair( + self, key) -> Optional[Union[IndiceData, ImplicitGemmIndiceData]]: if key is None: return None if key in self.indice_dict: @@ -208,4 +213,5 @@ def shadow_copy(self) -> "SparseConvTensor": self.benchmark) tensor.benchmark_record = self.benchmark_record tensor.thrust_allocator = self.thrust_allocator + tensor._timer = self._timer return tensor diff --git a/spconv/pytorch/cppcore.py b/spconv/pytorch/cppcore.py index 5555645..80d6415 100644 --- a/spconv/pytorch/cppcore.py +++ b/spconv/pytorch/cppcore.py @@ -1,20 +1,21 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from cumm import tensorview as tv -import torch +from cumm import tensorview as tv +import torch from typing import Optional, List + _TORCH_DTYPE_TO_TV = { torch.float32: tv.float32, torch.float64: tv.float64, @@ -26,10 +27,13 @@ torch.uint8: tv.uint8, } -def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Optional[List[int]] = None): + +def torch_tensor_to_tv(ten: torch.Tensor, + dtype: Optional[int] = None, + shape: Optional[List[int]] = None): assert ten.is_contiguous(), "must be contiguous tensor" ptr = ten.data_ptr() - device = ten.device + device = ten.device if device.type == "cpu": tv_device = -1 elif device.type == "cuda": @@ -42,10 +46,12 @@ def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, shape: Op dtype = _TORCH_DTYPE_TO_TV[ten.dtype] return tv.from_blob(ptr, shape, dtype, tv_device) + def get_current_stream(): return torch.cuda.current_stream().cuda_stream + if __name__ == "__main__": a = torch.rand(2, 2) atv = torch_tensor_to_tv(a) - print(atv.numpy_view()) \ No newline at end of file + print(atv.numpy_view()) diff --git a/spconv/pytorch/functional.py b/spconv/pytorch/functional.py index 1bad9f0..6eb79d3 100644 --- a/spconv/pytorch/functional.py +++ b/spconv/pytorch/functional.py @@ -15,8 +15,9 @@ import torch from torch import nn from torch.autograd import Function - -import spconv.pytorch.ops as ops +from typing import Optional +from spconv.tools import CUDAKernelTimer +from spconv.pytorch import ops import torch.cuda.amp as amp from torch.autograd.function import once_differentiable import numpy as np @@ -27,23 +28,32 @@ class SparseConvFunction(Function): @staticmethod @amp.custom_fwd(cast_inputs=torch.float16) - def forward(ctx, features, filters, indice_pairs, indice_pair_num, - num_activate_out, algo): + def forward(ctx, + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + algo, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.algo = algo + ctx.timer = timer return ops.indice_conv(features, filters, indice_pairs, indice_pair_num, num_activate_out, False, - algo=algo) + algo=algo, + timer=timer) @staticmethod @once_differentiable @amp.custom_bwd def backward(ctx, grad_output): indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors + timer = ctx.timer input_bp, filters_bp = ops.indice_conv_backward(features, filters, @@ -51,18 +61,27 @@ def backward(ctx, grad_output): indice_pairs, indice_pair_num, False, - algo=ctx.algo) + algo=ctx.algo, + timer=timer) - return input_bp, filters_bp, None, None, None, None + return input_bp, filters_bp, None, None, None, None, None class SparseInverseConvFunction(Function): @staticmethod @amp.custom_fwd(cast_inputs=torch.float16) - def forward(ctx, features, filters, indice_pairs, indice_pair_num, - num_activate_out, algo): + def forward(ctx, + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + algo, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.algo = algo + ctx.timer = timer + return ops.indice_conv(features, filters, indice_pairs, @@ -70,13 +89,16 @@ def forward(ctx, features, filters, indice_pairs, indice_pair_num, num_activate_out, True, False, - algo=algo) + algo=algo, + timer=timer) @staticmethod @once_differentiable @amp.custom_bwd def backward(ctx, grad_output): indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors + timer = ctx.timer + input_bp, filters_bp = ops.indice_conv_backward(features, filters, grad_output, @@ -84,29 +106,40 @@ def backward(ctx, grad_output): indice_pair_num, True, False, - algo=ctx.algo) + algo=ctx.algo, + timer=timer) - return input_bp, filters_bp, None, None, None, None + return input_bp, filters_bp, None, None, None, None, None class SparseImplicitGemmFunction(Function): @staticmethod @amp.custom_fwd(cast_inputs=torch.float16) - def forward(ctx, features: torch.Tensor, filters: torch.Tensor, - pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, + def forward(ctx, + features: torch.Tensor, + filters: torch.Tensor, + pair_fwd: torch.Tensor, + pair_bwd: torch.Tensor, pair_mask_fwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor], - num_activate_out: int, masks: List[np.ndarray], is_train: bool, - is_subm: bool): + num_activate_out: int, + masks: List[np.ndarray], + is_train: bool, + is_subm: bool, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): - out, mask_out, mask_width = ops.implicit_gemm( - features, filters, pair_fwd, pair_mask_fwd_splits, - mask_argsort_fwd_splits, num_activate_out, masks, is_train, is_subm) + out, mask_out, mask_width = ops.implicit_gemm(features, filters, + pair_fwd, + pair_mask_fwd_splits, + mask_argsort_fwd_splits, + num_activate_out, masks, + is_train, is_subm, timer) ctx.save_for_backward(features, filters, pair_fwd, pair_bwd) ctx.mask_width = mask_width ctx.mask_out = mask_out + ctx.timer = timer ctx.pair_mask_fwd_splits = pair_mask_fwd_splits ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits ctx.pair_mask_bwd_splits = pair_mask_bwd_splits @@ -130,30 +163,40 @@ def backward(ctx, grad_output): # num_activate_out = ctx.num_activate_out masks = ctx.masks is_subm = ctx.is_subm - - input_bp, filters_bp = ops.implicit_gemm_backward(features, - filters, - grad_output, - pair_fwd, - pair_bwd, - pair_mask_fwd_splits, - pair_mask_bwd_splits, - mask_argsort_fwd_splits, - mask_argsort_bwd_splits, - mask_output_fwd=mask_out, - masks=masks, - mask_width=mask_width, - is_subm=is_subm) - None_9 = [None] * 10 + timer = ctx.timer + input_bp, filters_bp = ops.implicit_gemm_backward( + features, + filters, + grad_output, + pair_fwd, + pair_bwd, + pair_mask_fwd_splits, + pair_mask_bwd_splits, + mask_argsort_fwd_splits, + mask_argsort_bwd_splits, + mask_output_fwd=mask_out, + masks=masks, + mask_width=mask_width, + is_subm=is_subm, + timer=timer) + None_9 = [None] * 11 return (input_bp, filters_bp, *None_9) + class SubMConvFunction(Function): @staticmethod @amp.custom_fwd(cast_inputs=torch.float16) - def forward(ctx, features, filters, indice_pairs, indice_pair_num, - num_activate_out, algo): + def forward(ctx, + features, + filters, + indice_pairs, + indice_pair_num, + num_activate_out, + algo, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.algo = algo + ctx.timer = timer return ops.indice_conv(features, filters, indice_pairs, @@ -161,13 +204,16 @@ def forward(ctx, features, filters, indice_pairs, indice_pair_num, num_activate_out, False, True, - algo=algo) + algo=algo, + timer=timer) @staticmethod @once_differentiable @amp.custom_bwd def backward(ctx, grad_output): indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors + timer = ctx.timer + input_bp, filters_bp = ops.indice_conv_backward(features, filters, grad_output, @@ -175,9 +221,10 @@ def backward(ctx, grad_output): indice_pair_num, False, True, - algo=ctx.algo) + algo=ctx.algo, + timer=timer) - return input_bp, filters_bp, None, None, None, None + return input_bp, filters_bp, None, None, None, None, None class SparseMaxPoolFunction(Function): @@ -199,12 +246,14 @@ def backward(ctx, grad_output): indice_pairs, indice_pair_num) return input_bp, None, None, None + class SparseMaxPoolImplicitGemmFunction(Function): @staticmethod @amp.custom_fwd(cast_inputs=torch.float16) - def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, indice_pairs_bwd: torch.Tensor, - num_activate_out: int): - out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd, num_activate_out) + def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, + indice_pairs_bwd: torch.Tensor, num_activate_out: int): + out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd, + num_activate_out) ctx.save_for_backward(indice_pairs_bwd, features, out) return out @@ -213,10 +262,11 @@ def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor, indice_ @amp.custom_bwd def backward(ctx, grad_output): indice_pairs_bwd, features, out = ctx.saved_tensors - input_bp = ops.indice_maxpool_implicit_gemm_backward(features, out, grad_output, - indice_pairs_bwd) + input_bp = ops.indice_maxpool_implicit_gemm_backward( + features, out, grad_output, indice_pairs_bwd) return input_bp, None, None, None + indice_conv = SparseConvFunction.apply implicit_gemm = SparseImplicitGemmFunction.apply indice_inverse_conv = SparseInverseConvFunction.apply diff --git a/spconv/pytorch/modules.py b/spconv/pytorch/modules.py index 6e0e8ef..f1038f8 100644 --- a/spconv/pytorch/modules.py +++ b/spconv/pytorch/modules.py @@ -1,18 +1,17 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import sys import time from collections import OrderedDict @@ -53,6 +52,7 @@ class SparseModule(nn.Module): def __init__(self, name=None): super().__init__() self.name = name + self._sparse_unique_name = "" class SparseSequential(SparseModule): @@ -143,3 +143,8 @@ def forward(self, input): input = module(input) return input + +def assign_name_for_sparse_modules(module: nn.Module): + for k, n in module.named_modules(): + if isinstance(n, SparseModule): + n._sparse_unique_name = k diff --git a/spconv/pytorch/ops.py b/spconv/pytorch/ops.py index 68a7849..44c1aa6 100644 --- a/spconv/pytorch/ops.py +++ b/spconv/pytorch/ops.py @@ -26,14 +26,19 @@ from spconv.core_cc.csrc.sparse.all import SpconvOps import spconv.core_cc as _ext +from spconv.utils import nullcontext + if hasattr(_ext, "cumm"): + CPU_ONLY_BUILD = False from spconv.algo import GEMM, CONV # , GATHER, SCATTER else: - GEMM = None - CONV = None + CPU_ONLY_BUILD = True + GEMM = None + CONV = None import time from spconv.constants import FILTER_HWIO from cumm.gemm import codeops +from spconv.tools import CUDAKernelTimer DEBUG = False @@ -240,19 +245,21 @@ def get_indice_pairs(indices: torch.Tensor, return out_inds, pair, indice_num_per_loc -def get_indice_pairs_implicit_gemm(indices: torch.Tensor, - batch_size: int, - spatial_shape: List[int], - algo: ConvAlgo, - ksize: List[int], - stride: List[int], - padding: List[int], - dilation: List[int], - out_padding: List[int], - subm: bool = False, - transpose: bool = False, - is_train: bool = True, - alloc: Optional[ThrustSortAllocator] = None): +def get_indice_pairs_implicit_gemm( + indices: torch.Tensor, + batch_size: int, + spatial_shape: List[int], + algo: ConvAlgo, + ksize: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + out_padding: List[int], + subm: bool = False, + transpose: bool = False, + is_train: bool = True, + alloc: Optional[ThrustSortAllocator] = None, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): """ Why return tuple? because pytorch seems don't support custom object in autograd. return: ( @@ -336,18 +343,18 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, out_inds_tv = torch_tensor_to_tv(out_inds) hashdata_tv = torch_tensor_to_tv(hashdata, dtype=tv.custom64) pair_mask_tv = torch_tensor_to_tv(pair_mask, dtype=tv.uint32) - - SpconvOps.generate_subm_conv_inds(inds_tv, - hashdata_tv, - pair_tv, - out_inds_tv, - indice_num_per_loc_tv, - batch_size=batch_size, - input_dims=spatial_shape, - ksize=ksize, - dilation=dilation, - indice_pair_mask=pair_mask_tv, - stream_int=stream) + with timer.record("gen_subm_inds", stream): + SpconvOps.generate_subm_conv_inds(inds_tv, + hashdata_tv, + pair_tv, + out_inds_tv, + indice_num_per_loc_tv, + batch_size=batch_size, + input_dims=spatial_shape, + ksize=ksize, + dilation=dilation, + indice_pair_mask=pair_mask_tv, + stream_int=stream) # torch.cuda.synchronize() # print("SUBM0", time.time() - t) # CONV.stream_synchronize(stream) @@ -358,13 +365,15 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, mask_argsort_tv = torch_tensor_to_tv(mask_argsort) if alloc is None: alloc = ThrustSortAllocator(indices.device) - for j in range(mask_split_count): - # thrust don't provide two-step sort (first step return workspace size) - # so I use this stupid hack to use torch allocator without touch - # pytorch binary (c++). - # f**k thrust - SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j], alloc.alloc, - mask_argsort_tv[j], stream) + with timer.record("gen_subm_inds_sort", stream): + for j in range(mask_split_count): + # thrust don't provide two-step sort (first step return workspace size) + # so I use this stupid hack to use torch allocator without touch + # pytorch binary (c++). + # f**k thrust + SpconvOps.sort_1d_by_key_allocator(pair_mask_tv[j], + alloc.alloc, + mask_argsort_tv[j], stream) # CONV.stream_synchronize(stream) pair_mask_in_splits = [pair_mask[i] for i in range(mask_split_count)] mask_argsort_in_splits = [ @@ -391,20 +400,20 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, dtype=indices.dtype, device=indices.device) indice_pairs_uniq_tv = torch_tensor_to_tv(indice_pairs_uniq) - - SpconvOps.generate_conv_inds_mask_stage1(inds_tv, - pair_bwd_tv, - indice_pairs_uniq_tv, - indice_num_per_loc_tv, - batch_size=batch_size, - output_dims=out_shape, - input_dims=spatial_shape, - ksize=ksize, - stride=stride, - padding=padding, - dilation=dilation, - transposed=transpose, - stream_int=stream) + with timer.record("gen_conv_inds_stage1", stream): + SpconvOps.generate_conv_inds_mask_stage1(inds_tv, + pair_bwd_tv, + indice_pairs_uniq_tv, + indice_num_per_loc_tv, + batch_size=batch_size, + output_dims=out_shape, + input_dims=spatial_shape, + ksize=ksize, + stride=stride, + padding=padding, + dilation=dilation, + transposed=transpose, + stream_int=stream) if DEBUG: CONV.stream_synchronize(stream) @@ -452,25 +461,25 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, CONV.stream_synchronize(stream) print("REGU_S2_PREPARE", time.time() - t) t = time.time() - - SpconvOps.generate_conv_inds_mask_stage2(inds_tv, - hashdata_tv, - pair_fwd_tv, - pair_bwd_tv, - uniq_res_tv, - out_inds_tv, - pair_mask_fwd_tv, - pair_mask_bwd_tv, - num_out_act=num_act_out, - batch_size=batch_size, - output_dims=out_shape, - input_dims=spatial_shape, - ksize=ksize, - stride=stride, - padding=padding, - dilation=dilation, - transposed=transpose, - stream_int=stream) + with timer.record("gen_conv_inds_stage2", stream): + SpconvOps.generate_conv_inds_mask_stage2(inds_tv, + hashdata_tv, + pair_fwd_tv, + pair_bwd_tv, + uniq_res_tv, + out_inds_tv, + pair_mask_fwd_tv, + pair_mask_bwd_tv, + num_out_act=num_act_out, + batch_size=batch_size, + output_dims=out_shape, + input_dims=spatial_shape, + ksize=ksize, + stride=stride, + padding=padding, + dilation=dilation, + transposed=transpose, + stream_int=stream) if DEBUG: CONV.stream_synchronize(stream) @@ -492,62 +501,61 @@ def get_indice_pairs_implicit_gemm(indices: torch.Tensor, mask_argsort_bwd_tv = torch_tensor_to_tv(mask_argsort_bwd) if alloc is None: alloc = ThrustSortAllocator(indices.device) - - if is_mask_split: - for j in range(mask_split_count): - mask_tv = tv.from_numpy(masks[j]) - # here we try to ensure only call allocator once. - if not is_train: - SpconvOps.sort_1d_by_key_split_allocator( - pair_mask_fwd_tv[j], alloc.alloc, mask_tv, - mask_argsort_fwd_tv[j], stream) - else: - if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): - SpconvOps.sort_1d_by_key_split_allocator( - pair_mask_bwd_tv[j], alloc.alloc, mask_tv, - mask_argsort_bwd_tv[j], stream) + with timer.record("gen_conv_inds_sort", stream): + if is_mask_split: + for j in range(mask_split_count): + mask_tv = tv.from_numpy(masks[j]) + # here we try to ensure only call allocator once. + if not is_train: SpconvOps.sort_1d_by_key_split_allocator( pair_mask_fwd_tv[j], alloc.alloc, mask_tv, mask_argsort_fwd_tv[j], stream) else: - SpconvOps.sort_1d_by_key_split_allocator( - pair_mask_fwd_tv[j], alloc.alloc, mask_tv, - mask_argsort_fwd_tv[j], stream) - SpconvOps.sort_1d_by_key_split_allocator( - pair_mask_bwd_tv[j], alloc.alloc, mask_tv, - mask_argsort_bwd_tv[j], stream) - - # SpconvOps.sort_1d_by_key_split(pair_mask_fwd_tv[j], mask_tv, - # mask_argsort_fwd_tv[j], stream) - # if is_train: - # SpconvOps.sort_1d_by_key_split(pair_mask_bwd_tv[j], - # mask_tv, - # mask_argsort_bwd_tv[j], - # stream) + if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): + SpconvOps.sort_1d_by_key_split_allocator( + pair_mask_bwd_tv[j], alloc.alloc, mask_tv, + mask_argsort_bwd_tv[j], stream) + SpconvOps.sort_1d_by_key_split_allocator( + pair_mask_fwd_tv[j], alloc.alloc, mask_tv, + mask_argsort_fwd_tv[j], stream) + else: + SpconvOps.sort_1d_by_key_split_allocator( + pair_mask_fwd_tv[j], alloc.alloc, mask_tv, + mask_argsort_fwd_tv[j], stream) + SpconvOps.sort_1d_by_key_split_allocator( + pair_mask_bwd_tv[j], alloc.alloc, mask_tv, + mask_argsort_bwd_tv[j], stream) + + # SpconvOps.sort_1d_by_key_split(pair_mask_fwd_tv[j], mask_tv, + # mask_argsort_fwd_tv[j], stream) + # if is_train: + # SpconvOps.sort_1d_by_key_split(pair_mask_bwd_tv[j], + # mask_tv, + # mask_argsort_bwd_tv[j], + # stream) - else: - # if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): - if not is_train: - SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], - alloc.alloc, - mask_argsort_fwd_tv[0], stream) else: - if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): - SpconvOps.sort_1d_by_key_allocator(pair_mask_bwd_tv[0], - alloc.alloc, - mask_argsort_bwd_tv[0], - stream) + # if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): + if not is_train: SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], - alloc.alloc, - mask_argsort_fwd_tv[0], stream) + alloc.alloc, + mask_argsort_fwd_tv[0], + stream) else: - SpconvOps.sort_1d_by_key_allocator(pair_mask_fwd_tv[0], - alloc.alloc, - mask_argsort_fwd_tv[0], stream) - SpconvOps.sort_1d_by_key_allocator(pair_mask_bwd_tv[0], - alloc.alloc, - mask_argsort_bwd_tv[0], - stream) + if pair_mask_bwd_tv.dim(1) > pair_mask_fwd_tv.dim(1): + SpconvOps.sort_1d_by_key_allocator( + pair_mask_bwd_tv[0], alloc.alloc, + mask_argsort_bwd_tv[0], stream) + SpconvOps.sort_1d_by_key_allocator( + pair_mask_fwd_tv[0], alloc.alloc, + mask_argsort_fwd_tv[0], stream) + else: + SpconvOps.sort_1d_by_key_allocator( + pair_mask_fwd_tv[0], alloc.alloc, + mask_argsort_fwd_tv[0], stream) + SpconvOps.sort_1d_by_key_allocator( + pair_mask_bwd_tv[0], alloc.alloc, + mask_argsort_bwd_tv[0], stream) if DEBUG: CONV.stream_synchronize(stream) print("REGU_S2_FINISH", time.time() - t) @@ -587,7 +595,8 @@ def indice_conv(features: torch.Tensor, num_activate_out: int, inverse: bool = False, subm: bool = False, - algo: ConvAlgo = ConvAlgo.Native): + algo: ConvAlgo = ConvAlgo.Native, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): # filters: RSKC # stream = get_current_stream() # CONV.stream_synchronize(stream) @@ -717,38 +726,38 @@ def indice_conv(features: torch.Tensor, stream=stream) # CONV.stream_synchronize(stream) # t = time.time() - - for i, nhot in enumerate(indice_pair_num_cpu): - if subm and i == kv_center: - continue - if subm and i > kv_center: - nhot = indice_pair_num_cpu[kv - i - 1] - if nhot <= 0: - continue - inp_indices = pair_in[i].slice_first_axis(0, nhot) - out_indices = pair_out[i].slice_first_axis(0, nhot) - b = filters_tv[i] - # inp @ filter.T, NC @ KC - beta = 1.0 if inited else 0.0 - algo_desp = GEMM.run_with_tuned_result( - tuned_res, - a, - b, - c, - False, - False if FILTER_HWIO else True, - False, - arch=arch, - stream=stream, - shuffle_type=ShuffleStrideType.ShuffleAC, - a_inds=inp_indices, - c_inds=out_indices, - hint=AlgoHint.Fowrard.value, - alpha=1.0, - beta=beta) - - # gather_times += gather_time - inited = True + with timer.record("forward", stream): + for i, nhot in enumerate(indice_pair_num_cpu): + if subm and i == kv_center: + continue + if subm and i > kv_center: + nhot = indice_pair_num_cpu[kv - i - 1] + if nhot <= 0: + continue + inp_indices = pair_in[i].slice_first_axis(0, nhot) + out_indices = pair_out[i].slice_first_axis(0, nhot) + b = filters_tv[i] + # inp @ filter.T, NC @ KC + beta = 1.0 if inited else 0.0 + algo_desp = GEMM.run_with_tuned_result( + tuned_res, + a, + b, + c, + False, + False if FILTER_HWIO else True, + False, + arch=arch, + stream=stream, + shuffle_type=ShuffleStrideType.ShuffleAC, + a_inds=inp_indices, + c_inds=out_indices, + hint=AlgoHint.Fowrard.value, + alpha=1.0, + beta=beta) + + # gather_times += gather_time + inited = True # CONV.stream_synchronize(stream) # print(out_features.mean(), out_features.max(), out_features.min()) @@ -770,7 +779,8 @@ def indice_conv_backward(features: torch.Tensor, indice_pair_num: torch.Tensor, inverse: bool = False, subm: bool = False, - algo: ConvAlgo = ConvAlgo.Native): + algo: ConvAlgo = ConvAlgo.Native, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): # print(out_bp.mean(), out_bp.max(), out_bp.min()) num_activate_out = out_bp.shape[0] @@ -1046,12 +1056,16 @@ def indice_conv_backward(features: torch.Tensor, return (din, dfilters.reshape(filters_shape)) -def implicit_gemm(features: torch.Tensor, filters: torch.Tensor, +def implicit_gemm(features: torch.Tensor, + filters: torch.Tensor, pair_fwd: torch.Tensor, pair_mask_fwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor], - num_activate_out: int, masks: List[np.ndarray], - is_train: bool, is_subm: bool): + num_activate_out: int, + masks: List[np.ndarray], + is_train: bool, + is_subm: bool, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): stream = get_current_stream() # if DEBUG: @@ -1136,24 +1150,25 @@ def implicit_gemm(features: torch.Tensor, filters: torch.Tensor, # CONV.stream_synchronize(stream) # t = time.time() - - for j in range(num_split): - beta = 0 if j == 0 else 1 - CONV.run_with_tuned_result(tune_res, - ConvOpType.kForward, - features_tv, - filters_tv, - out_features_tv, - mask=pair_mask_fwd_split_tvs[j], - mask_argsort=mask_argsort_fwd_split_tvs[j], - mask_output=mask_output_fwd_tvs[j], - indices=pair_fwd_tv, - reverse_mask=False, - mask_filter=masks_ints[j], - mask_width=-1, - beta=beta, - stream=stream, - verbose=False) + with timer.record("implicit_gemm", stream): + for j in range(num_split): + beta = 0 if j == 0 else 1 + CONV.run_with_tuned_result( + tune_res, + ConvOpType.kForward, + features_tv, + filters_tv, + out_features_tv, + mask=pair_mask_fwd_split_tvs[j], + mask_argsort=mask_argsort_fwd_split_tvs[j], + mask_output=mask_output_fwd_tvs[j], + indices=pair_fwd_tv, + reverse_mask=False, + mask_filter=masks_ints[j], + mask_width=-1, + beta=beta, + stream=stream, + verbose=False) # torch.cuda.synchronize() # if DEBUG: @@ -1166,16 +1181,20 @@ def implicit_gemm(features: torch.Tensor, filters: torch.Tensor, return out_features, mask_output_fwd, mask_width -def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor, - out_bp: torch.Tensor, pair_fwd: torch.Tensor, +def implicit_gemm_backward(features: torch.Tensor, + filters: torch.Tensor, + out_bp: torch.Tensor, + pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, pair_mask_fwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor], mask_output_fwd: torch.Tensor, - masks: List[np.ndarray], mask_width: int, - is_subm: bool): + masks: List[np.ndarray], + mask_width: int, + is_subm: bool, + timer: CUDAKernelTimer = CUDAKernelTimer(False)): # print(out_bp.mean(), out_bp.max(), out_bp.min()) if features.dtype == torch.int8 or features.dtype == torch.qint8: raise NotImplementedError("work in progress") @@ -1287,44 +1306,46 @@ def implicit_gemm_backward(features: torch.Tensor, filters: torch.Tensor, dtype=torch.int8, device=features.device) workspace_tv = torch_tensor_to_tv(workspace) - for j in range(num_split): - beta = 0 if j == 0 else 1 - if is_subm: - mask = pair_mask_fwd_split_tvs[j] - mask_argsort = mask_argsort_fwd_split_tvs[j] - else: - mask = pair_mask_bwd_split_tvs[j] - mask_argsort = mask_argsort_bwd_split_tvs[j] - - CONV.run_with_tuned_result(dgrad_tune_res, - ConvOpType.kBackwardInput, - din_tv, - filters_tv, - dout_tv, - mask=mask, - mask_argsort=mask_argsort, - mask_output=tv.Tensor(), - indices=pair_bwd_tv, - reverse_mask=is_subm, - mask_filter=masks[j].item(), - mask_width=-1, - beta=beta, - stream=stream) - CONV.run_with_tuned_result(wgrad_tune_res, - ConvOpType.kBackwardWeight, - features_tv, - dfilters_tv, - dout_tv, - mask=mask_output_fwd_tv[j], - mask_argsort=mask_argsort_fwd_split_tvs[j], - mask_output=tv.Tensor(), - indices=pair_fwd_tv, - reverse_mask=False, - mask_filter=masks[j].item(), - mask_width=mask_width, - beta=beta, - workspace=workspace_tv, - stream=stream) + with timer.record("implicit_gemm_backward", stream): + for j in range(num_split): + beta = 0 if j == 0 else 1 + if is_subm: + mask = pair_mask_fwd_split_tvs[j] + mask_argsort = mask_argsort_fwd_split_tvs[j] + else: + mask = pair_mask_bwd_split_tvs[j] + mask_argsort = mask_argsort_bwd_split_tvs[j] + + CONV.run_with_tuned_result(dgrad_tune_res, + ConvOpType.kBackwardInput, + din_tv, + filters_tv, + dout_tv, + mask=mask, + mask_argsort=mask_argsort, + mask_output=tv.Tensor(), + indices=pair_bwd_tv, + reverse_mask=is_subm, + mask_filter=masks[j].item(), + mask_width=-1, + beta=beta, + stream=stream) + CONV.run_with_tuned_result( + wgrad_tune_res, + ConvOpType.kBackwardWeight, + features_tv, + dfilters_tv, + dout_tv, + mask=mask_output_fwd_tv[j], + mask_argsort=mask_argsort_fwd_split_tvs[j], + mask_output=tv.Tensor(), + indices=pair_fwd_tv, + reverse_mask=False, + mask_filter=masks[j].item(), + mask_width=mask_width, + beta=beta, + workspace=workspace_tv, + stream=stream) return (din, dfilters.reshape(filters_shape)) @@ -1445,4 +1466,3 @@ def indice_maxpool_implicit_gemm_backward(features, out_features, out_bp, out_bp_tv, din_tv, indice_pairs_tv, stream) return din - diff --git a/spconv/pytorch/pool.py b/spconv/pytorch/pool.py index 863cabb..a3bd952 100644 --- a/spconv/pytorch/pool.py +++ b/spconv/pytorch/pool.py @@ -24,11 +24,12 @@ from spconv import pytorch as spconv from spconv.core import ConvAlgo -import spconv.pytorch.functional as Fsp +from spconv.pytorch import functional as Fsp from spconv.pytorch import ops from spconv.pytorch.core import IndiceData, ImplicitGemmIndiceData from spconv.pytorch.modules import SparseModule from spconv.cppconstants import CPU_ONLY_BUILD +from spconv.utils import nullcontext class SparseMaxPool(SparseModule): @@ -126,79 +127,87 @@ def forward(self, input): if input.benchmark: torch.cuda.synchronize() t = time.time() - out_padding = [0] * self.ndim + out_padding = [0] * self.ndim indice_dict = input.indice_dict.copy() - if self.algo == ConvAlgo.Native: - outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs( - indices, batch_size, spatial_shape, ConvAlgo.Native, - self.kernel_size, self.stride, self.padding, self.dilation, out_padding, - False) - if input.benchmark: - torch.cuda.synchronize() - interval = time.time() - t - out_tensor.benchmark_record[ - self.name]["indice_gen_time"].append(interval) - t = time.time() + profile_ctx = nullcontext() + if input._timer is not None and self._sparse_unique_name: + profile_ctx = input._timer.namespace(self._sparse_unique_name) + with profile_ctx: + if self.algo == ConvAlgo.Native: + outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs( + indices, batch_size, spatial_shape, ConvAlgo.Native, + self.kernel_size, self.stride, self.padding, self.dilation, + out_padding, False) + if input.benchmark: + torch.cuda.synchronize() + interval = time.time() - t + out_tensor.benchmark_record[ + self.name]["indice_gen_time"].append(interval) + t = time.time() - if self.indice_key is not None: - datas = input.find_indice_pair(self.indice_key) - if datas is None: - indice_data = IndiceData(outids, - indices, - indice_pairs, - indice_pairs_num, - spatial_shape, - is_subm=False, - algo=self.algo) - indice_dict[self.indice_key] = indice_data - else: - raise ValueError(f"indice key {self.indice_key} exists") + if self.indice_key is not None: + datas = input.find_indice_pair(self.indice_key) + if datas is None: + indice_data = IndiceData(outids, + indices, + indice_pairs, + indice_pairs_num, + spatial_shape, + is_subm=False, + algo=self.algo) + indice_dict[self.indice_key] = indice_data + else: + raise ValueError( + f"indice key {self.indice_key} exists") - out_features = Fsp.indice_maxpool(features, - indice_pairs.to(device), - indice_pairs_num.to(device), - outids.shape[0]) - else: - res = ops.get_indice_pairs_implicit_gemm(indices, - batch_size, - spatial_shape, - self.algo, - ksize=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - out_padding=out_padding, - subm=self.subm, - is_train=self.training, - alloc=input.thrust_allocator) - outids = res[0] - num_inds_per_loc = res[1] - pair_fwd = res[2] - pair_bwd = res[3] - pair_mask_fwd_splits = res[4] - pair_mask_bwd_splits = res[5] - mask_argsort_fwd_splits = res[6] - mask_argsort_bwd_splits = res[7] - masks = res[8] - if self.indice_key is not None: - indice_data = ImplicitGemmIndiceData( - outids, - indices, - pair_fwd, - pair_bwd, - pair_mask_fwd_splits=pair_mask_fwd_splits, - pair_mask_bwd_splits=pair_mask_bwd_splits, - mask_argsort_fwd_splits=mask_argsort_fwd_splits, - mask_argsort_bwd_splits=mask_argsort_bwd_splits, - masks=masks, - is_subm=self.subm, - out_spatial_shape=out_spatial_shape, - algo=self.algo) - msg = f"your indice key {self.indice_key} already exists in this sparse tensor." - assert self.indice_key not in indice_dict, msg - indice_dict[self.indice_key] = indice_data - out_features = Fsp.indice_maxpool_implicit_gemm( - features, pair_fwd, pair_bwd, outids.shape[0]) + out_features = Fsp.indice_maxpool(features, + indice_pairs.to(device), + indice_pairs_num.to(device), + outids.shape[0]) + else: + with input._timer.namespace("gen_pairs"): + res = ops.get_indice_pairs_implicit_gemm( + indices, + batch_size, + spatial_shape, + self.algo, + ksize=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + out_padding=out_padding, + subm=self.subm, + is_train=self.training, + alloc=input.thrust_allocator, + timer=input._timer) + outids = res[0] + num_inds_per_loc = res[1] + pair_fwd = res[2] + pair_bwd = res[3] + pair_mask_fwd_splits = res[4] + pair_mask_bwd_splits = res[5] + mask_argsort_fwd_splits = res[6] + mask_argsort_bwd_splits = res[7] + masks = res[8] + if self.indice_key is not None: + indice_data = ImplicitGemmIndiceData( + outids, + indices, + pair_fwd, + pair_bwd, + pair_mask_fwd_splits=pair_mask_fwd_splits, + pair_mask_bwd_splits=pair_mask_bwd_splits, + mask_argsort_fwd_splits=mask_argsort_fwd_splits, + mask_argsort_bwd_splits=mask_argsort_bwd_splits, + masks=masks, + is_subm=self.subm, + out_spatial_shape=out_spatial_shape, + algo=self.algo) + msg = f"your indice key {self.indice_key} already exists in this sparse tensor." + assert self.indice_key not in indice_dict, msg + indice_dict[self.indice_key] = indice_data + out_features = Fsp.indice_maxpool_implicit_gemm( + features, pair_fwd, pair_bwd, outids.shape[0]) if input.benchmark: torch.cuda.synchronize() diff --git a/spconv/pytorch/spatial.py b/spconv/pytorch/spatial.py index ffacbbd..ca11445 100644 --- a/spconv/pytorch/spatial.py +++ b/spconv/pytorch/spatial.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/spconv/pytorch/tables.py b/spconv/pytorch/tables.py index a0a1288..3240bb1 100644 --- a/spconv/pytorch/tables.py +++ b/spconv/pytorch/tables.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,18 +15,18 @@ import torch from torch.autograd import Function -import spconv.pytorch as spconv #from torch.nn import Module from spconv.pytorch.modules import SparseModule from spconv.pytorch.core import SparseConvTensor -from typing import List +from typing import List + class JoinTable(SparseModule): # Module): def forward(self, input: List[SparseConvTensor]): - output = spconv.SparseConvTensor( - torch.cat([i.features for i in input], 1), input[0].indices, - input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num, - input[0].indice_dict) + output = SparseConvTensor(torch.cat([i.features for i in input], 1), + input[0].indices, input[0].spatial_shape, + input[0].batch_size, input[0].grid, + input[0].voxel_num, input[0].indice_dict) output.benchmark_record = input[1].benchmark_record output.thrust_allocator = input[1].thrust_allocator return output @@ -37,10 +37,10 @@ def input_spatial_size(self, out_size): class AddTable(SparseModule): # Module): def forward(self, input: List[SparseConvTensor]): - output = spconv.SparseConvTensor( - sum([i.features for i in input]), input[0].indices, - input[0].spatial_shape, input[0].batch_size, input[0].grid, input[0].voxel_num, - input[0].indice_dict) + output = SparseConvTensor(sum([i.features for i in input]), + input[0].indices, input[0].spatial_shape, + input[0].batch_size, input[0].grid, + input[0].voxel_num, input[0].indice_dict) output.benchmark_record = input[1].benchmark_record output.thrust_allocator = input[1].thrust_allocator return output diff --git a/spconv/pytorch/utils.py b/spconv/pytorch/utils.py index c581322..9aa3ce3 100644 --- a/spconv/pytorch/utils.py +++ b/spconv/pytorch/utils.py @@ -82,24 +82,25 @@ def __call__(self, if self.point_indice_data.shape[0] < pc.shape[0]: self.point_indice_data = torch.empty([pc.shape[0]], - dtype=torch.int64, - device=self.device) + dtype=torch.int64, + device=self.device) pc_tv = torch_tensor_to_tv(pc) stream = get_current_stream() voxels_tv = torch_tensor_to_tv(self.voxels) indices_tv = torch_tensor_to_tv(self.indices) num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel) - hashdata_tv = torch_tensor_to_tv(self.hashdata, - dtype=tv.custom128, - shape=[self.hashdata.shape[0]]) - point_indice_data_tv = torch_tensor_to_tv(self.point_indice_data) + hashdata_tv = torch_tensor_to_tv( + self.hashdata, + dtype=tv.custom128, + shape=[self.hashdata.shape[0]]) + point_indice_data_tv = torch_tensor_to_tv( + self.point_indice_data) - res = SpconvOps.point2voxel_cuda(pc_tv, voxels_tv, indices_tv, - num_per_voxel_tv, hashdata_tv, - point_indice_data_tv, self.vsize, - self.grid_size, self.grid_stride, - self.coors_range, empty_mean, - clear_voxels, stream) + res = SpconvOps.point2voxel_cuda( + pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, + hashdata_tv, point_indice_data_tv, self.vsize, + self.grid_size, self.grid_stride, self.coors_range, + empty_mean, clear_voxels, stream) num_voxels = res[0].shape[0] else: pc_tv = torch_tensor_to_tv(pc) @@ -111,8 +112,9 @@ def __call__(self, res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, hashdata_tv, self.vsize, self.grid_size, - self.grid_stride, self.coors_range, - empty_mean, clear_voxels) + self.grid_stride, + self.coors_range, empty_mean, + clear_voxels) num_voxels = res[0].shape[0] return (self.voxels[:num_voxels], self.indices[:num_voxels], diff --git a/spconv/test_utils.py b/spconv/test_utils.py index 12105fb..97b1467 100644 --- a/spconv/test_utils.py +++ b/spconv/test_utils.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/spconv/tools.py b/spconv/tools.py new file mode 100644 index 0000000..019c23b --- /dev/null +++ b/spconv/tools.py @@ -0,0 +1,78 @@ +# Copyright 2021 Yan Yan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from spconv.cppconstants import CPU_ONLY_BUILD +import contextlib +from spconv.utils import nullcontext +if not CPU_ONLY_BUILD: + from cumm.tensorview import CUDAKernelTimer as _CUDAKernelTimer + + +class CUDAKernelTimer: + def __init__(self, enable: bool = True) -> None: + self.enable = enable and not CPU_ONLY_BUILD + if self.enable: + self._timer = _CUDAKernelTimer(enable) + else: + self._timer = None + + @contextlib.contextmanager + def _namespace(self, name: str): + assert self._timer is not None + self._timer.push(name) + try: + yield + finally: + self._timer.pop() + + @contextlib.contextmanager + def _record(self, name: str, stream: int = 0): + assert self._timer is not None + self._timer.push(name) + try: + self._timer.insert_pair("", "start", "stop") + self._timer.record("start", stream) + yield + self._timer.record("stop", stream) + finally: + self._timer.pop() + + def namespace(self, name: str): + if self.enable: + return self._namespace(name) + else: + return nullcontext() + + def record(self, name: str, stream: int = 0): + if self.enable: + return self._record(name, stream) + else: + return nullcontext() + + def get_all_pair_time(self) -> Dict[str, float]: + if self.enable: + assert self._timer is not None + return self._timer.get_all_pair_duration() + else: + return {} + + @staticmethod + def collect_by_name(name: str, res: Dict[str, float]): + filtered_res: Dict[str, float] = {} + for k, v in res.items(): + k_split = k.split(".") + if name in k_split: + filtered_res[k] = v + return filtered_res diff --git a/spconv/utils/__init__.py b/spconv/utils/__init__.py index f7114b6..f4bd83e 100644 --- a/spconv/utils/__init__.py +++ b/spconv/utils/__init__.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,18 +13,37 @@ # limitations under the License. import numpy as np -from cumm import tensorview as tv +from cumm import tensorview as tv +from contextlib import AbstractContextManager +from spconv.cppconstants import CPU_ONLY_BUILD from spconv.core_cc.csrc.sparse.all.ops_cpu1d import Point2VoxelCPU as Point2VoxelCPU1d from spconv.core_cc.csrc.sparse.all.ops_cpu2d import Point2VoxelCPU as Point2VoxelCPU2d from spconv.core_cc.csrc.sparse.all.ops_cpu3d import Point2VoxelCPU as Point2VoxelCPU3d from spconv.core_cc.csrc.sparse.all.ops_cpu4d import Point2VoxelCPU as Point2VoxelCPU4d -import spconv.core_cc.csrc.sparse.all as __all -IS_CPU_ONLY_BUILD = hasattr(__all, "ops1d") - -if IS_CPU_ONLY_BUILD: +if not CPU_ONLY_BUILD: from spconv.core_cc.csrc.sparse.all.ops1d import Point2Voxel as Point2VoxelGPU1d from spconv.core_cc.csrc.sparse.all.ops2d import Point2Voxel as Point2VoxelGPU2d from spconv.core_cc.csrc.sparse.all.ops3d import Point2Voxel as Point2VoxelGPU3d from spconv.core_cc.csrc.sparse.all.ops4d import Point2Voxel as Point2VoxelGPU4d + + +class nullcontext(AbstractContextManager): + """Context manager that does no additional processing. + + Used as a stand-in for a normal context manager, when a particular + block of code is only sometimes used with a normal context manager: + + cm = optional_cm if condition else nullcontext() + with cm: + # Perform operation, using optional_cm if condition is True + """ + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass diff --git a/test/aaa.py b/test/aaa.py deleted file mode 100644 index d2cb33f..0000000 --- a/test/aaa.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -STR = """ -BWG 0.0008761882781982422 -BWG 0.0008311271667480469 -BWG 0.002079486846923828 -BWG 0.002329587936401367 -BWG 0.0025458335876464844 -BWG 0.0026700496673583984 -BWG 0.002583742141723633 -BWG 0.0025262832641601562 -BWG 0.003481149673461914 -BWG 0.003238201141357422 -BWG 0.005095958709716797 -BWG 0.0037899017333984375 -BWG 0.003931283950805664 -BWG 0.003300189971923828 -""" -""" -0.003921985626220703 -0.0049707889556884766 -0.0052530765533447266 -0.0060312747955322266 -0.0036766529083251953 -0.00421142578125 - -0.002129793167114258 -0.0023038387298583984 -0.0013151168823242188 -0.0015285015106201172 -0.0008392333984375 -0.0008127689361572266 -0.0002486705780029297 -0.00030994415283203125 -""" - -STR1 = """ -SUBM 0.0005137920379638672 -F 0.0012662410736083984 -F 0.0016875267028808594 -REGU 0.0009055137634277344 -M 0.0009114742279052734 -SUBM 0.00037789344787597656 -F 0.0020329952239990234 -F 0.001947641372680664 -REGU 0.0009374618530273438 -M 0.00045609474182128906 -SUBM 0.0009856224060058594 -F 0.0009992122650146484 -F 0.0010600090026855469 -REGU 0.0006346702575683594 -M 0.0004057884216308594 -SUBM 0.0006394386291503906 -F 0.0008478164672851562 -F 0.0008838176727294922 -REGU 0.0007183551788330078 -M 0.00025177001953125 -SUBM 0.0009539127349853516 -F 0.0009481906890869141 -F 0.0010502338409423828 -REGU 0.0007147789001464844 -M 0.000274658203125 -SUBM 0.0007004737854003906 -F 0.0009715557098388672 -F 0.0012331008911132812 -REGU 0.0008800029754638672 -M 0.0002167224884033203 -SUBM 0.00045108795166015625 -F 0.0006735324859619141 -F 0.0008375644683837891 -""" -STR2 = """ -F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A0T1688_NS00_C3_01LLL_1 0.0007038116455078125 -F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0007627010345458984 -F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0007650852203369141 -F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0008864402770996094 -F Turing_f16f16f16f16f16tnt_m64n128k32m32n64k32A1T1688_NS00_C3_01LLL_1 0.0004017353057861328 -F Turing_f16f16f16f16f16tnt_m32n128k64m32n32k32A1T1688_NS00_C3_01LLL_1 0.0006165504455566406 -F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0005872249603271484 -F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0006289482116699219 -F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0002968311309814453 -F Turing_f16f16f16f16f16tnt_m64n64k32m32n32k32A1T1688_NS00_C3_01LLL_1 0.0003299713134765625 -F Turing_f16f16f16f16f16tnt_m64n128k64m32n64k32A1T1688_NS00_C3_01LLL_1 0.0002288818359375 -F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0002830028533935547 -F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0001780986785888672 -F Turing_f16f16f16f16f16tnt_m32n64k32m32n32k16A1T1688_NS00_C3_01LLL_1 0.0003058910369873047 -""" -def _handle_lines(s: str): - arr = s.split(" ") - return (arr[0], float(arr[-1])) -from cumm.gemm.codeops import group_by -def print_str(s: str): - - nums = list(map(_handle_lines, s.strip().split("\n"))) - num_dict = group_by(lambda x: x[0], nums) - num_dict_ = {k: sum([vv[1] for vv in v]) for k, v in num_dict.items()} - print(num_dict_) - -print_str(STR1) -print_str(STR2) \ No newline at end of file diff --git a/test/benchmark.py b/test/benchmark.py index d3b2014..d2e02dc 100644 --- a/test/benchmark.py +++ b/test/benchmark.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,10 +19,12 @@ import torch from torch import nn from cumm import tensorview as tv -from spconv.core import ConvAlgo +from spconv.core import ConvAlgo import spconv.pytorch as spconv from spconv.utils import Point2VoxelCPU3d + + def waymo_data(batch_size=1): gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, 150000, 1) @@ -42,7 +44,7 @@ def waymo_data(batch_size=1): class Net(nn.Module): def __init__(self, shape, algo): super().__init__() - pool_algo = algo + pool_algo = algo # pool_algo = ConvAlgo.Native self.net = spconv.SparseSequential( spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0", @@ -68,7 +70,6 @@ def __init__(self, shape, algo): # nn.BatchNorm1d(32), # nn.ReLU(), # spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"), - spconv.SparseMaxPool3d(2, 2, algo=pool_algo), spconv.SubMConv3d(64, 96, @@ -101,7 +102,6 @@ def __init__(self, shape, algo): # nn.BatchNorm1d(128), # nn.ReLU(), # spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"), - spconv.SparseMaxPool3d(2, 2, algo=pool_algo), spconv.SubMConv3d(128, 160, @@ -118,7 +118,6 @@ def __init__(self, shape, algo): # nn.BatchNorm1d(128), # nn.ReLU(), # spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"), - spconv.SparseMaxPool3d(2, 2, algo=pool_algo), spconv.SubMConv3d(160, 192, @@ -136,7 +135,6 @@ def __init__(self, shape, algo): # nn.ReLU(), spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo), # spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"), - spconv.SubMConv3d(192, 224, 3, @@ -174,7 +172,6 @@ def __init__(self, shape, algo): # # nn.ReLU(), # spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo), - ) max_batch_size = 1 # grid (dense map) is used for indice generation. use pre-allocated grid can run faster. @@ -183,16 +180,25 @@ def __init__(self, shape, algo): # self.grid = None self.shape = shape - def forward(self, features, coors, batch_size): - x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, - self.grid) + def forward(self, features, coors, batch_size, enable_timer: bool = False): + x = spconv.SparseConvTensor(features, + coors, + self.shape, + batch_size, + self.grid, + enable_timer=enable_timer) return self.net(x) + class Net2(nn.Module): def __init__(self, shape, algo): super().__init__() self.net = spconv.SparseSequential( - spconv.SubMConv3d(3, 128, 3, bias=False, indice_key="c0", + spconv.SubMConv3d(3, + 128, + 3, + bias=False, + indice_key="c0", algo=algo), # spconv.SubMConv3d(32, # 32, @@ -240,20 +246,22 @@ def forward(self, features, coors, batch_size): self.grid) return self.net(x) -import numpy as np -from cumm import tensorview as tv + +import numpy as np +from cumm import tensorview as tv from spconv.core_cc.csrc.sparse.all import SpconvOps -import pickle +import pickle import torch -from spconv.pytorch.cppcore import torch_tensor_to_tv +from spconv.pytorch.cppcore import torch_tensor_to_tv + def sort_bench(): with open("/home/yy/asd.pkl", "rb") as f: a_th = pickle.load(f) mask_argsort = torch.empty((1, a_th.shape[1]), - dtype=torch.int32, - device=a_th.device) + dtype=torch.int32, + device=a_th.device) a = a_th.cpu().numpy()[0] a_tv = torch_tensor_to_tv(a_th) @@ -262,8 +270,9 @@ def sort_bench(): a_tv_1 = a_tv.clone() SpconvOps.sort_1d_by_key(a_tv_1[0], mask_argsort_tv[0]) + def main(): - import pickle + import pickle np.random.seed(50051) torch.manual_seed(50051) # voxels, coors, spatial_shape = waymo_data() @@ -280,24 +289,55 @@ def main(): voxels_th = torch.from_numpy(voxels).to(device).to(dtype) coors_th = torch.from_numpy(coors).to(device).int() voxels_th.requires_grad = True - algo = spconv.ConvAlgo.MaskImplicitGemm + algo = spconv.ConvAlgo.Native + # 3080 Laptop + # MaskImpGemm: 11.2ms + # MaskSplitImpGemm: 12.2ms + # Native: 13.7ms + # F32 + # MaskSplitImpGemm: 22ms + # MaskImplicitGemm: 23.5ms + # Native: 21.7ms + # Pure Gemm + # Native: 6.6ms + # MaskImpGemm: 4.3ms + # MaskSplitImpGemm: 4.0ms + # F16 Bwd + # MaskSplitImpGemm: 12.2ms + # MaskImpGemm: 13.8ms + # Native: 25.2ms + + # F32 Bwd + # Native: 41.9ms + # MaskImpGemm: 51.0ms + # MaskSplitImpGemm: 41.1ms + # algo = None net = Net(spatial_shape, algo).to(device).eval().to(dtype).train() + spconv.assign_name_for_sparse_modules(net) print(coors_th.shape) out = net(voxels_th, coors_th, 1) print(out.spatial_shape) - print(voxels.mean(), voxels.max(), voxels.min()) - dout = np.random.uniform(-0.2, 0.2, - out.features.shape).astype(np.float32) + print(voxels.mean(), voxels.max(), voxels.min()) + dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32) dout_t = torch.from_numpy(dout).to(device).to(dtype) - print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min()) + print(out.spatial_shape, out.features.mean(), out.features.max(), + out.features.min()) times = [] with torch.no_grad(): for i in range(20): print("------------") torch.cuda.synchronize() t = time.time() - out_nograd = net(voxels_th, coors_th, 1) + out_nograd = net(voxels_th, coors_th, 1, True) + timer = out_nograd._timer + res = timer.collect_by_name("forward", timer.get_all_pair_time()) + res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) + + print(sum(res.values()) + sum(res2.values())) + # print(timer.get_all_pair_time()) + + # print(sum(timer.get_all_pair_time().values())) torch.cuda.synchronize() # sort_bench() times.append(time.time() - t) @@ -313,8 +353,8 @@ def main(): # torch.cuda.synchronize() # times.append(time.time() - t) - # print((net.grid == -1).float().sum(), net.grid.numel()) - # print("spconv time", time.time() - t) + # # # print((net.grid == -1).float().sum(), net.grid.numel()) + # # # print("spconv time", time.time() - t) # print("spconv bw time", np.mean(times[5:])) diff --git a/test/test_conv.py b/test/test_conv.py index 1d5304e..201894c 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -1,11 +1,11 @@ # Copyright 2021 Yan Yan -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -30,6 +30,7 @@ torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + class SparseConv3dTestTorch(nn.Module): def __init__(self, num_layers, @@ -363,7 +364,10 @@ def testSpConv3d(self): strides = [1, 2, 3] paddings = [0, 1, 2] dilations = [1, 2, 3] - algos = [ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, ConvAlgo.MaskSplitImplicitGemm] + algos = [ + ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, + ConvAlgo.MaskSplitImplicitGemm + ] algos = [ConvAlgo.MaskSplitImplicitGemm] for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( @@ -375,8 +379,16 @@ def testSpConv3d(self): device = torch.device(dev) num_points = [1000] * bs dtype = torch.float32 - net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, - d, algo=al).to(device).to(dtype) + net = SparseConv3dTestTorch(1, + 3, + shape, + IC, + OC, + k, + s, + p, + d, + algo=al).to(device).to(dtype) net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d).to(device).to(dtype) @@ -390,27 +402,32 @@ def testSpConv3d(self): indices_t = torch.from_numpy(indices).int().to(device) features_t = torch.from_numpy(features).to(device).to(dtype) features_t.requires_grad = True - features_dense_t = torch.from_numpy(features_dense).to(device).to(dtype) + features_dense_t = torch.from_numpy(features_dense).to(device).to( + dtype) features_dense_t.requires_grad = True if net.algo == ConvAlgo.Native: if FILTER_HWIO: - filters = np.random.uniform(-1, 1, size=[k, k, k, IC, - OC]).astype(np.float32) + filters = np.random.uniform(-1, 1, + size=[k, k, k, IC, + OC]).astype(np.float32) else: - filters = np.random.uniform(-1, 1, size=[k, k, k, OC, - IC]).astype(np.float32) + filters = np.random.uniform(-1, 1, + size=[k, k, k, OC, + IC]).astype(np.float32) filters_t = torch.from_numpy(filters).to(device).to(dtype) if FILTER_HWIO: - net_ref.net[0].weight.data[:] = filters_t.permute(4, 3, 0, 1, - 2).contiguous() + net_ref.net[0].weight.data[:] = filters_t.permute( + 4, 3, 0, 1, 2).contiguous() else: - net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1, - 2).contiguous() + net_ref.net[0].weight.data[:] = filters_t.permute( + 3, 4, 0, 1, 2).contiguous() else: - filters = np.random.uniform(-1, 1, size=[OC, k, k, k, IC]).astype(np.float32) + filters = np.random.uniform(-1, 1, + size=[OC, k, k, k, + IC]).astype(np.float32) filters_t = torch.from_numpy(filters).to(device).to(dtype) - net_ref.net[0].weight.data[:] = filters_t.permute(0, 4, 1, 2, - 3).contiguous() + net_ref.net[0].weight.data[:] = filters_t.permute( + 0, 4, 1, 2, 3).contiguous() net.net[0].weight.data[:] = filters_t out_ref = net_ref(features_dense_t) @@ -446,7 +463,6 @@ def testSpConv3d(self): self.assertAllClose(dw, dw_ref, atol=1e-4) self.assertAllClose(din_np, din_sparse_np, atol=1e-4) - def testSpDeConv3d(self): np.random.seed(484) devices = ["cuda:0"] @@ -499,11 +515,11 @@ def testSpDeConv3d(self): filters_t = torch.from_numpy(filters).to(device) print(net_ref.net[0].weight.shape) if FILTER_HWIO: - net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1, - 2).contiguous() + net_ref.net[0].weight.data[:] = filters_t.permute( + 3, 4, 0, 1, 2).contiguous() else: - net_ref.net[0].weight.data[:] = filters_t.permute(4, 3, 0, 1, - 2).contiguous() + net_ref.net[0].weight.data[:] = filters_t.permute( + 4, 3, 0, 1, 2).contiguous() net.net[0].weight.data[:] = filters_t out_ref = net_ref(features_dense_t) out = net(features_t, indices_t, bs).dense() @@ -532,7 +548,6 @@ def testSpDeConv3d(self): dw = dw.transpose(4, 3, 0, 1, 2) self.assertAllClose(dw, dw_ref, atol=1e-4) - def testSpCpConv3d(self): np.random.seed(484) devices = ["cuda:0", "cpu:0"] diff --git a/version.txt b/version.txt index abae0d9..c5864dc 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.1.3 \ No newline at end of file +2.1.5 \ No newline at end of file