Skip to content

Commit

Permalink
v2.1.5: add profile tool and python 3.6 for linux
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Nov 10, 2021
1 parent f31eee3 commit 82fd7a8
Show file tree
Hide file tree
Showing 80 changed files with 2,254 additions and 1,979 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10'] # this version is only used for upload.
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10'] # this version is only used for upload.
cuda-version: ['102', '111', '113', '114', '']

steps:
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/stale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ jobs:
steps:
- uses: actions/stale@v4
with:
stale-issue-message: 'Close stale issues due to inactivity.'
stale-pr-message: 'Close stale PRs due to inactivity.'
stale-issue-message: 'Mark stale issues due to inactivity.'
stale-pr-message: 'Mark stale PRs due to inactivity.'
operations-per-run: 300
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## [2.1.5] - 2021-11-10
### Added
- Add cuda profile tool
- Add python 36 support
### Changed
- Format all code
### Removed
- remove a unnecessary device sync and slightly improve performance.

## [2.1.0] - 2021-10-31
### Addad
* add implicit gemm algorithm for all kind of convolution with kernel volume <= 32. this algorithm is very fast with float16.
Expand Down
35 changes: 29 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,36 @@
See the License for the specific language governing permissions and
limitations under the License.
-->

[pypi-download]: https://img.shields.io/pypi/dm/spconv-cu114
[pypi-url]: https://pypi.org/project/spconv-cu114/
[pypi-image]: https://badge.fury.io/py/spconv-cu114.svg
[pypi-ver-cpu]: https://img.shields.io/pypi/v/spconv
[pypi-ver-114]: https://img.shields.io/pypi/v/spconv-cu114
[pypi-ver-111]: https://img.shields.io/pypi/v/spconv-cu111
[pypi-ver-113]: https://img.shields.io/pypi/v/spconv-cu113
[pypi-ver-102]: https://img.shields.io/pypi/v/spconv-cu102

[pypi-url-111]: https://pypi.org/project/spconv-cu111/
[pypi-download-111]: https://img.shields.io/pypi/dm/spconv-cu111
[pypi-url-113]: https://pypi.org/project/spconv-cu113/
[pypi-download-113]: https://img.shields.io/pypi/dm/spconv-cu113
[pypi-url-102]: https://pypi.org/project/spconv-cu102/
[pypi-download-102]: https://img.shields.io/pypi/dm/spconv-cu102
[pypi-url-114]: https://pypi.org/project/spconv-cu114/
[pypi-download-114]: https://img.shields.io/pypi/dm/spconv-cu114
[pypi-url-cpu]: https://pypi.org/project/spconv/
[pypi-download-cpu]: https://img.shields.io/pypi/dm/spconv

# SpConv: Spatially Sparse Convolution Library
[![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild) [![PyPI Version][pypi-image]][pypi-url] [![pypi monthly download][pypi-download]][pypi-url]
[![Build Status](https://github.com/traveller59/spconv/workflows/build/badge.svg)](https://github.com/traveller59/spconv/actions?query=workflow%3Abuild)

| | PyPi Version | Downloads |
| -------------- |:---------------------:| ---------------------:|
| CPU (Linux Only) | [![PyPI Version][pypi-ver-cpu]][pypi-url-cpu] | [![pypi monthly download][pypi-download-cpu]][pypi-url-cpu] |
| CUDA 10.2 | [![PyPI Version][pypi-ver-102]][pypi-url-102] | [![pypi monthly download][pypi-download-102]][pypi-url-102] |
| CUDA 11.1 | [![PyPI Version][pypi-ver-111]][pypi-url-111] | [![pypi monthly download][pypi-download-111]][pypi-url-111]|
| CUDA 11.3 (Linux Only) | [![PyPI Version][pypi-ver-113]][pypi-url-113] |[![pypi monthly download][pypi-download-113]][pypi-url-113]|
| CUDA 11.4 | [![PyPI Version][pypi-ver-114]][pypi-url-114] | [![pypi monthly download][pypi-download-114]][pypi-url-114]|


```spconv``` is a project that provide heavily-optimized sparse convolution implementation with tensor core support.
```spconv``` is a project that provide heavily-optimized sparse convolution implementation with tensor core support. check [benchmark](docs/BENCHMARK.md) to see how fast spconv 2.x runs.

[Spconv 1.x code](https://github.com/traveller59/spconv/tree/v1.2.1). We won't provide any support for spconv 1.x since it's deprecated. use spconv 2.x if possible. <!--remove this message in spconv 2.2-->

Expand Down Expand Up @@ -99,7 +119,10 @@ The c++ code will be built automatically when you change c++ code in project.

For NVIDIA Embedded Platforms, you need to specify cuda arch before build: ```export CUMM_CUDA_ARCH_LIST="7.2"``` for xavier.

You need to remove ```cumm``` in ```requires``` section in pyproject.toml after install editable ```cumm``` and before install spconv due to pyproject limit (can't find editable installed ```cumm```).

#### Linux

0. uninstall spconv and cumm installed by pip
1. install build-essential, install CUDA
2. ```git clone https://github.com/FindDefinition/cumm```, ```cd ./cumm```, ```pip install -e .```
Expand Down
48 changes: 48 additions & 0 deletions docs/BENCHMARK.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
Copyright 2021 Yan Yan
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

## Simple Benchmark

### Network Benchmark without batchnorm (F32/F16) in RTX 3080 Laptop GPU

Network Code: test/benchmark.py

| F32/F16 | Spconv 1.x F32 (1080Ti) | Native| Implicit Gemm | Implicit Gemm Split Mask |
| -------------- |:---------------------:|---------------------:|---------------------:| ---------------------:|
| Forward | 43ms | 21.7ms/13.7ms | 23.5ms/11.2ms | 22ms/12.2ms |
| Backward | 80ms | 41.9ms/25.2ms | 51.0ms/13.8ms | 41.1ms/12.2ms |

### Network Gemm Kernel Benchmark FP16 in RTX 3080 Laptop GPU

Network Code: test/benchmark.py

The network/input/profile code is same as above table.

This table only profile **fp16 gemm kernels** without output tensor create/clear overhead. this table show the performance upper bound of our algorithm.

| F16 | Native| Implicit Gemm | Implicit Gemm Split Mask |
| -------------- |:---------------------:|---------------------:| ---------------------:|
| Forward | 8.0ms | 4.3ms | 4.0ms |

We can see that the implicit gemm is very fast, gemm only use 4.3ms/11.2ms in network forward. we can achieve better performance in TensorRT + Pure C++.

**NOTE**
When you want to benchmark network in your laptop, don't forget to close all apps except terminals! Other apps will consume GPU resource and make kernels run slower.


## Comparsion with [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) and [torchsparse](https://github.com/mit-han-lab/torchsparse)

TODO
9 changes: 2 additions & 7 deletions docs/PERFORMANCE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@
* make sure your channel size is multiple of 8 when using fp16. multiple of 32 is better.
* spconv 2.x in Windows 10 is 1.5x~2x slower than Linux. use Linux if possible.

Network Benchmark without batchnorm (F32/F16) in RTX 3080 Laptop GPU

| F32/F16 | Spconv 1.x | Native| Implicit Gemm | Implicit Gemm Split Mask |
| -------------- |:---------------------:|---------------------:|---------------------:| ---------------------:|
| Forward | 43ms | 29ms/23ms | 30ms/15ms | 30ms/19ms |
| Backward | 80ms | 47ms/32ms | 56ms/15ms | 45ms/14ms |
See [benchmark](BENCHMARK.md) for more performance details of different algorithms.

## Algorithm Overview

Expand All @@ -57,4 +52,4 @@ In my test, ```Implicit Gemm``` is almost 2x faster than ```Native```.

TODO

In my test, ```Implicit Gemm Split Mask``` is slightly faster than ```Implicit Gemm```, but the indice generation is greatly slower, so currently we use ```Implicit Gemm``` by default.
In my test, ```Implicit Gemm Split Mask``` is slightly faster than ```Implicit Gemm```, but the indice generation is slower, so currently we use ```Implicit Gemm``` by default.
116 changes: 79 additions & 37 deletions example/mnist_sparse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2021 Yan Yan
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -22,11 +22,12 @@
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import contextlib
import torch.cuda.amp
import torch.cuda.amp


@contextlib.contextmanager
def identity_ctx():
yield
yield


class Net(nn.Module):
Expand All @@ -39,14 +40,13 @@ def __init__(self):
spconv.SubMConv2d(32, 64, 3, 1),
nn.ReLU(),
spconv.SparseMaxPool2d(2, 2),
spconv.ToDense(),
spconv.ToDense(),
)
self.fc1 = nn.Linear(14 * 14 * 64, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)


def forward(self, x: torch.Tensor):
# x: [N, 28, 28, 1], must be NHWC tensor
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
Expand Down Expand Up @@ -116,40 +116,72 @@ def test(args, model, device, test_loader):
with amp_ctx:

output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
test_loss += F.nll_loss(
output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(
dim=1,
keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
parser.add_argument('--batch-size',
type=int,
default=64,
metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
parser.add_argument('--test-batch-size',
type=int,
default=1000,
metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
parser.add_argument('--epochs',
type=int,
default=14,
metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
parser.add_argument('--lr',
type=float,
default=1.0,
metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
parser.add_argument('--gamma',
type=float,
default=0.7,
metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
parser.add_argument('--seed',
type=int,
default=1,
metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')

parser.add_argument('--save-model', action='store_true', default=False,
parser.add_argument(
'--log-interval',
type=int,
default=10,
metavar='N',
help='how many batches to wait before logging training status')

parser.add_argument('--save-model',
action='store_true',
default=False,
help='For Saving the current Model')
parser.add_argument('--fp16', action='store_true', default=False,
parser.add_argument('--fp16',
action='store_true',
default=False,
help='For mixed precision training')

args = parser.parse_args()
Expand All @@ -161,20 +193,30 @@ def main():

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
datasets.MNIST(
'../data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
datasets.MNIST(
'../data',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
# here we remove norm to get sparse tensor with lots of zeros
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
Expand Down
Loading

0 comments on commit 82fd7a8

Please sign in to comment.