Skip to content

Lower aten::_unique2 #4661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ FILTER=
BUILD_ONLY=0
RMBUILD=1
LOGFILE=/tmp/pytorch_cpp_test.log
XLA_EXPERIMENTAL="nonzero:masked_select"
XLA_EXPERIMENTAL="nonzero:masked_select:unique"

if [ "$DEBUG" == "1" ]; then
BUILDTYPE="Debug"
Expand Down
22 changes: 22 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11818,5 +11818,27 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) {
ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestUnique) {
torch::Tensor a =
torch::randint(0, 10, {10, 10}, torch::TensorOptions(torch::kInt));
std::tuple<at::Tensor, at::Tensor, at::Tensor> b = torch::_unique2(
a, /*sorted=*/true, /*return_indices=*/true, /*return_counts=*/true);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
std::tuple<at::Tensor, at::Tensor, at::Tensor> xla_b =
torch::_unique2(xla_a, /*sorted=*/true, /*return_indices=*/true,
/*return_counts=*/true);
AllClose(std::get<0>(b), std::get<0>(xla_b));
AllClose(std::get<1>(b), torch::_cast_Long(std::get<1>(xla_b)));
AllClose(std::get<2>(b), torch::_cast_Long(std::get<2>(xla_b)));
});
if (DebugUtil::ExperimentEnabled("unique")) {
// If the unique support is enabled, we must not see any aten:: calls.
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}
ExpectCounterChanged("xla::_unique2", cpp_test::GetIgnoredCounters());
ResetCounters();
}

} // namespace cpp_test
} // namespace torch_xla
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function run_xla_hlo_debug {

function run_dynamic {
echo "Running in DynamicShape mode: $@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:unique" run_test "$@"
}

function run_eager_debug {
Expand Down
32 changes: 32 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,38 @@ def test_expand_symint_correctness(self):
self.assertEqual(t3.shape[0], 2)
self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu())

def test_unique_ir(self):
x = torch.zeros(10, dtype=torch.int, device=xm.xla_device())
x[0] = 1
x[1] = 2
unique_elements, inverse_indices, counts = torch.unique(
x, sorted=True, return_inverse=True, return_counts=True)
self.assertIsInstance(unique_elements.shape[0], torch.SymInt)
self.assertIsInstance(counts.shape[0], torch.SymInt)
self.assertIsInstance(inverse_indices.shape[0], int)
self.assertEqual(str(unique_elements.shape[0]), '<=10')
self.assertEqual(str(counts.shape[0]), '<=10')
self.assertEqual(inverse_indices.shape[0], 10)
self.assertEqual(int(unique_elements.shape[0]), 3)
self.assertEqual(int(counts.shape[0]), 3)

def test_unique_correctness(self):

def test_fn(*tensors):
results = []
for t in tensors:
results += [torch.unique(t, sorted=True)]
results += list(torch.unique(t, sorted=True, return_inverse=True))
results += list(
torch.unique(
t, sorted=True, return_inverse=True, return_counts=True))
return results

self.runAtenTest([torch.randint(4, 10, size=(10,)) for _ in range(2)] +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also check result of unique is dynamic?

[torch.randint(4, 10, size=(10, 10)) for _ in range(2)] +
[torch.rand(10, 10) for _ in range(2)] + [torch.ones(1)],
test_fn)


if __name__ == '__main__':
assert os.environ['XLA_EXPERIMENTAL'] != ''
Expand Down
21 changes: 21 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3190,4 +3190,25 @@ at::Tensor XLANativeFunctions::_cdist_forward(
bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_unique2(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I don't think this PR will be impacted but just in case you need to develop dynamic models, you'd need to develop on functionization_dynamic_shape branch which has functionalization feature enabled.

const at::Tensor& self, bool sorted, bool return_inverse,
bool return_counts) {
// Note: sorted, return_inverse, return_counts are always treated as True on
// XLA device.
TORCH_LAZY_FN_COUNTER("xla::");
// Initially make XLA handled unique() handling experimental, and opt-in.
if (!DebugUtil::ExperimentEnabled("unique")) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(_unique2)>::call(self, sorted,
return_inverse,
return_counts);
}
std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> res =
tensor_methods::unique2(bridge::GetXlaTensor(self), sorted,
return_inverse, return_counts);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)),
bridge::AtenFromXlaTensor(std::get<1>(res)),
bridge::AtenFromXlaTensor(std::get<2>(res)));
}

} // namespace torch_xla
42 changes: 42 additions & 0 deletions torch_xla/csrc/ops/unique2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "torch_xla/csrc/ops/unique2.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(const torch::lazy::Value& input) {
xla::Shape input_shape = GetXlaShape(input);
int64_t num_elements = xla::ShapeUtil::ElementsIn(input_shape);
xla::PrimitiveType indices_type = GetShapeDimensionType(/*device=*/nullptr);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here, do we want the indices type to be S32 or S64? Though returning S32 does not break any tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave most of the op in S32 to avoid extra cost lol. I think eventually we want to make setDimensionSize and getDimensionSize handle S64, which I don't think is super complicated, but someone need to work on it...

xla::Shape unique_elements_shape =
xla::ShapeUtil::MakeShape(input_shape.element_type(), {num_elements});
xla::Shape inverse_indices_shape =
xla::ShapeUtil::MakeShape(indices_type, input_shape.dimensions());
xla::Shape counts_shape =
xla::ShapeUtil::MakeShape(indices_type, {num_elements});
unique_elements_shape.set_dynamic_dimension(0, true);
counts_shape.set_dynamic_dimension(0, true);
return xla::ShapeUtil::MakeTupleShape(
{unique_elements_shape, inverse_indices_shape, counts_shape});
}

} // namespace

Unique2::Unique2(const torch::lazy::Value& input)
: XlaNode(torch::lazy::OpKind(at::aten::_unique2), {input},
[&]() { return NodeOutputShape(input); },
/*num_outputs=*/3) {}

torch::lazy::NodePtr Unique2::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<Unique2>(operands.at(0));
}

XlaOpVector Unique2::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOps(BuildUnique2(input), loctx);
}

} // namespace torch_xla
16 changes: 16 additions & 0 deletions torch_xla/csrc/ops/unique2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class Unique2 : public XlaNode {
public:
Unique2(const torch::lazy::Value& input);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind adding a ToString()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears ToString() override is only needed when the op has input types other than lazy::Value. So I guess the default should be fine.

};

} // namespace torch_xla
13 changes: 13 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/triangular_solve.h"
#include "torch_xla/csrc/ops/uniform.h"
#include "torch_xla/csrc/ops/unique2.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
#include "torch_xla/csrc/ops/upsample_bilinear2d.h"
#include "torch_xla/csrc/ops/upsample_bilinear2d_backward.h"
Expand Down Expand Up @@ -2579,6 +2580,18 @@ void uniform_(XLATensorPtr& input, double from, double to) {
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
}

std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> unique2(
const XLATensorPtr& input, bool sorted, bool return_inverse,
bool return_counts) {
// Note: sorted, return_inverse, return_counts are always treated as True on
// XLA device.
torch::lazy::NodePtr node =
torch::lazy::MakeNode<Unique2>(input->GetIrValue());
return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)),
input->CreateFrom(torch::lazy::Value(node, 1)),
input->CreateFrom(torch::lazy::Value(node, 2)));
}

XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) {
auto input_shape = input->shape();
int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,10 @@ std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim);

void uniform_(XLATensorPtr& input, double from, double to);

std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> unique2(
const XLATensorPtr& input, bool sorted, bool return_inverse,
bool return_counts);

// Insert a dimension of size one at the specified position.
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);

Expand Down
99 changes: 99 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,4 +1123,103 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,
}
}

std::vector<xla::XlaOp> BuildUnique2(xla::XlaOp input) {
xla::XlaBuilder* builder = input.builder();

xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
int64_t num_elements = xla::ShapeUtil::ElementsIn(input_shape);

xla::XlaOp input_flattened = XlaHelpers::Flatten(input);
xla::PrimitiveType indices_type = GetShapeDimensionType(/*device=*/nullptr);
xla::XlaOp indices = xla::Iota(builder, indices_type, num_elements);

// sort elements and indices
xla::XlaOp sorted =
xla::Sort({input_flattened, indices},
xla::CreateScalarLtComputation(
{input_shape.element_type(), indices_type}, builder));

xla::XlaOp sorted_elements = xla::GetTupleElement(sorted, 0);
xla::XlaOp sorted_indices = xla::GetTupleElement(sorted, 1);

// calculate adjacent difference
xla::XlaOp right = xla::Slice(sorted_elements, {1}, {num_elements}, {1});
xla::XlaOp left = xla::Slice(sorted_elements, {0}, {num_elements - 1}, {1});
xla::XlaOp diff = xla::ConvertElementType(xla::Ne(right, left), indices_type);
xla::XlaOp adjacent_diff = xla::Pad(diff, xla::Zero(builder, indices_type),
xla::MakeEdgePaddingConfig({{1, 0}}));

// calculate cumulative sum
xla::XlaOp cumsum = xla::ReduceWindowWithGeneralPadding(
adjacent_diff, xla::Zero(builder, indices_type),
XlaHelpers::CreateAddComputation(indices_type),
/*window_dimensions=*/{num_elements},
/*window_strides=*/{1},
/*base_dilations=*/{}, /*window_dilations=*/{},
/*padding=*/{{num_elements - 1, 0}});

xla::ScatterDimensionNumbers scatter_dnums;
scatter_dnums.set_index_vector_dim(1);
scatter_dnums.add_inserted_window_dims(0);
scatter_dnums.add_scatter_dims_to_operand_dims(0);
scatter_dnums.add_update_window_dims(1);

auto select_second_combiner = [](xla::XlaOp a, xla::XlaOp b) -> xla::XlaOp {
return b;
};

auto count_combiner = [](xla::XlaOp a, xla::XlaOp b) -> xla::XlaOp {
xla::XlaOp one =
xla::One(a.builder(), XlaHelpers::ShapeOfXlaOp(a).element_type());
return xla::Add(a, one);
};

// 1. calculate unique_elements
xla::XlaOp sorted_elements_2d =
xla::Reshape(sorted_elements, {num_elements, 1});
xla::XlaOp unique_elements_2d = xla::Scatter(
xla::Zeros(builder, xla::ShapeUtil::MakeShape(input_shape.element_type(),
{num_elements, 1})),
cumsum, sorted_elements_2d,
MakeScatterComputation(select_second_combiner,
input_shape.element_type()),
scatter_dnums,
/*indices_are_sorted=*/true, /*unique_indices=*/false);
xla::XlaOp unique_elements = XlaHelpers::Flatten(unique_elements_2d);

// 2. calculate inverse_indices
xla::XlaOp cumsum_2d = xla::Reshape(cumsum, {num_elements, 1});
xla::XlaOp inverse_indices_2d = xla::Scatter(
xla::Zeros(builder,
xla::ShapeUtil::MakeShape(indices_type, {num_elements, 1})),
sorted_indices, cumsum_2d,
MakeScatterComputation(select_second_combiner, indices_type),
scatter_dnums,
/*indices_are_sorted=*/false, /*unique_indices=*/true);
xla::XlaOp inverse_indices = xla::Reshape(
XlaHelpers::Flatten(inverse_indices_2d), input_shape.dimensions());

// 3. calculate counts
xla::XlaOp counts_2d = xla::Scatter(
xla::Zeros(builder,
xla::ShapeUtil::MakeShape(indices_type, {num_elements, 1})),
cumsum, cumsum_2d, MakeScatterComputation(count_combiner, indices_type),
scatter_dnums,
/*indices_are_sorted=*/true, /*unique_indices=*/false);
xla::XlaOp counts = XlaHelpers::Flatten(counts_2d);

// 4. calculate number of unique elements
xla::XlaOp num_unique_elements =
xla::Reduce(adjacent_diff, xla::Zero(builder, indices_type),
XlaHelpers::CreateAddComputation(indices_type), {0}) +
xla::One(builder, indices_type);

std::vector<xla::XlaOp> results = {
/*unique_elements=*/xla::SetDimensionSize(unique_elements,
num_unique_elements, 0),
/*inverse_indices=*/inverse_indices,
/*counts=*/xla::SetDimensionSize(counts, num_unique_elements, 0)};
return results;
} // namespace torch_xla

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,6 @@ xla::XlaOp BuildAddcmul(xla::XlaOp input, xla::XlaOp t1, xla::XlaOp t2,
xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,
bool use_hamming, bool use_chebyshev);

std::vector<xla::XlaOp> BuildUnique2(xla::XlaOp input);

} // namespace torch_xla
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ supported:
- triangular_solve
- unbind.int
- uniform_
- _unique2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how _unique2 differs from unique

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _unique function does not support return_counts. It looks like it's not accessible to normal user because torch.unique is dispatched to either unique_dim or _unique2 here.

- unsqueeze
- unsqueeze_
- upsample_bilinear2d
Expand Down