diff --git a/build.sh b/build.sh index 248fea9a..cddaf710 100755 --- a/build.sh +++ b/build.sh @@ -328,6 +328,7 @@ docker run --rm --user "$(id -u):$(id -g)" \ -e USE_TCPX="${USE_TCPX:-0}" \ -e USE_EFA="${USE_EFA:-0}" \ -e USE_IB="${USE_IB:-0}" \ + -e USE_TCP="${USE_TCP:-0}" \ -e MAKE_NORMAL_MODE="${MAKE_NORMAL_MODE:-}" \ -e TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-}" \ -e FUNCTION_DEF="$(declare -f build_rccl_nccl_h build_ccl_rdma build_ccl_efa build_p2p build_ep build_eccl)" \ diff --git a/p2p/Makefile b/p2p/Makefile index 7c200469..d75f7768 100644 --- a/p2p/Makefile +++ b/p2p/Makefile @@ -7,6 +7,9 @@ USE_EFA ?= $(shell echo $${USE_EFA:-0}) # IB optional integration USE_IB ?= $(shell echo $${USE_IB:-0}) +# TCP endpoint integration +USE_TCP ?= $(shell echo $${USE_TCP:-0}) + # Compiler and flags CUDA_HOME ?= /usr/local/cuda CUDA_INC := $(CUDA_HOME)/include @@ -38,6 +41,11 @@ else LIBS2 += -lglog -lgflags -lgtest -lz -lelf -lpthread -libverbs endif +# Add TCP endpoint flag +ifeq ($(USE_TCP),1) + CXXFLAGS += -DUCCL_P2P_USE_TCP +endif + # Python and pybind11 configuration PYTHON ?= python3 PYTHON_CONFIG = $(PYTHON)-config @@ -73,8 +81,13 @@ CXXFLAGS += -DUCCL_P2P_USE_TCPX -I$(NCCL_INC) LDFLAGS += -L$(NCCL_LIB) -lnccl -Wl,-rpath,$(NCCL_LIB) OBJECTS := $(CORE_OBJECTS) else -SOURCES := engine.cc engine_pybind.cc -CORE_OBJECT := engine.o +SOURCES := engine.cc engine_pybind.cc tcp/tcp_endpoint.cc tcp/tcp_worker_pool.cc +CORE_OBJECT := engine.o tcp/tcp_endpoint.o tcp/tcp_worker_pool.o +endif + +ifeq ($(USE_TCPX),1) +OBJECTS := $(CORE_OBJECTS) +else OBJECTS := $(SOURCES:.cc=.o) endif @@ -121,6 +134,11 @@ endif %.o: %.cc $(CXX) $(CXXFLAGS) $(PYBIND11_INCLUDES) -c $< -o $@ +# Compile TCP source files +tcp/%.o: tcp/%.cc + @mkdir -p tcp + $(CXX) $(CXXFLAGS) $(PYBIND11_INCLUDES) -I. -c $< -o $@ + ifeq ($(USE_TCPX),1) install: $(P2P_SHARED_LIB) install -m 755 $(P2P_SHARED_LIB) $(LIBDIR)/ @@ -138,6 +156,7 @@ endif # Clean build artifacts clean: rm -f $(OBJECTS) $(CAPI_OBJECT) $(P2P_SHARED_LIB) + rm -f tcp/*.o ifneq ($(USE_TCPX),1) rm -f $(P2P_PYTHON_EXT) $(RDMA_PLUGIN_LIB) $(RDMA_OBJECTS) endif diff --git a/p2p/Makefile.rocm b/p2p/Makefile.rocm index 8633611b..95c6937f 100644 --- a/p2p/Makefile.rocm +++ b/p2p/Makefile.rocm @@ -5,6 +5,9 @@ USE_EFA ?= $(shell echo $${USE_EFA:-0}) # IB optional integration USE_IB ?= $(shell echo $${USE_IB:-0}) +# TCP endpoint integration +USE_TCP ?= $(shell echo $${USE_TCP:-0}) + # Compiler and flags HIP_HOME?=/opt/rocm HIP_INC := $(HIP_HOME)/include @@ -39,6 +42,11 @@ LIBDIR ?= $(PREFIX)/lib INCDIR ?= $(PREFIX)/include CXXFLAGS += -D__HIP_PLATFORM_AMD__ + +# Add TCP endpoint flag +ifeq ($(USE_TCP),1) + CXXFLAGS += -DUCCL_P2P_USE_TCP +endif LDFLAGS = -L$(HIP_LIB) -lamdhip64 \ -Wl,-rpath,$(HIP_LIB) -L${CONDA_LIB_HOME} -lglog -lgflags -lgtest \ -lz -lelf -libverbs -lpthread @@ -47,8 +55,8 @@ LDFLAGS = -L$(HIP_LIB) -lamdhip64 \ P2P_PYTHON_EXT := p2p$(PYEXT) P2P_SHARED_LIB := libuccl_p2p.so RDMA_PLUGIN_LIB := librdma_plugin.a -SOURCES := engine.cc engine_pybind.cc -OBJECTS := $(SOURCES:.cc=.o) +SOURCES := engine.cc engine_pybind.cc tcp/tcp_endpoint.cc tcp/tcp_worker_pool.cc +OBJECTS := engine.o engine_pybind.o tcp/tcp_endpoint.o tcp/tcp_worker_pool.o CAPI_SOURCE := uccl_engine.cc CAPI_HEADER := uccl_engine.h CAPI_OBJECT := $(CAPI_SOURCE:.cc=.o) @@ -84,6 +92,11 @@ $(RDMA_OBJECTS): %.o: $(RDMA_HOME)/%.cc %.o: %.cc $(CXX) $(CXXFLAGS) $(PYBIND11_INCLUDES) -c $< -o $@ +# Compile TCP source files +tcp/%.o: tcp/%.cc + @mkdir -p tcp + $(CXX) $(CXXFLAGS) $(PYBIND11_INCLUDES) -I. -c $< -o $@ + # Install the module install: $(P2P_PYTHON_EXT) $(P2P_SHARED_LIB) @mkdir -p $(INSTALL_DIR) @@ -95,6 +108,7 @@ install: $(P2P_PYTHON_EXT) $(P2P_SHARED_LIB) # Clean build artifacts clean: rm -f $(OBJECTS) $(CAPI_OBJECT) $(P2P_SHARED_LIB) $(P2P_PYTHON_EXT) $(RDMA_PLUGIN_LIB) $(RDMA_OBJECTS) + rm -f tcp/*.o make -C $(RDMA_HOME) -f Makefile.rocm clean -j$(nproc) # Test the module diff --git a/p2p/endpoint_wrapper.h b/p2p/endpoint_wrapper.h index 0a6e1801..05fa352e 100644 --- a/p2p/endpoint_wrapper.h +++ b/p2p/endpoint_wrapper.h @@ -1,6 +1,10 @@ #pragma once #include "engine.h" +#ifdef UCCL_P2P_USE_TCP +#include "tcp/tcp_endpoint.h" +#endif + namespace unified { template @@ -19,6 +23,12 @@ inline void delete_ep(RDMAEndPoint const& s) { else if constexpr (std::is_same_v>) { // shared_ptr: do nothing (shared_ptr handles lifetime) } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + // shared_ptr: do nothing (shared_ptr handles lifetime) + } #endif else { static_assert(always_false::value, @@ -53,6 +63,32 @@ inline int set_request(std::shared_ptr const& obj, Conn* conn, } #endif +#ifdef UCCL_P2P_USE_TCP +inline int tcp_set_request_write(std::shared_ptr const& obj, + Conn* conn, unified::P2PMhandle* local_mh, + void* src, size_t size, + FifoItem const& slot_item, + uccl::ucclRequest* ureq) { + ureq->type = uccl::ReqType::ReqWrite; + ureq->n = conn->uccl_conn_id_.flow_id; + return obj->uccl_write_async( + reinterpret_cast(conn->uccl_conn_id_.context), nullptr, + src, size, slot_item, ureq); +} + +inline int tcp_set_request_read(std::shared_ptr const& obj, + Conn* conn, unified::P2PMhandle* local_mh, + void* dst, size_t size, + FifoItem const& slot_item, + uccl::ucclRequest* ureq) { + ureq->type = uccl::ReqType::ReqRead; + ureq->n = conn->uccl_conn_id_.flow_id; + return obj->uccl_read_async( + reinterpret_cast(conn->uccl_conn_id_.context), nullptr, + dst, size, slot_item, ureq); +} +#endif + inline uccl::ConnID uccl_connect(RDMAEndPoint const& s, int dev, int local_gpuidx, int remote_dev, int remote_gpuidx, std::string remote_ip, @@ -60,8 +96,17 @@ inline uccl::ConnID uccl_connect(RDMAEndPoint const& s, int dev, return std::visit( [dev, local_gpuidx, remote_dev, remote_gpuidx, remote_ip, remote_port](auto&& obj) -> uccl::ConnID { - return obj->uccl_connect(dev, local_gpuidx, remote_dev, remote_gpuidx, - remote_ip, remote_port); + using T = std::decay_t; +#ifdef UCCL_P2P_USE_TCP + if constexpr (std::is_same_v>) { + return obj->uccl_connect(dev, local_gpuidx, remote_dev, remote_gpuidx, + remote_ip, remote_port); + } else +#endif + { + return obj->uccl_connect(dev, local_gpuidx, remote_dev, remote_gpuidx, + remote_ip, remote_port); + } }, s); } @@ -115,6 +160,13 @@ inline bool uccl_regmr(RDMAEndPoint const& s, int dev, void* data, size_t len, return false; } } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + // TCP doesn't need memory registration + mhandle->mhandle_ = nullptr; + } #endif else { static_assert(always_false::value, @@ -157,6 +209,16 @@ inline int uccl_send_async(RDMAEndPoint const& s, Conn* conn, ureq->n = conn->uccl_conn_id_.flow_id; return ureq->engine_idx; } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + ureq->type = uccl::ReqType::ReqTx; + ureq->n = conn->uccl_conn_id_.flow_id; + return obj->uccl_send_async( + reinterpret_cast(conn->uccl_conn_id_.context), + nullptr, data, size, ureq); + } #endif else { static_assert(always_false::value, @@ -190,6 +252,16 @@ inline int uccl_recv_async(RDMAEndPoint const& s, Conn* conn, ureq->n = conn->uccl_conn_id_.flow_id; return ureq->engine_idx; } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + ureq->type = uccl::ReqType::ReqRx; + ureq->n = conn->uccl_conn_id_.flow_id; + return obj->uccl_recv_async( + reinterpret_cast(conn->uccl_conn_id_.context), + nullptr, data, size, n, ureq); + } #endif else { static_assert(always_false::value, @@ -220,6 +292,13 @@ inline bool uccl_poll_ureq_once(RDMAEndPoint const& s, return obj->checkRecvComplete_once(ureq->n, ureq->engine_idx); } } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + // TCP operations are blocking, so always complete immediately + return obj->uccl_poll_ureq_once(ureq); + } #endif else { static_assert(always_false::value, @@ -248,6 +327,13 @@ inline int uccl_read_async(RDMAEndPoint const& s, Conn* conn, ureq->type = uccl::ReqType::ReqRead; return set_request(obj, conn, local_mh, dst, size, slot_item, ureq); } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + return tcp_set_request_read(obj, conn, local_mh, dst, size, slot_item, + ureq); + } #endif else { static_assert(always_false::value, @@ -276,6 +362,13 @@ inline int uccl_write_async(RDMAEndPoint const& s, Conn* conn, ureq->type = uccl::ReqType::ReqWrite; return set_request(obj, conn, local_mh, src, size, slot_item, ureq); } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + return tcp_set_request_write(obj, conn, local_mh, src, size, + slot_item, ureq); + } #endif else { static_assert(always_false::value, @@ -312,6 +405,19 @@ inline int prepare_fifo_metadata(RDMAEndPoint const& s, Conn* conn, uccl::serialize_fifo_item(remote_mem_info, out_buf); return 0; } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + // For TCP, just store address and size (no rkeys needed) + FifoItem remote_mem_info; + remote_mem_info.addr = reinterpret_cast(data); + remote_mem_info.size = size; + std::memset(remote_mem_info.padding, 0, + sizeof(remote_mem_info.padding)); + uccl::serialize_fifo_item(remote_mem_info, out_buf); + return 0; + } #endif else { static_assert(always_false::value, @@ -334,6 +440,12 @@ inline void uccl_deregmr(RDMAEndPoint const& s, P2PMhandle* mhandle) { else if constexpr (std::is_same_v>) { obj->uccl_deregmr(mhandle->mr_array); } +#endif +#ifdef UCCL_P2P_USE_TCP + else if constexpr (std::is_same_v>) { + // TCP doesn't need memory deregistration - no-op + } #endif else { static_assert(always_false::value, diff --git a/p2p/engine.cc b/p2p/engine.cc index 2f70dca4..0c52b109 100644 --- a/p2p/engine.cc +++ b/p2p/engine.cc @@ -69,6 +69,13 @@ Endpoint::Endpoint(uint32_t const local_gpu_idx, uint32_t const num_cpus) ep_ = std::shared_ptr( new NICEndpoint(local_gpu_idx_, INVALID_RANK_ID, 0, false)); numa_node_ = 0; +#elif defined(UCCL_P2P_USE_TCP) + ep_ = std::make_shared(local_gpu_idx_, 0); + numa_node_ = 0; + // Initialize GPU to device mapping (use 0 for TCP since no RDMA devices) + for (int i = 0; i < kMaxNumGPUs; i++) { + gpu_to_dev[i] = 0; + } #else ep_ = new uccl::RDMAEndpoint(num_cpus_); @@ -130,6 +137,15 @@ Endpoint::Endpoint(uint32_t const num_cpus) : num_cpus_(num_cpus) { #ifdef UCCL_P2P_USE_NATIVE_RDMA ep_ = std::shared_ptr( new NICEndpoint(local_gpu_idx_, INVALID_RANK_ID, 0, false)); +#elif defined(UCCL_P2P_USE_TCP) + // Initialize the TCP endpoint + ep_ = std::make_shared(local_gpu_idx_, 0); + // Initialize GPU to device mapping (use 0 for TCP since no RDMA devices) + int ngpus_detected = 0; + GPU_RT_CHECK(gpuGetDeviceCount(&ngpus_detected)); + for (int i = 0; i < kMaxNumGPUs; i++) { + gpu_to_dev[i] = 0; + } #else // Initialize the RDMA endpoint with lazy creation. ep_ = new uccl::RDMAEndpoint(num_cpus_); @@ -198,14 +214,18 @@ void Endpoint::initialize_engine() { GPU_RT_CHECK(gpuStreamCreateWithFlags(&streams_[i], gpuStreamNonBlocking)); } +#if defined(UCCL_P2P_USE_TCP) + numa_node_ = 0; // TCP doesn't have RDMA devices +#else numa_node_ = uccl::RDMAFactory::get_factory_dev(gpu_to_dev[local_gpu_idx_])->numa_node; +#endif // Initialize the engine based on the GPU index. std::cout << "Lazy creation of engine, GPU index: " << local_gpu_idx_ << std::endl; // Initialize engine by fixed engine offset since we did lazy initialization -#ifndef UCCL_P2P_USE_NATIVE_RDMA +#if !defined(UCCL_P2P_USE_NATIVE_RDMA) && !defined(UCCL_P2P_USE_TCP) unified::initialize_engine_by_dev(ep_, gpu_to_dev[local_gpu_idx_], false); std::cout << "Engine initialized for GPU " << local_gpu_idx_ << std::endl; #endif @@ -581,7 +601,6 @@ bool Endpoint::recv(uint64_t conn_id, uint64_t mr_id, void* data, size_t size) { } if (unified::uccl_poll_ureq_once(ep_, &ureq[i % kMaxInflightChunks])) { // Just mark it as completed, DO NOT increment ureq_finished here. - LOG(INFO) << "chunk recv::::" << i; done[i % kMaxInflightChunks] = true; } } diff --git a/p2p/engine.h b/p2p/engine.h index 8f54df96..6cea4a35 100644 --- a/p2p/engine.h +++ b/p2p/engine.h @@ -115,6 +115,10 @@ using MRArray = RKeyArrayT; #include "rdma/rdma_endpoint.h" #endif +#ifdef UCCL_P2P_USE_TCP +#include "tcp/tcp_endpoint.h" +#endif + namespace unified { struct P2PMhandle { @@ -125,6 +129,9 @@ struct P2PMhandle { #ifdef UCCL_P2P_USE_NATIVE_RDMA using RDMAEndPoint = std::variant>; +#elif defined(UCCL_P2P_USE_TCP) +using RDMAEndPoint = + std::variant>; #else using RDMAEndPoint = std::variant; #endif @@ -155,8 +162,8 @@ using FifoItem = uccl::FifoItem; #endif class Endpoint { -#ifdef UCCL_P2P_USE_NATIVE_RDMA - uint64_t const kChunkSize = 1024 * 1024 * 1024; +#if defined(UCCL_P2P_USE_NATIVE_RDMA) || defined(UCCL_P2P_USE_TCP) + uint64_t const kChunkSize = 1024 * 1024 * 1024; // 1GB for EFA #else uint64_t const kChunkSize = 1024 * 1024; #endif diff --git a/p2p/tcp/tcp_endpoint.cc b/p2p/tcp/tcp_endpoint.cc new file mode 100644 index 00000000..1d3c686b --- /dev/null +++ b/p2p/tcp/tcp_endpoint.cc @@ -0,0 +1,793 @@ +#include "tcp/tcp_endpoint.h" + +namespace tcp { + +uint64_t get_interface_bandwidth(std::string const& ifname) { + std::string speed_path = "/sys/class/net/" + ifname + "/speed"; + std::ifstream speed_file(speed_path); + if (speed_file.is_open()) { + int speed_mbps = 0; + speed_file >> speed_mbps; + if (speed_mbps > 0) { + return static_cast(speed_mbps) * 1000000ULL; + } + } + + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock < 0) return 0; + + struct ifreq ifr; + struct ethtool_cmd ecmd; + std::memset(&ifr, 0, sizeof(ifr)); + std::strncpy(ifr.ifr_name, ifname.c_str(), IFNAMSIZ - 1); + ecmd.cmd = ETHTOOL_GSET; + ifr.ifr_data = reinterpret_cast(&ecmd); + + uint64_t bandwidth = 0; + if (ioctl(sock, SIOCETHTOOL, &ifr) >= 0) { + uint32_t speed = ethtool_cmd_speed(&ecmd); + if (speed != UINT32_MAX && speed > 0) { + bandwidth = static_cast(speed) * 1000000ULL; + } + } + close(sock); + return bandwidth; +} + +std::vector parse_tcp_interfaces() { + std::vector interfaces; + + char const* env = std::getenv("UCCL_P2P_TCP_IFNAME"); + if (!env || strlen(env) == 0) { + char ifNames[MAX_IFS * MAX_IF_NAME_SIZE]; + uccl::socketAddress ifAddrs[MAX_IFS]; + int nIfs = + uccl::find_interfaces(ifNames, ifAddrs, MAX_IF_NAME_SIZE, MAX_IFS); + + for (int i = 0; i < nIfs; i++) { + std::string name = &ifNames[i * MAX_IF_NAME_SIZE]; + if (name == "lo") continue; + + std::string ip = uccl::get_dev_ip(name.c_str()); + if (ip.empty()) continue; + + uint64_t bw = get_interface_bandwidth(name); + if (bw == 0) bw = 10ULL * 1000 * 1000 * 1000; + + int num_conns = + std::max(1, static_cast(bw / kBandwidthPerConnection)); + interfaces.push_back({name, ip, bw, num_conns}); + break; + } + } else { + std::string env_str(env); + std::stringstream ss(env_str); + std::string ifname; + + while (std::getline(ss, ifname, ',')) { + ifname.erase(0, ifname.find_first_not_of(" \t")); + ifname.erase(ifname.find_last_not_of(" \t") + 1); + + if (ifname.empty()) continue; + + std::string ip = uccl::get_dev_ip(ifname.c_str()); + if (ip.empty()) { + LOG(WARNING) << "TCP: Interface " << ifname + << " has no IP address, skipping"; + continue; + } + + uint64_t bw = get_interface_bandwidth(ifname); + if (bw == 0) bw = 10ULL * 1000 * 1000 * 1000; + + int num_conns = + std::max(1, static_cast(bw / kBandwidthPerConnection)); + interfaces.push_back({ifname, ip, bw, num_conns}); + LOG(INFO) << "TCP: Using interface " << ifname << " (" << ip << "), " + << "bandwidth=" << (bw / 1e9) << " Gbps, " + << "connections=" << num_conns; + } + } + + if (interfaces.empty()) { + LOG(WARNING) << "TCP: No valid interfaces found, using localhost"; + interfaces.push_back({"lo", "127.0.0.1", 10ULL * 1000 * 1000 * 1000, 1}); + } + + return interfaces; +} + +TCPEndpoint::TCPEndpoint(int gpu_index, uint16_t port) + : gpu_index_(gpu_index), + listen_port_(port), + next_conn_id_(0), + next_request_id_(0), + running_(true) { + interfaces_ = parse_tcp_interfaces(); + + total_connections_ = 0; + for (auto const& iface : interfaces_) { + total_connections_ += iface.num_connections; + } + + LOG(INFO) << "TCPEndpoint initialized for GPU " << gpu_index_ << " with " + << interfaces_.size() << " interfaces, " << total_connections_ + << " total connections per peer"; + + thread_pool_ = std::make_unique(); + thread_pool_->start(); + + start_listening(); +} + +TCPEndpoint::~TCPEndpoint() { + running_ = false; + + if (thread_pool_) { + thread_pool_->stop(); + } + + if (listen_fd_ >= 0) { + close(listen_fd_); + listen_fd_ = -1; + } + + // Close per-interface data listen sockets + for (int fd : data_listen_fds_) { + if (fd >= 0) { + close(fd); + } + } + data_listen_fds_.clear(); + + std::unique_lock lock(conn_mutex_); + connection_groups_.clear(); +} + +uccl::ConnID TCPEndpoint::uccl_connect(int dev, int local_gpuidx, + int remote_dev, int remote_gpuidx, + std::string remote_ip, + uint16_t remote_port) { + uint64_t conn_id = next_conn_id_.fetch_add(1, std::memory_order_relaxed); + + LOG(INFO) << "TCPEndpoint::uccl_connect to " << remote_ip << ":" + << remote_port; + + auto group = std::make_shared(); + + NegotiationInfo local_info; + local_info.gpu_index = local_gpuidx; + local_info.num_interfaces = interfaces_.size(); + local_info.total_connections = total_connections_; + local_info.reserved = 0; + + int ctrl_fd = create_tcp_connection(remote_ip, remote_port); + if (ctrl_fd < 0) { + LOG(ERROR) << "Failed to create control connection"; + uccl::ConnID invalid_id; + invalid_id.flow_id = UINT64_MAX; + return invalid_id; + } + + group->ctrl_fd = ctrl_fd; + setup_tcp_socket_options(ctrl_fd); + + uccl::send_message(ctrl_fd, &local_info, sizeof(local_info)); + NegotiationInfo remote_info; + uccl::receive_message(ctrl_fd, &remote_info, sizeof(remote_info)); + + for (size_t i = 0; i < interfaces_.size(); i++) { + InterfaceNegotiationInfo iface_info; + std::memset(&iface_info, 0, sizeof(iface_info)); + std::strncpy(iface_info.ip_addr, interfaces_[i].ip_addr.c_str(), + sizeof(iface_info.ip_addr) - 1); + iface_info.num_connections = interfaces_[i].num_connections; + iface_info.data_port = + (i < data_listen_ports_.size()) ? data_listen_ports_[i] : listen_port_; + iface_info.reserved = 0; + uccl::send_message(ctrl_fd, &iface_info, sizeof(iface_info)); + } + + std::vector remote_interfaces( + remote_info.num_interfaces); + for (int i = 0; i < remote_info.num_interfaces; i++) { + uccl::receive_message(ctrl_fd, &remote_interfaces[i], + sizeof(InterfaceNegotiationInfo)); + } + + int actual_connections = std::min(static_cast(total_connections_), + remote_info.total_connections); + actual_connections = std::max(1, actual_connections); + + LOG(INFO) << "Negotiated " << actual_connections << " data connections with " + << remote_info.num_interfaces << " remote interfaces"; + + int created_connections = 0; + size_t local_iface_idx = 0; + size_t remote_iface_idx = 0; + + while (created_connections < actual_connections) { + auto const& local_iface = interfaces_[local_iface_idx % interfaces_.size()]; + auto const& remote_iface = + remote_interfaces[remote_iface_idx % remote_interfaces.size()]; + + std::string remote_iface_ip(remote_iface.ip_addr); + uint16_t remote_data_port = remote_iface.data_port; + + int fd = create_tcp_connection_from_interface( + remote_iface_ip, remote_data_port, local_iface.ip_addr); + if (fd >= 0) { + auto conn = std::make_unique(); + conn->fd = fd; + conn->local_ip = local_iface.ip_addr; + conn->remote_ip = remote_iface_ip; + conn->remote_port = remote_data_port; + setup_tcp_socket_options(fd); + + thread_pool_->assign_data_connection(fd, conn.get(), &group->match_queue); + + group->add_data_connection(std::move(conn)); + created_connections++; + } + + local_iface_idx++; + remote_iface_idx++; + } + + LOG(INFO) << "Created " << group->data_connection_count() + << " data connections to peer"; + + { + std::unique_lock lock(conn_mutex_); + connection_groups_[conn_id] = group; + } + + uccl::ConnID result; + result.flow_id = conn_id; + result.context = reinterpret_cast(conn_id); + return result; +} + +uccl::ConnID TCPEndpoint::uccl_accept(int dev, int listen_fd, int local_gpuidx, + std::string& remote_ip, int* remote_dev, + int* remote_gpuidx) { + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + + // Accept control connection + int ctrl_fd = accept(listen_fd, (struct sockaddr*)&client_addr, &client_len); + if (ctrl_fd < 0) { + LOG(ERROR) << "accept failed: " << strerror(errno); + uccl::ConnID invalid_id; + invalid_id.flow_id = UINT64_MAX; + return invalid_id; + } + + char ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &client_addr.sin_addr, ip_str, INET_ADDRSTRLEN); + remote_ip = ip_str; + + LOG(INFO) << "TCPEndpoint::uccl_accept from " << remote_ip; + + NegotiationInfo local_info; + local_info.gpu_index = local_gpuidx; + local_info.num_interfaces = interfaces_.size(); + local_info.total_connections = total_connections_; + local_info.reserved = 0; + + NegotiationInfo remote_info; + uccl::receive_message(ctrl_fd, &remote_info, sizeof(remote_info)); + uccl::send_message(ctrl_fd, &local_info, sizeof(local_info)); + + std::vector remote_interfaces( + remote_info.num_interfaces); + for (int i = 0; i < remote_info.num_interfaces; i++) { + uccl::receive_message(ctrl_fd, &remote_interfaces[i], + sizeof(InterfaceNegotiationInfo)); + } + + for (size_t i = 0; i < interfaces_.size(); i++) { + InterfaceNegotiationInfo iface_info; + std::memset(&iface_info, 0, sizeof(iface_info)); + std::strncpy(iface_info.ip_addr, interfaces_[i].ip_addr.c_str(), + sizeof(iface_info.ip_addr) - 1); + iface_info.num_connections = interfaces_[i].num_connections; + iface_info.data_port = + (i < data_listen_ports_.size()) ? data_listen_ports_[i] : listen_port_; + iface_info.reserved = 0; + uccl::send_message(ctrl_fd, &iface_info, sizeof(iface_info)); + } + + if (remote_gpuidx) *remote_gpuidx = remote_info.gpu_index; + if (remote_dev) *remote_dev = 0; + + int actual_connections = std::min(static_cast(total_connections_), + remote_info.total_connections); + actual_connections = std::max(1, actual_connections); + + uint64_t conn_id = next_conn_id_.fetch_add(1, std::memory_order_relaxed); + + auto group = std::make_shared(); + group->ctrl_fd = ctrl_fd; + setup_tcp_socket_options(ctrl_fd); + + fd_set listen_fds; + FD_ZERO(&listen_fds); + int max_fd = -1; + + for (int data_listen_fd : data_listen_fds_) { + FD_SET(data_listen_fd, &listen_fds); + if (data_listen_fd > max_fd) max_fd = data_listen_fd; + } + + FD_SET(listen_fd, &listen_fds); + if (listen_fd > max_fd) max_fd = listen_fd; + + int accepted_connections = 0; + while (accepted_connections < actual_connections) { + fd_set read_fds = listen_fds; + + struct timeval timeout; + timeout.tv_sec = 30; + timeout.tv_usec = 0; + + int ready = select(max_fd + 1, &read_fds, nullptr, nullptr, &timeout); + if (ready <= 0) { + LOG(ERROR) << "Timeout waiting for data connection"; + break; + } + + int data_fd = -1; + std::string local_iface_ip; + + for (size_t i = 0; i < data_listen_fds_.size(); i++) { + if (FD_ISSET(data_listen_fds_[i], &read_fds)) { + data_fd = accept(data_listen_fds_[i], (struct sockaddr*)&client_addr, + &client_len); + if (data_fd >= 0 && i < interfaces_.size()) { + local_iface_ip = interfaces_[i].ip_addr; + } + break; + } + } + + if (data_fd < 0 && FD_ISSET(listen_fd, &read_fds)) { + data_fd = accept(listen_fd, (struct sockaddr*)&client_addr, &client_len); + } + + if (data_fd < 0) { + LOG(ERROR) << "Failed to accept data connection: " << strerror(errno); + break; + } + + auto conn = std::make_unique(); + conn->fd = data_fd; + conn->local_ip = local_iface_ip; + inet_ntop(AF_INET, &client_addr.sin_addr, ip_str, INET_ADDRSTRLEN); + conn->remote_ip = ip_str; + conn->remote_port = ntohs(client_addr.sin_port); + setup_tcp_socket_options(data_fd); + + thread_pool_->assign_data_connection(data_fd, conn.get(), + &group->match_queue); + + group->add_data_connection(std::move(conn)); + accepted_connections++; + } + + LOG(INFO) << "Accepted " << group->data_connection_count() + << " data connections from peer"; + + { + std::unique_lock lock(conn_mutex_); + connection_groups_[conn_id] = group; + } + + uccl::ConnID result; + result.flow_id = conn_id; + result.context = reinterpret_cast(conn_id); + return result; +} + +int TCPEndpoint::uccl_regmr(uccl::UcclFlow* flow, void* data, size_t len, + int type, struct uccl::Mhandle** mhandle) { + *mhandle = nullptr; + return 0; +} + +int TCPEndpoint::uccl_regmr(void* data, size_t len, MRArray& mr_array) { + return 0; +} + +int TCPEndpoint::uccl_regmr(int dev, void* data, size_t len, int type, + struct uccl::Mhandle** mhandle) { + if (mhandle) *mhandle = nullptr; + return 0; +} + +void TCPEndpoint::uccl_deregmr(struct uccl::Mhandle* mhandle) {} + +void TCPEndpoint::uccl_deregmr(MRArray const& mr_array) {} + +int TCPEndpoint::uccl_send_async(uccl::UcclFlow* flow, struct uccl::Mhandle* mh, + void const* data, size_t size, + struct uccl::ucclRequest* ureq) { + uint64_t conn_id = reinterpret_cast(flow); + auto group = get_connection_group(conn_id); + if (!group) return -1; + + auto handle = new TCPAsyncHandle(); + + uint32_t send_seq_id = group->match_queue.get_next_send_seq_id(); + uint32_t request_id = + next_request_id_.fetch_add(1, std::memory_order_relaxed); + handle->request_id = request_id; + + thread_pool_->register_pending_send(size, request_id, &handle->completed, + &handle->success); + + size_t num_chunks = (size + kChunkSize - 1) / kChunkSize; + size_t offset = 0; + + for (size_t i = 0; i < num_chunks; ++i) { + size_t chunk_size = std::min(kChunkSize, size - offset); + bool is_last = (i == num_chunks - 1); + + // Select different connection for each chunk (load balance across workers) + TCPConnection* conn = group->select_data_connection(); + if (!conn) { + delete handle; + return -1; + } + + TCPRequest req; + req.type = TCPRequestType::SEND; + req.ctrl_fd = group->ctrl_fd; + req.data = const_cast(static_cast(data) + offset); + req.size = chunk_size; + req.total_size = size; + req.dest_addr = offset; + req.flags = TCPDataHeader::kFlagNeedsMatch; + if (is_last) { + req.flags |= TCPDataHeader::kFlagLastChunk; + } + req.request_id = request_id; + req.send_seq_id = send_seq_id; + req.conn_group = group.get(); + req.assigned_conn = conn; + + thread_pool_->submit_request(req); + + offset += chunk_size; + } + + if (ureq) { + ureq->engine_idx = reinterpret_cast(handle); + ureq->n = conn_id; + } + + return 0; +} + +int TCPEndpoint::uccl_recv_async(uccl::UcclFlow* flow, + struct uccl::Mhandle** mhandles, void** data, + int* sizes, int n, + struct uccl::ucclRequest* ureq) { + if (n <= 0) return -1; + uint64_t conn_id = reinterpret_cast(flow); + auto group = get_connection_group(conn_id); + if (!group) return -1; + + auto handle = new TCPAsyncHandle(); + uint32_t request_id = + next_request_id_.fetch_add(1, std::memory_order_relaxed); + handle->request_id = request_id; + + thread_pool_->register_pending_recv(reinterpret_cast(data[0]), + sizes[0], request_id, &handle->completed, + &handle->success); + + group->match_queue.push_recv(reinterpret_cast(data[0]), sizes[0], + request_id); + + if (ureq) { + ureq->engine_idx = reinterpret_cast(handle); + ureq->n = conn_id; + } + + return 0; +} + +int TCPEndpoint::uccl_read_async(uccl::UcclFlow* flow, struct uccl::Mhandle* mh, + void* dst, size_t size, + uccl::FifoItem const& slot_item, + uccl::ucclRequest* ureq) { + uint64_t conn_id = reinterpret_cast(flow); + auto group = get_connection_group(conn_id); + if (!group) return -1; + + auto handle = new TCPAsyncHandle(); + uint32_t request_id = + next_request_id_.fetch_add(1, std::memory_order_relaxed); + handle->request_id = request_id; + + thread_pool_->register_pending_recv(reinterpret_cast(dst), size, + request_id, &handle->completed, + &handle->success); + + size_t num_chunks = (size + kChunkSize - 1) / kChunkSize; + size_t offset = 0; + uint64_t base_dst = reinterpret_cast(dst); + uint64_t base_remote = slot_item.addr; + + for (size_t i = 0; i < num_chunks; ++i) { + size_t chunk_size = std::min(kChunkSize, size - offset); + + TCPConnection* conn = group->select_data_connection(); + if (!conn) { + delete handle; + return -1; + } + + TCPRequest req; + req.type = TCPRequestType::READ; + req.ctrl_fd = group->ctrl_fd; + req.data = nullptr; // No data to send for READ request + req.size = chunk_size; + req.total_size = size; + req.dest_addr = base_dst + offset; + req.remote_addr = base_remote + offset; + req.completed = nullptr; + req.success = nullptr; + req.request_id = request_id; + req.conn_group = group.get(); + req.assigned_conn = conn; + + thread_pool_->submit_request(req); + + offset += chunk_size; + } + + if (ureq) { + ureq->engine_idx = reinterpret_cast(handle); + ureq->n = conn_id; + } + + return 0; +} + +int TCPEndpoint::uccl_write_async(uccl::UcclFlow* flow, + struct uccl::Mhandle* mh, void* src, + size_t size, uccl::FifoItem const& slot_item, + uccl::ucclRequest* ureq) { + uint64_t conn_id = reinterpret_cast(flow); + auto group = get_connection_group(conn_id); + if (!group) return -1; + + auto handle = new TCPAsyncHandle(); + uint32_t request_id = + next_request_id_.fetch_add(1, std::memory_order_relaxed); + handle->request_id = request_id; + + thread_pool_->register_pending_send(size, request_id, &handle->completed, + &handle->success); + + size_t num_chunks = (size + kChunkSize - 1) / kChunkSize; + size_t offset = 0; + uint64_t base_dest_addr = slot_item.addr; + + for (size_t i = 0; i < num_chunks; ++i) { + size_t chunk_size = std::min(kChunkSize, size - offset); + bool is_last = (i == num_chunks - 1); + + // Select different connection for each chunk (load balance across workers) + TCPConnection* conn = group->select_data_connection(); + if (!conn) { + delete handle; + return -1; + } + + TCPRequest req; + req.type = TCPRequestType::WRITE; + req.ctrl_fd = group->ctrl_fd; + req.data = static_cast(src) + offset; + req.size = chunk_size; + req.total_size = size; + req.dest_addr = base_dest_addr + offset; + req.flags = is_last ? TCPDataHeader::kFlagLastChunk : 0; + req.request_id = request_id; + req.conn_group = group.get(); + req.assigned_conn = conn; + + thread_pool_->submit_request(req); + + offset += chunk_size; + } + + if (ureq) { + ureq->engine_idx = reinterpret_cast(handle); + ureq->n = conn_id; + } + + return 0; +} + +bool TCPEndpoint::uccl_poll_ureq_once(struct uccl::ucclRequest* ureq) { + if (!ureq) return false; + + TCPAsyncHandle* handle = reinterpret_cast(ureq->engine_idx); + if (!handle) return true; + + bool completed = handle->completed.load(std::memory_order_acquire); + if (completed) { + delete handle; + ureq->engine_idx = 0; + } + return completed; +} + +int TCPEndpoint::prepare_fifo_metadata(uccl::UcclFlow* flow, + struct uccl::Mhandle** mhandle, + void const* data, size_t size, + char* out_buf) { + uccl::FifoItem item; + item.addr = reinterpret_cast(data); + item.size = size; + std::memset(item.padding, 0, sizeof(item.padding)); + uccl::serialize_fifo_item(item, out_buf); + return 0; +} + +void TCPEndpoint::start_listening() { + listen_fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd_ < 0) { + LOG(ERROR) << "Failed to create listen socket: " << strerror(errno); + return; + } + + int opt = 1; + setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)); + + struct sockaddr_in addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(0); + + if (bind(listen_fd_, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + LOG(ERROR) << "Failed to bind listen socket: " << strerror(errno); + close(listen_fd_); + listen_fd_ = -1; + return; + } + + socklen_t len = sizeof(addr); + getsockname(listen_fd_, (struct sockaddr*)&addr, &len); + listen_port_ = ntohs(addr.sin_port); + + if (listen(listen_fd_, 128) < 0) { + LOG(ERROR) << "Failed to listen: " << strerror(errno); + close(listen_fd_); + listen_fd_ = -1; + return; + } + + LOG(INFO) << "TCP control listening on port " << listen_port_; + + for (auto const& iface : interfaces_) { + int data_fd = socket(AF_INET, SOCK_STREAM, 0); + if (data_fd < 0) { + LOG(ERROR) << "Failed to create data listen socket for " << iface.ip_addr; + continue; + } + + setsockopt(data_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + struct sockaddr_in data_addr; + std::memset(&data_addr, 0, sizeof(data_addr)); + data_addr.sin_family = AF_INET; + inet_pton(AF_INET, iface.ip_addr.c_str(), &data_addr.sin_addr); + data_addr.sin_port = htons(0); + + if (bind(data_fd, (struct sockaddr*)&data_addr, sizeof(data_addr)) < 0) { + LOG(ERROR) << "Failed to bind data listen socket to " << iface.ip_addr + << ": " << strerror(errno); + close(data_fd); + continue; + } + + socklen_t addr_len = sizeof(data_addr); + getsockname(data_fd, (struct sockaddr*)&data_addr, &addr_len); + uint16_t data_port = ntohs(data_addr.sin_port); + + if (listen(data_fd, 128) < 0) { + LOG(ERROR) << "Failed to listen on " << iface.ip_addr; + close(data_fd); + continue; + } + + data_listen_fds_.push_back(data_fd); + data_listen_ports_.push_back(data_port); + LOG(INFO) << "TCP data listening on " << iface.ip_addr << ":" << data_port; + } +} + +int TCPEndpoint::create_tcp_connection(std::string const& remote_ip, + int remote_port) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) return -1; + + struct sockaddr_in remote_addr; + std::memset(&remote_addr, 0, sizeof(remote_addr)); + remote_addr.sin_family = AF_INET; + remote_addr.sin_port = htons(remote_port); + inet_pton(AF_INET, remote_ip.c_str(), &remote_addr.sin_addr); + + int retries = 100; + while (connect(fd, (struct sockaddr*)&remote_addr, sizeof(remote_addr)) < 0) { + if (--retries <= 0) { + LOG(ERROR) << "Failed to connect to " << remote_ip << ":" << remote_port; + close(fd); + return -1; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return fd; +} + +int TCPEndpoint::create_tcp_connection_from_interface( + std::string const& remote_ip, int remote_port, + std::string const& local_ip) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) return -1; + + struct sockaddr_in local_addr; + std::memset(&local_addr, 0, sizeof(local_addr)); + local_addr.sin_family = AF_INET; + local_addr.sin_port = 0; + inet_pton(AF_INET, local_ip.c_str(), &local_addr.sin_addr); + + if (bind(fd, (struct sockaddr*)&local_addr, sizeof(local_addr)) < 0) { + LOG(WARNING) << "Failed to bind to local IP " << local_ip; + } + + struct sockaddr_in remote_addr; + std::memset(&remote_addr, 0, sizeof(remote_addr)); + remote_addr.sin_family = AF_INET; + remote_addr.sin_port = htons(remote_port); + inet_pton(AF_INET, remote_ip.c_str(), &remote_addr.sin_addr); + + int retries = 100; + while (connect(fd, (struct sockaddr*)&remote_addr, sizeof(remote_addr)) < 0) { + if (--retries <= 0) { + close(fd); + return -1; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return fd; +} + +void TCPEndpoint::setup_tcp_socket_options(int fd) { + int opt = 1; + if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) { + LOG(ERROR) << "Failed to set TCP_NODELAY: " << strerror(errno); + } + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + LOG(ERROR) << "Failed to set SO_REUSEADDR: " << strerror(errno); + } + if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) < 0) { + LOG(ERROR) << "Failed to set SO_REUSEPORT: " << strerror(errno); + } +} + +std::shared_ptr TCPEndpoint::get_connection_group( + uint64_t conn_id) { + std::shared_lock lock(conn_mutex_); + auto it = connection_groups_.find(conn_id); + if (it == connection_groups_.end()) return nullptr; + return it->second; +} + +} // namespace tcp diff --git a/p2p/tcp/tcp_endpoint.h b/p2p/tcp/tcp_endpoint.h new file mode 100644 index 00000000..617a7608 --- /dev/null +++ b/p2p/tcp/tcp_endpoint.h @@ -0,0 +1,142 @@ +#pragma once + +#include "tcp/tcp_worker_pool.h" +#include "util/gpu_rt.h" +#include "util/net.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // For common types like uccl::ConnID, FifoItem, etc. + +namespace tcp { + +static constexpr size_t kChunkSize = 128 * 1024; +static_assert(kChunkSize <= kStagingBufferSize); +static constexpr uint64_t kBandwidthPerConnection = 20ULL * 1000 * 1000 * 1000; + +struct MRArray { + void* dummy = nullptr; +}; + +struct InterfaceInfo { + std::string name; + std::string ip_addr; + uint64_t bandwidth_bps; + int num_connections; +}; + +uint64_t get_interface_bandwidth(std::string const& ifname); +std::vector parse_tcp_interfaces(); + +class TCPEndpoint { + public: + explicit TCPEndpoint(int gpu_index, uint16_t port = 0); + ~TCPEndpoint(); + + int gpuIndex() const { return gpu_index_; } + + uccl::ConnID uccl_connect(int dev, int local_gpuidx, int remote_dev, + int remote_gpuidx, std::string remote_ip, + uint16_t remote_port); + uint16_t get_p2p_listen_port(int dev) { return listen_port_; } + int get_p2p_listen_fd(int dev) { return listen_fd_; } + uccl::ConnID uccl_accept(int dev, int listen_fd, int local_gpuidx, + std::string& remote_ip, int* remote_dev, + int* remote_gpuidx); + + int uccl_regmr(uccl::UcclFlow* flow, void* data, size_t len, int type, + struct uccl::Mhandle** mhandle); + int uccl_regmr(void* data, size_t len, MRArray& mr_array); + int uccl_regmr(int dev, void* data, size_t len, int type, + struct uccl::Mhandle** mhandle); + void uccl_deregmr(struct uccl::Mhandle* mhandle); + void uccl_deregmr(MRArray const& mr_array); + + int uccl_send_async(uccl::UcclFlow* flow, struct uccl::Mhandle* mh, + void const* data, size_t size, + struct uccl::ucclRequest* ureq); + int uccl_recv_async(uccl::UcclFlow* flow, struct uccl::Mhandle** mhandles, + void** data, int* sizes, int n, + struct uccl::ucclRequest* ureq); + int uccl_read_async(uccl::UcclFlow* flow, struct uccl::Mhandle* mh, void* dst, + size_t size, uccl::FifoItem const& slot_item, + uccl::ucclRequest* ureq); + int uccl_write_async(uccl::UcclFlow* flow, struct uccl::Mhandle* mh, + void* src, size_t size, uccl::FifoItem const& slot_item, + uccl::ucclRequest* ureq); + bool uccl_poll_ureq_once(struct uccl::ucclRequest* ureq); + int prepare_fifo_metadata(uccl::UcclFlow* flow, + struct uccl::Mhandle** mhandle, void const* data, + size_t size, char* out_buf); + + int get_best_dev_idx(int gpu_idx) { return 0; } + + bool initialize_engine_by_dev(int dev, bool enable_p2p_listen) { + return true; + } + + void create_unified_p2p_socket() {} + + private: + struct InterfaceNegotiationInfo { + char ip_addr[16]; + int32_t num_connections; + uint16_t data_port; + uint16_t reserved; + }; + + struct NegotiationInfo { + int32_t gpu_index; + int32_t num_interfaces; + int32_t total_connections; + int32_t reserved; + }; + + void start_listening(); + + int create_tcp_connection(std::string const& remote_ip, int remote_port); + + int create_tcp_connection_from_interface(std::string const& remote_ip, + int remote_port, + std::string const& local_ip); + + void setup_tcp_socket_options(int fd); + + std::shared_ptr get_connection_group(uint64_t conn_id); + + int gpu_index_; + uint16_t listen_port_; + int listen_fd_ = -1; + std::vector data_listen_fds_; + std::vector data_listen_ports_; + std::atomic next_conn_id_; + std::atomic next_request_id_; + std::atomic running_; + + std::vector interfaces_; + int total_connections_; + + std::unique_ptr thread_pool_; + + mutable std::shared_mutex conn_mutex_; + std::unordered_map> + connection_groups_; +}; + +} // namespace tcp diff --git a/p2p/tcp/tcp_worker_pool.cc b/p2p/tcp/tcp_worker_pool.cc new file mode 100644 index 00000000..c6c452d0 --- /dev/null +++ b/p2p/tcp/tcp_worker_pool.cc @@ -0,0 +1,737 @@ +#include "tcp/tcp_worker_pool.h" + +namespace tcp { + +void RecvMatchQueue::push_recv(uint64_t dest_addr, size_t size, + uint32_t recv_request_id) { + std::lock_guard lock(mutex_); + auto info = std::make_unique(); + info->dest_addr = dest_addr; + info->size = size; + info->recv_request_id = recv_request_id; + info->received.store(0, std::memory_order_relaxed); + pending_recvs_.push_back(std::move(info)); +} + +bool RecvMatchQueue::get_recv_info(uint32_t send_seq_id, + uint64_t* base_dest_addr, + uint32_t* recv_request_id) { + std::lock_guard lock(mutex_); + + auto it = in_progress_.find(send_seq_id); + if (it != in_progress_.end() && it->second) { + *base_dest_addr = it->second->dest_addr; + *recv_request_id = it->second->recv_request_id; + return true; + } + + while (next_seq_to_assign_ <= send_seq_id) { + if (pending_recvs_.empty()) { + return false; + } + in_progress_[next_seq_to_assign_] = std::move(pending_recvs_.front()); + pending_recvs_.pop_front(); + next_seq_to_assign_++; + } + + it = in_progress_.find(send_seq_id); + if (it == in_progress_.end() || !it->second) { + LOG(ERROR) << "RecvMatchQueue::get_recv_info: send_seq_id=" << send_seq_id + << " not found after assignment"; + exit(1); + } + *base_dest_addr = it->second->dest_addr; + *recv_request_id = it->second->recv_request_id; + return true; +} + +void RecvMatchQueue::add_received_bytes(uint32_t send_seq_id, size_t bytes) { + std::lock_guard lock(mutex_); + auto it = in_progress_.find(send_seq_id); + if (it == in_progress_.end()) { + LOG(ERROR) << "RecvMatchQueue::add_received_bytes: send_seq_id=" + << send_seq_id << " not found"; + exit(1); + } + + size_t new_total = + it->second->received.fetch_add(bytes, std::memory_order_relaxed) + bytes; + if (new_total >= it->second->size) { + in_progress_.erase(it); + } +} + +uint32_t RecvMatchQueue::get_next_send_seq_id() { + return next_send_seq_id_.fetch_add(1, std::memory_order_relaxed); +} + +void PendingRecvMap::add(uint64_t dest_addr, size_t size, uint32_t request_id, + std::atomic* completed, + std::atomic* success) { + std::lock_guard lock(mutex_); + pending_recvs_[request_id] = + std::make_unique(size, request_id, completed, success); +} + +bool PendingRecvMap::update_and_check_complete(uint32_t request_id, + size_t chunk_size) { + std::lock_guard lock(mutex_); + auto it = pending_recvs_.find(request_id); + if (it == pending_recvs_.end()) { + return false; // Not found, ignore + } + + PendingTransfer* pr = it->second.get(); + size_t new_received = + pr->transferred_size.fetch_add(chunk_size, std::memory_order_relaxed) + + chunk_size; + + bool is_complete = (new_received >= pr->total_size); + + if (is_complete) { + if (pr->success) { + pr->success->store(true, std::memory_order_release); + } + if (pr->completed) { + pr->completed->store(true, std::memory_order_release); + } + pending_recvs_.erase(it); + return true; + } + + return false; +} + +void PendingSendMap::add(size_t size, uint32_t request_id, + std::atomic* completed, + std::atomic* success) { + std::lock_guard lock(mutex_); + pending_sends_[request_id] = + std::make_unique(size, request_id, completed, success); +} + +bool PendingSendMap::update_and_check_complete(uint32_t request_id, + size_t chunk_size) { + std::lock_guard lock(mutex_); + auto it = pending_sends_.find(request_id); + if (it == pending_sends_.end()) { + return false; + } + + PendingTransfer* ps = it->second.get(); + size_t new_sent = + ps->transferred_size.fetch_add(chunk_size, std::memory_order_relaxed) + + chunk_size; + + bool is_complete = (new_sent >= ps->total_size); + + if (is_complete) { + if (ps->success) { + ps->success->store(true, std::memory_order_release); + } + if (ps->completed) { + ps->completed->store(true, std::memory_order_release); + } + pending_sends_.erase(it); + return true; + } + + return false; +} + +bool send_exact(int fd, void const* buf, size_t n) { + char const* ptr = static_cast(buf); + size_t sent = 0; + while (sent < n) { + ssize_t ret = ::send(fd, ptr + sent, n - sent, MSG_NOSIGNAL); + if (ret < 0) { + if (errno == EINTR) continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) continue; + return false; + } + sent += ret; + } + return true; +} + +bool recv_exact(int fd, void* buf, size_t n) { + char* ptr = static_cast(buf); + size_t received = 0; + while (received < n) { + ssize_t ret = ::recv(fd, ptr + received, n - received, 0); + if (ret < 0) { + if (errno == EINTR) continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) continue; + return false; + } + if (ret == 0) return false; + received += ret; + } + return true; +} + +bool send_header_and_data(int fd, void const* header, size_t header_size, + void const* data, size_t data_size) { + struct iovec iov[2]; + iov[0].iov_base = const_cast(header); + iov[0].iov_len = header_size; + iov[1].iov_base = const_cast(data); + iov[1].iov_len = data_size; + + struct msghdr msg; + std::memset(&msg, 0, sizeof(msg)); + msg.msg_iov = iov; + msg.msg_iovlen = 2; + + size_t total = header_size + data_size; + size_t sent = 0; + + while (sent < total) { + size_t offset = sent; + if (offset >= header_size) { + iov[0].iov_base = nullptr; + iov[0].iov_len = 0; + iov[1].iov_base = + static_cast(const_cast(data)) + (offset - header_size); + iov[1].iov_len = data_size - (offset - header_size); + } else { + iov[0].iov_base = static_cast(const_cast(header)) + offset; + iov[0].iov_len = header_size - offset; + } + + ssize_t ret = ::sendmsg(fd, &msg, MSG_NOSIGNAL); + if (ret < 0) { + if (errno == EINTR) continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) continue; + return false; + } + sent += ret; + } + return true; +} + +TCPReceiverWorker::TCPReceiverWorker(uint32_t id, TCPThreadPool* thread_pool) + : worker_id_(id), + running_(false), + epoll_fd_(-1), + thread_pool_(thread_pool), + staging_buffer_(nullptr) { + gpuError_t err = gpuMallocHost(reinterpret_cast(&staging_buffer_), + kStagingBufferSize); + if (err != gpuSuccess) { + LOG(ERROR) << "TCPReceiverWorker " << id + << ": Failed to allocate pinned memory" << err; + } + epoll_fd_ = epoll_create1(0); + if (epoll_fd_ < 0) { + LOG(ERROR) << "TCPReceiverWorker " << id << ": Failed to create epoll"; + } + LOG(INFO) << "TCPReceiverWorker " << id << " initialized"; +} + +TCPReceiverWorker::~TCPReceiverWorker() { + stop(); + if (epoll_fd_ >= 0) { + close(epoll_fd_); + } + if (staging_buffer_) { + (void)gpuFreeHost(staging_buffer_); + staging_buffer_ = nullptr; + } +} + +void TCPReceiverWorker::start() { + if (running_) return; + running_ = true; + worker_thread_ = std::thread(&TCPReceiverWorker::worker_loop, this); + LOG(INFO) << "TCPReceiverWorker " << worker_id_ << " started"; +} + +void TCPReceiverWorker::stop() { + if (!running_) return; + running_ = false; + if (worker_thread_.joinable()) { + worker_thread_.join(); + } + LOG(INFO) << "TCPReceiverWorker " << worker_id_ << " stopped"; +} + +bool TCPReceiverWorker::add_data_connection(int fd) { + struct epoll_event ev; + ev.events = EPOLLIN | EPOLLET; // Edge-triggered, read only + ev.data.fd = fd; + + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &ev) < 0) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": epoll_ctl ADD failed for fd=" << fd << " errno=" << errno; + return false; + } + + std::lock_guard lock(mutex_); + data_fds_.insert(fd); + return true; +} + +void TCPReceiverWorker::remove_data_connection(int fd) { + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr); + std::lock_guard lock(mutex_); + data_fds_.erase(fd); +} + +void TCPReceiverWorker::worker_loop() { + std::vector events(kEpollMaxEvents); + + while (running_) { + int n = + epoll_wait(epoll_fd_, events.data(), kEpollMaxEvents, kEpollTimeoutMs); + if (n < 0) { + if (errno == EINTR) continue; + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ << ": epoll_wait error"; + break; + } + + for (int i = 0; i < n; ++i) { + if (events[i].events & EPOLLIN) { + int fd = events[i].data.fd; + if (!process_event(fd)) { + struct epoll_event ev; + ev.events = EPOLLIN | EPOLLET; + ev.data.fd = fd; + epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &ev); + } + } + } + } +} + +bool TCPReceiverWorker::process_event(int fd) { + while (true) { + TCPDataHeader header; + ssize_t peeked = recv(fd, &header, sizeof(header), MSG_PEEK | MSG_DONTWAIT); + if (peeked < static_cast(sizeof(header))) { + return true; + } + + if ((header.flags & TCPDataHeader::kFlagNeedsMatch) && + static_cast(header.msg_type) == + TCPDataMsgType::DATA_CHUNK) { + RecvMatchQueue* match_queue = thread_pool_->get_match_queue(fd); + if (!match_queue) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": no match queue for fd=" << fd; + exit(1); + } + uint32_t send_seq_id = header.request_id; + uint64_t base_dest_addr; + uint32_t recv_request_id; + if (!match_queue->get_recv_info(send_seq_id, &base_dest_addr, + &recv_request_id)) { + return false; + } + } + + if (!recv_exact(fd, &header, sizeof(header))) { + exit(1); + } + + if (static_cast(header.msg_type) == + TCPDataMsgType::READ_REQUEST) { + process_read_request(fd, header); + } else { + process_data_chunk(fd, header); + } + } +} + +void TCPReceiverWorker::process_read_request(int fd, + TCPDataHeader const& header) { + if (header.size > kStagingBufferSize) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ << ": READ size too large"; + return; + } + +#if UCCL_TCP_GPU_MEMCPY + void* src = reinterpret_cast(header.remote_addr); + gpuError_t err = + gpuMemcpy(staging_buffer_, src, header.size, gpuMemcpyDeviceToHost); + if (err != gpuSuccess) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": GPU memcpy failed for READ"; + return; + } +#endif + + TCPDataHeader response; + response.msg_type = static_cast(TCPDataMsgType::DATA_CHUNK); + response.flags = 0; + response.request_id = header.request_id; + response.reserved = 0; + response.dest_addr = header.dest_addr; + response.remote_addr = 0; + response.size = header.size; + response.total_size = header.total_size; + + if (!send_exact(fd, &response, sizeof(response))) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": failed to send READ response header"; + return; + } + + if (!send_exact(fd, staging_buffer_, header.size)) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": failed to send READ response data"; + return; + } +} + +void TCPReceiverWorker::process_data_chunk(int fd, + TCPDataHeader const& header) { + if (header.size > kStagingBufferSize) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ << ": chunk too large"; + return; + } + + if (!recv_exact(fd, staging_buffer_, header.size)) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": failed to read chunk data"; + return; + } + + uint64_t actual_dest_addr = header.dest_addr; + uint32_t recv_request_id = header.request_id; + uint32_t send_seq_id = 0; + + RecvMatchQueue* match_queue = nullptr; + if (header.flags & TCPDataHeader::kFlagNeedsMatch) { + match_queue = thread_pool_->get_match_queue(fd); + if (!match_queue) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": no match queue for fd=" << fd; + return; + } + + send_seq_id = header.request_id; + uint64_t base_dest_addr; + if (!match_queue->get_recv_info(send_seq_id, &base_dest_addr, + &recv_request_id)) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ + << ": failed to get recv_info for send_seq_id=" << send_seq_id; + return; + } + + actual_dest_addr = base_dest_addr + header.dest_addr; + } + +#if UCCL_TCP_GPU_MEMCPY + void* gpu_dest = reinterpret_cast(actual_dest_addr); + gpuError_t err = + gpuMemcpy(gpu_dest, staging_buffer_, header.size, gpuMemcpyHostToDevice); + if (err != gpuSuccess) { + LOG(ERROR) << "TCPReceiverWorker " << worker_id_ << ": GPU memcpy failed"; + } +#endif + + thread_pool_->get_pending_recvs()->update_and_check_complete(recv_request_id, + header.size); + + if (header.flags & TCPDataHeader::kFlagNeedsMatch) { + match_queue->add_received_bytes(send_seq_id, header.size); + } +} + +TCPSenderWorker::TCPSenderWorker(uint32_t id, TCPThreadPool* thread_pool) + : worker_id_(id), + running_(false), + request_ring_(nullptr), + staging_buffer_(nullptr), + thread_pool_(thread_pool) { + gpuError_t err = gpuMallocHost(reinterpret_cast(&staging_buffer_), + kStagingBufferSize); + if (err != gpuSuccess) { + LOG(ERROR) << "TCPSenderWorker " << id + << ": Failed to allocate pinned memory" << err; + } + request_ring_ = uccl::create_ring(sizeof(TCPRequest), kRequestRingSize); + LOG(INFO) << "TCPSenderWorker " << id << " initialized"; +} + +TCPSenderWorker::~TCPSenderWorker() { + stop(); + if (request_ring_) { + free(request_ring_); + } + if (staging_buffer_) { + (void)gpuFreeHost(staging_buffer_); + staging_buffer_ = nullptr; + } +} + +void TCPSenderWorker::start() { + if (running_) return; + running_ = true; + worker_thread_ = std::thread(&TCPSenderWorker::worker_loop, this); + LOG(INFO) << "TCPSenderWorker " << worker_id_ << " started"; +} + +void TCPSenderWorker::stop() { + if (!running_) return; + running_ = false; + + TCPRequest shutdown_req; + shutdown_req.type = TCPRequestType::SHUTDOWN; + submit_request(shutdown_req); + + if (worker_thread_.joinable()) { + worker_thread_.join(); + } + LOG(INFO) << "TCPSenderWorker " << worker_id_ << " stopped"; +} + +bool TCPSenderWorker::submit_request(TCPRequest const& req) { + while (jring_sp_enqueue_bulk(request_ring_, &req, 1, nullptr) != 1) { + if (!running_) return false; + std::this_thread::yield(); + } + return true; +} + +void TCPSenderWorker::worker_loop() { + while (running_) { + if (!process_requests()) { + std::this_thread::yield(); + } + } +} + +bool TCPSenderWorker::process_requests() { + TCPRequest req; + bool processed_any = false; + + while (jring_sc_dequeue_bulk(request_ring_, &req, 1, nullptr) == 1) { + processed_any = true; + + if (req.type == TCPRequestType::SHUTDOWN) { + return true; + } + + bool success = false; + switch (req.type) { + case TCPRequestType::SEND: + success = do_send(req); + break; + case TCPRequestType::WRITE: + success = do_write(req); + break; + case TCPRequestType::READ: + success = do_read(req); + break; + default: + break; + } + + // For SEND and WRITE, track bytes sent via shared map + if (req.type == TCPRequestType::SEND || req.type == TCPRequestType::WRITE) { + if (success) { + thread_pool_->get_pending_sends()->update_and_check_complete( + req.request_id, req.size); + } + } + } + return processed_any; +} + +bool TCPSenderWorker::do_send(TCPRequest& req) { + if (!req.data || req.size == 0) return false; + + TCPConnection* conn = req.assigned_conn; + if (!conn || !conn->is_valid()) { + LOG(ERROR) << "TCPSenderWorker " << worker_id_ + << ": no valid assigned connection"; + return false; + } + + TCPDataHeader header; + header.msg_type = static_cast(TCPDataMsgType::DATA_CHUNK); + header.flags = req.flags; + header.request_id = req.send_seq_id; + header.reserved = 0; + header.dest_addr = req.dest_addr; + header.remote_addr = 0; + header.size = req.size; + header.total_size = req.total_size; + +#if UCCL_TCP_GPU_MEMCPY + gpuError_t err = + gpuMemcpy(staging_buffer_, static_cast(req.data), req.size, + gpuMemcpyDeviceToHost); + if (err != gpuSuccess) { + return false; + } +#endif + + // Send header + data in one syscall using scatter-gather I/O + if (!send_header_and_data(conn->fd, &header, sizeof(header), staging_buffer_, + req.size)) { + return false; + } + + return true; +} + +bool TCPSenderWorker::do_write(TCPRequest& req) { + if (!req.data || req.size == 0) return false; + + TCPConnection* conn = req.assigned_conn; + if (!conn || !conn->is_valid()) return false; + + TCPDataHeader header; + header.msg_type = static_cast(TCPDataMsgType::DATA_CHUNK); + header.flags = req.flags; + header.request_id = req.request_id; + header.reserved = 0; + header.dest_addr = req.dest_addr; + header.remote_addr = 0; + header.size = req.size; + header.total_size = req.total_size; + +#if UCCL_TCP_GPU_MEMCPY + gpuError_t err = + gpuMemcpy(staging_buffer_, static_cast(req.data), req.size, + gpuMemcpyDeviceToHost); + if (err != gpuSuccess) { + return false; + } +#endif + + // Send header + data in one syscall using scatter-gather I/O + if (!send_header_and_data(conn->fd, &header, sizeof(header), staging_buffer_, + req.size)) { + return false; + } + + return true; +} + +bool TCPSenderWorker::do_read(TCPRequest& req) { + TCPConnection* conn = req.assigned_conn; + if (!conn || !conn->is_valid()) return false; + + TCPDataHeader header; + header.msg_type = static_cast(TCPDataMsgType::READ_REQUEST); + header.flags = 0; + header.request_id = req.request_id; + header.reserved = 0; + header.dest_addr = req.dest_addr; + header.remote_addr = req.remote_addr; + header.size = req.size; + header.total_size = req.total_size; + + if (!send_exact(conn->fd, &header, sizeof(header))) { + return false; + } + + return true; +} + +TCPConnection* TCPConnectionGroup::select_data_connection() { + std::shared_lock lock(mutex); + if (data_connections.empty()) return nullptr; + if (data_connections.size() == 1) return data_connections[0].get(); + + size_t idx = round_robin_idx.fetch_add(1, std::memory_order_relaxed) % + data_connections.size(); + return data_connections[idx].get(); +} + +void TCPConnectionGroup::add_data_connection( + std::unique_ptr conn) { + std::unique_lock lock(mutex); + data_connections.push_back(std::move(conn)); +} + +size_t TCPConnectionGroup::data_connection_count() const { + std::shared_lock lock(mutex); + return data_connections.size(); +} + +TCPThreadPool::TCPThreadPool(size_t num_threads) { + if (num_threads == 0) { + num_threads = get_tcp_thread_count(); + } + + size_t num_senders = std::max(size_t{1}, num_threads / 2); + size_t num_receivers = std::max(size_t{1}, num_threads - num_senders); + + sender_workers_.reserve(num_senders); + for (size_t i = 0; i < num_senders; ++i) { + sender_workers_.push_back(std::make_unique(i, this)); + } + + receiver_workers_.reserve(num_receivers); + for (size_t i = 0; i < num_receivers; ++i) { + receiver_workers_.push_back(std::make_unique(i, this)); + } + + LOG(INFO) << "TCPThreadPool created with " << num_senders + << " sender workers + " << num_receivers << " receiver workers"; +} + +TCPThreadPool::~TCPThreadPool() { stop(); } + +void TCPThreadPool::start() { + for (auto& w : sender_workers_) w->start(); + for (auto& w : receiver_workers_) w->start(); +} + +void TCPThreadPool::stop() { + for (auto& w : sender_workers_) w->stop(); + for (auto& w : receiver_workers_) w->stop(); +} + +uint32_t TCPThreadPool::assign_data_connection(int fd, TCPConnection* conn, + RecvMatchQueue* match_queue) { + uint32_t id = next_receiver_.fetch_add(1, std::memory_order_relaxed) % + receiver_workers_.size(); + receiver_workers_[id]->add_data_connection(fd); + conn->receiver_worker_id = id; + + id = next_sender_.fetch_add(1, std::memory_order_relaxed) % + sender_workers_.size(); + conn->sender_worker_id = id; + + if (match_queue) { + std::unique_lock lock(fd_match_queue_mutex_); + fd_to_match_queue_[fd] = match_queue; + } + + return id; +} + +RecvMatchQueue* TCPThreadPool::get_match_queue(int fd) { + std::shared_lock lock(fd_match_queue_mutex_); + auto it = fd_to_match_queue_.find(fd); + if (it == fd_to_match_queue_.end()) { + return nullptr; + } + return it->second; +} + +bool TCPThreadPool::submit_request(TCPRequest const& req) { + uint32_t id = req.assigned_conn->sender_worker_id; + return sender_workers_[id]->submit_request(req); +} + +void TCPThreadPool::register_pending_recv(uint64_t dest_addr, size_t size, + uint32_t request_id, + std::atomic* completed, + std::atomic* success) { + pending_recvs_.add(dest_addr, size, request_id, completed, success); +} + +void TCPThreadPool::register_pending_send(size_t size, uint32_t request_id, + std::atomic* completed, + std::atomic* success) { + pending_sends_.add(size, request_id, completed, success); +} + +} // namespace tcp diff --git a/p2p/tcp/tcp_worker_pool.h b/p2p/tcp/tcp_worker_pool.h new file mode 100644 index 00000000..210b2955 --- /dev/null +++ b/p2p/tcp/tcp_worker_pool.h @@ -0,0 +1,313 @@ +#pragma once + +#include "util/gpu_rt.h" +#include "util/jring.h" +#include "util/util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tcp { + +#define UCCL_TCP_GPU_MEMCPY 0 + +static constexpr size_t kStagingBufferSize = 16 * 1024 * 1024; +static constexpr size_t kDefaultTCPThreads = 10; +static constexpr size_t kRequestRingSize = 1024; +static constexpr int kEpollMaxEvents = 64; + +bool send_exact(int fd, void const* buf, size_t n); +bool recv_exact(int fd, void* buf, size_t n); +bool send_header_and_data(int fd, void const* header, size_t header_size, + void const* data, size_t data_size); +static constexpr int kEpollTimeoutMs = 100; + +inline size_t get_tcp_thread_count() { + char const* env = std::getenv("UCCL_P2P_TCP_THREADS"); + if (env && strlen(env) > 0) { + int val = std::atoi(env); + if (val > 0) { + LOG(INFO) << "TCP: Using " << val << " threads from UCCL_P2P_TCP_THREADS"; + return static_cast(val); + } + } + LOG(INFO) << "TCP: Using default " << kDefaultTCPThreads << " threads"; + return kDefaultTCPThreads; +} + +enum class TCPRequestType : uint32_t { + SEND = 0, + WRITE = 1, + READ = 3, + SHUTDOWN = 255 +}; + +enum class TCPDataMsgType : uint32_t { + DATA_CHUNK = 0, + READ_REQUEST = 1, +}; + +struct TCPDataHeader { + uint32_t msg_type; + uint32_t flags; + uint32_t request_id; + uint32_t reserved; + uint64_t dest_addr; + uint64_t remote_addr; + uint64_t size; + uint64_t total_size; + + static constexpr uint32_t kFlagLastChunk = 1; + static constexpr uint32_t kFlagNeedsMatch = 2; +}; +static_assert(sizeof(TCPDataHeader) == 48, "TCPDataHeader size mismatch"); + +struct TCPConnection; +struct TCPConnectionGroup; +class TCPReceiverWorker; +class TCPSenderWorker; + +struct alignas(64) TCPRequest { + TCPRequestType type; + int ctrl_fd; + void* data; + size_t size; + size_t total_size; + uint64_t dest_addr; + uint64_t remote_addr; + std::atomic* completed; + std::atomic* success; + uint32_t request_id; + uint32_t send_seq_id; + uint32_t flags; + void* conn_group; + TCPConnection* assigned_conn; + + TCPRequest() + : type(TCPRequestType::SEND), + ctrl_fd(-1), + data(nullptr), + size(0), + total_size(0), + dest_addr(0), + remote_addr(0), + completed(nullptr), + success(nullptr), + request_id(0), + send_seq_id(0), + flags(0), + conn_group(nullptr), + assigned_conn(nullptr) {} +}; + +struct TCPConnection { + int fd = -1; + std::string local_ip; + std::string remote_ip; + int remote_port = 0; + uint32_t sender_worker_id = 0; + uint32_t receiver_worker_id = 0; + + TCPConnection() = default; + + TCPConnection(TCPConnection&& other) noexcept + : fd(other.fd), + local_ip(std::move(other.local_ip)), + remote_ip(std::move(other.remote_ip)), + remote_port(other.remote_port), + sender_worker_id(other.sender_worker_id), + receiver_worker_id(other.receiver_worker_id) { + other.fd = -1; + } + + ~TCPConnection() { + if (fd >= 0) { + close(fd); + fd = -1; + } + } + + bool is_valid() const { return fd >= 0; } +}; + +struct alignas(64) PendingTransfer { + size_t total_size; + std::atomic transferred_size{0}; + uint32_t request_id; + std::atomic* completed; + std::atomic* success; + + PendingTransfer() = default; + PendingTransfer(size_t size, uint32_t req_id, std::atomic* comp, + std::atomic* succ) + : total_size(size), + transferred_size(0), + request_id(req_id), + completed(comp), + success(succ) {} +}; + +class PendingRecvMap { + public: + void add(uint64_t dest_addr, size_t size, uint32_t request_id, + std::atomic* completed, std::atomic* success); + bool update_and_check_complete(uint32_t request_id, size_t chunk_size); + + private: + mutable std::mutex mutex_; + std::unordered_map> pending_recvs_; +}; + +class PendingSendMap { + public: + void add(size_t size, uint32_t request_id, std::atomic* completed, + std::atomic* success); + bool update_and_check_complete(uint32_t request_id, size_t chunk_size); + + private: + mutable std::mutex mutex_; + std::unordered_map> pending_sends_; +}; + +struct RecvMatchInfo { + uint64_t dest_addr; + size_t size; + uint32_t recv_request_id; + std::atomic received{0}; +}; + +class RecvMatchQueue { + public: + void push_recv(uint64_t dest_addr, size_t size, uint32_t recv_request_id); + bool get_recv_info(uint32_t send_seq_id, uint64_t* base_dest_addr, + uint32_t* recv_request_id); + void add_received_bytes(uint32_t send_seq_id, size_t bytes); + uint32_t get_next_send_seq_id(); + + private: + mutable std::mutex mutex_; + std::deque> pending_recvs_; + std::unordered_map> in_progress_; + uint32_t next_seq_to_assign_{0}; + std::atomic next_send_seq_id_{0}; +}; + +class TCPThreadPool; + +class TCPReceiverWorker { + public: + TCPReceiverWorker(uint32_t id, TCPThreadPool* thread_pool); + ~TCPReceiverWorker(); + + void start(); + void stop(); + bool add_data_connection(int fd); + void remove_data_connection(int fd); + uint32_t id() const { return worker_id_; } + + private: + void worker_loop(); + bool process_event(int fd); + void process_read_request(int fd, TCPDataHeader const& header); + void process_data_chunk(int fd, TCPDataHeader const& header); + + uint32_t worker_id_; + std::atomic running_; + int epoll_fd_; + std::thread worker_thread_; + TCPThreadPool* thread_pool_; + char* staging_buffer_; + mutable std::mutex mutex_; + std::unordered_set data_fds_; +}; + +class TCPSenderWorker { + public: + TCPSenderWorker(uint32_t id, TCPThreadPool* thread_pool); + ~TCPSenderWorker(); + + void start(); + void stop(); + bool submit_request(TCPRequest const& req); + uint32_t id() const { return worker_id_; } + + private: + void worker_loop(); + bool process_requests(); + bool do_send(TCPRequest& req); + bool do_write(TCPRequest& req); + bool do_read(TCPRequest& req); + + uint32_t worker_id_; + std::atomic running_; + jring_t* request_ring_; + char* staging_buffer_; + std::thread worker_thread_; + TCPThreadPool* thread_pool_; +}; + +struct TCPConnectionGroup { + int ctrl_fd = -1; + std::vector> data_connections; + std::atomic round_robin_idx{0}; + mutable std::shared_mutex mutex; + RecvMatchQueue match_queue; + + TCPConnection* select_data_connection(); + void add_data_connection(std::unique_ptr conn); + size_t data_connection_count() const; +}; + +class TCPThreadPool { + public: + explicit TCPThreadPool(size_t num_threads = 0); + ~TCPThreadPool(); + + void start(); + void stop(); + uint32_t assign_data_connection(int fd, TCPConnection* conn, + RecvMatchQueue* match_queue); + bool submit_request(TCPRequest const& req); + void register_pending_recv(uint64_t dest_addr, size_t size, + uint32_t request_id, std::atomic* completed, + std::atomic* success); + void register_pending_send(size_t size, uint32_t request_id, + std::atomic* completed, + std::atomic* success); + RecvMatchQueue* get_match_queue(int fd); + PendingRecvMap* get_pending_recvs() { return &pending_recvs_; } + PendingSendMap* get_pending_sends() { return &pending_sends_; } + + private: + PendingRecvMap pending_recvs_; + PendingSendMap pending_sends_; + std::vector> sender_workers_; + std::vector> receiver_workers_; + std::atomic next_sender_{0}; + std::atomic next_receiver_{0}; + mutable std::shared_mutex fd_match_queue_mutex_; + std::unordered_map fd_to_match_queue_; +}; + +struct TCPAsyncHandle { + std::atomic completed{false}; + std::atomic success{false}; + uint32_t request_id{0}; +}; + +} // namespace tcp