Skip to content

Commit

Permalink
[pjrt] Use the PjRtMemorySpace* version of BufferFromHostLiteral
Browse files Browse the repository at this point in the history
Buffers live in memory spaces and not on devices. The `PjRtDevice` version
of `BufferFromHostLiteral` is deprecated and will be removed once the migration
is complete.

PiperOrigin-RevId: 720889702
  • Loading branch information
superbobry authored and Google-ML-Automation committed Jan 29, 2025
1 parent c6eccfa commit 0d55f37
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
12 changes: 8 additions & 4 deletions xla/examples/axpy/stablehlo_compile_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,20 @@ TEST_F(StableHloAxpyTest, CompileAndExecuteCPUTestProgram) {
std::cerr << "\ty:" << y_literal << std::endl;

PjRtDevice* host_cpu = client->devices()[0];
TF_ASSERT_OK_AND_ASSIGN(PjRtMemorySpace * host_cpu_memory_space,
host_cpu->default_memory_space());

// Transfer our literals to buffers. If we were using a GPU, these buffers
// would correspond to device memory.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> alpha,
client->BufferFromHostLiteral(alpha_literal, host_cpu));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtBuffer> x,
client->BufferFromHostLiteral(x_literal, host_cpu));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtBuffer> y,
client->BufferFromHostLiteral(y_literal, host_cpu));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> x,
client->BufferFromHostLiteral(x_literal, host_cpu_memory_space));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> y,
client->BufferFromHostLiteral(y_literal, host_cpu_memory_space));

// Do our computation.
TF_ASSERT_OK_AND_ASSIGN(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ TEST(TfrtCpuClientTest, BufferFromLiteralInt4) {
TF_ASSERT_OK_AND_ASSIGN(auto literal, xla::MakeFakeLiteral(shape));
TF_ASSERT_OK_AND_ASSIGN(
auto buffer,
client->BufferFromHostLiteral(literal, client->addressable_devices()[0]));
client->BufferFromHostLiteral(literal, client->memory_spaces()[0]));
TF_ASSERT_OK_AND_ASSIGN(auto received_literal, buffer->ToLiteralSync());
EXPECT_THAT(received_literal->data<s4>(),
ElementsAreArray(literal.data<s4>()));
Expand Down
14 changes: 8 additions & 6 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostFullBuffer) {
auto literal = xla::LiteralUtil::CreateR1<float>({41.0f, 42.0f});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> buffer,
client->BufferFromHostLiteral(literal, client->addressable_devices()[0]));
client->BufferFromHostLiteral(literal, client->memory_spaces()[0]));

TF_ASSERT_OK_AND_ASSIGN(int64_t size, buffer->GetOnDeviceSizeInBytes());
void* dst =
Expand All @@ -865,7 +865,7 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostSubBuffer) {

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> buffer,
client->BufferFromHostLiteral(literal, client->addressable_devices()[0]));
client->BufferFromHostLiteral(literal, client->memory_spaces()[0]));
TF_ASSERT_OK_AND_ASSIGN(int64_t size, buffer->GetOnDeviceSizeInBytes());
void* dst =
tsl::port::AlignedMalloc(size, tsl::Allocator::kAllocatorAlignment);
Expand All @@ -884,7 +884,7 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostOutOfRange) {

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> buffer,
client->BufferFromHostLiteral(literal, client->addressable_devices()[0]));
client->BufferFromHostLiteral(literal, client->memory_spaces()[0]));
TF_ASSERT_OK_AND_ASSIGN(int64_t size, buffer->GetOnDeviceSizeInBytes());
void* dst =
tsl::port::AlignedMalloc(size, tsl::Allocator::kAllocatorAlignment);
Expand All @@ -901,7 +901,7 @@ TEST(StreamExecutorGpuClientTest, CopyRawToHostFuture) {
auto literal = xla::LiteralUtil::CreateR1<float>({41.0f, 42.0f});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> buffer,
client->BufferFromHostLiteral(literal, client->addressable_devices()[0]));
client->BufferFromHostLiteral(literal, client->memory_spaces()[0]));

auto dst_promise = xla::PjRtFuture<void*>::CreatePromise();
xla::PjRtFuture<void*> dst_future(dst_promise);
Expand Down Expand Up @@ -1075,8 +1075,10 @@ TEST(StreamExecutorGpuClientTest, GetAllocatorStatsTest) {

for (auto device : client->addressable_devices()) {
const xla::Literal literal = xla::LiteralUtil::CreateR0<int32_t>(0);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtBuffer> buffer,
client->BufferFromHostLiteral(literal, device));
TF_ASSERT_OK_AND_ASSIGN(auto* memory_space, device->default_memory_space())
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtBuffer> buffer,
client->BufferFromHostLiteral(literal, memory_space));

auto stats = device->GetAllocatorStats();
TF_ASSERT_OK(stats.status());
Expand Down
5 changes: 3 additions & 2 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,9 @@ class PjRtClient {
platform_name());
}

// TODO(b/277820585): remove BufferFromHostLiteral with PjRtDevice after the
// migration is done.

// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for GetReadyFuture().Await()
// completes on the return value before letting literal go out of scope.
Expand All @@ -920,8 +923,6 @@ class PjRtClient {
return Unimplemented("BufferFromHostLiteral is not implemented.");
}

// TODO(b/277820585): remove BufferFromHostLiteral with PjRtDevice after the
// migration is done.
virtual absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtMemorySpace* memory_space) {
return tsl::errors::Unimplemented(
Expand Down

0 comments on commit 0d55f37

Please sign in to comment.