Skip to content

Commit 6048a86

Browse files
akuegelGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Avoid a segfault in StreamAttributeAnnotator
Currently it is assumed that GetTupleElement is never the root of a computation. That assumption is not necessarily true, e.g. during autotuning of Cublas Gemm calls we can have a GetTupleElement op as root. PiperOrigin-RevId: 825932301
1 parent 316c9c9 commit 6048a86

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

xla/service/gpu/transforms/stream_attribute_annotator.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ absl::StatusOr<bool> AnnotateStreamAttributesForUsers(
156156
std::vector<HloInstruction*> all_consumers;
157157
for (auto user : instr->users()) {
158158
if (HloPredicateIsOp<HloOpcode::kGetTupleElement>(user)) {
159+
if (user->user_count() == 0) {
160+
continue;
161+
}
159162
user = user->users()[0];
160163
}
161164
all_consumers.push_back(user);

xla/service/gpu/transforms/stream_attribute_annotator_test.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,28 @@ TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) {
159159
EXPECT_EQ(gpu_config.wait_on_operation_queues()[0], 1);
160160
}
161161

162+
TEST_F(StreamAttributeAnnotatorTest, GTENoUserIsHandled) {
163+
constexpr absl::string_view kHloString = R"(
164+
HloModule ModuleWithAsync
165+
166+
ENTRY entry {
167+
p1_32 = f32[16,32] parameter(0)
168+
p2_32 = f32[32,16] parameter(1)
169+
170+
custom-call.3 = (f32[16,16], s8[1028]{0}) custom-call(p1_32, p2_32), custom_call_target="__cublas$gemm", backend_config={"operation_queue_id":"1","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","grad_x":false,"grad_y":false}}
171+
ROOT get-tuple-element.24 = f32[16,16] get-tuple-element(custom-call.3), index=0
172+
}
173+
)";
174+
175+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
176+
ParseAndReturnVerifiedModule(kHloString));
177+
178+
StreamAttributeAnnotator attr_annotator{device_description()};
179+
bool changed;
180+
TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
181+
EXPECT_FALSE(changed);
182+
}
183+
162184
TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) {
163185
constexpr absl::string_view kHloString = R"(
164186
HloModule ModuleWithFusion

0 commit comments

Comments
 (0)