Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TransferEngine] Implement Linux Shared Memory Transport #issue48 #77

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ cmake ..
make -j$(nproc) && sudo make install

echo "*** Download and installing [golang-1.22] ***"
wget https://go.dev/dl/go1.22.linux-amd64.tar.gz
sudo tar -C /usr/local -xzf go1.22.linux-amd64.tar.gz
wget https://go.dev/dl/go1.22.10.linux-amd64.tar.gz
sudo tar -C /usr/local -xzf go1.22.10.linux-amd64.tar.gz

echo "*** Dependencies Installed! ***"
6 changes: 3 additions & 3 deletions doc/zh/transfer-engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Transfer Engine 使用SIEVE算法来管理端点的逐出。如果由于链路
例如,可使用如下命令行启动 `etcd` 服务:
```bash
# This is 10.0.0.1
etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://10.0.0.1:2379
etcd --listen-client-urls http://127.0.0.1:2379 --advertise-client-urls http://127.0.0.1:2379
```

1.2. **启动 `http` 作为 `metadata` 服务**
Expand All @@ -95,8 +95,8 @@ Transfer Engine 使用SIEVE算法来管理端点的逐出。如果由于链路
```bash
# This is 10.0.0.2
./transfer_engine_bench --mode=target \
--metadata_server=etcd://10.0.0.1:2379 \
--local_server_name=10.0.0.2:12345 \
--metadata_server=etcd://127.0.0.1:2379 \
--local_server_name=127.0.0.1:12345 \
--device_name=erdma_0
```
各个参数的含义如下:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/kvcache-ai/Mooncake/mooncake-transfer-engine/example/http-metadata-server

go 1.22.9
go 1.23

require github.com/gin-gonic/gin v1.10.0

Expand Down
5 changes: 5 additions & 0 deletions mooncake-transfer-engine/include/transfer_engine_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ struct segment_desc {
uint64_t port;
// maybe more needed for mount
} nvmeof;
struct {
void *addr;
uint16_t size;
const char *location;
}shm;
} desc_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class CxlTransport : public Transport {

~CxlTransport();

void createSharedMem(void* addr, size_t size, const std::string& location);

void delete_shared_memory(void *addr);

BatchID allocateBatchID(size_t batch_size) override;

int submitTransfer(BatchID batch_id,
Expand All @@ -55,6 +59,7 @@ class CxlTransport : public Transport {
int freeBatchID(BatchID batch_id) override;

private:

int install(std::string &local_server_name,
std::shared_ptr<TransferMetadata> meta, void **args) override;

Expand All @@ -77,6 +82,10 @@ class CxlTransport : public Transport {
}

const char *getName() const override { return "cxl"; }

std::unordered_map<void*,size_t> SharedMem_map_;

void startSlice(Slice *slice);
};
} // namespace mooncake

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2024 KVCache.AI
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SHM_TRANSPORT_H_
#define SHM_TRANSPORT_H_

#include <infiniband/verbs.h>

#include <atomic>
#include <cstddef>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "transfer_metadata.h"
#include "transport/transport.h"

namespace mooncake {
class TransferMetadata;

class ShmTransport : public Transport {
public:
using BufferDesc = TransferMetadata::BufferDesc;
using SegmentDesc = TransferMetadata::SegmentDesc;
using HandShakeDesc = TransferMetadata::HandShakeDesc;

public:
ShmTransport();

~ShmTransport();

void createSharedMem(void* addr, size_t size, const std::string& location);

void delete_shared_memory(void *addr);

BatchID allocateBatchID(size_t batch_size) override;

int submitTransfer(BatchID batch_id,
const std::vector<TransferRequest> &entries) override;

int getTransferStatus(BatchID batch_id, size_t task_id,
TransferStatus &status) override;

int freeBatchID(BatchID batch_id) override;

int submitTransferTask(
const std::vector<TransferRequest *> &request_list,
const std::vector<TransferTask *> &task_list);

private:

int install(std::string &local_server_name,
std::shared_ptr<TransferMetadata> meta, void **args) override;

int registerLocalMemory(void *addr, size_t length, const std::string &location,
bool remote_accessible,
bool update_metadata) override;

int unregisterLocalMemory(void *addr,
bool update_metadata = false) override;

int registerLocalMemoryBatch(
const std::vector<Transport::BufferEntry> &buffer_list,
const std::string &location) override {
return 0;
}

int unregisterLocalMemoryBatch(
const std::vector<void *> &addr_list) override {
return 0;
}

const char *getName() const override { return "shm"; }

std::unordered_map<void*,size_t> SharedMem_map_;

void startSlice(Slice *slice);

int allocateLocalSegmentID();
};
} // namespace mooncake

#endif
3 changes: 3 additions & 0 deletions mooncake-transfer-engine/src/multi_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "transport/rdma_transport/rdma_transport.h"
#include "transport/tcp_transport/tcp_transport.h"
#include "transport/shm_transport/shm_transport.h"
#include "transport/transport.h"
#ifdef USE_CUDA
#include "transport/nvmeof_transport/nvmeof_transport.h"
Expand Down Expand Up @@ -130,6 +131,8 @@ Transport *MultiTransport::installTransport(const std::string &proto,
transport = new RdmaTransport();
} else if (std::string(proto) == "tcp") {
transport = new TcpTransport();
} else if (std::string(proto) == "shm") {
transport = new ShmTransport();
}
#ifdef USE_CUDA
else if (std::string(proto) == "nvmeof") {
Expand Down
1 change: 1 addition & 0 deletions mooncake-transfer-engine/src/transfer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ int TransferEngine::registerLocalMemory(void *addr, size_t length,
return ERR_ADDRESS_OVERLAPPED;
}
for (auto transport : multi_transports_->listTransports()) {
printf("transport = %p\n", transport);
int ret = transport->registerLocalMemory(
addr, length, location, remote_accessible, update_metadata);
if (ret < 0) return ret;
Expand Down
26 changes: 25 additions & 1 deletion mooncake-transfer-engine/src/transfer_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name,
buffersJSON.append(bufferJSON);
}
segmentJSON["buffers"] = buffersJSON;
} else if (segmentJSON["protocol"] == "shm") {
Json::Value buffersJSON(Json::arrayValue);
for (const auto &buffer : desc.buffers) {
Json::Value bufferJSON;
bufferJSON["name"] = buffer.name;
bufferJSON["addr"] = static_cast<Json::UInt64>(buffer.addr);
bufferJSON["length"] = static_cast<Json::UInt64>(buffer.length);
buffersJSON.append(bufferJSON);
}
segmentJSON["buffers"] = buffersJSON;
} else {
LOG(ERROR) << "Unsupported segment descriptor for register, name "
<< desc.name << " protocol " << desc.protocol;
Expand Down Expand Up @@ -222,6 +232,19 @@ std::shared_ptr<TransferMetadata::SegmentDesc> TransferMetadata::getSegmentDesc(
}
desc->nvmeof_buffers.push_back(buffer);
}
} else if (desc->protocol == "shm") {
for (const auto &bufferJSON : segmentJSON["buffers"]) {
BufferDesc buffer;
buffer.name = bufferJSON["name"].asString();
buffer.addr = bufferJSON["addr"].asUInt64();
buffer.length = bufferJSON["length"].asUInt64();
if (buffer.name.empty() || !buffer.addr || !buffer.length) {
LOG(WARNING) << "Corrupted segment descriptor, name "
<< segment_name << " protocol " << desc->protocol;
return nullptr;
}
desc->buffers.push_back(buffer);
}
} else {
LOG(ERROR) << "Unsupported segment descriptor, name " << segment_name
<< " protocol " << desc->protocol;
Expand Down Expand Up @@ -317,9 +340,10 @@ int TransferMetadata::addLocalSegment(SegmentID segment_id,
return 0;
}

int TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc,
int __attribute__((noinline,optimize(0))) TransferMetadata::addLocalMemoryBuffer(const BufferDesc &buffer_desc,
bool update_metadata) {
{

RWSpinlock::WriteGuard guard(segment_lock_);
auto new_segment_desc = std::make_shared<SegmentDesc>();
auto &segment_desc = segment_id_to_desc_map_[LOCAL_SEGMENT_ID];
Expand Down
3 changes: 3 additions & 0 deletions mooncake-transfer-engine/src/transport/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ add_library(transport OBJECT ${XPORT_SOURCES} $<TARGET_OBJECTS:rdma_transport>)
add_subdirectory(tcp_transport)
target_sources(transport PUBLIC $<TARGET_OBJECTS:tcp_transport>)

add_subdirectory(shm_transport)
target_sources(transport PUBLIC $<TARGET_OBJECTS:shm_transport>)

if (USE_CUDA)
add_subdirectory(nvmeof_transport)
target_sources(transport PUBLIC $<TARGET_OBJECTS:nvmeof_transport>)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
file(GLOB CXL_SOURCES "*.cpp")

add_library(cxl_transport OBJECT ${CXL_SOURCES})
# target_link_libraries(rdma_transport PUBLIC transport)
# target_link_libraries(cxl_transport PUBLIC transport)
Loading
Loading