Skip to content

Commit

Permalink
[XLA:GPU] Add comments to remove send/recv validation when we replace it
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729604926
  • Loading branch information
frgossen authored and Google-ML-Automation committed Feb 21, 2025
1 parent eb2b573 commit 90e75b6
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 27 deletions.
2 changes: 1 addition & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ cc_library(
"//xla/hlo/parser:hlo_parser",
"//xla/service/graphcycles",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
Expand Down
29 changes: 28 additions & 1 deletion xla/service/collective_permute_cycle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "xla/service/source_target_pairs.h"

namespace xla {
Expand Down Expand Up @@ -124,7 +126,32 @@ CycleType GetCycleType(const SourceTargetPairs& pairs) {
}

bool HasCycles(const SourceTargetPairs& pairs) {
return GetCycleType(pairs) != CycleType::kNone;
// Build source-target map for quick lookup.
std::vector<int64_t> source_target_map;
for (int64_t i = 0; i < pairs.size(); ++i) {
int64_t source = pairs[i].source;
int64_t target = pairs[i].target;
while (source_target_map.size() <= source) source_target_map.push_back(-1);
source_target_map[source] = target;
}

// Cache indices known to be acyclic.
absl::flat_hash_set<int64_t> acyclic;

// Search for cycles.
int64_t n = source_target_map.size();
for (int64_t i = 0; i < n; ++i) {
absl::flat_hash_set<int64_t> path;
while (i != -1 && !acyclic.contains(i)) {
if (path.contains(i)) return true;
path.insert(i);
i = i < n ? source_target_map[i] : -1;
}
acyclic.insert(path.begin(), path.end());
}

// No cycles found.
return false;
}

} // namespace collective_permute_cycle
Expand Down
51 changes: 47 additions & 4 deletions xla/service/collective_permute_cycle_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ TEST_F(CollectivePermuteUtilsTest, HasCycles) {
EXPECT_TRUE(HasCycles(fwd4_.cycle));
EXPECT_TRUE(HasCycles(bwd4_.cycle));

EXPECT_FALSE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 2}})))
<< "Lasso 3->2";
EXPECT_FALSE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 1}})))
<< "Lasso 3->1";
EXPECT_TRUE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 2}})));
EXPECT_TRUE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 1}})));

EXPECT_FALSE(HasCycles(SourceTargetPairs({{1, 2}, {2, 3}, {3, 0}})))
<< "Forward only";
Expand Down Expand Up @@ -159,6 +157,51 @@ TEST_F(CollectivePermuteUtilsTest, GetCycleType) {
<< "Lasso 3->1";
}

TEST_F(CollectivePermuteUtilsTest, HasCyclesTwoCycles) {
// Cycle: 0->1, 1->2, 2->3, 3->0
// Cycle: 4->5, 5->6, 6->7, 7->4
SourceTargetPairs two_cyles(
{{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 5}, {5, 6}, {6, 7}, {7, 4}});
EXPECT_TRUE(HasCycles(two_cyles));
}

TEST_F(CollectivePermuteUtilsTest, HasCyclesOneCycleAndOneAlmostCycle) {
// Not a cycle: 0->1, 1->2, 2->3 (missing: 3->4)
// Cycle: 4->5, 5->6, 6->7, 7->4
SourceTargetPairs two_cyles(
{{0, 1}, {1, 2}, {2, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 4}});
EXPECT_TRUE(HasCycles(two_cyles));
}

TEST_F(CollectivePermuteUtilsTest, HasCyclesTwoAlmostCycles) {
// Not a cycle: 0->1, 1->2, 3->0 (missing: 2->3)
// Not a cycle: 4->5, 5->6, 7->4 (missing: 6->7)
SourceTargetPairs two_cyles({{0, 1}, {1, 2}, {3, 0}, {4, 5}, {5, 6}, {7, 4}});
EXPECT_FALSE(HasCycles(two_cyles));
}

TEST_F(CollectivePermuteUtilsTest, HasCyclesTwoCyclesInterleaved) {
// Cycle: 0->2, 2->4, 4->6, 6->0
// Cycle: 1->3, 3->5, 5->7, 7->1
SourceTargetPairs two_cyles(
{{0, 2}, {2, 4}, {4, 6}, {6, 0}, {1, 3}, {3, 5}, {5, 7}, {7, 1}});
EXPECT_TRUE(HasCycles(two_cyles));
}

TEST_F(CollectivePermuteUtilsTest, HasCyclesSimpleCycle) {
// Cycle: 0->1, 1->2, 2->3, 3->4, 4->5, 5->6, 6->7, 7->0
SourceTargetPairs two_cyles(
{{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}, {6, 7}, {7, 0}});
EXPECT_TRUE(HasCycles(two_cyles));
}

TEST_F(CollectivePermuteUtilsTest, HasCyclesSimpleAlmostCycle) {
// Not a cycle: 0->1, 1->2, 2->3, 4->5, 5->6, 6->7, 7->0 (missing: 3->4)
SourceTargetPairs two_cyles(
{{0, 1}, {1, 2}, {2, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 0}});
EXPECT_FALSE(HasCycles(two_cyles));
}

} // namespace
} // namespace collective_permute_cycle
} // namespace xla
47 changes: 36 additions & 11 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,8 @@ absl::StatusOr<HloInstruction*> CloneBackwardChain(
UpdateInstructionChannelId(cloned, next_channel_id);
clone_map[chain_op] = cloned;
if (postprocess_pipelined_ops.has_value()) {
TF_RETURN_IF_ERROR((*postprocess_pipelined_ops)(cloned));
TF_RETURN_IF_ERROR(
(*postprocess_pipelined_ops)(cloned, /*new_while_instr=*/nullptr));
}
last_cloned = cloned;
if (loop_variant_parameter_info != nullptr &&
Expand Down Expand Up @@ -1765,6 +1766,8 @@ absl::Status TransformLoopForward(
TF_RETURN_IF_ERROR(
UpdateControlDependencies(instr, cloned_instr, while_body_to_peeled));
UpdateInstructionChannelId(cloned_instr, next_channel_id);
// TODO(frgossen): Remove this once we have eliminated the need for
// send/recv validation.
TF_RETURN_IF_ERROR(UpdateSendRecvValidation(
cloned_instr, true, CollectivePipeliner::PipeliningDirection::kForward,
loop_analysis));
Expand Down Expand Up @@ -1832,6 +1835,8 @@ absl::Status TransformLoopForward(
loop_computation->parent()->AddEmbeddedComputation(
while_body->CloneWithReplacements(&replacements));
for (HloInstruction* instruction : new_while_body->instructions()) {
// TODO(frgossen): Remove this once we have eliminated the need for
// send/recv validation.
TF_RETURN_IF_ERROR(UpdateSendRecvValidation(
instruction, false, CollectivePipeliner::PipeliningDirection::kForward,
loop_analysis));
Expand Down Expand Up @@ -1910,7 +1915,8 @@ absl::Status TransformLoopForward(
}

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(processed, /*new_while_instr=*/nullptr));
}

InstructionMap cloned_map = pipelined_values_map;
Expand All @@ -1922,7 +1928,8 @@ absl::Status TransformLoopForward(
new_operands));
cloned_map[formatting_op] = processed;
if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(processed, /*new_while_instr=*/nullptr));
}
}
return processed;
Expand Down Expand Up @@ -2633,9 +2640,10 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis,
static absl::Status TransformLoopBackward(
const WhileLoopAnalysis& loop_analysis, bool insert_non_alias_custom_call,
int64_t level_to_operate_on, bool process_different_sized_ops,
HloPredicate should_process, HloPredicate acceptable_formatting,
HloPredicate acceptable_formatting,
CollectivePipeliner::HloPostprocessor postprocess_peeled,
CollectivePipeliner::HloPostprocessor postprocess_rotated,
CollectivePipeliner::HloPostprocessor postprocess_peeled_trailing_op,
int64_t& next_channel_id,
CollectivePipeliner::HloPostprocessor post_processing_fn) {
// Defining some maps/sets to keep track of instructions duplicated.
Expand Down Expand Up @@ -2739,10 +2747,12 @@ static absl::Status TransformLoopBackward(
post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx]));
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx],
/*new_while_instr=*/nullptr));
}
if (postprocess_peeled.has_value()) {
TF_RETURN_IF_ERROR(postprocess_peeled.value()(new_init_operands[idx]));
TF_RETURN_IF_ERROR(postprocess_peeled.value()(
new_init_operands[idx], /*new_while_instr=*/nullptr));
}
}
ConstantValue next_loop_iteration =
Expand Down Expand Up @@ -2793,10 +2803,12 @@ static absl::Status TransformLoopBackward(
&loop_variant_parameter_info, post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(cloned_instr));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(cloned_instr, /*new_while_instr=*/nullptr));
}
if (postprocess_rotated.has_value()) {
TF_RETURN_IF_ERROR(postprocess_rotated.value()(cloned_instr));
TF_RETURN_IF_ERROR(postprocess_rotated.value()(
cloned_instr, /*new_while_instr=*/nullptr));
}
} else {
auto new_operands =
Expand Down Expand Up @@ -2852,6 +2864,8 @@ static absl::Status TransformLoopBackward(
new_loop_root,
while_body_replacement_map));
for (HloInstruction* instruction : new_while_body->instructions()) {
// TODO(frgossen): Remove this once we have eliminated the need for
// send/recv validation.
TF_RETURN_IF_ERROR(UpdateSendRecvValidation(
instruction, false, CollectivePipeliner::PipeliningDirection::kBackward,
loop_analysis));
Expand Down Expand Up @@ -2930,9 +2944,18 @@ static absl::Status TransformLoopBackward(
MapNewOperands(instr->operands(), while_body_replacement_map);
HloInstruction* cloned_instr = while_loop->parent()->AddInstruction(
instr->CloneWithNewOperands(instr->shape(), new_operands));

if (postprocess_peeled_trailing_op.has_value()) {
CHECK_NE(new_while_loop, nullptr);
TF_RETURN_IF_ERROR(
postprocess_peeled_trailing_op.value()(cloned_instr, new_while_loop));
}

TF_RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr,
while_body_replacement_map));
UpdateInstructionChannelId(cloned_instr, next_channel_id);
// TODO(frgossen): Remove this once we have eliminated the need for
// send/recv validation.
TF_RETURN_IF_ERROR(UpdateSendRecvValidation(
cloned_instr, true, CollectivePipeliner::PipeliningDirection::kBackward,
loop_analysis));
Expand All @@ -2954,6 +2977,7 @@ static absl::Status TransformLoopBackward(
TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations());
return absl::OkStatus();
}

bool IsForwardSinkIterationFeasible(HloInstruction* while_inst,
int64_t collective_size_threshold) {
for (HloInstruction* inst :
Expand Down Expand Up @@ -3059,9 +3083,10 @@ absl::StatusOr<bool> CollectivePipeliner::RunPipeliner(
CHECK_EQ(config_.pipelining_direction, PipeliningDirection::kBackward);
TF_RETURN_IF_ERROR(TransformLoopBackward(
*loop_analysis, !config_.last_run, config_.level_to_operate_on,
config_.process_different_sized_ops, config_.should_process,
config_.acceptable_formatting, config_.postprocess_backward_peeled_op,
config_.postprocess_backward_rotated_op, next_channel_id,
config_.process_different_sized_ops, config_.acceptable_formatting,
config_.postprocess_backward_peeled_op,
config_.postprocess_backward_rotated_op,
config_.postprocess_backward_peeled_trailing_op, next_channel_id,
config_.postprocess_pipelined_ops));
}
++transformed_loops;
Expand Down
13 changes: 9 additions & 4 deletions xla/service/collective_pipeliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

Expand Down Expand Up @@ -65,10 +66,12 @@ class CollectivePipeliner : public HloModulePass {
kForwardSink,
};

// Postprocessing cloned collective instructions, such as for modifying loop
// iteration related frontend attributes to reflect loop pipelining.
using HloPostprocessor =
std::optional<std::function<absl::Status(HloInstruction* instr)>>;
// Postprocessing cloned collective instructions, such as peeled instructions
// before and after the loop, and rotated instructions. The new while op is
// only passed for the peeled trailing ops when the new while op was already
// created.
using HloPostprocessor = std::optional<std::function<absl::Status(
HloInstruction* instr, HloInstruction* new_while_instr)>>;

struct Config {
int64_t level_to_operate_on = 0;
Expand Down Expand Up @@ -100,8 +103,10 @@ class CollectivePipeliner : public HloModulePass {
// pipelined. The control dependencies will be dropped when the operation is
// pipelined. This is currently only used to support kBackward pipelining.
bool should_allow_control_dependencies = false;
// TODO(frgossen): Consolidate these postprocessing functions.
HloPostprocessor postprocess_backward_peeled_op = std::nullopt;
HloPostprocessor postprocess_backward_rotated_op = std::nullopt;
HloPostprocessor postprocess_backward_peeled_trailing_op = std::nullopt;
// Determines whether a loop invariant instruction can be considered
// in the pipelining chain.
bool should_add_loop_invariant_op_in_chain = false;
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand All @@ -905,6 +906,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand All @@ -928,6 +930,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand Down
26 changes: 22 additions & 4 deletions xla/service/gpu/gpu_p2p_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ absl::Status PostprocessP2PImpl(

// Modifies the loop iteration frontend attribute for the peeled off Send and
// Recv for the first iteration of a loop.
absl::Status PostprocessPeeledP2P(HloInstruction* instr) {
absl::Status PostprocessPeeledP2P(HloInstruction* instr,
HloInstruction* new_while_instr) {
// We only use this to post-process the peeled send/recv before the new loop
// was created.
CHECK(new_while_instr == nullptr);

auto transform_bounds = [&](std::vector<ReplicaGroup>& replica_groups) {
std::vector<std::pair<int64_t, int64_t>> bounds;
bounds.reserve(replica_groups.size());
Expand Down Expand Up @@ -210,7 +215,12 @@ absl::Status PostprocessPeeledP2P(HloInstruction* instr) {

// Modifies the loop iteration frontend attribute for the rotated Send and Recv
// for the remaining iterations in a loop.
absl::Status PostprocessRotatedP2P(HloInstruction* instr) {
absl::Status PostprocessRotatedP2P(HloInstruction* instr,
HloInstruction* new_while_instr) {
// We only use this to post-process the peeled send/recv before the new loop
// was created.
CHECK(new_while_instr == nullptr);

auto transform_bounds = [&](std::vector<ReplicaGroup>& replica_groups) {
std::vector<std::pair<int64_t, int64_t>> bounds;
bounds.reserve(replica_groups.size());
Expand Down Expand Up @@ -471,11 +481,19 @@ absl::StatusOr<bool> GpuP2PPipeliner::Run(

if (enable_partial_send_recv_pipelining_) {
should_process = FullyPipelineRecv;
postprocess_backward_peeled_op = [&](HloInstruction* it) {
postprocess_backward_peeled_op = [&](HloInstruction* it,
HloInstruction* new_while_instr) {
// When post-processing non-trailing peeled send/recv, the new while loop
// was not yet created.
CHECK_EQ(new_while_instr, nullptr);
peeled_send_recvs.push_back(it);
return absl::OkStatus();
};
postprocess_backward_rotated_op = [&](HloInstruction* it) {
postprocess_backward_rotated_op = [&](HloInstruction* it,
HloInstruction* new_while_instr) {
// When post-processing non-trailing peeled send/recv, the new while loop
// was not yet created.
CHECK_EQ(new_while_instr, nullptr);
rotated_send_recvs.push_back(it);
return absl::OkStatus();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ int64_t ComputeSuggestedCombinerThreshold(
return MaxAvailableMemory(module, device_info) - peak_memory_bytes;
}

absl::Status AppendPipelinedInstruction(HloInstruction* instr) {
absl::Status AppendPipelinedInstruction(HloInstruction* instr,
HloInstruction* new_while_instr) {
if (!IsCollective(instr)) {
return absl::OkStatus();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ int64_t ComputeSuggestedCombinerThreshold(
// Adds information that `instr` has been pipelined to the
// `CollectiveBackendInfo`. It is up to the caller to decide when to invoke
// this.
absl::Status AppendPipelinedInstruction(HloInstruction* instr);
absl::Status AppendPipelinedInstruction(HloInstruction* instr,
HloInstruction* new_while_instr);

// Returns true if module contains any pipelined instruction. False otherwise.
bool ContainsPipelinedInstruction(const HloModule& module);
Expand Down

0 comments on commit 90e75b6

Please sign in to comment.