From a92e93a14b52b422844d6daed2459e0b265a3711 Mon Sep 17 00:00:00 2001 From: feng zhao Date: Sat, 5 Jul 2025 06:23:39 +0000 Subject: [PATCH] Fix done message parse Though the document of kernel (https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types) specify the format of netlink message with NLMSG_DONE, it also says that "Note that some implementations may issue custom NLMSG_DONE messages in reply to do action requests. In that case the payload is implementation-specific and may also be absent.". After searching the source code of kernel, we can find that 1. the format specified in the document is obeyed by most generic netlink but some generic netlink like this (https://elixir.bootlin.com/linux/v6.15/source/drivers/net/team/team_core.c#L2494) has no payload, so as a generic lib, we should not suppose the format of DoneMessage. it's sensible to just save the payload. 2. when use NLMSG_DONE as an end of multi messages, there will always be a NLM_F_MULTIPART in the flag and only in this case should we parse it as a DoneMessage, in other occassion like connector netlink (https://elixir.bootlin.com/linux/v6.15/source/drivers/connector/connector.c#L101), we should parse it as a common message. Signed-off-by: feng zhao --- src/done.rs | 117 ++++++++++++++----------------------------------- src/message.rs | 29 +++++------- 2 files changed, 44 insertions(+), 102 deletions(-) diff --git a/src/done.rs b/src/done.rs index 1cba4d6..c8b5e06 100644 --- a/src/done.rs +++ b/src/done.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MIT -use std::mem::size_of; - use byteorder::{ByteOrder, NativeEndian}; use netlink_packet_utils::DecodeError; @@ -11,99 +9,52 @@ const CODE: Field = 0..4; const EXTENDED_ACK: Rest = 4..; const DONE_HEADER_LEN: usize = EXTENDED_ACK.start; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] #[non_exhaustive] -pub struct DoneBuffer { - buffer: T, +pub struct DoneMessage { + pub payload: Vec, } -impl> DoneBuffer { - pub fn new(buffer: T) -> DoneBuffer { - DoneBuffer { buffer } - } - - /// Consume the packet, returning the underlying buffer. - pub fn into_inner(self) -> T { - self.buffer - } - - pub fn new_checked(buffer: T) -> Result { - let packet = Self::new(buffer); - packet.check_buffer_length()?; - Ok(packet) - } - - fn check_buffer_length(&self) -> Result<(), DecodeError> { - let len = self.buffer.as_ref().len(); - if len < DONE_HEADER_LEN { - Err(format!( - "invalid DoneBuffer: length is {len} but DoneBuffer are \ - at least {DONE_HEADER_LEN} bytes" - ) - .into()) +impl DoneMessage { + pub fn code(&self) -> Option { + if self.payload.len() < DONE_HEADER_LEN { + None } else { - Ok(()) + Some(NativeEndian::read_i32(&self.payload[CODE])) } } - /// Return the error code - pub fn code(&self) -> i32 { - let data = self.buffer.as_ref(); - NativeEndian::read_i32(&data[CODE]) - } -} - -impl<'a, T: AsRef<[u8]> + ?Sized> DoneBuffer<&'a T> { - /// Return a pointer to the extended ack attributes. - pub fn extended_ack(&self) -> &'a [u8] { - let data = self.buffer.as_ref(); - &data[EXTENDED_ACK] - } -} - -impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> DoneBuffer<&'a mut T> { - /// Return a mutable pointer to the extended ack attributes. - pub fn extended_ack_mut(&mut self) -> &mut [u8] { - let data = self.buffer.as_mut(); - &mut data[EXTENDED_ACK] + pub fn extended_ack(&self) -> Option<&[u8]> { + if self.payload.len() < DONE_HEADER_LEN { + None + } else { + Some(&self.payload[EXTENDED_ACK]) + } } -} -impl + AsMut<[u8]>> DoneBuffer { - /// set the error code field - pub fn set_code(&mut self, value: i32) { - let data = self.buffer.as_mut(); - NativeEndian::write_i32(&mut data[CODE], value) + pub fn new_with_code>(code: i32, extend_ack: &T) -> Self { + let mut payload = vec![0; DONE_HEADER_LEN + extend_ack.as_ref().len()]; + NativeEndian::write_i32(&mut payload, code); + payload[CODE.end..].copy_from_slice(extend_ack.as_ref()); + Self { payload } } } -#[derive(Debug, Default, Clone, PartialEq, Eq)] -#[non_exhaustive] -pub struct DoneMessage { - pub code: i32, - pub extended_ack: Vec, -} - impl Emitable for DoneMessage { fn buffer_len(&self) -> usize { - size_of::() + self.extended_ack.len() + self.payload.len() } fn emit(&self, buffer: &mut [u8]) { - let mut buffer = DoneBuffer::new(buffer); - buffer.set_code(self.code); - buffer - .extended_ack_mut() - .copy_from_slice(&self.extended_ack); + buffer.copy_from_slice(&self.payload); } } -impl> Parseable> for DoneMessage { +impl> Parseable for DoneMessage { type Error = DecodeError; - fn parse(buf: &DoneBuffer<&T>) -> Result { + fn parse(buf: &T) -> Result { Ok(DoneMessage { - code: buf.code(), - extended_ack: buf.extended_ack().to_vec(), + payload: buf.as_ref().to_vec(), }) } } @@ -114,22 +65,18 @@ mod tests { #[test] fn serialize_and_parse() { - let expected = DoneMessage { - code: 5, - extended_ack: vec![1, 2, 3], - }; - + let expected = DoneMessage::new_with_code(5, &[1, 2, 3]); let len = expected.buffer_len(); - assert_eq!(len, size_of::() + expected.extended_ack.len()); + assert_eq!( + len, + size_of::() + expected.extended_ack().unwrap().len() + ); let mut buf = vec![0; len]; expected.emit(&mut buf); - let done_buf = DoneBuffer::new(&buf); - assert_eq!(done_buf.code(), expected.code); - assert_eq!(done_buf.extended_ack(), &expected.extended_ack); - - let got = DoneMessage::parse(&done_buf).unwrap(); - assert_eq!(got, expected); + let got = DoneMessage::parse(&buf); + assert!(got.is_ok()); + assert_eq!(got.unwrap(), expected); } } diff --git a/src/message.rs b/src/message.rs index b3b0541..f6cbe03 100644 --- a/src/message.rs +++ b/src/message.rs @@ -6,9 +6,9 @@ use netlink_packet_utils::DecodeError; use crate::{ payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN}, - DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, - NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload, - NetlinkSerializable, Parseable, + DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer, + NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, + Parseable, NLM_F_MULTIPART, }; /// Represent a netlink message. @@ -103,10 +103,11 @@ where Error(msg) } NLMSG_NOOP => Noop, - NLMSG_DONE => { - let msg = DoneBuffer::new_checked(&bytes) - .and_then(|buf| DoneMessage::parse(&buf))?; - Done(msg) + // only parse message_type of NLMSG_DONE when flag has + // NLM_F_MULTIPART because some special netlink like + // connector use NLMSG_DONE for all the message + NLMSG_DONE if header.flags & NLM_F_MULTIPART == NLM_F_MULTIPART => { + Done(DoneMessage::parse(&bytes)?) } NLMSG_OVERRUN => Overrun(bytes.to_vec()), message_type => match I::deserialize(&header, bytes) { @@ -205,11 +206,9 @@ mod tests { #[test] fn test_done() { - let header = NetlinkHeader::default(); - let done_msg = DoneMessage { - code: 0, - extended_ack: vec![6, 7, 8, 9], - }; + let mut header = NetlinkHeader::default(); + header.flags |= NLM_F_MULTIPART; + let done_msg = DoneMessage::new_with_code(0, &[6, 7, 8, 9]); let mut want = NetlinkMessage::new( header, NetlinkPayload::::Done(done_msg.clone()), @@ -221,16 +220,12 @@ mod tests { len, header.buffer_len() + size_of::() - + done_msg.extended_ack.len() + + done_msg.extended_ack().unwrap().len() ); let mut buf = vec![1; len]; want.emit(&mut buf); - let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]); - assert_eq!(done_buf.code(), done_msg.code); - assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack); - let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap(); assert_eq!(got, want); }