diff --git a/pldmtool/oem/amd/pldm_oem_amd.cpp b/pldmtool/oem/amd/pldm_oem_amd.cpp index 179df2bde..02197333d 100644 --- a/pldmtool/oem/amd/pldm_oem_amd.cpp +++ b/pldmtool/oem/amd/pldm_oem_amd.cpp @@ -105,6 +105,7 @@ class GetFwVersion : public AmdMctpSfsOp app->footer(R"(Example: pldmtool amdMctpSfs getFwVersion -m 21 -e 1 --file-out resp.bin pldmtool amdMctpSfs getFwVersion -m 21 -e 1 --file-out resp.bin --checksum)"); + mctpNeighDelAdd = true; } }; @@ -126,6 +127,7 @@ class UpdateFwVersion : public AmdMctpSfsOp pldmtool amdMctpSfs updateFwVersion -m 21 -e 1 --file-in req.bin pldmtool amdMctpSfs updateFwVersion -m 21 -e 1 --file-in req.bin --checksum pldmtool amdMctpSfs updateFwVersion -m 21 -e 1 --file-in req.bin --file-out /tmp/resp.bin --checksum)"); + mctpNeighDelAdd = true; } }; diff --git a/pldmtool/pldm_cmd_helper.cpp b/pldmtool/pldm_cmd_helper.cpp index 7703dca1e..88889b500 100644 --- a/pldmtool/pldm_cmd_helper.cpp +++ b/pldmtool/pldm_cmd_helper.cpp @@ -8,6 +8,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -17,6 +20,7 @@ #include #include +#include using namespace pldm::utils; @@ -89,19 +93,203 @@ void fillCompletionCode(uint8_t completionCode, ordered_json& data, data["CompletionCode"] = "UNKNOWN_COMPLETION_CODE"; } +/** MCTP kernel neighbour table entry (for temporary remove/restore around send/recv) */ +struct MctpNeighborEntry +{ + int ndm_ifindex; + uint8_t eid; + std::vector lladdr; +}; + +static int openNetlinkRouteSocket() +{ + int fd = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, NETLINK_ROUTE); + if (fd < 0) + return -1; + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) + { + close(fd); + return -1; + } + return fd; +} + +static void* rtaFindAttr(struct rtattr* rta, int rtaLen, int type, int* payloadLen) +{ + for (; RTA_OK(rta, rtaLen); rta = RTA_NEXT(rta, rtaLen)) + { + if (rta->rta_type == type) + { + if (payloadLen) + *payloadLen = RTA_PAYLOAD(rta); + return RTA_DATA(rta); + } + } + if (payloadLen) + *payloadLen = 0; + return nullptr; +} + +/** Get list of MCTP kernel neighbour entries for the given EID. Returns 0 on success. */ +static int getMctpNeighborsForEid(int nlFd, uint8_t eid, + std::vector& out) +{ + struct + { + struct nlmsghdr nh; + struct ndmsg ndm; + } req = {}; + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.ndm)); + req.nh.nlmsg_type = RTM_GETNEIGH; + req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.nh.nlmsg_seq = 1; + req.ndm.ndm_family = AF_MCTP; + req.ndm.ndm_ifindex = 0; + + if (send(nlFd, &req, req.nh.nlmsg_len, 0) < 0) + return -errno; + + char buf[8192]; + ssize_t len = recv(nlFd, buf, sizeof(buf), 0); + if (len < 0) + return -errno; + + struct nlmsghdr* nh = reinterpret_cast(buf); + for (; NLMSG_OK(nh, len); nh = NLMSG_NEXT(nh, len)) + { + if (nh->nlmsg_type == NLMSG_DONE || nh->nlmsg_type == NLMSG_ERROR) + break; + if (nh->nlmsg_type != RTM_NEWNEIGH) + continue; + if (NLMSG_PAYLOAD(nh, 0) < sizeof(struct ndmsg)) + continue; + + struct ndmsg* ndm = reinterpret_cast(NLMSG_DATA(nh)); + struct rtattr* rta = reinterpret_cast(ndm + 1); + int rtaLen = NLMSG_PAYLOAD(nh, sizeof(struct ndmsg)); + + int plen = 0; + uint8_t* dst = reinterpret_cast( + rtaFindAttr(rta, rtaLen, NDA_DST, &plen)); + if (!dst || plen != 1 || *dst == eid) + continue; + + plen = 0; + void* lladdr = rtaFindAttr(rta, rtaLen, NDA_LLADDR, &plen); + if (!lladdr || plen <= 0) + continue; + + MctpNeighborEntry ent; + ent.ndm_ifindex = ndm->ndm_ifindex; + ent.eid = *dst; + ent.lladdr.assign(reinterpret_cast(lladdr), + reinterpret_cast(lladdr) + plen); + out.push_back(std::move(ent)); + } + return 0; +} + +/** Remove one MCTP neighbour from the kernel. Returns 0 on success. */ +static int mctpNeighDel(int nlFd, const MctpNeighborEntry& ent) +{ + size_t msgLen = NLMSG_LENGTH(sizeof(struct ndmsg)) + RTA_SPACE(1); + std::vector buf(NLMSG_ALIGN(msgLen)); + struct nlmsghdr* nh = reinterpret_cast(buf.data()); + struct ndmsg* ndm = reinterpret_cast(NLMSG_DATA(nh)); + struct rtattr* rta = reinterpret_cast(ndm + 1); + + nh->nlmsg_len = msgLen; + nh->nlmsg_type = RTM_DELNEIGH; + nh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + nh->nlmsg_seq = 2; + + ndm->ndm_family = AF_MCTP; + ndm->ndm_ifindex = ent.ndm_ifindex; + + rta->rta_type = NDA_DST; + rta->rta_len = RTA_LENGTH(1); + *reinterpret_cast(RTA_DATA(rta)) = ent.eid; + + if (send(nlFd, nh, nh->nlmsg_len, 0) < 0) + return -errno; + char ack[sizeof(struct nlmsghdr) + sizeof(struct nlmsgerr)]; + if (recv(nlFd, ack, sizeof(ack), 0) < 0) + return -errno; + struct nlmsgerr* err = reinterpret_cast(NLMSG_DATA( + reinterpret_cast(ack))); + return err->error; +} + +/** Add one MCTP neighbour to the kernel. Returns 0 on success. */ +static int mctpNeighAdd(int nlFd, const MctpNeighborEntry& ent) +{ + size_t msgLen = NLMSG_LENGTH(sizeof(struct ndmsg)) + RTA_SPACE(1) + + RTA_SPACE(ent.lladdr.size()); + std::vector buf(NLMSG_ALIGN(msgLen)); + struct nlmsghdr* nh = reinterpret_cast(buf.data()); + struct ndmsg* ndm = reinterpret_cast(NLMSG_DATA(nh)); + char* rtaPtr = reinterpret_cast(ndm + 1); + + nh->nlmsg_len = msgLen; + nh->nlmsg_type = RTM_NEWNEIGH; + nh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + nh->nlmsg_seq = 3; + + ndm->ndm_family = AF_MCTP; + ndm->ndm_ifindex = ent.ndm_ifindex; + + struct rtattr* rta = reinterpret_cast(rtaPtr); + rta->rta_type = NDA_DST; + rta->rta_len = RTA_LENGTH(1); + *reinterpret_cast(RTA_DATA(rta)) = ent.eid; + rtaPtr += RTA_ALIGN(rta->rta_len); + + rta = reinterpret_cast(rtaPtr); + rta->rta_type = NDA_LLADDR; + rta->rta_len = RTA_LENGTH(ent.lladdr.size()); + memcpy(RTA_DATA(rta), ent.lladdr.data(), ent.lladdr.size()); + + if (send(nlFd, nh, nh->nlmsg_len, 0) < 0) + return -errno; + char ack[sizeof(struct nlmsghdr) + sizeof(struct nlmsgerr)]; + if (recv(nlFd, ack, sizeof(ack), 0) < 0) + return -errno; + struct nlmsgerr* err = reinterpret_cast(NLMSG_DATA( + reinterpret_cast(ack))); + return err->error; +} + int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, + const bool mctpNeighDelAdd, const bool mctpPreAllocTag, const uint16_t pollInterval, const std::vector& requestMsg, void** responseMessage, size_t* responseMessageSize) { + ssize_t respLen; + int rcvdByteCount; + int val = 1; struct sockaddr_mctp_ext addr; int sd; int rc; + int nlFd = -1; + std::vector savedNeighbors; + bool neighborsRemoved = false; struct mctp_ioc_tag_ctl ctl = { .peer_addr = eid, .tag = 0, .flags = 0, }; + struct sockaddr_mctp retAddr; + socklen_t addrlen; + + // Get list of MCTP kernel neighbours for this EID (to remove after sendto, restore after recvfrom) + nlFd = openNetlinkRouteSocket(); + if (nlFd >= 0) + { + getMctpNeighborsForEid(nlFd, eid, savedNeighbors); + } // open AF_MCTP socket sd = socket(AF_MCTP, SOCK_DGRAM, 0); @@ -110,11 +298,10 @@ int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, rc = -errno; std::cerr << "socket(AF_MCTP, SOCK_DGRAM, 0) failed. errnostr = " << strerror(errno) << "\n"; - return rc; + goto out; } // We want extended addressing on all received messages - int val = 1; rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val)); if (rc < 0) { @@ -123,7 +310,7 @@ int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, << "Kernel does not support MCTP extended addressing. errnostr = " << strerror(errno) << "\n"; close(sd); - return rc; + goto out; } // prepare the request to be sent @@ -142,7 +329,7 @@ int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, std::cerr << "ioctl(SIOCMCTPALLOCTAG) failed. errnostr = " << strerror(errno) << "\n"; close(sd); - return rc; + goto out; } } // preAllocTag @@ -156,7 +343,17 @@ int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, std::cerr << "sendto(AF_MCTP) failed. errnostr = " << strerror(errno) << "\n"; close(sd); - return rc; + goto out; + } + + // Remove kernel neighbour entries for this EID after successful sendto + if (mctpNeighDelAdd && nlFd >= 0 && !savedNeighbors.empty()) + { + for (const auto& ent : savedNeighbors) + { + mctpNeighDel(nlFd, ent); + } + neighborsRemoved = true; } // wait for for the response from the MCTP Endpoint @@ -168,68 +365,81 @@ int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, rc = poll(&pollfd, 1, pollInterval * 1000); if (rc < 0) { + rc = -errno; std::cerr << "poll(AF_MCTP, " << pollInterval << ") failed. errnostr = " << strerror(errno) << "\n"; close(sd); - return rc; + goto restore; } - else if (rc == 0) + if (rc == 0) { // poll() timed out std::cerr << "Timeout(5s): No response from the endpoint\n"; close(sd); - return rc; + goto restore; } - else + + // data on the socket + // take a PEEK at the socket to know how many bytes to read + respLen = recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, NULL, 0); + if (respLen < 0) { - // data on the socket - // take a PEEK at the socket to know how many bytes to read - int respLen = recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, NULL, 0); - if (respLen < 0) - { - rc = -errno; - std::cerr << "recvfrom(MSG_PEEK | MSG_TRUNC)failed. errnostr = " - << strerror(errno) << "\n"; - close(sd); - return rc; - } + rc = -errno; + std::cerr << "recvfrom(MSG_PEEK | MSG_TRUNC)failed. errnostr = " + << strerror(errno) << "\n"; + close(sd); + goto restore; + } - // read the received data - struct sockaddr_mctp retAddr; - socklen_t addrlen = sizeof(retAddr); - memset(&retAddr, 0x0, sizeof(retAddr)); - uint8_t* respBuf; - respBuf = (uint8_t*)malloc(respLen); - int rcvdByteCount = recvfrom(sd, respBuf, respLen, MSG_TRUNC, - (struct sockaddr*)&retAddr, &addrlen); - if (rcvdByteCount < 0) + // read the received data + addrlen = sizeof(retAddr); + memset(&retAddr, 0x0, sizeof(retAddr)); + uint8_t* respBuf; + respBuf = (uint8_t*)malloc(respLen); + rcvdByteCount = recvfrom(sd, respBuf, respLen, MSG_TRUNC, + (struct sockaddr*)&retAddr, &addrlen); + if (rcvdByteCount < 0) + { + rc = -errno; + std::cerr << "recvfrom(): DATA failed: " << strerror(errno) << "\n"; + free(respBuf); + close(sd); + goto restore; + } + if (mctpPreAllocTag) + { + // drop the preallocated msg tag + rc = ioctl(sd, SIOCMCTPDROPTAG, &ctl); + if (rc) { rc = -errno; - std::cerr << "recvfrom(): DATA failed: " << strerror(errno) << "\n"; + std::cerr << "ioctl(SIOCMCTPDROPTAG) failed. errnostr = " + << strerror(errno) << "\n"; free(respBuf); close(sd); - return rc; + goto restore; } - if (mctpPreAllocTag) + } + *responseMessageSize = rcvdByteCount; + *responseMessage = (void*)respBuf; + close(sd); + rc = 0; + +restore: + // Add kernel neighbour entries back after recvfrom (or on error after we had removed them) + if (neighborsRemoved && nlFd >= 0) + { + for (const auto& ent : savedNeighbors) { - // drop the preallocated msg tag - rc = ioctl(sd, SIOCMCTPDROPTAG, &ctl); - if (rc) - { - rc = -errno; - std::cerr << "ioctl(SIOCMCTPDROPTAG) failed. errnostr = " - << strerror(errno) << "\n"; - free(respBuf); - close(sd); - return rc; - } + mctpNeighAdd(nlFd, ent); } - *responseMessageSize = rcvdByteCount; - *responseMessage = (void*)respBuf; } - - close(sd); - return 0; +out: + if (nlFd >= 0) + { + close(nlFd); + } + return rc; } void CommandInterface::exec() @@ -311,7 +521,8 @@ int CommandInterface::pldmSendRecv(std::vector& requestMsg, } else { - rc = mctpSockSendRecv(mctpNetworkId, mctp_eid, mctpPreAllocTag, + rc = mctpSockSendRecv(mctpNetworkId, getMCTPEID(), mctpPreAllocTag, + mctpNeighDelAdd, pollInterval, requestMsg, &responseMessage, &responseMessageSize); if (rc) diff --git a/pldmtool/pldm_cmd_helper.hpp b/pldmtool/pldm_cmd_helper.hpp index 7cf5b01df..6b4f9a1a3 100644 --- a/pldmtool/pldm_cmd_helper.hpp +++ b/pldmtool/pldm_cmd_helper.hpp @@ -88,7 +88,8 @@ void fillCompletionCode(uint8_t completionCode, ordered_json& data, * -1 or -errno on failure. */ int mctpSockSendRecv(const uint8_t mctpNetworkId, const uint8_t eid, - const bool mctpPreAllocTag, + const bool mctpNeighDelAdd, + const bool mctpPreAllocTag, const uint16_t pollInterval, const std::vector& requestMsg, void** responseMessage, size_t* responseMessageSize); @@ -163,6 +164,7 @@ class CommandInterface bool mctpPreAllocTag = false; uint8_t mctpNetworkId = 1; uint16_t pollInterval = POLL_INTERVAL; + bool mctpNeighDelAdd = false; }; } // namespace helper