diff --git a/runtime/backend/backend_options.h b/runtime/backend/backend_options.h index 6c106cc156..f7b0a000d4 100644 --- a/runtime/backend/backend_options.h +++ b/runtime/backend/backend_options.h @@ -6,88 +6,94 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once -#include -#include -#include -#include - -namespace executorch { -namespace runtime { - -// Strongly-typed option key template -template -struct OptionKey { - const char* key; - constexpr explicit OptionKey(const char* k) : key(k) {} -}; - -// Union replaced with std::variant -using OptionValue = std::variant; - -struct BackendOption { - const char* key; // key is the name of the backend option, like num_threads, - // enable_profiling, etc - OptionValue - value; // value is the value of the backend option, like 4, true, etc -}; - -template -class BackendOptions { - public: - // Initialize with zero options - BackendOptions() : size_(0) {} - - // Type-safe setters - template - void set_option(OptionKey key, T value) { - const char* k = key.key; - // Update existing if found - for (size_t i = 0; i < size_; ++i) { - if (strcmp(options_[i].key, k) == 0) { - options_[i].value = value; - return; - } - } - // Add new option if space available - if (size_ < MaxCapacity) { - options_[size_++] = BackendOption{k, value}; - } - } - - // Type-safe getters - template - Error get_option(OptionKey key, T& out) const { - const char* k = key.key; - for (size_t i = 0; i < size_; ++i) { - if (strcmp(options_[i].key, k) == 0) { - if (auto* val = std::get_if(&options_[i].value)) { - out = *val; - return Error::Ok; - } - return Error::InvalidArgument; - } - } - return Error::NotFound; - } - - private: - BackendOption options_[MaxCapacity]{}; // Storage for backend options - size_t size_; // Current number of options -}; - -// Helper functions for creating typed option keys (unchanged) -constexpr OptionKey BoolKey(const char* k) { - return OptionKey(k); -} - -constexpr OptionKey IntKey(const char* k) { - return OptionKey(k); -} - -constexpr OptionKey StrKey(const char* k) { - return OptionKey(k); -} - -} // namespace runtime -} // namespace executorch + #pragma once + #include + #include + #include + #include + #include + #include + + namespace executorch { + namespace runtime { + + // Strongly-typed option key template + template + struct OptionKey { + const char* key; + constexpr explicit OptionKey(const char* k) : key(k) {} + }; + + // Union replaced with std::variant + using OptionValue = std::variant; + + struct BackendOption { + const char* key; // key is the name of the backend option, like num_threads, + // enable_profiling, etc + OptionValue + value; // value is the value of the backend option, like 4, true, etc + }; + + template + class BackendOptions { + public: + // Initialize with zero options + BackendOptions() : size_(0) {} + + // Type-safe setters + template + void set_option(OptionKey key, T value) { + const char* k = key.key; + // Update existing if found + for (size_t i = 0; i < size_; ++i) { + if (strcmp(options_[i].key, k) == 0) { + options_[i].value = value; + return; + } + } + // Add new option if space available + if (size_ < MaxCapacity) { + options_[size_++] = BackendOption{k, value}; + } + } + + // Type-safe getters + template + Error get_option(OptionKey key, T& out) const { + const char* k = key.key; + for (size_t i = 0; i < size_; ++i) { + if (strcmp(options_[i].key, k) == 0) { + if (auto* val = std::get_if(&options_[i].value)) { + out = *val; + return Error::Ok; + } + return Error::InvalidArgument; + } + } + return Error::NotFound; + } + executorch::runtime::ArrayRef view() const { + return executorch::runtime::ArrayRef(options_, size_); + } + + private: + BackendOption options_[MaxCapacity]{}; // Storage for backend options + size_t size_; // Current number of options + }; + + // Helper functions for creating typed option keys (unchanged) + constexpr OptionKey BoolKey(const char* k) { + return OptionKey(k); + } + + constexpr OptionKey IntKey(const char* k) { + return OptionKey(k); + } + + constexpr OptionKey StrKey(const char* k) { + return OptionKey(k); + } + + } // namespace runtime + } // namespace executorch + \ No newline at end of file diff --git a/runtime/backend/interface.h b/runtime/backend/interface.h index 95705d48f9..8643e316e0 100644 --- a/runtime/backend/interface.h +++ b/runtime/backend/interface.h @@ -12,6 +12,8 @@ #include #include +#include +#include #include #include #include @@ -99,6 +101,20 @@ class BackendInterface { DelegateHandle* handle, EValue** args) const = 0; + /** + * Responsible update the backend status, if any. The backend options are passed in + * by users, and the backend can update its internal status based on the options. + * + * @param[in] context Runtime context if any. Currently it's not used. + * @param[in] args A list of BackendOptions passed in by users. + * @retval Error::Ok if successful. + */ + ET_NODISCARD virtual Error update( + BackendUpdateContext& context, + const executorch::runtime::ArrayRef& backend_options) const { + return Error::Ok; + }; + /** * Responsible for destroying a handle, if it's required for some backend. * It may be needed for some backends. For example, resources associated with diff --git a/runtime/backend/test/backend_interface_update_test.cpp b/runtime/backend/test/backend_interface_update_test.cpp new file mode 100644 index 0000000000..20d9100f83 --- /dev/null +++ b/runtime/backend/test/backend_interface_update_test.cpp @@ -0,0 +1,285 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + + #include + #include + + #include + + using namespace ::testing; +using executorch::runtime::BackendInterface; +using executorch::runtime::Result; +using executorch::runtime::DelegateHandle; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::ArrayRef; +using executorch::runtime::Error; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::EValue; +using executorch::runtime::BackendUpdateContext; +using executorch::runtime::BackendOption; +using executorch::runtime::BackendOptions; +using executorch::runtime::Backend; +using executorch::runtime::StrKey; +using executorch::runtime::IntKey; +using executorch::runtime::BoolKey; +using executorch::runtime::get_backend_class; +using executorch::runtime::MemoryAllocator; + +class MockBackend : public BackendInterface { + public: + ~MockBackend() override = default; + + bool is_available() const override { return true; } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + init_called = true; + return nullptr; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + EValue** args) const override { + execute_count++; + return Error::Ok; + } + + Error update( + BackendUpdateContext& context, + const executorch::runtime::ArrayRef& backend_options) const override { + update_count++; + int sucess_update = 0; + for (const auto& backend_option : backend_options) { + if (strcmp(backend_option.key, "Backend") == 0) { + if (std::holds_alternative(backend_option.value)) { + // Store the value in our member variable + target_backend = std::get(backend_option.value); + sucess_update++; + } + } else if (strcmp(backend_option.key, "NumberOfThreads") == 0) { + if (std::holds_alternative(backend_option.value)) { + num_threads = std::get(backend_option.value); + sucess_update++; + } + } else if (strcmp(backend_option.key, "Debug") == 0) { + if (std::holds_alternative(backend_option.value)) { + debug = std::get(backend_option.value); + sucess_update++; + } + } + } + if (sucess_update == backend_options.size()) { + return Error::Ok; + } + return Error::InvalidArgument; + } + + // Mutable allows modification in const methods + mutable std::optional target_backend; + mutable int num_threads = 0; + mutable bool debug = false; + + // State tracking + mutable bool init_called = false; + mutable int execute_count = 0; + mutable int update_count = 0; +}; + + class BackendInterfaceUpdateTest : public ::testing::Test { + protected: + + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + mock_backend = std::make_unique(); + // static Error register_success = register_executor_backend(); + } + + std::unique_ptr mock_backend; + BackendOptions<5> options; +}; + +TEST_F(BackendInterfaceUpdateTest, HandlesInvalidOption) { + BackendUpdateContext context; + + // Test invalid key case + BackendOption invalid_option{ + "InvalidKey", + "None" + }; + + Error err = mock_backend->update(context, invalid_option); + EXPECT_EQ(err, Error::InvalidArgument); + +} + + TEST_F(BackendInterfaceUpdateTest, HandlesStringOption) { + BackendUpdateContext context; + options.set_option(StrKey("Backend"), "GPU"); + // // Create a backend option to pass to update + + EXPECT_EQ(mock_backend->target_backend, std::nullopt); + + // Test successful update + Error err = mock_backend->update(context, options.view()); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(mock_backend->target_backend, "GPU"); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesIntOption) { + // Check the default num_threads value is 0 + EXPECT_EQ(mock_backend->debug, false); + // Create a mock context (needs to be defined or mocked) + BackendUpdateContext context; + + int expected_num_threads = 4; + + // Create a backend option to pass to update + options.set_option(IntKey("NumberOfThreads"), expected_num_threads); + + // Test successful update + Error err = mock_backend->update(context, options.view()); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(mock_backend->num_threads, expected_num_threads); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesBoolOption) { + // Check the default num_threads value is 0 + EXPECT_EQ(mock_backend->debug, false); + // Create a mock context (needs to be defined or mocked) + BackendUpdateContext context; + + options.set_option(BoolKey("Debug"), true); + + // Test successful update + Error err = mock_backend->update(context, options.view()); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(mock_backend->debug, true); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesMultipleOptions) { + // Check the default num_threads value is 0 + EXPECT_EQ(mock_backend->debug, false); + // Create a mock context (needs to be defined or mocked) + BackendUpdateContext context; + + options.set_option(BoolKey("Debug"), true); + options.set_option(IntKey("NumberOfThreads"), 4); + options.set_option(StrKey("Backend"), "GPU"); + + // Test successful update + Error err = mock_backend->update(context, options.view()); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(mock_backend->debug, true); + EXPECT_EQ(mock_backend->num_threads, 4); + EXPECT_EQ(mock_backend->target_backend, "GPU"); +} + +TEST_F(BackendInterfaceUpdateTest, UpdateBeforeInit) { + BackendUpdateContext update_context; + MemoryAllocator memory_allocator{MemoryAllocator(0, nullptr)}; + + BackendInitContext init_context(&memory_allocator); + + // Create backend option + options.set_option(StrKey("Backend"), "GPU"); + + // Update before init + Error err = mock_backend->update(update_context, options.view()); + EXPECT_EQ(err, Error::Ok); + + // Now call init + FreeableBuffer* processed = nullptr; // Not used in mock + ArrayRef compile_specs; // Empty + auto handle_or_error = mock_backend->init(init_context, processed, compile_specs); + EXPECT_EQ(handle_or_error.error(), Error::Ok); + + // Verify state + EXPECT_TRUE(mock_backend->init_called); + EXPECT_EQ(mock_backend->update_count, 1); + EXPECT_EQ(mock_backend->execute_count, 0); + ASSERT_TRUE(mock_backend->target_backend.has_value()); + EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "GPU"); +} + +TEST_F(BackendInterfaceUpdateTest, UpdateAfterInitBeforeExecute) { + BackendUpdateContext update_context; + MemoryAllocator init_memory_allocator{MemoryAllocator(0, nullptr)}; + BackendInitContext init_context(&init_memory_allocator); + BackendExecutionContext execute_context; + + // First call init + FreeableBuffer* processed = nullptr; + ArrayRef compile_specs; + auto handle_or_error = mock_backend->init(init_context, processed, compile_specs); + EXPECT_TRUE(handle_or_error.ok()); + + // Verify init called but execute not called + EXPECT_TRUE(mock_backend->init_called); + EXPECT_EQ(mock_backend->execute_count, 0); + + // Now update + options.set_option(StrKey("Backend"), "CPU"); + Error err = mock_backend->update(update_context, options.view()); + EXPECT_EQ(err, Error::Ok); + + // Now execute + DelegateHandle* handle = handle_or_error.get(); + EValue** args = nullptr; // Not used in mock + err = mock_backend->execute(execute_context, handle, args); + EXPECT_EQ(err, Error::Ok); + + // Verify state + EXPECT_EQ(mock_backend->update_count, 1); + EXPECT_EQ(mock_backend->execute_count, 1); + ASSERT_TRUE(mock_backend->target_backend.has_value()); + EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "CPU"); +} + +TEST_F(BackendInterfaceUpdateTest, UpdateBetweenExecutes) { + BackendUpdateContext update_context; + MemoryAllocator init_memory_allocator{MemoryAllocator(0, nullptr)}; + BackendInitContext init_context(&init_memory_allocator); + BackendExecutionContext execute_context; + + // Initialize + FreeableBuffer* processed = nullptr; + ArrayRef compile_specs; + auto handle_or_error = mock_backend->init(init_context, processed, compile_specs); + EXPECT_TRUE(handle_or_error.ok()); + DelegateHandle* handle = handle_or_error.get(); + + // First execute + EValue** args = nullptr; + Error err = mock_backend->execute(execute_context, handle, args); + EXPECT_EQ(err, Error::Ok); + + // Update between executes + options.set_option(StrKey("Backend"), "NPU"); + err = mock_backend->update(update_context, options.view()); + EXPECT_EQ(err, Error::Ok); + + // Second execute + err = mock_backend->execute(execute_context, handle, args); + EXPECT_EQ(err, Error::Ok); + + // Verify state + EXPECT_EQ(mock_backend->update_count, 1); + EXPECT_EQ(mock_backend->execute_count, 2); + ASSERT_TRUE(mock_backend->target_backend.has_value()); + EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "NPU"); +} diff --git a/runtime/backend/test/targets.bzl b/runtime/backend/test/targets.bzl index 97299bbcb3..5430cb17cc 100644 --- a/runtime/backend/test/targets.bzl +++ b/runtime/backend/test/targets.bzl @@ -14,3 +14,12 @@ def define_common_targets(): "//executorch/runtime/backend:interface", ], ) + + runtime.cxx_test( + name = "backend_interface_update_test", + srcs = ["backend_interface_update_test.cpp"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/backend:interface", + ], + )