diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml index d61c2dc59..725658d37 100644 --- a/.github/workflows/build-linux.yml +++ b/.github/workflows/build-linux.yml @@ -22,7 +22,7 @@ jobs: - name: make run: | cd build - make -j + sudo make install -j shell: bash - name: start-metadata-server run: | @@ -31,6 +31,11 @@ jobs: go mod tidy && go build -o http-metadata-server . ./http-metadata-server --addr=:8090 & shell: bash + - name: start-mooncake-master + run: | + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib + mooncake_master --enable_gc=false & + shell: bash - name: test run: | cd build @@ -38,3 +43,9 @@ jobs: ldconfig -v || echo "always continue" MC_METADATA_SERVER=http://127.0.0.1:8090/metadata make test -j ARGS="-V" shell: bash + - name: mooncake store python test + run: | + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib + cd mooncake-store/tests + MC_METADATA_SERVER=http://127.0.0.1:8090/metadata python3 test_distributed_object_store.py + shell: bash diff --git a/mooncake-integration/vllm/distributed_object_store.cpp b/mooncake-integration/vllm/distributed_object_store.cpp index e9739b0da..97e6e3170 100644 --- a/mooncake-integration/vllm/distributed_object_store.cpp +++ b/mooncake-integration/vllm/distributed_object_store.cpp @@ -296,7 +296,7 @@ int DistributedObjectStore::put(const std::string &key, if (ret) return ret; ErrorCode error_code = client_->Put(std::string(key), slices, config); freeSlices(slices); - if (error_code != ErrorCode::OK) return 1; + if (error_code != ErrorCode::OK) return toInt(error_code); return 0; } @@ -343,7 +343,7 @@ int DistributedObjectStore::remove(const std::string &key) { return 1; } ErrorCode error_code = client_->Remove(key); - if (error_code != ErrorCode::OK) return 1; + if (error_code != ErrorCode::OK) return toInt(error_code); return 0; } @@ -355,7 +355,7 @@ int DistributedObjectStore::isExist(const std::string &key) { ErrorCode err = client_->IsExist(key); if (err == ErrorCode::OK) return 1; // Yes if (err == ErrorCode::OBJECT_NOT_FOUND) return 0; // No - return -1; // Error + return toInt(err); // Error } int64_t DistributedObjectStore::getSize(const std::string &key) { @@ -368,7 +368,7 @@ int64_t DistributedObjectStore::getSize(const std::string &key) { ErrorCode error_code = client_->Query(key, object_info); if (error_code != ErrorCode::OK) { - return -1; // Error or object doesn't exist + return toInt(error_code); } // Calculate total size from all replicas' handles diff --git a/mooncake-store/include/master_service.h b/mooncake-store/include/master_service.h index 08971534c..3106ebb88 100644 --- a/mooncake-store/include/master_service.h +++ b/mooncake-store/include/master_service.h @@ -81,13 +81,12 @@ class MasterService { // Comparator for GC tasks priority queue struct GCTaskComparator { bool operator()(GCTask* a, GCTask* b) const { - return a->deletion_time > - b->deletion_time; // 最小堆,最早过期的优先级最高 + return a->deletion_time > b->deletion_time; } }; public: - MasterService(); + MasterService(bool enable_gc = true); ~MasterService(); /** @@ -191,6 +190,7 @@ class MasterService { boost::lockfree::queue gc_queue_{kGCQueueSize}; std::thread gc_thread_; std::atomic gc_running_{false}; + bool enable_gc_{true}; // Flag to enable/disable garbage collection static constexpr uint64_t kGCThreadSleepMs = 10; // 10 ms sleep between GC checks diff --git a/mooncake-store/src/master.cpp b/mooncake-store/src/master.cpp index 21347ab6d..2ed8731a1 100644 --- a/mooncake-store/src/master.cpp +++ b/mooncake-store/src/master.cpp @@ -17,6 +17,7 @@ // Define command line flags DEFINE_int32(port, 50051, "Port for master service to listen on"); DEFINE_int32(max_threads, 4, "Maximum number of threads to use"); +DEFINE_bool(enable_gc, true, "Enable garbage collection"); namespace mooncake { @@ -258,8 +259,12 @@ int main(int argc, char* argv[]) { return 1; } - // Create master service instance - auto master_service = std::make_shared(); + // Create master service instance with GC flag + auto master_service = + std::make_shared(FLAGS_enable_gc); + + LOG(INFO) << "Garbage collection: " + << (FLAGS_enable_gc ? "enabled" : "disabled"); // Initialize gRPC server std::string server_address = "0.0.0.0:" + std::to_string(FLAGS_port); diff --git a/mooncake-store/src/master_service.cpp b/mooncake-store/src/master_service.cpp index 238d46887..da38a8bec 100644 --- a/mooncake-store/src/master_service.cpp +++ b/mooncake-store/src/master_service.cpp @@ -66,13 +66,18 @@ ErrorCode BufferAllocatorManager::RemoveSegment( return ErrorCode::OK; } -MasterService::MasterService() +MasterService::MasterService(bool enable_gc) : buffer_allocator_manager_(std::make_shared()), - allocation_strategy_(std::make_shared()) { - // Start the GC thread - gc_running_ = true; - gc_thread_ = std::thread(&MasterService::GCThreadFunc, this); - VLOG(1) << "action=start_gc_thread"; + allocation_strategy_(std::make_shared()), + enable_gc_(enable_gc) { + // Start the GC thread if enabled + if (enable_gc_) { + gc_running_ = true; + gc_thread_ = std::thread(&MasterService::GCThreadFunc, this); + VLOG(1) << "action=start_gc_thread"; + } else { + VLOG(1) << "action=gc_disabled"; + } } MasterService::~MasterService() { @@ -133,7 +138,12 @@ ErrorCode MasterService::GetReplicaList( VLOG(1) << "key=" << key << ", replica_list=" << VectorToString(replica_list); } - MarkForGC(key, 1000); // After 1 second, the object will be removed + + // Only mark for GC if enabled + if (enable_gc_) { + MarkForGC(key, 1000); // After 1 second, the object will be removed + } + return ErrorCode::OK; } diff --git a/mooncake-store/tests/CMakeLists.txt b/mooncake-store/tests/CMakeLists.txt index 6142a51c7..bacdec6f0 100644 --- a/mooncake-store/tests/CMakeLists.txt +++ b/mooncake-store/tests/CMakeLists.txt @@ -1,9 +1,10 @@ add_executable(buffer_allocator_test buffer_allocator_test.cpp) target_link_libraries(buffer_allocator_test PUBLIC cache_allocator cachelib_memory_allocator gtest gtest_main pthread) - +add_test(NAME buffer_allocator_test COMMAND buffer_allocator_test) add_executable(master_service_test master_service_test.cpp) target_link_libraries(master_service_test PUBLIC cache_allocator cachelib_memory_allocator glog gtest gtest_main pthread) +add_test(NAME master_service_test COMMAND master_service_test) add_executable(client_integration_test client_integration_test.cpp) target_link_libraries(client_integration_test PUBLIC @@ -17,6 +18,7 @@ target_link_libraries(client_integration_test PUBLIC gtest_main pthread ) +add_test(NAME client_integration_test COMMAND client_integration_test) add_executable(stress_workload_test stress_workload_test.cpp) target_link_libraries(stress_workload_test PUBLIC diff --git a/mooncake-store/tests/client_integration_test.cpp b/mooncake-store/tests/client_integration_test.cpp index 2602ecab7..4d658a5c5 100644 --- a/mooncake-store/tests/client_integration_test.cpp +++ b/mooncake-store/tests/client_integration_test.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -12,9 +13,11 @@ #include "types.h" #include "utils.h" -static std::string protocol = "tcp"; // Transfer protocol: rdma|tcp -static std::string device_name = - "ibp6s0"; // Device name to use, valid if protocol=rdma +DEFINE_string(protocol, "tcp", "Transfer protocol: rdma|tcp"); +DEFINE_string(device_name, "ibp6s0", + "Device name to use, valid if protocol=rdma"); +DEFINE_string(transfer_engine_metadata_url, "127.0.0.1:2379", + "Metadata connection string for transfer engine"); namespace mooncake { namespace testing { @@ -29,10 +32,16 @@ class ClientIntegrationTest : public ::testing::Test { // google::SetVLOGLevel("*", 1); FLAGS_logtostderr = 1; - if (getenv("PROTOCOL")) protocol = getenv("PROTOCOL"); - if (getenv("DEVICE_NAME")) device_name = getenv("DEVICE_NAME"); - LOG(INFO) << "Protocol: " << protocol - << ", Device name: " << device_name; + // Override flags from environment variables if present + if (getenv("PROTOCOL")) FLAGS_protocol = getenv("PROTOCOL"); + if (getenv("DEVICE_NAME")) FLAGS_device_name = getenv("DEVICE_NAME"); + if (getenv("MC_METADATA_SERVER")) + FLAGS_transfer_engine_metadata_url = getenv("MC_METADATA_SERVER"); + + LOG(INFO) << "Protocol: " << FLAGS_protocol + << ", Device name: " << FLAGS_device_name + << ", Metadata URL: " << FLAGS_transfer_engine_metadata_url; + InitializeClient(); InitializeSegment(); } @@ -44,10 +53,10 @@ class ClientIntegrationTest : public ::testing::Test { } static void InitializeSegment() { - const size_t ram_buffer_size = 1024 * 1024 * 32; // 32MB + const size_t ram_buffer_size = 1024 * 1024 * 1024; // 1GB segment_ptr_ = allocate_buffer_allocator_memory(ram_buffer_size); LOG_ASSERT(segment_ptr_); - ErrorCode rc = client_->MountSegment("localhost:12345", segment_ptr_, + ErrorCode rc = client_->MountSegment("localhost:17812", segment_ptr_, ram_buffer_size); if (rc != ErrorCode::OK) { LOG(ERROR) << "Failed to mount segment: " << toString(rc); @@ -57,12 +66,15 @@ class ClientIntegrationTest : public ::testing::Test { static void InitializeClient() { client_ = std::make_unique(); - void** args = (protocol == "rdma") ? rdma_args(device_name) : nullptr; - ASSERT_EQ(client_->Init("localhost:12345", // Local hostname - "127.0.0.1:2379", // Metadata connection string - protocol, args, - "localhost:50051" // Master server address - ), + void** args = + (FLAGS_protocol == "rdma") ? rdma_args(FLAGS_device_name) : nullptr; + ASSERT_EQ(client_->Init( + "localhost:17812", // Local hostname + FLAGS_transfer_engine_metadata_url, // Metadata + // connection string + FLAGS_protocol, args, + "localhost:50051" // Master server address + ), ErrorCode::OK); client_buffer_allocator_ = std::make_unique(128 * 1024 * 1024); @@ -82,7 +94,7 @@ class ClientIntegrationTest : public ::testing::Test { } static void CleanupSegment() { - if (client_->UnmountSegment("localhost:12345", segment_ptr_) != + if (client_->UnmountSegment("localhost:17812", segment_ptr_) != ErrorCode::OK) { LOG(ERROR) << "Failed to unmount segment"; } @@ -168,7 +180,7 @@ TEST_F(ClientIntegrationTest, RemoveOperation) { } // Test heavy workload operations -TEST_F(ClientIntegrationTest, AllocateTest) { +TEST_F(ClientIntegrationTest, DISABLED_AllocateTest) { const size_t data_size = 1 * 1024 * 1024; // 1MB std::string large_data(data_size, 'A'); // Fill with 'A's const int num_operations = 13; @@ -281,3 +293,14 @@ TEST_F(ClientIntegrationTest, LargeAllocateTest) { } // namespace testing } // namespace mooncake + +int main(int argc, char** argv) { + // Initialize Google's flags library + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // Initialize Google Test + ::testing::InitGoogleTest(&argc, argv); + + // Run all tests + return RUN_ALL_TESTS(); +} diff --git a/mooncake-store/tests/master_service_test.cpp b/mooncake-store/tests/master_service_test.cpp index 9a4f501e3..0712f3406 100644 --- a/mooncake-store/tests/master_service_test.cpp +++ b/mooncake-store/tests/master_service_test.cpp @@ -547,3 +547,8 @@ TEST_F(MasterServiceTest, CleanupStaleHandlesTest) { } } // namespace mooncake::test + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/mooncake-store/tests/test_distributed_object_store.py b/mooncake-store/tests/test_distributed_object_store.py index 4e1a6d2ca..43c888554 100644 --- a/mooncake-store/tests/test_distributed_object_store.py +++ b/mooncake-store/tests/test_distributed_object_store.py @@ -11,7 +11,7 @@ def get_client(store): protocol = os.getenv("PROTOCOL", "tcp") device_name = os.getenv("DEVICE_NAME", "ibp6s0") local_hostname = os.getenv("LOCAL_HOSTNAME", "localhost") - metadata_server = os.getenv("METADATA_ADDR", "127.0.0.1:2379") + metadata_server = os.getenv("MC_METADATA_SERVER", "127.0.0.1:2379") global_segment_size = 3200 * 1024 * 1024 # 3200 MB local_buffer_size = 512 * 1024 * 1024 # 512 MB master_server_address = os.getenv("MASTER_SERVER", "127.0.0.1:50051") @@ -149,6 +149,10 @@ def worker(thread_id): # Wait for all threads to complete get operations get_barrier.wait() + # Remove all keys + for key in thread_keys: + self.assertEqual(self.store.remove(key), 0) + except Exception as e: thread_exceptions.append(f"Thread {thread_id} failed: {str(e)}") @@ -204,6 +208,64 @@ def worker(thread_id): print(f"System Put bandwidth: {total_data_size_gb/put_duration:.2f} GB/sec") print(f"System Get bandwidth: {total_data_size_gb/get_duration:.2f} GB/sec") - + def test_dict_fuzz_e2e(self): + """End-to-end fuzz test comparing distributed store behavior with dict. + Performs ~1000 random operations (put, get, remove) with random value sizes between 1KB and 64MB. + After testing, all keys are removed. + """ + import random + # Local reference dict to simulate expected dict behavior + reference = {} + operations = 1000 + # Use a pool of keys to limit memory consumption + keys_pool = [f"key_{i}" for i in range(100)] + # Track which keys have values assigned to ensure consistency + key_values = {} + # Fuzz record for debugging in case of errors + fuzz_record = [] + try: + for i in range(operations): + op = random.choice(["put", "get", "remove"]) + key = random.choice(keys_pool) + if op == "put": + # If key already exists, use the same value to ensure consistency + if key in key_values: + value = key_values[key] + size = len(value) + else: + size = random.randint(1, 64 * 1024 * 1024) + value = os.urandom(size) + key_values[key] = value + + fuzz_record.append(f"{i}: put {key} [size: {size}]") + error_code = self.store.put(key, value) + if error_code == -200: + # The space is not enough, continue to next operation + continue + elif error_code == 0: + reference[key] = value + else: + raise RuntimeError(f"Put operation failed for key {key}. Error code: {error_code}") + elif op == "get": + fuzz_record.append(f"{i}: get {key}") + retrieved = self.store.get(key) + expected = reference.get(key, b"") + self.assertEqual(retrieved, expected) + elif op == "remove": + fuzz_record.append(f"{i}: remove {key}") + self.store.remove(key) + reference.pop(key, None) + # Also remove from key_values to allow new value if key is reused + key_values.pop(key, None) + except Exception as e: + print(f"Error: {e}") + print('\nFuzz record (operations so far):') + for record in fuzz_record: + print(record) + raise e + # Cleanup: ensure all remaining keys are removed + for key in list(reference.keys()): + self.store.remove(key) + if __name__ == '__main__': unittest.main()