Skip to content

Commit

Permalink
fix #524 and small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Oct 18, 2022
1 parent 8b52b3a commit 24df06f
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 6 deletions.
8 changes: 8 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- name: "Spconv Contributors"
title: "Spconv: Spatially Sparse Convolution Library"
date-released: 2022-10-12
url: "https://github.com/traveller59/spconv"
license: Apache-2.0
32 changes: 32 additions & 0 deletions docs/USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,38 @@ class ExampleNet(nn.Module):
return self.net(x)
```

#### Common Mistake
* issue [#467](https://github.com/traveller59/spconv/issues/467)
```Python
class WrongNet(nn.Module):
def __init__(self, shape):
super().__init__()
self.Encoder = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=2, indice_key="cp1",algo=algo)
self.Sparse_Conv = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=1,algo=algo)
self.Decoder = spconv.SparseInverseConv3d(channels, channels, kernel_size=3, indice_key="cp1",algo=algo)

def forward(self, sparse_tensor):
encoded = self.Encoder(sparse_tensor)
s_conv = self.Sparse_Conv(encoded)
return self.Decoder(s_conv).features

class CorrectNet(nn.Module):
def __init__(self, shape):
super().__init__()
self.Encoder = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=2, indice_key="cp1",algo=algo)
self.Sparse_Conv = spconv.SparseConv3d(channels, channels, kernel_size=3, stride=1, indice_key="cp2",algo=algo)
self.Sparse_Conv_Decoder = spconv.SparseInverseConv3d(channels, channels, kernel_size=3, indice_key="cp2",algo=algo)
self.Decoder = spconv.SparseInverseConv3d(channels, channels, kernel_size=3, indice_key="cp1",algo=algo)

def forward(self, sparse_tensor):
encoded = self.Encoder(sparse_tensor)
s_conv = self.Sparse_Conv(encoded)
return self.Decoder(self.Sparse_Conv_Decoder(s_conv)).features

```

The ```Sparse_Conv``` in ```ExampleNet``` Change spatial structure of output of ```Encoder```, so we can't inverse back to input of ```Encoder``` via ```Decoder```, we need to inverse from ```Sparse_Conv.output``` to ```Encoder.output``` via ```Sparse_Conv_Decoder```, then inverse from ```Encoder.output``` to ```Encoder.input``` via ```Decoder```.

### Sparse Add

In sematic segmentation network, we may use conv1x3, 3x1 and 3x3 in a block, but it's impossible to sum result from these layers because regular add requires same indices.
Expand Down
4 changes: 2 additions & 2 deletions example/fuse_bn_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def fuse_act_net(conv, act):
fused_conv = copy.deepcopy(conv)
if isinstance(act, torch.nn.ReLU):
fused_conv.act_type = tv.gemm.Activation.ReLU
if isinstance(act, torch.nn.Sigmoid):
elif isinstance(act, torch.nn.Sigmoid):
fused_conv.act_type = tv.gemm.Activation.Sigmoid
elif isinstance(act, torch.nn.LeakyReLU):
fused_conv.act_type = tv.gemm.Activation.LeakyReLU
Expand Down Expand Up @@ -385,7 +385,7 @@ def main():
torch.manual_seed(50051)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f:
with open(Path(__file__).parent.parent / "test" / "data" / "test_spconv.pkl", "rb") as f:
(voxels, coors, spatial_shape) = pickle.load(f)
np.random.seed(50051)
device = torch.device("cuda:0")
Expand Down
4 changes: 2 additions & 2 deletions spconv/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,8 @@ def get_all_available(self,
ldw = weight.dim(-1)
ldo = out.dim(-1)
mask_width_valid = True

if desp.op_type == ConvOpType.kBackwardWeight.value:
if desp.op_type.value == ConvOpType.kBackwardWeight.value:
assert mask_width > 0
mask_width_valid = mask_width % desp.tile_shape[2] == 0
if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
Expand Down
28 changes: 26 additions & 2 deletions spconv/pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def forward(self, input: SparseConvTensor):
indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num
out_spatial_shape = datas.spatial_shape
assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
self._check_inverse_reuse_valid(input, spatial_shape,
datas)
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
Expand Down Expand Up @@ -466,7 +467,10 @@ def forward(self, input: SparseConvTensor):
mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
masks = datas.masks
out_spatial_shape = datas.spatial_shape
assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
# assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"

self._check_inverse_reuse_valid(input, spatial_shape,
datas)
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
Expand Down Expand Up @@ -602,6 +606,26 @@ def _check_subm_reuse_valid(self, inp: SparseConvTensor,
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
)

def _check_inverse_reuse_valid(self, inp: SparseConvTensor,
spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData,
IndiceData]):
if self.kernel_size != datas.ksize:
raise ValueError(
f"Inverse with same indice_key must have same kernel"
f" size, expect {datas.ksize}, this layer {self.kernel_size}, "
"please check Inverse Convolution in docs/USAGE.md.")
if inp.spatial_shape != datas.out_spatial_shape:
raise ValueError(
f"Inverse with same indice_key must have same spatial structure (spatial shape)"
f", expect {datas.spatial_shape}, input {spatial_shape}, "
"please check Inverse Convolution in docs/USAGE.md.")
if inp.indices.shape[0] != datas.out_indices.shape[0]:
raise ValueError(
f"Inverse with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}, "
"please check Inverse Convolution in ."
)

class SparseConv1d(SparseConvolution):
def __init__(self,
Expand Down
1 change: 1 addition & 0 deletions tools/install_windows_cuda.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ $CUDA_KNOWN_URLS = @{
"11.4" = "https://developer.download.nvidia.com/compute/cuda/11.4.2/network_installers/cuda_11.4.2_win10_network.exe";
"11.5" = "https://developer.download.nvidia.com/compute/cuda/11.5.0/network_installers/cuda_11.5.0_win10_network.exe";
"11.7" = "https://developer.download.nvidia.com/compute/cuda/11.7.1/network_installers/cuda_11.7.1_windows_network.exe";
"11.8" = "https://developer.download.nvidia.com/compute/cuda/11.8.0/network_installers/cuda_11.8.0_windows_network.exe";
}

# cuda_runtime.h is in nvcc <= 10.2, but cudart >= 11.0
Expand Down

0 comments on commit 24df06f

Please sign in to comment.