Skip to content

Commit 80fc4fa

Browse files
Merge pull request #28387 from vfdev-5:fix-deadlock-makeshardfn-28385
PiperOrigin-RevId: 755565110
2 parents 3299aed + 2c52ae3 commit 80fc4fa

File tree

3 files changed

+19
-26
lines changed

3 files changed

+19
-26
lines changed

jaxlib/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,7 @@ cc_library(
885885
"@xla//xla/python:nb_helpers",
886886
"@xla//xla/python:nb_numpy",
887887
"@xla//xla/python:pprof_profile_builder",
888+
"@xla//xla/python:safe_static_init",
888889
"@xla//xla/python:types",
889890
"@xla//xla/python:version",
890891
"@xla//xla/python/compile_only_ifrt:client",

jaxlib/py_values.cc

+12-9
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ limitations under the License.
6060
#include "xla/python/ifrt/user_context.h"
6161
#include "xla/python/nb_numpy.h"
6262
#include "xla/python/pjrt_ifrt/pjrt_dtype.h"
63+
#include "xla/python/safe_static_init.h"
6364
#include "xla/python/types.h"
6465
#include "xla/shape.h"
6566
#include "xla/tsl/concurrency/ref_count.h"
@@ -591,9 +592,11 @@ absl::StatusOr<ShardFn> MakeShardFn(nb::handle arg, ifrt::Client* client,
591592
ifrt::Device* to_device,
592593
ifrt::MemoryKind to_memory_kind,
593594
const DevicePutOptions& options) {
594-
static const absl::flat_hash_map<PyObject*,
595-
DevicePutHandler>* const handlers = [] {
596-
auto p = new absl::flat_hash_map<PyObject*, DevicePutHandler>();
595+
using PyObjectDeviceHandlerMap = absl::flat_hash_map<PyObject*, DevicePutHandler>;
596+
597+
auto init_fn = [](){
598+
std::unique_ptr<PyObjectDeviceHandlerMap> p = std::make_unique<PyObjectDeviceHandlerMap>();
599+
597600
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
598601
// Python scalar types.
599602
static_assert(sizeof(bool) == 1, "Conversion code assumes bool is 1 byte");
@@ -660,20 +663,20 @@ absl::StatusOr<ShardFn> MakeShardFn(nb::handle arg, ifrt::Client* client,
660663
static_assert(sizeof(int) == sizeof(int32_t),
661664
"int must be the same size as int32_t");
662665
(*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar<int32_t>;
663-
664666
return p;
665-
}();
667+
};
668+
const PyObjectDeviceHandlerMap& handlers = xla::SafeStaticInit<PyObjectDeviceHandlerMap>(init_fn);
666669

667670
if (arg.type().ptr() == PyArray::type().ptr()) {
668671
auto array = nb::borrow<PyArray>(arg);
669672
return HandlePyArray(arg, client, to_device, to_memory_kind, options);
670673
}
671674

672-
auto res = handlers->find(arg.type().ptr());
673-
if (res == handlers->end()) {
675+
auto res = handlers.find(arg.type().ptr());
676+
if (res == handlers.end()) {
674677
for (auto base_class : arg.type().attr("__mro__")) {
675-
res = handlers->find(base_class.ptr());
676-
if (res != handlers->end()) {
678+
res = handlers.find(base_class.ptr());
679+
if (res != handlers.end()) {
677680
return res->second(arg, client, to_device, to_memory_kind, options);
678681
}
679682
}

jaxlib/sharding.cc

+6-17
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ limitations under the License.
4141
#include "xla/pjrt/status_casters.h"
4242
#include "xla/python/ifrt/device_list.h"
4343
#include "xla/python/nb_numpy.h"
44+
#include "xla/python/safe_static_init.h"
4445
#include "xla/tsl/platform/logging.h"
4546
#include "xla/tsl/platform/statusor.h"
4647
#include "xla/xla_data.pb.h"
@@ -240,24 +241,12 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec,
240241
// TODO(phawkins): this leaks a reference to the check_pspec function.
241242
// A better way to fix this would be to move PartitionSpec and this check into
242243
// C++.
243-
nb::object* check_pspec = []() {
244-
static absl::Mutex mu;
245-
static nb::object* output = nullptr;
246-
{
247-
absl::MutexLock lock(&mu);
248-
if (output) {
249-
return output;
250-
}
251-
}
244+
auto init_fn = [](){
252245
nb::module_ si = nb::module_::import_("jax._src.named_sharding");
253-
nb::object attr = si.attr("check_pspec");
254-
absl::MutexLock lock(&mu);
255-
if (!output) {
256-
output = new nb::object(attr);
257-
}
258-
return output;
259-
}();
260-
(*check_pspec)(mesh_, spec_);
246+
return std::make_unique<nb::object>(si.attr("check_pspec"));
247+
};
248+
nb::object& check_pspec = xla::SafeStaticInit<nb::object>(init_fn);
249+
check_pspec(mesh_, spec_);
261250
}
262251

263252
/*static*/ PyObject* NamedSharding::type_ = nullptr;

0 commit comments

Comments
 (0)