diff --git a/xla/service/while_loop_simplifier.cc b/xla/service/while_loop_simplifier.cc index f2e2f9676277d..9a11a3b2d430c 100644 --- a/xla/service/while_loop_simplifier.cc +++ b/xla/service/while_loop_simplifier.cc @@ -151,7 +151,7 @@ static absl::StatusOr RemoveDeadTupleIndices( HloInstruction* while_op, absl::flat_hash_set& used_tuple_indices, std::optional> dead_to_surviving_index = std::nullopt) { - auto copy_original_value = + auto copy_remaining_original_arrays = [&](const HloInstruction* src_instruction, HloInstruction* dest_instruction, const absl::flat_hash_map& old_to_new_tuple_idx) { @@ -305,8 +305,9 @@ static absl::StatusOr RemoveDeadTupleIndices( CopyFrontendAttributes(while_op, new_while_op); CopyMetadata(while_op, new_while_op); - copy_original_value(while_init, new_while_init, old_to_new_tuple_idx); - copy_original_value(while_op, new_while_op, old_to_new_tuple_idx); + copy_remaining_original_arrays(while_init, new_while_init, + old_to_new_tuple_idx); + copy_remaining_original_arrays(while_op, new_while_op, old_to_new_tuple_idx); // Create a tuple op that recreates the output of the old while op. That is, // we transform to @@ -1193,6 +1194,20 @@ static std::vector GetFlatTupleElems( } static absl::StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { + auto flatten_original_value = [&](HloInstruction* old_instr, + HloInstruction* new_instr) { + if (old_instr->original_value()) { + auto new_original_value = + std::make_shared(new_instr->shape()); + int64_t i = 0; + for (auto& [shape_index, original_array] : + old_instr->original_value()->tree().leaves()) { + *new_original_value->mutable_tree()->mutable_element({i++}) = + original_array; + } + new_instr->set_original_value(new_original_value); + } + }; HloModule* module = while_op->GetModule(); HloComputation* computation = while_op->parent(); auto* while_init = while_op->mutable_operand(0); @@ -1294,6 +1309,9 @@ static absl::StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { for (auto& instr : new_instrs) { computation->AddInstruction(std::move(instr)); } + + flatten_original_value(while_init, new_while_op->mutable_operand(0)); + flatten_original_value(while_op, new_while_op); return true; } diff --git a/xla/service/while_loop_simplifier_test.cc b/xla/service/while_loop_simplifier_test.cc index 53b701ec86294..5b1efe3d32b42 100644 --- a/xla/service/while_loop_simplifier_test.cc +++ b/xla/service/while_loop_simplifier_test.cc @@ -1535,5 +1535,51 @@ ENTRY %main (arg.0: f32[3], arg.1: f32[2]) -> (f32[3], f32[2], f32[2], f32[3]) { R"(({"arg.0"}, {"arg.1"}, {"constant.0"}))"); } +TEST_F(WhileLoopSimplifierTest, FlattenNestedTupleWithOriginalValue) { + const std::string hlo_string = R"( + HloModule Test + Body { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ta = (s32[1]) get-tuple-element(param), index=0 + a = s32[1] get-tuple-element(ta), index=0 + a.1 = s32[1] add(a, a) + tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1 + ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + } + Cond { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + b = s32[2] constant({0,1}) + c = s32[3] constant({0,1,2}) + d = s32[4] constant({0,1,2,3}) + ta = (s32[1]) tuple(a) + td = (s32[4]) tuple(d) + tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td) + init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd), origin={(({"a"}), ( + {"b"}, {"c"}, ({"d"})))} + ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init), + condition=Cond, body=Body, origin={(({"while.116" {0}}), ( + {"while.116" {1}}, {"while.116" {2}}, ({"while.116" {3}})))} + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopSimplifier().Run(module.get())); + EXPECT_TRUE(changed); + HloInstruction* while_instr = FindFirstWhile(module.get()); + ASSERT_NE(while_instr->original_value(), nullptr); + EXPECT_EQ( + while_instr->original_value()->ToString(), + R"(({"while.116" {0}}, {"while.116" {1}}, {"while.116" {2}}, {"while.116" {3}}))"); + HloInstruction* while_init = while_instr->while_init(); + ASSERT_NE(while_init->original_value(), nullptr); + EXPECT_EQ(while_init->original_value()->ToString(), + R"(({"a"}, {"b"}, {"c"}, {"d"}))"); +} + } // namespace } // namespace xla