Skip to content

Commit 9589b73

Browse files
loislotensorflower-gardener
authored andcommitted
Refactor: Use macros to define unary operations and use this list in all the places where we need them all together.
The refactoring uses the well known approach for the compilers. The positive parts are: 1) we significantly reduce the amount of the boilerplate. 2) we force engineers to use the same names for the ops in all the places. 3) we avoid the errors when someone misses to add a special handling for a new op. This change introduces `UNARI_OPS_WITHOUT_ACCURACY` and `UNARY_OPS_WITH_ACCURACY` macros in `hlo_opcode.h` to list unary operations. These lists are in use in many places where we work with the unary ops the identical way. The common cases are: define a case statement in a switch for every unary op. define the function definition/declaration for every unary op. etc. FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#33278 from openxla:dependabot/pip/xla/backends/cpu/benchmarks/e2e/gemma2/keras/keras-3.12.0 b37d94a32428d62ed3e73765f4e7b61bc6ed8549 PiperOrigin-RevId: 794210068
1 parent fd85062 commit 9589b73

33 files changed

+709
-818
lines changed

tensorflow/lite/core/model_building.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class [[nodiscard]] Buffer {
104104
template <TfLiteType kType, class T>
105105
void Assign(Buffer b, std::vector<int> shape, const std::vector<T>& data,
106106
Quantization quantization) {
107-
using Storage = TfLiteTypeToType<kType>::Type;
107+
using Storage = typename TfLiteTypeToType<kType>::Type;
108108
std::unique_ptr<Storage[]> buffer_data(new Storage[data.size()]);
109109
std::copy(begin(data), end(data), buffer_data.get());
110110
Assign(

third_party/xla/xla/backends/cpu/benchmarks/e2e/gemma2/keras/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
keras==3.11.3
1+
keras==3.12.0
22
keras_nlp==0.18.1
33
tensorflow==2.18.0
44
jax==0.4.38

third_party/xla/xla/backends/gpu/runtime/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ cc_library(
701701
"//xla:executable_run_options",
702702
"//xla:shape_util",
703703
"//xla:util",
704+
"//xla/ffi",
704705
"//xla/ffi:attribute_map",
705706
"//xla/ffi:call_frame",
706707
"//xla/ffi:execution_context",
@@ -720,6 +721,7 @@ cc_library(
720721
"//xla/tsl/platform:errors",
721722
"//xla/tsl/platform:statusor",
722723
"@com_google_absl//absl/algorithm:container",
724+
"@com_google_absl//absl/base:nullability",
723725
"@com_google_absl//absl/container:inlined_vector",
724726
"@com_google_absl//absl/log",
725727
"@com_google_absl//absl/log:check",
@@ -748,6 +750,7 @@ xla_test(
748750
"//xla/service:executable",
749751
"//xla/service:platform_util",
750752
"//xla/service/gpu:buffer_allocations",
753+
"//xla/service/gpu:resource_requests",
751754
"//xla/stream_executor:platform",
752755
"//xla/stream_executor:platform_manager",
753756
"//xla/stream_executor:stream",

third_party/xla/xla/backends/gpu/runtime/custom_call_thunk.cc

Lines changed: 156 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ limitations under the License.
2222
#include <optional>
2323
#include <string>
2424
#include <utility>
25+
#include <variant>
2526
#include <vector>
2627

2728
#include "absl/algorithm/container.h"
29+
#include "absl/base/nullability.h"
2830
#include "absl/container/inlined_vector.h"
29-
#include "absl/log/check.h"
3031
#include "absl/log/log.h"
3132
#include "absl/memory/memory.h"
3233
#include "absl/status/status.h"
@@ -41,9 +42,11 @@ limitations under the License.
4142
#include "xla/ffi/attribute_map.h"
4243
#include "xla/ffi/call_frame.h"
4344
#include "xla/ffi/execution_state.h"
45+
#include "xla/ffi/ffi.h"
4446
#include "xla/ffi/ffi_api.h"
4547
#include "xla/hlo/ir/hlo_computation.h"
4648
#include "xla/primitive_util.h"
49+
#include "xla/runtime/object_pool.h"
4750
#include "xla/service/buffer_assignment.h"
4851
#include "xla/service/custom_call_status.h"
4952
#include "xla/service/custom_call_status_internal.h"
@@ -250,6 +253,44 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
250253
std::move(attributes), std::move(execution_state), called_computation));
251254
}
252255

256+
absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
257+
ThunkInfo thunk_info, std::string target_name, OwnedHandlerBundle bundle,
258+
std::vector<std::optional<ShapedSlice>> operands,
259+
std::vector<std::optional<ShapedSlice>> results,
260+
xla::ffi::AttributesMap attributes,
261+
const HloComputation* called_computation) {
262+
if (!bundle.execute) {
263+
return absl::InvalidArgumentError(
264+
"Execute handler is required for a CustomCallThunk");
265+
}
266+
267+
auto execution_state = std::make_unique<ffi::ExecutionState>();
268+
// Initialize FFI handler state if it has an instantiate callback.
269+
if (bundle.instantiate) {
270+
// At FFI handler instantiation time, we don't have any arguments or
271+
// results or access to the underlying device (stream, etc.)
272+
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
273+
274+
CallFrameBuilder::AttributesBuilder attrs;
275+
attrs.Append(attributes);
276+
277+
builder.AddAttributes(attrs.Build());
278+
CallFrame call_frame = builder.Build();
279+
280+
CallOptions options;
281+
options.execution_state = execution_state.get();
282+
TF_RETURN_IF_ERROR(Call(*bundle.instantiate, call_frame, options,
283+
xla::ffi::ExecutionStage::kInstantiate));
284+
}
285+
286+
TF_ASSIGN_OR_RETURN(CallFrame call_frame,
287+
BuildCallFramePrototype(operands, results, attributes));
288+
return absl::WrapUnique(new CustomCallThunk(
289+
thunk_info, std::move(target_name), std::move(bundle),
290+
std::move(operands), std::move(results), std::move(call_frame),
291+
std::move(attributes), std::move(execution_state), called_computation));
292+
}
293+
253294
CustomCallThunk::CustomCallThunk(
254295
ThunkInfo thunk_info, std::string target_name,
255296
std::vector<std::optional<ShapedSlice>> operands,
@@ -266,7 +307,7 @@ CustomCallThunk::CustomCallThunk(
266307

267308
CustomCallThunk::CustomCallThunk(
268309
ThunkInfo thunk_info, std::string target_name,
269-
XLA_FFI_Handler_Bundle bundle,
310+
std::variant<XLA_FFI_Handler_Bundle, OwnedHandlerBundle> bundle,
270311
std::vector<std::optional<ShapedSlice>> operands,
271312
std::vector<std::optional<ShapedSlice>> results, CallFrame call_frame,
272313
ffi::AttributesMap attributes,
@@ -317,18 +358,9 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
317358
return absl::OkStatus();
318359
}
319360

320-
absl::Status CustomCallThunk::ExecuteFfiHandler(
321-
RunId run_id, XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
322-
se::Stream* stream, const ffi::ExecutionContext* execution_context,
323-
const BufferAllocations* buffer_allocations) {
324-
if (handler == nullptr) {
325-
return absl::InternalError("FFI execute handler is not set");
326-
}
327-
if (stage != XLA_FFI_ExecutionStage_PREPARE &&
328-
!(buffer_allocations && stream)) {
329-
return absl::InternalError("buffer allocations and stream are required");
330-
}
331-
361+
absl::StatusOr<ObjectPool<CallFrame>::BorrowedObject>
362+
CustomCallThunk::BuildCallFrame(
363+
const BufferAllocations* absl_nullable buffer_allocations) {
332364
auto device_memory = [&](BufferAllocation::Slice slice) {
333365
return buffer_allocations ? buffer_allocations->GetDeviceAddress(slice)
334366
: se::DeviceMemoryBase{};
@@ -360,58 +392,142 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(
360392
// device memory addresses.
361393
TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate());
362394
TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results));
395+
return call_frame;
396+
}
363397

398+
CallOptions CustomCallThunk::BuildCallOptions(
399+
RunId run_id, se::Stream* absl_nullable stream,
400+
const BufferAllocations* absl_nullable buffer_allocations,
401+
const ffi::ExecutionContext* absl_nonnull execution_context) {
364402
int32_t device_ordinal = -1;
365403
se::DeviceMemoryAllocator* allocator = nullptr;
366-
if (stage != XLA_FFI_ExecutionStage_PREPARE) {
404+
if (buffer_allocations != nullptr) {
367405
device_ordinal = buffer_allocations->device_ordinal();
368406
allocator = buffer_allocations->memory_allocator();
369407
}
370408

371-
CallOptions options = {run_id,
372-
device_ordinal,
373-
CallOptions::GpuOptions{stream, allocator},
374-
called_computation_,
375-
execution_context,
376-
execution_state_.get()};
409+
return CallOptions{run_id,
410+
device_ordinal,
411+
CallOptions::GpuOptions{stream, allocator},
412+
called_computation_,
413+
execution_context,
414+
execution_state_.get()};
415+
}
416+
417+
absl::Status CustomCallThunk::ExecuteFfiHandler(
418+
RunId run_id, XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
419+
se::Stream* stream, const ffi::ExecutionContext* execution_context,
420+
const BufferAllocations* buffer_allocations) {
421+
if (handler == nullptr) {
422+
return absl::InternalError("FFI execute handler is not set");
423+
}
424+
if (stage != XLA_FFI_ExecutionStage_PREPARE &&
425+
!(buffer_allocations && stream)) {
426+
return absl::InternalError("buffer allocations and stream are required");
427+
}
428+
429+
TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations));
430+
CallOptions options =
431+
BuildCallOptions(run_id, stream, buffer_allocations, execution_context);
432+
return Call(handler, *call_frame, options, stage);
433+
}
434+
435+
absl::Status CustomCallThunk::ExecuteFfiHandler(
436+
RunId run_id, xla::ffi::Ffi& handler, xla::ffi::ExecutionStage stage,
437+
se::Stream* stream, const ffi::ExecutionContext* execution_context,
438+
const BufferAllocations* buffer_allocations) {
439+
if (stage != xla::ffi::ExecutionStage::kPrepare &&
440+
!(buffer_allocations && stream)) {
441+
return absl::InternalError("buffer allocations and stream are required");
442+
}
443+
444+
TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations));
445+
CallOptions options =
446+
BuildCallOptions(run_id, stream, buffer_allocations, execution_context);
377447
return Call(handler, *call_frame, options, stage);
378448
}
379449

380450
absl::Status CustomCallThunk::Prepare(
381451
const PrepareParams& params, ResourceRequestsInterface& resource_requests) {
382-
if (!bundle_ || !bundle_->prepare) {
383-
return absl::OkStatus();
452+
if (bundle_.has_value()) {
453+
const RunId run_id =
454+
params.collective_params ? params.collective_params->run_id : RunId{-1};
455+
456+
if (const auto* c_bundle =
457+
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
458+
c_bundle && c_bundle->prepare) {
459+
return ExecuteFfiHandler(run_id, c_bundle->prepare,
460+
XLA_FFI_ExecutionStage_PREPARE,
461+
/*stream=*/nullptr,
462+
/*execution_context=*/nullptr,
463+
/*buffer_allocations=*/nullptr);
464+
}
465+
if (const auto* owned_bundle =
466+
std::get_if<OwnedHandlerBundle>(&bundle_.value());
467+
owned_bundle && owned_bundle->prepare) {
468+
return ExecuteFfiHandler(run_id, *owned_bundle->prepare,
469+
xla::ffi::ExecutionStage::kPrepare,
470+
/*stream=*/nullptr,
471+
/*execution_context=*/nullptr,
472+
/*buffer_allocations=*/nullptr);
473+
}
384474
}
385475

386-
return ExecuteFfiHandler(
387-
params.collective_params ? params.collective_params->run_id : RunId{-1},
388-
bundle_->prepare, XLA_FFI_ExecutionStage_PREPARE,
389-
/*stream=*/nullptr,
390-
/*execution_context=*/nullptr,
391-
/*buffer_allocations=*/nullptr);
476+
return absl::OkStatus();
392477
}
393478

394479
absl::Status CustomCallThunk::Initialize(const InitializeParams& params) {
395-
if (!bundle_ || !bundle_->initialize) {
396-
return absl::OkStatus();
480+
if (bundle_.has_value()) {
481+
const RunId run_id =
482+
params.collective_params ? params.collective_params->run_id : RunId{-1};
483+
484+
if (const auto* c_bundle =
485+
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
486+
c_bundle && c_bundle->initialize) {
487+
return ExecuteFfiHandler(run_id, *c_bundle->initialize,
488+
XLA_FFI_ExecutionStage_INITIALIZE, params.stream,
489+
params.ffi_execution_context,
490+
params.buffer_allocations);
491+
}
492+
if (const auto* owned_bundle =
493+
std::get_if<OwnedHandlerBundle>(&bundle_.value());
494+
owned_bundle && owned_bundle->initialize) {
495+
return ExecuteFfiHandler(run_id, *owned_bundle->initialize,
496+
xla::ffi::ExecutionStage::kInitialize,
497+
params.stream, params.ffi_execution_context,
498+
params.buffer_allocations);
499+
}
397500
}
398-
399-
return ExecuteFfiHandler(
400-
params.collective_params ? params.collective_params->run_id : RunId{-1},
401-
bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, params.stream,
402-
params.ffi_execution_context, params.buffer_allocations);
501+
return absl::OkStatus();
403502
}
404503

405504
absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
406505
TF_ASSIGN_OR_RETURN(
407506
se::Stream * stream,
408507
GetStreamForExecution(Thunk::execution_stream_id(), params));
508+
409509
if (bundle_.has_value()) {
410-
return ExecuteFfiHandler(
411-
params.collective_params ? params.collective_params->run_id : RunId{-1},
412-
bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, stream,
413-
params.ffi_execution_context, params.buffer_allocations);
510+
const RunId run_id =
511+
params.collective_params ? params.collective_params->run_id : RunId{-1};
512+
if (const auto* c_bundle =
513+
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
514+
c_bundle) {
515+
return ExecuteFfiHandler(
516+
run_id, c_bundle->execute, XLA_FFI_ExecutionStage_EXECUTE, stream,
517+
params.ffi_execution_context, params.buffer_allocations);
518+
}
519+
if (const auto* owned_bundle =
520+
std::get_if<OwnedHandlerBundle>(&bundle_.value());
521+
owned_bundle) {
522+
if (!owned_bundle->execute) {
523+
return absl::InternalError("FFI execute handler is not set");
524+
}
525+
return ExecuteFfiHandler(
526+
run_id, *owned_bundle->execute, xla::ffi::ExecutionStage::kExecute,
527+
stream, params.ffi_execution_context, params.buffer_allocations);
528+
}
414529
}
530+
415531
return ExecuteCustomCall(params);
416532
}
417533

0 commit comments

Comments
 (0)