Skip to content

Commit

Permalink
v2.2.3: fix contiguous, add msg if point vanish
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Sep 28, 2022
1 parent 1661828 commit 8b52b3a
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Changelog
## [2.2.3] - 2022-9-28
### Fixed
- Fix missing .contiguous for input feature
- Add some debug msg if points vanished.

## [2.2.2] - 2022-9-25
### Fixed
- Fix CI problem: main function too long and cause OOM in CI vm.
Expand Down
2 changes: 1 addition & 1 deletion docs/COMMON_PROBLEMS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Your coordinates generate nothing with some conv params. Modify your conv params
Example:

Conv Params:
```spatial shape=[8, 200, 200],ksize=[3, 3, 3],stride=[2, 2, 2],padding=[0, 1, 1],dilation=[1, 1, 1]```
```spatial shape=[8, 200, 200],ksize=3,stride=2,padding=[0, 1, 1],dilation=1```
Coordinates:
```
[[0, 7, 153, 142]]
Expand Down
2 changes: 2 additions & 0 deletions example/libspconv/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ libspconv + pybindings = "core_cc.so" in spconv python package.

run ```run_build.sh``` to get ```libspconv.so```.

see [inference code example](main.cu)

## libspconv API

currently not available, but you can check python code to understand how to use C++ apis, spconv python and libspconv use same c++ code.
46 changes: 45 additions & 1 deletion spconv/csrc/sparse/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,28 @@ def get_indice_pairs_implicit_gemm(self):
}}
}}
// tv::ssprint("HASH SIZE", hash_size, num_act_out);
if (num_act_out == 0){{
std::stringstream ss;
ss << R"(Your points vanished here, this usually because you provide
conv params that may ignore some input points. Example:
spatial_shape=[8, 200, 200]
ksize=3
stride=2
padding=[0, 1, 1]
dilation=1
Coordinates=[[0, 7, 153, 142]]
these params will cause ALL points in z == 7 dropped because of padding_z=0.
enlarge your spatial shape or change your conv param to make sure
every input point has a corresponding output point.
Your Conv Params: )" << "\\n";
tv::sstream_print<'\\0'>(ss, " spatial_shape=", input_dims, "\\n");
tv::sstream_print<'\\0'>(ss, " ksize=", ksize, "\\n");
tv::sstream_print<'\\0'>(ss, " stride=", stride, "\\n");
tv::sstream_print<'\\0'>(ss, " padding=", padding, "\\n");
tv::sstream_print<'\\0'>(ss, " dilation=", dilation, "\\n");
tv::ssprint(ss.str());
throw std::runtime_error(ss.str());
}}
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
num_act_out = num_out_act_bound;
}}
Expand Down Expand Up @@ -2092,6 +2113,29 @@ def get_indice_pairs(self):
// TODO pytorch unique may be faster?
num_act_out = apply_thrust_unique_to_indice_pairs_uniq(indice_pairs_uniq, thrustalloc, stream_int);
if (num_act_out == 0){{
std::stringstream ss;
ss << R"(Your points vanished here, this usually because you provide
conv params that may ignore some input points. Example:
spatial_shape=[8, 200, 200]
ksize=3
stride=2
padding=[0, 1, 1]
dilation=1
Coordinates=[[0, 7, 153, 142]]
these params will cause ALL points in z == 7 dropped because of padding_z=0.
enlarge your spatial shape or change your conv param to make sure
every input point has a corresponding output point.
Your Conv Params: )" << "\\n";
tv::sstream_print<'\\0'>(ss, " spatial_shape=", input_dims, "\\n");
tv::sstream_print<'\\0'>(ss, " ksize=", ksize, "\\n");
tv::sstream_print<'\\0'>(ss, " stride=", stride, "\\n");
tv::sstream_print<'\\0'>(ss, " padding=", padding, "\\n");
tv::sstream_print<'\\0'>(ss, " dilation=", dilation, "\\n");
tv::ssprint(ss.str());
throw std::runtime_error(ss.str());
}}
bool use_bound_algo = false;
int64_t num_out_bounded = num_act_out;
if (num_out_act_bound > 0 && num_act_out > num_out_act_bound){{
Expand Down
4 changes: 3 additions & 1 deletion spconv/pytorch/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def forward(self, input: SparseConvTensor):
out_tensor.spatial_shape = out_spatial_shape
return out_tensor
indice_dict = input.indice_dict.copy()

# only support contiguous tensor for now
if not features.is_contiguous():
features = features.contiguous()
algo = self.algo
if self.indice_key is not None:
datas = input.find_indice_pair(self.indice_key)
Expand Down
51 changes: 37 additions & 14 deletions spconv/pytorch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@
DEBUG_INT64_HASH_K = False
INT32_MAX = SpconvOps.get_int32_max()

_POINT_VANISH_MSG = """Your points vanished here, this usually because you provide
conv params that may ignore some input points. Example:
spatial_shape=[8, 200, 200]
ksize=3
stride=2
padding=[0, 1, 1]
dilation=1
Coordinates=[[0, 7, 153, 142]]
these params will cause ALL points in z == 7 dropped because of padding_z=0.
enlarge your spatial shape or change your conv param to make sure
every input point has a corresponding output point.
Your Conv Params:
spatial_shape={}
ksize={}
stride={}
padding={}
dilation={}"""


def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size)
Expand Down Expand Up @@ -239,6 +257,9 @@ def get_indice_pairs(indices: torch.Tensor,
stream_int=stream)
uniq_res = indice_pairs_uniq.unique()
num_act_out = uniq_res.shape[0] - 1
if (num_act_out == 0):
msg = _POINT_VANISH_MSG.format(spatial_shape, ksize, stride, padding, dilation)
raise ValueError(msg)
use_bound_algo = False
if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
Expand Down Expand Up @@ -611,6 +632,9 @@ def get_indice_pairs_implicit_gemm(
num_act_out = uniq_res.shape[0] - 1
uniq_out_indices_offset_tv = torch_tensor_to_tv(uniq_res)
raw_out_indices_offset_tv = indice_pairs_uniq_tv
if (num_act_out == 0):
msg = _POINT_VANISH_MSG.format(spatial_shape, ksize, stride, padding, dilation)
raise ValueError(msg)

if num_out_act_bound > 0 and num_act_out > num_out_act_bound:
num_act_out = num_out_act_bound
Expand Down Expand Up @@ -1080,6 +1104,15 @@ def indice_conv_backward(features: torch.Tensor,
algo: ConvAlgo = ConvAlgo.Native,
timer: CUDAKernelTimer = CUDAKernelTimer(False)):
# print(out_bp.mean(), out_bp.max(), out_bp.min())
filters_shape = filters.shape
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()

if SPCONV_CPP_GEMM and GEMM_CPP is not None:
alloc = TorchAllocator(features.device)
Expand Down Expand Up @@ -1110,16 +1143,6 @@ def indice_conv_backward(features: torch.Tensor,
df = alloc.allocated[AllocKeys.DFilters]
return din, df

filters_shape = filters.shape
# TODO handle this in nn.Module to make sure features in backward is contiguous
if not features.is_contiguous():
features = features.contiguous()
if not out_bp.is_contiguous():
out_bp = out_bp.contiguous()
assert out_bp.is_contiguous()
assert filters.is_contiguous()
assert features.is_contiguous()

if not ALL_WEIGHT_IS_KRSC:
kv_dim = 0
is_KC_not_CK = not FILTER_HWIO
Expand Down Expand Up @@ -1438,6 +1461,10 @@ def implicit_gemm(features: torch.Tensor,
if bias is not None:
bias_tv = torch_tensor_to_tv(bias)

if not features.is_contiguous():
features = features.contiguous()
assert features.is_contiguous()
assert filters.is_contiguous()

if SPCONV_CPP_GEMM and CONV_CPP is not None:
alloc = TorchAllocator(features.device)
Expand Down Expand Up @@ -1477,10 +1504,6 @@ def implicit_gemm(features: torch.Tensor,
# CONV.stream_synchronize(stream)

# t = time.time()
if not features.is_contiguous():
features = features.contiguous()
assert features.is_contiguous()
assert filters.is_contiguous()

if features.dtype == torch.int8 or features.dtype == torch.qint8:
raise NotImplementedError("work in progress")
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.2
2.2.3

0 comments on commit 8b52b3a

Please sign in to comment.