From 25853122aa4c7714625fe16698c6493226b77c07 Mon Sep 17 00:00:00 2001 From: liuziqian <2949547669@qq.com> Date: Fri, 10 Jan 2025 20:38:09 +0800 Subject: [PATCH 1/2] implement shm --- dependencies.sh | 4 +- doc/zh/transfer-engine.md | 6 +- .../example/http-metadata-server/go.mod | 2 +- .../include/transfer_engine_c.h | 5 + .../transport/cxl_transport/cxl_transport.h | 9 + .../transport/shm_transport/shm_transport.h | 98 +++++ .../src/multi_transport.cpp | 3 + .../src/transfer_engine.cpp | 1 + .../src/transfer_metadata.cpp | 26 +- .../src/transport/CMakeLists.txt | 3 + .../transport/cxl_transport/CMakeLists.txt | 2 +- .../transport/cxl_transport/cxl_transport.cpp | 197 +++++++++- .../transport/shm_transport/CMakeLists.txt | 4 + .../transport/shm_transport/shm_transport.cpp | 300 +++++++++++++++ .../transport/tcp_transport/CMakeLists.txt | 4 +- mooncake-transfer-engine/tests/CMakeLists.txt | 4 + .../tests/cxl_transport_test.cpp | 341 +++++++++++++++++ .../tests/shm_transport_test.cpp | 351 ++++++++++++++++++ .../tests/tcp_transport_test.cpp | 4 +- 19 files changed, 1341 insertions(+), 23 deletions(-) create mode 100644 mooncake-transfer-engine/include/transport/shm_transport/shm_transport.h create mode 100644 mooncake-transfer-engine/src/transport/shm_transport/CMakeLists.txt create mode 100644 mooncake-transfer-engine/src/transport/shm_transport/shm_transport.cpp create mode 100644 mooncake-transfer-engine/tests/cxl_transport_test.cpp create mode 100644 mooncake-transfer-engine/tests/shm_transport_test.cpp diff --git a/dependencies.sh b/dependencies.sh index 46b9048..01ba78c 100755 --- a/dependencies.sh +++ b/dependencies.sh @@ -55,7 +55,7 @@ cmake .. make -j$(nproc) && sudo make install echo "*** Download and installing [golang-1.22] ***" -wget https://go.dev/dl/go1.22.linux-amd64.tar.gz -sudo tar -C /usr/local -xzf go1.22.linux-amd64.tar.gz +wget https://go.dev/dl/go1.22.10.linux-amd64.tar.gz +sudo tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz echo "*** Dependencies Installed! ***" diff --git a/doc/zh/transfer-engine.md b/doc/zh/transfer-engine.md index 19e404b..b6c5192 100644 --- a/doc/zh/transfer-engine.md +++ b/doc/zh/transfer-engine.md @@ -79,7 +79,7 @@ Transfer Engine 使用SIEVE算法来管理端点的逐出。如果由于链路 例如,可使用如下命令行启动 `etcd` 服务: ```bash # This is 10.0.0.1 - etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://10.0.0.1:2379 + etcd --listen-client-urls http://127.0.0.1:2379 --advertise-client-urls http://127.0.0.1:2379 ``` 1.2. **启动 `http` 作为 `metadata` 服务** @@ -95,8 +95,8 @@ Transfer Engine 使用SIEVE算法来管理端点的逐出。如果由于链路 ```bash # This is 10.0.0.2 ./transfer_engine_bench --mode=target \ - --metadata_server=etcd://10.0.0.1:2379 \ - --local_server_name=10.0.0.2:12345 \ + --metadata_server=etcd://127.0.0.1:2379 \ + --local_server_name=127.0.0.1:12345 \ --device_name=erdma_0 ``` 各个参数的含义如下: diff --git a/mooncake-transfer-engine/example/http-metadata-server/go.mod b/mooncake-transfer-engine/example/http-metadata-server/go.mod index 5deeec0..39027b6 100644 --- a/mooncake-transfer-engine/example/http-metadata-server/go.mod +++ b/mooncake-transfer-engine/example/http-metadata-server/go.mod @@ -1,6 +1,6 @@ module github.com/kvcache-ai/Mooncake/mooncake-transfer-engine/example/http-metadata-server -go 1.22.9 +go 1.23 require github.com/gin-gonic/gin v1.10.0 diff --git a/mooncake-transfer-engine/include/transfer_engine_c.h b/mooncake-transfer-engine/include/transfer_engine_c.h index bd3a942..9077131 100644 --- a/mooncake-transfer-engine/include/transfer_engine_c.h +++ b/mooncake-transfer-engine/include/transfer_engine_c.h @@ -70,6 +70,11 @@ struct segment_desc { uint64_t port; // maybe more needed for mount } nvmeof; + struct { + void *addr; + uint16_t size; + const char *location; + }shm; } desc_; }; diff --git a/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h b/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h index 4cc6f09..ac38a8e 100644 --- a/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h +++ b/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h @@ -44,6 +44,10 @@ class CxlTransport : public Transport { ~CxlTransport(); + void createSharedMem(void* addr, size_t size, const std::string& location); + + void delete_shared_memory(void *addr); + BatchID allocateBatchID(size_t batch_size) override; int submitTransfer(BatchID batch_id, @@ -55,6 +59,7 @@ class CxlTransport : public Transport { int freeBatchID(BatchID batch_id) override; private: + int install(std::string &local_server_name, std::shared_ptr meta, void **args) override; @@ -77,6 +82,10 @@ class CxlTransport : public Transport { } const char *getName() const override { return "cxl"; } + + std::unordered_map SharedMem_map_; + + void startSlice(Slice *slice); }; } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transport/shm_transport/shm_transport.h b/mooncake-transfer-engine/include/transport/shm_transport/shm_transport.h new file mode 100644 index 0000000..fe0b8e1 --- /dev/null +++ b/mooncake-transfer-engine/include/transport/shm_transport/shm_transport.h @@ -0,0 +1,98 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SHM_TRANSPORT_H_ +#define SHM_TRANSPORT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "transfer_metadata.h" +#include "transport/transport.h" + +namespace mooncake { +class TransferMetadata; + +class ShmTransport : public Transport { + public: + using BufferDesc = TransferMetadata::BufferDesc; + using SegmentDesc = TransferMetadata::SegmentDesc; + using HandShakeDesc = TransferMetadata::HandShakeDesc; + + public: + ShmTransport(); + + ~ShmTransport(); + + void createSharedMem(void* addr, size_t size, const std::string& location); + + void delete_shared_memory(void *addr); + + BatchID allocateBatchID(size_t batch_size) override; + + int submitTransfer(BatchID batch_id, + const std::vector &entries) override; + + int getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) override; + + int freeBatchID(BatchID batch_id) override; + + int submitTransferTask( + const std::vector &request_list, + const std::vector &task_list); + + private: + + int install(std::string &local_server_name, + std::shared_ptr meta, void **args) override; + + int registerLocalMemory(void *addr, size_t length, const std::string &location, + bool remote_accessible, + bool update_metadata) override; + + int unregisterLocalMemory(void *addr, + bool update_metadata = false) override; + + int registerLocalMemoryBatch( + const std::vector &buffer_list, + const std::string &location) override { + return 0; + } + + int unregisterLocalMemoryBatch( + const std::vector &addr_list) override { + return 0; + } + + const char *getName() const override { return "shm"; } + + std::unordered_map SharedMem_map_; + + void startSlice(Slice *slice); + + int allocateLocalSegmentID(); +}; +} // namespace mooncake + +#endif \ No newline at end of file diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index f2c1cc3..5680ac3 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -16,6 +16,7 @@ #include "transport/rdma_transport/rdma_transport.h" #include "transport/tcp_transport/tcp_transport.h" +#include "transport/shm_transport/shm_transport.h" #include "transport/transport.h" #ifdef USE_CUDA #include "transport/nvmeof_transport/nvmeof_transport.h" @@ -130,6 +131,8 @@ Transport *MultiTransport::installTransport(const std::string &proto, transport = new RdmaTransport(); } else if (std::string(proto) == "tcp") { transport = new TcpTransport(); + } else if (std::string(proto) == "shm") { + transport = new ShmTransport(); } #ifdef USE_CUDA else if (std::string(proto) == "nvmeof") { diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 6b3768b..8fc2870 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -91,6 +91,7 @@ int TransferEngine::registerLocalMemory(void *addr, size_t length, return ERR_ADDRESS_OVERLAPPED; } for (auto transport : multi_transports_->listTransports()) { + printf("transport = %p\n", transport); int ret = transport->registerLocalMemory( addr, length, location, remote_accessible, update_metadata); if (ret < 0) return ret; diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index eafe88f..5aaabf8 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -122,6 +122,16 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, buffersJSON.append(bufferJSON); } segmentJSON["buffers"] = buffersJSON; + } else if (segmentJSON["protocol"] == "shm") { + Json::Value buffersJSON(Json::arrayValue); + for (const auto &buffer : desc.buffers) { + Json::Value bufferJSON; + bufferJSON["name"] = buffer.name; + bufferJSON["addr"] = static_cast(buffer.addr); + bufferJSON["length"] = static_cast(buffer.length); + buffersJSON.append(bufferJSON); + } + segmentJSON["buffers"] = buffersJSON; } else { LOG(ERROR) << "Unsupported segment descriptor for register, name " << desc.name << " protocol " << desc.protocol; @@ -222,6 +232,19 @@ std::shared_ptr TransferMetadata::getSegmentDesc( } desc->nvmeof_buffers.push_back(buffer); } + } else if (desc->protocol == "shm") { + for (const auto &bufferJSON : segmentJSON["buffers"]) { + BufferDesc buffer; + buffer.name = bufferJSON["name"].asString(); + buffer.addr = bufferJSON["addr"].asUInt64(); + buffer.length = bufferJSON["length"].asUInt64(); + if (buffer.name.empty() || !buffer.addr || !buffer.length) { + LOG(WARNING) << "Corrupted segment descriptor, name " + << segment_name << " protocol " << desc->protocol; + return nullptr; + } + desc->buffers.push_back(buffer); + } } else { LOG(ERROR) << "Unsupported segment descriptor, name " << segment_name << " protocol " << desc->protocol; @@ -317,9 +340,10 @@ int TransferMetadata::addLocalSegment(SegmentID segment_id, return 0; } -int TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc, +int __attribute__((noinline,optimize(0))) TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc, bool update_metadata) { { + RWSpinlock::WriteGuard guard(segment_lock_); auto new_segment_desc = std::make_shared(); auto &segment_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID]; diff --git a/mooncake-transfer-engine/src/transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/CMakeLists.txt index 5dec80c..96c7c73 100644 --- a/mooncake-transfer-engine/src/transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/CMakeLists.txt @@ -6,6 +6,9 @@ add_library(transport OBJECT ${XPORT_SOURCES} $) add_subdirectory(tcp_transport) target_sources(transport PUBLIC $) +add_subdirectory(shm_transport) +target_sources(transport PUBLIC $) + if (USE_CUDA) add_subdirectory(nvmeof_transport) target_sources(transport PUBLIC $) diff --git a/mooncake-transfer-engine/src/transport/cxl_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/cxl_transport/CMakeLists.txt index eef9051..ea28e98 100644 --- a/mooncake-transfer-engine/src/transport/cxl_transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/cxl_transport/CMakeLists.txt @@ -1,4 +1,4 @@ file(GLOB CXL_SOURCES "*.cpp") add_library(cxl_transport OBJECT ${CXL_SOURCES}) -# target_link_libraries(rdma_transport PUBLIC transport) \ No newline at end of file +# target_link_libraries(cxl_transport PUBLIC transport) \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp b/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp index 284f88e..725e8be 100644 --- a/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp +++ b/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp @@ -29,43 +29,216 @@ #include "transfer_metadata.h" #include "transport/transport.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +//create shared mem +//one batch[taskid] -> one shared mem + namespace mooncake { CxlTransport::CxlTransport() { // TODO + std::cout<<"hhh"<{{nullptr, nullptr}}; + //metadata_ = std::make_shared(); // 在构造函数中初始化 metadata_ +} + +CxlTransport::~CxlTransport() { +#ifdef CONFIG_USE_BATCH_DESC_SET + for (auto &entry : batch_desc_set_) delete entry.second; + batch_desc_set_.clear(); +#endif + metadata_->removeSegmentDesc(local_server_name_); + batch_desc_set_.clear(); + + // Clean up batches + SharedMem_map_.clear(); +} + +void CxlTransport::createSharedMem(void* addr, size_t size, const std::string& location) { + // 使用 shm_open 创建或打开一个共享内存对象 + int shm_fd = shm_open(location.c_str(), O_CREAT | O_RDWR, 0666); + if (shm_fd == -1) { + perror("shm_open"); + return; + } + + // 设置共享内存的大小 + if (ftruncate(shm_fd, size) == -1) { + perror("ftruncate"); + close(shm_fd); + return; + } + + // 使用 mmap 将共享内存对象映射到进程的地址空间 + int prot = PROT_READ | PROT_WRITE; + int flags = MAP_SHARED | MAP_FIXED; + void *mapped = mmap(addr, size, prot, flags, shm_fd, 0); + if (mapped == MAP_FAILED) { + perror("mmap"); + close(shm_fd); + return; + } + + // 检查映射地址是否与期望的地址一致 + if (mapped != addr) { + fprintf(stderr, "mmap did not map at the desired address.\n"); + munmap(mapped, size); + close(shm_fd); + return; + } + + // 关闭文件描述符,mmap 仍然保持映射 + close(shm_fd); + + return; } -CxlTransport::~CxlTransport() {} + +void CxlTransport::delete_shared_memory(void *addr) { + size_t length = SharedMem_map_[addr]; + if (addr == NULL) { + fprintf(stderr, "Invalid arguments.\n"); + return; + } + // 解除映射 + if (munmap(addr, length) == -1) { + perror("munmap"); + // 处理错误 + } +} CxlTransport::BatchID CxlTransport::allocateBatchID(size_t batch_size) { + //auto cxl_desc = new BatchDesc(); auto batch_id = Transport::allocateBatchID(batch_size); + auto &batch_desc = *((BatchDesc *)(batch_id)); + batch_desc.batch_size = batch_size; + batch_desc.task_list.resize(batch_size); + batch_desc.id = batch_id; + //batch_desc.context = cxl_desc; + return batch_id; } int CxlTransport::getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + if (task_id >= task_count) return ERR_INVALID_ARGUMENT; + auto &task = batch_desc.task_list[task_id]; + status.transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == + (uint64_t)task.slices.size()) { + if (failed_slice_count) + status.s = TransferStatusEnum::FAILED; + else + status.s = TransferStatusEnum::COMPLETED; + task.is_finished = true; + } else { + status.s = TransferStatusEnum::WAITING; + } return 0; } -int CxlTransport::submitTransfer(BatchID batch_id, - const std::vector &entries) { +int CxlTransport::submitTransfer(BatchID batch_id, const std::vector& entries) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + + // Check if adding new entries would exceed the batch size + if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) + return -1; + + size_t task_id = batch_desc.task_list.size(); + batch_desc.task_list.resize(task_id + entries.size()); + for (const auto& request : entries) { + TransferTask &task = batch_desc.task_list[task_id]; + auto target_id = request.target_id; + ++task_id; + task.total_bytes = request.length; + auto slice = new Slice(); + slice->source_addr = (char *)request.source; + slice->length = request.length; + slice->opcode = request.opcode; + slice->cxl.remote_addr = (void*)request.target_offset; + slice->cxl.remote_offset = request.length; + slice->status = Slice::PENDING; + startSlice(slice); + } + return 0; } -int CxlTransport::freeBatchID(BatchID batch_id) { return 0; } +int CxlTransport::freeBatchID(BatchID batch_id) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + //auto &nvmeof_desc = *((NVMeoFBatchDesc *)(batch_desc.context)); + //int desc_idx = nvmeof_desc.desc_idx_; + int rc = Transport::freeBatchID(batch_id); + if (rc < 0) { + return -1; + } + return 0; } -int CxlTransport::install(std::string &local_server_name, - std::shared_ptr meta, void **args) { - return 0; +int CxlTransport::install(std::string& local_server_name, std::shared_ptr meta, void** args) { + // Initialize control shared memory + return Transport::install(local_server_name, meta, args); + //return 0; } -int CxlTransport::registerLocalMemory(void *addr, size_t length, - const string &location, - bool remote_accessible, - bool update_metadata) { +// Assuming metadata_ is a shared pointer to TransferMetadata +// and is initialized properly in the constructor. + +int CxlTransport::registerLocalMemory(void* addr, size_t length, const std::string& location, bool remote_accessible, bool update_metadata) { + // Generate a unique name for the shared memory segment + SharedMem_map_[addr] = length; + //create linux shared memory + createSharedMem(addr, length, location); + // Create a BufferDesc and add it to the metadata + BufferDesc buffer_desc; + buffer_desc.addr = (uint64_t)addr; + buffer_desc.length = length; + buffer_desc.name = location; + int ret = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (ret != 0) { + // Remove the shared memory segment if registration fails + //deleteSharedMem(addr); + + return ret; + } return 0; } -int CxlTransport::unregisterLocalMemory(void *addr, bool update_metadata) { +int CxlTransport::unregisterLocalMemory(void* addr, bool update_metadata) { + // Generate the unique name for the shared memory segment + // Remove the buffer's entry from the metadata + delete_shared_memory(addr); + int ret = metadata_->removeLocalMemoryBuffer(addr, update_metadata); + return 0; } + +void CxlTransport::startSlice(Slice *slice) { + slice->task->is_finished = false; + //slice->task->status = TransferTask::RUNNING; + //slice->task->submit_time = std::chrono::high_resolution_clock::now(); + slice->task->total_bytes += slice->length; + if(slice->opcode == TransferRequest::WRITE) { + memcpy((void*)slice->cxl.remote_addr, slice->source_addr, slice->length); + } + else { + memcpy(slice->source_addr, (void*)slice->cxl.remote_addr, slice->length); + } + // if (slice->task->status == TransferStatusEnum::COMPLETED) + // slice->markSuccess(); + // else + // slice->markFailed(); +} + } // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/shm_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/shm_transport/CMakeLists.txt new file mode 100644 index 0000000..563d203 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/shm_transport/CMakeLists.txt @@ -0,0 +1,4 @@ +file(GLOB SHM_SOURCES "*.cpp") + +add_library(shm_transport OBJECT ${SHM_SOURCES}) +# target_link_libraries(shm_transport PUBLIC transport) \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/shm_transport/shm_transport.cpp b/mooncake-transfer-engine/src/transport/shm_transport/shm_transport.cpp new file mode 100644 index 0000000..e9abfeb --- /dev/null +++ b/mooncake-transfer-engine/src/transport/shm_transport/shm_transport.cpp @@ -0,0 +1,300 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport/shm_transport/shm_transport.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "transfer_engine.h" +#include "transfer_metadata.h" +#include "transport/transport.h" + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +//create shared mem +//one batch[taskid] -> one shared mem + +namespace mooncake { +ShmTransport::ShmTransport() { + // TODO + //SharedMem_map_ = std::unordered_map{{nullptr, nullptr}}; + //metadata_ = std::make_shared(); // 在构造函数中初始化 metadata_ +} + +ShmTransport::~ShmTransport() { +#ifdef CONFIG_USE_BATCH_DESC_SET + for (auto &entry : batch_desc_set_) delete entry.second; + batch_desc_set_.clear(); +#endif + metadata_->removeSegmentDesc(local_server_name_); + batch_desc_set_.clear(); + + // Clean up batches + SharedMem_map_.clear(); +} + +void ShmTransport::createSharedMem(void* addr, size_t size, const std::string& location) { + // 使用 shm_open 创建或打开一个共享内存对象 + int shm_fd = shm_open(location.c_str(), O_CREAT | O_RDWR, 0666); + if (shm_fd == -1) { + perror("shm_open"); + return; + } + + // 设置共享内存的大小 + if (ftruncate(shm_fd, size) == -1) { + perror("ftruncate"); + close(shm_fd); + return; + } + + // 使用 mmap 将共享内存对象映射到进程的地址空间 + int prot = PROT_READ | PROT_WRITE; + int flags = MAP_SHARED | MAP_FIXED; + void *mapped = mmap(addr, size, prot, flags, shm_fd, 0); + if (mapped == MAP_FAILED) { + perror("mmap"); + close(shm_fd); + return; + } + + // 检查映射地址是否与期望的地址一致 + if (mapped != addr) { + fprintf(stderr, "mmap did not map at the desired address.\n"); + munmap(mapped, size); + close(shm_fd); + return; + } + + // 关闭文件描述符,mmap 仍然保持映射 + close(shm_fd); + + return; +} + + + +void ShmTransport::delete_shared_memory(void *addr) { + size_t length = SharedMem_map_[addr]; + if (addr == NULL) { + fprintf(stderr, "Invalid arguments.\n"); + return; + } + // 解除映射 + if (munmap(addr, length) == -1) { + perror("munmap"); + // 处理错误 + } +} +ShmTransport::BatchID ShmTransport::allocateBatchID(size_t batch_size) { + //auto shm_desc = new BatchDesc(); + auto batch_id = Transport::allocateBatchID(batch_size); + auto &batch_desc = *((BatchDesc *)(batch_id)); + batch_desc.batch_size = batch_size; + batch_desc.task_list.resize(batch_size); + batch_desc.id = batch_id; + //batch_desc.context = shm_desc; + + return batch_id; +} + +int ShmTransport::getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + if (task_id >= task_count) return ERR_INVALID_ARGUMENT; + auto &task = batch_desc.task_list[task_id]; + status.transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == + (uint64_t)task.slices.size()) { + if (failed_slice_count) + status.s = TransferStatusEnum::FAILED; + else + status.s = TransferStatusEnum::COMPLETED; + task.is_finished = true; + } else { + status.s = TransferStatusEnum::WAITING; + } + return 0; +} + +int ShmTransport::submitTransfer(BatchID batch_id, const std::vector& entries) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + + // Check if adding new entries would exceed the batch size + if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size){ + LOG(ERROR) << "TcpTransport: Exceed the limitation of current batch's " + "capacity"; + return ERR_TOO_MANY_REQUESTS; + } + + size_t task_id = batch_desc.task_list.size(); + batch_desc.task_list.resize(task_id + entries.size()); + for (const auto& request : entries) { + TransferTask &task = batch_desc.task_list[task_id]; + auto target_id = request.target_id; + ++task_id; + task.total_bytes = request.length; + auto slice = new Slice(); + slice->source_addr = (char *)request.source; + slice->length = request.length; + slice->opcode = request.opcode; + slice->local.dest_addr = (void*)request.target_offset; + slice->status = Slice::PENDING; + task.slices.push_back(slice); + startSlice(slice); + } + + return 0; +} + +int ShmTransport::submitTransferTask( + const std::vector &request_list, + const std::vector &task_list) { + for (size_t index = 0; index < request_list.size(); ++index) { + auto &request = *request_list[index]; + auto &task = *task_list[index]; + task.total_bytes = request.length; + auto slice = new Slice(); + slice->source_addr = (char *)request.source; + slice->length = request.length; + slice->opcode = request.opcode; + slice->local.dest_addr = (void*)request.target_offset; + slice->task = &task; + slice->target_id = request.target_id; + slice->status = Slice::PENDING; + task.slices.push_back(slice); + startSlice(slice); + } + return 0; +} + +int ShmTransport::freeBatchID(BatchID batch_id) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + //auto &nvmeof_desc = *((NVMeoFBatchDesc *)(batch_desc.context)); + //int desc_idx = nvmeof_desc.desc_idx_; + int rc = Transport::freeBatchID(batch_id); + if (rc < 0) { + return -1; + } + return 0; } + +int ShmTransport::allocateLocalSegmentID() { + auto desc = std::make_shared(); + if (!desc) return ERR_MEMORY; + desc->name = local_server_name_; + desc->protocol = "shm"; + metadata_->addLocalSegment(LOCAL_SEGMENT_ID, local_server_name_, + std::move(desc)); + return 0; +} + +int ShmTransport::install(std::string& local_server_name, std::shared_ptr meta, void** args) { + // Initialize control shared memory + + metadata_ = meta; + local_server_name_ = local_server_name; + + int ret = allocateLocalSegmentID(); + if (ret) { + LOG(ERROR) << "ShmTransport: cannot allocate local segment"; + return -1; + } + + ret = metadata_->updateLocalSegmentDesc(); + if (ret) { + LOG(ERROR) << "ShmTransport: cannot publish segments, " + "check the availability of metadata storage"; + return -1; + } + return 0; + //return 0; +} + +// Assuming metadata_ is a shared pointer to TransferMetadata +// and is initialized properly in the constructor. + +int ShmTransport::registerLocalMemory(void* addr, size_t length, const std::string& location, bool remote_accessible, bool update_metadata) { + // Generate a unique name for the shared memory segment + //SharedMem_map_[addr] = length; + //create linux shared memory + printf("addr %p",addr); + printf("length %d",length); + printf("location %s",location.c_str()); + printf("remote_accessible %d",remote_accessible); + printf("update_metadata %d",update_metadata); + createSharedMem(addr, length, local_server_name_); + // Create a BufferDesc and add it to the metadata + asm volatile ("nop":::"memory"); + BufferDesc buffer_desc; + asm volatile ("nop %0":: "r"(&buffer_desc) :"memory"); + buffer_desc.addr = (uint64_t)addr; + buffer_desc.length = length; + buffer_desc.name = local_server_name_; + printf("%d",update_metadata); + int ret = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (ret != 0) { + // Remove the shared memory segment if registration fails + //deleteSharedMem(addr); + + return ret; + } + return 0; +} + +int ShmTransport::unregisterLocalMemory(void* addr, bool update_metadata) { + // Generate the unique name for the shared memory segment + // Remove the buffer's entry from the metadata + delete_shared_memory(addr); + int ret = metadata_->removeLocalMemoryBuffer(addr, update_metadata); + + return 0; +} + +void ShmTransport::startSlice(Slice *slice) { + slice->task->is_finished = false; + //slice->task->status = TransferTask::RUNNING; + //slice->task->submit_time = std::chrono::high_resolution_clock::now(); + slice->task->total_bytes += slice->length; + if(slice->opcode == TransferRequest::WRITE) { + memcpy((void*)slice->local.dest_addr, slice->source_addr, slice->length); + } + else { + memcpy(slice->source_addr, (void*)slice->local.dest_addr, slice->length); + } + //if (slice->task->status == TransferStatusEnum::COMPLETED) + slice->markSuccess(); + +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/tcp_transport/CMakeLists.txt index 1fd83ea..2c58cd5 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/tcp_transport/CMakeLists.txt @@ -1,3 +1,3 @@ -file(GLOB CXL_SOURCES "*.cpp") +file(GLOB TCP_SOURCES "*.cpp") -add_library(tcp_transport OBJECT ${CXL_SOURCES}) +add_library(tcp_transport OBJECT ${TCP_SOURCES}) diff --git a/mooncake-transfer-engine/tests/CMakeLists.txt b/mooncake-transfer-engine/tests/CMakeLists.txt index 3611ba5..2e54f8a 100644 --- a/mooncake-transfer-engine/tests/CMakeLists.txt +++ b/mooncake-transfer-engine/tests/CMakeLists.txt @@ -21,3 +21,7 @@ add_test(NAME transfer_metadata_test COMMAND transfer_metadata_test) add_executable(topology_test topology_test.cpp) target_link_libraries(topology_test PUBLIC transfer_engine gtest gtest_main) add_test(NAME topology_test COMMAND topology_test) + +add_executable(shm_transport_test shm_transport_test.cpp) +target_link_libraries(shm_transport_test PUBLIC transfer_engine gtest gtest_main) +add_test(NAME shm_transport_test COMMAND shm_transport_test) diff --git a/mooncake-transfer-engine/tests/cxl_transport_test.cpp b/mooncake-transfer-engine/tests/cxl_transport_test.cpp new file mode 100644 index 0000000..6aab4dc --- /dev/null +++ b/mooncake-transfer-engine/tests/cxl_transport_test.cpp @@ -0,0 +1,341 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#include +#include + +#include + +static void checkCudaError(cudaError_t result, const char *message) { + if (result != cudaSuccess) { + LOG(ERROR) << message << " (Error code: " << result << " - " + << cudaGetErrorString(result) << ")" << std::endl; + exit(EXIT_FAILURE); + } +} + +#endif + +#include "transfer_engine.h" +#include "transport/transport.h" + +#ifdef USE_CUDA +DEFINE_int32(gpu_id, 0, "GPU ID to use"); +#endif + +using namespace mooncake; + +namespace mooncake { + +class CXLTransportTest : public ::testing::Test { + public: + protected: + void SetUp() override { + google::InitGoogleLogging("CXLTransportTest"); + FLAGS_logtostderr = 1; + + const char *env = std::getenv("MC_METADATA_SERVER"); + if (env) + metadata_server = env; + else + metadata_server = metadata_server; + LOG(INFO) << "metadata_server: " << metadata_server; + + env = std::getenv("MC_LOCAL_SERVER_NAME"); + if (env) + local_server_name = env; + else + local_server_name = "127.0.0.2:12345"; + LOG(INFO) << "local_server_name: " << local_server_name; + } + + void TearDown() override { + // 清理 glog + google::ShutdownGoogleLogging(); + } + + std::string metadata_server; + std::string local_server_name; +}; + +static void *allocateMemoryPool(size_t size, int socket_id, + bool from_vram = false) { + return numa_alloc_onnode(size, socket_id); +} + +TEST_F(CXLTransportTest, GetTcpTest) { + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + auto rc = engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + LOG_ASSERT(rc == 0); + Transport *xport = nullptr; + xport = engine->installTransport("cxl", nullptr); + LOG_ASSERT(xport != nullptr); +} + +TEST_F(CXLTransportTest, Writetest) { + const size_t kDataLength = 4096000; + void *addr = nullptr; + const size_t ram_buffer_size = 1ull << 30; + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + auto rc = engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + LOG_ASSERT(rc == 0); + Transport *xport = nullptr; + xport = engine->installTransport("cxl", nullptr); + LOG_ASSERT(xport != nullptr); + + addr = allocateMemoryPool(ram_buffer_size, 0, false); + rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); + LOG_ASSERT(!rc); + + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + auto segment_id = engine->openSegment(local_server_name); + TransferRequest entry; + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); +} + +TEST_F(CXLTransportTest, WriteAndReadtest) { + const size_t kDataLength = 4096000; + void *addr = nullptr; + const size_t ram_buffer_size = 1ull << 30; + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + Transport *xport = nullptr; + xport = engine->installTransport("c x l", nullptr); + LOG_ASSERT(xport != nullptr); + + addr = allocateMemoryPool(ram_buffer_size, 0, false); + int rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); + LOG_ASSERT(!rc); + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + + auto segment_id = engine->openSegment(local_server_name); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); + } + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + + TransferRequest entry; + entry.opcode = TransferRequest::READ; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr) + kDataLength; + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(!ret); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + } + ret = engine->freeBatchID(batch_id); + LOG_ASSERT(!ret); + } + LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, + kDataLength)); +} + +TEST_F(CXLTransportTest, WriteAndRead2test) { + const size_t kDataLength = 4096000; + void *addr = nullptr; + const size_t ram_buffer_size = 1ull << 30; + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + Transport *xport = nullptr; + xport = engine->installTransport("cxl", nullptr); + LOG_ASSERT(xport != nullptr); + + addr = allocateMemoryPool(ram_buffer_size, 0, false); + int rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); + LOG_ASSERT(!rc); + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + + auto segment_id = engine->openSegment(local_server_name); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); + } + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::READ; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr) + kDataLength; + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(!ret); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + } + ret = engine->freeBatchID(batch_id); + LOG_ASSERT(!ret); + } + LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, + kDataLength)); + + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); + } + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::READ; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr) + kDataLength; + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(!ret); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + } + ret = engine->freeBatchID(batch_id); + LOG_ASSERT(!ret); + } + LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, + kDataLength)); +} + +} // namespace mooncake + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/mooncake-transfer-engine/tests/shm_transport_test.cpp b/mooncake-transfer-engine/tests/shm_transport_test.cpp new file mode 100644 index 0000000..f348555 --- /dev/null +++ b/mooncake-transfer-engine/tests/shm_transport_test.cpp @@ -0,0 +1,351 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#include +#include + +#include + +static void checkCudaError(cudaError_t result, const char *message) { + if (result != cudaSuccess) { + LOG(ERROR) << message << " (Error code: " << result << " - " + << cudaGetErrorString(result) << ")" << std::endl; + exit(EXIT_FAILURE); + } +} + +#endif + +#include "transfer_engine.h" +#include "transport/transport.h" + +#ifdef USE_CUDA +DEFINE_int32(gpu_id, 0, "GPU ID to use"); +#endif + +using namespace mooncake; + +namespace mooncake { + +class SHMTransportTest : public ::testing::Test { + public: + protected: + void SetUp() override { + google::InitGoogleLogging("SHMTransportTest"); + FLAGS_logtostderr = 1; + + const char *env = std::getenv("MC_METADATA_SERVER"); + if (env) + metadata_server = env; + else + metadata_server = metadata_server; + + metadata_server = "127.0.0.1:2379"; + LOG(INFO) << "metadata_server: " << metadata_server; + + env = std::getenv("MC_LOCAL_SERVER_NAME"); + if (env) + local_server_name = env; + else + local_server_name = "127.0.0.1:2488"; + + + LOG(INFO) << "local_server_name: " << local_server_name; + } + + void TearDown() override { + // 清理 glog + google::ShutdownGoogleLogging(); + } + + std::string metadata_server; + std::string local_server_name; +}; + +static void *allocateMemoryPool(size_t size, int socket_id, + bool from_vram = false) { + return numa_alloc_onnode(size, socket_id); +} + + + + + +TEST_F(SHMTransportTest, GetTcpTest) { + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + auto rc = engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + LOG_ASSERT(rc == 0); + Transport *xport = nullptr; + xport = engine->installTransport("shm", nullptr); + LOG_ASSERT(xport != nullptr); +} + +TEST_F(SHMTransportTest, Writetest) { + const size_t kDataLength = 4096000; + void *addr = nullptr; + const size_t ram_buffer_size = 1ull << 30; + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + auto rc = engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + LOG_ASSERT(rc == 0); + Transport *xport = nullptr; + + xport = engine->installTransport("shm", nullptr); + LOG_ASSERT(xport != nullptr); + + + addr = allocateMemoryPool(ram_buffer_size, 0, false); + rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); + LOG_ASSERT(!rc); + + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + auto segment_id = engine->openSegment(local_server_name); + TransferRequest entry; + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); +} + +TEST_F(SHMTransportTest, WriteAndReadtest) { + const size_t kDataLength = 4096000; + void *addr = nullptr; + const size_t ram_buffer_size = 1ull << 30; + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + Transport *xport = nullptr; + xport = engine->installTransport("shm", nullptr); + LOG_ASSERT(xport != nullptr); + + addr = allocateMemoryPool(ram_buffer_size, 0, false); + int rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); + LOG_ASSERT(!rc); + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + + auto segment_id = engine->openSegment(local_server_name); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); + } + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + + TransferRequest entry; + entry.opcode = TransferRequest::READ; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr) + kDataLength; + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(!ret); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + } + ret = engine->freeBatchID(batch_id); + LOG_ASSERT(!ret); + } + LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, + kDataLength)); +} + +TEST_F(SHMTransportTest, WriteAndRead2test) { + const size_t kDataLength = 4096000; + void *addr = nullptr; + const size_t ram_buffer_size = 1ull << 30; + auto engine = std::make_unique(); + auto hostname_port = parseHostNameWithPort(local_server_name); + engine->init(metadata_server, local_server_name, + hostname_port.first.c_str(), hostname_port.second); + Transport *xport = nullptr; + xport = engine->installTransport("shm", nullptr); + LOG_ASSERT(xport != nullptr); + + addr = allocateMemoryPool(ram_buffer_size, 0, false); + int rc = engine->registerLocalMemory(addr, ram_buffer_size, "cpu:0"); + LOG_ASSERT(!rc); + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + + auto segment_id = engine->openSegment(local_server_name); + auto segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); + } + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::READ; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr) + kDataLength; + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(!ret); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + } + ret = engine->freeBatchID(batch_id); + LOG_ASSERT(!ret); + } + LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, + kDataLength)); + + for (size_t offset = 0; offset < kDataLength; ++offset) + *((char *)(addr) + offset) = 'a' + lrand48() % 26; + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::WRITE; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr); + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + ASSERT_EQ(ret, 0); + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + } + ret = engine->freeBatchID(batch_id); + ASSERT_EQ(ret, 0); + } + + { + auto batch_id = engine->allocateBatchID(1); + int ret = 0; + TransferRequest entry; + entry.opcode = TransferRequest::READ; + entry.length = kDataLength; + entry.source = (uint8_t *)(addr) + kDataLength; + entry.target_id = segment_id; + entry.target_offset = remote_base; + ret = engine->submitTransfer(batch_id, {entry}); + LOG_ASSERT(!ret); + bool completed = false; + TransferStatus status; + while (!completed) { + int ret = engine->getTransferStatus(batch_id, 0, status); + LOG_ASSERT(!ret); + if (status.s == TransferStatusEnum::COMPLETED) completed = true; + LOG_ASSERT(status.s != TransferStatusEnum::FAILED); + } + ret = engine->freeBatchID(batch_id); + LOG_ASSERT(!ret); + } + LOG_ASSERT(0 == memcmp((uint8_t *)(addr), (uint8_t *)(addr) + kDataLength, + kDataLength)); +} + +} // namespace mooncake + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/mooncake-transfer-engine/tests/tcp_transport_test.cpp b/mooncake-transfer-engine/tests/tcp_transport_test.cpp index 862d88c..58a9b06 100644 --- a/mooncake-transfer-engine/tests/tcp_transport_test.cpp +++ b/mooncake-transfer-engine/tests/tcp_transport_test.cpp @@ -62,13 +62,15 @@ class TCPTransportTest : public ::testing::Test { metadata_server = env; else metadata_server = metadata_server; + + metadata_server = "127.0.0.1:2379"; LOG(INFO) << "metadata_server: " << metadata_server; env = std::getenv("MC_LOCAL_SERVER_NAME"); if (env) local_server_name = env; else - local_server_name = "127.0.0.2:12345"; + local_server_name = "127.0.0.1:2222"; LOG(INFO) << "local_server_name: " << local_server_name; } From 9ba688446812dd39e95681e7b5a797bdb196e774 Mon Sep 17 00:00:00 2001 From: liuziqian <2949547669@qq.com> Date: Mon, 13 Jan 2025 19:14:04 +0800 Subject: [PATCH 2/2] fix the shm_transport_test.cpp hard code bug --- mooncake-transfer-engine/tests/shm_transport_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mooncake-transfer-engine/tests/shm_transport_test.cpp b/mooncake-transfer-engine/tests/shm_transport_test.cpp index f348555..0e98f02 100644 --- a/mooncake-transfer-engine/tests/shm_transport_test.cpp +++ b/mooncake-transfer-engine/tests/shm_transport_test.cpp @@ -63,7 +63,6 @@ class SHMTransportTest : public ::testing::Test { else metadata_server = metadata_server; - metadata_server = "127.0.0.1:2379"; LOG(INFO) << "metadata_server: " << metadata_server; env = std::getenv("MC_LOCAL_SERVER_NAME");