-
Notifications
You must be signed in to change notification settings - Fork 545
Lower aten::_unique2 #4661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Lower aten::_unique2 #4661
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3190,4 +3190,25 @@ at::Tensor XLANativeFunctions::_cdist_forward( | |
bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p)); | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_unique2( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: I don't think this PR will be impacted but just in case you need to develop dynamic models, you'd need to develop on |
||
const at::Tensor& self, bool sorted, bool return_inverse, | ||
bool return_counts) { | ||
// Note: sorted, return_inverse, return_counts are always treated as True on | ||
// XLA device. | ||
TORCH_LAZY_FN_COUNTER("xla::"); | ||
// Initially make XLA handled unique() handling experimental, and opt-in. | ||
if (!DebugUtil::ExperimentEnabled("unique")) { | ||
return at::native::call_fallback_fn<&xla_cpu_fallback, | ||
ATEN_OP(_unique2)>::call(self, sorted, | ||
return_inverse, | ||
return_counts); | ||
} | ||
std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> res = | ||
tensor_methods::unique2(bridge::GetXlaTensor(self), sorted, | ||
return_inverse, return_counts); | ||
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), | ||
bridge::AtenFromXlaTensor(std::get<1>(res)), | ||
bridge::AtenFromXlaTensor(std::get<2>(res))); | ||
} | ||
|
||
} // namespace torch_xla |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include "torch_xla/csrc/ops/unique2.h" | ||
|
||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
#include "torch_xla/csrc/xla_lower_util.h" | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
xla::Shape NodeOutputShape(const torch::lazy::Value& input) { | ||
xla::Shape input_shape = GetXlaShape(input); | ||
int64_t num_elements = xla::ShapeUtil::ElementsIn(input_shape); | ||
xla::PrimitiveType indices_type = GetShapeDimensionType(/*device=*/nullptr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here, do we want the indices type to be S32 or S64? Though returning S32 does not break any tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would leave most of the op in S32 to avoid extra cost lol. I think eventually we want to make |
||
xla::Shape unique_elements_shape = | ||
xla::ShapeUtil::MakeShape(input_shape.element_type(), {num_elements}); | ||
xla::Shape inverse_indices_shape = | ||
xla::ShapeUtil::MakeShape(indices_type, input_shape.dimensions()); | ||
xla::Shape counts_shape = | ||
xla::ShapeUtil::MakeShape(indices_type, {num_elements}); | ||
unique_elements_shape.set_dynamic_dimension(0, true); | ||
counts_shape.set_dynamic_dimension(0, true); | ||
return xla::ShapeUtil::MakeTupleShape( | ||
{unique_elements_shape, inverse_indices_shape, counts_shape}); | ||
} | ||
|
||
} // namespace | ||
|
||
Unique2::Unique2(const torch::lazy::Value& input) | ||
: XlaNode(torch::lazy::OpKind(at::aten::_unique2), {input}, | ||
[&]() { return NodeOutputShape(input); }, | ||
/*num_outputs=*/3) {} | ||
|
||
torch::lazy::NodePtr Unique2::Clone(torch::lazy::OpList operands) const { | ||
return torch::lazy::MakeNode<Unique2>(operands.at(0)); | ||
} | ||
|
||
XlaOpVector Unique2::Lower(LoweringContext* loctx) const { | ||
xla::XlaOp input = loctx->GetOutputOp(operand(0)); | ||
return ReturnOps(BuildUnique2(input), loctx); | ||
} | ||
|
||
} // namespace torch_xla |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#pragma once | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
|
||
class Unique2 : public XlaNode { | ||
public: | ||
Unique2(const torch::lazy::Value& input); | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: mind adding a ToString()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears ToString() override is only needed when the op has input types other than lazy::Value. So I guess the default should be fine. |
||
}; | ||
|
||
} // namespace torch_xla |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -321,6 +321,7 @@ supported: | |
- triangular_solve | ||
- unbind.int | ||
- uniform_ | ||
- _unique2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
- unsqueeze | ||
- unsqueeze_ | ||
- upsample_bilinear2d | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also check result of
unique
is dynamic?