Skip to content

Commit

Permalink
v2.3.4: global pool, generative model support
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Mar 23, 2023
1 parent 004effb commit f582ec3
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## [2.3.4] - 2023-03-23
### Added
- Add SparseGlobalMaxPool and SparseGlobalAvgPool for training only. libspconv don't support it.

## [2.3.3] - 2023-02-02
### Fixed
- Fix int8 nvrtc error when use prebuilt
Expand Down
14 changes: 13 additions & 1 deletion docs/USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ from spconv.pytorch.hash import HashTable
| ```spconv.SparseConv3d``` | Downsample | ```nn.Conv3d``` | Use ```indice_key``` to save data for inverse |
| ```spconv.SubMConv3d``` | Convolution | N/A | Use ```indice_key``` to save data for reuse |
| ```spconv.SparseInverseConv3d``` | Upsample | N/A | Use pre-saved ```indice_key``` to upsample |
| ```spconv.SparseConvTranspose3d``` | Upsample (don't use this)| ```nn.ConvTranspose3d``` | VERY SLOW and CAN'T RECOVER ORIGIN POINT CLOUD |
| ```spconv.SparseConvTranspose3d``` | Upsample (for generative model)| ```nn.ConvTranspose3d``` | VERY SLOW and CAN'T RECOVER ORIGIN POINT CLOUD |
| ```spconv.SparseMaxPool3d``` | Downsample | ```nn.MaxPool3d``` | Use ```indice_key``` to save data for inverse |
| ```spconv.SparseSequential``` | Container | ```nn.Sequential``` | support layers above and ```nn.ReLU, nn.BatchNorm, ...```|
| ```spconv.SparseGlobalMaxPool``` | global pool | N/A | return dense tensor instead of SparseConvTensor|
| ```spconv.SparseGlobalAvgPool``` | global pool | N/A | return dense tensor instead of SparseConvTensor|


| Functional APIs | Usage |
Expand Down Expand Up @@ -143,6 +145,16 @@ class ExampleNet(nn.Module):
return self.net(x)
```

### How To Use SparseConvTranspose

```SparseConvTranspose``` (standard upsampling) should only be used in generative model. You need to use a classifier to check if a output coordicates is empty, then set batch indices (or xyz) of that sparse tensor to a negative number:

```Python
spt.indices[empty_mask, 0] = -1
```

In next sparse convolution, invalid coordinates will be removed until you perform next ```spt.indices[empty_mask, 0] = -1```.

#### Common Mistake
* issue [#467](https://github.com/traveller59/spconv/issues/467)
```Python
Expand Down
10 changes: 10 additions & 0 deletions spconv/core_cc/csrc/sparse/all/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,16 @@ class SpconvOps:
"""
...
@staticmethod
def global_pool_rearrange(out_indices: Tensor, coords: Tensor, counts: Tensor, stream: int = 0) -> None:
"""
Args:
out_indices:
coords:
counts:
stream:
"""
...
@staticmethod
def maxpool_implicit_gemm_forward(out: Tensor, inp: Tensor, inds: Tensor, stream: int = 0) -> None:
"""
Args:
Expand Down
26 changes: 26 additions & 0 deletions spconv/csrc/sparse/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,32 @@ def indice_maxpool_backward(self):
""")
return code

@pccm.pybind.mark
@_STATIC_FUNCTION
def global_pool_rearrange(self):
code = pccm.FunctionCode()
code.arg("out_indices, coords, counts", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0", pyanno="int")
code.add_dependency(IndiceMaxPoolCPU)
if not CUMM_CPU_ONLY_BUILD:
code.add_dependency(IndiceMaxPool)
code.raw(f"""
if (out_indices.is_cpu()){{
IndiceMaxPoolCPU::global_pool_rearrange(out_indices, coords, counts);
}}
""")
if not CUMM_CPU_ONLY_BUILD:
with code.else_():
code.raw(f"""
IndiceMaxPool::global_pool_rearrange(out_indices, coords, counts, stream);
""")
else:
code.raw(f"""
TV_THROW_RT_ERR("not implemented in cpu-only spconv!!! ")
""")
return code


@pccm.pybind.mark
@_STATIC_FUNCTION
def maxpool_implicit_gemm_forward(self):
Expand Down
9 changes: 5 additions & 4 deletions spconv/csrc/sparse/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def query_npq(self):
stride_valid.append(
f"!(npq_no_stride[{i + 1}] % problem_.stride[{i}])")
code.raw(f"""
return npq_no_stride[0] < problem_.N &&
return (npq_no_stride[0] < problem_.N) &&
(npq_no_stride[0] >= 0) &&
{' && '.join(hw_valid)} &&
{' && '.join(stride_valid)};
""")
Expand All @@ -218,7 +219,7 @@ def query_npq_no_stride(self):
(f"npq_offset[{i + 1}] >= 0 && "
f"npq_offset[{i + 1}] < problem_.output_dims[{i}]"))
code.raw(f"""
return npq_offset[0] < problem_.N &&
return (npq_offset[0] < problem_.N) && (npq_offset[0] >= 0) &&
{' && '.join(hw_valid)};
""")
return code
Expand All @@ -240,7 +241,7 @@ def query_nhw(self):
(f"nhw_offset[{i + 1}] >= 0 && "
f"nhw_offset[{i + 1}] < problem_.input_dims[{i}]"))
code.raw(f"""
return nhw_offset[0] < problem_.N &&
return (nhw_offset[0] < problem_.N) && (nhw_offset[0] >= 0) &&
{' && '.join(hw_valid)};
""")
return code
Expand All @@ -262,7 +263,7 @@ def query_nhw_out(self):
(f"nhw_offset[{i + 1}] >= 0 && "
f"nhw_offset[{i + 1}] < problem_.output_dims[{i}]"))
code.raw(f"""
return nhw_offset[0] < problem_.N &&
return (nhw_offset[0] < problem_.N) && (nhw_offset[0] >= 0) &&
{' && '.join(hw_valid)};
""")
return code
Expand Down
64 changes: 64 additions & 0 deletions spconv/csrc/sparse/maxpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,46 @@ def backward_avgpool_implicit_gemm_kernel(self):
}}
""")
return code

@pccm.cuda.cuda_global_function
def global_pool_rearrange_kernel(self):
code = pccm.FunctionCode()
code.arg("out_indices", "int*")
code.arg("coords", "const int*")
code.arg("counts", "int*")
code.arg("num_indices", "int")

code.arg("indices_stride", "int")

code.raw(f"""
for (int i : tv::KernelLoopX<int>(num_indices)) {{
int batch_idx = coords[i * indices_stride];
if (batch_idx >= 0){{
auto old = atomicAdd(counts + batch_idx, 1);
out_indices[batch_idx * num_indices + old] = i;
}}
}}
""")
return code

@pccm.cuda.static_function
def global_pool_rearrange(self):
code = pccm.FunctionCode()
code.arg("out_indices", "tv::Tensor")
code.arg("coords", "tv::Tensor")
code.arg("counts", "tv::Tensor")
code.arg("stream", "std::uintptr_t", "0")

code.raw(f"""
auto nhot = coords.dim(0);
auto cudastream = reinterpret_cast<cudaStream_t>(stream);
tv::cuda::Launch launcher = tv::cuda::Launch(nhot, cudastream);
launcher(global_pool_rearrange_kernel, out_indices.data_ptr<int>(),
coords.data_ptr<const int>(), counts.data_ptr<int>(), nhot,
coords.stride(0));
TV_CHECK_CUDA_ERR_V2("global_pool_feature_rearrange failed!!!");
""")
return code

@pccm.cuda.static_function
def forward(self):
Expand Down Expand Up @@ -555,6 +595,30 @@ def __init__(self):
self.add_dependency(OMPLib)
self.add_include("tensorview/parallel/all.h")

@pccm.static_function
def global_pool_rearrange(self):
code = pccm.FunctionCode()
code.arg("out_indices", "tv::Tensor")
code.arg("coords", "tv::Tensor")
code.arg("counts", "tv::Tensor")

code.raw(f"""
auto nhot = coords.dim(0);
auto out_ptr = out_indices.data_ptr<int>();
auto coord_ptr = coords.data_ptr<const int>();
auto count_ptr = counts.data_ptr<int>();
int indices_stride = coords.stride(0);
for (int i = 0; i < nhot; ++i){{
int batch_idx = coord_ptr[0];
if (batch_idx >= 0){{
out_ptr[batch_idx * nhot + (count_ptr[batch_idx]++)] = i;
}}
coord_ptr += indices_stride;
}}
""")
return code


@pccm.static_function
def forward(self):
code = pccm.FunctionCode()
Expand Down
3 changes: 2 additions & 1 deletion spconv/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d,
SparseMaxPool3d, SparseMaxPool4d,
SparseAvgPool1d, SparseAvgPool2d,
SparseAvgPool3d)
SparseAvgPool3d, SparseGlobalMaxPool,
SparseGlobalAvgPool)
from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable


Expand Down
14 changes: 14 additions & 0 deletions spconv/pytorch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,3 +2103,17 @@ def maximum_value_int_(ten: torch.Tensor, value: int):
else:
assert not ten.is_cuda
SpconvOps.maximum_value_int(torch_tensor_to_tv(ten), value, stream)


def global_pool_rearrange(coords: torch.Tensor, batch_size: int):
is_cpu = not coords.is_cuda
stream = 0
if not is_cpu:
stream = get_current_stream()

out_indices = torch.empty((batch_size, coords.shape[0]), dtype=torch.int32, device=coords.device)
counts = torch.zeros((batch_size, ), dtype=torch.int32, device=coords.device)

SpconvOps.global_pool_rearrange(torch_tensor_to_tv(out_indices), torch_tensor_to_tv(coords),
torch_tensor_to_tv(counts), stream )
return out_indices, counts
36 changes: 36 additions & 0 deletions spconv/pytorch/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,42 @@ def forward(self, input: spconv.SparseConvTensor):
out_tensor.spatial_shape = out_spatial_shape
return out_tensor

class SparseGlobalMaxOrAvgPool(SparseModule):
"""TODO: deploy not supported. this implementation support
backward natively. for deploy, we should use single kernel with
smem based reduce.
"""
def __init__(self, is_mean: bool, name=None):
super(SparseGlobalMaxOrAvgPool, self).__init__(name=name)
self.is_mean = is_mean

def forward(self, input: spconv.SparseConvTensor):
is_int8 = input.is_quantized
assert not is_int8, "not implemented"
assert isinstance(input, spconv.SparseConvTensor)
out_indices, counts = ops.global_pool_rearrange(input.indices, input.batch_size)
counts_cpu = counts.cpu()

counts_cpu_np = counts_cpu.numpy()
res_features_list: List[torch.Tensor] = []
for i in range(input.batch_size):
real_inds = out_indices[i, :counts_cpu_np[i]]
real_features = input.features[real_inds]
if self.is_mean:
real_features_reduced = torch.mean(real_features, dim=0)[0]
else:
real_features_reduced = torch.max(real_features, dim=0)[0]
res_features_list.append(real_features_reduced)
res = torch.stack(res_features_list)
return res

class SparseGlobalAvgPool(SparseGlobalMaxOrAvgPool):
def __init__(self, name=None):
super(SparseGlobalAvgPool, self).__init__(is_mean=True, name=name)

class SparseGlobalMaxPool(SparseGlobalMaxOrAvgPool):
def __init__(self, name=None):
super(SparseGlobalMaxPool, self).__init__(is_mean=False, name=name)

class SparseAvgPool(SparseModule):
def __init__(self,
Expand Down
18 changes: 18 additions & 0 deletions test/bench_build_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import cProfile
import re
import pstats

import io

pr = cProfile.Profile()
pr.enable()

import spconv
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('tottime')
ps.print_stats()

with open('test.txt', 'w') as f:
f.write(s.getvalue())

73 changes: 72 additions & 1 deletion test/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,20 @@ def forward(self, features, coors, batch_size):
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size)
return self.net(x) # .dense()

class SparseGlobalMaxPoolTestTorch(nn.Module):
def __init__(self, shape):
super().__init__()
layers = [
spconv.SparseGlobalMaxPool()
]
self.net = spconv.SparseSequential(*layers, )
self.shape = shape

def forward(self, features, coors, batch_size):
coors = coors.int()
x = spconv.SparseConvTensor(features, coors, self.shape, batch_size)
return self.net(x) # .dense()


class MaxPool3dTestTorch(nn.Module):
def __init__(self, num_layers, ndim, shape, kernel_size, stride, padding,
Expand Down Expand Up @@ -526,6 +540,63 @@ def test_spmaxpool3d():
test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4)


def test_spglobalmaxpool3d():
test_case = TestCase()

np.random.seed(485)
devices = ["cpu:0", "cuda:0"]
shapes = [[19, 18, 17]]
batchsizes = [1, 2]

channels = [64]
# ksizes = [2]
# strides = [2]
# paddings = [0]
# dilations = [1]


for dev, shape, bs, C in params_grid(
devices, shapes, batchsizes, channels):
device = torch.device(dev)
num_points = [1000] * bs
# when data contains negative, sparse maxpool is not equal to dense maxpool.
sparse_dict = generate_sparse_data(shape,
num_points,
C,
data_range=[0.1, 0.4])

features = np.ascontiguousarray(sparse_dict["features"]).astype(
np.float32)
indices = np.ascontiguousarray(
sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
features_dense = sparse_dict["features_dense"].astype(np.float32)
indices_t = torch.from_numpy(indices).int().to(device)
features_t = torch.from_numpy(features).to(device)
features_t.requires_grad = True
features_dense_t = torch.from_numpy(features_dense).to(device)
features_dense_t.requires_grad = True
net = SparseGlobalMaxPoolTestTorch(shape).to(device)
net_ref = MaxPool3dTestTorch(1, 3, shape, shape, shape, 0, 1).to(device)

out_ref = net_ref(features_dense_t)
out = net(features_t, indices_t, bs)
out_dense = out
out_np = out.detach().cpu().numpy()
out_ref_np = out_ref.detach().cpu().numpy()
test_case.assertAllClose(out_np.reshape(-1), out_ref_np.reshape(-1), atol=1e-4)

dout = np.random.uniform(
-0.2, 0.2, out_dense.shape).astype(features.dtype)
dout_t = torch.from_numpy(dout).to(device).view(bs, C, 1, 1, 1)
out.backward(dout_t.reshape(bs, C))
out_ref.backward(dout_t)
din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4,
1).contiguous()
din_sparse = gather_nd(din_dense, indices_t.long())
din = features_t.grad.detach()
din_np = din.cpu().numpy()
din_sparse_np = din_sparse.cpu().numpy()
test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4)

if __name__ == "__main__":
test_spmaxpool3d()
test_spglobalmaxpool3d()
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.3
2.3.4

0 comments on commit f582ec3

Please sign in to comment.