From bf011c76039fdcd89debcbbfc6e157e0068aa41d Mon Sep 17 00:00:00 2001 From: "yan.yan" Date: Tue, 23 Nov 2021 22:44:48 +0800 Subject: [PATCH] temp commit --- example/voxel_gen.py | 120 ++- scripts/dev_subm.py | 409 --------- spconv/algo.py | 6 +- spconv/benchmark/__init__.py | 0 spconv/benchmark/__main__.py | 6 + spconv/benchmark/basic.py | 199 +++++ spconv/benchmark/core.py | 40 + spconv/core.py | 6 +- spconv/core_cc/csrc/sparse/all/__init__.pyi | 6 +- spconv/core_cc/csrc/sparse/all/ops1d.pyi | 3 +- spconv/core_cc/csrc/sparse/all/ops2d.pyi | 3 +- spconv/core_cc/csrc/sparse/all/ops3d.pyi | 3 +- spconv/core_cc/csrc/sparse/all/ops4d.pyi | 3 +- spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi | 6 +- spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi | 6 +- spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi | 6 +- spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi | 6 +- spconv/csrc/sparse/all.py | 10 +- spconv/csrc/sparse/pointops.py | 90 +- spconv/pytorch/constants.py | 12 + spconv/pytorch/conv.py | 18 +- spconv/pytorch/core.py | 9 +- spconv/pytorch/cppcore.py | 18 +- spconv/pytorch/modules.py | 1 + spconv/pytorch/ops.py | 11 +- spconv/pytorch/utils.py | 75 +- test/benchmark.py | 40 +- test/test_all_algo.py | 663 +++++++++++++++ test/test_conv.py | 848 ++++++------------- test/test_implgemm.py | 15 - test/test_multi_impl.py | 327 ++++++- test/test_native_kernels.py | 14 - test_before_push.sh | 10 + version.txt | 2 +- 34 files changed, 1793 insertions(+), 1198 deletions(-) delete mode 100644 scripts/dev_subm.py create mode 100644 spconv/benchmark/__init__.py create mode 100644 spconv/benchmark/__main__.py create mode 100644 spconv/benchmark/basic.py create mode 100644 spconv/benchmark/core.py create mode 100644 test/test_all_algo.py delete mode 100644 test/test_implgemm.py delete mode 100644 test/test_native_kernels.py create mode 100644 test_before_push.sh diff --git a/example/voxel_gen.py b/example/voxel_gen.py index 4b1dd57..504acaf 100644 --- a/example/voxel_gen.py +++ b/example/voxel_gen.py @@ -19,6 +19,67 @@ from spconv.pytorch.utils import PointToVoxel import torch +def main_pytorch_voxel_gen(): + np.random.seed(50051) + # voxel gen source code: spconv/csrc/sparse/pointops.py + gen = PointToVoxel(vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -6, 80, 80, 6], + num_point_features=3, + max_num_voxels=5000, + max_num_points_per_voxel=5) + + pc = np.random.uniform(-4, 4, size=[1000, 3]) + pc_th = torch.from_numpy(pc) + voxels_th, indices_th, num_p_in_vx_th = gen(pc_th) + voxels_np = voxels_th.numpy() + indices_np = indices_th.numpy() + num_p_in_vx_np = num_p_in_vx_th.numpy() + print(f"------Raw Voxels {voxels_np.shape[0]}-------") + print(voxels_np[0]) + # run voxel gen and FILL MEAN VALUE to voxel remain + voxels_th, indices_th, num_p_in_vx_th = gen(pc_th, empty_mean=True) + voxels_np = voxels_th.numpy() + indices_np = indices_th.numpy() + num_p_in_vx_np = num_p_in_vx_th.numpy() + print("------Voxels with mean filled-------") + print(voxels_np[0]) + voxels_th, indices_th, num_p_in_vx_th, pc_voxel_id = gen.generate_voxel_with_id(pc_th, empty_mean=True) + print("------Voxel ids for every point-------") + print(pc_voxel_id[:10]) + + + +def main_pytorch_voxel_gen_cuda(): + np.random.seed(50051) + # voxel gen source code: spconv/csrc/sparse/pointops.py + device = torch.device("cuda:0") + gen = PointToVoxel(vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -6, 80, 80, 6], + num_point_features=3, + max_num_voxels=5000, + max_num_points_per_voxel=5, + device=device) + + pc = np.random.uniform(-4, 4, size=[1000, 3]).astype(np.float32) + pc_th = torch.from_numpy(pc).to(device) + voxels_th, indices_th, num_p_in_vx_th = gen(pc_th) + voxels_np = voxels_th.cpu().numpy() + indices_np = indices_th.cpu().numpy() + num_p_in_vx_np = num_p_in_vx_th.cpu().numpy() + print(f"------Raw Voxels {voxels_np.shape[0]}-------") + print(voxels_np[0]) + # run voxel gen and FILL MEAN VALUE to voxel remain + voxels_tv, indices_tv, num_p_in_vx_tv = gen(pc_th, empty_mean=True) + voxels_np = voxels_tv.cpu().numpy() + indices_np = indices_tv.cpu().numpy() + num_p_in_vx_np = num_p_in_vx_tv.cpu().numpy() + print("------Voxels with mean filled-------") + print(voxels_np[0]) + voxels_th, indices_th, num_p_in_vx_th, pc_voxel_id = gen.generate_voxel_with_id(pc_th, empty_mean=True) + print("------Voxel ids for every point-------") + print(pc[:10]) + print(indices_th[pc_voxel_id[:10]]) + def main(): np.random.seed(50051) @@ -81,58 +142,26 @@ def main_point_with_features(): print("------Voxels with mean filled-------") print(voxels_np[0]) - -def main_pytorch_voxel_gen(): +def main_cuda(): np.random.seed(50051) - # voxel gen source code: spconv/csrc/sparse/pointops.py - gen = PointToVoxel(vsize_xyz=[0.1, 0.1, 0.1], - coors_range_xyz=[-80, -80, -2, 80, 80, 6], - num_point_features=3, - max_num_voxels=5000, - max_num_points_per_voxel=5) - - pc = np.random.uniform(-10, 10, size=[1000, 3]) - pc_th = torch.from_numpy(pc) - voxels_th, indices_th, num_p_in_vx_th = gen(pc_th) - voxels_np = voxels_th.numpy() - indices_np = indices_th.numpy() - num_p_in_vx_np = num_p_in_vx_th.numpy() - print(f"------Raw Voxels {voxels_np.shape[0]}-------") - print(voxels_np[0]) - # run voxel gen and FILL MEAN VALUE to voxel remain - voxels_tv, indices_tv, num_p_in_vx_tv = gen(pc_th, empty_mean=True) - voxels_np = voxels_tv.numpy() - indices_np = indices_tv.numpy() - num_p_in_vx_np = num_p_in_vx_tv.numpy() - print("------Voxels with mean filled-------") - print(voxels_np[0]) - + from spconv.utils import Point2VoxelGPU3d -def main_pytorch_voxel_gen_cuda(): - np.random.seed(50051) # voxel gen source code: spconv/csrc/sparse/pointops.py - device = torch.device("cuda:0") - gen = PointToVoxel(vsize_xyz=[0.1, 0.1, 0.1], - coors_range_xyz=[-80, -80, -2, 80, 80, 6], - num_point_features=3, - max_num_voxels=5000, - max_num_points_per_voxel=5, - device=device) + gen = Point2VoxelGPU3d(vsize_xyz=[0.1, 0.1, 0.1], + coors_range_xyz=[-80, -80, -2, 80, 80, 6], + num_point_features=3, + max_num_voxels=5000, + max_num_points_per_voxel=5) - pc = np.random.uniform(-10, 10, size=[1000, 3]).astype(np.float32) - pc_th = torch.from_numpy(pc).to(device) - voxels_th, indices_th, num_p_in_vx_th = gen(pc_th) - voxels_np = voxels_th.cpu().numpy() - indices_np = indices_th.cpu().numpy() - num_p_in_vx_np = num_p_in_vx_th.cpu().numpy() - print(f"------Raw Voxels {voxels_np.shape[0]}-------") - print(voxels_np[0]) - # run voxel gen and FILL MEAN VALUE to voxel remain - voxels_tv, indices_tv, num_p_in_vx_tv = gen(pc_th, empty_mean=True) + pc = np.random.uniform(-10, 10, size=[100000, 3]).astype(np.float32) + pc_tv = tv.from_numpy(pc).cuda() + # generate voxels, note that voxels_tv reference to a persistent buffer in generator, + # so we can't run it in multi-thread. + voxels_tv, indices_tv, num_p_in_vx_tv = gen.point_to_voxel_hash(pc_tv) voxels_np = voxels_tv.cpu().numpy() indices_np = indices_tv.cpu().numpy() num_p_in_vx_np = num_p_in_vx_tv.cpu().numpy() - print("------Voxels with mean filled-------") + print(f"------CUDA Raw Voxels {voxels_np.shape[0]}-------") print(voxels_np[0]) @@ -141,4 +170,5 @@ def main_pytorch_voxel_gen_cuda(): main_point_with_features() main_pytorch_voxel_gen() if torch.cuda.is_available(): + main_cuda() main_pytorch_voxel_gen_cuda() diff --git a/scripts/dev_subm.py b/scripts/dev_subm.py deleted file mode 100644 index 96c0b43..0000000 --- a/scripts/dev_subm.py +++ /dev/null @@ -1,409 +0,0 @@ -import sys -from pathlib import Path -from typing import Dict, List, Tuple -import pickle -import sys -import time -from pathlib import Path -from cumm.gemm.algospec.core import GemmAlgo - -import numpy as np -import pccm -import torch -import torch.nn.functional as F - -from cumm import dtypes -from cumm import tensorview as tv -from cumm.constants import PACKAGE_ROOT -from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType -from cumm.conv.main import ConvMainUnitTest, gen_gemm_kernels -from cumm.conv.params import ConvProblem -from cumm.gemm import kernel -import os -from spconv.core_cc.csrc.sparse.all import SpconvOps -from cumm.gemm.codeops import div_up -from spconv.constants import PACKAGE_ROOT -from spconv.core import ConvAlgo - -from spconv.pytorch import ops -from spconv.algo import CONV, BestConvAlgoByProfile -from spconv.pytorch.cppcore import torch_tensor_to_tv - - -def reduce_mask_count(mask: np.ndarray, width: int): - mask_length_32 = (div_up(mask.shape[0], width)) * width - if mask.shape[0] < mask_length_32: - mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype) - mask_pad[:mask.shape[0]] = mask - mask = mask_pad - mask = mask.reshape(-1, width) - maskr = np.bitwise_or.reduce(mask, axis=1) - maskr_tv = tv.from_numpy(maskr) - return SpconvOps.count_bits(maskr_tv).numpy().sum() * width - - -def reduce_mask_count_x(mask: np.ndarray, width: int): - mask_length_32 = (div_up(mask.shape[0], width)) * width - if mask.shape[0] < mask_length_32: - mask_pad = np.zeros((mask_length_32, ), dtype=mask.dtype) - mask_pad[:mask.shape[0]] = mask - mask = mask_pad - mask = mask.reshape(-1, width) - maskr = np.bitwise_or.reduce(mask, axis=1) - return maskr - - -def dev_subm_inds_v2(subm: bool = True, run_conv: bool = True): - limit_input_n = 16384 - limit_input_n = None - np.random.seed(484) - - with (PACKAGE_ROOT.parent / "test/data/test_spconv.pkl").open("rb") as f: - voxels_np, indices_np, spatial_shape = pickle.load(f) - from spconv.test_utils import generate_sparse_data - voxels_np = voxels_np[:limit_input_n] - indices_np = indices_np[:limit_input_n] - - # spatial_shape = [19, 18, 17] - # sparse_dict = generate_sparse_data(spatial_shape, [1024], 128) - - # voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype( - # np.float32) - # indices_np = np.ascontiguousarray( - # sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) - - voxels = tv.from_numpy(voxels_np).cuda() - indices = tv.from_numpy(indices_np).cuda() - indices_th = torch.from_numpy(indices_np).cuda() - print(spatial_shape, indices_np.shape) - ndim = 3 - if subm: - ksize = [3, 3, 3] - kv = np.prod(ksize) - padding = [1] * ndim - stride = [1] * ndim - dilation = [1] * ndim - out_padding = [0] * ndim - else: - ksize = [2, 2, 2] - kv = np.prod(ksize) - padding = [0] * ndim - stride = [1] * ndim - dilation = [1] * ndim - out_padding = [0] * ndim - out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs( - indices_th, 1, spatial_shape, ConvAlgo.Native, ksize, stride, padding, - dilation, out_padding, subm) - indice_num_per_loc_np = indice_num_per_loc.cpu().numpy() - indice_pairs_np = pair_ref.cpu().numpy() - algo = ConvAlgo.MaskImplicitGemm - if algo == ConvAlgo.MaskImplicitGemm: - num_split = 1 - else: - num_split = 2 - for i in range(5): - res = ops.get_indice_pairs_implicit_gemm(indices_th, 1, spatial_shape, - algo, ksize, stride, padding, - dilation, out_padding, subm) - out_inds = res[0] - num_inds_per_loc = res[1] - pair_fwd = res[2] - pair_fwd_x = pair_fwd.cpu().numpy().reshape(-1) - pair_fwd_x[pair_fwd_x == -1] = 0 - loc_num_np = (pair_fwd_x > 0).reshape(kv, -1).sum(1) - print(loc_num_np) - print(indice_num_per_loc_np) - - pair_bwd = res[3] - pair_mask_fwd_splits = res[4] - - pair_mask_bwd_splits = res[5] - mask_tv = torch_tensor_to_tv(pair_mask_fwd_splits[0], dtype=tv.uint32).cpu().numpy() - bench_reduce_mask(mask_tv) - return - - mask_argsort_fwd_splits = res[6] - mask_argsort_bwd_splits = res[7] - masks = res[8] - pair_mask_fwd_splits_tv = [ - ops.torch_tensor_to_tv(t, dtype=tv.uint32) - for t in pair_mask_fwd_splits - ] - valid_location_bitcount = [ - SpconvOps.count_bits(t) for t in pair_mask_fwd_splits_tv - ] - valid_location_count = sum( - [t.cpu().numpy().sum() for t in valid_location_bitcount]) - reduce_length = 32 - split_mask_valid_count = sum([ - reduce_mask_count(t.cpu().numpy(), reduce_length) - for t in pair_mask_fwd_splits_tv - ]) - if subm: - print("SUBM", valid_location_count, split_mask_valid_count, - pair_fwd.numel()) - else: - print("REGULAR", valid_location_count, split_mask_valid_count, - pair_fwd.numel()) - # return - - if run_conv: - C = 64 - K = 64 - desps = CONV.desps - mask_output_fwd = torch.zeros([2, div_up(out_inds.shape[0], 32)], - dtype=torch.int32, - device=indices_th.device) - mask_output_bwd = torch.zeros([2, div_up(indices.dim(0), 32)], - dtype=torch.int32, - device=indices_th.device) - - for desp in desps: - if desp.algo != GemmAlgo.Simt.value: - continue - # if desp.op_type == ConvOpType.kBackwardWeight.value: - # continue - # if desp.tile_shape ! - if desp.dtype_a == dtypes.int8.tv_dtype: - inp = np.random.randint(-1, 1, size=[voxels_np.shape[0], - C]).astype(np.int8) - weight = np.random.randint(-1, 1, size=[K, *ksize, - C]).astype(np.int8) - output = np.random.randint(-1, 1, size=[ - out_inds.shape[0], K - ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) - else: - inp = np.random.uniform(-1, 1, size=[ - voxels_np.shape[0], C - ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_input)) - weight = np.random.uniform(-1, 1, size=[K, *ksize, C]).astype( - dtypes.get_npdtype_from_tvdtype(desp.dtype_weight)) - output = np.random.uniform(-1, 1, size=[ - out_inds.shape[0], K - ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output)) - weight_ref = weight.transpose(1, 2, 3, 0, 4) - weight_ref = np.ascontiguousarray(weight_ref).reshape(-1, K, C) - if desp.op_type == ConvOpType.kBackwardInput.value: - inp_tv = tv.zeros(inp.shape, desp.dtype_input, 0) - else: - inp_tv = tv.from_numpy(inp).cuda() - if desp.op_type == ConvOpType.kBackwardWeight.value: - weight_tv = tv.zeros(weight.shape, desp.dtype_weight, 0) - else: - weight_tv = tv.from_numpy(weight).cuda() - # _ = tv.zeros([5000, 10], tv.float32, 0) - if desp.op_type == ConvOpType.kForward.value: - output_tv = tv.zeros(output.shape, desp.dtype_output, 0) - else: - output_tv = tv.from_numpy(output).cuda() - torch.cuda.synchronize() - t = time.time() - spk = 1 - if desp.op_type == ConvOpType.kBackwardWeight.value: - # TODO support splitk parallel - spk = 32 - if subm: - if desp.op_type == ConvOpType.kForward.value: - indice_pairs = pair_fwd - elif desp.op_type == ConvOpType.kBackwardInput.value: - indice_pairs = pair_bwd - else: - indice_pairs = pair_fwd - mask_output = mask_output_fwd - # print([bin(x.item()) for x in masks]) - for j in range(num_split): - beta = 1 if j == 1 else 0 - mask_filter = 0xffffffff - mask_filter = masks[j].item() - - reverse_mask = False - if desp.op_type == ConvOpType.kBackwardWeight.value: - mask_op = mask_output[j] - else: - mask_op = pair_mask_fwd_splits[j] - if desp.op_type == ConvOpType.kBackwardInput.value: - reverse_mask = True - CONV.run_with_tuned_result( - BestConvAlgoByProfile(desp, spk), - desp.op_type, - inp_tv, - weight_tv, - output_tv, - torch_tensor_to_tv(mask_op, dtype=tv.uint32), - torch_tensor_to_tv(mask_argsort_fwd_splits[j]), - torch_tensor_to_tv(mask_output[j], dtype=tv.uint32), - torch_tensor_to_tv(indice_pairs), - reverse_mask, - mask_filter=mask_filter, - mask_width=32, - beta=beta, - verbose=True, - ) - else: - if desp.op_type == ConvOpType.kForward.value: - indice_pairs = pair_fwd # inp -> out - mask_ops = pair_mask_fwd_splits - mask_argsorts = mask_argsort_fwd_splits - mask_output = mask_output_fwd - elif desp.op_type == ConvOpType.kBackwardInput.value: - indice_pairs = pair_bwd # out -> inp - mask_ops = pair_mask_bwd_splits - mask_argsorts = mask_argsort_bwd_splits - mask_output = mask_output_bwd - - print([bin(x.item()) for x in masks]) - else: - indice_pairs = pair_fwd # inp -> out - mask_ops = pair_mask_fwd_splits - mask_argsorts = mask_argsort_fwd_splits - mask_output = mask_output_fwd - - for j in range(2): - beta = 1 if j == 1 else 0 - mask_filter = masks[j].item() - reverse_mask = False - if desp.op_type == ConvOpType.kBackwardWeight.value: - mask_op = mask_output[j] - else: - mask_op = mask_ops[j] - - CONV.run_with_tuned_result( - BestConvAlgoByProfile(desp, spk), - desp.op_type, - inp_tv, - weight_tv, - output_tv, - torch_tensor_to_tv(mask_op, dtype=tv.uint32), - torch_tensor_to_tv(mask_argsorts[j]), - torch_tensor_to_tv(mask_output[j], dtype=tv.uint32), - torch_tensor_to_tv(indice_pairs), - reverse_mask, - mask_filter=mask_filter, - mask_width=32, - beta=beta, - verbose=True, - ) - - torch.cuda.synchronize() - duration = time.time() - t - if desp.op_type == ConvOpType.kForward.value: - output_ref = np.zeros_like(output, dtype=np.float32) - # ref algorithm - for filter_offset in range(kv): - if subm and filter_offset > kv // 2: - nhot = indice_num_per_loc_np[kv - 1 - filter_offset] - elif subm and filter_offset == kv // 2: - nhot = voxels.shape[0] - else: - nhot = indice_num_per_loc_np[filter_offset] - a_inds = indice_pairs_np[0][filter_offset][:nhot] - c_inds = indice_pairs_np[1][filter_offset][:nhot] - # print(a_inds_cpu[:10]) - a = inp[a_inds] - cc = a.astype( - np.float32) @ weight_ref[filter_offset].T.astype( - np.float32) - output_ref[c_inds] += cc - - output_cpu = output_tv.cpu().numpy().astype(np.float32) - duration = time.time() - t - my = output_cpu.reshape(-1) - print("ERROR", np.linalg.norm(output_ref.reshape(-1) - my)) - - elif desp.op_type == ConvOpType.kBackwardInput.value: - dinput_ref = np.zeros_like(inp, dtype=np.float32) - # ref algorithm - for filter_offset in range(kv): - if subm and filter_offset > kv // 2: - nhot = indice_num_per_loc_np[kv - 1 - filter_offset] - elif subm and filter_offset == kv // 2: - nhot = voxels.shape[0] - else: - nhot = indice_num_per_loc_np[filter_offset] - a_inds = indice_pairs_np[1][filter_offset][:nhot] - c_inds = indice_pairs_np[0][filter_offset][:nhot] - - # print(a_inds_cpu[:10]) - a = output[a_inds] - # NK @ KC - cc = a.astype( - np.float32) @ weight_ref[filter_offset].astype( - np.float32) - dinput_ref[c_inds] += cc - din_cpu = inp_tv.cpu().numpy() - print( - "ERROR", - np.linalg.norm( - din_cpu.reshape(-1) - dinput_ref.reshape(-1))) - else: - dw_ref = np.zeros_like(weight_ref, - dtype=np.float32) # KV, K, C - for filter_offset in range(kv): - if subm and filter_offset > kv // 2: - nhot = indice_num_per_loc_np[kv - 1 - filter_offset] - elif subm and filter_offset == kv // 2: - nhot = voxels.shape[0] - else: - nhot = indice_num_per_loc_np[filter_offset] - o_inds = indice_pairs_np[1][filter_offset][:nhot] - i_inds = indice_pairs_np[0][filter_offset][:nhot] - # print(a_inds_cpu[:10]) - out_gather = output[o_inds] # [N, K] - inp_gather = inp[i_inds] # [N, C] - # KN @ NC - dw_res = out_gather.astype( - np.float32).T @ inp_gather.astype(np.float32) - dw_ref[filter_offset] = dw_res - # print(indice_pairs_np_test[0]) - dw_ref_kcrs = dw_ref.transpose(1, 0, 2) - dw_cpu = weight_tv.cpu().numpy().reshape(K, np.prod(ksize), C) - - print( - "ERROR", - np.linalg.norm( - dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1))) - -def reverse_bits(a: np.ndarray): - a_unpack = np.unpackbits(a, bitorder="little") - return np.packbits(a_unpack) - -def _count_mask_reduce(masks: np.ndarray): - masks_tv_count = SpconvOps.count_bits(tv.from_numpy(masks)) - masks_tv_count_sum = masks_tv_count.numpy_view().sum() - - reduce_count = reduce_mask_count(masks, 64) - print(masks_tv_count_sum, reduce_count, reduce_count / masks_tv_count_sum) - - -def bench_reduce_mask(masks: np.ndarray, width: int = 27): - # masks = np.random.randint(0, 2000000000, size=[100000], dtype=np.uint32)# & 0xffff - width_mask = np.array(0xffffffff, dtype=np.uint32) << (32 - width) >> (32 - width) - - width_half_mask = np.array(0xffffffff, dtype=np.uint32) >> (32 - width // 2 - 1) - width_half_mask_left = width_half_mask << (width // 2 + 1) - print(bin(width_half_mask)) - masks_sort = masks.copy() - masks_sort.sort() - _count_mask_reduce(masks_sort) - masks_sort = masks.copy() & width_half_mask - masks_sort.sort() - _count_mask_reduce(masks_sort) - - # masks.sort() - # masks = masks & 0xffff - - reversed_masks = SpconvOps.reverse_bits(tv.from_numpy(masks)).numpy()# & 0xffff0000 - new_masks = np.concatenate([masks, reversed_masks]) - - np.random.shuffle(new_masks) - new_masks.sort() - _count_mask_reduce(new_masks) - new_masks &= width_half_mask - new_masks.sort() - _count_mask_reduce(new_masks) - - - - -if __name__ == "__main__": - dev_subm_inds_v2() diff --git a/spconv/algo.py b/spconv/algo.py index 7e6c017..403cbcc 100644 --- a/spconv/algo.py +++ b/spconv/algo.py @@ -131,9 +131,9 @@ def get_all_available( # skip volta tensor op since it is very slow in architectures except volta. if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: continue - lda = a.dim(1) - ldb = b.dim(1) - ldc = c.dim(1) + lda = a.stride[0] + ldb = b.stride[0] + ldc = c.stride[0] if desp.supported_ldx(lda, ldb, ldc): finally_algos.append(desp) return finally_algos diff --git a/spconv/benchmark/__init__.py b/spconv/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spconv/benchmark/__main__.py b/spconv/benchmark/__main__.py new file mode 100644 index 0000000..1474cf3 --- /dev/null +++ b/spconv/benchmark/__main__.py @@ -0,0 +1,6 @@ +from .basic import bench_basic + +import fire + +if __name__ == "__main__": + fire.Fire() diff --git a/spconv/benchmark/basic.py b/spconv/benchmark/basic.py new file mode 100644 index 0000000..9aced6b --- /dev/null +++ b/spconv/benchmark/basic.py @@ -0,0 +1,199 @@ +from spconv.benchmark.core import get_voxel_data + + +import time +from pathlib import Path + +import numpy as np +import torch +from torch import nn +from cumm import tensorview as tv +from spconv.core import ConvAlgo +from cumm import dtypes +import spconv.pytorch as spconv +from spconv.test_utils import params_grid + +class Net(nn.Module): + def __init__(self, shape, algo): + super().__init__() + pool_algo = algo + # pool_algo = ConvAlgo.Native + self.net = spconv.SparseSequential( + spconv.SubMConv3d(3, 64, 3, bias=False, indice_key="c0", + algo=algo), + + spconv.SubMConv3d(64, + 64, + 3, + bias=False, + indice_key="c0", + algo=algo), + # nn.BatchNorm1d(32), + # nn.ReLU(), + # spconv.SparseConv3d(64, 64, 2, 2, bias=False, indice_key="m0"), + spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(64, + 96, + 3, + bias=False, + indice_key="c1", + algo=algo), + spconv.SubMConv3d(96, + 96, + 3, + bias=False, + indice_key="c1", + algo=algo), + # nn.BatchNorm1d(64), + # nn.ReLU(), + # spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1"), + spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(96, + 128, + 3, + bias=False, + indice_key="c2", + algo=algo), + spconv.SubMConv3d(128, + 128, + 3, + bias=False, + indice_key="c2", + algo=algo), + # nn.BatchNorm1d(128), + # nn.ReLU(), + # spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"), + spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(128, + 160, + 3, + bias=False, + indice_key="c3", + algo=algo), + spconv.SubMConv3d(160, + 160, + 3, + bias=False, + indice_key="c3", + algo=algo), + # nn.BatchNorm1d(128), + # nn.ReLU(), + # spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"), + spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(160, + 192, + 3, + bias=False, + indice_key="c4", + algo=algo), + spconv.SubMConv3d(192, + 192, + 3, + bias=False, + indice_key="c4", + algo=algo), + # nn.BatchNorm1d(128), + # nn.ReLU(), + spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo), + # spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"), + spconv.SubMConv3d(192, + 224, + 3, + bias=False, + indice_key="c5", + algo=algo), + spconv.SubMConv3d(224, + 224, + 3, + bias=False, + indice_key="c5", + algo=algo), + # nn.BatchNorm1d(224), + # nn.ReLU(), + # spconv.SparseConv3d(224, 224, 2, 2, bias=False, indice_key="m5"), + spconv.SparseMaxPool3d(2, 2, indice_key="m5", algo=pool_algo), + spconv.SubMConv3d(224, + 256, + 3, + bias=False, + indice_key="c6", + algo=algo), + spconv.SubMConv3d(256, + 256, + 3, + bias=False, + indice_key="c6", + algo=algo), + + # nn.BatchNorm1d(256), + # nn.ReLU(), + + # spconv.SparseInverseConv3d(256, 128, 2, indice_key="m5", bias=False, algo=algo), + # # # nn.BatchNorm1d(128), + # # # nn.ReLU(), + + # spconv.SparseInverseConv3d(128, 64, 2, indice_key="m4", bias=False, algo=algo), + ) + max_batch_size = 1 + # grid (dense map) is used for indice generation. use pre-allocated grid can run faster. + self.grid = torch.full([max_batch_size, *shape], -1, + dtype=torch.int32).cuda() + # self.grid = None + self.shape = shape + + def forward(self, features, coors, batch_size, enable_timer: bool = False): + x = spconv.SparseConvTensor(features, + coors, + self.shape, + batch_size, + self.grid, + enable_timer=enable_timer) + return self.net(x) + +_DTYPE_TO_TORCH_DTYPE = { + dtypes.float32: torch.float32, + dtypes.float16: torch.float16, +} + +def bench_basic(dtype_str: str): + dtype = dtypes.get_dtype_by_shortcut(dtype_str) + if dtype not in _DTYPE_TO_TORCH_DTYPE: + raise NotImplementedError("only support bench f32 and f16 for now") + torch_dtype = _DTYPE_TO_TORCH_DTYPE[dtype] + algos = [spconv.ConvAlgo.Native, spconv.ConvAlgo.MaskImplicitGemm, spconv.ConvAlgo.MaskSplitImplicitGemm] + (voxels, coors, spatial_shape) = get_voxel_data() + device = torch.device("cuda:0") + + for algo, in params_grid(algos): + voxels_th = torch.from_numpy(voxels).to(device).to(torch_dtype) + coors_th = torch.from_numpy(coors).to(device).int() + voxels_th.requires_grad = True + net = Net(spatial_shape, algo).to(device).train().to(torch_dtype)# .train() + spconv.assign_name_for_sparse_modules(net) + with torch.no_grad(): + out: spconv.SparseConvTensor = net(voxels_th, coors_th, 1) + dout = np.random.uniform(-0.2, 0.2, out.features.shape).astype(np.float32) + dout_t = torch.from_numpy(dout).to(device).to(torch_dtype) + times = [] + with torch.no_grad(): + for i in range(20): + torch.cuda.synchronize() + t = time.time() + out_nograd = net(voxels_th, coors_th, 1, False) + timer = out_nograd._timer + torch.cuda.synchronize() + times.append(time.time() - t) + print(f"basic[{dtype_str}|{algo}|forward]", np.mean(times[10:])) + times = [] + + for i in range(10): + out = net(voxels_th, coors_th, 1) + torch.cuda.synchronize() + t = time.time() + out.features.backward(dout_t) + torch.cuda.synchronize() + times.append(time.time() - t) + print(f"basic[{dtype_str}|{algo}|backward]", np.mean(times[5:])) + +if __name__ == "__main__": + bench_basic("f16") \ No newline at end of file diff --git a/spconv/benchmark/core.py b/spconv/benchmark/core.py new file mode 100644 index 0000000..16a5d03 --- /dev/null +++ b/spconv/benchmark/core.py @@ -0,0 +1,40 @@ +import requests +import fire +import pickle +from io import BytesIO +import numpy as np +from spconv.constants import PACKAGE_ROOT + +RAW_PC_PATH = "https://raw.githubusercontent.com/traveller59/spconv/v2.1.10/test/data/test_spconv.pkl" + +def get_voxel_data(): + editable_test_data_path = PACKAGE_ROOT.parent / "test/data/test_spconv.pkl" + if editable_test_data_path.exists(): + with editable_test_data_path.open("rb") as f: + return pickle.load(f) + ff = BytesIO() + with requests.get(RAW_PC_PATH, stream=True) as req: + req.raise_for_status() + for chunk in req.iter_content(chunk_size=8192): + ff.write(chunk) + ff.seek(0) + (voxels, coors, spatial_shape) = pickle.load(ff) + return voxels, coors, spatial_shape + +def get_pc_data(): + editable_test_data_path = PACKAGE_ROOT.parent / "test/data/benchmark-pc.npz" + if editable_test_data_path.exists(): + pc = np.load(str(editable_test_data_path))["pc"] + return pc + ff = BytesIO() + with requests.get(RAW_PC_PATH, stream=True) as req: + req.raise_for_status() + for chunk in req.iter_content(chunk_size=8192): + ff.write(chunk) + ff.seek(0) + pc = np.load(ff)["pc"] + return pc + +if __name__ == "__main__": + pc = get_pc_data() + print(pc[:10]) \ No newline at end of file diff --git a/spconv/core.py b/spconv/core.py index d8a45a1..d3ce7d9 100644 --- a/spconv/core.py +++ b/spconv/core.py @@ -452,7 +452,7 @@ class AlgoHint(Enum): *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, - 2, ["f16,f16,f16,f16,f16"], + 2, ["f16,f16,f16,f32,f32"], NHWC, NHWC, NHWC, @@ -464,7 +464,7 @@ class AlgoHint(Enum): *gen_conv_params(ConvBwdWeight, (64, 64, 32), (32, 32, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, - 2, ["f16,f16,f16,f16,f16"], + 2, ["f16,f16,f16,f32,f32"], NHWC, NHWC, NHWC, @@ -476,7 +476,7 @@ class AlgoHint(Enum): *gen_conv_params(ConvBwdWeight, (128, 128, 32), (32, 64, 32), NDIM_DONT_CARE, ConvIterAlgo.Optimized, - 2, ["f16,f16,f16,f16,f16"], + 2, ["f16,f16,f16,f32,f32"], NHWC, NHWC, NHWC, diff --git a/spconv/core_cc/csrc/sparse/all/__init__.pyi b/spconv/core_cc/csrc/sparse/all/__init__.pyi index 5e8e4e9..1b6e157 100644 --- a/spconv/core_cc/csrc/sparse/all/__init__.pyi +++ b/spconv/core_cc/csrc/sparse/all/__init__.pyi @@ -298,7 +298,7 @@ class SpconvOps: """ ... @staticmethod - def point2voxel_cpu(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], empty_mean: bool = False, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point2voxel_cpu(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, pc_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], empty_mean: bool = False, clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -306,6 +306,7 @@ class SpconvOps: indices: num_per_voxel: densehashdata: + pc_voxel_id: vsize: grid_size: grid_stride: @@ -315,7 +316,7 @@ class SpconvOps: """ ... @staticmethod - def point2voxel_cuda(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], empty_mean: bool = False, clear_voxels: bool = True, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + def point2voxel_cuda(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, pc_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], empty_mean: bool = False, clear_voxels: bool = True, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -324,6 +325,7 @@ class SpconvOps: num_per_voxel: hashdata: point_indice_data: + pc_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops1d.pyi b/spconv/core_cc/csrc/sparse/all/ops1d.pyi index 03e57ca..7d9535a 100644 --- a/spconv/core_cc/csrc/sparse/all/ops1d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops1d.pyi @@ -29,7 +29,7 @@ class Point2Voxel: """ ... @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,6 +38,7 @@ class Point2Voxel: num_per_voxel: hashdata: point_indice_data: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops2d.pyi b/spconv/core_cc/csrc/sparse/all/ops2d.pyi index 03e57ca..7d9535a 100644 --- a/spconv/core_cc/csrc/sparse/all/ops2d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops2d.pyi @@ -29,7 +29,7 @@ class Point2Voxel: """ ... @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,6 +38,7 @@ class Point2Voxel: num_per_voxel: hashdata: point_indice_data: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops3d.pyi b/spconv/core_cc/csrc/sparse/all/ops3d.pyi index 03e57ca..7d9535a 100644 --- a/spconv/core_cc/csrc/sparse/all/ops3d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops3d.pyi @@ -29,7 +29,7 @@ class Point2Voxel: """ ... @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,6 +38,7 @@ class Point2Voxel: num_per_voxel: hashdata: point_indice_data: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops4d.pyi b/spconv/core_cc/csrc/sparse/all/ops4d.pyi index 03e57ca..7d9535a 100644 --- a/spconv/core_cc/csrc/sparse/all/ops4d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops4d.pyi @@ -29,7 +29,7 @@ class Point2Voxel: """ ... @staticmethod - def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_hash_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, hashdata: Tensor, point_indice_data: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True, empty_mean: bool = False, stream_int: int = 0) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -38,6 +38,7 @@ class Point2Voxel: num_per_voxel: hashdata: point_indice_data: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi index d44e436..2c57e2e 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi @@ -27,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -35,6 +35,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: @@ -43,7 +44,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -51,6 +52,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi index d44e436..2c57e2e 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi @@ -27,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -35,6 +35,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: @@ -43,7 +44,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -51,6 +52,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi index d44e436..2c57e2e 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi @@ -27,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -35,6 +35,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: @@ -43,7 +44,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -51,6 +52,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi b/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi index d44e436..2c57e2e 100644 --- a/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi +++ b/spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi @@ -27,7 +27,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -35,6 +35,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: @@ -43,7 +44,7 @@ class Point2VoxelCPU: """ ... @staticmethod - def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + def point_to_voxel_empty_mean_static(points: Tensor, voxels: Tensor, indices: Tensor, num_per_voxel: Tensor, densehashdata: Tensor, points_voxel_id: Tensor, vsize: List[float], grid_size: List[int], grid_stride: List[int], coors_range: List[float], clear_voxels: bool = True) -> Tuple[Tensor, Tensor, Tensor]: """ Args: points: @@ -51,6 +52,7 @@ class Point2VoxelCPU: indices: num_per_voxel: densehashdata: + points_voxel_id: vsize: grid_size: grid_stride: diff --git a/spconv/csrc/sparse/all.py b/spconv/csrc/sparse/all.py index fc9206d..d9f904d 100644 --- a/spconv/csrc/sparse/all.py +++ b/spconv/csrc/sparse/all.py @@ -920,7 +920,7 @@ def calc_point2voxel_meta_data(self): def point2voxel_cpu(self): code = pccm.FunctionCode() code.arg("points", "tv::Tensor") - code.arg("voxels, indices, num_per_voxel, densehashdata", "tv::Tensor") + code.arg("voxels, indices, num_per_voxel, densehashdata, pc_voxel_id", "tv::Tensor") code.arg("vsize", f"std::vector") code.arg("grid_size, grid_stride", f"std::vector") code.arg("coors_range", f"std::vector") @@ -950,11 +950,11 @@ def point2voxel_cpu(self): }} if (empty_mean){{ return Point2Voxel{ndim}DCPU::point_to_voxel_empty_mean_static(points, voxels, indices, - num_per_voxel, densehashdata, + num_per_voxel, densehashdata, pc_voxel_id, vsize_, grid_size_, grid_stride_, coors_range_, clear_voxels); }} else{{ return Point2Voxel{ndim}DCPU::point_to_voxel_static(points, voxels, indices, - num_per_voxel, densehashdata, + num_per_voxel, densehashdata, pc_voxel_id, vsize_, grid_size_, grid_stride_, coors_range_, clear_voxels); }} }} @@ -967,7 +967,7 @@ def point2voxel_cpu(self): def point2voxel_cuda(self): code = pccm.FunctionCode() code.arg("points", "tv::Tensor") - code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", + code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data, pc_voxel_id", "tv::Tensor") code.arg("vsize", f"std::vector") code.arg("grid_size, grid_stride", f"std::vector") @@ -1000,7 +1000,7 @@ def point2voxel_cuda(self): coors_range_[i + {ndim}] = coors_range[i + {ndim}]; }} return Point2Voxel{ndim}D::point_to_voxel_hash_static(points, voxels, indices, - num_per_voxel, hashdata, point_indice_data, + num_per_voxel, hashdata, point_indice_data, pc_voxel_id, vsize_, grid_size_, grid_stride_, coors_range_, clear_voxels, empty_mean, stream_int); }} diff --git a/spconv/csrc/sparse/pointops.py b/spconv/csrc/sparse/pointops.py index 8ff8651..61feb4d 100644 --- a/spconv/csrc/sparse/pointops.py +++ b/spconv/csrc/sparse/pointops.py @@ -208,6 +208,7 @@ def generate_voxel(self): code.arg("points_indice_data", f"const int64_t*") code.arg("voxels", f"{self.dtype} *") code.arg("num_per_voxel", f"int *") + code.arg("points_voxel_id", f"int64_t*") code.arg("point_stride", f"int") code.arg("max_points_per_voxel", f"int") @@ -219,14 +220,17 @@ def generate_voxel(self): code.arg("grid_stride", f"tv::array") code.arg("num_points", f"int") + # TODO add backward? code.raw(f""" int voxel_stride0 = point_stride * max_points_per_voxel; for (int i : tv::KernelLoopX(num_points)){{ int64_t prod = points_indice_data[i]; + int voxel_id = -1; if (prod != -1){{ auto voxel_index_pair = table.lookup(prod); if (!voxel_index_pair.empty() && voxel_index_pair.second < max_voxels) {{ + voxel_id = voxel_index_pair.second; int old = atomicAdd(num_per_voxel + voxel_index_pair.second, 1); if (old < max_points_per_voxel) {{ for (int j = 0; j < point_stride; ++j) {{ @@ -235,6 +239,7 @@ def generate_voxel(self): }} }} }} + points_voxel_id[i] = voxel_id; }} """) return code @@ -385,6 +390,7 @@ def point_to_voxel_hash(self): code.arg("stream_int", f"std::uintptr_t", "0") code.raw(f""" + tv::Tensor points_voxel_id = tv::empty({{points.dim(0)}}, tv::int64, 0); int64_t expected_hash_data_num = points.dim(0) * 2; if (hashdata.dim(0) < expected_hash_data_num){{ hashdata = tv::zeros({{expected_hash_data_num}}, tv::custom128, 0); @@ -393,74 +399,18 @@ def point_to_voxel_hash(self): point_indice_data = tv::zeros({{points.dim(0)}}, tv::int64, 0); }} return point_to_voxel_hash_static(points, voxels, indices, num_per_voxel, - hashdata, point_indice_data, Point2VoxelCommon::tvarray2array(vsize), + hashdata, point_indice_data, points_voxel_id, Point2VoxelCommon::tvarray2array(vsize), Point2VoxelCommon::tvarray2array(grid_size), Point2VoxelCommon::tvarray2array(grid_stride), Point2VoxelCommon::tvarray2array(coors_range), clear_voxels, empty_mean, stream_int); """) return code.ret("std::tuple") - code.raw(f""" - - TV_ASSERT_INVALID_ARG(points.ndim() == 2 && points.dim(1) >= {self.ndim}, "error"); - using V = int64_t; - using KeyType = int64_t; - constexpr KeyType kEmptyKey = std::numeric_limits::max(); - if (clear_voxels){{ - voxels.zero_(); - }} - using table_t = - tv::hash::LinearHashTable, - kEmptyKey, false>; - using pair_t = typename table_t::value_type; - // int64_t expected_hash_data_num = int64_t(tv::hash::align_to_power2(points.dim(0) * 2)); - int64_t expected_hash_data_num = points.dim(0) * 2; - - if (hashdata.dim(0) < expected_hash_data_num){{ - hashdata = tv::zeros({{expected_hash_data_num}}, tv::custom128, 0); - }} - if (point_indice_data.dim(0) < points.dim(0)){{ - point_indice_data = tv::zeros({{points.dim(0)}}, tv::int64, 0); - }} - // auto timer = tv::CudaContextTimer<>(); - num_per_voxel.zero_(); - table_t hash = table_t(hashdata.data_ptr(), expected_hash_data_num); - hash.clear(); - // tv::ssprint("clear time", timer.report()); - auto launcher = tv::cuda::Launch(points.dim(0)); - launcher(kernel::build_hash_table, hash, points.data_ptr(), - point_indice_data.data_ptr(), - points.dim(1), vsize, coors_range, grid_size, grid_stride, points.dim(0)); - // tv::ssprint("build_hash_table", timer.report()); - - auto table_launcher = tv::cuda::Launch(hash.size()); - tv::Tensor count = tv::zeros({{1}}, tv::int32, 0); - Layout layout = Layout::from_shape(grid_size); - table_launcher(kernel::assign_table, hash, indices.data_ptr(), - count.data_ptr(), - layout, voxels.dim(0)); - auto count_cpu = count.cpu(); - int count_val = count_cpu.item(); - // tv::ssprint("assign_table", timer.report()); - - launcher(kernel::generate_voxel, hash, points.data_ptr(), - point_indice_data.data_ptr(), voxels.data_ptr<{self.dtype}>(), - num_per_voxel.data_ptr(), points.dim(1), voxels.dim(1), - voxels.dim(0), vsize, coors_range, - grid_size, grid_stride, points.dim(0)); - // tv::ssprint("generate_voxel", timer.report()); - - return std::make_tuple(voxels.slice_first_axis(0, count_val), - indices.slice_first_axis(0, count_val), - num_per_voxel.slice_first_axis(0, count_val)); - - """) - return code.ret("std::tuple") @pccm.pybind.mark @pccm.cuda.static_function def point_to_voxel_hash_static(self): code = pccm.FunctionCode() code.arg("points", "tv::Tensor") - code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data", + code.arg("voxels, indices, num_per_voxel, hashdata, point_indice_data, points_voxel_id", "tv::Tensor") code.arg("vsize", f"std::array") code.arg("grid_size, grid_stride", f"std::array") @@ -516,7 +466,7 @@ def point_to_voxel_hash_static(self): launcher(kernel::generate_voxel, hash, points.data_ptr(), point_indice_data.data_ptr(), voxels.data_ptr<{self.dtype}>(), - num_per_voxel.data_ptr(), points.dim(1), voxels.dim(1), + num_per_voxel.data_ptr(), points_voxel_id.data_ptr(), points.dim(1), voxels.dim(1), voxels.dim(0), vsize_tv, coors_range_tv, grid_size_tv, grid_stride_tv, points.dim(0)); // tv::ssprint("generate_voxel", timer.report()); @@ -636,7 +586,7 @@ def ctor(self): def point_to_voxel_static_template(self, mean: bool = False): code = pccm.FunctionCode() code.arg("points", "tv::Tensor") - code.arg("voxels, indices, num_per_voxel, densehashdata", "tv::Tensor") + code.arg("voxels, indices, num_per_voxel, densehashdata, points_voxel_id", "tv::Tensor") code.arg("vsize", f"std::array") code.arg("grid_size, grid_stride", f"std::array") code.arg("coors_range", f"std::array") @@ -653,6 +603,7 @@ def point_to_voxel_static_template(self, mean: bool = False): if (clear_voxels){{ voxels.zero_(); }} + auto points_voxel_id_ptr = points_voxel_id.data_ptr(); int res_voxel_num = 0; int num_features = points.dim(1); auto N = points.dim(0); @@ -680,20 +631,25 @@ def point_to_voxel_static_template(self, mean: bool = False): }} coor[j] = c; }} - if (failed) + if (failed){{ + points_voxel_id_ptr[i] = -1; continue; + }} voxelidx = coor_to_voxelidx_rw({codeops.unpack("coor", range(self.ndim))}); - + if (voxelidx == -1) {{ voxelidx = voxel_num; - if (voxel_num >= max_num_voxels) + if (voxel_num >= max_num_voxels){{ + points_voxel_id_ptr[i] = -1; continue; + }} voxel_num += 1; coor_to_voxelidx_rw({codeops.unpack("coor", range(self.ndim))}) = voxelidx; for (int k = 0; k < {self.ndim}; ++k) {{ coors_rw(voxelidx, k) = coor[k]; }} }} + points_voxel_id_ptr[i] = voxelidx; num = num_points_per_voxel_rw(voxelidx); if (num < max_num_points_per_voxel) {{ // voxel_point_mask_rw(voxelidx, num) = {self.dtype}(1); @@ -781,8 +737,10 @@ def point_to_voxel(self): code.arg("points", "tv::Tensor") code.arg("clear_voxels", "bool", "true") code.raw(f""" + tv::Tensor points_voxel_id = tv::empty({{points.dim(0)}}, tv::int64, -1); + return point_to_voxel_static(points, voxels, indices, num_per_voxel, densehashdata, - tvarray2array(vsize), + points_voxel_id, tvarray2array(vsize), tvarray2array(grid_size), tvarray2array(grid_stride), tvarray2array(coors_range), clear_voxels); """) @@ -795,8 +753,10 @@ def point_to_voxel_empty_mean(self): code.arg("points", "tv::Tensor") code.arg("clear_voxels", "bool", "true") code.raw(f""" + tv::Tensor points_voxel_id = tv::empty({{points.dim(0)}}, tv::int64, -1); + return point_to_voxel_empty_mean_static(points, voxels, indices, num_per_voxel, - densehashdata, tvarray2array(vsize), + densehashdata, points_voxel_id, tvarray2array(vsize), tvarray2array(grid_size), tvarray2array(grid_stride), tvarray2array(coors_range), clear_voxels); """) diff --git a/spconv/pytorch/constants.py b/spconv/pytorch/constants.py index 1de0083..16165e9 100644 --- a/spconv/pytorch/constants.py +++ b/spconv/pytorch/constants.py @@ -27,3 +27,15 @@ except: # for unknown errors, just set a version PYTORCH_VERSION = [1, 8, 0] + + +if PYTORCH_VERSION >= [1, 6, 0]: + TORCH_HAS_AMP = True +else: + TORCH_HAS_AMP = False + +def is_amp_enabled(): + if TORCH_HAS_AMP: + return torch.is_autocast_enabled() + else: + return False \ No newline at end of file diff --git a/spconv/pytorch/conv.py b/spconv/pytorch/conv.py index 5b5a601..d3d2308 100644 --- a/spconv/pytorch/conv.py +++ b/spconv/pytorch/conv.py @@ -35,6 +35,20 @@ FILTER_HWIO = False + +def expand_nd(val: Union[int, List[int], Tuple[int, ...]], ndim: int) -> List[int]: + if isinstance(val, int): + val = [val] * ndim + elif isinstance(val, list): + assert len(val) == ndim + elif isinstance(val, tuple): + assert len(val) == ndim + return [*val] + else: + raise NotImplementedError + return val + + def _calculate_fan_in_and_fan_out_hwio(tensor, algo: ConvAlgo): dimensions = tensor.ndimension() if dimensions < 2: @@ -110,7 +124,9 @@ def __init__(self, self.out_channels = out_channels self.kernel_size = kernel_size kv = int(np.prod(kernel_size)) - self.conv1x1 = kv == 1 + kv_stride = int(np.prod(kernel_size)) + + self.conv1x1 = kv == 1 and kv_stride == 1 self.stride = stride self.padding = padding self.dilation = dilation diff --git a/spconv/pytorch/core.py b/spconv/pytorch/core.py index 21bbee9..231eab7 100644 --- a/spconv/pytorch/core.py +++ b/spconv/pytorch/core.py @@ -104,7 +104,8 @@ def __init__(self, indice_dict: Optional[dict] = None, benchmark: bool = False, permanent_thrust_allocator: bool = False, - enable_timer: bool = False): + enable_timer: bool = False, + force_algo: Optional[ConvAlgo] = None): """ Args: features: [num_points, num_features] feature tensor @@ -115,6 +116,8 @@ def __init__(self, is very large. benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to SparseConvTensor. + enable_timer: if exists, all spconv internal ops run time will be record in _timer. + force_algo: force conv/pool layers use this algo, should only used for debug. """ ndim = indices.shape[1] - 1 assert features.ndim == 2 @@ -139,6 +142,7 @@ def __init__(self, if permanent_thrust_allocator: self.thrust_allocator = ThrustSortAllocator(features.device) self._timer = CUDAKernelTimer(enable_timer) + self.force_algo = force_algo def replace_feature(self, feature: torch.Tensor): """we need to replace x.features = F.relu(x.features) with x = x.replace_feature(F.relu(x.features)) @@ -152,6 +156,8 @@ def replace_feature(self, feature: torch.Tensor): new_spt.benchmark_record = self.benchmark_record new_spt.thrust_allocator = self.thrust_allocator new_spt._timer = self._timer + new_spt.force_algo = self.force_algo + return new_spt @property @@ -217,4 +223,5 @@ def shadow_copy(self) -> "SparseConvTensor": tensor.benchmark_record = self.benchmark_record tensor.thrust_allocator = self.thrust_allocator tensor._timer = self._timer + tensor.force_algo = self.force_algo return tensor diff --git a/spconv/pytorch/cppcore.py b/spconv/pytorch/cppcore.py index 5b97fbb..edd731f 100644 --- a/spconv/pytorch/cppcore.py +++ b/spconv/pytorch/cppcore.py @@ -30,7 +30,8 @@ def torch_tensor_to_tv(ten: torch.Tensor, dtype: Optional[int] = None, - shape: Optional[List[int]] = None): + shape: Optional[List[int]] = None, + stride: Optional[List[int]] = None): # assert ten.is_contiguous(), "must be contiguous tensor" ptr = ten.data_ptr() device = ten.device @@ -40,11 +41,20 @@ def torch_tensor_to_tv(ten: torch.Tensor, tv_device = 0 else: raise NotImplementedError - if shape is None: - shape = list(ten.shape) if dtype is None: dtype = _TORCH_DTYPE_TO_TV[ten.dtype] - return tv.from_blob(ptr, shape, list(ten.stride()), dtype, tv_device) + if stride is None: + stride = list(ten.stride()) + if shape is None: + shape = list(ten.shape) + else: + if not ten.is_contiguous(): + msg = "if you provide custom shape for non-contig tensor, stride must not None" + assert stride is not None, msg + else: + # custom shape, if tensor is contiguous, we use from_blob and calc strides + return tv.from_blob(ptr, shape, dtype, tv_device) + return tv.from_blob_strided(ptr, shape, stride, dtype, tv_device) def get_current_stream(): diff --git a/spconv/pytorch/modules.py b/spconv/pytorch/modules.py index f1038f8..80c102f 100644 --- a/spconv/pytorch/modules.py +++ b/spconv/pytorch/modules.py @@ -137,6 +137,7 @@ def forward(self, input): input = module(input) else: if isinstance(input, spconv.SparseConvTensor): + print(input.features.shape) if input.indices.shape[0] != 0: input = input.replace_feature(module(input.features)) else: diff --git a/spconv/pytorch/ops.py b/spconv/pytorch/ops.py index 3bf9011..ab91bf6 100644 --- a/spconv/pytorch/ops.py +++ b/spconv/pytorch/ops.py @@ -1066,7 +1066,7 @@ def indice_conv_backward(features: torch.Tensor, alpha=1.0, beta=beta) - if not FILTER_HWIO: + if is_KC_not_CK: a = out_bp_tv b = features_tv a_inds = out_indices @@ -1376,6 +1376,9 @@ def implicit_gemm_backward(features: torch.Tensor, mask_width=-1, beta=beta, stream=stream) + # for backward weight, beta = 0 because each split + # handle different kernel locations. + # TODO remove D iterator in backward weight kernel CONV.run_with_tuned_result( wgrad_tune_res, ConvOpType.kBackwardWeight, @@ -1389,7 +1392,7 @@ def implicit_gemm_backward(features: torch.Tensor, reverse_mask=False, mask_filter=masks[j].item(), mask_width=mask_width, - beta=beta, + beta=0, workspace=workspace_tv, stream=stream) @@ -1403,6 +1406,8 @@ def indice_maxpool(features: torch.Tensor, indice_pairs: torch.Tensor, # stream = get_current_stream() # CONV.stream_synchronize(stream) # t = time.time() + if not features.is_contiguous(): + features = features.contiguous() out_channel = features.shape[-1] out_features = torch.zeros((num_activate_out, out_channel), @@ -1474,6 +1479,8 @@ def indice_maxpool_implicit_gemm(features: torch.Tensor, stream = get_current_stream() # CONV.stream_synchronize(stream) # t = time.time() + if not features.is_contiguous(): + features = features.contiguous() out_channel = features.shape[-1] out_features = torch.empty((num_activate_out, out_channel), diff --git a/spconv/pytorch/utils.py b/spconv/pytorch/utils.py index 9aa3ce3..7527d2f 100644 --- a/spconv/pytorch/utils.py +++ b/spconv/pytorch/utils.py @@ -71,36 +71,72 @@ def __call__(self, pc: torch.Tensor, clear_voxels: bool = True, empty_mean: bool = False): + """generate voxels/indices/num_point_per_voxel/pc_voxel_ids from + point cloud. + This function don't return pc_voxel_id for backward compatility. + pc_voxel_id will be added in spconv 2.2. + Args: + pc: [N, 3+] point cloud. + clear_voxels: if True, call zero on voxels + empty_mean: if True, full empty location of voxels with mean. + Returns: + voxels: voxels + indices: quantized coords + num_per_voxel: number of points in a voxel + """ + + res = self.generate_voxel_with_id(pc, clear_voxels, empty_mean) + return res[0], res[1], res[2] + + def generate_voxel_with_id(self, + pc: torch.Tensor, + clear_voxels: bool = True, + empty_mean: bool = False): + """generate voxels/indices/num_point_per_voxel/pc_voxel_ids from + point cloud. + Args: + pc: [N, 3+] point cloud. + clear_voxels: if True, call zero on voxels + empty_mean: if True, full empty location of voxels with mean. + Returns: + voxels: voxels + indices: quantized coords + num_per_voxel: number of points in a voxel + pc_voxel_id: voxel id for every point. if not exists, -1. + """ assert pc.device.type == self.device.type, "your pc device is wrong" expected_hash_data_num = pc.shape[0] * 2 with torch.no_grad(): + pc_voxel_id = torch.empty([pc.shape[0]], + dtype=torch.int64, + device=self.device) + pc_voxel_id_tv = torch_tensor_to_tv(pc_voxel_id) + if self.device.type != "cpu": - if self.hashdata.shape[0] < expected_hash_data_num: - self.hashdata = torch.empty([expected_hash_data_num, 2], - dtype=torch.int64, - device=self.device) + hashdata = torch.empty([expected_hash_data_num, 2], + dtype=torch.int64, + device=pc.device) + + point_indice_data = torch.empty([pc.shape[0]], + dtype=torch.int64, + device=pc.device) - if self.point_indice_data.shape[0] < pc.shape[0]: - self.point_indice_data = torch.empty([pc.shape[0]], - dtype=torch.int64, - device=self.device) pc_tv = torch_tensor_to_tv(pc) stream = get_current_stream() voxels_tv = torch_tensor_to_tv(self.voxels) indices_tv = torch_tensor_to_tv(self.indices) num_per_voxel_tv = torch_tensor_to_tv(self.num_per_voxel) hashdata_tv = torch_tensor_to_tv( - self.hashdata, + hashdata, dtype=tv.custom128, - shape=[self.hashdata.shape[0]]) - point_indice_data_tv = torch_tensor_to_tv( - self.point_indice_data) - - res = SpconvOps.point2voxel_cuda( - pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, - hashdata_tv, point_indice_data_tv, self.vsize, - self.grid_size, self.grid_stride, self.coors_range, - empty_mean, clear_voxels, stream) + shape=[hashdata.shape[0]]) + point_indice_data_tv = torch_tensor_to_tv(point_indice_data) + with torch.cuda.device(pc.device): + res = SpconvOps.point2voxel_cuda( + pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, + hashdata_tv, point_indice_data_tv, pc_voxel_id_tv, self.vsize, + self.grid_size, self.grid_stride, self.coors_range, + empty_mean, clear_voxels, stream) num_voxels = res[0].shape[0] else: pc_tv = torch_tensor_to_tv(pc) @@ -111,6 +147,7 @@ def __call__(self, hashdata_tv = torch_tensor_to_tv(self.hashdata, dtype=tv.int32) res = SpconvOps.point2voxel_cpu(pc_tv, voxels_tv, indices_tv, num_per_voxel_tv, hashdata_tv, + pc_voxel_id_tv, self.vsize, self.grid_size, self.grid_stride, self.coors_range, empty_mean, @@ -118,4 +155,4 @@ def __call__(self, num_voxels = res[0].shape[0] return (self.voxels[:num_voxels], self.indices[:num_voxels], - self.num_per_voxel[:num_voxels]) + self.num_per_voxel[:num_voxels], pc_voxel_id) diff --git a/test/benchmark.py b/test/benchmark.py index 6cffa7d..2c3ad95 100644 --- a/test/benchmark.py +++ b/test/benchmark.py @@ -24,7 +24,7 @@ import spconv.pytorch as spconv from spconv.utils import Point2VoxelCPU3d - +# torch.backends.cudnn.enabled = False def waymo_data(batch_size=1): gen = Point2VoxelCPU3d([0.1, 0.1, 0.1], [-80, -80, -2, 80, 80, 6], 3, 150000, 1) @@ -289,7 +289,7 @@ def main(): voxels_th = torch.from_numpy(voxels).to(device).to(dtype) coors_th = torch.from_numpy(coors).to(device).int() voxels_th.requires_grad = True - algo = spconv.ConvAlgo.Native + algo = spconv.ConvAlgo.MaskImplicitGemm # 3080 Laptop # MaskImpGemm: 11.2ms # MaskSplitImpGemm: 12.2ms @@ -324,26 +324,26 @@ def main(): print(out.spatial_shape, out.features.mean(), out.features.max(), out.features.min()) - # times = [] - # with torch.no_grad(): - # for i in range(20): - # print("------------") - # torch.cuda.synchronize() - # t = time.time() - # out_nograd = net(voxels_th, coors_th, 1, False) - # timer = out_nograd._timer - # # res = timer.collect_by_name("forward", timer.get_all_pair_time()) - # # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) + times = [] + with torch.no_grad(): + for i in range(20): + print("------------") + torch.cuda.synchronize() + t = time.time() + out_nograd = net(voxels_th, coors_th, 1, False) + timer = out_nograd._timer + # res = timer.collect_by_name("forward", timer.get_all_pair_time()) + # res2 = timer.collect_by_name("forward0", timer.get_all_pair_time()) - # # print(sum(res.values()) + sum(res2.values())) - # # print(timer.get_all_pair_time()) + # print(sum(res.values()) + sum(res2.values())) + # print(timer.get_all_pair_time()) - # # print(sum(timer.get_all_pair_time().values())) - # torch.cuda.synchronize() - # # sort_bench() - # times.append(time.time() - t) - # print("spconv time", np.mean(times[10:])) - # times = [] + # print(sum(timer.get_all_pair_time().values())) + torch.cuda.synchronize() + # sort_bench() + times.append(time.time() - t) + print("spconv time", np.mean(times[10:])) + times = [] # for i in range(10): # out = net(voxels_th, coors_th, 1) diff --git a/test/test_all_algo.py b/test/test_all_algo.py new file mode 100644 index 0000000..59e2bc3 --- /dev/null +++ b/test/test_all_algo.py @@ -0,0 +1,663 @@ +# Copyright 2021 Yan Yan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test all gemm/conv kernels. +We can't test all kernels in network because auto-tuner will only find one best kernel. +""" + + +import sys +from pathlib import Path +from typing import Dict, List, Tuple +import pickle +import sys +import time +from pathlib import Path +from cumm.gemm.algospec.core import GemmAlgo, ShuffleStrideType + +import numpy as np +import pccm +import torch +import torch.nn.functional as F +from spconv.test_utils import TestCase +from cumm import tensorview as tv +from cumm.conv.bases import NCHW, NHWC, ConvIterAlgo, ConvOpType +import os +from cumm.gemm.codeops import div_up +from spconv.core import AlgoHint, ConvAlgo +from spconv.pytorch.conv import expand_nd +from spconv.pytorch import ops +from spconv.algo import CONV, GEMM, BestAlgoByProfile, BestConvAlgoByProfile +from spconv.pytorch.cppcore import get_current_stream, torch_tensor_to_tv +from spconv.test_utils import generate_sparse_data, params_grid +import tqdm +from spconv.constants import ALL_WEIGHT_IS_KRSC + +assert ALL_WEIGHT_IS_KRSC is True, "we only support KRSC in spconv >= 2.2" + +# TODO remove or release this when tf32 op is ready +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +NUMPY_DTYPE_TO_TORCH = { + np.float32: torch.float32, + np.float16: torch.float16, + np.int8: torch.int8, + +} + +class SparseConvTester: + def __init__(self, algo: ConvAlgo, subm: bool, shape: List[int], bs: int, dtype: np.dtype, N: int, K: int, C: int, + ksize: int, stride: int, padding: int, dilation: int) -> None: + ndim = 3 + self.shape = shape + self.bs = bs + self.dtype = dtype + self.dtype_th = NUMPY_DTYPE_TO_TORCH[dtype] + self.K = K + self.C = C + self.ksize = expand_nd(ksize, ndim) + self.stride = expand_nd(stride, ndim) + self.padding = expand_nd(padding, ndim) + self.dilation = expand_nd(dilation, ndim) + self.N = N + self.device = torch.device("cuda:0") + op = expand_nd(0, ndim) + self.kv: int = np.prod(self.ksize) + self.num_split = 1 if algo == ConvAlgo.MaskImplicitGemm else 2 + + sparse_dict = generate_sparse_data(shape, [1500] * bs, C) + + voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype( + np.float32) + indices_np = np.ascontiguousarray( + sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) + indices_th = torch.from_numpy(indices_np).to(self.device) + out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs( + indices_th, 1, shape, ConvAlgo.Native, self.ksize, self.stride, self.padding, + self.dilation, op, subm) + self.indice_num_per_loc_np = indice_num_per_loc.cpu().numpy() + self.indice_pairs_np = pair_ref.cpu().numpy() + self.pair_native = pair_ref + self.indice_num_per_loc = indice_num_per_loc + if algo == ConvAlgo.Native: + self.out_inds: torch.Tensor = out_inds + self.num_inds_per_loc: torch.Tensor = indice_num_per_loc + self.pair_fwd : torch.Tensor = torch.Tensor() + self.pair_bwd: torch.Tensor = torch.Tensor() + self.pair_mask_fwd_splits: List[torch.Tensor] = [] + self.pair_mask_bwd_splits: List[torch.Tensor] = [] + self.mask_argsort_fwd_splits: List[torch.Tensor] = [] + self.mask_argsort_bwd_splits: List[torch.Tensor] = [] + self.masks = np.array([]) + else: + res = ops.get_indice_pairs_implicit_gemm(indices_th, bs, shape, + algo, self.ksize, self.stride, self.padding, + self.dilation, op, subm=subm) + + self.out_inds = res[0] + self.num_inds_per_loc = res[1] + self.pair_fwd = res[2] + self.pair_bwd = res[3] + self.pair_mask_fwd_splits = res[4] + self.pair_mask_bwd_splits = res[5] + self.mask_argsort_fwd_splits = res[6] + self.mask_argsort_bwd_splits = res[7] + self.masks = res[8] + self.voxels_np = voxels_np + self.indices_np = indices_np + + self.subm = subm + if dtype == np.int8: + self.inp = np.random.randint(-2, 2, size=[voxels_np.shape[0], + C]).astype(np.int8) + self.weight = np.random.randint(-2, 2, size=[K, *self.ksize, + C]).astype(np.int8) + self.output = np.random.randint(-2, 2, size=[ + self.out_inds.shape[0], K + ]).astype(dtype) + else: + self.inp = np.random.uniform(-1, 1, size=[ + voxels_np.shape[0], C + ]).astype(dtype) + self.weight = np.random.uniform(-1, 1, size=[K, *self.ksize, C]).astype(dtype) + self.output = np.random.uniform(-1, 1, size=[ + self.out_inds.shape[0], K + ]).astype(dtype) + self.weight_ref = self.weight.transpose(1, 2, 3, 0, 4) + self.weight_ref = np.ascontiguousarray(self.weight_ref).reshape(-1, K, C) + + self.out_ref, self.din_ref, self.dw_ref = self._get_ref_output() + self.dw_ref = np.ascontiguousarray(self.dw_ref.transpose(1, 0, 2).reshape(K, *self.ksize, C)) + + def _get_ref_output(self): + output_ref = np.zeros_like(self.output, dtype=np.float32) + dinput_ref = np.zeros_like(self.inp, dtype=np.float32) + dw_ref = np.zeros_like(self.weight_ref, + dtype=np.float32) # KV, K, C + + for filter_offset in range(self.kv): + if self.subm and filter_offset > self.kv // 2: + nhot = self.indice_num_per_loc_np[self.kv - 1 - filter_offset] + elif self.subm and filter_offset == self.kv // 2: + nhot = self.voxels_np.shape[0] + else: + nhot = self.indice_num_per_loc_np[filter_offset] + + i_inds = self.indice_pairs_np[0][filter_offset][:nhot] + o_inds = self.indice_pairs_np[1][filter_offset][:nhot] + a = self.inp[i_inds] + cc = a.astype( + np.float32) @ self.weight_ref[filter_offset].T.astype( + np.float32) + output_ref[o_inds] += cc + a = self.output[o_inds] + # NK @ KC + cc = a.astype( + np.float32) @ self.weight_ref[filter_offset].astype( + np.float32) + dinput_ref[i_inds] += cc + out_gather = self.output[o_inds] # [N, K] + inp_gather = self.inp[i_inds] # [N, C] + # KN @ NC + dw_res = out_gather.astype( + np.float32).T @ inp_gather.astype(np.float32) + dw_ref[filter_offset] = dw_res + return output_ref, dinput_ref, dw_ref + + def get_operands(self, op_type: ConvOpType): + zeros_func = tv.zeros if not self.subm else tv.empty + if op_type == ConvOpType.kBackwardInput: + inp_tv = zeros_func(list(self.inp.shape), self.dtype, 0) + else: + inp_tv = tv.from_numpy(self.inp).cuda() + if op_type == ConvOpType.kBackwardWeight: + weight_tv = zeros_func(list(self.weight.shape), self.dtype, 0) + else: + weight_tv = tv.from_numpy(self.weight).cuda() + if op_type == ConvOpType.kForward: + output_tv = zeros_func(list(self.output.shape), self.dtype, 0) + else: + output_tv = tv.from_numpy(self.output).cuda() + return inp_tv, weight_tv, output_tv + + def get_operands_torch(self, op_type: ConvOpType): + zeros_func = torch.zeros if not self.subm else torch.empty + if op_type == ConvOpType.kBackwardInput: + inp_tv = zeros_func(list(self.inp.shape), dtype=self.dtype_th, device=self.device) + else: + inp_tv = torch.from_numpy(self.inp).cuda() + if op_type == ConvOpType.kBackwardWeight: + weight_tv = zeros_func(list(self.weight.shape), dtype=self.dtype_th, device=self.device) + else: + weight_tv = torch.from_numpy(self.weight).cuda() + if op_type == ConvOpType.kForward: + output_tv = zeros_func(list(self.output.shape), dtype=self.dtype_th, device=self.device) + else: + output_tv = torch.from_numpy(self.output).cuda() + return inp_tv, weight_tv, output_tv + +def _test_impgemm_conv_cuda(subm: bool): + ndim = 3 + dtype_to_tol = { + np.float32: (1e-4, 1e-4), + np.float16: (1e-2, 1e-2), + np.int8: (1e-4, 1e-4), + } + device = torch.device("cuda:0") + shapes = [[19, 18, 17]] + batchsizes = [1] + dtypes = [np.float32, np.float16] + test_case = TestCase() + in_channels = [32, 47] + out_channels = [32, 48, 62] + if subm: + ksizes = [3] + strides = [1] + paddings = [0] + dilations = [1] + else: + ksizes = [2, 3] + strides = [1, 2, 3] + paddings = [0, 1] + dilations = [1, 2] + algos = [ + ConvAlgo.MaskSplitImplicitGemm, + ConvAlgo.MaskImplicitGemm, + ] + arch = torch.cuda.get_device_capability() + + for shape, bs, C, K, k, s, p, d, algo, dtype in tqdm.tqdm(params_grid( + shapes, batchsizes, in_channels, out_channels, ksizes, + strides, paddings, dilations, algos, dtypes)): + tester = SparseConvTester(algo, subm, shape, bs, dtype, 1500, K, C, k, s, p, d) + atol, rtol = dtype_to_tol[dtype] + mask_width_to_mask_out_fwd: Dict[int, torch.Tensor] = {} + mask_width_to_mask_out_bwd: Dict[int, torch.Tensor] = {} + + op_types = [ConvOpType.kForward, ConvOpType.kBackwardInput] + spk = 1 + for op_type in op_types: + inp_tv, weight_tv, output_tv = tester.get_operands(op_type) + avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, op_type, -1) + for desp in avail_desps: + if not subm: + if op_type == ConvOpType.kForward: + output_tv.zero_() + else: + inp_tv.zero_() + + # this algo must success + mask_width = desp.tile_shape[0] + # if mask_width != 32: + # continue + if mask_width not in mask_width_to_mask_out_fwd: + mask_width_to_mask_out_fwd[mask_width] = torch.zeros([2, div_up(tester.out_inds.shape[0], mask_width)], + dtype=torch.int32, + device=tester.device) + mask_output_fwd = mask_width_to_mask_out_fwd[mask_width] + + if subm: + if desp.op_type == ConvOpType.kForward.value: + indice_pairs = tester.pair_fwd + elif desp.op_type == ConvOpType.kBackwardInput.value: + indice_pairs = tester.pair_bwd + else: + indice_pairs = tester.pair_fwd + mask_output = mask_output_fwd + # print([bin(x.item()) for x in masks]) + for j in range(tester.num_split): + beta = 1 if j == 1 else 0 + mask_filter = tester.masks[j].item() + + reverse_mask = False + if desp.op_type == ConvOpType.kBackwardWeight.value: + mask_op = mask_output[j] + else: + mask_op = tester.pair_mask_fwd_splits[j] + if desp.op_type == ConvOpType.kBackwardInput.value: + reverse_mask = True + mask_output_run = torch_tensor_to_tv(mask_output[j], dtype=tv.uint32) + if desp.op_type == ConvOpType.kBackwardWeight.value: + mask_output_run = tv.Tensor() + CONV.run_with_tuned_result( + BestConvAlgoByProfile(desp, spk), + desp.op_type, + inp_tv, + weight_tv, + output_tv, + torch_tensor_to_tv(mask_op, dtype=tv.uint32), + torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]), + mask_output_run, + torch_tensor_to_tv(indice_pairs), + reverse_mask, + mask_filter=mask_filter, + mask_width=mask_width, + beta=beta, + verbose=False, + ) + else: + if mask_width not in mask_width_to_mask_out_bwd: + mask_width_to_mask_out_bwd[mask_width] = torch.zeros([2, div_up(tester.indices_np.shape[0], mask_width)], + dtype=torch.int32, + device=tester.device) + mask_output_bwd = mask_width_to_mask_out_bwd[mask_width] + + if desp.op_type == ConvOpType.kForward.value: + indice_pairs = tester.pair_fwd # inp -> out + mask_ops = tester.pair_mask_fwd_splits + mask_argsorts = tester.mask_argsort_fwd_splits + mask_output = mask_output_fwd + elif desp.op_type == ConvOpType.kBackwardInput.value: + indice_pairs = tester.pair_bwd # out -> inp + mask_ops = tester.pair_mask_bwd_splits + mask_argsorts = tester.mask_argsort_bwd_splits + mask_output = mask_output_bwd + else: + indice_pairs = tester.pair_fwd # inp -> out + mask_ops = tester.pair_mask_fwd_splits + mask_argsorts = tester.mask_argsort_fwd_splits + mask_output = mask_output_fwd + + for j in range(tester.num_split): + beta = 1 if j == 1 else 0 + mask_filter = tester.masks[j].item() + reverse_mask = False + if desp.op_type == ConvOpType.kBackwardWeight.value: + mask_op = mask_output[j] + else: + mask_op = mask_ops[j] + + CONV.run_with_tuned_result( + BestConvAlgoByProfile(desp, spk), + desp.op_type, + inp_tv, + weight_tv, + output_tv, + torch_tensor_to_tv(mask_op, dtype=tv.uint32), + torch_tensor_to_tv(mask_argsorts[j]), + torch_tensor_to_tv(mask_output[j], dtype=tv.uint32), + torch_tensor_to_tv(indice_pairs), + reverse_mask, + mask_filter=mask_filter, + mask_width=mask_width, + beta=beta, + verbose=False, + ) + out_ref = tester.out_ref + din_ref = tester.din_ref + dw_ref = tester.dw_ref + if op_type == ConvOpType.kForward: + out_my = output_tv.cpu().numpy() + if dtype != np.float16: + test_case.assertAllClose(out_ref, out_my, atol=atol, rtol=rtol) + else: + error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1)) + assert error_norm < 5 + # print(desp, ) + else: + din_my = inp_tv.cpu().numpy() + if dtype != np.float16: + test_case.assertAllClose(din_ref, din_my, atol=atol, rtol=rtol) + else: + error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1)) + assert error_norm < 10, f"{desp}, {error_norm}, {k}, {s}, {p}, {d}" + inp_tv, weight_tv, output_tv = tester.get_operands(ConvOpType.kBackwardWeight) + + for spk in [1, 4, 16, 64]: + for mask_width, mask_output in mask_width_to_mask_out_fwd.items(): + avail_desps = CONV.get_all_available(inp_tv, weight_tv, output_tv, NHWC, NHWC, NHWC, arch, ConvOpType.kBackwardWeight, mask_width) + for desp in avail_desps: + weight_tv.zero_() + if subm: + indice_pairs = tester.pair_fwd + for j in range(tester.num_split): + beta = 0 + mask_filter = tester.masks[j].item() + mask_op = mask_output[j] + mask_op_tv = torch_tensor_to_tv(mask_op, dtype=tv.uint32) + # mask_op_np = mask_op_tv.cpu().numpy() + # bit_ref = np.bitwise_or.reduce(mask_op_np, axis=0) + # bit_my = mask_filter + CONV.run_with_tuned_result( + BestConvAlgoByProfile(desp, spk), + desp.op_type, + inp_tv, + weight_tv, + output_tv, + mask_op_tv, + torch_tensor_to_tv(tester.mask_argsort_fwd_splits[j]), + tv.Tensor(), + torch_tensor_to_tv(indice_pairs), + reverse_mask=False, + mask_filter=mask_filter, + mask_width=mask_width, + beta=beta, + verbose=False, + ) + else: + indice_pairs = tester.pair_fwd # inp -> out + mask_ops = tester.pair_mask_fwd_splits + mask_argsorts = tester.mask_argsort_fwd_splits + for j in range(tester.num_split): + # beta = 1 if j == 1 else 0 + beta = 0 + mask_filter = tester.masks[j].item() + reverse_mask = False + mask_op = mask_output[j] + + CONV.run_with_tuned_result( + BestConvAlgoByProfile(desp, spk), + desp.op_type, + inp_tv, + weight_tv, + output_tv, + torch_tensor_to_tv(mask_op, dtype=tv.uint32), + torch_tensor_to_tv(mask_argsorts[j]), + torch_tensor_to_tv(mask_output[j], dtype=tv.uint32), + torch_tensor_to_tv(indice_pairs), + reverse_mask, + mask_filter=mask_filter, + mask_width=mask_width, + beta=beta, + verbose=False, + ) + dw_ref = tester.dw_ref + dw_my = weight_tv.cpu().numpy() + if dtype != np.float16: + # print(desp, spk, K, C, mask_width, algo) + test_case.assertAllClose(dw_ref, dw_my, atol=atol, rtol=rtol) + else: + error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1)) + # print(desp, error_norm) + assert error_norm < 5 + +def _test_native_conv_cuda(subm: bool): + ndim = 3 + dtype_to_tol = { + np.float32: (1e-4, 1e-4), + np.float16: (1e-2, 1e-2), + np.int8: (1e-4, 1e-4), + } + device = torch.device("cuda:0") + shapes = [[19, 18, 17]] + batchsizes = [1] + dtypes = [np.float32, np.float16] + test_case = TestCase() + in_channels = [32, 47] + out_channels = [32, 48, 62] + if subm: + ksizes = [3, 5] + strides = [1] + paddings = [0] + dilations = [1] + else: + ksizes = [2, 3] + strides = [1, 2, 3] + paddings = [0, 1] + dilations = [1, 2] + arch = torch.cuda.get_device_capability() + stream = get_current_stream() + for shape, bs, C, K, k, s, p, d, dtype in tqdm.tqdm(params_grid( + shapes, batchsizes, in_channels, out_channels, ksizes, + strides, paddings, dilations, dtypes)): + tester = SparseConvTester(ConvAlgo.Native, subm, shape, bs, dtype, 1500, K, C, k, s, p, d) + atol, rtol = dtype_to_tol[dtype] + + kv_center = tester.kv // 2 + kv = tester.kv + pair_in = torch_tensor_to_tv(tester.pair_native)[0] + pair_out = torch_tensor_to_tv(tester.pair_native)[1] + + op_types = [ConvOpType.kForward, ConvOpType.kBackwardInput, ConvOpType.kBackwardWeight] + indice_pair_num_cpu = tester.indice_num_per_loc_np + spk = 1 + + out_ref = tester.out_ref + din_ref = tester.din_ref + dw_ref = tester.dw_ref.reshape(K, -1, C) + + for op_type in op_types: + inp_th, weight_th, output_th = tester.get_operands_torch(op_type) + weight_th = weight_th.view(K, -1, C) + inp_tv = torch_tensor_to_tv(inp_th) + weight_tv = torch_tensor_to_tv(weight_th) + output_tv = torch_tensor_to_tv(output_th) + + if op_type == ConvOpType.kForward: + a = inp_tv + c = output_tv + b = weight_tv.select(1, tester.kv // 2) + + + avail_desps = GEMM.get_all_available(a, b, c, False, True, False, arch, ShuffleStrideType.ShuffleAC) + for desp in avail_desps: + if subm: + torch.mm(inp_th, weight_th[:, tester.kv // 2].T, out=output_th) + else: + output_tv.zero_() + inited = subm + for i, nhot in enumerate(indice_pair_num_cpu): + if subm and i == kv_center: + continue + if subm and i > kv_center: + nhot = indice_pair_num_cpu[kv - i - 1] + if nhot <= 0: + continue + inp_indices = pair_in[i].slice_first_axis(0, nhot) + out_indices = pair_out[i].slice_first_axis(0, nhot) + b = weight_tv.select(1, i) + # inp @ filter.T, NC @ KC + beta = 1.0 if inited else 0.0 + GEMM.run_with_tuned_result( + BestAlgoByProfile(desp, 1), + a, + b, + c, + False, + True, + False, + arch=arch, + stream=stream, + shuffle_type=ShuffleStrideType.ShuffleAC, + a_inds=inp_indices, + c_inds=out_indices, + hint=AlgoHint.Fowrard.value, + alpha=1.0, + beta=beta) + inited = True + out_my = output_tv.cpu().numpy() + if dtype != np.float16: + # error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1)) + # assert error_norm < 1 + # print(desp, K, C, k, error_norm) + + test_case.assertAllClose(out_ref, out_my, atol=atol, rtol=rtol) + else: + error_norm = np.linalg.norm(out_ref.reshape(-1) - out_my.reshape(-1)) + assert error_norm < 10 + + elif op_type == ConvOpType.kBackwardInput: + a = output_tv + b = weight_tv.select(1, tester.kv // 2) + c = inp_tv + avail_desps = GEMM.get_all_available(a, b, c, False, False, False, arch, ShuffleStrideType.ShuffleAC) + for desp in avail_desps: + if subm: + torch.mm(output_th, weight_th[:, tester.kv // 2], out=inp_th) + else: + inp_tv.zero_() + inited = subm + for i, nhot in enumerate(indice_pair_num_cpu): + if subm and i == kv_center: + continue + if subm and i > kv_center: + nhot = indice_pair_num_cpu[kv - i - 1] + if nhot <= 0: + continue + inp_indices = pair_in[i].slice_first_axis(0, nhot) + out_indices = pair_out[i].slice_first_axis(0, nhot) + b = weight_tv.select(1, i) + # inp @ filter.T, NC @ KC + beta = 1.0 if inited else 0.0 + GEMM.run_with_tuned_result( + BestAlgoByProfile(desp, 1), + a, + b, + c, + False, + False, + False, + arch=arch, + stream=stream, + shuffle_type=ShuffleStrideType.ShuffleAC, + a_inds=out_indices, + c_inds=inp_indices, + hint=AlgoHint.Fowrard.value, + alpha=1.0, + beta=beta) + inited = True + din_my = inp_tv.cpu().numpy() + if dtype != np.float16: + # error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1)) + # print(desp, K, C, k, error_norm) + test_case.assertAllClose(din_ref, din_my, atol=atol, rtol=rtol) + # assert error_norm < 1 + + else: + error_norm = np.linalg.norm(din_ref.reshape(-1) - din_my.reshape(-1)) + assert error_norm < 10 + + else: + a = output_tv + b = inp_tv + c = weight_tv.select(1, tester.kv // 2) + avail_desps = GEMM.get_all_available(a, b, c, True, False, False, arch, ShuffleStrideType.ShuffleAB) + for desp in avail_desps: + inited = subm + weight_tv.zero_() + if subm: + torch.mm(output_th.T, inp_th, out=weight_th[:, kv_center]) + + for i, nhot in enumerate(indice_pair_num_cpu): + if subm and i == kv_center: + continue + if subm and i > kv_center: + nhot = indice_pair_num_cpu[kv - i - 1] + if nhot <= 0: + continue + beta = 1.0 if inited else 0.0 + inp_indices = pair_in[i].slice_first_axis(0, nhot) + out_indices = pair_out[i].slice_first_axis(0, nhot) + a_inds = out_indices + b_inds = inp_indices + + GEMM.run_with_tuned_result(BestAlgoByProfile(desp, 32), + a, + b, + weight_tv.select(1, i), + True, + False, + False, + arch=arch, + stream=stream, + shuffle_type=ShuffleStrideType.ShuffleAB, + a_inds=a_inds, + b_inds=b_inds, + hint=AlgoHint.BackwardWeight.value, + alpha=1.0, + beta=beta) + dw_my = weight_tv.cpu().numpy() + if dtype != np.float16: + error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1)) + assert error_norm < 1 + + # test_case.assertAllClose(dw_ref, dw_my, atol=atol, rtol=rtol) + # print(desp, error_norm) + + else: + error_norm = np.linalg.norm(dw_ref.reshape(-1) - dw_my.reshape(-1)) + # print(desp, error_norm) + assert error_norm < 10 + + +def test_all_algo_unit(): + _test_impgemm_conv_cuda(True) + _test_impgemm_conv_cuda(False) + _test_native_conv_cuda(True) + _test_native_conv_cuda(False) + + +if __name__ == "__main__": + test_all_algo_unit() \ No newline at end of file diff --git a/test/test_conv.py b/test/test_conv.py index d344fde..2b54dba 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Compare results between sparse and dense layers: +SparseConvXd +SparseConvTransposeXd +SparseMaxPoolXd +""" + import time import unittest from pathlib import Path @@ -24,13 +30,11 @@ import spconv.pytorch as spconv from spconv.test_utils import TestCase, generate_sparse_data, params_grid from spconv.constants import ALL_WEIGHT_IS_KRSC, FILTER_HWIO -# import sparseconvnet as scn # we must disable tf32 to increase reference precision. torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False - class SparseConv3dTestTorch(nn.Module): def __init__(self, num_layers, @@ -76,52 +80,6 @@ def forward(self, features, coors, batch_size): self.grid) return self.net(x) # .dense() - -class SubMConv3dTestTorch(nn.Module): - def __init__(self, - num_layers, - ndim, - shape, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - algo=spconv.ConvAlgo.Native): - super().__init__() - layers = [ - spconv.SubMConv3d(in_channels, - out_channels, - kernel_size, - stride, - padding=padding, - dilation=dilation, - bias=False, - algo=algo) - ] - for i in range(1, num_layers): - layers.append( - spconv.SubMConv3d(out_channels, - out_channels, - kernel_size, - stride, - padding=padding, - dilation=dilation, - bias=False, - algo=algo)) - self.net = spconv.SparseSequential(*layers, ) - # self.grid = torch.full([3, *shape], -1, dtype=torch.int32).cuda() - self.grid = None - self.shape = shape - - def forward(self, features, coors, batch_size): - coors = coors.int() # .cpu() - x = spconv.SparseConvTensor(features, coors, self.shape, batch_size, - self.grid) - return self.net(x) # .dense() - - class Conv3dTestTorch(nn.Module): def __init__(self, num_layers, ndim, shape, in_channels, out_channels, kernel_size, stride, padding, dilation): @@ -150,11 +108,11 @@ def __init__(self, num_layers, ndim, shape, in_channels, out_channels, def forward(self, x): return self.net(x) # .dense() - class SparseDeConv3dTestTorch(nn.Module): def __init__(self, num_layers, ndim, shape, in_channels, out_channels, - kernel_size, stride, padding, dilation): + kernel_size, stride, padding, dilation, algo): super().__init__() + self.algo = algo layers = [ spconv.SparseConvTranspose3d(in_channels, out_channels, @@ -162,7 +120,8 @@ def __init__(self, num_layers, ndim, shape, in_channels, out_channels, stride, padding=padding, dilation=dilation, - bias=False) + bias=False, + algo=algo) ] for i in range(1, num_layers): layers.append( @@ -172,7 +131,8 @@ def __init__(self, num_layers, ndim, shape, in_channels, out_channels, stride, padding=padding, dilation=dilation, - bias=False)) + bias=False, + algo=algo)) self.net = spconv.SparseSequential(*layers, ) self.shape = shape @@ -213,14 +173,15 @@ def forward(self, x): class SparseMaxPoolTestTorch(nn.Module): def __init__(self, num_layers, ndim, shape, kernel_size, stride, padding, - dilation): + dilation, algo): super().__init__() + self.algo = algo layers = [ - spconv.SparseMaxPool3d(kernel_size, stride, padding, dilation) + spconv.SparseMaxPool3d(kernel_size, stride, padding, dilation, algo=algo) ] for i in range(1, num_layers): layers.append( - spconv.SparseMaxPool3d(kernel_size, stride, padding, dilation)) + spconv.SparseMaxPool3d(kernel_size, stride, padding, dilation, algo=algo)) self.net = spconv.SparseSequential(*layers, ) self.shape = shape @@ -243,86 +204,6 @@ def __init__(self, num_layers, ndim, shape, kernel_size, stride, padding, def forward(self, x): return self.net(x) # .dense() - -class SubmanifoldConvTestTorch(nn.Module): - def __init__(self, num_layers, ndim, shape, in_channels, out_channels, - kernel_size, stride): - super().__init__() - layers = [ - spconv.SubMConv3d(in_channels, - out_channels, - kernel_size, - bias=False, - indice_key="subm0") - ] - for i in range(1, num_layers): - layers.append( - spconv.SubMConv3d(out_channels, - out_channels, - kernel_size, - bias=False)) - self.net = nn.Sequential(*layers, ) - self.shape = shape - - def forward(self, features, coors, batch_size): - coors = coors.int() - x = spconv.SparseConvTensor(features, coors, self.shape, batch_size) - return self.net(x) - - -class SCNCoupleDeConvTest(nn.Module): - def __init__(self, num_layers, ndim, shape, in_channels, out_channels, - kernel_size, stride): - super().__init__() - self.scn_input = scn.InputLayer(ndim, shape, mode=0) - self.net = nn.Sequential( - scn.Convolution(ndim, - in_channels, - out_channels, - kernel_size, - stride, - bias=False), - scn.Deconvolution(ndim, - out_channels, - in_channels, - kernel_size, - stride, - bias=False), - scn.SparseToDense(ndim, in_channels), - ) - - def forward(self, features, coors, batch_size): - coors = coors.long().cpu() - x = self.scn_input((coors, features)) - return self.net(x) - - -class SparseCoupleDeConvTest(nn.Module): - def __init__(self, num_layers, ndim, shape, in_channels, out_channels, - kernel_size, stride): - super().__init__() - self.net = spconv.SparseSequential( - spconv.SparseConv3d(in_channels, - out_channels, - kernel_size, - stride, - indice_key="cp0", - bias=False), - spconv.SparseInverseConv3d(out_channels, - in_channels, - kernel_size, - indice_key="cp0", - bias=False), - ) - self.todense = spconv.ToDense() - self.shape = shape - - def forward(self, features, coors, batch_size): - coors = coors.int() - x = spconv.SparseConvTensor(features, coors, self.shape, batch_size) - return self.todense(self.net(x)) # .dense() - - def gather_nd(params, indices): # this function has a limit that MAX_ADVINDEX_CALC_DIMS=5 ndim = indices.shape[-1] @@ -349,367 +230,147 @@ def scatter_nd(indices, updates, shape): ret[slices] = updates.view(*output_shape) return ret +def test_spconv3d(): + test_case = TestCase() + np.random.seed(484) + torch.manual_seed(48848) + devices = ["cuda:0"] + shapes = [[19, 18, 17]] + batchsizes = [1, 2] -class TestSpConv(TestCase): - def testSpConv3d(self): - np.random.seed(484) - torch.manual_seed(48848) - devices = ["cuda:0"] - shapes = [[19, 18, 17]] - batchsizes = [1, 2] - - in_channels = [32] - out_channels = [32, 48, 64] - ksizes = [2, 3] - strides = [1, 2, 3] - paddings = [0, 1, 2] - dilations = [1, 2, 3] - algos = [ - ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, - ConvAlgo.MaskSplitImplicitGemm - ] - # algos = [ConvAlgo.Native] - - for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( - devices, shapes, batchsizes, in_channels, out_channels, ksizes, - strides, paddings, dilations, algos): - if all([s > 1, d > 1]): - continue # don't support this. - # print(dev, shape, bs, IC, OC, k, s, p, d) - device = torch.device(dev) - num_points = [1000] * bs - dtype = torch.float32 - net = SparseConv3dTestTorch(1, - 3, - shape, - IC, - OC, - k, - s, - p, - d, - algo=al).to(device).to(dtype) - net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, - d).to(device).to(dtype) - - sparse_dict = generate_sparse_data(shape, num_points, IC) - - features = np.ascontiguousarray(sparse_dict["features"]).astype( - np.float32) - indices = np.ascontiguousarray( - sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) - features_dense = sparse_dict["features_dense"].astype(np.float32) - indices_t = torch.from_numpy(indices).int().to(device) - features_t = torch.from_numpy(features).to(device).to(dtype) - features_t.requires_grad = True - features_dense_t = torch.from_numpy(features_dense).to(device).to( - dtype) - features_dense_t.requires_grad = True - if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: - if FILTER_HWIO: - filters = np.random.uniform(-1, 1, - size=[k, k, k, IC, - OC]).astype(np.float32) - else: - filters = np.random.uniform(-1, 1, - size=[k, k, k, OC, - IC]).astype(np.float32) - filters_t = torch.from_numpy(filters).to(device).to(dtype) - if FILTER_HWIO: - net_ref.net[0].weight.data[:] = filters_t.permute( - 4, 3, 0, 1, 2).contiguous() - else: - net_ref.net[0].weight.data[:] = filters_t.permute( - 3, 4, 0, 1, 2).contiguous() - else: - filters = np.random.uniform(-1, 1, - size=[OC, k, k, k, - IC]).astype(np.float32) - filters_t = torch.from_numpy(filters).to(device).to(dtype) - net_ref.net[0].weight.data[:] = filters_t.permute( - 0, 4, 1, 2, 3).contiguous() - - net.net[0].weight.data[:] = filters_t - out_ref = net_ref(features_dense_t) - out = net(features_t, indices_t, bs).dense() - out_np = out.detach().cpu().numpy() - out_ref_np = out_ref.detach().cpu().numpy() - self.assertAllClose(out_np, out_ref_np, atol=1e-4) - - dout = np.random.uniform(-0.2, 0.2, - out_ref.shape).astype(features.dtype) - dout_t = torch.from_numpy(dout).to(device) - out.backward(dout_t) - out_ref.backward(dout_t) - din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, - 1).contiguous() - din_sparse = gather_nd(din_dense, indices_t.long()) - din = features_t.grad.detach() - - din_np = din.cpu().numpy() - din_sparse_np = din_sparse.cpu().numpy() - for layer, layer_ref in zip(net.net, net_ref.net): - dw = layer.weight.grad.detach().cpu().numpy() - dw_ref = layer_ref.weight.grad.detach().cpu().numpy() - if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: - if FILTER_HWIO: - dw = dw.transpose(4, 3, 0, 1, 2) - else: - dw = dw.transpose(3, 4, 0, 1, 2) - else: - # OHWI -> OIHW - dw = dw.transpose(0, 4, 1, 2, 3) - - self.assertAllClose(dw, dw_ref, atol=1e-4) - self.assertAllClose(din_np, din_sparse_np, atol=1e-4) - - def testSpDeConv3d(self): - np.random.seed(484) - devices = ["cuda:0"] - shapes = [[19, 18, 17]] - batchsizes = [1, 2] - - in_channels = [64] - out_channels = [32, 48, 64] - ksizes = [2, 3] - strides = [2, 3] - paddings = [0, 1, 2] - dilations = [1, 2, 3] - ksizes = [3] - - strides = [1] - paddings = [0] - dilations = [1] - - for dev, shape, bs, IC, OC, k, s, p, d in params_grid( - devices, shapes, batchsizes, in_channels, out_channels, ksizes, - strides, paddings, dilations): - if all([s > 1, d > 1]): - continue # don't support this. - device = torch.device(dev) - num_points = [1000] * bs - - sparse_dict = generate_sparse_data(shape, num_points, IC) - - features = np.ascontiguousarray(sparse_dict["features"]).astype( - np.float32) - indices = np.ascontiguousarray( - sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) - features_dense = sparse_dict["features_dense"].astype(np.float32) + in_channels = [32] + out_channels = [32, 48, 64] + ksizes = [2, 3] + strides = [1, 2, 3] + paddings = [0, 1, 2] + dilations = [1, 2, 3] + algos = [ + ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, + ConvAlgo.MaskSplitImplicitGemm + ] + # algos = [ConvAlgo.Native] + + for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( + devices, shapes, batchsizes, in_channels, out_channels, ksizes, + strides, paddings, dilations, algos): + if all([s > 1, d > 1]): + continue # don't support this. + # print(dev, shape, bs, IC, OC, k, s, p, d) + device = torch.device(dev) + num_points = [1500] * bs + dtype = torch.float32 + net = SparseConv3dTestTorch(1, + 3, + shape, + IC, + OC, + k, + s, + p, + d, + algo=al).to(device).to(dtype) + net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, + d).to(device).to(dtype) + + sparse_dict = generate_sparse_data(shape, num_points, IC) + + features = np.ascontiguousarray(sparse_dict["features"]).astype( + np.float32) + indices = np.ascontiguousarray( + sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) + features_dense = sparse_dict["features_dense"].astype(np.float32) + indices_t = torch.from_numpy(indices).int().to(device) + features_t = torch.from_numpy(features).to(device).to(dtype) + features_t.requires_grad = True + features_dense_t = torch.from_numpy(features_dense).to(device).to( + dtype) + features_dense_t.requires_grad = True + if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: if FILTER_HWIO: - filters = np.random.uniform(0, 1, size=[k, k, k, IC, - OC]).astype(np.float32) + filters = np.random.uniform(-1, 1, + size=[k, k, k, IC, + OC]).astype(np.float32) else: - filters = np.random.uniform(0, 1, size=[k, k, k, OC, - IC]).astype(np.float32) - - indices_t = torch.from_numpy(indices).int().to(device) - features_t = torch.from_numpy(features).to(device) - features_t.requires_grad = True - features_dense_t = torch.from_numpy(features_dense).to(device) - features_dense_t.requires_grad = True - net = SparseDeConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, - d).to(device) - net_ref = DeConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, - d).to(device) - filters_t = torch.from_numpy(filters).to(device) - print(net_ref.net[0].weight.shape) + filters = np.random.uniform(-1, 1, + size=[k, k, k, OC, + IC]).astype(np.float32) + filters_t = torch.from_numpy(filters).to(device).to(dtype) if FILTER_HWIO: net_ref.net[0].weight.data[:] = filters_t.permute( - 3, 4, 0, 1, 2).contiguous() + 4, 3, 0, 1, 2).contiguous() else: net_ref.net[0].weight.data[:] = filters_t.permute( - 4, 3, 0, 1, 2).contiguous() - net.net[0].weight.data[:] = filters_t - out_ref = net_ref(features_dense_t) - out = net(features_t, indices_t, bs).dense() - out_np = out.detach().cpu().numpy() - out_ref_np = out_ref.detach().cpu().numpy() - self.assertAllClose(out_np, out_ref_np, atol=1e-4) - - dout = np.random.uniform(-0.2, 0.2, - out_ref.shape).astype(features.dtype) - dout_t = torch.from_numpy(dout).to(device) - out.backward(dout_t) - out_ref.backward(dout_t) - din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, - 1).contiguous() - din_sparse = gather_nd(din_dense, indices_t.long()) - din = features_t.grad.detach() - din_np = din.cpu().numpy() - din_sparse_np = din_sparse.cpu().numpy() - self.assertAllClose(din_np, din_sparse_np, atol=1e-4) - for layer, layer_ref in zip(net.net, net_ref.net): - dw = layer.weight.grad.detach().cpu().numpy() - dw_ref = layer_ref.weight.grad.detach().cpu().numpy() + 3, 4, 0, 1, 2).contiguous() + else: + filters = np.random.uniform(-1, 1, + size=[OC, k, k, k, + IC]).astype(np.float32) + filters_t = torch.from_numpy(filters).to(device).to(dtype) + net_ref.net[0].weight.data[:] = filters_t.permute( + 0, 4, 1, 2, 3).contiguous() + + net.net[0].weight.data[:] = filters_t + out_ref = net_ref(features_dense_t) + out = net(features_t, indices_t, bs).dense() + out_np = out.detach().cpu().numpy() + out_ref_np = out_ref.detach().cpu().numpy() + test_case.assertAllClose(out_np, out_ref_np, atol=1e-4) + + dout = np.random.uniform(-0.2, 0.2, + out_ref.shape).astype(features.dtype) + dout_t = torch.from_numpy(dout).to(device) + out.backward(dout_t) + out_ref.backward(dout_t) + din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, + 1).contiguous() + din_sparse = gather_nd(din_dense, indices_t.long()) + din = features_t.grad.detach() + + din_np = din.cpu().numpy() + din_sparse_np = din_sparse.cpu().numpy() + for layer, layer_ref in zip(net.net, net_ref.net): + dw = layer.weight.grad.detach().cpu().numpy() + dw_ref = layer_ref.weight.grad.detach().cpu().numpy() + if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: if FILTER_HWIO: - dw = dw.transpose(3, 4, 0, 1, 2) - else: dw = dw.transpose(4, 3, 0, 1, 2) - self.assertAllClose(dw, dw_ref, atol=1e-4) - - def testSpCpConv3d(self): - np.random.seed(484) - devices = ["cuda:0", "cpu:0"] - shapes = [[20, 20, 20]] - batchsizes = [1, 2] - - in_channels = [64] - out_channels = [32, 48, 64] - ksizes = [2] - strides = [2] - paddings = [0, 1, 2] - dilations = [1, 2, 3] - - for dev, shape, bs, IC, OC, k, s in params_grid( - devices, shapes, batchsizes, in_channels, out_channels, ksizes, - strides): - device = torch.device(dev) - num_points = [1000] * bs - - sparse_dict = generate_sparse_data(shape, num_points, IC) - - features = np.ascontiguousarray(sparse_dict["features"]).astype( - np.float32) - indices = np.ascontiguousarray( - sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) - features_dense = sparse_dict["features_dense"].astype(np.float32) - filters = np.random.uniform(0, 1, size=[k, k, k, IC, - OC]).astype(np.float32) - indices_t = torch.from_numpy(indices).int().to(device) - indices_scn_t = torch.from_numpy( - indices[:, [1, 2, 3, 0]]).int().to(device) - features_t = torch.from_numpy(features).to(device) - features_t.requires_grad = True - features_ref_t = torch.from_numpy(features).to(device) - features_ref_t.requires_grad = True - - net_ref = SCNCoupleDeConvTest(1, 3, shape, IC, OC, k, s).to(device) - net = SparseCoupleDeConvTest(1, 3, shape, IC, OC, k, s).to(device) - net_ref.net[0].weight.data[:] = net.net[0].weight.data[:].view( - *net_ref.net[0].weight.shape) - net_ref.net[1].weight.data[:] = net.net[1].weight.data[:].view( - *net_ref.net[1].weight.shape) - out_ref = net_ref(features_ref_t, indices_scn_t, bs) - out = net(features_t, indices_t, bs) - dout = np.random.uniform(-0.2, 0.2, - out_ref.shape).astype(features.dtype) - dout_t = torch.from_numpy(dout).to(device) - out.backward(dout_t) - out_ref.backward(dout_t) - din = features_t.grad.detach() - din_ref = features_ref_t.grad.detach() - din_np = din.cpu().numpy() - din_ref_np = din_ref.cpu().numpy() - self.assertAllClose(din_ref_np, din_np, atol=1e-4) - for layer, layer_ref in zip(net.net, net_ref.net): - dw = layer.weight.grad.detach().cpu().numpy() - dw_ref = layer_ref.weight.grad.detach().cpu().view( - *dw.shape).numpy() - self.assertAllClose(dw, dw_ref, atol=1e-4) - - out_np = out.detach().cpu().numpy() - out_ref_np = out_ref.detach().cpu().numpy() - self.assertAllClose(out_np, out_ref_np, atol=1e-4) - - def testSpMaxPool3d(self): - np.random.seed(485) - devices = ["cuda:0"] - shapes = [[19, 18, 17]] - batchsizes = [1, 2] - - in_channels = [64] - out_channels = [64] - ksizes = [2, 3] - strides = [1, 2, 3] - paddings = [0, 1] - dilations = [1, 2, 3] - # ksizes = [2] - # strides = [2] - # paddings = [0] - # dilations = [1] - - for dev, shape, bs, IC, OC, k, s, p, d in params_grid( - devices, shapes, batchsizes, in_channels, out_channels, ksizes, - strides, paddings, dilations): - if all([s > 1, d > 1]): - continue # don't support this. - device = torch.device(dev) - num_points = [1000] * bs - - # when data contains negative, sparse maxpool is not equal to dense maxpool. - sparse_dict = generate_sparse_data(shape, - num_points, - IC, - data_range=[0.1, 1]) - - features = np.ascontiguousarray(sparse_dict["features"]).astype( - np.float32) - indices = np.ascontiguousarray( - sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) - features_dense = sparse_dict["features_dense"].astype(np.float32) - filters = np.random.uniform(0, 1, size=[k, k, k, OC, - IC]).astype(np.float32) - indices_t = torch.from_numpy(indices).int().to(device) - features_t = torch.from_numpy(features).to(device) - features_t.requires_grad = True - features_dense_t = torch.from_numpy(features_dense).to(device) - features_dense_t.requires_grad = True - net = SparseMaxPoolTestTorch(1, 3, shape, k, s, p, d).to(device) - net_ref = MaxPool3dTestTorch(1, 3, shape, k, s, p, d).to(device) - - out_ref = net_ref(features_dense_t) - out = net(features_t, indices_t, bs) - - outids = out.indices - outfeatures = out.features - outids_dev = outids.float() - out_dense = out.dense(channels_first=False) - out = out_dense.permute(0, 4, 1, 2, 3).contiguous() - out_np = out.detach().cpu().numpy() - out_ref_np = out_ref.detach().cpu().numpy() - self.assertAllClose(out_np, out_ref_np, atol=1e-4) - - dout_sparse = np.random.uniform( - -0.2, 0.2, outfeatures.shape).astype(features.dtype) - dout_sparse_t = torch.from_numpy(dout_sparse).to(device) - dout_t = scatter_nd(outids.long(), dout_sparse_t, - list(out_dense.shape)) - dout_t = dout_t.permute(0, 4, 1, 2, 3).contiguous() - out.backward(dout_t) - out_ref.backward(dout_t) - din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, - 1).contiguous() - din_sparse = gather_nd(din_dense, indices_t.long()) - din = features_t.grad.detach() - - din_np = din.cpu().numpy() - din_sparse_np = din_sparse.cpu().numpy() - self.assertAllClose(din_np, din_sparse_np, atol=1e-4) - - -def main(algo=spconv.ConvAlgo.Native, dtype=torch.float32): - # function for develop. - np.random.seed(484) - # devices = ["cuda:0"] - devices = ["cuda:0"] - shapes = [[400, 400, 15]] - batchsizes = [2] + else: + dw = dw.transpose(3, 4, 0, 1, 2) + else: + # OHWI -> OIHW + dw = dw.transpose(0, 4, 1, 2, 3) + + test_case.assertAllClose(dw, dw_ref, atol=1e-4) + test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4) - in_channels = [19] - out_channels = [17] - ksizes = [(3, 3, 3)] - strides = [1] - paddings = [0] - dilations = [1] +def test_spdeconv3d(): + test_case = TestCase() - for dev, shape, bs, IC, OC, k, s, p, d in params_grid( + np.random.seed(484) + devices = ["cuda:0"] + shapes = [[19, 18, 17]] + batchsizes = [1, 2] + + in_channels = [64] + out_channels = [32, 48, 64] + ksizes = [2, 3] + strides = [2, 3] + paddings = [0, 1, 2] + dilations = [1, 2, 3] + + algos = [ + ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, + ConvAlgo.MaskSplitImplicitGemm + ] + + for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( devices, shapes, batchsizes, in_channels, out_channels, ksizes, - strides, paddings, dilations): + strides, paddings, dilations, algos): if all([s > 1, d > 1]): - continue + continue # don't support this. device = torch.device(dev) - num_points = [30000] * bs + num_points = [1000] * bs + dtype = torch.float32 sparse_dict = generate_sparse_data(shape, num_points, IC) @@ -718,115 +379,154 @@ def main(algo=spconv.ConvAlgo.Native, dtype=torch.float32): indices = np.ascontiguousarray( sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) features_dense = sparse_dict["features_dense"].astype(np.float32) - indices_t = torch.from_numpy(indices) - filters = np.random.uniform(0, 1, size=[k[0], 1, 1, IC, - OC]).astype(np.float32) - indices_t = torch.from_numpy(indices).int().to(device).to(dtype) - features_t = torch.from_numpy(features).to(device).to(dtype) + net = SparseDeConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, + d, al).to(device) + net_ref = DeConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, + d).to(device) - features_dense_t = torch.from_numpy(features_dense).to(device).to( - dtype) - net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d, - algo=algo).to(device).to(dtype) - net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, - d).to(device).to(dtype) - filters_t = torch.from_numpy(filters).to(device).to(dtype) - net_ref.net[0].weight[:] = filters_t.permute(4, 3, 0, 1, - 2).contiguous() - net.net[0].weight[:] = filters_t + if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: + if FILTER_HWIO: + filters = np.random.uniform(-1, 1, + size=[k, k, k, IC, + OC]).astype(np.float32) + else: + filters = np.random.uniform(-1, 1, + size=[k, k, k, OC, + IC]).astype(np.float32) + filters_t = torch.from_numpy(filters).to(device).to(dtype) + if FILTER_HWIO: + net_ref.net[0].weight.data[:] = filters_t.permute( + 3, 4, 0, 1, 2).contiguous() + else: + net_ref.net[0].weight.data[:] = filters_t.permute( + 4, 3, 0, 1, 2).contiguous() + else: + filters = np.random.uniform(-1, 1, + size=[OC, k, k, k, + IC]).astype(np.float32) + filters_t = torch.from_numpy(filters).to(device).to(dtype) + net_ref.net[0].weight.data[:] = filters_t.permute( + 4, 0, 1, 2, 3).contiguous() + net.net[0].weight.data[:] = filters_t + + indices_t = torch.from_numpy(indices).int().to(device) + features_t = torch.from_numpy(features).to(device) + features_t.requires_grad = True + features_dense_t = torch.from_numpy(features_dense).to(device) + features_dense_t.requires_grad = True + filters_t = torch.from_numpy(filters).to(device) out_ref = net_ref(features_dense_t) - times = [] - for i in range(10): - t = time.time() - out = net(features_t, indices_t, bs) - torch.cuda.synchronize() - times.append(time.time() - t) - # print((net.grid == -1).float().sum(), net.grid.numel()) - # print("spconv time", time.time() - t) - print("spconv time", np.mean(times[2:])) - out = net(features_t, indices_t, bs) - # print(out.indices) - out = out.dense() - out_numpy = out.detach().cpu().numpy() - - print( - np.linalg.norm(out.detach().cpu().numpy() - - out_ref.detach().cpu().numpy())) - print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), - out_numpy.sum()) + out = net(features_t, indices_t, bs).dense() + out_np = out.detach().cpu().numpy() + out_ref_np = out_ref.detach().cpu().numpy() + test_case.assertAllClose(out_np, out_ref_np, atol=1e-4) + + dout = np.random.uniform(-0.2, 0.2, + out_ref.shape).astype(features.dtype) + dout_t = torch.from_numpy(dout).to(device) + out.backward(dout_t) + out_ref.backward(dout_t) + din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, + 1).contiguous() + din_sparse = gather_nd(din_dense, indices_t.long()) + din = features_t.grad.detach() + din_np = din.cpu().numpy() + din_sparse_np = din_sparse.cpu().numpy() + test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4) + for layer, layer_ref in zip(net.net, net_ref.net): + dw = layer.weight.grad.detach().cpu().numpy() + dw_ref = layer_ref.weight.grad.detach().cpu().numpy() + if net.algo == ConvAlgo.Native and not ALL_WEIGHT_IS_KRSC: + if FILTER_HWIO: + dw = dw.transpose(3, 4, 0, 1, 2) + else: + dw = dw.transpose(4, 3, 0, 1, 2) + else: + # OHWI -> OIHW + dw = dw.transpose(4, 0, 1, 2, 3) + test_case.assertAllClose(dw, dw_ref, atol=1e-4) +def test_spmaxpool3d(): + test_case = TestCase() -def main_subm(algo, dtype=torch.float32): - # function for develop. - np.random.seed(484) - torch.manual_seed(50051) - # devices = ["cuda:0"] + np.random.seed(485) devices = ["cuda:0"] - shapes = [[400, 400, 15]] - batchsizes = [2] + shapes = [[19, 18, 17]] + batchsizes = [1, 2] - in_channels = [32] + in_channels = [64] out_channels = [64] - ksizes = [(3, 3, 3)] - strides = [1] - paddings = [1] - dilations = [1] - for dev, shape, bs, IC, OC, k, s, p, d in params_grid( + ksizes = [2, 3] + strides = [1, 2, 3] + paddings = [0, 1] + dilations = [1, 2, 3] + # ksizes = [2] + # strides = [2] + # paddings = [0] + # dilations = [1] + algos = [ + ConvAlgo.Native, ConvAlgo.MaskImplicitGemm, + ConvAlgo.MaskSplitImplicitGemm + ] + + + for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid( devices, shapes, batchsizes, in_channels, out_channels, ksizes, - strides, paddings, dilations): + strides, paddings, dilations, algos): if all([s > 1, d > 1]): - continue + continue # don't support this. device = torch.device(dev) - num_points = [120000] * bs + num_points = [1000] * bs - sparse_dict = generate_sparse_data(shape, num_points, IC) + # when data contains negative, sparse maxpool is not equal to dense maxpool. + sparse_dict = generate_sparse_data(shape, + num_points, + IC, + data_range=[0.1, 1]) features = np.ascontiguousarray(sparse_dict["features"]).astype( np.float32) indices = np.ascontiguousarray( sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) features_dense = sparse_dict["features_dense"].astype(np.float32) - indices_t = torch.from_numpy(indices) - filters = np.random.uniform(0, 1, size=[k[0], 1, 1, IC, - OC]).astype(np.float32) - indices_t = torch.from_numpy(indices).int().to(device).to(dtype) - features_t = torch.from_numpy(features).to(device).to(dtype) + indices_t = torch.from_numpy(indices).int().to(device) + features_t = torch.from_numpy(features).to(device) + features_t.requires_grad = True + features_dense_t = torch.from_numpy(features_dense).to(device) + features_dense_t.requires_grad = True + net = SparseMaxPoolTestTorch(1, 3, shape, k, s, p, d, al).to(device) + net_ref = MaxPool3dTestTorch(1, 3, shape, k, s, p, d).to(device) - features_dense_t = torch.from_numpy(features_dense).to(device).to( - dtype) - net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d, - algo=algo).to(device).to(dtype) - net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, - d).to(device).to(dtype) - filters_t = torch.from_numpy(filters).to(device).to(dtype) - net_ref.net[0].weight[:] = filters_t.permute(4, 3, 0, 1, - 2).contiguous() - net.net[0].weight[:] = filters_t out_ref = net_ref(features_dense_t) - times = [] - for i in range(20): - t = time.time() - out = net(features_t, indices_t, bs) - torch.cuda.synchronize() - times.append(time.time() - t) - # print((net.grid == -1).float().sum(), net.grid.numel()) - # print("spconv time", time.time() - t) - print("spconv time", np.mean(times[10:])) out = net(features_t, indices_t, bs) - # print(out.indices) - out = out.dense() - out_numpy = out.detach().cpu().numpy() - # print( - # np.linalg.norm(out.detach().cpu().numpy() - - # out_ref.detach().cpu().numpy())) - print(out_numpy.min(), out_numpy.max(), out_numpy.mean(), - out_numpy.sum()) - return out_numpy - - -if __name__ == '__main__': - # main_subm(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32) - # main(algo=spconv.ConvAlgo.SparseConvNet, dtype=torch.float32) - # TestCase().assertAllClose(out_my, out_ref) - # unittest.main() - TestSpConv().testSpConv3d() + + outids = out.indices + outfeatures = out.features + outids_dev = outids.float() + out_dense = out.dense(channels_first=False) + out = out_dense.permute(0, 4, 1, 2, 3).contiguous() + out_np = out.detach().cpu().numpy() + out_ref_np = out_ref.detach().cpu().numpy() + test_case.assertAllClose(out_np, out_ref_np, atol=1e-4) + + dout_sparse = np.random.uniform( + -0.2, 0.2, outfeatures.shape).astype(features.dtype) + dout_sparse_t = torch.from_numpy(dout_sparse).to(device) + dout_t = scatter_nd(outids.long(), dout_sparse_t, + list(out_dense.shape)) + dout_t = dout_t.permute(0, 4, 1, 2, 3).contiguous() + out.backward(dout_t) + out_ref.backward(dout_t) + din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, + 1).contiguous() + din_sparse = gather_nd(din_dense, indices_t.long()) + din = features_t.grad.detach() + + din_np = din.cpu().numpy() + din_sparse_np = din_sparse.cpu().numpy() + test_case.assertAllClose(din_np, din_sparse_np, atol=1e-4) + + + +if __name__ == "__main__": + test_spmaxpool3d() \ No newline at end of file diff --git a/test/test_implgemm.py b/test/test_implgemm.py deleted file mode 100644 index 42024cd..0000000 --- a/test/test_implgemm.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from spconv.core_cc.csrc.sparse.all import SpconvOps diff --git a/test/test_multi_impl.py b/test/test_multi_impl.py index 37cf98b..20536d9 100644 --- a/test/test_multi_impl.py +++ b/test/test_multi_impl.py @@ -12,9 +12,330 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Compare results between different algo: -CPU: gather-mm-scatter +"""Compare results between different algos: +CPU: simple gather-mm-scatter Native: Fused gather-mm-scatter -ImplicitGemm +ImplicitGemm: implicit gemm """ +import time +from pathlib import Path + +import numpy as np +import torch +from torch import nn +from cumm import tensorview as tv +from spconv.core import ConvAlgo + +import spconv.pytorch as spconv +import pickle +from spconv.test_utils import generate_sparse_data, params_grid + + +class Net(nn.Module): + def __init__(self, shape, algo): + super().__init__() + pool_algo = algo + # pool_algo = ConvAlgo.Native + self.net = spconv.SparseSequential( + spconv.SubMConv3d(3, 32, 3, bias=False, indice_key="c0", + algo=algo), + spconv.SubMConv3d(32, + 32, + 3, + bias=False, + indice_key="c0", + algo=algo), + # # nn.BatchNorm1d(32), + # # nn.ReLU(), + spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0", + algo=algo), + spconv.SubMConv3d(64, + 64, + 3, + bias=False, + indice_key="c0", + algo=algo), + # nn.BatchNorm1d(32), + # # nn.ReLU(), + spconv.SparseConv3d(64, 64, 3, 2, 1, bias=False, indice_key="m0", algo=algo), + # # spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(64, + 96, + 3, + bias=False, + indice_key="c1", + algo=algo), + spconv.SubMConv3d(96, + 96, + 3, + bias=False, + indice_key="c1", + algo=algo), + # nn.BatchNorm1d(64), + # nn.ReLU(), + spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1", algo=algo), + # spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(96, + 128, + 3, + bias=False, + indice_key="c2", + algo=algo), + spconv.SubMConv3d(128, + 128, + 3, + bias=False, + indice_key="c2", + algo=algo), + # nn.BatchNorm1d(128), + # nn.ReLU(), + # spconv.SparseConv3d(128, 128, 2, 2, bias=False, indice_key="m2"), + spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(128, + 160, + 3, + bias=False, + indice_key="c3", + algo=algo), + spconv.SubMConv3d(160, + 160, + 3, + bias=False, + indice_key="c3", + algo=algo), + # nn.BatchNorm1d(128), + # nn.ReLU(), + # spconv.SparseConv3d(160, 160, 2, 2, bias=False, indice_key="m3"), + spconv.SparseMaxPool3d(2, 2, algo=pool_algo, indice_key="m3"), + spconv.SubMConv3d(160, + 192, + 3, + bias=False, + indice_key="c4", + algo=algo), + spconv.SubMConv3d(192, + 192, + 3, + bias=False, + indice_key="c4", + algo=algo), + # nn.BatchNorm1d(128), + # nn.ReLU(), + spconv.SparseMaxPool3d(2, 2, indice_key="m4", algo=pool_algo), + # spconv.SparseConv3d(192, 192, 2, 2, bias=False, indice_key="m4"), + spconv.SubMConv3d(192, + 224, + 3, + bias=False, + indice_key="c5", + algo=algo), + spconv.SubMConv3d(224, + 224, + 3, + bias=False, + indice_key="c5", + algo=algo), + # nn.BatchNorm1d(256), + # nn.ReLU(), + + spconv.SparseInverseConv3d(224, 128, 2, indice_key="m4", bias=False, algo=algo), + # # nn.BatchNorm1d(128), + # nn.ReLU(), + + spconv.SparseInverseConv3d(128, 64, 2, indice_key="m3", bias=False, algo=algo), + ) + max_batch_size = 1 + # grid (dense map) is used for indice generation. use pre-allocated grid can run faster. + # self.grid = None + self.shape = shape + + def forward(self, features, coors, batch_size): + x = spconv.SparseConvTensor(features, + coors, + self.shape, + batch_size) + return self.net(x) + +class NetLight(nn.Module): + def __init__(self, shape, algo): + super().__init__() + pool_algo = algo + # pool_algo = ConvAlgo.Native + self.net = spconv.SparseSequential( + spconv.SubMConv3d(3, 32, 3, bias=False, indice_key="c0", + algo=algo), + spconv.SubMConv3d(32, + 32, + 3, + bias=False, + indice_key="c0", + algo=algo), + # # nn.BatchNorm1d(32), + # # nn.ReLU(), + spconv.SubMConv3d(32, 64, 3, bias=False, indice_key="c0", + algo=algo), + spconv.SubMConv3d(64, + 64, + 3, + bias=False, + indice_key="c0", + algo=algo), + # nn.BatchNorm1d(32), + # # nn.ReLU(), + spconv.SparseConv3d(64, 64, 3, 2, 1, bias=False, indice_key="m0", algo=algo), + # # spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + spconv.SubMConv3d(64, + 96, + 3, + bias=False, + indice_key="c1", + algo=algo), + spconv.SubMConv3d(96, + 96, + 3, + bias=False, + indice_key="c1", + algo=algo), + # nn.BatchNorm1d(64), + # nn.ReLU(), + spconv.SparseConv3d(96, 96, 2, 2, bias=False, indice_key="m1", algo=algo), + # spconv.SparseMaxPool3d(2, 2, algo=pool_algo), + + spconv.SparseInverseConv3d(96, 64, 2, indice_key="m1", bias=False, algo=algo), + # # nn.BatchNorm1d(128), + # nn.ReLU(), + + spconv.SparseInverseConv3d(64, 32, 3, indice_key="m0", bias=False, algo=algo), + ) + max_batch_size = 1 + # grid (dense map) is used for indice generation. use pre-allocated grid can run faster. + # self.grid = None + self.shape = shape + + def forward(self, features, coors, batch_size): + x = spconv.SparseConvTensor(features, + coors, + self.shape, + batch_size) + return self.net(x) + + +def _test_multi_impl(dtype: torch.dtype): + # TODO remove or release this when tf32 op is ready + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + np.random.seed(50051) + if dtype != torch.float16: + with open(Path(__file__).parent / "data" / "test_spconv.pkl", "rb") as f: + (voxels, coors, spatial_shape) = pickle.load(f) + else: + # CPU fp16 is very slow, so we use a small data here. + spatial_shape = [19, 18, 17] + sparse_dict = generate_sparse_data(spatial_shape, [1500] * 1, 3) + + voxels = np.ascontiguousarray(sparse_dict["features"]).astype( + np.float32) + coors = np.ascontiguousarray( + sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32) + + + device = torch.device("cuda:0") + device_cpu = torch.device("cpu:0") + + voxels_th = torch.from_numpy(voxels).to(device_cpu).to(dtype) + coors_th = torch.from_numpy(coors).to(device_cpu).int() + voxels_th_cuda = torch.from_numpy(voxels).to(device).to(dtype) + coors_th_cuda = torch.from_numpy(coors).to(device).int() + net_cls = Net + if dtype == torch.float16: + # CPU fp16 is very slow, so we use a small network here. + net_cls = NetLight + # cpu + torch.manual_seed(50051) + net_native_cpu = net_cls(spatial_shape, ConvAlgo.Native).to(device_cpu).to(dtype) + # gpu_native + torch.manual_seed(50051) + net_native_gpu = net_cls(spatial_shape, ConvAlgo.Native).to(device).to(dtype) + + torch.manual_seed(50051) + net_imp_gpu = net_cls(spatial_shape, ConvAlgo.MaskImplicitGemm).to(device).to(dtype) + + torch.manual_seed(50051) + net_simp_gpu = net_cls(spatial_shape, ConvAlgo.MaskSplitImplicitGemm).to(device).to(dtype) + + spconv.assign_name_for_sparse_modules(net_native_cpu) + spconv.assign_name_for_sparse_modules(net_native_gpu) + spconv.assign_name_for_sparse_modules(net_imp_gpu) + spconv.assign_name_for_sparse_modules(net_simp_gpu) + with torch.no_grad(): + out: torch.Tensor = net_native_cpu(voxels_th, coors_th, 1).dense() + dout = np.random.uniform(-0.2, 0.2, out.shape).astype(np.float32) + dout_t = torch.from_numpy(dout).to(device_cpu).to(dtype) + dout_t_cu = torch.from_numpy(dout).to(device).to(dtype) + + + + out_cpu = net_native_cpu(voxels_th, coors_th, 1).dense() + out_cpu.backward(dout_t) + out = net_native_gpu(voxels_th_cuda, coors_th_cuda, 1).dense() + + out.backward(dout_t_cu) + out_imp = net_imp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense() + + out_imp.backward(dout_t_cu) + out_simp = net_simp_gpu(voxels_th_cuda, coors_th_cuda, 1).dense() + + out_simp.backward(dout_t_cu) + with torch.no_grad(): + dense_cpu = out_cpu.cuda() + dense_native = out + dense_imp = out_imp + dense_simp = out_simp + + error_native = torch.linalg.norm(dense_cpu - dense_native).cpu().item() + error_imp = torch.linalg.norm(dense_cpu - dense_imp).cpu().item() + error_simp = torch.linalg.norm(dense_cpu - dense_simp).cpu().item() + + print("error_native", error_native) + print("error_imp", error_imp) + print("error_simp", error_simp) + if dtype == torch.float32: + assert error_native < 0.01 + assert error_imp < 0.01 + assert error_simp < 0.01 + else: + assert error_native < 10 + assert error_imp < 10 + assert error_simp < 10 + + + cpu_params = dict(net_native_cpu.named_parameters()) + native_params = dict(net_native_gpu.named_parameters()) + imp_params = dict(net_imp_gpu.named_parameters()) + simp_params = dict(net_simp_gpu.named_parameters()) + + for k, cpu_w in cpu_params.items(): + native_w = native_params[k] + imp_w = imp_params[k] + simp_w = simp_params[k] + cpu_w_grad = cpu_w.grad.detach().cuda() + native_w_grad = native_w.grad.detach() + imp_w_grad = imp_w.grad.detach() + simp_w_grad = simp_w.grad.detach() + + error_native = torch.linalg.norm(native_w_grad - cpu_w_grad).cpu().item() + error_imp = torch.linalg.norm(native_w_grad - imp_w_grad).cpu().item() + error_simp = torch.linalg.norm(native_w_grad - simp_w_grad).cpu().item() + print(k, error_native, error_imp, error_simp) + assert error_imp < 1 + assert error_simp < 1 + +def test_multi_impl(): + _test_multi_impl(torch.float32) + _test_multi_impl(torch.float16) + + +if __name__ == "__main__": + test_multi_impl() diff --git a/test/test_native_kernels.py b/test/test_native_kernels.py deleted file mode 100644 index b8bf5f6..0000000 --- a/test/test_native_kernels.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2021 Yan Yan -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/test_before_push.sh b/test_before_push.sh new file mode 100644 index 0000000..f115529 --- /dev/null +++ b/test_before_push.sh @@ -0,0 +1,10 @@ +# developers must run this file before push or pull request. +# this script contains three parts: +# 1. unit tests for all gemm/conv kernels +# 2. comparison test: compare network fwd/bwd results between CPU, Native, ImplicitGemm +# 3. f32/f16 train/eval test based on mnist and some small datasets + +echo "-------------UNIT TEST START--------------" +pytest ./test +echo "-------------UNIT TEST END--------------" +python ./example/mnist_sparse.py --fp16 \ No newline at end of file diff --git a/version.txt b/version.txt index 63a1a1c..ccbccc3 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.1.9 +2.2.0