@@ -22,11 +22,12 @@ limitations under the License.
2222#include < optional>
2323#include < string>
2424#include < utility>
25+ #include < variant>
2526#include < vector>
2627
2728#include " absl/algorithm/container.h"
29+ #include " absl/base/nullability.h"
2830#include " absl/container/inlined_vector.h"
29- #include " absl/log/check.h"
3031#include " absl/log/log.h"
3132#include " absl/memory/memory.h"
3233#include " absl/status/status.h"
@@ -41,9 +42,11 @@ limitations under the License.
4142#include " xla/ffi/attribute_map.h"
4243#include " xla/ffi/call_frame.h"
4344#include " xla/ffi/execution_state.h"
45+ #include " xla/ffi/ffi.h"
4446#include " xla/ffi/ffi_api.h"
4547#include " xla/hlo/ir/hlo_computation.h"
4648#include " xla/primitive_util.h"
49+ #include " xla/runtime/object_pool.h"
4750#include " xla/service/buffer_assignment.h"
4851#include " xla/service/custom_call_status.h"
4952#include " xla/service/custom_call_status_internal.h"
@@ -250,6 +253,44 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
250253 std::move (attributes), std::move (execution_state), called_computation));
251254}
252255
256+ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create (
257+ ThunkInfo thunk_info, std::string target_name, OwnedHandlerBundle bundle,
258+ std::vector<std::optional<ShapedSlice>> operands,
259+ std::vector<std::optional<ShapedSlice>> results,
260+ xla::ffi::AttributesMap attributes,
261+ const HloComputation* called_computation) {
262+ if (!bundle.execute ) {
263+ return absl::InvalidArgumentError (
264+ " Execute handler is required for a CustomCallThunk" );
265+ }
266+
267+ auto execution_state = std::make_unique<ffi::ExecutionState>();
268+ // Initialize FFI handler state if it has an instantiate callback.
269+ if (bundle.instantiate ) {
270+ // At FFI handler instantiation time, we don't have any arguments or
271+ // results or access to the underlying device (stream, etc.)
272+ CallFrameBuilder builder (/* num_args=*/ 0 , /* num_rets=*/ 0 );
273+
274+ CallFrameBuilder::AttributesBuilder attrs;
275+ attrs.Append (attributes);
276+
277+ builder.AddAttributes (attrs.Build ());
278+ CallFrame call_frame = builder.Build ();
279+
280+ CallOptions options;
281+ options.execution_state = execution_state.get ();
282+ TF_RETURN_IF_ERROR (Call (*bundle.instantiate , call_frame, options,
283+ xla::ffi::ExecutionStage::kInstantiate ));
284+ }
285+
286+ TF_ASSIGN_OR_RETURN (CallFrame call_frame,
287+ BuildCallFramePrototype (operands, results, attributes));
288+ return absl::WrapUnique (new CustomCallThunk (
289+ thunk_info, std::move (target_name), std::move (bundle),
290+ std::move (operands), std::move (results), std::move (call_frame),
291+ std::move (attributes), std::move (execution_state), called_computation));
292+ }
293+
253294CustomCallThunk::CustomCallThunk (
254295 ThunkInfo thunk_info, std::string target_name,
255296 std::vector<std::optional<ShapedSlice>> operands,
@@ -266,7 +307,7 @@ CustomCallThunk::CustomCallThunk(
266307
267308CustomCallThunk::CustomCallThunk (
268309 ThunkInfo thunk_info, std::string target_name,
269- XLA_FFI_Handler_Bundle bundle,
310+ std::variant< XLA_FFI_Handler_Bundle, OwnedHandlerBundle> bundle,
270311 std::vector<std::optional<ShapedSlice>> operands,
271312 std::vector<std::optional<ShapedSlice>> results, CallFrame call_frame,
272313 ffi::AttributesMap attributes,
@@ -317,18 +358,9 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
317358 return absl::OkStatus ();
318359}
319360
320- absl::Status CustomCallThunk::ExecuteFfiHandler (
321- RunId run_id, XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
322- se::Stream* stream, const ffi::ExecutionContext* execution_context,
323- const BufferAllocations* buffer_allocations) {
324- if (handler == nullptr ) {
325- return absl::InternalError (" FFI execute handler is not set" );
326- }
327- if (stage != XLA_FFI_ExecutionStage_PREPARE &&
328- !(buffer_allocations && stream)) {
329- return absl::InternalError (" buffer allocations and stream are required" );
330- }
331-
361+ absl::StatusOr<ObjectPool<CallFrame>::BorrowedObject>
362+ CustomCallThunk::BuildCallFrame (
363+ const BufferAllocations* absl_nullable buffer_allocations) {
332364 auto device_memory = [&](BufferAllocation::Slice slice) {
333365 return buffer_allocations ? buffer_allocations->GetDeviceAddress (slice)
334366 : se::DeviceMemoryBase{};
@@ -360,58 +392,142 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(
360392 // device memory addresses.
361393 TF_ASSIGN_OR_RETURN (auto call_frame, call_frames_->GetOrCreate ());
362394 TF_RETURN_IF_ERROR (call_frame->UpdateWithBuffers (arguments, results));
395+ return call_frame;
396+ }
363397
398+ CallOptions CustomCallThunk::BuildCallOptions (
399+ RunId run_id, se::Stream* absl_nullable stream,
400+ const BufferAllocations* absl_nullable buffer_allocations,
401+ const ffi::ExecutionContext* absl_nonnull execution_context) {
364402 int32_t device_ordinal = -1 ;
365403 se::DeviceMemoryAllocator* allocator = nullptr ;
366- if (stage != XLA_FFI_ExecutionStage_PREPARE ) {
404+ if (buffer_allocations != nullptr ) {
367405 device_ordinal = buffer_allocations->device_ordinal ();
368406 allocator = buffer_allocations->memory_allocator ();
369407 }
370408
371- CallOptions options = {run_id,
372- device_ordinal,
373- CallOptions::GpuOptions{stream, allocator},
374- called_computation_,
375- execution_context,
376- execution_state_.get ()};
409+ return CallOptions{run_id,
410+ device_ordinal,
411+ CallOptions::GpuOptions{stream, allocator},
412+ called_computation_,
413+ execution_context,
414+ execution_state_.get ()};
415+ }
416+
417+ absl::Status CustomCallThunk::ExecuteFfiHandler (
418+ RunId run_id, XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
419+ se::Stream* stream, const ffi::ExecutionContext* execution_context,
420+ const BufferAllocations* buffer_allocations) {
421+ if (handler == nullptr ) {
422+ return absl::InternalError (" FFI execute handler is not set" );
423+ }
424+ if (stage != XLA_FFI_ExecutionStage_PREPARE &&
425+ !(buffer_allocations && stream)) {
426+ return absl::InternalError (" buffer allocations and stream are required" );
427+ }
428+
429+ TF_ASSIGN_OR_RETURN (auto call_frame, BuildCallFrame (buffer_allocations));
430+ CallOptions options =
431+ BuildCallOptions (run_id, stream, buffer_allocations, execution_context);
432+ return Call (handler, *call_frame, options, stage);
433+ }
434+
435+ absl::Status CustomCallThunk::ExecuteFfiHandler (
436+ RunId run_id, xla::ffi::Ffi& handler, xla::ffi::ExecutionStage stage,
437+ se::Stream* stream, const ffi::ExecutionContext* execution_context,
438+ const BufferAllocations* buffer_allocations) {
439+ if (stage != xla::ffi::ExecutionStage::kPrepare &&
440+ !(buffer_allocations && stream)) {
441+ return absl::InternalError (" buffer allocations and stream are required" );
442+ }
443+
444+ TF_ASSIGN_OR_RETURN (auto call_frame, BuildCallFrame (buffer_allocations));
445+ CallOptions options =
446+ BuildCallOptions (run_id, stream, buffer_allocations, execution_context);
377447 return Call (handler, *call_frame, options, stage);
378448}
379449
380450absl::Status CustomCallThunk::Prepare (
381451 const PrepareParams& params, ResourceRequestsInterface& resource_requests) {
382- if (!bundle_ || !bundle_->prepare ) {
383- return absl::OkStatus ();
452+ if (bundle_.has_value ()) {
453+ const RunId run_id =
454+ params.collective_params ? params.collective_params ->run_id : RunId{-1 };
455+
456+ if (const auto * c_bundle =
457+ std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value ());
458+ c_bundle && c_bundle->prepare ) {
459+ return ExecuteFfiHandler (run_id, c_bundle->prepare ,
460+ XLA_FFI_ExecutionStage_PREPARE,
461+ /* stream=*/ nullptr ,
462+ /* execution_context=*/ nullptr ,
463+ /* buffer_allocations=*/ nullptr );
464+ }
465+ if (const auto * owned_bundle =
466+ std::get_if<OwnedHandlerBundle>(&bundle_.value ());
467+ owned_bundle && owned_bundle->prepare ) {
468+ return ExecuteFfiHandler (run_id, *owned_bundle->prepare ,
469+ xla::ffi::ExecutionStage::kPrepare ,
470+ /* stream=*/ nullptr ,
471+ /* execution_context=*/ nullptr ,
472+ /* buffer_allocations=*/ nullptr );
473+ }
384474 }
385475
386- return ExecuteFfiHandler (
387- params.collective_params ? params.collective_params ->run_id : RunId{-1 },
388- bundle_->prepare , XLA_FFI_ExecutionStage_PREPARE,
389- /* stream=*/ nullptr ,
390- /* execution_context=*/ nullptr ,
391- /* buffer_allocations=*/ nullptr );
476+ return absl::OkStatus ();
392477}
393478
394479absl::Status CustomCallThunk::Initialize (const InitializeParams& params) {
395- if (!bundle_ || !bundle_->initialize ) {
396- return absl::OkStatus ();
480+ if (bundle_.has_value ()) {
481+ const RunId run_id =
482+ params.collective_params ? params.collective_params ->run_id : RunId{-1 };
483+
484+ if (const auto * c_bundle =
485+ std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value ());
486+ c_bundle && c_bundle->initialize ) {
487+ return ExecuteFfiHandler (run_id, *c_bundle->initialize ,
488+ XLA_FFI_ExecutionStage_INITIALIZE, params.stream ,
489+ params.ffi_execution_context ,
490+ params.buffer_allocations );
491+ }
492+ if (const auto * owned_bundle =
493+ std::get_if<OwnedHandlerBundle>(&bundle_.value ());
494+ owned_bundle && owned_bundle->initialize ) {
495+ return ExecuteFfiHandler (run_id, *owned_bundle->initialize ,
496+ xla::ffi::ExecutionStage::kInitialize ,
497+ params.stream , params.ffi_execution_context ,
498+ params.buffer_allocations );
499+ }
397500 }
398-
399- return ExecuteFfiHandler (
400- params.collective_params ? params.collective_params ->run_id : RunId{-1 },
401- bundle_->initialize , XLA_FFI_ExecutionStage_INITIALIZE, params.stream ,
402- params.ffi_execution_context , params.buffer_allocations );
501+ return absl::OkStatus ();
403502}
404503
405504absl::Status CustomCallThunk::ExecuteOnStream (const ExecuteParams& params) {
406505 TF_ASSIGN_OR_RETURN (
407506 se::Stream * stream,
408507 GetStreamForExecution (Thunk::execution_stream_id (), params));
508+
409509 if (bundle_.has_value ()) {
410- return ExecuteFfiHandler (
411- params.collective_params ? params.collective_params ->run_id : RunId{-1 },
412- bundle_->execute , XLA_FFI_ExecutionStage_EXECUTE, stream,
413- params.ffi_execution_context , params.buffer_allocations );
510+ const RunId run_id =
511+ params.collective_params ? params.collective_params ->run_id : RunId{-1 };
512+ if (const auto * c_bundle =
513+ std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value ());
514+ c_bundle) {
515+ return ExecuteFfiHandler (
516+ run_id, c_bundle->execute , XLA_FFI_ExecutionStage_EXECUTE, stream,
517+ params.ffi_execution_context , params.buffer_allocations );
518+ }
519+ if (const auto * owned_bundle =
520+ std::get_if<OwnedHandlerBundle>(&bundle_.value ());
521+ owned_bundle) {
522+ if (!owned_bundle->execute ) {
523+ return absl::InternalError (" FFI execute handler is not set" );
524+ }
525+ return ExecuteFfiHandler (
526+ run_id, *owned_bundle->execute , xla::ffi::ExecutionStage::kExecute ,
527+ stream, params.ffi_execution_context , params.buffer_allocations );
528+ }
414529 }
530+
415531 return ExecuteCustomCall (params);
416532}
417533
0 commit comments