diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..131b2bb --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,8 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: + - name: "Spconv Contributors" +title: "Spconv: Spatially Sparse Convolution Library" +date-released: 2022-10-12 +url: "https://github.com/traveller59/spconv" +license: Apache-2.0 diff --git a/docs/USAGE.md b/docs/USAGE.md index 24d04c9..671b8d0 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -143,6 +143,38 @@ class ExampleNet(nn.Module): return self.net(x) ``` +#### Common Mistake +* issue [#467](https://github.com/traveller59/spconv/issues/467) +```Python +class WrongNet(nn.Module): + def __init__(self, shape): + super().__init__() + self.Encoder = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=2, indice_key="cp1",algo=algo) + self.Sparse_Conv = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=1,algo=algo) + self.Decoder = spconv.SparseInverseConv3d(channels, channels, kernel_size=3, indice_key="cp1",algo=algo) + + def forward(self, sparse_tensor): + encoded = self.Encoder(sparse_tensor) + s_conv = self.Sparse_Conv(encoded) + return self.Decoder(s_conv).features + +class CorrectNet(nn.Module): + def __init__(self, shape): + super().__init__() + self.Encoder = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=2, indice_key="cp1",algo=algo) + self.Sparse_Conv = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=1, indice_key="cp2",algo=algo) + self.Sparse_Conv_Decoder = spconv.SparseInverseConv3d(channels, channels, kernel_size=3, indice_key="cp2",algo=algo) + self.Decoder = spconv.SparseInverseConv3d(channels, channels, kernel_size=3, indice_key="cp1",algo=algo) + + def forward(self, sparse_tensor): + encoded = self.Encoder(sparse_tensor) + s_conv = self.Sparse_Conv(encoded) + return self.Decoder(self.Sparse_Conv_Decoder(s_conv)).features + +``` + +The ```Sparse_Conv``` in ```ExampleNet``` Change spatial structure of output of ```Encoder```, so we can't inverse back to input of ```Encoder``` via ```Decoder```, we need to inverse from ```Sparse_Conv.output``` to ```Encoder.output``` via ```Sparse_Conv_Decoder```, then inverse from ```Encoder.output``` to ```Encoder.input``` via ```Decoder```. + ### Sparse Add In sematic segmentation network, we may use conv1x3, 3x1 and 3x3 in a block, but it's impossible to sum result from these layers because regular add requires same indices. diff --git a/example/fuse_bn_act.py b/example/fuse_bn_act.py index 00eda20..99cca11 100644 --- a/example/fuse_bn_act.py +++ b/example/fuse_bn_act.py @@ -75,7 +75,7 @@ def fuse_act_net(conv, act): fused_conv = copy.deepcopy(conv) if isinstance(act, torch.nn.ReLU): fused_conv.act_type = tv.gemm.Activation.ReLU - if isinstance(act, torch.nn.Sigmoid): + elif isinstance(act, torch.nn.Sigmoid): fused_conv.act_type = tv.gemm.Activation.Sigmoid elif isinstance(act, torch.nn.LeakyReLU): fused_conv.act_type = tv.gemm.Activation.LeakyReLU @@ -385,7 +385,7 @@ def main(): torch.manual_seed(50051) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False - with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f: + with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f: (voxels, coors, spatial_shape) = pickle.load(f) np.random.seed(50051) device = torch.device("cuda:0") diff --git a/spconv/algo.py b/spconv/algo.py index 0651615..72a5831 100644 --- a/spconv/algo.py +++ b/spconv/algo.py @@ -726,8 +726,8 @@ def get_all_available(self, ldw = weight.dim(-1) ldo = out.dim(-1) mask_width_valid = True - - if desp.op_type == ConvOpType.kBackwardWeight.value: + + if desp.op_type.value == ConvOpType.kBackwardWeight.value: assert mask_width > 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0 if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: diff --git a/spconv/pytorch/conv.py b/spconv/pytorch/conv.py index 17280e0..9ff22ef 100644 --- a/spconv/pytorch/conv.py +++ b/spconv/pytorch/conv.py @@ -353,7 +353,8 @@ def forward(self, input: SparseConvTensor): indice_pairs = datas.indice_pairs indice_pair_num = datas.indice_pair_num out_spatial_shape = datas.spatial_shape - assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv" + self._check_inverse_reuse_valid(input, spatial_shape, + datas) else: if self.indice_key is not None and datas is not None: outids = datas.out_indices @@ -466,7 +467,10 @@ def forward(self, input: SparseConvTensor): mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits masks = datas.masks out_spatial_shape = datas.spatial_shape - assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv" + # assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv" + + self._check_inverse_reuse_valid(input, spatial_shape, + datas) else: if self.indice_key is not None and datas is not None: outids = datas.out_indices @@ -602,6 +606,26 @@ def _check_subm_reuse_valid(self, inp: SparseConvTensor, f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}" ) + def _check_inverse_reuse_valid(self, inp: SparseConvTensor, + spatial_shape: List[int], + datas: Union[ImplicitGemmIndiceData, + IndiceData]): + if self.kernel_size != datas.ksize: + raise ValueError( + f"Inverse with same indice_key must have same kernel" + f" size, expect {datas.ksize}, this layer {self.kernel_size}, " + "please check Inverse Convolution in docs/USAGE.md.") + if inp.spatial_shape != datas.out_spatial_shape: + raise ValueError( + f"Inverse with same indice_key must have same spatial structure (spatial shape)" + f", expect {datas.spatial_shape}, input {spatial_shape}, " + "please check Inverse Convolution in docs/USAGE.md.") + if inp.indices.shape[0] != datas.out_indices.shape[0]: + raise ValueError( + f"Inverse with same indice_key must have same num of indices" + f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, " + "please check Inverse Convolution in ." + ) class SparseConv1d(SparseConvolution): def __init__(self, diff --git a/tools/install_windows_cuda.ps1 b/tools/install_windows_cuda.ps1 index 3468ab5..6e631c8 100644 --- a/tools/install_windows_cuda.ps1 +++ b/tools/install_windows_cuda.ps1 @@ -12,6 +12,7 @@ $CUDA_KNOWN_URLS = @{ "11.4" = "https://developer.download.nvidia.com/compute/cuda/11.4.2/network_installers/cuda_11.4.2_win10_network.exe"; "11.5" = "https://developer.download.nvidia.com/compute/cuda/11.5.0/network_installers/cuda_11.5.0_win10_network.exe"; "11.7" = "https://developer.download.nvidia.com/compute/cuda/11.7.1/network_installers/cuda_11.7.1_windows_network.exe"; + "11.8" = "https://developer.download.nvidia.com/compute/cuda/11.8.0/network_installers/cuda_11.8.0_windows_network.exe"; } # cuda_runtime.h is in nvcc <= 10.2, but cudart >= 11.0