From 2ad77d663e6543ca74692e69e5b7f0d430baa4ff Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Mon, 24 Feb 2025 14:36:36 -0800 Subject: [PATCH] Build absl::string_view(data, length) (instead of StringRef::str) explicitly since the llvm::StringRef to absl::string_view converter is not (always?) available on Android. END_PUBLIC PiperOrigin-RevId: 730600175 --- xla/service/spmd/shardy/round_trip_common/BUILD | 1 + .../round_trip_common/import_backend_func_calls.cc | 3 ++- .../round_trip_common/import_sdy_custom_calls.cc | 6 ++++-- .../shardy/sdy_round_trip/remove_size_one_axes.cc | 4 +++- .../spmd/shardy/sdy_round_trip/shard_map_import.cc | 5 ++--- xla/service/spmd/shardy/utils.h | 12 ++++++++---- 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/xla/service/spmd/shardy/round_trip_common/BUILD b/xla/service/spmd/shardy/round_trip_common/BUILD index 1a1cf9cda6684..26b4313eb5f3f 100644 --- a/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/xla/service/spmd/shardy/round_trip_common/BUILD @@ -22,6 +22,7 @@ cc_library( "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc b/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc index b2c0e517e7430..8c8d472ffb0d9 100644 --- a/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc +++ b/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc @@ -72,7 +72,8 @@ class BackendFuncCallPattern : public OpConversionPattern { FuncOp func = symbolTable.lookup(adaptor.getCallee()); CHECK(func) << "Failed to lookup function: " - << absl::string_view(adaptor.getCallee()); + << absl::string_view(adaptor.getCallee().data(), + adaptor.getCallee().size()); mlir::SmallVector namedCompAttrs; llvm::copy_if(callOp->getDiscardableAttrs(), std::back_inserter(namedCompAttrs), diff --git a/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc b/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc index 4a36c2ba3b158..c224ddc444fd1 100644 --- a/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc +++ b/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -60,9 +61,10 @@ mlir::LogicalResult rewriteShardingCustomCall( std::vector unspecDims; if (std::optional backendConfig = op.getBackendConfig()) { + StringRef configStr = + mlir::dyn_cast(*backendConfig).getValue(); CHECK_OK(xla::sharding_op_util::ParseAttributes( - mlir::dyn_cast(*backendConfig).getValue(), - &unspecDims)); + absl::string_view(configStr.data(), configStr.size()), &unspecDims)); } if (op->getNumResults() != 1) { diff --git a/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc b/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc index bee62bff1a360..c73f5c63a1689 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc @@ -75,7 +75,9 @@ MeshAttr removeSizeOneAxes(MeshAttr mesh) { TensorShardingAttr removeSizeOneAxes(TensorShardingAttr sharding, const SymbolTable& symbolTable) { MeshAttr mesh = sharding.getMesh(symbolTable); - CHECK(mesh) << "unknown mesh: " << absl::string_view(sharding.getMeshName()); + CHECK(mesh) << "unknown mesh: " + << absl::string_view(sharding.getMeshName().data(), + sharding.getMeshName().size()); auto isNotSizeOne = [&](AxisRefAttr axis) { return axis.getSize(mesh) != 1; }; diff --git a/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index 1d547c7842a7b..1d195565b6cd5 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -77,8 +77,7 @@ class ManualComputationPattern : public OpConversionPattern { mlir::LogicalResult matchAndRewrite( CallOp callOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override { - if (!absl::StrContains(callOp.getCallee(), - kManualComputationBodyFuncName)) { + if (!callOp.getCallee().contains(kManualComputationBodyFuncName)) { return mlir::failure(); } @@ -159,7 +158,7 @@ class SdyRoundTripShardMapImportPass MLIRContext& context = getContext(); mlir::ConversionTarget target(context); target.addDynamicallyLegalOp([](CallOp op) { - return !absl::StrContains(op.getCallee(), kManualComputationBodyFuncName); + return !op.getCallee().contains(kManualComputationBodyFuncName); }); target.addLegalOp(); mlir::RewritePatternSet patterns(&context); diff --git a/xla/service/spmd/shardy/utils.h b/xla/service/spmd/shardy/utils.h index 7975a55599d64..7e7f2af813cb5 100644 --- a/xla/service/spmd/shardy/utils.h +++ b/xla/service/spmd/shardy/utils.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -80,13 +81,16 @@ template AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr, llvm::StringRef attrName) { if (mlir::Attribute stringAttr = dictAttr.get(attrName)) { - std::string value; + std::string unescapedValue; std::string error; - CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), - &value, &error)) + llvm::StringRef escapedValue = + mlir::cast(stringAttr).getValue(); + CHECK(absl::CUnescape( + absl::string_view(escapedValue.data(), escapedValue.size()), + &unescapedValue, &error)) << error; return mlir::cast( - mlir::parseAttribute(value, stringAttr.getContext())); + mlir::parseAttribute(unescapedValue, stringAttr.getContext())); } return nullptr; }