Skip to content

Commit 76994d4

Browse files
rmazpytorchmergebot
authored andcommitted
[pytorch] add experimental TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT (pytorch#150537)
Summary: Add an experimental feature to defer pytorch library initialization cost to post startup. As noted this feature is not thread safe, it requires the client to maintain thread safety at library load time. Reviewed By: zou3519 Differential Revision: D71917841 Pull Request resolved: pytorch#150537 Approved by: https://github.com/zou3519
1 parent 9e55dae commit 76994d4

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

aten/src/ATen/core/library.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ void Library::reset() {
5858

5959
#define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")"
6060

61+
#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT
62+
namespace detail {
63+
std::vector<TorchLibraryInit*> torch_library_initializers;
64+
} // namespace detail
65+
void initialize_torch_libraries() {
66+
for (auto* initializer : detail::torch_library_initializers) {
67+
initializer->initialize();
68+
}
69+
detail::torch_library_initializers.clear();
70+
}
71+
#endif
72+
6173
Library::Library(Kind kind, std::string ns, std::optional<c10::DispatchKey> k, const char* file, uint32_t line)
6274
: kind_(kind)
6375
, ns_(ns == "_" ? std::nullopt : std::make_optional(std::move(ns)))

torch/csrc/jit/mobile/import.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <torch/csrc/jit/serialization/import_export_functions.h>
2323
#include <torch/csrc/jit/serialization/import_read.h>
2424
#include <torch/custom_class.h>
25+
#include <torch/library.h>
2526
#include <optional>
2627
#include <string>
2728
#include <vector>
@@ -646,6 +647,9 @@ mobile::Module _load_for_mobile(
646647
std::optional<at::Device> device,
647648
ExtraFilesMap& extra_files,
648649
uint64_t module_load_options) {
650+
#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT
651+
torch::initialize_torch_libraries();
652+
#endif
649653
auto observer = torch::observerConfig().getModuleObserver();
650654
if (observer) {
651655
extra_files.insert(std::make_pair("model_path", filename));

torch/library.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,42 @@ class TORCH_API Library final {
884884
at::OperatorName _parseNameForLib(const char* name_str) const;
885885
};
886886

887+
#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT
888+
void initialize_torch_libraries();
889+
#endif
890+
887891
namespace detail {
888892

893+
#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT
894+
extern std::vector<TorchLibraryInit*> torch_library_initializers;
895+
class TorchLibraryInit final {
896+
private:
897+
using InitFn = void(Library&);
898+
Library::Kind kind;
899+
InitFn* init_function;
900+
const char* ns;
901+
std::optional<c10::DispatchKey> key;
902+
const char* file;
903+
uint32_t line;
904+
std::unique_ptr<Library> lib = nullptr;
905+
906+
public:
907+
TorchLibraryInit(
908+
Library::Kind kind,
909+
InitFn* fn,
910+
const char* ns,
911+
std::optional<c10::DispatchKey> k,
912+
const char* file,
913+
uint32_t line) : kind(kind), init_function(fn), ns(ns), key(k), file(file), line(line) {
914+
torch_library_initializers.push_back(this);
915+
}
916+
917+
void initialize() {
918+
lib = std::unique_ptr<Library>(new Library(kind, ns, key, file, line));
919+
init_function(*lib);
920+
}
921+
};
922+
#else
889923
class TorchLibraryInit final {
890924
private:
891925
using InitFn = void(Library&);
@@ -903,6 +937,7 @@ class TorchLibraryInit final {
903937
fn(lib_);
904938
}
905939
};
940+
#endif
906941

907942
} // namespace detail
908943

0 commit comments

Comments
 (0)