@@ -60,6 +60,7 @@ limitations under the License.
60
60
#include " xla/python/ifrt/user_context.h"
61
61
#include " xla/python/nb_numpy.h"
62
62
#include " xla/python/pjrt_ifrt/pjrt_dtype.h"
63
+ #include " xla/python/safe_static_init.h"
63
64
#include " xla/python/types.h"
64
65
#include " xla/shape.h"
65
66
#include " xla/tsl/concurrency/ref_count.h"
@@ -591,9 +592,11 @@ absl::StatusOr<ShardFn> MakeShardFn(nb::handle arg, ifrt::Client* client,
591
592
ifrt::Device* to_device,
592
593
ifrt::MemoryKind to_memory_kind,
593
594
const DevicePutOptions& options) {
594
- static const absl::flat_hash_map<PyObject*,
595
- DevicePutHandler>* const handlers = [] {
596
- auto p = new absl::flat_hash_map<PyObject*, DevicePutHandler>();
595
+ using PyObjectDeviceHandlerMap = absl::flat_hash_map<PyObject*, DevicePutHandler>;
596
+
597
+ auto init_fn = [](){
598
+ std::unique_ptr<PyObjectDeviceHandlerMap> p = std::make_unique<PyObjectDeviceHandlerMap>();
599
+
597
600
const NumpyScalarTypes& dtypes = GetNumpyScalarTypes ();
598
601
// Python scalar types.
599
602
static_assert (sizeof (bool ) == 1 , " Conversion code assumes bool is 1 byte" );
@@ -660,20 +663,20 @@ absl::StatusOr<ShardFn> MakeShardFn(nb::handle arg, ifrt::Client* client,
660
663
static_assert (sizeof (int ) == sizeof (int32_t ),
661
664
" int must be the same size as int32_t" );
662
665
(*p)[dtypes.np_intc .ptr ()] = HandleNumpyScalar<int32_t >;
663
-
664
666
return p;
665
- }();
667
+ };
668
+ const PyObjectDeviceHandlerMap& handlers = xla::SafeStaticInit<PyObjectDeviceHandlerMap>(init_fn);
666
669
667
670
if (arg.type ().ptr () == PyArray::type ().ptr ()) {
668
671
auto array = nb::borrow<PyArray>(arg);
669
672
return HandlePyArray (arg, client, to_device, to_memory_kind, options);
670
673
}
671
674
672
- auto res = handlers-> find (arg.type ().ptr ());
673
- if (res == handlers-> end ()) {
675
+ auto res = handlers. find (arg.type ().ptr ());
676
+ if (res == handlers. end ()) {
674
677
for (auto base_class : arg.type ().attr (" __mro__" )) {
675
- res = handlers-> find (base_class.ptr ());
676
- if (res != handlers-> end ()) {
678
+ res = handlers. find (base_class.ptr ());
679
+ if (res != handlers. end ()) {
677
680
return res->second (arg, client, to_device, to_memory_kind, options);
678
681
}
679
682
}
0 commit comments