diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index ad8bb43de305..092979e44a56 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -8,7 +8,7 @@ FILTER= BUILD_ONLY=0 RMBUILD=1 LOGFILE=/tmp/pytorch_cpp_test.log -XLA_EXPERIMENTAL="nonzero:masked_select" +XLA_EXPERIMENTAL="nonzero:masked_select:unique" if [ "$DEBUG" == "1" ]; then BUILDTYPE="Debug" diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 58f1ebe1564a..6c6ce7e58d60 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -11818,5 +11818,27 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) { ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestUnique) { + torch::Tensor a = + torch::randint(0, 10, {10, 10}, torch::TensorOptions(torch::kInt)); + std::tuple b = torch::_unique2( + a, /*sorted=*/true, /*return_indices=*/true, /*return_counts=*/true); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + std::tuple xla_b = + torch::_unique2(xla_a, /*sorted=*/true, /*return_indices=*/true, + /*return_counts=*/true); + AllClose(std::get<0>(b), std::get<0>(xla_b)); + AllClose(std::get<1>(b), torch::_cast_Long(std::get<1>(xla_b))); + AllClose(std::get<2>(b), torch::_cast_Long(std::get<2>(xla_b))); + }); + if (DebugUtil::ExperimentEnabled("unique")) { + // If the unique support is enabled, we must not see any aten:: calls. + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + } + ExpectCounterChanged("xla::_unique2", cpp_test::GetIgnoredCounters()); + ResetCounters(); +} + } // namespace cpp_test } // namespace torch_xla diff --git a/test/run_tests.sh b/test/run_tests.sh index f30818f09ccd..bf3013de5c41 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -76,7 +76,7 @@ function run_xla_hlo_debug { function run_dynamic { echo "Running in DynamicShape mode: $@" - XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@" + XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:unique" run_test "$@" } function run_eager_debug { diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 068954900934..1b08493c4fe8 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -178,6 +178,38 @@ def test_expand_symint_correctness(self): self.assertEqual(t3.shape[0], 2) self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu()) + def test_unique_ir(self): + x = torch.zeros(10, dtype=torch.int, device=xm.xla_device()) + x[0] = 1 + x[1] = 2 + unique_elements, inverse_indices, counts = torch.unique( + x, sorted=True, return_inverse=True, return_counts=True) + self.assertIsInstance(unique_elements.shape[0], torch.SymInt) + self.assertIsInstance(counts.shape[0], torch.SymInt) + self.assertIsInstance(inverse_indices.shape[0], int) + self.assertEqual(str(unique_elements.shape[0]), '<=10') + self.assertEqual(str(counts.shape[0]), '<=10') + self.assertEqual(inverse_indices.shape[0], 10) + self.assertEqual(int(unique_elements.shape[0]), 3) + self.assertEqual(int(counts.shape[0]), 3) + + def test_unique_correctness(self): + + def test_fn(*tensors): + results = [] + for t in tensors: + results += [torch.unique(t, sorted=True)] + results += list(torch.unique(t, sorted=True, return_inverse=True)) + results += list( + torch.unique( + t, sorted=True, return_inverse=True, return_counts=True)) + return results + + self.runAtenTest([torch.randint(4, 10, size=(10,)) for _ in range(2)] + + [torch.randint(4, 10, size=(10, 10)) for _ in range(2)] + + [torch.rand(10, 10) for _ in range(2)] + [torch.ones(1)], + test_fn) + if __name__ == '__main__': assert os.environ['XLA_EXPERIMENTAL'] != '' diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index df76e6dff6ca..55cdf4297c97 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3190,4 +3190,25 @@ at::Tensor XLANativeFunctions::_cdist_forward( bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p)); } +std::tuple XLANativeFunctions::_unique2( + const at::Tensor& self, bool sorted, bool return_inverse, + bool return_counts) { + // Note: sorted, return_inverse, return_counts are always treated as True on + // XLA device. + TORCH_LAZY_FN_COUNTER("xla::"); + // Initially make XLA handled unique() handling experimental, and opt-in. + if (!DebugUtil::ExperimentEnabled("unique")) { + return at::native::call_fallback_fn<&xla_cpu_fallback, + ATEN_OP(_unique2)>::call(self, sorted, + return_inverse, + return_counts); + } + std::tuple res = + tensor_methods::unique2(bridge::GetXlaTensor(self), sorted, + return_inverse, return_counts); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), + bridge::AtenFromXlaTensor(std::get<1>(res)), + bridge::AtenFromXlaTensor(std::get<2>(res))); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/unique2.cpp b/torch_xla/csrc/ops/unique2.cpp new file mode 100644 index 000000000000..f0af99259ef7 --- /dev/null +++ b/torch_xla/csrc/ops/unique2.cpp @@ -0,0 +1,42 @@ +#include "torch_xla/csrc/ops/unique2.h" + +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/xla_lower_util.h" + +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::Value& input) { + xla::Shape input_shape = GetXlaShape(input); + int64_t num_elements = xla::ShapeUtil::ElementsIn(input_shape); + xla::PrimitiveType indices_type = GetShapeDimensionType(/*device=*/nullptr); + xla::Shape unique_elements_shape = + xla::ShapeUtil::MakeShape(input_shape.element_type(), {num_elements}); + xla::Shape inverse_indices_shape = + xla::ShapeUtil::MakeShape(indices_type, input_shape.dimensions()); + xla::Shape counts_shape = + xla::ShapeUtil::MakeShape(indices_type, {num_elements}); + unique_elements_shape.set_dynamic_dimension(0, true); + counts_shape.set_dynamic_dimension(0, true); + return xla::ShapeUtil::MakeTupleShape( + {unique_elements_shape, inverse_indices_shape, counts_shape}); +} + +} // namespace + +Unique2::Unique2(const torch::lazy::Value& input) + : XlaNode(torch::lazy::OpKind(at::aten::_unique2), {input}, + [&]() { return NodeOutputShape(input); }, + /*num_outputs=*/3) {} + +torch::lazy::NodePtr Unique2::Clone(torch::lazy::OpList operands) const { + return torch::lazy::MakeNode(operands.at(0)); +} + +XlaOpVector Unique2::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + return ReturnOps(BuildUnique2(input), loctx); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/unique2.h b/torch_xla/csrc/ops/unique2.h new file mode 100644 index 000000000000..480b29ec09b3 --- /dev/null +++ b/torch_xla/csrc/ops/unique2.h @@ -0,0 +1,16 @@ +#pragma once + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class Unique2 : public XlaNode { + public: + Unique2(const torch::lazy::Value& input); + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; +}; + +} // namespace torch_xla diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 3f450cfe7c65..b9942f94e5e3 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -124,6 +124,7 @@ #include "torch_xla/csrc/ops/topk.h" #include "torch_xla/csrc/ops/triangular_solve.h" #include "torch_xla/csrc/ops/uniform.h" +#include "torch_xla/csrc/ops/unique2.h" #include "torch_xla/csrc/ops/unsqueeze.h" #include "torch_xla/csrc/ops/upsample_bilinear2d.h" #include "torch_xla/csrc/ops/upsample_bilinear2d_backward.h" @@ -2579,6 +2580,18 @@ void uniform_(XLATensorPtr& input, double from, double to) { XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); } +std::tuple unique2( + const XLATensorPtr& input, bool sorted, bool return_inverse, + bool return_counts) { + // Note: sorted, return_inverse, return_counts are always treated as True on + // XLA device. + torch::lazy::NodePtr node = + torch::lazy::MakeNode(input->GetIrValue()); + return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), + input->CreateFrom(torch::lazy::Value(node, 1)), + input->CreateFrom(torch::lazy::Value(node, 2))); +} + XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) { auto input_shape = input->shape(); int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 14e23aeed9e4..b6f5c6cb8e78 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -851,6 +851,10 @@ std::vector unbind(const XLATensorPtr& input, int64_t dim); void uniform_(XLATensorPtr& input, double from, double to); +std::tuple unique2( + const XLATensorPtr& input, bool sorted, bool return_inverse, + bool return_counts); + // Insert a dimension of size one at the specified position. XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 15f55336ecd4..d8103db8223c 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1123,4 +1123,103 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, } } +std::vector BuildUnique2(xla::XlaOp input) { + xla::XlaBuilder* builder = input.builder(); + + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + int64_t num_elements = xla::ShapeUtil::ElementsIn(input_shape); + + xla::XlaOp input_flattened = XlaHelpers::Flatten(input); + xla::PrimitiveType indices_type = GetShapeDimensionType(/*device=*/nullptr); + xla::XlaOp indices = xla::Iota(builder, indices_type, num_elements); + + // sort elements and indices + xla::XlaOp sorted = + xla::Sort({input_flattened, indices}, + xla::CreateScalarLtComputation( + {input_shape.element_type(), indices_type}, builder)); + + xla::XlaOp sorted_elements = xla::GetTupleElement(sorted, 0); + xla::XlaOp sorted_indices = xla::GetTupleElement(sorted, 1); + + // calculate adjacent difference + xla::XlaOp right = xla::Slice(sorted_elements, {1}, {num_elements}, {1}); + xla::XlaOp left = xla::Slice(sorted_elements, {0}, {num_elements - 1}, {1}); + xla::XlaOp diff = xla::ConvertElementType(xla::Ne(right, left), indices_type); + xla::XlaOp adjacent_diff = xla::Pad(diff, xla::Zero(builder, indices_type), + xla::MakeEdgePaddingConfig({{1, 0}})); + + // calculate cumulative sum + xla::XlaOp cumsum = xla::ReduceWindowWithGeneralPadding( + adjacent_diff, xla::Zero(builder, indices_type), + XlaHelpers::CreateAddComputation(indices_type), + /*window_dimensions=*/{num_elements}, + /*window_strides=*/{1}, + /*base_dilations=*/{}, /*window_dilations=*/{}, + /*padding=*/{{num_elements - 1, 0}}); + + xla::ScatterDimensionNumbers scatter_dnums; + scatter_dnums.set_index_vector_dim(1); + scatter_dnums.add_inserted_window_dims(0); + scatter_dnums.add_scatter_dims_to_operand_dims(0); + scatter_dnums.add_update_window_dims(1); + + auto select_second_combiner = [](xla::XlaOp a, xla::XlaOp b) -> xla::XlaOp { + return b; + }; + + auto count_combiner = [](xla::XlaOp a, xla::XlaOp b) -> xla::XlaOp { + xla::XlaOp one = + xla::One(a.builder(), XlaHelpers::ShapeOfXlaOp(a).element_type()); + return xla::Add(a, one); + }; + + // 1. calculate unique_elements + xla::XlaOp sorted_elements_2d = + xla::Reshape(sorted_elements, {num_elements, 1}); + xla::XlaOp unique_elements_2d = xla::Scatter( + xla::Zeros(builder, xla::ShapeUtil::MakeShape(input_shape.element_type(), + {num_elements, 1})), + cumsum, sorted_elements_2d, + MakeScatterComputation(select_second_combiner, + input_shape.element_type()), + scatter_dnums, + /*indices_are_sorted=*/true, /*unique_indices=*/false); + xla::XlaOp unique_elements = XlaHelpers::Flatten(unique_elements_2d); + + // 2. calculate inverse_indices + xla::XlaOp cumsum_2d = xla::Reshape(cumsum, {num_elements, 1}); + xla::XlaOp inverse_indices_2d = xla::Scatter( + xla::Zeros(builder, + xla::ShapeUtil::MakeShape(indices_type, {num_elements, 1})), + sorted_indices, cumsum_2d, + MakeScatterComputation(select_second_combiner, indices_type), + scatter_dnums, + /*indices_are_sorted=*/false, /*unique_indices=*/true); + xla::XlaOp inverse_indices = xla::Reshape( + XlaHelpers::Flatten(inverse_indices_2d), input_shape.dimensions()); + + // 3. calculate counts + xla::XlaOp counts_2d = xla::Scatter( + xla::Zeros(builder, + xla::ShapeUtil::MakeShape(indices_type, {num_elements, 1})), + cumsum, cumsum_2d, MakeScatterComputation(count_combiner, indices_type), + scatter_dnums, + /*indices_are_sorted=*/true, /*unique_indices=*/false); + xla::XlaOp counts = XlaHelpers::Flatten(counts_2d); + + // 4. calculate number of unique elements + xla::XlaOp num_unique_elements = + xla::Reduce(adjacent_diff, xla::Zero(builder, indices_type), + XlaHelpers::CreateAddComputation(indices_type), {0}) + + xla::One(builder, indices_type); + + std::vector results = { + /*unique_elements=*/xla::SetDimensionSize(unique_elements, + num_unique_elements, 0), + /*inverse_indices=*/inverse_indices, + /*counts=*/xla::SetDimensionSize(counts, num_unique_elements, 0)}; + return results; +} // namespace torch_xla + } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 460f5729767f..d82908b0c525 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -137,4 +137,6 @@ xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2, xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p, bool use_hamming, bool use_chebyshev); +std::vector BuildUnique2(xla::XlaOp input); + } // namespace torch_xla diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 1785b93a1166..a9b679eaa320 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -321,6 +321,7 @@ supported: - triangular_solve - unbind.int - uniform_ + - _unique2 - unsqueeze - unsqueeze_ - upsample_bilinear2d