Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Support postprocessing of peeled ops in the trailing while loop iteration #22962

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,8 @@ absl::StatusOr<HloInstruction*> CloneBackwardChain(
}
clone_map[chain_op] = cloned;
if (postprocess_pipelined_ops.has_value()) {
TF_RETURN_IF_ERROR((*postprocess_pipelined_ops)(cloned));
TF_RETURN_IF_ERROR(
(*postprocess_pipelined_ops)(cloned, /*new_while_instr=*/nullptr));
}
last_cloned = cloned;
if (loop_variant_parameter_info != nullptr &&
Expand Down Expand Up @@ -1941,7 +1942,8 @@ absl::Status TransformLoopForward(
}

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(processed, /*new_while_instr=*/nullptr));
}

InstructionMap cloned_map = pipelined_values_map;
Expand All @@ -1957,7 +1959,8 @@ absl::Status TransformLoopForward(
}
cloned_map[formatting_op] = processed;
if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(processed, /*new_while_instr=*/nullptr));
}
}
return processed;
Expand Down Expand Up @@ -2672,6 +2675,7 @@ static absl::Status TransformLoopBackward(
HloPredicate should_process, HloPredicate acceptable_formatting,
CollectivePipeliner::HloPostprocessor postprocess_peeled,
CollectivePipeliner::HloPostprocessor postprocess_rotated,
CollectivePipeliner::HloPostprocessor postprocess_peeled_trailing_op,
int64_t& next_channel_id,
CollectivePipeliner::HloPostprocessor post_processing_fn) {
// Defining some maps/sets to keep track of instructions duplicated.
Expand Down Expand Up @@ -2777,10 +2781,12 @@ static absl::Status TransformLoopBackward(
/*loop_variant_parameter_info=*/nullptr, post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx]));
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx],
/*new_while_instr=*/nullptr));
}
if (postprocess_peeled.has_value()) {
TF_RETURN_IF_ERROR(postprocess_peeled.value()(new_init_operands[idx]));
TF_RETURN_IF_ERROR(postprocess_peeled.value()(
new_init_operands[idx], /*new_while_instr=*/nullptr));
}
}
ConstantValue next_loop_iteration =
Expand Down Expand Up @@ -2835,10 +2841,12 @@ static absl::Status TransformLoopBackward(
post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(cloned_instr));
TF_RETURN_IF_ERROR(
(*post_processing_fn)(cloned_instr, /*new_while_instr=*/nullptr));
}
if (postprocess_rotated.has_value()) {
TF_RETURN_IF_ERROR(postprocess_rotated.value()(cloned_instr));
TF_RETURN_IF_ERROR(postprocess_rotated.value()(
cloned_instr, /*new_while_instr=*/nullptr));
}
} else {
auto new_operands =
Expand Down Expand Up @@ -2972,6 +2980,13 @@ static absl::Status TransformLoopBackward(
MapNewOperands(instr->operands(), while_body_replacement_map);
HloInstruction* cloned_instr = while_loop->parent()->AddInstruction(
instr->CloneWithNewOperands(instr->shape(), new_operands));

if (postprocess_peeled_trailing_op.has_value()) {
CHECK_NE(new_while_loop, nullptr);
TF_RETURN_IF_ERROR(
postprocess_peeled_trailing_op.value()(cloned_instr, new_while_loop));
}

TF_RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr,
while_body_replacement_map));
UpdateInstructionChannelId(cloned_instr, next_channel_id);
Expand All @@ -2998,6 +3013,7 @@ static absl::Status TransformLoopBackward(
TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations());
return absl::OkStatus();
}

bool IsForwardSinkIterationFeasible(HloInstruction* while_inst,
int64_t collective_size_threshold) {
for (HloInstruction* inst :
Expand Down Expand Up @@ -3105,7 +3121,8 @@ absl::StatusOr<bool> CollectivePipeliner::RunPipeliner(
*loop_analysis, !config_.last_run, config_.level_to_operate_on,
config_.process_different_sized_ops, config_.should_process,
config_.acceptable_formatting, config_.postprocess_backward_peeled_op,
config_.postprocess_backward_rotated_op, next_channel_id,
config_.postprocess_backward_rotated_op,
config_.postprocess_backward_peeled_trailing_op, next_channel_id,
config_.postprocess_pipelined_ops));
}
++transformed_loops;
Expand Down
13 changes: 9 additions & 4 deletions xla/service/collective_pipeliner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

Expand Down Expand Up @@ -65,10 +66,12 @@ class CollectivePipeliner : public HloModulePass {
kForwardSink,
};

// Postprocessing cloned collective instructions, such as for modifying loop
// iteration related frontend attributes to reflect loop pipelining.
using HloPostprocessor =
std::optional<std::function<absl::Status(HloInstruction* instr)>>;
// Postprocessing cloned collective instructions, such as peeled instructions
// before and after the loop, and rotated instructions. The new while op is
// only passed for the peeled trailing ops when the new while op was already
// created.
using HloPostprocessor = std::optional<std::function<absl::Status(
HloInstruction* instr, HloInstruction* new_while_instr)>>;

struct Config {
int64_t level_to_operate_on = 0;
Expand Down Expand Up @@ -100,8 +103,10 @@ class CollectivePipeliner : public HloModulePass {
// pipelined. The control dependencies will be dropped when the operation is
// pipelined. This is currently only used to support kBackward pipelining.
bool should_allow_control_dependencies = false;
// TODO(frgossen): Consolidate these postprocessing functions.
HloPostprocessor postprocess_backward_peeled_op = std::nullopt;
HloPostprocessor postprocess_backward_rotated_op = std::nullopt;
HloPostprocessor postprocess_backward_peeled_trailing_op = std::nullopt;
// Determines whether a loop invariant instruction can be considered
// in the pipelining chain.
bool should_add_loop_invariant_op_in_chain = false;
Expand Down
14 changes: 11 additions & 3 deletions xla/service/collective_pipeliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ absl::StatusOr<bool> RunOptimizer(
std::nullopt,
CollectivePipeliner::HloPostprocessor postprocess_backward_rotated =
std::nullopt,
CollectivePipeliner::HloPostprocessor postprocess_backward_peeled_trailing =
std::nullopt,
bool should_add_loop_invariant_op_in_chain = false,
int64_t collective_size_threshold_to_stop_sinking = INT64_MAX) {
CollectivePipeliner::Config config = {
Expand All @@ -101,7 +103,8 @@ absl::StatusOr<bool> RunOptimizer(
/*reuse_pipelined_op_buffer=*/reuse_pipelined_op_buffer,
should_allow_loop_variant_parameter_in_chain,
/*should_allow_control_dependencies=*/false, postprocess_backward_peeled,
postprocess_backward_rotated, should_add_loop_invariant_op_in_chain,
postprocess_backward_rotated, postprocess_backward_peeled_trailing,
should_add_loop_invariant_op_in_chain,
/*postprocess_pipelined_ops=*/std::nullopt,
collective_size_threshold_to_stop_sinking};
HloPassPipeline pass("optimizer");
Expand Down Expand Up @@ -2790,13 +2793,15 @@ TEST_F(CollectivePipelinerTest,
};
const char* kAttr = "_xla_other_attr";
// Mutate an existing attribute.
auto postprocess_peeled = [&](HloInstruction* instr) {
auto postprocess_peeled = [&](HloInstruction* instr,
HloInstruction* new_while_instr) {
xla::FrontendAttributes attributes = instr->frontend_attributes();
(*attributes.mutable_map())[kAttr] = "1";
instr->set_frontend_attributes(attributes);
return absl::OkStatus();
};
auto postprocess_rotated = [&](HloInstruction* instr) {
auto postprocess_rotated = [&](HloInstruction* instr,
HloInstruction* new_while_instr) {
xla::FrontendAttributes attributes = instr->frontend_attributes();
(*attributes.mutable_map())[kAttr] = "2";
instr->set_frontend_attributes(attributes);
Expand Down Expand Up @@ -3172,6 +3177,7 @@ ENTRY entry {
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateTrue,
/*postprocess_backward_peeled=*/std::nullopt,
/*postprocess_backward_rotated=*/std::nullopt,
/*postprocess_backward_peeled_trailing=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true)
.value());
XLA_VLOG_LINES(1, module->ToString());
Expand Down Expand Up @@ -3202,6 +3208,7 @@ ENTRY entry {
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateTrue,
/*postprocess_backward_peeled=*/std::nullopt,
/*postprocess_backward_rotated=*/std::nullopt,
/*postprocess_backward_peeled_trailing=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false)
.value());
}
Expand Down Expand Up @@ -3593,6 +3600,7 @@ ENTRY entry {
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse,
/*postprocess_backward_peeled=*/std::nullopt,
/*postprocess_backward_rotated=*/std::nullopt,
/*postprocess_backward_peeled_trailing=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*collective_size_threshold_to_stop_sinking=*/1024)
.value());
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand All @@ -905,6 +906,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand All @@ -928,6 +930,7 @@ absl::Status RunCollectiveOptimizationPasses(
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/false,
/*postprocess_pipelined_ops=*/AppendPipelinedInstruction,
};
Expand Down
26 changes: 22 additions & 4 deletions xla/service/gpu/gpu_p2p_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ absl::Status PostprocessP2PImpl(

// Modifies the loop iteration frontend attribute for the peeled off Send and
// Recv for the first iteration of a loop.
absl::Status PostprocessPeeledP2P(HloInstruction* instr) {
absl::Status PostprocessPeeledP2P(HloInstruction* instr,
HloInstruction* new_while_instr) {
// We only use this to post-process the peeled send/recv before the new loop
// was created.
CHECK(new_while_instr == nullptr);

auto transform_bounds = [&](std::vector<ReplicaGroup>& replica_groups) {
std::vector<std::pair<int64_t, int64_t>> bounds;
bounds.reserve(replica_groups.size());
Expand Down Expand Up @@ -210,7 +215,12 @@ absl::Status PostprocessPeeledP2P(HloInstruction* instr) {

// Modifies the loop iteration frontend attribute for the rotated Send and Recv
// for the remaining iterations in a loop.
absl::Status PostprocessRotatedP2P(HloInstruction* instr) {
absl::Status PostprocessRotatedP2P(HloInstruction* instr,
HloInstruction* new_while_instr) {
// We only use this to post-process the peeled send/recv before the new loop
// was created.
CHECK(new_while_instr == nullptr);

auto transform_bounds = [&](std::vector<ReplicaGroup>& replica_groups) {
std::vector<std::pair<int64_t, int64_t>> bounds;
bounds.reserve(replica_groups.size());
Expand Down Expand Up @@ -471,11 +481,19 @@ absl::StatusOr<bool> GpuP2PPipeliner::Run(

if (enable_partial_send_recv_pipelining_) {
should_process = FullyPipelineRecv;
postprocess_backward_peeled_op = [&](HloInstruction* it) {
postprocess_backward_peeled_op = [&](HloInstruction* it,
HloInstruction* new_while_instr) {
// When post-processing non-trailing peeled send/recv, the new while loop
// was not yet created.
CHECK_EQ(new_while_instr, nullptr);
peeled_send_recvs.push_back(it);
return absl::OkStatus();
};
postprocess_backward_rotated_op = [&](HloInstruction* it) {
postprocess_backward_rotated_op = [&](HloInstruction* it,
HloInstruction* new_while_instr) {
// When post-processing non-trailing peeled send/recv, the new while loop
// was not yet created.
CHECK_EQ(new_while_instr, nullptr);
rotated_send_recvs.push_back(it);
return absl::OkStatus();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ int64_t ComputeSuggestedCombinerThreshold(
return MaxAvailableMemory(module, device_info) - peak_memory_bytes;
}

absl::Status AppendPipelinedInstruction(HloInstruction* instr) {
absl::Status AppendPipelinedInstruction(HloInstruction* instr,
HloInstruction* new_while_instr) {
if (!IsCollective(instr)) {
return absl::OkStatus();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ int64_t ComputeSuggestedCombinerThreshold(
// Adds information that `instr` has been pipelined to the
// `CollectiveBackendInfo`. It is up to the caller to decide when to invoke
// this.
absl::Status AppendPipelinedInstruction(HloInstruction* instr);
absl::Status AppendPipelinedInstruction(HloInstruction* instr,
HloInstruction* new_while_instr);

// Returns true if module contains any pipelined instruction. False otherwise.
bool ContainsPipelinedInstruction(const HloModule& module);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ TEST_F(CollectiveCombinerUtilsTest,
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
};
config.postprocess_pipelined_ops = AppendPipelinedInstruction;
Expand Down Expand Up @@ -434,6 +435,7 @@ TEST_F(CollectiveCombinerUtilsTest,
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*postprocess_backward_peeled_trailing_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
};
config.postprocess_pipelined_ops = AppendPipelinedInstruction;
Expand Down
Loading