@@ -60,6 +60,7 @@ limitations under the License.
6060#include " xla/python/ifrt/user_context.h"
6161#include " xla/python/nb_numpy.h"
6262#include " xla/python/pjrt_ifrt/pjrt_dtype.h"
63+ #include " xla/python/safe_static_init.h"
6364#include " xla/python/types.h"
6465#include " xla/shape.h"
6566#include " xla/tsl/concurrency/ref_count.h"
@@ -591,9 +592,11 @@ absl::StatusOr<ShardFn> MakeShardFn(nb::handle arg, ifrt::Client* client,
591592 ifrt::Device* to_device,
592593 ifrt::MemoryKind to_memory_kind,
593594 const DevicePutOptions& options) {
594- static const absl::flat_hash_map<PyObject*,
595- DevicePutHandler>* const handlers = [] {
596- auto p = new absl::flat_hash_map<PyObject*, DevicePutHandler>();
595+ using PyObjectDeviceHandlerMap = absl::flat_hash_map<PyObject*, DevicePutHandler>;
596+
597+ auto init_fn = [](){
598+ std::unique_ptr<PyObjectDeviceHandlerMap> p = std::make_unique<PyObjectDeviceHandlerMap>();
599+
597600 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes ();
598601 // Python scalar types.
599602 static_assert (sizeof (bool ) == 1 , " Conversion code assumes bool is 1 byte" );
@@ -660,20 +663,20 @@ absl::StatusOr<ShardFn> MakeShardFn(nb::handle arg, ifrt::Client* client,
660663 static_assert (sizeof (int ) == sizeof (int32_t ),
661664 " int must be the same size as int32_t" );
662665 (*p)[dtypes.np_intc .ptr ()] = HandleNumpyScalar<int32_t >;
663-
664666 return p;
665- }();
667+ };
668+ const PyObjectDeviceHandlerMap& handlers = xla::SafeStaticInit<PyObjectDeviceHandlerMap>(init_fn);
666669
667670 if (arg.type ().ptr () == PyArray::type ().ptr ()) {
668671 auto array = nb::borrow<PyArray>(arg);
669672 return HandlePyArray (arg, client, to_device, to_memory_kind, options);
670673 }
671674
672- auto res = handlers-> find (arg.type ().ptr ());
673- if (res == handlers-> end ()) {
675+ auto res = handlers. find (arg.type ().ptr ());
676+ if (res == handlers. end ()) {
674677 for (auto base_class : arg.type ().attr (" __mro__" )) {
675- res = handlers-> find (base_class.ptr ());
676- if (res != handlers-> end ()) {
678+ res = handlers. find (base_class.ptr ());
679+ if (res != handlers. end ()) {
677680 return res->second (arg, client, to_device, to_memory_kind, options);
678681 }
679682 }
0 commit comments