Skip to content

Commit fcbdbd6

Browse files
pearupytorchmergebot
authored andcommitted
Fix silent nnz overflow for large sparse compressed tensors. (pytorch#102523)
Fixes pytorch#102520 Pull Request resolved: pytorch#102523 Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch
1 parent 77f9701 commit fcbdbd6

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

aten/src/ATen/SparseCsrTensorImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
5454
const Tensor& values() const {
5555
return values_;
5656
}
57-
int nnz() {
57+
int64_t nnz() {
5858
return col_indices_.size(-1);
5959
}
6060

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
213213
DimVector compressed_indices_batchsize = DimVector(compressed_indices.sizes().slice(0, batch_ndim));
214214
DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim));
215215
DimVector values_batchsize = DimVector(values.sizes().slice(0, batch_ndim));
216-
const int values_nnz = values.size(batch_ndim);
216+
const int64_t values_nnz = values.size(batch_ndim);
217217
DimVector values_blocksize = DimVector(values.sizes().slice(batch_ndim + 1, block_ndim));
218218
DimVector values_densesize = DimVector(values.sizes().slice(batch_ndim + 1 + block_ndim, dense_ndim));
219219
TORCH_CHECK(
@@ -229,9 +229,9 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
229229
") must be divisible with blocksize[", i, "] (=", blocksize[i],
230230
") as defined by values shape");
231231
}
232-
const int nrows = size[batch_ndim] / blocksize[0];
233-
const int ncols = size[batch_ndim + 1] / blocksize[1];
234-
int compressed_dim_size, plain_dim_size;
232+
const int64_t nrows = size[batch_ndim] / blocksize[0];
233+
const int64_t ncols = size[batch_ndim + 1] / blocksize[1];
234+
int64_t compressed_dim_size, plain_dim_size;
235235
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
236236
[&] { return std::make_tuple(nrows, ncols); },
237237
[&] { return std::make_tuple(ncols, nrows); });

test/test_sparse_csr.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU)
1313
from torch.testing._internal.common_device_type import \
1414
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
15-
precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan)
15+
precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
16+
largeTensorTest)
1617
from torch.testing._internal.common_methods_invocations import \
1718
(op_db, sparse_csr_unary_ufuncs, ReductionOpInfo)
1819
from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUDA
@@ -930,6 +931,18 @@ def test_csr_is_contiguous(self):
930931
with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have is_contiguous"):
931932
a.is_contiguous()
932933

934+
@onlyCPU
935+
@largeTensorTest("20GB", "cpu")
936+
def test_csr_nnz(self):
937+
# Tests the limits of the number of specified elements in CSR tensors, see gh-102520.
938+
for nnz in [0, 2**31]:
939+
rows, cols = 1, max(nnz, 1)
940+
crow_indices = torch.tensor([0, nnz], dtype=torch.int64)
941+
col_indices = torch.arange(nnz, dtype=torch.int64)
942+
values = torch.ones(nnz, dtype=torch.int8)
943+
a = torch.sparse_csr_tensor(crow_indices, col_indices, values, (rows, cols))
944+
self.assertEqual(a._nnz(), nnz)
945+
933946
def test_csr_double_to_sparse_csr(self):
934947
a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
935948
a.to_sparse_csr().to_sparse_csr()

0 commit comments

Comments
 (0)