Skip to content

Add argument to fix a random seed when generating random arguments for HLO runner. Also add OutputFormat so that literal dumps can be saved as a pb file. #21566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ xla_test(
"//xla:xla_proto_cc",
"//xla/hlo/testlib:filecheck",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:hlo_proto_cc",
"//xla/tsl/lib/core:status_test_util",
Expand Down
73 changes: 62 additions & 11 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,37 @@ std::string AbslUnparseFlag(InputFormat input_format) {
}
}

bool AbslParseFlag(absl::string_view text, OutputFormat* output_format,
std::string* error) {
if (text == "text") {
*output_format = OutputFormat::kText;
return true;
}
if (text == "proto_binary") {
*output_format = OutputFormat::kProtoBinary;
return true;
}
if (text == "proto_text") {
*output_format = OutputFormat::kProtoText;
return true;
}
*error = "unknown value for enumeration";
return false;
}

std::string AbslUnparseFlag(OutputFormat output_format) {
switch (output_format) {
case OutputFormat::kText:
return "text";
case OutputFormat::kProtoBinary:
return "proto_binary";
case OutputFormat::kProtoText:
return "proto_text";
default:
return absl::StrCat(output_format);
}
}

bool AbslParseFlag(absl::string_view text,
FunctionalHloRunner::ModuleArgumentMode* argument_mode,
std::string* error) {
Expand Down Expand Up @@ -442,7 +473,7 @@ FunctionalHloRunner::CreateExecutableBuildOptionsFromExecutionOptions(

absl::Status FunctionalHloRunner::DumpOutput(
const FunctionalHloRunner::PerDeviceLiteralVecType& output,
absl::string_view dump_output_to, int task_id) {
absl::string_view dump_output_to, int task_id, OutputFormat output_format) {
std::vector<std::string> output_path_vec =
absl::StrSplit(dump_output_to, '.');
std::string suffix = output_path_vec.back();
Expand All @@ -458,12 +489,30 @@ absl::Status FunctionalHloRunner::DumpOutput(
for (int literal_id = 0; literal_id < literal_vec.size(); ++literal_id) {
output_path_vec[literal_id_index] = absl::StrCat("literal_", literal_id);
std::string literal_path = absl::StrJoin(output_path_vec, ".");
CHECK_EQ(suffix, std::string("txt"));
absl::Status write_status =
tsl::WriteStringToFile(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToString());
if (!write_status.ok()) {
return write_status;
switch (output_format) {
case OutputFormat::kText: {
CHECK_EQ(suffix, std::string("txt"));
absl::Status write_status =
tsl::WriteStringToFile(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToString());
if (!write_status.ok()) {
return write_status;
}
} break;
case OutputFormat::kProtoBinary: {
CHECK_EQ(suffix, std::string("pb"));
TF_RETURN_IF_ERROR(
tsl::WriteBinaryProto(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToProto()));
break;
}
case OutputFormat::kProtoText: {
CHECK_EQ(suffix, std::string("pbtxt"));
TF_RETURN_IF_ERROR(
tsl::WriteTextProto(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToProto()));
break;
}
}
}
}
Expand Down Expand Up @@ -545,7 +594,8 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client,
const RunningOptions& running_options,
absl::string_view hlo_text,
InputFormat input_format,
const PerDeviceLiteralVecType& arguments) {
const PerDeviceLiteralVecType& arguments,
std::minstd_rand0* engine) {
// We only support SPMD as of now, i.e., all devices are supposed
// to execute the same HLO module.
// Currently there is no mechanism to map the loaded arguments to
Expand Down Expand Up @@ -577,7 +627,7 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client,

return CompileAndRun(
client, debug_options, preproc_options, compile_options, running_options,
hlo_module_and_arguments.hlo_module.get(), loaded_arguments);
hlo_module_and_arguments.hlo_module.get(), loaded_arguments, engine);
}

absl::Status FunctionalHloRunner::LoadAndCompile(
Expand Down Expand Up @@ -723,12 +773,13 @@ FunctionalHloRunner::CompileAndRun(PjRtClient& client,
const CompileOptions& compile_options,
const RunningOptions& running_options,
HloModule* hlo_module,
const PerDeviceLiteralVecType& arguments) {
const PerDeviceLiteralVecType& arguments,
std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
Compile(client, hlo_module, debug_options,
preproc_options, compile_options));

return Run(client, executable.get(), arguments, running_options);
return Run(client, executable.get(), arguments, running_options, engine);
}

namespace {
Expand Down
19 changes: 16 additions & 3 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ enum class InputFormat {
// in conjunction with xla_dump_as_text.
};

enum class OutputFormat : std::uint8_t {
kText, // Text format returned by Literal::ToString().
kProtoBinary, // Protobuf binary format of an xla::LiteralProto message.
kProtoText, // Protobuf text format of an xla::LiteralProto message.
};

// Interface for profiler plugins. If being set in RunningOptions, profiling
// session will be created for the last run of the HLO module.
class ProfilerInterface {
Expand Down Expand Up @@ -134,6 +140,10 @@ bool AbslParseFlag(absl::string_view text, InputFormat* input_format,
std::string* error);
std::string AbslUnparseFlag(InputFormat input_format);

bool AbslParseFlag(absl::string_view text, OutputFormat* output_format,
std::string* error);
std::string AbslUnparseFlag(OutputFormat output_format);

// FunctionalHloRunner takes an HLO module as input and runs the HLO module
// on a single or multiple hosts with various options (e.g. SPMD). The HLO
// module can be pre- or post-optimizations.
Expand Down Expand Up @@ -346,7 +356,8 @@ class FunctionalHloRunner {
const PreprocessingOptions& preproc_options,
const CompileOptions& compile_options,
const RunningOptions& running_options, absl::string_view hlo_text,
InputFormat input_format, const PerDeviceLiteralVecType& arguments = {});
InputFormat input_format, const PerDeviceLiteralVecType& arguments = {},
std::minstd_rand0* engine = nullptr);

// Loads and compiles an HLO for debugging purposes.
//
Expand All @@ -368,7 +379,8 @@ class FunctionalHloRunner {
const PreprocessingOptions& preproc_options,
const CompileOptions& compile_options,
const RunningOptions& running_options, HloModule* hlo_module,
const PerDeviceLiteralVecType& arguments = {});
const PerDeviceLiteralVecType& arguments = {},
std::minstd_rand0* engine = nullptr);

// Compiles the HLO module.
static absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
Expand Down Expand Up @@ -429,7 +441,8 @@ class FunctionalHloRunner {

static absl::Status DumpOutput(
const FunctionalHloRunner::PerDeviceLiteralVecType& output,
absl::string_view dump_output_to, int task_id);
absl::string_view dump_output_to, int task_id,
OutputFormat output_format = OutputFormat::kText);

private:
// Calculates the requested number of replicas and partitions.
Expand Down
19 changes: 19 additions & 0 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstdlib>
#include <memory>
#include <random>
#include <string>
#include <vector>

Expand All @@ -31,6 +32,7 @@ limitations under the License.
#include "xla/debug_options_flags.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/service/hlo.pb.h"
#include "xla/status_macros.h"
Expand Down Expand Up @@ -633,6 +635,23 @@ TEST_F(FunctionalHloRunnerTest, ReadHloUnoptimizedSnapshot) {
hlo_module_and_arguments_from_binary.arguments.size());
}

TEST_F(FunctionalHloRunnerTest, FixFakeArguments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());

// Options corresponding to --num_replicas=1 --num_partitions=1
xla::DebugOptions debug_options;
FunctionalHloRunner::PreprocessingOptions preproc_options;
CompileOptions compile_options;
FunctionalHloRunner::RunningOptions running_options;

std::minstd_rand0 engine(42);
TF_EXPECT_OK(FunctionalHloRunner::LoadAndRun(
*client, debug_options, preproc_options, compile_options, running_options,
{GetHloPath("single_device.hlo")}, InputFormat::kText,
/*arguments=*/{}, /*engine=*/&engine));
}

} // namespace
} // namespace xla

Expand Down
Loading