From f582ec34f0bf5a916f0aeba9ac6410c27f75ec8a Mon Sep 17 00:00:00 2001 From: "yan.yan" Date: Thu, 23 Mar 2023 13:50:54 +0800 Subject: [PATCH] v2.3.4: global pool, generative model support --- CHANGELOG.md | 4 ++ docs/USAGE.md | 14 +++- spconv/core_cc/csrc/sparse/all/__init__.pyi | 10 +++ spconv/csrc/sparse/all.py | 26 ++++++++ spconv/csrc/sparse/indices.py | 9 +-- spconv/csrc/sparse/maxpool.py | 64 ++++++++++++++++++ spconv/pytorch/__init__.py | 3 +- spconv/pytorch/ops.py | 14 ++++ spconv/pytorch/pool.py | 36 ++++++++++ test/bench_build_time.py | 18 +++++ test/test_conv.py | 73 ++++++++++++++++++++- version.txt | 2 +- 12 files changed, 265 insertions(+), 8 deletions(-) create mode 100644 test/bench_build_time.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c50cbff..09ab85f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## [2.3.4] - 2023-03-23 +### Added +- Add SparseGlobalMaxPool and SparseGlobalAvgPool for training only. libspconv don't support it. + ## [2.3.3] - 2023-02-02 ### Fixed - Fix int8 nvrtc error when use prebuilt diff --git a/docs/USAGE.md b/docs/USAGE.md index 671b8d0..f02711f 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -31,9 +31,11 @@ from spconv.pytorch.hash import HashTable | ```spconv.SparseConv3d``` | Downsample | ```nn.Conv3d``` | Use ```indice_key``` to save data for inverse | | ```spconv.SubMConv3d``` | Convolution | N/A | Use ```indice_key``` to save data for reuse | | ```spconv.SparseInverseConv3d``` | Upsample | N/A | Use pre-saved ```indice_key``` to upsample | -| ```spconv.SparseConvTranspose3d``` | Upsample (don't use this)| ```nn.ConvTranspose3d``` | VERY SLOW and CAN'T RECOVER ORIGIN POINT CLOUD | +| ```spconv.SparseConvTranspose3d``` | Upsample (for generative model)| ```nn.ConvTranspose3d``` | VERY SLOW and CAN'T RECOVER ORIGIN POINT CLOUD | | ```spconv.SparseMaxPool3d``` | Downsample | ```nn.MaxPool3d``` | Use ```indice_key``` to save data for inverse | | ```spconv.SparseSequential``` | Container | ```nn.Sequential``` | support layers above and ```nn.ReLU, nn.BatchNorm, ...```| +| ```spconv.SparseGlobalMaxPool``` | global pool | N/A | return dense tensor instead of SparseConvTensor| +| ```spconv.SparseGlobalAvgPool``` | global pool | N/A | return dense tensor instead of SparseConvTensor| | Functional APIs | Usage | @@ -143,6 +145,16 @@ class ExampleNet(nn.Module): return self.net(x) ``` +### How To Use SparseConvTranspose + +```SparseConvTranspose``` (standard upsampling) should only be used in generative model. You need to use a classifier to check if a output coordicates is empty, then set batch indices (or xyz) of that sparse tensor to a negative number: + +```Python +spt.indices[empty_mask, 0] = -1 +``` + +In next sparse convolution, invalid coordinates will be removed until you perform next ```spt.indices[empty_mask, 0] = -1```. + #### Common Mistake * issue [#467](https://github.com/traveller59/spconv/issues/467) ```Python diff --git a/spconv/core_cc/csrc/sparse/all/__init__.pyi b/spconv/core_cc/csrc/sparse/all/__init__.pyi index f12b69d..b3987a5 100644 --- a/spconv/core_cc/csrc/sparse/all/__init__.pyi +++ b/spconv/core_cc/csrc/sparse/all/__init__.pyi @@ -296,6 +296,16 @@ class SpconvOps: """ ... @staticmethod + def global_pool_rearrange(out_indices: Tensor, coords: Tensor, counts: Tensor, stream: int = 0) -> None: + """ + Args: + out_indices: + coords: + counts: + stream: + """ + ... + @staticmethod def maxpool_implicit_gemm_forward(out: Tensor, inp: Tensor, inds: Tensor, stream: int = 0) -> None: """ Args: diff --git a/spconv/csrc/sparse/all.py b/spconv/csrc/sparse/all.py index 28dd68b..5199740 100644 --- a/spconv/csrc/sparse/all.py +++ b/spconv/csrc/sparse/all.py @@ -782,6 +782,32 @@ def indice_maxpool_backward(self): """) return code + @pccm.pybind.mark + @_STATIC_FUNCTION + def global_pool_rearrange(self): + code = pccm.FunctionCode() + code.arg("out_indices, coords, counts", "tv::Tensor") + code.arg("stream", "std::uintptr_t", "0", pyanno="int") + code.add_dependency(IndiceMaxPoolCPU) + if not CUMM_CPU_ONLY_BUILD: + code.add_dependency(IndiceMaxPool) + code.raw(f""" + if (out_indices.is_cpu()){{ + IndiceMaxPoolCPU::global_pool_rearrange(out_indices, coords, counts); + }} + """) + if not CUMM_CPU_ONLY_BUILD: + with code.else_(): + code.raw(f""" + IndiceMaxPool::global_pool_rearrange(out_indices, coords, counts, stream); + """) + else: + code.raw(f""" + TV_THROW_RT_ERR("not implemented in cpu-only spconv!!! ") + """) + return code + + @pccm.pybind.mark @_STATIC_FUNCTION def maxpool_implicit_gemm_forward(self): diff --git a/spconv/csrc/sparse/indices.py b/spconv/csrc/sparse/indices.py index 51380cb..f693086 100644 --- a/spconv/csrc/sparse/indices.py +++ b/spconv/csrc/sparse/indices.py @@ -195,7 +195,8 @@ def query_npq(self): stride_valid.append( f"!(npq_no_stride[{i + 1}] % problem_.stride[{i}])") code.raw(f""" - return npq_no_stride[0] < problem_.N && + return (npq_no_stride[0] < problem_.N) && + (npq_no_stride[0] >= 0) && {' && '.join(hw_valid)} && {' && '.join(stride_valid)}; """) @@ -218,7 +219,7 @@ def query_npq_no_stride(self): (f"npq_offset[{i + 1}] >= 0 && " f"npq_offset[{i + 1}] < problem_.output_dims[{i}]")) code.raw(f""" - return npq_offset[0] < problem_.N && + return (npq_offset[0] < problem_.N) && (npq_offset[0] >= 0) && {' && '.join(hw_valid)}; """) return code @@ -240,7 +241,7 @@ def query_nhw(self): (f"nhw_offset[{i + 1}] >= 0 && " f"nhw_offset[{i + 1}] < problem_.input_dims[{i}]")) code.raw(f""" - return nhw_offset[0] < problem_.N && + return (nhw_offset[0] < problem_.N) && (nhw_offset[0] >= 0) && {' && '.join(hw_valid)}; """) return code @@ -262,7 +263,7 @@ def query_nhw_out(self): (f"nhw_offset[{i + 1}] >= 0 && " f"nhw_offset[{i + 1}] < problem_.output_dims[{i}]")) code.raw(f""" - return nhw_offset[0] < problem_.N && + return (nhw_offset[0] < problem_.N) && (nhw_offset[0] >= 0) && {' && '.join(hw_valid)}; """) return code diff --git a/spconv/csrc/sparse/maxpool.py b/spconv/csrc/sparse/maxpool.py index 38f86be..bc7165f 100644 --- a/spconv/csrc/sparse/maxpool.py +++ b/spconv/csrc/sparse/maxpool.py @@ -298,6 +298,46 @@ def backward_avgpool_implicit_gemm_kernel(self): }} """) return code + + @pccm.cuda.cuda_global_function + def global_pool_rearrange_kernel(self): + code = pccm.FunctionCode() + code.arg("out_indices", "int*") + code.arg("coords", "const int*") + code.arg("counts", "int*") + code.arg("num_indices", "int") + + code.arg("indices_stride", "int") + + code.raw(f""" + for (int i : tv::KernelLoopX(num_indices)) {{ + int batch_idx = coords[i * indices_stride]; + if (batch_idx >= 0){{ + auto old = atomicAdd(counts + batch_idx, 1); + out_indices[batch_idx * num_indices + old] = i; + }} + }} + """) + return code + + @pccm.cuda.static_function + def global_pool_rearrange(self): + code = pccm.FunctionCode() + code.arg("out_indices", "tv::Tensor") + code.arg("coords", "tv::Tensor") + code.arg("counts", "tv::Tensor") + code.arg("stream", "std::uintptr_t", "0") + + code.raw(f""" + auto nhot = coords.dim(0); + auto cudastream = reinterpret_cast(stream); + tv::cuda::Launch launcher = tv::cuda::Launch(nhot, cudastream); + launcher(global_pool_rearrange_kernel, out_indices.data_ptr(), + coords.data_ptr(), counts.data_ptr(), nhot, + coords.stride(0)); + TV_CHECK_CUDA_ERR_V2("global_pool_feature_rearrange failed!!!"); + """) + return code @pccm.cuda.static_function def forward(self): @@ -555,6 +595,30 @@ def __init__(self): self.add_dependency(OMPLib) self.add_include("tensorview/parallel/all.h") + @pccm.static_function + def global_pool_rearrange(self): + code = pccm.FunctionCode() + code.arg("out_indices", "tv::Tensor") + code.arg("coords", "tv::Tensor") + code.arg("counts", "tv::Tensor") + + code.raw(f""" + auto nhot = coords.dim(0); + auto out_ptr = out_indices.data_ptr(); + auto coord_ptr = coords.data_ptr(); + auto count_ptr = counts.data_ptr(); + int indices_stride = coords.stride(0); + for (int i = 0; i < nhot; ++i){{ + int batch_idx = coord_ptr[0]; + if (batch_idx >= 0){{ + out_ptr[batch_idx * nhot + (count_ptr[batch_idx]++)] = i; + }} + coord_ptr += indices_stride; + }} + """) + return code + + @pccm.static_function def forward(self): code = pccm.FunctionCode() diff --git a/spconv/pytorch/__init__.py b/spconv/pytorch/__init__.py index 096c7d8..1d4207b 100644 --- a/spconv/pytorch/__init__.py +++ b/spconv/pytorch/__init__.py @@ -21,7 +21,8 @@ from spconv.pytorch.pool import (SparseMaxPool1d, SparseMaxPool2d, SparseMaxPool3d, SparseMaxPool4d, SparseAvgPool1d, SparseAvgPool2d, - SparseAvgPool3d) + SparseAvgPool3d, SparseGlobalMaxPool, + SparseGlobalAvgPool) from spconv.pytorch.tables import AddTable, ConcatTable, JoinTable diff --git a/spconv/pytorch/ops.py b/spconv/pytorch/ops.py index 17f82fe..61eb1b1 100644 --- a/spconv/pytorch/ops.py +++ b/spconv/pytorch/ops.py @@ -2103,3 +2103,17 @@ def maximum_value_int_(ten: torch.Tensor, value: int): else: assert not ten.is_cuda SpconvOps.maximum_value_int(torch_tensor_to_tv(ten), value, stream) + + +def global_pool_rearrange(coords: torch.Tensor, batch_size: int): + is_cpu = not coords.is_cuda + stream = 0 + if not is_cpu: + stream = get_current_stream() + + out_indices = torch.empty((batch_size, coords.shape[0]), dtype=torch.int32, device=coords.device) + counts = torch.zeros((batch_size, ), dtype=torch.int32, device=coords.device) + + SpconvOps.global_pool_rearrange(torch_tensor_to_tv(out_indices), torch_tensor_to_tv(coords), + torch_tensor_to_tv(counts), stream ) + return out_indices, counts \ No newline at end of file diff --git a/spconv/pytorch/pool.py b/spconv/pytorch/pool.py index 2d51740..2384507 100644 --- a/spconv/pytorch/pool.py +++ b/spconv/pytorch/pool.py @@ -248,6 +248,42 @@ def forward(self, input: spconv.SparseConvTensor): out_tensor.spatial_shape = out_spatial_shape return out_tensor +class SparseGlobalMaxOrAvgPool(SparseModule): + """TODO: deploy not supported. this implementation support + backward natively. for deploy, we should use single kernel with + smem based reduce. + """ + def __init__(self, is_mean: bool, name=None): + super(SparseGlobalMaxOrAvgPool, self).__init__(name=name) + self.is_mean = is_mean + + def forward(self, input: spconv.SparseConvTensor): + is_int8 = input.is_quantized + assert not is_int8, "not implemented" + assert isinstance(input, spconv.SparseConvTensor) + out_indices, counts = ops.global_pool_rearrange(input.indices, input.batch_size) + counts_cpu = counts.cpu() + + counts_cpu_np = counts_cpu.numpy() + res_features_list: List[torch.Tensor] = [] + for i in range(input.batch_size): + real_inds = out_indices[i, :counts_cpu_np[i]] + real_features = input.features[real_inds] + if self.is_mean: + real_features_reduced = torch.mean(real_features, dim=0)[0] + else: + real_features_reduced = torch.max(real_features, dim=0)[0] + res_features_list.append(real_features_reduced) + res = torch.stack(res_features_list) + return res + +class SparseGlobalAvgPool(SparseGlobalMaxOrAvgPool): + def __init__(self, name=None): + super(SparseGlobalAvgPool, self).__init__(is_mean=True, name=name) + +class SparseGlobalMaxPool(SparseGlobalMaxOrAvgPool): + def __init__(self, name=None): + super(SparseGlobalMaxPool, self).__init__(is_mean=False, name=name) class SparseAvgPool(SparseModule): def __init__(self, diff --git a/test/bench_build_time.py b/test/bench_build_time.py new file mode 100644 index 0000000..2f9370a --- /dev/null +++ b/test/bench_build_time.py @@ -0,0 +1,18 @@ +import cProfile +import re +import pstats + +import io + +pr = cProfile.Profile() +pr.enable() + +import spconv +pr.disable() +s = io.StringIO() +ps = pstats.Stats(pr, stream=s).sort_stats('tottime') +ps.print_stats() + +with open('test.txt', 'w') as f: + f.write(s.getvalue()) + diff --git a/test/test_conv.py b/test/test_conv.py index d19ae5e..64b6fd7 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -190,6 +190,20 @@ def forward(self, features, coors, batch_size): x = spconv.SparseConvTensor(features, coors, self.shape, batch_size) return self.net(x) # .dense() +class SparseGlobalMaxPoolTestTorch(nn.Module): + def __init__(self, shape): + super().__init__() + layers = [ + spconv.SparseGlobalMaxPool() + ] + self.net = spconv.SparseSequential(*layers, ) + self.shape = shape + + def forward(self, features, coors, batch_size): + coors = coors.int() + x = spconv.SparseConvTensor(features, coors, self.shape, batch_size) + return self.net(x) # .dense() + class MaxPool3dTestTorch(nn.Module): def __init__(self, num_layers, ndim, shape, kernel_size, stride, padding, @@ -526,6 +540,63 @@ def test_spmaxpool3d(): test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4) +def test_spglobalmaxpool3d(): + test_case = TestCase() + + np.random.seed(485) + devices = ["cpu:0", "cuda:0"] + shapes = [[19, 18, 17]] + batchsizes = [1, 2] + + channels = [64] + # ksizes = [2] + # strides = [2] + # paddings = [0] + # dilations = [1] + + + for dev, shape, bs, C in params_grid( + devices, shapes, batchsizes, channels): + device = torch.device(dev) + num_points = [1000] * bs + # when data contains negative, sparse maxpool is not equal to dense maxpool. + sparse_dict = generate_sparse_data(shape, + num_points, + C, + data_range=[0.1, 0.4]) + + features = np.ascontiguousarray(sparse_dict["features"]).astype( + np.float32) + indices = np.ascontiguousarray( + sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) + features_dense = sparse_dict["features_dense"].astype(np.float32) + indices_t = torch.from_numpy(indices).int().to(device) + features_t = torch.from_numpy(features).to(device) + features_t.requires_grad = True + features_dense_t = torch.from_numpy(features_dense).to(device) + features_dense_t.requires_grad = True + net = SparseGlobalMaxPoolTestTorch(shape).to(device) + net_ref = MaxPool3dTestTorch(1, 3, shape, shape, shape, 0, 1).to(device) + + out_ref = net_ref(features_dense_t) + out = net(features_t, indices_t, bs) + out_dense = out + out_np = out.detach().cpu().numpy() + out_ref_np = out_ref.detach().cpu().numpy() + test_case.assertAllClose(out_np.reshape(-1), out_ref_np.reshape(-1), atol=1e-4) + + dout = np.random.uniform( + -0.2, 0.2, out_dense.shape).astype(features.dtype) + dout_t = torch.from_numpy(dout).to(device).view(bs, C, 1, 1, 1) + out.backward(dout_t.reshape(bs, C)) + out_ref.backward(dout_t) + din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, + 1).contiguous() + din_sparse = gather_nd(din_dense, indices_t.long()) + din = features_t.grad.detach() + din_np = din.cpu().numpy() + din_sparse_np = din_sparse.cpu().numpy() + test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4) if __name__ == "__main__": - test_spmaxpool3d() + test_spglobalmaxpool3d() diff --git a/version.txt b/version.txt index 0bee604..3f684d2 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.3.3 +2.3.4