Skip to content

Commit b5e81ae

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Replace std::shared_ptr<xla::ifrt::LoadedExecutable> with xla::ifrt::LoadedExecutableRef
PiperOrigin-RevId: 755581247
1 parent 35b25c9 commit b5e81ae

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

jaxlib/py_compile_only_client.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class CompileOnlyPyClient : public PyClient {
7070
return client;
7171
}
7272

73-
absl::StatusOr<std::shared_ptr<ifrt::Executable>> CompileUnloaded(
73+
absl::StatusOr<ifrt::ExecutableRef> CompileUnloaded(
7474
absl::string_view mlir_module, ifrt::DeviceListRef executable_devices,
7575
CompileOptions options, std::vector<nb::capsule> host_callbacks) {
7676
if (!host_callbacks.empty()) {
@@ -102,7 +102,7 @@ class CompileOnlyPyClient : public PyClient {
102102
*ifrt_client->topology().description()));
103103
TF_ASSIGN_OR_RETURN(auto ifrt_executable,
104104
ifrt::PjRtExecutable::Create(std::move(executable)));
105-
return std::shared_ptr<ifrt::Executable>(std::move(ifrt_executable));
105+
return ifrt::ExecutableRef(std::move(ifrt_executable));
106106
}
107107

108108
private:

jaxlib/py_executable.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ absl::Status PyShardedToken::Await() {
8484

8585
PyLoadedExecutable::PyLoadedExecutable(
8686
nb_class_ptr<PyClient> client,
87-
std::shared_ptr<ifrt::LoadedExecutable> ifrt_loaded_executable,
87+
ifrt::LoadedExecutableRef ifrt_loaded_executable,
8888
std::optional<nb_traceback> traceback,
8989
std::optional<std::string> fingerprint)
9090
: client_(std::move(client)),

jaxlib/py_executable.h

+6-7
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,18 @@ using ExecuteShardedArg = std::variant<PyArray, std::vector<PyArray>>;
131131
// b) to add Python-specific functionality.
132132
class PyLoadedExecutable {
133133
public:
134-
PyLoadedExecutable(
135-
nb_class_ptr<PyClient> client,
136-
std::shared_ptr<ifrt::LoadedExecutable> ifrt_loaded_executable,
137-
std::optional<nb_traceback> traceback,
138-
std::optional<std::string> fingerprint);
134+
PyLoadedExecutable(nb_class_ptr<PyClient> client,
135+
ifrt::LoadedExecutableRef ifrt_loaded_executable,
136+
std::optional<nb_traceback> traceback,
137+
std::optional<std::string> fingerprint);
139138
~PyLoadedExecutable();
140139

141140
nb_class_ptr<PyClient> client() const { return client_; }
142141
ifrt::LoadedExecutable* ifrt_loaded_executable() const {
143142
return ifrt_loaded_executable_.get();
144143
}
145144

146-
std::shared_ptr<ifrt::LoadedExecutable> shared_ifrt_loaded_executable() {
145+
ifrt::LoadedExecutableRef shared_ifrt_loaded_executable() {
147146
return ifrt_loaded_executable_;
148147
}
149148

@@ -226,7 +225,7 @@ class PyLoadedExecutable {
226225
friend class PyClient;
227226

228227
nb_class_ptr<PyClient> client_;
229-
std::shared_ptr<ifrt::LoadedExecutable> ifrt_loaded_executable_;
228+
ifrt::LoadedExecutableRef ifrt_loaded_executable_;
230229
std::optional<nb_traceback> traceback_;
231230

232231
// Identical executables (i.e. representing the same program) will have the

0 commit comments

Comments
 (0)