From 1e487105265b5ea500483fb8ef9132705ad96a10 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:10 -0700 Subject: [PATCH 01/21] Migrate to draft-16 protocol messages and parameters --- moq-transport/src/coding/kvp.rs | 217 ++++++++++--- moq-transport/src/coding/mod.rs | 2 + moq-transport/src/coding/track_extensions.rs | 196 ++++++++++++ moq-transport/src/message/dynamic_groups.rs | 131 ++++++++ moq-transport/src/message/fetch_ok.rs | 9 + moq-transport/src/message/mod.rs | 59 ++-- moq-transport/src/message/namespace.rs | 61 ++++ moq-transport/src/message/parameters.rs | 47 +++ moq-transport/src/message/publish.rs | 111 ++----- .../src/message/publish_namespace_done.rs | 41 +++ moq-transport/src/message/publish_ok.rs | 172 +--------- moq-transport/src/message/publisher.rs | 8 +- moq-transport/src/message/request_error.rs | 87 +++++ moq-transport/src/message/request_ok.rs | 45 +++ moq-transport/src/message/subscribe.rs | 154 +-------- .../src/message/subscribe_namespace.rs | 23 +- moq-transport/src/message/subscribe_ok.rs | 78 ++--- moq-transport/src/message/subscribe_update.rs | 76 ++--- moq-transport/src/message/subscriber.rs | 6 +- moq-transport/src/setup/auth_token.rs | 298 ++++++++++++++++++ moq-transport/src/setup/client.rs | 31 +- moq-transport/src/setup/mod.rs | 4 +- moq-transport/src/setup/param_types.rs | 4 + moq-transport/src/setup/server.rs | 28 +- moq-transport/src/setup/version.rs | 6 + 25 files changed, 1269 insertions(+), 625 deletions(-) create mode 100644 moq-transport/src/coding/track_extensions.rs create mode 100644 moq-transport/src/message/dynamic_groups.rs create mode 100644 moq-transport/src/message/namespace.rs create mode 100644 moq-transport/src/message/parameters.rs create mode 100644 moq-transport/src/message/publish_namespace_done.rs create mode 100644 moq-transport/src/message/request_error.rs create mode 100644 moq-transport/src/message/request_ok.rs create mode 100644 moq-transport/src/setup/auth_token.rs diff --git a/moq-transport/src/coding/kvp.rs b/moq-transport/src/coding/kvp.rs index 2ed9caa9..065f5d39 100644 --- a/moq-transport/src/coding/kvp.rs +++ b/moq-transport/src/coding/kvp.rs @@ -48,13 +48,46 @@ impl KeyValuePair { value: Value::BytesValue(value), } } -} -impl Decode for KeyValuePair { - fn decode(r: &mut R) -> Result { - let key = u64::decode(r)?; + /// Validate that the key parity matches the value type. + /// Even keys => IntValue, Odd keys => BytesValue. + fn validate_key_parity(&self) -> Result<(), EncodeError> { + match &self.value { + Value::IntValue(_) => { + if !self.key.is_multiple_of(2) { + return Err(EncodeError::InvalidValue); + } + } + Value::BytesValue(_) => { + if self.key.is_multiple_of(2) { + return Err(EncodeError::InvalidValue); + } + } + } + Ok(()) + } - if key % 2 == 0 { + /// Encode only the value portion of this KVP (not the key/delta). + /// The caller is responsible for encoding the key or delta type. + pub(crate) fn encode_value(&self, w: &mut W) -> Result<(), EncodeError> { + self.validate_key_parity()?; + match &self.value { + Value::IntValue(v) => { + (*v).encode(w)?; + } + Value::BytesValue(v) => { + v.len().encode(w)?; + Self::encode_remaining(w, v.len())?; + w.put_slice(v); + } + } + Ok(()) + } + + /// Decode only the value portion of a KVP given the absolute key. + /// The caller has already decoded the key/delta and resolved the absolute key. + pub(crate) fn decode_value(key: u64, r: &mut R) -> Result { + if key.is_multiple_of(2) { // VarInt variant let value = u64::decode(r)?; log::trace!("[KVP] Decoded even key={}, value={}", key, value); @@ -81,30 +114,22 @@ impl Decode for KeyValuePair { } } +/// Legacy Decode for KeyValuePair — reads absolute key from wire. +/// Used only by ExtensionHeaders which reads KVPs from a bounded byte slice. +impl Decode for KeyValuePair { + fn decode(r: &mut R) -> Result { + let key = u64::decode(r)?; + Self::decode_value(key, r) + } +} + +/// Legacy Encode for KeyValuePair — writes absolute key to wire. +/// Used only by ExtensionHeaders which writes KVPs into a temporary buffer. impl Encode for KeyValuePair { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - match &self.value { - Value::IntValue(v) => { - // key must be even for IntValue - if !self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; - (*v).encode(w)?; - Ok(()) - } - Value::BytesValue(v) => { - // key must be odd for BytesValue - if self.key.is_multiple_of(2) { - return Err(EncodeError::InvalidValue); - } - self.key.encode(w)?; - v.len().encode(w)?; - Self::encode_remaining(w, v.len())?; - w.put_slice(v); - Ok(()) - } - } + self.validate_key_parity()?; + self.key.encode(w)?; + self.encode_value(w) } } @@ -116,7 +141,10 @@ impl fmt::Debug for KeyValuePair { /// A collection of KeyValuePair entries, where the number of key-value-pairs are encoded/decoded first. /// This structure is appropriate for Control message parameters. -/// Since duplicate parameters are allowed for unknown parameters, we don't do duplicate checking here. +/// +/// Per draft-16 Section 1.4.2, Key-Value-Pairs use delta-encoded Type fields: +/// each Type is encoded as a delta from the previous Type (or from 0 for the first). +/// Entries are sorted by key (Type) in ascending order for encoding. #[derive(Default, Clone, Eq, PartialEq)] pub struct KeyValuePairs(pub Vec); @@ -150,16 +178,49 @@ impl KeyValuePairs { pub fn get(&self, key: u64) -> Option<&KeyValuePair> { self.0.iter().find(|k| k.key == key) } + + /// Get an integer value by key, returning None if not found or if the value is not an integer + pub fn get_intvalue(&self, key: u64) -> Option { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + Value::BytesValue(_) => None, + }) + } + + /// Get a bytes value by key, returning None if not found or if the value is not bytes + pub fn get_bytesvalue(&self, key: u64) -> Option<&Vec> { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(_) => None, + Value::BytesValue(v) => Some(v), + }) + } } impl Decode for KeyValuePairs { - fn decode(mut r: &mut R) -> Result { + /// Decode Key-Value-Pairs with delta-encoded Type fields (draft-16 Section 1.4.2). + fn decode(r: &mut R) -> Result { let mut kvps = Vec::new(); let count = u64::decode(r)?; + let mut prev_key: u64 = 0; + for _ in 0..count { - let kvp = KeyValuePair::decode(&mut r)?; + // Read delta type + let delta = u64::decode(r)?; + + // Reconstruct absolute key: prev_key + delta + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[KVP] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; kvps.push(kvp); + prev_key = key; } Ok(KeyValuePairs(kvps)) @@ -167,11 +228,32 @@ impl Decode for KeyValuePairs { } impl Encode for KeyValuePairs { + /// Encode Key-Value-Pairs with delta-encoded Type fields (draft-16 Section 1.4.2). + /// Entries are sorted by key in ascending order before encoding. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.0.len().encode(w)?; - for kvp in &self.0 { - kvp.encode(w)?; + // Sort by key for delta encoding (Types must be in ascending order) + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + // Compute and encode the delta + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[KVP] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + + // Encode the value (without the key) + kvp.encode_value(w)?; + + prev_key = kvp.key; } Ok(()) @@ -243,9 +325,10 @@ mod tests { } #[test] - fn encode_decode_keyvaluepairs() { + fn encode_decode_keyvaluepairs_single() { let mut buf = BytesMut::new(); + // Single entry: key=1 (odd, bytes). Delta from 0 = 1. let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); @@ -253,21 +336,79 @@ mod tests { buf.to_vec(), vec![ 0x01, // 1 KeyValuePair - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + // Delta=1 (from 0), then length=5, then data + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, ] ); let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); + } + #[test] + fn encode_decode_keyvaluepairs_multiple() { + let mut buf = BytesMut::new(); + + // Multiple entries inserted out of order — encoding should sort by key. + // Keys: 0 (even, int), 1 (odd, bytes), 100 (even, int) let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(0, 0); kvps.set_intvalue(100, 100); kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); kvps.encode(&mut buf).unwrap(); - let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair count - assert_eq!(14, buf_vec.len()); // 14 bytes total - assert_eq!(3, buf_vec[0]); // 3 KeyValuePairs + + #[rustfmt::skip] + let expected = vec![ + 0x03, // 3 KeyValuePairs + // Entry 1: key=0 (delta=0 from 0), even, int value=0 + 0x00, 0x00, + // Entry 2: key=1 (delta=1 from 0), odd, bytes len=5 + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, + // Entry 3: key=100 (delta=99 from 1), even, int value=100 + 0x40, 0x63, 0x40, 0x64, + ]; + assert_eq!(buf.to_vec(), expected); + + // Decode and verify — decoded entries will be in sorted order + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + // Build expected sorted kvps for comparison + let mut expected_kvps = KeyValuePairs::new(); + expected_kvps.set_intvalue(0, 0); + expected_kvps.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + expected_kvps.set_intvalue(100, 100); + assert_eq!(decoded, expected_kvps); + } + + #[test] + fn encode_decode_keyvaluepairs_roundtrip_sorted() { + let mut buf = BytesMut::new(); + + // Insert in sorted order — should roundtrip exactly + let mut kvps = KeyValuePairs::new(); + kvps.set_intvalue(2, 42); + kvps.set_intvalue(4, 100); + kvps.encode(&mut buf).unwrap(); + + #[rustfmt::skip] + let expected = vec![ + 0x02, // 2 KeyValuePairs + // Entry 1: key=2 (delta=2), int value=42 + 0x02, 0x2a, + // Entry 2: key=4 (delta=2 from 2), int value=100 + 0x02, 0x40, 0x64, + ]; + assert_eq!(buf.to_vec(), expected); + + let decoded = KeyValuePairs::decode(&mut buf).unwrap(); + assert_eq!(decoded, kvps); + } + + #[test] + fn encode_decode_keyvaluepairs_empty() { + let mut buf = BytesMut::new(); + + let kvps = KeyValuePairs::new(); + kvps.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // count=0 let decoded = KeyValuePairs::decode(&mut buf).unwrap(); assert_eq!(decoded, kvps); } diff --git a/moq-transport/src/coding/mod.rs b/moq-transport/src/coding/mod.rs index 13e97be1..71cc4148 100644 --- a/moq-transport/src/coding/mod.rs +++ b/moq-transport/src/coding/mod.rs @@ -6,6 +6,7 @@ mod integer; mod kvp; mod location; mod string; +mod track_extensions; mod track_namespace; mod tuple; mod varint; @@ -16,6 +17,7 @@ pub use encode::*; pub use hex_dump::*; pub use kvp::*; pub use location::*; +pub use track_extensions::*; pub use track_namespace::*; pub use tuple::*; pub use varint::*; diff --git a/moq-transport/src/coding/track_extensions.rs b/moq-transport/src/coding/track_extensions.rs new file mode 100644 index 00000000..3da50c55 --- /dev/null +++ b/moq-transport/src/coding/track_extensions.rs @@ -0,0 +1,196 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePair, Value}; +use std::fmt; + +/// A collection of KeyValuePair entries for Track Extensions. +/// Per draft-16 Section 9.10, Track Extensions are encoded WITHOUT a count or length prefix. +/// They are simply a sequence of delta-encoded key-value pairs until end of message. +/// +/// This differs from: +/// - KeyValuePairs: has a count prefix +/// - ExtensionHeaders: has a byte-length prefix +#[derive(Default, Clone, Eq, PartialEq)] +pub struct TrackExtensions(pub Vec); + +impl TrackExtensions { + pub fn new() -> Self { + Self::default() + } + + /// Insert or replace a KeyValuePair with the same key. + pub fn set(&mut self, kvp: KeyValuePair) { + if let Some(existing) = self.0.iter_mut().find(|k| k.key == kvp.key) { + *existing = kvp; + } else { + self.0.push(kvp); + } + } + + pub fn set_intvalue(&mut self, key: u64, value: u64) { + self.set(KeyValuePair::new_int(key, value)); + } + + pub fn set_bytesvalue(&mut self, key: u64, value: Vec) { + self.set(KeyValuePair::new_bytes(key, value)); + } + + pub fn has(&self, key: u64) -> bool { + self.0.iter().any(|k| k.key == key) + } + + pub fn get(&self, key: u64) -> Option<&KeyValuePair> { + self.0.iter().find(|k| k.key == key) + } + + /// Get an integer value by key, returning None if not found or if the value is not an integer + pub fn get_intvalue(&self, key: u64) -> Option { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(v) => Some(*v), + Value::BytesValue(_) => None, + }) + } + + /// Get a bytes value by key, returning None if not found or if the value is not bytes + pub fn get_bytesvalue(&self, key: u64) -> Option<&Vec> { + self.get(key).and_then(|kvp| match &kvp.value { + Value::IntValue(_) => None, + Value::BytesValue(v) => Some(v), + }) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Decode for TrackExtensions { + /// Decode Track Extensions - reads delta-encoded key-value pairs until end of buffer. + /// Per draft-16, Track Extensions have NO count or length prefix. + fn decode(r: &mut R) -> Result { + let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + + // Read until buffer is exhausted + while r.has_remaining() { + // Read delta type + let delta = u64::decode(r)?; + + // Reconstruct absolute key: prev_key + delta + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[TrackExt] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; + kvps.push(kvp); + prev_key = key; + } + + Ok(TrackExtensions(kvps)) + } +} + +impl Encode for TrackExtensions { + /// Encode Track Extensions - writes delta-encoded key-value pairs WITHOUT any prefix. + /// Entries are sorted by key in ascending order before encoding. + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + // Sort by key for delta encoding (Types must be in ascending order) + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + let mut prev_key: u64 = 0; + for kvp in sorted { + // Compute and encode the delta + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[TrackExt] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(w)?; + + // Encode the value (without the key) + kvp.encode_value(w)?; + + prev_key = kvp.key; + } + + Ok(()) + } +} + +impl fmt::Debug for TrackExtensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{{ ")?; + for (i, kv) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", kv)?; + } + write!(f, " }}") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode_empty() { + let mut buf = BytesMut::new(); + + let ext = TrackExtensions::new(); + ext.encode(&mut buf).unwrap(); + // Empty TrackExtensions produces NO bytes (no prefix!) + let expected: Vec = vec![]; + assert_eq!(buf.to_vec(), expected); + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } + + #[test] + fn encode_decode_single() { + let mut buf = BytesMut::new(); + + let mut ext = TrackExtensions::new(); + ext.set_intvalue(2, 42); // key=2 (even), value=42 + ext.encode(&mut buf).unwrap(); + + // Expected: delta=2, value=42 (no count or length prefix!) + assert_eq!(buf.to_vec(), vec![0x02, 0x2a]); + + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } + + #[test] + fn encode_decode_multiple() { + let mut buf = BytesMut::new(); + + let mut ext = TrackExtensions::new(); + ext.set_intvalue(0, 0); + ext.set_intvalue(2, 100); + ext.encode(&mut buf).unwrap(); + + // Expected: + // Entry 1: delta=0, value=0 + // Entry 2: delta=2 (from 0), value=100 + // No count prefix! + #[rustfmt::skip] + let expected = vec![ + 0x00, 0x00, // delta=0, value=0 + 0x02, 0x40, 0x64, // delta=2, value=100 (varint) + ]; + assert_eq!(buf.to_vec(), expected); + + let decoded = TrackExtensions::decode(&mut buf).unwrap(); + assert_eq!(decoded, ext); + } +} diff --git a/moq-transport/src/message/dynamic_groups.rs b/moq-transport/src/message/dynamic_groups.rs new file mode 100644 index 00000000..4a5b7708 --- /dev/null +++ b/moq-transport/src/message/dynamic_groups.rs @@ -0,0 +1,131 @@ +//! Dynamic Groups support for MOQT. +//! +//! This module provides helper functions for working with Dynamic Groups parameters +//! as defined in the MOQT specification. Dynamic Groups allow subscribers to request +//! publishers to create new groups on demand. + +use crate::coding::KeyValuePairs; +use crate::message::ParameterType; + +/// Helper trait for Dynamic Groups parameter operations on KeyValuePairs. +pub trait DynamicGroupsExt { + /// Check if dynamic groups are enabled/supported + fn has_dynamic_groups(&self) -> bool; + + /// Get the dynamic groups value (if present) + fn get_dynamic_groups(&self) -> Option; + + /// Enable dynamic groups support + fn set_dynamic_groups(&mut self, value: u64); + + /// Check if a new group request is present + fn has_new_group_request(&self) -> bool; + + /// Get the new group request value (if present) + fn get_new_group_request(&self) -> Option; + + /// Request a new group from the publisher + fn set_new_group_request(&mut self, value: u64); +} + +impl DynamicGroupsExt for KeyValuePairs { + fn has_dynamic_groups(&self) -> bool { + self.has(ParameterType::DynamicGroups.into()) + } + + fn get_dynamic_groups(&self) -> Option { + self.get_intvalue(ParameterType::DynamicGroups.into()) + } + + fn set_dynamic_groups(&mut self, value: u64) { + self.set_intvalue(ParameterType::DynamicGroups.into(), value); + } + + fn has_new_group_request(&self) -> bool { + self.has(ParameterType::NewGroupRequest.into()) + } + + fn get_new_group_request(&self) -> Option { + self.get_intvalue(ParameterType::NewGroupRequest.into()) + } + + fn set_new_group_request(&mut self, value: u64) { + self.set_intvalue(ParameterType::NewGroupRequest.into(), value); + } +} + +/// Dynamic Groups configuration for a track +#[derive(Clone, Debug, Default)] +pub struct DynamicGroupsConfig { + /// Whether dynamic groups are enabled for this track + pub enabled: bool, + /// The current pending new group request (if any) + pub pending_request: Option, +} + +impl DynamicGroupsConfig { + /// Create a new configuration with dynamic groups disabled + pub fn new() -> Self { + Self::default() + } + + /// Create a new configuration with dynamic groups enabled + pub fn enabled() -> Self { + Self { + enabled: true, + pending_request: None, + } + } + + /// Request a new group with the given request ID + pub fn request_new_group(&mut self, request_id: u64) { + self.pending_request = Some(request_id); + } + + /// Clear the pending request (after it has been processed) + pub fn clear_pending_request(&mut self) { + self.pending_request = None; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dynamic_groups_ext() { + let mut params = KeyValuePairs::new(); + + // Initially no dynamic groups + assert!(!params.has_dynamic_groups()); + assert_eq!(params.get_dynamic_groups(), None); + + // Enable dynamic groups + params.set_dynamic_groups(1); + assert!(params.has_dynamic_groups()); + assert_eq!(params.get_dynamic_groups(), Some(1)); + + // New group request + assert!(!params.has_new_group_request()); + params.set_new_group_request(42); + assert!(params.has_new_group_request()); + assert_eq!(params.get_new_group_request(), Some(42)); + } + + #[test] + fn test_dynamic_groups_config() { + let config = DynamicGroupsConfig::new(); + assert!(!config.enabled); + assert!(config.pending_request.is_none()); + + let config = DynamicGroupsConfig::enabled(); + assert!(config.enabled); + + let mut config = DynamicGroupsConfig::enabled(); + config.request_new_group(123); + assert_eq!(config.pending_request, Some(123)); + + config.clear_pending_request(); + assert!(config.pending_request.is_none()); + } +} diff --git a/moq-transport/src/message/fetch_ok.rs b/moq-transport/src/message/fetch_ok.rs index c8db3c04..f3b721ed 100644 --- a/moq-transport/src/message/fetch_ok.rs +++ b/moq-transport/src/message/fetch_ok.rs @@ -1,4 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; +use crate::data::ExtensionHeaders; use crate::message::GroupOrder; /// A publisher sends a FETCH_OK control message in response to successful fetches. @@ -18,6 +19,9 @@ pub struct FetchOk { /// Optional parameters pub params: KeyValuePairs, + + /// Track extensions + pub track_extensions: ExtensionHeaders, } impl Decode for FetchOk { @@ -33,6 +37,7 @@ impl Decode for FetchOk { let end_of_track = bool::decode(r)?; let end_location = Location::decode(r)?; let params = KeyValuePairs::decode(r)?; + let track_extensions = ExtensionHeaders::decode(r)?; Ok(Self { id, @@ -40,6 +45,7 @@ impl Decode for FetchOk { end_of_track, end_location, params, + track_extensions, }) } } @@ -57,6 +63,7 @@ impl Encode for FetchOk { self.end_of_track.encode(w)?; self.end_location.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -81,6 +88,7 @@ mod tests { end_of_track: true, end_location: Location::new(2, 3), params: kvps.clone(), + track_extensions: Default::default(), }; msg.encode(&mut buf).unwrap(); let decoded = FetchOk::decode(&mut buf).unwrap(); @@ -97,6 +105,7 @@ mod tests { end_of_track: true, end_location: Location::new(2, 3), params: Default::default(), + track_extensions: Default::default(), }; let encoded = msg.encode(&mut buf); assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); diff --git a/moq-transport/src/message/mod.rs b/moq-transport/src/message/mod.rs index 267368c7..246474a8 100644 --- a/moq-transport/src/message/mod.rs +++ b/moq-transport/src/message/mod.rs @@ -5,73 +5,65 @@ //! The only exception are OBJECT "messages", which are sent over dedicated QUIC streams. //! +mod dynamic_groups; mod fetch; mod fetch_cancel; -mod fetch_error; mod fetch_ok; mod fetch_type; mod filter_type; mod go_away; mod group_order; mod max_request_id; -mod pubilsh_namespace_done; +mod namespace; +mod parameters; mod publish; mod publish_done; -mod publish_error; mod publish_namespace; mod publish_namespace_cancel; -mod publish_namespace_error; -mod publish_namespace_ok; +mod publish_namespace_done; mod publish_ok; mod publisher; +mod request_error; +mod request_ok; mod requests_blocked; mod subscribe; -mod subscribe_error; mod subscribe_namespace; -mod subscribe_namespace_error; -mod subscribe_namespace_ok; mod subscribe_ok; mod subscribe_update; mod subscriber; mod track_status; -mod track_status_error; mod track_status_ok; mod unsubscribe; -mod unsubscribe_namespace; +pub use dynamic_groups::*; pub use fetch::*; pub use fetch_cancel::*; -pub use fetch_error::*; pub use fetch_ok::*; pub use fetch_type::*; pub use filter_type::*; pub use go_away::*; pub use group_order::*; pub use max_request_id::*; -pub use pubilsh_namespace_done::*; +pub use namespace::*; +pub use parameters::*; pub use publish::*; pub use publish_done::*; -pub use publish_error::*; pub use publish_namespace::*; pub use publish_namespace_cancel::*; -pub use publish_namespace_error::*; -pub use publish_namespace_ok::*; +pub use publish_namespace_done::*; pub use publish_ok::*; pub use publisher::*; +pub use request_error::*; +pub use request_ok::*; pub use requests_blocked::*; pub use subscribe::*; -pub use subscribe_error::*; pub use subscribe_namespace::*; -pub use subscribe_namespace_error::*; -pub use subscribe_namespace_ok::*; pub use subscribe_ok::*; pub use subscribe_update::*; pub use subscriber::*; pub use track_status::*; -pub use track_status_error::*; pub use track_status_ok::*; pub use unsubscribe::*; -pub use unsubscribe_namespace::*; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use std::fmt; @@ -89,13 +81,18 @@ macro_rules! message_types { impl Decode for Message { fn decode(r: &mut R) -> Result { let t = u64::decode(r)?; - let _len = u16::decode(r)?; + let len = u16::decode(r)? as usize; - // TODO: Check the length of the message. + // Read exactly len bytes into a sub-buffer to properly handle Track Extensions + if r.remaining() < len { + return Err(DecodeError::More(len - r.remaining())); + } + let payload = r.copy_to_bytes(len); + let mut payload_reader = std::io::Cursor::new(payload); match t { $($val => { - let msg = $name::decode(r)?; + let msg = $name::decode(&mut payload_reader)?; Ok(Self::$name(msg)) })* _ => Err(DecodeError::InvalidMessage(t)), @@ -185,40 +182,36 @@ message_types! { Unsubscribe = 0xa, // SUBSCRIBE family, sent by publisher SubscribeOk = 0x4, - SubscribeError = 0x5, // ANNOUNCE family, sent by publisher PublishNamespace = 0x6, PublishNamespaceDone = 0x9, // ANNOUNCE family, sent by subscriber - PublishNamespaceOk = 0x7, - PublishNamespaceError = 0x8, + RequestOk = 0x7, PublishNamespaceCancel = 0xc, + // NAMESPACE family, sent by relay to subscriber (draft-16) + Namespace = 0x8, + // TRACK_STATUS family, sent by subscriber TrackStatus = 0xd, // TRACK_STATUS family, sent by publisher TrackStatusOk = 0xe, - TrackStatusError = 0xf, // NAMESPACE family, sent by subscriber SubscribeNamespace = 0x11, - UnsubscribeNamespace = 0x14, - // NAMESPACE family, sent by publisher - SubscribeNamespaceOk = 0x12, - SubscribeNamespaceError = 0x13, // FETCH family, sent by subscriber Fetch = 0x16, FetchCancel = 0x17, // FETCH family, sent by publisher FetchOk = 0x18, - FetchError = 0x19, // PUBLISH family, sent by publisher Publish = 0x1d, PublishDone = 0xb, // PUBLISH family, sent by subscriber PublishOk = 0x1e, - PublishError = 0x1f, + + RequestError = 0x5, } diff --git a/moq-transport/src/message/namespace.rs b/moq-transport/src/message/namespace.rs new file mode 100644 index 00000000..978d1a1d --- /dev/null +++ b/moq-transport/src/message/namespace.rs @@ -0,0 +1,61 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; + +/// NAMESPACE message (draft-16) +/// +/// Sent by relay to subscriber to announce a namespace matching their SUBSCRIBE_NAMESPACE. +/// This is different from PUBLISH_NAMESPACE which is sent by publisher to relay. +/// +/// Wire format: 0x08 +#[derive(Clone, Debug)] +pub struct Namespace { + /// Request ID (from the SUBSCRIBE_NAMESPACE) + pub id: u64, + /// The namespace being announced + pub track_namespace: TrackNamespace, + /// Optional parameters + pub params: KeyValuePairs, +} + +impl Decode for Namespace { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let track_namespace = TrackNamespace::decode(r)?; + let params = KeyValuePairs::decode(r)?; + + Ok(Self { + id, + track_namespace, + params, + }) + } +} + +impl Encode for Namespace { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.track_namespace.encode(w)?; + self.params.encode(w)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_namespace_encode_decode() { + let msg = Namespace { + id: 42, + track_namespace: TrackNamespace::from_utf8_path("live/room1"), + params: KeyValuePairs::new(), + }; + + let mut buf = Vec::new(); + msg.encode(&mut buf).unwrap(); + + let decoded = Namespace::decode(&mut buf.as_slice()).unwrap(); + assert_eq!(decoded.id, 42); + assert_eq!(decoded.track_namespace.to_utf8_path(), "live/room1"); + } +} diff --git a/moq-transport/src/message/parameters.rs b/moq-transport/src/message/parameters.rs new file mode 100644 index 00000000..b61d9e37 --- /dev/null +++ b/moq-transport/src/message/parameters.rs @@ -0,0 +1,47 @@ +/// Version-Specific Message Parameter Types +/// Used in SUBSCRIBE, SUBSCRIBE_OK, PUBLISH, FETCH, REQUEST_UPDATE, etc. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u64)] +pub enum ParameterType { + /// Used in: REQUEST_OK, PUBLISH, PUBLISH_OK, SUBSCRIBE, SUBSCRIBE_OK, REQUEST_UPDATE + DeliveryTimeout = 0x02, + /// Used in: CLIENT_SETUP, SERVER_SETUP, PUBLISH, SUBSCRIBE, REQUEST_UPDATE, + /// SUBSCRIBE_NAMESPACE, PUBLISH_NAMESPACE, TRACK_STATUS, FETCH + AuthorizationToken = 0x03, + /// Used in: PUBLISH, SUBSCRIBE_OK, FETCH_OK, REQUEST_OK + MaxCacheDuration = 0x04, + /// Used in: SUBSCRIBE_OK, PUBLISH, PUBLISH_OK + Expires = 0x08, + /// Used in: SUBSCRIBE_OK, PUBLISH, REQUEST_OK + LargestObject = 0x09, + /// Used in: SUBSCRIBE_OK, PUBLISH + PublisherPriority = 0x0E, + /// Used in: SUBSCRIBE, REQUEST_UPDATE, PUBLISH, PUBLISH_OK, SUBSCRIBE_NAMESPACE + Forward = 0x10, + /// Used in: SUBSCRIBE, FETCH, REQUEST_UPDATE, PUBLISH_OK + SubscriberPriority = 0x20, + /// Used in: SUBSCRIBE, PUBLISH_OK, REQUEST_UPDATE (renamed to SubscriptionLocationFilter per PR #1518) + SubscriptionFilter = 0x21, + /// Used in: SUBSCRIBE, SUBSCRIBE_OK, REQUEST_OK, PUBLISH, PUBLISH_OK, FETCH + GroupOrder = 0x22, + /// Used in: SUBSCRIBE, FETCH - Filter by subgroup ID ranges (PR #1518) + SubgroupFilter = 0x25, + /// Used in: SUBSCRIBE, FETCH - Filter by object ID ranges (PR #1518) + ObjectFilter = 0x26, + /// Used in: SUBSCRIBE, FETCH - Filter by priority ranges (PR #1518) + PriorityFilter = 0x27, + /// Used in: SUBSCRIBE, FETCH - Filter by property value ranges (PR #1518) + PropertyFilter = 0x28, + /// Used in: SUBSCRIBE_NAMESPACE - Track filter for top-N selection (PR #1518) + TrackFilter = 0x29, + /// Used in: PUBLISH, SUBSCRIBE_OK + DynamicGroups = 0x30, + /// Used in: PUBLISH_OK, SUBSCRIBE, REQUEST_UPDATE + NewGroupRequest = 0x32, +} + +impl From for u64 { + fn from(value: ParameterType) -> Self { + value as u64 + } +} diff --git a/moq-transport/src/message/publish.rs b/moq-transport/src/message/publish.rs index feea3639..467b51cd 100644 --- a/moq-transport/src/message/publish.rs +++ b/moq-transport/src/message/publish.rs @@ -1,9 +1,10 @@ -use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, -}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; +use crate::data::ExtensionHeaders; /// Sent by publisher to initiate a subscription to a track. +/// +/// Draft-16: Fields like group_order, content_exists, largest_location, forward +/// have been moved to Parameters (Section 9.2.2). #[derive(Clone, Debug, Eq, PartialEq)] pub struct Publish { /// The publish request ID @@ -14,14 +15,11 @@ pub struct Publish { pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) pub track_alias: u64, - pub group_order: GroupOrder, - pub content_exists: bool, - // The largest object available for this track, if content exists. - pub largest_location: Option, - pub forward: bool, - - /// Optional parameters + /// Optional parameters (may contain Forward, GroupOrder, LargestObject, PublisherPriority, etc.) pub params: KeyValuePairs, + + /// Track extensions + pub track_extensions: ExtensionHeaders, } impl Decode for Publish { @@ -32,31 +30,18 @@ impl Decode for Publish { let track_name = String::decode(r)?; let track_alias = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message, so validate it now so we can return a protocol error. - if group_order == GroupOrder::Publisher { - return Err(DecodeError::InvalidGroupOrder); - } - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; - let forward = bool::decode(r)?; - let params = KeyValuePairs::decode(r)?; + // Track Extensions use remaining bytes (no length prefix per draft-16) + let track_extensions = ExtensionHeaders::decode_remaining_bytes(r)?; + Ok(Self { id, track_namespace, track_name, track_alias, - group_order, - content_exists, - largest_location, - forward, params, + track_extensions, }) } } @@ -69,22 +54,8 @@ impl Encode for Publish { self.track_name.encode(w)?; self.track_alias.encode(w)?; - // GroupOrder enum has Publisher in it, but it's not allowed to be used in this - // publish message. - if self.group_order == GroupOrder::Publisher { - return Err(EncodeError::InvalidValue); - } - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } - self.forward.encode(w)?; self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -99,37 +70,16 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // Content exists = true - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - forward: true, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = Publish::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // Content exists = false let msg = Publish { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: false, - largest_location: None, - forward: true, params: kvps.clone(), + track_extensions: Default::default(), }; msg.encode(&mut buf).unwrap(); let decoded = Publish::decode(&mut buf).unwrap(); @@ -137,7 +87,7 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_no_params() { let mut buf = BytesMut::new(); let msg = Publish { @@ -145,32 +95,11 @@ mod tests { track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), track_alias: 212, - group_order: GroupOrder::Ascending, - content_exists: true, - largest_location: None, - forward: true, params: Default::default(), + track_extensions: Default::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } - - #[test] - fn encode_bad_group_order() { - let mut buf = BytesMut::new(); - - let msg = Publish { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - track_alias: 212, - group_order: GroupOrder::Publisher, - content_exists: false, - largest_location: None, - forward: true, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::InvalidValue)); + msg.encode(&mut buf).unwrap(); + let decoded = Publish::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/publish_namespace_done.rs b/moq-transport/src/message/publish_namespace_done.rs new file mode 100644 index 00000000..4540ab47 --- /dev/null +++ b/moq-transport/src/message/publish_namespace_done.rs @@ -0,0 +1,41 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; + +/// Sent by the publisher to terminate a PUBLISH_NAMESPACE. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PublishNamespaceDone { + pub track_namespace: TrackNamespace, +} + +impl Decode for PublishNamespaceDone { + fn decode(r: &mut R) -> Result { + let track_namespace = TrackNamespace::decode(r)?; + + Ok(Self { track_namespace }) + } +} + +impl Encode for PublishNamespaceDone { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.track_namespace.encode(w)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = PublishNamespaceDone { + track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), + }; + msg.encode(&mut buf).unwrap(); + let decoded = PublishNamespaceDone::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } +} diff --git a/moq-transport/src/message/publish_ok.rs b/moq-transport/src/message/publish_ok.rs index 5564c7ac..e376c89b 100644 --- a/moq-transport/src/message/publish_ok.rs +++ b/moq-transport/src/message/publish_ok.rs @@ -1,110 +1,30 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::FilterType; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; -/// Sent by the subscriber to request all future objects for the given track. +/// Sent by the subscriber to acknowledge a PUBLISH message and establish a subscription. /// -/// Objects will use the provided ID instead of the full track name, to save bytes. +/// Draft-16: All subscription properties (forward, subscriber_priority, group_order, +/// filter_type, etc.) are now in Parameters (Section 9.2.2). #[derive(Clone, Debug, Eq, PartialEq)] pub struct PublishOk { /// The request ID of the Publish this message is replying to. pub id: u64, - /// Forward Flag - pub forward: bool, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// The order the subscription will be delivered in - pub group_order: GroupOrder, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - - /// Optional parameters + /// Parameters (may contain Forward, SubscriberPriority, GroupOrder, SubscriptionFilter, etc.) pub params: KeyValuePairs, } impl Decode for PublishOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let forward = bool::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; - Ok(Self { - id, - forward, - subscriber_priority, - group_order, - filter_type, - start_location, - end_group_id, - params, - }) + Ok(Self { id, params }) } } impl Encode for PublishOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.forward.encode(w)?; - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -120,49 +40,11 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); - // FilterType = NextGroupStart - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteStart - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: kvps.clone(), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - - // FilterType = AbsoluteRange let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -171,49 +53,15 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_no_params() { let mut buf = BytesMut::new(); - // FilterType = AbsoluteStart - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = PublishOk { - id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id let msg = PublishOk { id: 12345, - forward: true, - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, params: Default::default(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + let decoded = PublishOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } } diff --git a/moq-transport/src/message/publisher.rs b/moq-transport/src/message/publisher.rs index 6cdf0750..61700289 100644 --- a/moq-transport/src/message/publisher.rs +++ b/moq-transport/src/message/publisher.rs @@ -48,14 +48,12 @@ macro_rules! publisher_msgs { publisher_msgs! { PublishNamespace, PublishNamespaceDone, + Namespace, Publish, PublishDone, SubscribeOk, - SubscribeError, TrackStatusOk, - TrackStatusError, FetchOk, - FetchError, - SubscribeNamespaceOk, - SubscribeNamespaceError, + RequestOk, + RequestError, } diff --git a/moq-transport/src/message/request_error.rs b/moq-transport/src/message/request_error.rs new file mode 100644 index 00000000..fce02c6f --- /dev/null +++ b/moq-transport/src/message/request_error.rs @@ -0,0 +1,87 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; + +/// REQUEST_ERROR message (draft-16 Section 9.8). +/// +/// Sent in response to any request (SUBSCRIBE, FETCH, PUBLISH, etc.) to indicate failure. +#[derive(Clone, Debug)] +pub struct RequestError { + pub id: u64, + + /// An error code identifying the failure reason. + pub error_code: u64, + + /// Minimum time in milliseconds before the request SHOULD be sent again, plus one. + /// A value of 0 means the request SHOULD NOT be retried. + /// A value of 1 means the request can be retried immediately. + pub retry_interval: u64, + + /// An optional, human-readable reason. + pub reason_phrase: ReasonPhrase, +} + +impl Decode for RequestError { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let error_code = u64::decode(r)?; + let retry_interval = u64::decode(r)?; + let reason_phrase = ReasonPhrase::decode(r)?; + + Ok(Self { + id, + error_code, + retry_interval, + reason_phrase, + }) + } +} + +impl Encode for RequestError { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.error_code.encode(w)?; + self.retry_interval.encode(w)?; + self.reason_phrase.encode(w)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = RequestError { + id: 42, + error_code: 0x1, + retry_interval: 5000, + reason_phrase: ReasonPhrase("unauthorized".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.error_code, msg.error_code); + assert_eq!(decoded.retry_interval, msg.retry_interval); + } + + #[test] + fn encode_decode_no_retry() { + let mut buf = BytesMut::new(); + + let msg = RequestError { + id: 10, + error_code: 0x0, + retry_interval: 0, + reason_phrase: ReasonPhrase("internal error".to_string()), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestError::decode(&mut buf).unwrap(); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.error_code, msg.error_code); + assert_eq!(decoded.retry_interval, 0); + } +} diff --git a/moq-transport/src/message/request_ok.rs b/moq-transport/src/message/request_ok.rs new file mode 100644 index 00000000..9ceb8879 --- /dev/null +++ b/moq-transport/src/message/request_ok.rs @@ -0,0 +1,45 @@ +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; + +/// Reqeust Ok +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RequestOk { + /// The SubscribeNamespace/PublishNamespace request ID this message is replying to. + pub id: u64, + + /// Optional parameters + pub params: KeyValuePairs, +} + +impl Decode for RequestOk { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let params = KeyValuePairs::decode(r)?; + Ok(Self { id, params }) + } +} + +impl Encode for RequestOk { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.params.encode(w) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BytesMut; + + #[test] + fn encode_decode() { + let mut buf = BytesMut::new(); + + let msg = RequestOk { + id: 12345, + params: KeyValuePairs::new(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = RequestOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } +} diff --git a/moq-transport/src/message/subscribe.rs b/moq-transport/src/message/subscribe.rs index 828ccb92..e2a3d0bd 100644 --- a/moq-transport/src/message/subscribe.rs +++ b/moq-transport/src/message/subscribe.rs @@ -1,8 +1,4 @@ -use crate::coding::{ - Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location, TrackNamespace, -}; -use crate::message::FilterType; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackNamespace}; /// Sent by the subscriber to request all future objects for the given track. /// @@ -16,22 +12,9 @@ pub struct Subscribe { pub track_namespace: TrackNamespace, pub track_name: String, // TODO SLG - consider making a FullTrackName base struct (total size limit of 4096) - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters + /// NOTE(itzmanish): since the forward and other fields are moved to parameters + /// we need to validate them on publisher logic pub params: KeyValuePairs, } @@ -42,41 +25,12 @@ impl Decode for Subscribe { let track_namespace = TrackNamespace::decode(r)?; let track_name = String::decode(r)?; - let subscriber_priority = u8::decode(r)?; - let group_order = GroupOrder::decode(r)?; - - let forward = bool::decode(r)?; - - let filter_type = FilterType::decode(r)?; - let start_location: Option; - let end_group_id: Option; - match filter_type { - FilterType::AbsoluteStart => { - start_location = Some(Location::decode(r)?); - end_group_id = None; - } - FilterType::AbsoluteRange => { - start_location = Some(Location::decode(r)?); - end_group_id = Some(u64::decode(r)?); - } - _ => { - start_location = None; - end_group_id = None; - } - } - let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace, track_name, - subscriber_priority, - group_order, - forward, - filter_type, - start_location, - end_group_id, params, }) } @@ -88,37 +42,6 @@ impl Encode for Subscribe { self.track_namespace.encode(w)?; self.track_name.encode(w)?; - - self.subscriber_priority.encode(w)?; - self.group_order.encode(w)?; - - self.forward.encode(w)?; - - self.filter_type.encode(w)?; - match self.filter_type { - FilterType::AbsoluteStart => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - // Just ignore end_group_id if it happens to be set - } - FilterType::AbsoluteRange => { - if let Some(start) = &self.start_location { - start.encode(w)?; - } else { - return Err(EncodeError::MissingField("StartLocation".to_string())); - } - if let Some(end) = self.end_group_id { - end.encode(w)?; - } else { - return Err(EncodeError::MissingField("EndGroupId".to_string())); - } - } - _ => {} - } - self.params.encode(w)?; Ok(()) @@ -143,12 +66,6 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::NextGroupStart, - start_location: None, - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -160,12 +77,6 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); @@ -177,69 +88,10 @@ mod tests { id: 12345, track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: Some(23456), params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = Subscribe::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } - - #[test] - fn encode_missing_fields() { - let mut buf = BytesMut::new(); - - // FilterType = AbsoluteStart - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteStart, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing start_location - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: None, - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - - // FilterType = AbsoluteRange - missing end_group_id - let msg = Subscribe { - id: 12345, - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - track_name: "audiotrack".to_string(), - subscriber_priority: 127, - group_order: GroupOrder::Publisher, - forward: true, - filter_type: FilterType::AbsoluteRange, - start_location: Some(Location::new(12345, 67890)), - end_group_id: None, - params: Default::default(), - }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - } } diff --git a/moq-transport/src/message/subscribe_namespace.rs b/moq-transport/src/message/subscribe_namespace.rs index ba292d69..92662036 100644 --- a/moq-transport/src/message/subscribe_namespace.rs +++ b/moq-transport/src/message/subscribe_namespace.rs @@ -9,19 +9,36 @@ pub struct SubscribeNamespace { /// The track namespace prefix pub track_namespace_prefix: TrackNamespace, + /// The Forward value that new subscriptions resulting from this SUBSCRIBE_NAMESPACE will have + pub forward: u8, + /// Optional parameters pub params: KeyValuePairs, } +impl SubscribeNamespace { + /// Creates a new SubscribeNamespace message. + pub fn new(id: u64, track_namespace_prefix: TrackNamespace, forward: u8) -> Self { + Self { + id, + track_namespace_prefix, + forward, + params: KeyValuePairs::new(), + } + } +} + impl Decode for SubscribeNamespace { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_namespace_prefix = TrackNamespace::decode(r)?; + let forward = u8::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, track_namespace_prefix, + forward, params, }) } @@ -31,6 +48,7 @@ impl Encode for SubscribeNamespace { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_namespace_prefix.encode(w)?; + self.forward.encode(w)?; self.params.encode(w)?; Ok(()) @@ -52,11 +70,14 @@ mod tests { let msg = SubscribeNamespace { id: 12345, + forward: 0, track_namespace_prefix: TrackNamespace::from_utf8_path("path/prefix"), params: kvps, }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); + assert_eq!(decoded.id, msg.id); + assert_eq!(decoded.forward, msg.forward); + assert_eq!(decoded.track_namespace_prefix, msg.track_namespace_prefix); } } diff --git a/moq-transport/src/message/subscribe_ok.rs b/moq-transport/src/message/subscribe_ok.rs index 97bb3aa7..c87f952b 100644 --- a/moq-transport/src/message/subscribe_ok.rs +++ b/moq-transport/src/message/subscribe_ok.rs @@ -1,5 +1,4 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; -use crate::message::GroupOrder; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, TrackExtensions}; /// Sent by the publisher to accept a Subscribe. #[derive(Clone, Debug, Eq, PartialEq)] @@ -10,42 +9,26 @@ pub struct SubscribeOk { /// The identifier used for this track in Subgroups or Datagrams. pub track_alias: u64, - /// The time in milliseconds after which the subscription is not longer valid. - pub expires: u64, - - /// Order groups will be delivered in - pub group_order: GroupOrder, - - /// If content_exists, then largest_location is the location of the largest - /// object available for this track - pub content_exists: bool, - pub largest_location: Option, // Only provided if content_exists is 1/true - - /// Subscribe Parameters + /// Subscribe Parameters (has count prefix per spec) pub params: KeyValuePairs, + + /// Track extensions (NO prefix per draft-16 Section 9.10 - reads until end of message) + pub track_extensions: TrackExtensions, } impl Decode for SubscribeOk { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; let track_alias = u64::decode(r)?; - let expires = u64::decode(r)?; - let group_order = GroupOrder::decode(r)?; - let content_exists = bool::decode(r)?; - let largest_location = match content_exists { - true => Some(Location::decode(r)?), - false => None, - }; let params = KeyValuePairs::decode(r)?; + // Track extensions have NO prefix - read until end of message + let track_extensions = TrackExtensions::decode(r)?; Ok(Self { id, track_alias, - expires, - group_order, - content_exists, - largest_location, params, + track_extensions, }) } } @@ -54,17 +37,8 @@ impl Encode for SubscribeOk { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; self.track_alias.encode(w)?; - self.expires.encode(w)?; - self.group_order.encode(w)?; - self.content_exists.encode(w)?; - if self.content_exists { - if let Some(largest) = &self.largest_location { - largest.encode(w)?; - } else { - return Err(EncodeError::MissingField("LargestLocation".to_string())); - } - } self.params.encode(w)?; + self.track_extensions.encode(w)?; Ok(()) } @@ -83,14 +57,15 @@ mod tests { let mut kvps = KeyValuePairs::new(); kvps.set_bytesvalue(123, vec![0x00, 0x01, 0x02, 0x03]); + // Track extensions (no prefix) + let mut ext = TrackExtensions::new(); + ext.set_intvalue(2, 42); + let msg = SubscribeOk { id: 12345, track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: Some(Location::new(2, 3)), - params: kvps.clone(), + params: kvps, + track_extensions: ext, }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeOk::decode(&mut buf).unwrap(); @@ -98,19 +73,22 @@ mod tests { } #[test] - fn encode_missing_fields() { + fn encode_decode_empty_extensions() { let mut buf = BytesMut::new(); let msg = SubscribeOk { - id: 12345, - track_alias: 100, - expires: 3600, - group_order: GroupOrder::Publisher, - content_exists: true, - largest_location: None, - params: Default::default(), + id: 0, + track_alias: 0, + params: KeyValuePairs::new(), + track_extensions: TrackExtensions::new(), }; - let encoded = msg.encode(&mut buf); - assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); + msg.encode(&mut buf).unwrap(); + // Expected: id=0 (1 byte), track_alias=0 (1 byte), params_count=0 (1 byte), NO track_extensions bytes + assert_eq!(buf.to_vec(), vec![0x00, 0x00, 0x00]); + let decoded = SubscribeOk::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } + + // Note: encode_missing_fields test removed — content_exists was removed + // from the struct in draft-16; no fields to validate at encode time. } diff --git a/moq-transport/src/message/subscribe_update.rs b/moq-transport/src/message/subscribe_update.rs index 3bf20e23..895378d9 100644 --- a/moq-transport/src/message/subscribe_update.rs +++ b/moq-transport/src/message/subscribe_update.rs @@ -1,53 +1,31 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs, Location}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; -/// Sent by the subscriber to request all future objects for the given track. +/// REQUEST_UPDATE message (draft-16 Section 9.11). /// -/// Objects will use the provided ID instead of the full track name, to save bytes. +/// Sent to modify an existing request (SUBSCRIBE, PUBLISH, FETCH, etc.). +/// Parameters previously set that are not present in the update remain unchanged. #[derive(Clone, Debug, Eq, PartialEq)] pub struct SubscribeUpdate { - /// The request ID of this request + /// The request ID of this REQUEST_UPDATE pub id: u64, - /// The request ID of the SUBSCRIBE this message is updating. - pub subscription_request_id: u64, + /// The request ID of the existing request this message is updating. + pub existing_request_id: u64, - /// The starting location - pub start_location: Location, - /// The end Group ID, plus 1. A value of 0 means the subscription is open-ended. - pub end_group_id: u64, - - /// Subscriber Priority - pub subscriber_priority: u8, - - /// Forward Flag - pub forward: bool, - - /// Optional parameters + /// Parameters to update (draft-16 Section 9.2.2). + /// Parameters not present remain unchanged from the original request. pub params: KeyValuePairs, } impl Decode for SubscribeUpdate { fn decode(r: &mut R) -> Result { let id = u64::decode(r)?; - - let subscription_request_id = u64::decode(r)?; - - let start_location = Location::decode(r)?; - let end_group_id = u64::decode(r)?; - - let subscriber_priority = u8::decode(r)?; - - let forward = bool::decode(r)?; - + let existing_request_id = u64::decode(r)?; let params = KeyValuePairs::decode(r)?; Ok(Self { id, - subscription_request_id, - start_location, - end_group_id, - subscriber_priority, - forward, + existing_request_id, params, }) } @@ -56,16 +34,7 @@ impl Decode for SubscribeUpdate { impl Encode for SubscribeUpdate { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { self.id.encode(w)?; - - self.subscription_request_id.encode(w)?; - - self.start_location.encode(w)?; - self.end_group_id.encode(w)?; - - self.subscriber_priority.encode(w)?; - - self.forward.encode(w)?; - + self.existing_request_id.encode(w)?; self.params.encode(w)?; Ok(()) @@ -81,21 +50,30 @@ mod tests { fn encode_decode() { let mut buf = BytesMut::new(); - // One parameter for testing let mut kvps = KeyValuePairs::new(); kvps.set_intvalue(124, 456); let msg = SubscribeUpdate { id: 1000, - subscription_request_id: 924, - start_location: Location::new(1, 1), - end_group_id: 100000, - subscriber_priority: 127, - forward: true, + existing_request_id: 924, params: kvps.clone(), }; msg.encode(&mut buf).unwrap(); let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); } + + #[test] + fn encode_decode_empty_params() { + let mut buf = BytesMut::new(); + + let msg = SubscribeUpdate { + id: 5, + existing_request_id: 3, + params: KeyValuePairs::new(), + }; + msg.encode(&mut buf).unwrap(); + let decoded = SubscribeUpdate::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + } } diff --git a/moq-transport/src/message/subscriber.rs b/moq-transport/src/message/subscriber.rs index 0a11fb9e..3c433149 100644 --- a/moq-transport/src/message/subscriber.rs +++ b/moq-transport/src/message/subscriber.rs @@ -53,10 +53,8 @@ subscriber_msgs! { FetchCancel, TrackStatus, SubscribeNamespace, - UnsubscribeNamespace, PublishNamespaceCancel, - PublishNamespaceOk, - PublishNamespaceError, + RequestOk, PublishOk, - PublishError, + RequestError, } diff --git a/moq-transport/src/setup/auth_token.rs b/moq-transport/src/setup/auth_token.rs new file mode 100644 index 00000000..a1b22be5 --- /dev/null +++ b/moq-transport/src/setup/auth_token.rs @@ -0,0 +1,298 @@ +//! Authorization Token support for MOQT. +//! +//! This module provides support for authorization tokens as defined in the MOQT specification. +//! Tokens can be sent inline or referenced by alias to avoid retransmission of large tokens. + +use std::collections::HashMap; + +/// Authorization Token Types +/// +/// Defines how an authorization token is transmitted in messages. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u8)] +pub enum AuthTokenType { + /// No authorization token present + None = 0x0, + /// Authorization token sent inline + Inline = 0x1, + /// Authorization token referenced by alias + Alias = 0x2, + /// Authorization token cached with new alias + Store = 0x3, + /// Use previously stored token (DELETE is not allowed in CLIENT_SETUP) + UseAlias = 0x4, +} + +impl TryFrom for AuthTokenType { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 0x0 => Ok(Self::None), + 0x1 => Ok(Self::Inline), + 0x2 => Ok(Self::Alias), + 0x3 => Ok(Self::Store), + 0x4 => Ok(Self::UseAlias), + _ => Err(()), + } + } +} + +/// An authorization token value +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct AuthToken { + /// The raw token bytes + pub token: Vec, + /// Optional alias for caching + pub alias: Option, +} + +impl AuthToken { + /// Create a new authorization token + pub fn new(token: Vec) -> Self { + Self { token, alias: None } + } + + /// Create a new authorization token with an alias for caching + pub fn with_alias(token: Vec, alias: u64) -> Self { + Self { + token, + alias: Some(alias), + } + } + + /// Check if the token is empty + pub fn is_empty(&self) -> bool { + self.token.is_empty() + } +} + +/// Authorization Token Cache +/// +/// Stores authorization tokens by their alias for efficient re-use across multiple messages. +/// The cache enforces a maximum size limit as negotiated during setup. +#[derive(Debug)] +pub struct AuthTokenCache { + /// Maximum number of tokens that can be cached + max_size: usize, + /// Cached tokens by alias + tokens: HashMap>, + /// Next available alias (for server-assigned aliases) + next_alias: u64, +} + +impl Default for AuthTokenCache { + fn default() -> Self { + Self::new(0) + } +} + +impl AuthTokenCache { + /// Create a new auth token cache with the specified maximum size + pub fn new(max_size: usize) -> Self { + Self { + max_size, + tokens: HashMap::new(), + next_alias: 0, + } + } + + /// Get the maximum cache size + pub fn max_size(&self) -> usize { + self.max_size + } + + /// Set the maximum cache size (typically from setup negotiation) + pub fn set_max_size(&mut self, max_size: usize) { + self.max_size = max_size; + } + + /// Get the current number of cached tokens + pub fn len(&self) -> usize { + self.tokens.len() + } + + /// Check if the cache is empty + pub fn is_empty(&self) -> bool { + self.tokens.is_empty() + } + + /// Check if the cache is at capacity + pub fn is_full(&self) -> bool { + self.tokens.len() >= self.max_size + } + + /// Store a token with the given alias + /// + /// Returns an error if: + /// - The cache is at capacity + /// - The alias is already in use + pub fn store(&mut self, alias: u64, token: Vec) -> Result<(), AuthTokenCacheError> { + if self.max_size == 0 { + return Err(AuthTokenCacheError::CacheDisabled); + } + if self.tokens.len() >= self.max_size { + return Err(AuthTokenCacheError::CacheOverflow); + } + if self.tokens.contains_key(&alias) { + return Err(AuthTokenCacheError::DuplicateAlias(alias)); + } + self.tokens.insert(alias, token); + Ok(()) + } + + /// Store a token with an auto-generated alias + /// + /// Returns the assigned alias, or an error if the cache is full + pub fn store_with_auto_alias(&mut self, token: Vec) -> Result { + if self.max_size == 0 { + return Err(AuthTokenCacheError::CacheDisabled); + } + if self.tokens.len() >= self.max_size { + return Err(AuthTokenCacheError::CacheOverflow); + } + + // Find next available alias + while self.tokens.contains_key(&self.next_alias) { + self.next_alias = self.next_alias.wrapping_add(1); + } + + let alias = self.next_alias; + self.tokens.insert(alias, token); + self.next_alias = self.next_alias.wrapping_add(1); + + Ok(alias) + } + + /// Get a token by its alias + pub fn get(&self, alias: u64) -> Option<&Vec> { + self.tokens.get(&alias) + } + + /// Remove a token by its alias + pub fn remove(&mut self, alias: u64) -> Option> { + self.tokens.remove(&alias) + } + + /// Clear all cached tokens + pub fn clear(&mut self) { + self.tokens.clear(); + } +} + +/// Errors that can occur when working with the auth token cache +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AuthTokenCacheError { + /// The cache is disabled (max_size is 0) + CacheDisabled, + /// The cache is full and cannot accept more tokens + CacheOverflow, + /// The alias is already in use + DuplicateAlias(u64), + /// The alias was not found in the cache + UnknownAlias(u64), + /// The token is malformed + MalformedToken, + /// The token has expired + ExpiredToken, +} + +impl std::fmt::Display for AuthTokenCacheError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::CacheDisabled => write!(f, "authorization token cache is disabled"), + Self::CacheOverflow => write!(f, "authorization token cache is full"), + Self::DuplicateAlias(alias) => { + write!(f, "duplicate authorization token alias: {}", alias) + } + Self::UnknownAlias(alias) => { + write!(f, "unknown authorization token alias: {}", alias) + } + Self::MalformedToken => write!(f, "malformed authorization token"), + Self::ExpiredToken => write!(f, "expired authorization token"), + } + } +} + +impl std::error::Error for AuthTokenCacheError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auth_token_type_conversion() { + assert_eq!(AuthTokenType::try_from(0u8), Ok(AuthTokenType::None)); + assert_eq!(AuthTokenType::try_from(1u8), Ok(AuthTokenType::Inline)); + assert_eq!(AuthTokenType::try_from(2u8), Ok(AuthTokenType::Alias)); + assert_eq!(AuthTokenType::try_from(3u8), Ok(AuthTokenType::Store)); + assert_eq!(AuthTokenType::try_from(4u8), Ok(AuthTokenType::UseAlias)); + assert!(AuthTokenType::try_from(5u8).is_err()); + } + + #[test] + fn test_auth_token() { + let token = AuthToken::new(vec![1, 2, 3, 4]); + assert!(!token.is_empty()); + assert!(token.alias.is_none()); + + let token_with_alias = AuthToken::with_alias(vec![5, 6, 7, 8], 42); + assert_eq!(token_with_alias.alias, Some(42)); + + let empty_token = AuthToken::default(); + assert!(empty_token.is_empty()); + } + + #[test] + fn test_auth_token_cache() { + let mut cache = AuthTokenCache::new(3); + assert_eq!(cache.max_size(), 3); + assert!(cache.is_empty()); + + // Store tokens + cache.store(1, vec![1, 2, 3]).unwrap(); + cache.store(2, vec![4, 5, 6]).unwrap(); + assert_eq!(cache.len(), 2); + assert!(!cache.is_full()); + + // Get token + assert_eq!(cache.get(1), Some(&vec![1, 2, 3])); + assert_eq!(cache.get(2), Some(&vec![4, 5, 6])); + assert_eq!(cache.get(3), None); + + // Store with auto-alias + let alias = cache.store_with_auto_alias(vec![7, 8, 9]).unwrap(); + assert!(cache.is_full()); + + // Cache overflow + assert_eq!( + cache.store(99, vec![10, 11]), + Err(AuthTokenCacheError::CacheOverflow) + ); + + // Duplicate alias + cache.remove(alias); + assert_eq!( + cache.store(1, vec![10, 11]), + Err(AuthTokenCacheError::DuplicateAlias(1)) + ); + + // Remove and clear + assert!(cache.remove(1).is_some()); + cache.clear(); + assert!(cache.is_empty()); + } + + #[test] + fn test_auth_token_cache_disabled() { + let mut cache = AuthTokenCache::new(0); + assert_eq!( + cache.store(1, vec![1, 2, 3]), + Err(AuthTokenCacheError::CacheDisabled) + ); + assert_eq!( + cache.store_with_auto_alias(vec![1, 2, 3]), + Err(AuthTokenCacheError::CacheDisabled) + ); + } +} diff --git a/moq-transport/src/setup/client.rs b/moq-transport/src/setup/client.rs index edefb32e..2d5b7f7e 100644 --- a/moq-transport/src/setup/client.rs +++ b/moq-transport/src/setup/client.rs @@ -1,14 +1,9 @@ -use super::Versions; use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the client to setup the session. -/// This CLIENT_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x20 vs 0x40 for versions <= 10. +/// Draft-16: version negotiation uses ALPN; no Versions field in CLIENT_SETUP. #[derive(Debug)] pub struct Client { - /// The list of supported versions in preferred order. - pub versions: Versions, - /// Setup Parameters, ie: PATH, MAX_REQUEST_ID, /// MAX_AUTH_TOKEN_CACHE_SIZE, AUTHORIZATION_TOKEN, etc. pub params: KeyValuePairs, @@ -26,10 +21,9 @@ impl Decode for Client { let _len = u16::decode(r)?; // TODO: Check the length of the message. - let versions = Versions::decode(r)?; let params = KeyValuePairs::decode(r)?; - Ok(Self { versions, params }) + Ok(Self { params }) } } @@ -45,7 +39,6 @@ impl Encode for Client { // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); - self.versions.encode(&mut buf).unwrap(); self.params.encode(&mut buf).unwrap(); // Make sure buf.len() <= u16::MAX @@ -66,7 +59,7 @@ impl Encode for Client { #[cfg(test)] mod tests { use super::*; - use crate::setup::{ParameterType, Version}; + use crate::setup::ParameterType; use bytes::BytesMut; #[test] @@ -76,26 +69,22 @@ mod tests { let mut params = KeyValuePairs::default(); params.set_bytesvalue(ParameterType::Path.into(), "testpath".as_bytes().to_vec()); - let client = Client { - versions: [Version::DRAFT_13].into(), - params, - }; + let client = Client { params }; client.encode(&mut buf).unwrap(); + // Draft-16: no Versions field, just Type + Length + Parameters #[rustfmt::skip] assert_eq!( buf.to_vec(), vec![ - 0x20, // Type - 0x00, 0x14, // Length - 0x01, // 1 Version - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0D, // Version DRAFT_13 (0xff00000D) - 0x01, // 1 Param - 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, // Key=1 (Path), Value="testpath" + 0x20, // Type (CLIENT_SETUP) + 0x00, 0x0b, // Length = 11 bytes + 0x01, // 1 Parameter (count) + // Delta=1 (Path), Length=8, "testpath" + 0x01, 0x08, 0x74, 0x65, 0x73, 0x74, 0x70, 0x61, 0x74, 0x68, ] ); let decoded = Client::decode(&mut buf).unwrap(); - assert_eq!(decoded.versions, client.versions); assert_eq!(decoded.params, client.params); } } diff --git a/moq-transport/src/setup/mod.rs b/moq-transport/src/setup/mod.rs index 44e3664b..1098ed7a 100644 --- a/moq-transport/src/setup/mod.rs +++ b/moq-transport/src/setup/mod.rs @@ -4,14 +4,16 @@ //! The client sends the [Client] message and the server responds with the [Server] message. //! Both sides negotate the [Version] and [Role]. +mod auth_token; mod client; mod param_types; mod server; mod version; +pub use auth_token::*; pub use client::*; pub use param_types::*; pub use server::*; pub use version::*; -pub const ALPN: &[u8] = b"moq-00"; +pub const ALPN: &[u8] = b"moqt-16"; diff --git a/moq-transport/src/setup/param_types.rs b/moq-transport/src/setup/param_types.rs index 2f4e9862..65f731e5 100644 --- a/moq-transport/src/setup/param_types.rs +++ b/moq-transport/src/setup/param_types.rs @@ -7,7 +7,11 @@ pub enum ParameterType { AuthorizationToken = 0x3, MaxAuthTokenCacheSize = 0x4, Authority = 0x5, + /// Maximum number of Range pairs allowed per subscription/fetch (PR #1518) + MaxFilterRanges = 0x6, MOQTImplementation = 0x7, + /// Maximum value for MaxTracksSelected parameter in TRACK_FILTER (PR #1518) + MaxTracksSelected = 0x8, } impl From for u64 { diff --git a/moq-transport/src/setup/server.rs b/moq-transport/src/setup/server.rs index 3fae91f8..7880228b 100644 --- a/moq-transport/src/setup/server.rs +++ b/moq-transport/src/setup/server.rs @@ -1,14 +1,9 @@ -use super::Version; use crate::coding::{Decode, DecodeError, Encode, EncodeError, KeyValuePairs}; /// Sent by the server in response to a client setup. -/// This SERVER_SETUP message is used by moq-transport draft versions 11 and later. -/// Id = 0x21 vs 0x41 for versions <= 10. +/// Draft-16: version negotiation uses ALPN; no Versions field in SERVER_SETUP. #[derive(Debug)] pub struct Server { - /// The list of supported versions in preferred order. - pub version: Version, - /// Setup Parameters, ie: MAX_REQUEST_ID, MAX_AUTH_TOKEN_CACHE_SIZE, /// AUTHORIZATION_TOKEN, etc. pub params: KeyValuePairs, @@ -26,10 +21,9 @@ impl Decode for Server { let _len = u16::decode(r)?; // TODO: Check the length of the message. - let version = Version::decode(r)?; let params = KeyValuePairs::decode(r)?; - Ok(Self { version, params }) + Ok(Self { params }) } } @@ -44,7 +38,6 @@ impl Encode for Server { // write the length later, to avoid the copy of the message bytes? let mut buf = Vec::new(); - self.version.encode(&mut buf).unwrap(); self.params.encode(&mut buf).unwrap(); // Make sure buf.len() <= u16::MAX @@ -75,27 +68,24 @@ mod tests { let mut params = KeyValuePairs::default(); params.set_intvalue(ParameterType::MaxRequestId.into(), 1000); - let server = Server { - version: Version::DRAFT_14, - params, - }; + let server = Server { params }; server.encode(&mut buf).unwrap(); + // Draft-16: no Versions field, just Type + Length + Parameters #[rustfmt::skip] assert_eq!( buf.to_vec(), vec![ - 0x21, // Type - 0x00, 0x0c, // Length - 0xC0, 0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x0E, // Version DRAFT_14 (0xff00000E) - 0x01, // 1 Param - 0x02, 0x43, 0xe8, // Key=2 (MaxRequestId), Value=1000 + 0x21, // Type (SERVER_SETUP) + 0x00, 0x04, // Length = 4 bytes + 0x01, // 1 Parameter (count) + // Delta=2 (MaxRequestId), Value=1000 + 0x02, 0x43, 0xe8, ] ); let decoded = Server::decode(&mut buf).unwrap(); - assert_eq!(decoded.version, server.version); assert_eq!(decoded.params, server.params); } } diff --git a/moq-transport/src/setup/version.rs b/moq-transport/src/setup/version.rs index fa896e7d..2fb41ae4 100644 --- a/moq-transport/src/setup/version.rs +++ b/moq-transport/src/setup/version.rs @@ -23,6 +23,12 @@ impl Version { /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-14.html pub const DRAFT_14: Version = Version(0xff00000e); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-15.html + pub const DRAFT_15: Version = Version(0xff00000f); + + /// https://www.ietf.org/archive/id/draft-ietf-moq-transport-16.html + pub const DRAFT_16: Version = Version(0xff000010); } impl From for Version { From 6103ab72319680c2bbef8eecf66edf442112f654 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:14 -0700 Subject: [PATCH 02/21] Update data plane and session layer for draft-16 --- moq-transport/src/data/datagram.rs | 357 +++++++--- moq-transport/src/data/extension_headers.rs | 127 +++- moq-transport/src/data/extension_types.rs | 38 ++ moq-transport/src/data/header.rs | 96 ++- moq-transport/src/data/mod.rs | 2 + moq-transport/src/data/subgroup.rs | 51 +- moq-transport/src/mlog/events.rs | 158 +++-- moq-transport/src/session/error.rs | 14 +- moq-transport/src/session/mod.rs | 197 ++++-- .../src/session/publish_namespace.rs | 157 +++++ .../src/session/publish_namespace_received.rs | 116 ++++ moq-transport/src/session/publish_received.rs | 282 ++++++++ moq-transport/src/session/published.rs | 621 ++++++++++++++++++ moq-transport/src/session/publisher.rs | 365 +++++----- moq-transport/src/session/reader.rs | 6 +- moq-transport/src/session/subscribe.rs | 66 +- .../src/session/subscribe_namespace.rs | 143 ++++ .../session/subscribe_namespace_received.rs | 148 +++++ moq-transport/src/session/subscribed.rs | 94 ++- moq-transport/src/session/subscriber.rs | 462 ++++++++++--- .../src/session/track_status_requested.rs | 3 +- 21 files changed, 2884 insertions(+), 619 deletions(-) create mode 100644 moq-transport/src/data/extension_types.rs create mode 100644 moq-transport/src/session/publish_namespace.rs create mode 100644 moq-transport/src/session/publish_namespace_received.rs create mode 100644 moq-transport/src/session/publish_received.rs create mode 100644 moq-transport/src/session/published.rs create mode 100644 moq-transport/src/session/subscribe_namespace.rs create mode 100644 moq-transport/src/session/subscribe_namespace_received.rs diff --git a/moq-transport/src/data/datagram.rs b/moq-transport/src/data/datagram.rs index 7a521319..61d36ce1 100644 --- a/moq-transport/src/data/datagram.rs +++ b/moq-transport/src/data/datagram.rs @@ -3,6 +3,7 @@ use crate::data::{ExtensionHeaders, ObjectStatus}; #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum DatagramType { + // Payload types with Priority Present (0x00-0x07) ObjectIdPayload = 0x00, ObjectIdPayloadExt = 0x01, ObjectIdPayloadEndOfGroup = 0x02, @@ -11,13 +12,125 @@ pub enum DatagramType { PayloadExt = 0x05, PayloadEndOfGroup = 0x06, PayloadExtEndOfGroup = 0x07, + // Payload types with Priority Not Present (0x08-0x0F) + ObjectIdPayloadNoPriority = 0x08, + ObjectIdPayloadExtNoPriority = 0x09, + ObjectIdPayloadEndOfGroupNoPriority = 0x0a, + ObjectIdPayloadExtEndOfGroupNoPriority = 0x0b, + PayloadNoPriority = 0x0c, + PayloadExtNoPriority = 0x0d, + PayloadEndOfGroupNoPriority = 0x0e, + PayloadExtEndOfGroupNoPriority = 0x0f, + // Status types with Priority Present (0x20-0x25) ObjectIdStatus = 0x20, ObjectIdStatusExt = 0x21, + Status = 0x24, + StatusExt = 0x25, + // Status types with Priority Not Present (0x28-0x2D) + ObjectIdStatusNoPriority = 0x28, + ObjectIdStatusExtNoPriority = 0x29, + StatusNoPriority = 0x2c, + StatusExtNoPriority = 0x2d, +} + +impl DatagramType { + /// Returns true if this datagram type has the Object ID field present + pub fn has_object_id(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayload + | DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadNoPriority + | DatagramType::ObjectIdPayloadExtNoPriority + | DatagramType::ObjectIdPayloadEndOfGroupNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::ObjectIdStatusNoPriority + | DatagramType::ObjectIdStatusExtNoPriority + ) + } + + /// Returns true if this datagram type has the Publisher Priority field present + pub fn has_priority(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayload + | DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::Payload + | DatagramType::PayloadExt + | DatagramType::PayloadEndOfGroup + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::Status + | DatagramType::StatusExt + ) + } + + /// Returns true if this datagram type has extension headers + pub fn has_extensions(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayloadExt + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::PayloadExt + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadExtNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::PayloadExtNoPriority + | DatagramType::PayloadExtEndOfGroupNoPriority + | DatagramType::ObjectIdStatusExt + | DatagramType::StatusExt + | DatagramType::ObjectIdStatusExtNoPriority + | DatagramType::StatusExtNoPriority + ) + } + + /// Returns true if this is a status datagram (no payload) + pub fn is_status(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdStatus + | DatagramType::ObjectIdStatusExt + | DatagramType::Status + | DatagramType::StatusExt + | DatagramType::ObjectIdStatusNoPriority + | DatagramType::ObjectIdStatusExtNoPriority + | DatagramType::StatusNoPriority + | DatagramType::StatusExtNoPriority + ) + } + + /// Returns true if this is a payload datagram + pub fn is_payload(&self) -> bool { + !self.is_status() + } + + /// Returns true if this datagram type indicates end of group + pub fn is_end_of_group(&self) -> bool { + matches!( + *self, + DatagramType::ObjectIdPayloadEndOfGroup + | DatagramType::ObjectIdPayloadExtEndOfGroup + | DatagramType::PayloadEndOfGroup + | DatagramType::PayloadExtEndOfGroup + | DatagramType::ObjectIdPayloadEndOfGroupNoPriority + | DatagramType::ObjectIdPayloadExtEndOfGroupNoPriority + | DatagramType::PayloadEndOfGroupNoPriority + | DatagramType::PayloadExtEndOfGroupNoPriority + ) + } } impl Decode for DatagramType { fn decode(r: &mut B) -> Result { match u64::decode(r)? { + // Payload types with Priority Present (0x00-0x07) 0x00 => Ok(Self::ObjectIdPayload), 0x01 => Ok(Self::ObjectIdPayloadExt), 0x02 => Ok(Self::ObjectIdPayloadEndOfGroup), @@ -26,8 +139,25 @@ impl Decode for DatagramType { 0x05 => Ok(Self::PayloadExt), 0x06 => Ok(Self::PayloadEndOfGroup), 0x07 => Ok(Self::PayloadExtEndOfGroup), + // Payload types with Priority Not Present (0x08-0x0F) + 0x08 => Ok(Self::ObjectIdPayloadNoPriority), + 0x09 => Ok(Self::ObjectIdPayloadExtNoPriority), + 0x0a => Ok(Self::ObjectIdPayloadEndOfGroupNoPriority), + 0x0b => Ok(Self::ObjectIdPayloadExtEndOfGroupNoPriority), + 0x0c => Ok(Self::PayloadNoPriority), + 0x0d => Ok(Self::PayloadExtNoPriority), + 0x0e => Ok(Self::PayloadEndOfGroupNoPriority), + 0x0f => Ok(Self::PayloadExtEndOfGroupNoPriority), + // Status types with Priority Present (0x20-0x25) 0x20 => Ok(Self::ObjectIdStatus), 0x21 => Ok(Self::ObjectIdStatusExt), + 0x24 => Ok(Self::Status), + 0x25 => Ok(Self::StatusExt), + // Status types with Priority Not Present (0x28-0x2D) + 0x28 => Ok(Self::ObjectIdStatusNoPriority), + 0x29 => Ok(Self::ObjectIdStatusExtNoPriority), + 0x2c => Ok(Self::StatusNoPriority), + 0x2d => Ok(Self::StatusExtNoPriority), _ => Err(DecodeError::InvalidDatagramType), } } @@ -56,9 +186,10 @@ pub struct Datagram { pub object_id: Option, /// Publisher priority, where **smaller** values are sent first. - pub publisher_priority: u8, + /// Optional when using NoPriority datagram types (0x08-0x0F, 0x28-0x2D). + pub publisher_priority: Option, - /// Optional extension headers if type is 0x1 (NoEndOfGroupWithExtensions) or 0x3 (EndofGroupWithExtensions) + /// Optional extension headers for types with extensions pub extension_headers: Option, /// The Object Status. @@ -75,47 +206,38 @@ impl Decode for Datagram { let group_id = u64::decode(r)?; // Decode Object Id if required - let object_id = match datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::ObjectIdStatus - | DatagramType::ObjectIdStatusExt => Some(u64::decode(r)?), - _ => None, + let object_id = if datagram_type.has_object_id() { + Some(u64::decode(r)?) + } else { + None }; - let publisher_priority = u8::decode(r)?; + // Decode Publisher Priority if required + let publisher_priority = if datagram_type.has_priority() { + Some(u8::decode(r)?) + } else { + None + }; // Decode Extension Headers if required - let extension_headers = match datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => Some(ExtensionHeaders::decode(r)?), - _ => None, + let extension_headers = if datagram_type.has_extensions() { + Some(ExtensionHeaders::decode(r)?) + } else { + None }; - // Decode Status if required - let status = match datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - Some(ObjectStatus::decode(r)?) - } - _ => None, + // Decode Status if required (for status datagram types) + let status = if datagram_type.is_status() { + Some(ObjectStatus::decode(r)?) + } else { + None }; - // Decode Payload if required - let payload = match datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::Payload - | DatagramType::PayloadExt - | DatagramType::PayloadEndOfGroup - | DatagramType::PayloadExtEndOfGroup => Some(r.copy_to_bytes(r.remaining())), - _ => None, + // Decode Payload if required (for payload datagram types) + let payload = if datagram_type.is_payload() { + Some(r.copy_to_bytes(r.remaining())) + } else { + None }; Ok(Self { @@ -138,70 +260,49 @@ impl Encode for Datagram { self.group_id.encode(w)?; // Encode Object Id if required - match self.datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::ObjectIdStatus - | DatagramType::ObjectIdStatusExt => { - if let Some(object_id) = &self.object_id { - object_id.encode(w)?; - } else { - return Err(EncodeError::MissingField("ObjectId".to_string())); - } + if self.datagram_type.has_object_id() { + if let Some(object_id) = &self.object_id { + object_id.encode(w)?; + } else { + return Err(EncodeError::MissingField("ObjectId".to_string())); } - _ => {} - }; + } - self.publisher_priority.encode(w)?; + // Encode Publisher Priority if required + if self.datagram_type.has_priority() { + if let Some(publisher_priority) = &self.publisher_priority { + publisher_priority.encode(w)?; + } else { + return Err(EncodeError::MissingField("PublisherPriority".to_string())); + } + } // Encode Extension Headers if required - match self.datagram_type { - DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::PayloadExt - | DatagramType::PayloadExtEndOfGroup - | DatagramType::ObjectIdStatusExt => { - if let Some(extension_headers) = &self.extension_headers { - extension_headers.encode(w)?; - } else { - return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); - } + if self.datagram_type.has_extensions() { + if let Some(extension_headers) = &self.extension_headers { + extension_headers.encode(w)?; + } else { + return Err(EncodeError::MissingField("ExtensionHeaders".to_string())); } - _ => {} - }; + } - // Decode Status if required - match self.datagram_type { - DatagramType::ObjectIdStatus | DatagramType::ObjectIdStatusExt => { - if let Some(status) = &self.status { - status.encode(w)?; - } else { - return Err(EncodeError::MissingField("Status".to_string())); - } + // Encode Status if required (for status datagram types) + if self.datagram_type.is_status() { + if let Some(status) = &self.status { + status.encode(w)?; + } else { + return Err(EncodeError::MissingField("Status".to_string())); } - _ => {} } - // Decode Payload if required - match self.datagram_type { - DatagramType::ObjectIdPayload - | DatagramType::ObjectIdPayloadExt - | DatagramType::ObjectIdPayloadEndOfGroup - | DatagramType::ObjectIdPayloadExtEndOfGroup - | DatagramType::Payload - | DatagramType::PayloadExt - | DatagramType::PayloadEndOfGroup - | DatagramType::PayloadExtEndOfGroup => { - if let Some(payload) = &self.payload { - Self::encode_remaining(w, payload.len())?; - w.put_slice(payload); - } else { - return Err(EncodeError::MissingField("Payload".to_string())); - } + // Encode Payload if required (for payload datagram types) + if self.datagram_type.is_payload() { + if let Some(payload) = &self.payload { + Self::encode_remaining(w, payload.len())?; + w.put_slice(payload); + } else { + return Err(EncodeError::MissingField("Payload".to_string())); } - _ => {} } Ok(()) @@ -293,7 +394,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -310,7 +411,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -327,7 +428,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -344,7 +445,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -361,7 +462,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -378,7 +479,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -395,7 +496,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -412,7 +513,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -429,7 +530,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -446,7 +547,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: Some(ext_hdrs.clone()), status: None, payload: Some(Bytes::from("payload")), @@ -456,6 +557,40 @@ mod tests { assert_eq!(19, buf.len()); let decoded = Datagram::decode(&mut buf).unwrap(); assert_eq!(decoded, msg); + + // DatagramType = ObjectIdPayloadNoPriority (no priority field) + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayloadNoPriority, + track_alias: 12, + group_id: 10, + object_id: Some(1234), + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + msg.encode(&mut buf).unwrap(); + // Length should be: Type(1)+Alias(1)+GroupId(1)+ObjectId(2)+Payload(7) = 12 (no priority) + assert_eq!(12, buf.len()); + let decoded = Datagram::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); + + // DatagramType = PayloadNoPriority (no priority field, no object id) + let msg = Datagram { + datagram_type: DatagramType::PayloadNoPriority, + track_alias: 12, + group_id: 10, + object_id: None, + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + msg.encode(&mut buf).unwrap(); + // Length should be: Type(1)+Alias(1)+GroupId(1)+Payload(7) = 10 (no priority, no object id) + assert_eq!(10, buf.len()); + let decoded = Datagram::decode(&mut buf).unwrap(); + assert_eq!(decoded, msg); } #[test] @@ -468,7 +603,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -482,7 +617,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: Some(Bytes::from("payload")), @@ -496,7 +631,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: Some(ObjectStatus::EndOfTrack), payload: None, @@ -510,7 +645,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: None, - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: None, @@ -524,7 +659,7 @@ mod tests { track_alias: 12, group_id: 10, object_id: Some(1234), - publisher_priority: 127, + publisher_priority: Some(127), extension_headers: None, status: None, payload: None, @@ -532,6 +667,18 @@ mod tests { let encoded = msg.encode(&mut buf); assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); - // TODO SLG - add tests + // DatagramType = ObjectIdPayload - missing priority (priority is required for this type) + let msg = Datagram { + datagram_type: DatagramType::ObjectIdPayload, + track_alias: 12, + group_id: 10, + object_id: Some(1234), + publisher_priority: None, + extension_headers: None, + status: None, + payload: Some(Bytes::from("payload")), + }; + let encoded = msg.encode(&mut buf); + assert!(matches!(encoded.unwrap_err(), EncodeError::MissingField(_))); } } diff --git a/moq-transport/src/data/extension_headers.rs b/moq-transport/src/data/extension_headers.rs index f6dba873..22191548 100644 --- a/moq-transport/src/data/extension_headers.rs +++ b/moq-transport/src/data/extension_headers.rs @@ -4,6 +4,8 @@ use std::fmt; /// A collection of KeyValuePair entries, where the length in bytes of key-value-pairs are encoded/decoded first. /// This structure is appropriate for Data plane extension headers. +/// +/// Per draft-16 Section 1.4.2, Key-Value-Pairs use delta-encoded Type fields. /// Since duplicate parameters are allowed for unknown extension headers, we don't do duplicate checking here. #[derive(Default, Clone, Eq, PartialEq)] pub struct ExtensionHeaders(pub Vec); @@ -44,7 +46,40 @@ impl ExtensionHeaders { } } +impl ExtensionHeaders { + /// Decode extension headers from remaining bytes (no length prefix). + /// Used for Track Extensions in PUBLISH where the length is implicit from the message. + pub fn decode_remaining_bytes(r: &mut R) -> Result { + if !r.has_remaining() { + return Ok(ExtensionHeaders::new()); + } + + let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + + while r.has_remaining() { + // Read delta type and reconstruct absolute key + let delta = u64::decode(r)?; + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[ExtHdr] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, r)?; + kvps.push(kvp); + prev_key = key; + } + + Ok(ExtensionHeaders(kvps)) + } +} + impl Decode for ExtensionHeaders { + /// Decode extension headers with delta-encoded Type fields (draft-16 Section 1.4.2). fn decode(r: &mut R) -> Result { // Read total byte length of the encoded kvps // Note: this is the difference between KeyValuePairs and ExtensionHeaders. @@ -65,9 +100,23 @@ impl Decode for ExtensionHeaders { let mut kvps_bytes = bytes::Bytes::from(buf); let mut kvps = Vec::new(); + let mut prev_key: u64 = 0; + while kvps_bytes.has_remaining() { - let kvp = KeyValuePair::decode(&mut kvps_bytes)?; + // Read delta type and reconstruct absolute key + let delta = u64::decode(&mut kvps_bytes)?; + let key = prev_key.checked_add(delta).ok_or_else(|| { + log::error!( + "[ExtHdr] Delta type overflow: prev_key={}, delta={}", + prev_key, + delta + ); + DecodeError::BoundsExceeded(crate::coding::BoundsExceeded) + })?; + + let kvp = KeyValuePair::decode_value(key, &mut kvps_bytes)?; kvps.push(kvp); + prev_key = key; } Ok(ExtensionHeaders(kvps)) @@ -75,14 +124,31 @@ impl Decode for ExtensionHeaders { } impl Encode for ExtensionHeaders { + /// Encode extension headers with delta-encoded Type fields (draft-16 Section 1.4.2). + /// Entries are sorted by key in ascending order before encoding. fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - // Encode all KeyValuePair entries into a temporary buffer to compute total byte length + // Sort by key for delta encoding + let mut sorted: Vec<&KeyValuePair> = self.0.iter().collect(); + sorted.sort_by_key(|kvp| kvp.key); + + // Encode all entries into a temporary buffer to compute total byte length let mut tmp = bytes::BytesMut::new(); - for kvp in &self.0 { - kvp.encode(&mut tmp)?; + let mut prev_key: u64 = 0; + for kvp in sorted { + let delta = kvp.key.checked_sub(prev_key).ok_or_else(|| { + log::error!( + "[ExtHdr] Keys not sortable: prev_key={}, current_key={}", + prev_key, + kvp.key + ); + EncodeError::InvalidValue + })?; + delta.encode(&mut tmp)?; + kvp.encode_value(&mut tmp)?; + prev_key = kvp.key; } - // Write total byte length (u64) followed by the encoded bytes + // Write total byte length followed by the encoded bytes (tmp.len() as u64).encode(w)?; w.put_slice(&tmp); @@ -109,9 +175,10 @@ mod tests { use bytes::BytesMut; #[test] - fn encode_decode_extension_headers() { + fn encode_decode_extension_headers_single() { let mut buf = BytesMut::new(); + // Single entry: key=1. Delta from 0 = 1. let mut ext_hdrs = ExtensionHeaders::new(); ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); ext_hdrs.encode(&mut buf).unwrap(); @@ -119,21 +186,55 @@ mod tests { buf.to_vec(), vec![ 0x07, // 7 bytes total length - 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, // Key=1, Value=[1,2,3,4,5] + // Delta=1, length=5, data + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, ] ); let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); assert_eq!(decoded, ext_hdrs); + } + + #[test] + fn encode_decode_extension_headers_multiple() { + let mut buf = BytesMut::new(); + // Multiple entries inserted out of order — encoding sorts by key. + // Keys: 0 (even, int), 1 (odd, bytes), 100 (even, int) let mut ext_hdrs = ExtensionHeaders::new(); - ext_hdrs.set_intvalue(0, 0); // 2 bytes - ext_hdrs.set_intvalue(100, 100); // 4 bytes - ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); // 1 byte key, 1 byte length, 5 bytes data = 7 bytes + ext_hdrs.set_intvalue(0, 0); + ext_hdrs.set_intvalue(100, 100); + ext_hdrs.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); ext_hdrs.encode(&mut buf).unwrap(); let buf_vec = buf.to_vec(); - // Validate the encoded length and the KeyValuePair's length. - assert_eq!(14, buf_vec.len()); // 14 bytes total (length + 3 kvps) - assert_eq!(13, buf_vec[0]); // 13 bytes for the 3 KeyValuePairs data + + #[rustfmt::skip] + let expected = vec![ + 0x0d, // 13 bytes total length for the KVP data + // Entry 1: key=0 (delta=0), even, int value=0 + 0x00, 0x00, + // Entry 2: key=1 (delta=1), odd, bytes len=5 + 0x01, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05, + // Entry 3: key=100 (delta=99), even, int value=100 + 0x40, 0x63, 0x40, 0x64, + ]; + assert_eq!(buf_vec, expected); + + // Decode and verify — decoded entries will be in sorted order + let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); + let mut expected_ext = ExtensionHeaders::new(); + expected_ext.set_intvalue(0, 0); + expected_ext.set_bytesvalue(1, vec![0x01, 0x02, 0x03, 0x04, 0x05]); + expected_ext.set_intvalue(100, 100); + assert_eq!(decoded, expected_ext); + } + + #[test] + fn encode_decode_extension_headers_empty() { + let mut buf = BytesMut::new(); + + let ext_hdrs = ExtensionHeaders::new(); + ext_hdrs.encode(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), vec![0x00]); // length=0 let decoded = ExtensionHeaders::decode(&mut buf).unwrap(); assert_eq!(decoded, ext_hdrs); } diff --git a/moq-transport/src/data/extension_types.rs b/moq-transport/src/data/extension_types.rs new file mode 100644 index 00000000..7c83de7a --- /dev/null +++ b/moq-transport/src/data/extension_types.rs @@ -0,0 +1,38 @@ +//! Known extension header type constants for the MOQT data plane. +//! +//! These extension headers can be attached to objects in subgroups, datagrams, and fetch streams. +//! See the MOQT specification for detailed semantics of each extension type. + +/// Immutable Extensions (0xB) +/// +/// A container extension header that wraps other extension headers that MUST NOT +/// be modified by relays or intermediaries. The contents of this extension header +/// should be preserved exactly as received when forwarding objects. +pub const IMMUTABLE_EXTENSIONS: u64 = 0xB; + +/// Prior Group ID Gap (0x3C) +/// +/// Indicates that one or more groups prior to this one are missing or unavailable. +/// The value is an integer indicating the number of missing prior groups. +/// This is used to signal discontinuities in the group sequence to subscribers. +pub const PRIOR_GROUP_ID_GAP: u64 = 0x3C; + +/// Prior Object ID Gap (0x3E) +/// +/// Indicates that one or more objects prior to this one within the same group/subgroup +/// are missing or unavailable. The value is an integer indicating the number of missing +/// prior objects. This is used to signal discontinuities in the object sequence. +pub const PRIOR_OBJECT_ID_GAP: u64 = 0x3E; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extension_type_values() { + // Verify the spec-defined values + assert_eq!(IMMUTABLE_EXTENSIONS, 0xB); + assert_eq!(PRIOR_GROUP_ID_GAP, 0x3C); + assert_eq!(PRIOR_OBJECT_ID_GAP, 0x3E); + } +} diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/data/header.rs index b8274b4a..2f83f3c7 100644 --- a/moq-transport/src/data/header.rs +++ b/moq-transport/src/data/header.rs @@ -6,6 +6,7 @@ use std::fmt; #[repr(u64)] #[derive(Copy, Debug, Clone, Eq, PartialEq)] pub enum StreamHeaderType { + // Priority Present variants (0x10-0x1D) SubgroupZeroId = 0x10, SubgroupZeroIdExt = 0x11, SubgroupFirstObjectId = 0x12, @@ -18,13 +19,27 @@ pub enum StreamHeaderType { SubgroupFirstObjectIdExtEndOfGroup = 0x1b, SubgroupIdEndOfGroup = 0x1c, SubgroupIdExtEndOfGroup = 0x1d, + // Priority Not Present variants (0x30-0x3D) + SubgroupZeroIdNoPriority = 0x30, + SubgroupZeroIdExtNoPriority = 0x31, + SubgroupFirstObjectIdNoPriority = 0x32, + SubgroupFirstObjectIdExtNoPriority = 0x33, + SubgroupIdNoPriority = 0x34, + SubgroupIdExtNoPriority = 0x35, + SubgroupZeroIdEndOfGroupNoPriority = 0x38, + SubgroupZeroIdExtEndOfGroupNoPriority = 0x39, + SubgroupFirstObjectIdEndOfGroupNoPriority = 0x3a, + SubgroupFirstObjectIdExtEndOfGroupNoPriority = 0x3b, + SubgroupIdEndOfGroupNoPriority = 0x3c, + SubgroupIdExtEndOfGroupNoPriority = 0x3d, + // Fetch Fetch = 0x5, } impl StreamHeaderType { pub fn is_subgroup(&self) -> bool { let header_type = *self as u64; - (0x10..=0x1d).contains(&header_type) + (0x10..=0x1d).contains(&header_type) || (0x30..=0x3d).contains(&header_type) } pub fn is_fetch(&self) -> bool { @@ -40,6 +55,12 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupZeroIdExtEndOfGroup | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupZeroIdExtNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtNoPriority + | StreamHeaderType::SubgroupIdExtNoPriority + | StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority | StreamHeaderType::Fetch ) } @@ -51,6 +72,37 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupIdExt | StreamHeaderType::SubgroupIdEndOfGroup | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupIdNoPriority + | StreamHeaderType::SubgroupIdExtNoPriority + | StreamHeaderType::SubgroupIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority + ) + } + + pub fn has_priority(&self) -> bool { + let header_type = *self as u64; + // Priority Present variants are 0x10-0x1D + // Priority Not Present variants are 0x30-0x3D + (0x10..=0x1d).contains(&header_type) + } + + /// Returns true if this header type signals end-of-group when the stream ends. + /// For these types, the relay should write an EndOfGroup marker when the stream completes. + pub fn signals_end_of_group(&self) -> bool { + matches!( + *self, + StreamHeaderType::SubgroupZeroIdEndOfGroup + | StreamHeaderType::SubgroupZeroIdExtEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroup + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup + | StreamHeaderType::SubgroupIdEndOfGroup + | StreamHeaderType::SubgroupIdExtEndOfGroup + | StreamHeaderType::SubgroupZeroIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdEndOfGroupNoPriority + | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority ) } } @@ -83,6 +135,7 @@ impl Decode for StreamHeaderType { ); let header_type = match type_value { + // Priority Present variants (0x10-0x1D) 0x10_u64 => Ok(Self::SubgroupZeroId), 0x11_u64 => Ok(Self::SubgroupZeroIdExt), 0x12_u64 => Ok(Self::SubgroupFirstObjectId), @@ -95,6 +148,20 @@ impl Decode for StreamHeaderType { 0x1b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroup), 0x1c_u64 => Ok(Self::SubgroupIdEndOfGroup), 0x1d_u64 => Ok(Self::SubgroupIdExtEndOfGroup), + // Priority Not Present variants (0x30-0x3D) + 0x30_u64 => Ok(Self::SubgroupZeroIdNoPriority), + 0x31_u64 => Ok(Self::SubgroupZeroIdExtNoPriority), + 0x32_u64 => Ok(Self::SubgroupFirstObjectIdNoPriority), + 0x33_u64 => Ok(Self::SubgroupFirstObjectIdExtNoPriority), + 0x34_u64 => Ok(Self::SubgroupIdNoPriority), + 0x35_u64 => Ok(Self::SubgroupIdExtNoPriority), + 0x38_u64 => Ok(Self::SubgroupZeroIdEndOfGroupNoPriority), + 0x39_u64 => Ok(Self::SubgroupZeroIdExtEndOfGroupNoPriority), + 0x3a_u64 => Ok(Self::SubgroupFirstObjectIdEndOfGroupNoPriority), + 0x3b_u64 => Ok(Self::SubgroupFirstObjectIdExtEndOfGroupNoPriority), + 0x3c_u64 => Ok(Self::SubgroupIdEndOfGroupNoPriority), + 0x3d_u64 => Ok(Self::SubgroupIdExtEndOfGroupNoPriority), + // Fetch 0x05_u64 => Ok(Self::Fetch), _ => { log::error!( @@ -290,7 +357,31 @@ mod tests { track_alias: 10, group_id: 0, subgroup_id: Some(1), - publisher_priority: 100, + publisher_priority: Some(100), + }), + fetch_header: None, + }; + sh.encode(&mut buf).unwrap(); + let decoded = StreamHeader::decode(&mut buf).unwrap(); + assert_eq!(decoded, sh); + assert!(sh.header_type.is_subgroup()); + assert!(!sh.header_type.is_fetch()); + assert!(sh.header_type.has_subgroup_id()); + } + + #[test] + fn encode_decode_stream_header_no_priority() { + let mut buf = BytesMut::new(); + + // Test a NoPriority subgroup header type + let sh = StreamHeader { + header_type: StreamHeaderType::SubgroupIdNoPriority, + subgroup_header: Some(SubgroupHeader { + header_type: StreamHeaderType::SubgroupIdNoPriority, + track_alias: 10, + group_id: 0, + subgroup_id: Some(1), + publisher_priority: None, }), fetch_header: None, }; @@ -300,5 +391,6 @@ mod tests { assert!(sh.header_type.is_subgroup()); assert!(!sh.header_type.is_fetch()); assert!(sh.header_type.has_subgroup_id()); + assert!(!sh.header_type.has_priority()); } } diff --git a/moq-transport/src/data/mod.rs b/moq-transport/src/data/mod.rs index d76ba871..0d0025ab 100644 --- a/moq-transport/src/data/mod.rs +++ b/moq-transport/src/data/mod.rs @@ -1,5 +1,6 @@ mod datagram; mod extension_headers; +mod extension_types; mod fetch; mod header; mod object_status; @@ -7,6 +8,7 @@ mod subgroup; pub use datagram::*; pub use extension_headers::*; +pub use extension_types::*; pub use fetch::*; pub use header::*; pub use object_status::*; diff --git a/moq-transport/src/data/subgroup.rs b/moq-transport/src/data/subgroup.rs index 45e89e9f..9cfb1127 100644 --- a/moq-transport/src/data/subgroup.rs +++ b/moq-transport/src/data/subgroup.rs @@ -16,7 +16,8 @@ pub struct SubgroupHeader { pub subgroup_id: Option, /// Publisher priority, where **smaller** values are sent first. - pub publisher_priority: u8, + /// Optional when using NoPriority header types (0x30-0x3D). + pub publisher_priority: Option, } // Note: Not using the Decode trait, since we need to know the header_type to properly parse this, and it @@ -52,12 +53,20 @@ impl SubgroupHeader { } }; - let publisher_priority = u8::decode(r)?; - log::trace!( - "[DECODE] SubgroupHeader: publisher_priority={}, buffer_remaining={} bytes", - publisher_priority, - r.remaining() - ); + let publisher_priority = if header_type.has_priority() { + let priority = u8::decode(r)?; + log::trace!( + "[DECODE] SubgroupHeader: publisher_priority={}, buffer_remaining={} bytes", + priority, + r.remaining() + ); + Some(priority) + } else { + log::trace!( + "[DECODE] SubgroupHeader: publisher_priority=None (not present for NoPriority header type)" + ); + None + }; let result = Self { header_type, @@ -68,7 +77,7 @@ impl SubgroupHeader { }; log::debug!( - "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={}", + "[DECODE] SubgroupHeader complete: track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}", result.track_alias, result.group_id, result.subgroup_id, @@ -82,7 +91,7 @@ impl SubgroupHeader { impl Encode for SubgroupHeader { fn encode(&self, w: &mut W) -> Result<(), EncodeError> { log::trace!( - "[ENCODE] SubgroupHeader: starting encode - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + "[ENCODE] SubgroupHeader: starting encode - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}", self.track_alias, self.group_id, self.subgroup_id, @@ -125,11 +134,25 @@ impl Encode for SubgroupHeader { log::trace!("[ENCODE] SubgroupHeader: subgroup_id not encoded (not required for this header type)"); } - self.publisher_priority.encode(w)?; - log::trace!( - "[ENCODE] SubgroupHeader: encoded publisher_priority={}", - self.publisher_priority - ); + if self.header_type.has_priority() { + if let Some(publisher_priority) = self.publisher_priority { + publisher_priority.encode(w)?; + log::trace!( + "[ENCODE] SubgroupHeader: encoded publisher_priority={}", + publisher_priority + ); + } else { + log::error!( + "[ENCODE] SubgroupHeader: MISSING publisher_priority for header_type={:?}", + self.header_type + ); + return Err(EncodeError::MissingField("PublisherPriority".to_string())); + } + } else { + log::trace!( + "[ENCODE] SubgroupHeader: publisher_priority not encoded (NoPriority header type)" + ); + } let bytes_written = start_pos - w.remaining_mut(); log::debug!( diff --git a/moq-transport/src/mlog/events.rs b/moq-transport/src/mlog/events.rs index 53c513f0..f074b747 100644 --- a/moq-transport/src/mlog/events.rs +++ b/moq-transport/src/mlog/events.rs @@ -2,10 +2,10 @@ // - SubscribeUpdate (parsed/created) // - PublishNamespaceDone (parsed/created) // - PublishNamespaceCancel (parsed/created) -// - TrackStatus, TrackStatusOk, TrackStatusError (parsed/created) -// - SubscribeNamespace, SubscribeNamespaceOk, SubscribeNamespaceError, UnsubscribeNamespace (parsed/created) -// - Fetch, FetchOk, FetchError, FetchCancel (parsed/created) -// - Publish, PublishOk, PublishError, PublishDone (parsed/created) +// - TrackStatus, TrackStatusOk (parsed/created) +// - SubscribeNamespace (parsed/created) +// - Fetch, FetchOk, FetchCancel (parsed/created) +// - Publish, PublishOk, PublishDone (parsed/created) // - MaxRequestId (parsed/created) // - RequestsBlocked (parsed/created) // @@ -207,7 +207,6 @@ fn create_control_message_event( /// Create a control_message_parsed event for CLIENT_SETUP pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Event { - let versions: Vec = msg.versions.0.iter().map(|v| format!("{:?}", v)).collect(); create_control_message_event( time, stream_id, @@ -215,8 +214,6 @@ pub fn client_setup_parsed(time: f64, stream_id: u64, msg: &setup::Client) -> Ev "client_setup", json!( { - "number_of_supported_versions": msg.versions.0.len(), - "supported_versions": versions, "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -231,7 +228,6 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E "server_setup", json!( { - "selected_version": format!("{:?}", msg.version), "parameters": key_value_pairs_to_vec(&msg.params.0), }), ) @@ -239,25 +235,12 @@ pub fn server_setup_created(time: f64, stream_id: u64, msg: &setup::Server) -> E /// Helper to convert SUBSCRIBE message to JSON fn subscribe_to_json(msg: &message::Subscribe) -> JsonValue { - let mut json = json!({ + let json = json!({ "subscribe_id": msg.id, "track_namespace": msg.track_namespace.to_string(), "track_name": &msg.track_name, - "subscriber_priority": msg.subscriber_priority, - "group_order": format!("{:?}", msg.group_order), - "filter_type": format!("{:?}", msg.filter_type), "parameters": key_value_pairs_to_vec(&msg.params.0), }); - - // Add optional fields based on filter type - if let Some(start_loc) = &msg.start_location { - json["start_group"] = json!(start_loc.group_id); - json["start_object"] = json!(start_loc.object_id); - } - if let Some(end_group) = msg.end_group_id { - json["end_group"] = json!(end_group); - } - json } @@ -273,23 +256,11 @@ pub fn subscribe_created(time: f64, stream_id: u64, msg: &message::Subscribe) -> /// Helper to convert SUBSCRIBE_OK message to JSON fn subscribe_ok_to_json(msg: &message::SubscribeOk) -> JsonValue { - let mut json = json!({ + let json = json!({ "subscribe_id": msg.id, "track_alias": msg.track_alias, - "expires": msg.expires, - "group_order": format!("{:?}", msg.group_order), - "content_exists": msg.content_exists, "parameters": key_value_pairs_to_vec(&msg.params.0), }); - - // Add optional largest_location fields if content exists - if msg.content_exists { - if let Some(largest) = &msg.largest_location { - json["largest_group_id"] = json!(largest.group_id); - json["largest_object_id"] = json!(largest.object_id); - } - } - json } @@ -316,33 +287,34 @@ pub fn subscribe_ok_created(time: f64, stream_id: u64, msg: &message::SubscribeO } /// Helper to convert SUBSCRIBE_ERROR message to JSON -fn subscribe_error_to_json(msg: &message::SubscribeError) -> JsonValue { +fn request_error_to_json(msg: &message::RequestError) -> JsonValue { json!({ - "subscribe_id": msg.id, + "request_id": msg.id, "error_code": msg.error_code, + "retry_interval": msg.retry_interval, "reason_phrase": &msg.reason_phrase.0, }) } /// Create a control_message_parsed event for SUBSCRIBE_ERROR -pub fn subscribe_error_parsed(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { +pub fn request_error_parsed(time: f64, stream_id: u64, msg: &message::RequestError) -> Event { create_control_message_event( time, stream_id, true, - "subscribe_error", - subscribe_error_to_json(msg), + "request_error", + request_error_to_json(msg), ) } /// Create a control_message_created event for SUBSCRIBE_ERROR -pub fn subscribe_error_created(time: f64, stream_id: u64, msg: &message::SubscribeError) -> Event { +pub fn reqeust_error_created(time: f64, stream_id: u64, msg: &message::RequestError) -> Event { create_control_message_event( time, stream_id, false, - "subscribe_error", - subscribe_error_to_json(msg), + "request_error", + request_error_to_json(msg), ) } @@ -386,78 +358,100 @@ pub fn publish_namespace_created( } /// Helper to convert PUBLISH_NAMESPACE_OK message to JSON -fn publish_namespace_ok_to_json(msg: &message::PublishNamespaceOk) -> JsonValue { +fn request_ok_to_json(msg: &message::RequestOk) -> JsonValue { json!({ "request_id": msg.id, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_OK (was ANNOUNCE_OK) -pub fn publish_namespace_ok_parsed( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceOk, -) -> Event { +/// Create a control_message_parsed event for REQUEST_OK +pub fn request_ok_parsed(time: f64, stream_id: u64, msg: &message::RequestOk) -> Event { + create_control_message_event(time, stream_id, true, "request_ok", request_ok_to_json(msg)) +} + +/// Create a control_message_created event for Reqeust OK +pub fn reqeust_ok_created(time: f64, stream_id: u64, msg: &message::RequestOk) -> Event { create_control_message_event( time, stream_id, - true, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + false, + "request_ok", + request_ok_to_json(msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_OK -pub fn publish_namespace_ok_created( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceOk, -) -> Event { +fn publish_to_json(msg: &message::Publish) -> JsonValue { + json!({ + "publish_id": msg.id, + "track_namespace": msg.track_namespace.to_string(), + "track_name": &msg.track_name, + "track_alias": msg.track_alias, + "parameters": key_value_pairs_to_vec(&msg.params.0), + }) +} + +/// Create a control_message_parsed event for PUBLISH +pub fn publish_parsed(time: f64, stream_id: u64, msg: &message::Publish) -> Event { + create_control_message_event(time, stream_id, true, "publish", publish_to_json(msg)) +} + +/// Create a control_message_created event for PUBLISH +pub fn publish_created(time: f64, stream_id: u64, msg: &message::Publish) -> Event { + create_control_message_event(time, stream_id, false, "publish", publish_to_json(msg)) +} + +fn publish_ok_to_json(msg: &message::PublishOk) -> JsonValue { + json!({ + "publish_id": msg.id, + "parameters": key_value_pairs_to_vec(&msg.params.0), + }) +} + +/// Create a control_message_parsed event for PUBLISH_OK +pub fn publish_ok_parsed(time: f64, stream_id: u64, msg: &message::PublishOk) -> Event { + create_control_message_event(time, stream_id, true, "publish_ok", publish_ok_to_json(msg)) +} + +/// Create a control_message_created event for PUBLISH_OK +pub fn publish_ok_created(time: f64, stream_id: u64, msg: &message::PublishOk) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_ok", - publish_namespace_ok_to_json(msg), + "publish_ok", + publish_ok_to_json(msg), ) } -/// Helper to convert PUBLISH_NAMESPACE_ERROR message to JSON -fn publish_namespace_error_to_json(msg: &message::PublishNamespaceError) -> JsonValue { +/// Helper to convert PUBLISH_DONE message to JSON +fn publish_done_to_json(msg: &message::PublishDone) -> JsonValue { json!({ - "request_id": msg.id, - "error_code": msg.error_code, - "reason_phrase": &msg.reason_phrase.0, + "publish_id": msg.id, + "status_code": msg.status_code, + "stream_count": msg.stream_count, + "reason": &msg.reason.0, }) } -/// Create a control_message_parsed event for PUBLISH_NAMESPACE_ERROR (was ANNOUNCE_ERROR) -pub fn publish_namespace_error_parsed( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceError, -) -> Event { +/// Create a control_message_parsed event for PUBLISH_DONE +pub fn publish_done_parsed(time: f64, stream_id: u64, msg: &message::PublishDone) -> Event { create_control_message_event( time, stream_id, true, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "publish_done", + publish_done_to_json(msg), ) } -/// Create a control_message_created event for PUBLISH_NAMESPACE_ERROR -pub fn publish_namespace_error_created( - time: f64, - stream_id: u64, - msg: &message::PublishNamespaceError, -) -> Event { +/// Create a control_message_created event for PUBLISH_DONE +pub fn publish_done_created(time: f64, stream_id: u64, msg: &message::PublishDone) -> Event { create_control_message_event( time, stream_id, false, - "publish_namespace_error", - publish_namespace_error_to_json(msg), + "publish_done", + publish_done_to_json(msg), ) } diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 25006b9f..2e5ceacb 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -2,14 +2,8 @@ use crate::{coding, serve, setup}; #[derive(thiserror::Error, Debug, Clone)] pub enum SessionError { - #[error("webtransport session: {0}")] - Session(#[from] web_transport::SessionError), - - #[error("webtransport write: {0}")] - Write(#[from] web_transport::WriteError), - - #[error("webtransport read: {0}")] - Read(#[from] web_transport::ReadError), + #[error("webtransport error: {0}")] + WebTransport(#[from] web_transport::Error), #[error("encode error: {0}")] Encode(#[from] coding::EncodeError), @@ -53,9 +47,7 @@ impl SessionError { // PROTOCOL_VIOLATION (0x3) - The role negotiated in the handshake was violated Self::RoleViolation => 0x3, // INTERNAL_ERROR (0x1) - Generic internal errors - Self::Session(_) => 0x1, - Self::Read(_) => 0x1, - Self::Write(_) => 0x1, + Self::WebTransport(_) => 0x1, Self::Encode(_) => 0x1, Self::BoundsExceeded(_) => 0x1, Self::Internal => 0x1, diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index 3b01dbd1..2b6dea17 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -1,19 +1,27 @@ -mod announce; -mod announced; mod error; +mod publish_namespace; +mod publish_namespace_received; +mod publish_received; +mod published; mod publisher; mod reader; mod subscribe; +mod subscribe_namespace; +mod subscribe_namespace_received; mod subscribed; mod subscriber; mod track_status_requested; mod writer; -pub use announce::*; -pub use announced::*; pub use error::*; +pub use publish_namespace::*; +pub use publish_namespace_received::*; +pub use publish_received::*; +pub use published::*; pub use publisher::*; pub use subscribe::*; +pub use subscribe_namespace::*; +pub use subscribe_namespace_received::*; pub use subscribed::*; pub use subscriber::*; pub use track_status_requested::*; @@ -52,14 +60,6 @@ pub struct Session { } impl Session { - // Helper for determining the largest supported version - fn largest_common(a: &[T], b: &[T]) -> Option { - a.iter() - .filter(|x| b.contains(x)) // keep only items also in b - .cloned() // clone because we return T, not &T - .max() // take the largest - } - fn new( webtransport: web_transport::Session, sender: Writer, @@ -101,7 +101,7 @@ impl Session { /// Create an outbound/client QUIC connection, by opening a bi-directional QUIC stream for /// MOQT control messaging. Performs SETUP messaging and version negotiation. pub async fn connect( - mut session: web_transport::Session, + session: web_transport::Session, mlog_path: Option, ) -> Result<(Session, Publisher, Subscriber), SessionError> { let mlog = mlog_path.and_then(|path| { @@ -113,16 +113,10 @@ impl Session { let mut sender = Writer::new(control.0); let mut recver = Reader::new(control.1); - let versions: setup::Versions = [setup::Version::DRAFT_14].into(); - - // TODO SLG - make configurable? let mut params = KeyValuePairs::default(); params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - let client = setup::Client { - versions: versions.clone(), - params, - }; + let client = setup::Client { params }; log::debug!("sending CLIENT_SETUP: {:?}", client); sender.encode(&client).await?; @@ -142,7 +136,7 @@ impl Session { /// Accepts an inbound/server QUIC connection, by accepting a bi-directional QUIC stream for /// MOQT control messaging. Performs SETUP messaging and version negotiation. pub async fn accept( - mut session: web_transport::Session, + session: web_transport::Session, mlog_path: Option, ) -> Result<(Session, Option, Option), SessionError> { let mut mlog = mlog_path.and_then(|path| { @@ -163,35 +157,24 @@ impl Session { let _ = mlog.add_event(event); } - let server_versions = setup::Versions(vec![setup::Version::DRAFT_14]); - - if let Some(largest_common_version) = - Self::largest_common(&server_versions, &client.versions) - { - // TODO SLG - make configurable? - let mut params = KeyValuePairs::default(); - params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); + // TODO SLG - make configurable? + let mut params = KeyValuePairs::default(); + params.set_intvalue(setup::ParameterType::MaxRequestId.into(), 100); - let server = setup::Server { - version: largest_common_version, - params, - }; + let server = setup::Server { params }; - log::debug!("sending SERVER_SETUP: {:?}", server); + log::debug!("sending SERVER_SETUP: {:?}", server); - // Emit mlog event for SERVER_SETUP created - if let Some(ref mut mlog) = mlog { - let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); - let _ = mlog.add_event(event); - } + // Emit mlog event for SERVER_SETUP created + if let Some(ref mut mlog) = mlog { + let event = mlog::events::server_setup_created(mlog.elapsed_ms(), 0, &server); + let _ = mlog.add_event(event); + } - sender.encode(&server).await?; + sender.encode(&server).await?; - // We are the server, so the first request id is 1 - Ok(Session::new(session, sender, recver, 1, mlog)) - } else { - Err(SessionError::Version(client.versions, server_versions)) - } + // We are the server, so the first request id is 1 + Ok(Session::new(session, sender, recver, 1, mlog)) } /// Run Tasks for the session, including sending of control messages, receiving and processing @@ -199,9 +182,10 @@ impl Session { /// and receiving and processing QUIC datagrams received pub async fn run(self) -> Result<(), SessionError> { tokio::select! { - res = Self::run_recv(self.recver, self.publisher, self.subscriber.clone(), self.mlog.clone()) => res, + res = Self::run_recv(self.recver, self.publisher.clone(), self.subscriber.clone(), self.mlog.clone()) => res, res = Self::run_send(self.sender, self.outgoing, self.mlog.clone()) => res, res = Self::run_streams(self.webtransport.clone(), self.subscriber.clone()) => res, + res = Self::run_bidi_streams(self.webtransport.clone(), self.publisher) => res, res = Self::run_datagrams(self.webtransport, self.subscriber) => res, } } @@ -229,8 +213,8 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_created(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_created(time, stream_id, m)) + Message::RequestError(m) => { + Some(mlog::events::reqeust_error_created(time, stream_id, m)) } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_created(time, stream_id, m)) @@ -238,16 +222,22 @@ impl Session { Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_created(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_created(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_created(time, stream_id, m), - ), + Message::RequestOk(m) => { + Some(mlog::events::reqeust_ok_created(time, stream_id, m)) + } Message::GoAway(m) => { Some(mlog::events::go_away_created(time, stream_id, m)) } - _ => None, // TODO: Add other message types + Message::Publish(m) => { + Some(mlog::events::publish_created(time, stream_id, m)) + } + Message::PublishOk(m) => { + Some(mlog::events::publish_ok_created(time, stream_id, m)) + } + Message::PublishDone(m) => { + Some(mlog::events::publish_done_created(time, stream_id, m)) + } + _ => None, }; if let Some(event) = event { @@ -291,8 +281,8 @@ impl Session { Message::SubscribeOk(m) => { Some(mlog::events::subscribe_ok_parsed(time, stream_id, m)) } - Message::SubscribeError(m) => { - Some(mlog::events::subscribe_error_parsed(time, stream_id, m)) + Message::RequestError(m) => { + Some(mlog::events::request_error_parsed(time, stream_id, m)) } Message::Unsubscribe(m) => { Some(mlog::events::unsubscribe_parsed(time, stream_id, m)) @@ -300,16 +290,22 @@ impl Session { Message::PublishNamespace(m) => { Some(mlog::events::publish_namespace_parsed(time, stream_id, m)) } - Message::PublishNamespaceOk(m) => Some( - mlog::events::publish_namespace_ok_parsed(time, stream_id, m), - ), - Message::PublishNamespaceError(m) => Some( - mlog::events::publish_namespace_error_parsed(time, stream_id, m), - ), + Message::RequestOk(m) => { + Some(mlog::events::request_ok_parsed(time, stream_id, m)) + } Message::GoAway(m) => { Some(mlog::events::go_away_parsed(time, stream_id, m)) } - _ => None, // TODO: Add other message types + Message::Publish(m) => { + Some(mlog::events::publish_parsed(time, stream_id, m)) + } + Message::PublishOk(m) => { + Some(mlog::events::publish_ok_parsed(time, stream_id, m)) + } + Message::PublishDone(m) => { + Some(mlog::events::publish_done_parsed(time, stream_id, m)) + } + _ => None, }; if let Some(event) = event { @@ -318,6 +314,29 @@ impl Session { } } + // RequestOk and RequestError are bidirectional — they can be responses + // to requests originated by either side (e.g., PUBLISH_NAMESPACE from the + // publisher or SUBSCRIBE_NAMESPACE from the subscriber). We must try both + // handlers so the response reaches whichever side owns that request ID. + match &msg { + Message::RequestOk(_) | Message::RequestError(_) => { + // Try subscriber handler first (for SUBSCRIBE_NAMESPACE responses) + if let Ok(pub_msg) = TryInto::::try_into(msg.clone()) { + if let Some(sub) = subscriber.as_mut() { + let _ = sub.recv_message(pub_msg); + } + } + // Also try publisher handler (for PUBLISH_NAMESPACE responses) + if let Ok(sub_msg) = TryInto::::try_into(msg) { + if let Some(pub_) = publisher.as_mut() { + let _ = pub_.recv_message(sub_msg); + } + } + continue; + } + _ => {} + } + let msg = match TryInto::::try_into(msg) { Ok(msg) => { subscriber @@ -353,7 +372,7 @@ impl Session { /// Will read stream header to know what type of stream it is and create /// the appropriate stream handlers. async fn run_streams( - mut webtransport: web_transport::Session, + webtransport: web_transport::Session, subscriber: Option, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); @@ -375,9 +394,55 @@ impl Session { } } + /// Accepts bidirectional QUIC streams for messages like SUBSCRIBE_NAMESPACE. + /// In draft-16, SUBSCRIBE_NAMESPACE uses its own bidirectional stream. + async fn run_bidi_streams( + webtransport: web_transport::Session, + publisher: Option, + ) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + + loop { + tokio::select! { + res = webtransport.accept_bi() => { + let (_send, recv) = res?; + let mut publisher = publisher.clone().ok_or(SessionError::RoleViolation)?; + + tasks.push(async move { + let mut reader = Reader::new(recv); + + // Read the message from the bidi stream + let msg: message::Message = match reader.decode().await { + Ok(msg) => msg, + Err(e) => { + log::warn!("failed to decode message on bidi stream: {}", e); + return; + } + }; + + log::debug!("received message on bidi stream: {:?}", msg); + + // Handle SUBSCRIBE_NAMESPACE on its dedicated bidi stream + match msg { + Message::SubscribeNamespace(subscribe_ns) => { + if let Err(e) = publisher.recv_message(message::Subscriber::SubscribeNamespace(subscribe_ns)) { + log::warn!("failed to handle SUBSCRIBE_NAMESPACE: {}", e); + } + } + other => { + log::warn!("unexpected message type on bidi stream: {:?}", other); + } + } + }); + }, + _ = tasks.next(), if !tasks.is_empty() => {}, + }; + } + } + /// Receives QUIC datagrams and processes them using the Subscriber logic async fn run_datagrams( - mut webtransport: web_transport::Session, + webtransport: web_transport::Session, mut subscriber: Option, ) -> Result<(), SessionError> { loop { diff --git a/moq-transport/src/session/publish_namespace.rs b/moq-transport/src/session/publish_namespace.rs new file mode 100644 index 00000000..14f7fcee --- /dev/null +++ b/moq-transport/src/session/publish_namespace.rs @@ -0,0 +1,157 @@ +use std::ops; + +use crate::coding::TrackNamespace; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Publisher; + +#[derive(Debug, Clone)] +pub struct PublishNamespaceInfo { + pub request_id: u64, + pub namespace: TrackNamespace, +} + +/// Internal state for PublishNamespace. +/// +/// PublishNamespace is a namespace registry that advertises to subscribers +/// that a publisher has tracks available in a namespace. It does NOT route +/// subscriptions - that happens via PUBLISH/SUBSCRIBE messages directly. +struct PublishNamespaceState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for PublishNamespaceState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + } + } +} + +/// Represents an outbound PUBLISH_NAMESPACE request (publisher side). +/// When dropped, sends PUBLISH_NAMESPACE_DONE to the peer. +#[must_use = "sends PUBLISH_NAMESPACE_DONE on drop"] +pub struct PublishNamespace { + publisher: Publisher, + state: State, + + pub info: PublishNamespaceInfo, +} + +impl PublishNamespace { + pub(super) fn new( + mut publisher: Publisher, + request_id: u64, + namespace: TrackNamespace, + ) -> (PublishNamespace, PublishNamespaceRecv) { + let info = PublishNamespaceInfo { + request_id, + namespace: namespace.clone(), + }; + + publisher.send_message(message::PublishNamespace { + id: request_id, + track_namespace: namespace.clone(), + params: Default::default(), + }); + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + }; + let recv = PublishNamespaceRecv { + state: recv, + request_id, + }; + + (send, recv) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl Drop for PublishNamespace { + fn drop(&mut self) { + if self.state.lock().closed.is_err() { + return; + } + + self.publisher.send_message(message::PublishNamespaceDone { + track_namespace: self.namespace.clone(), + }); + } +} + +impl ops::Deref for PublishNamespace { + type Target = PublishNamespaceInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct PublishNamespaceRecv { + state: State, + pub request_id: u64, +} + +impl PublishNamespaceRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/publish_namespace_received.rs b/moq-transport/src/session/publish_namespace_received.rs new file mode 100644 index 00000000..a25e6b76 --- /dev/null +++ b/moq-transport/src/session/publish_namespace_received.rs @@ -0,0 +1,116 @@ +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::{PublishNamespaceInfo, Subscriber}; + +#[derive(Default)] +struct PublishNamespaceReceivedState {} + +/// Represents an inbound PUBLISH_NAMESPACE that was received (subscriber side). +/// When dropped, sends PUBLISH_NAMESPACE_CANCEL (if ok'd) or PUBLISH_NAMESPACE_ERROR. +pub struct PublishNamespaceReceived { + session: Subscriber, + state: State, + + pub info: PublishNamespaceInfo, + + ok: bool, + error: Option, +} + +impl PublishNamespaceReceived { + pub(super) fn new( + session: Subscriber, + request_id: u64, + namespace: TrackNamespace, + ) -> (PublishNamespaceReceived, PublishNamespaceReceivedRecv) { + let info = PublishNamespaceInfo { + request_id, + namespace, + }; + + let (send, recv) = State::default().split(); + let send = Self { + session, + info, + ok: false, + error: None, + state: send, + }; + let recv = PublishNamespaceReceivedRecv { _state: recv }; + + (send, recv) + } + + pub fn ok(&mut self) -> Result<(), ServeError> { + if self.ok { + return Err(ServeError::Duplicate); + } + + self.session.send_message(message::RequestOk { + id: self.info.request_id, + params: Default::default(), + }); + + self.ok = true; + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + self.state + .lock() + .modified() + .ok_or(ServeError::Cancel)? + .await; + } + } + + pub fn close(mut self, err: ServeError) -> Result<(), ServeError> { + self.error = Some(err); + Ok(()) + } +} + +impl ops::Deref for PublishNamespaceReceived { + type Target = PublishNamespaceInfo; + + fn deref(&self) -> &PublishNamespaceInfo { + &self.info + } +} + +impl Drop for PublishNamespaceReceived { + fn drop(&mut self) { + let err = self.error.clone().unwrap_or(ServeError::Done); + + if self.ok { + self.session.send_message(message::PublishNamespaceCancel { + track_namespace: self.namespace.clone(), + error_code: err.code(), + reason_phrase: ReasonPhrase(err.to_string()), + }); + } else { + self.session.send_message(message::RequestError { + id: self.info.request_id, + error_code: err.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase(err.to_string()), + }); + } + } +} + +pub(super) struct PublishNamespaceReceivedRecv { + _state: State, +} + +impl PublishNamespaceReceivedRecv { + pub fn recv_done(self) -> Result<(), ServeError> { + Ok(()) + } +} diff --git a/moq-transport/src/session/publish_received.rs b/moq-transport/src/session/publish_received.rs new file mode 100644 index 00000000..0a99c0fb --- /dev/null +++ b/moq-transport/src/session/publish_received.rs @@ -0,0 +1,282 @@ +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::serve::ServeError; +use crate::watch::State; +use crate::{data, message, serve}; + +use super::Subscriber; + +#[derive(Debug, Clone)] +pub struct PublishReceivedInfo { + pub id: u64, + pub track_namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +impl PublishReceivedInfo { + pub fn new_from_publish(msg: &message::Publish) -> Self { + Self { + id: msg.id, + track_namespace: msg.track_namespace.clone(), + track_name: msg.track_name.clone(), + track_alias: msg.track_alias, + } + } +} + +struct PublishReceivedState { + ok: bool, + closed: Result<(), ServeError>, + writer: Option, +} + +impl Default for PublishReceivedState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + writer: None, + } + } +} + +#[must_use = "sends PUBLISH_ERROR on drop if not accepted"] +pub struct PublishReceived { + subscriber: Subscriber, + pub info: PublishReceivedInfo, + state: State, + ok: bool, +} + +impl PublishReceived { + pub(super) fn new( + subscriber: Subscriber, + msg: &message::Publish, + ) -> (Self, PublishReceivedRecv) { + let info = PublishReceivedInfo::new_from_publish(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + subscriber, + info, + state: send, + ok: false, + }; + + let recv = PublishReceivedRecv { + state: recv, + writer_mode: None, + }; + + (send, recv) + } + + pub fn accept( + mut self, + track: serve::TrackWriter, + publish_msg: message::PublishOk, + ) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } + state.closed.clone()?; + + self.subscriber.send_message(publish_msg); + + if let Some(mut state) = state.into_mut() { + state.ok = true; + state.writer = Some(track); + } + + self.ok = true; + + std::mem::forget(self); + + Ok(()) + } + + pub fn reject(mut self, error_code: u64, reason: &str) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + self.subscriber.send_message(message::RequestError { + id: self.info.id, + error_code, + retry_interval: 0, + reason_phrase: ReasonPhrase(reason.to_string()), + }); + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Closed(error_code)); + } + + std::mem::forget(self); + + Ok(()) + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + } + .await; + } + } +} + +impl ops::Deref for PublishReceived { + type Target = PublishReceivedInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for PublishReceived { + fn drop(&mut self) { + if self.ok { + return; + } + + let state = self.state.lock(); + let err = state + .closed + .as_ref() + .err() + .cloned() + .unwrap_or(ServeError::NotFound); + drop(state); + + self.subscriber.send_message(message::RequestError { + id: self.info.id, + error_code: err.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase(err.to_string()), + }); + } +} + +pub(super) struct PublishReceivedRecv { + state: State, + writer_mode: Option, +} + +impl PublishReceivedRecv { + pub fn track_alias(&self) -> Option { + None + } + + pub fn recv_done(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Done); + } + + Ok(()) + } + + fn take_writer(&mut self) -> Result { + if let Some(writer) = self.writer_mode.take() { + return Ok(writer); + } + + let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; + let writer = state.writer.take().ok_or(ServeError::Done)?; + Ok(writer.into()) + } + + fn put_writer(&mut self, writer: serve::TrackWriterMode) { + self.writer_mode = Some(writer); + } + + pub fn subgroup( + &mut self, + header: data::SubgroupHeader, + ) -> Result { + let writer = self.take_writer()?; + + let mut subgroups = match writer { + serve::TrackWriterMode::Track(track) => track.subgroups()?, + serve::TrackWriterMode::Subgroups(subgroups) => subgroups, + _ => return Err(ServeError::Mode), + }; + + let result = subgroups.create(serve::Subgroup { + group_id: header.group_id, + subgroup_id: header.subgroup_id.unwrap_or(0), + priority: header.publisher_priority.unwrap_or(127), + header_type: Some(header.header_type), + }); + + // Always put writer back, even on error, to avoid losing it + self.put_writer(subgroups.into()); + + result + } + + pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { + let writer = self.take_writer()?; + + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; + + match writer { + serve::TrackWriterMode::Track(track) => { + let mut datagrams = track.datagrams()?; + datagrams.write(serve::Datagram { + group_id: datagram.group_id, + object_id: datagram.object_id.unwrap_or(0), + priority: datagram.publisher_priority.unwrap_or(127), + payload: datagram.payload.unwrap_or_default(), + extension_headers: datagram.extension_headers.unwrap_or_default(), + status, + })?; + self.put_writer(serve::TrackWriterMode::Datagrams(datagrams)); + Ok(()) + } + serve::TrackWriterMode::Datagrams(mut datagrams) => { + datagrams.write(serve::Datagram { + group_id: datagram.group_id, + object_id: datagram.object_id.unwrap_or(0), + priority: datagram.publisher_priority.unwrap_or(127), + payload: datagram.payload.unwrap_or_default(), + extension_headers: datagram.extension_headers.unwrap_or_default(), + status, + })?; + self.put_writer(serve::TrackWriterMode::Datagrams(datagrams)); + Ok(()) + } + other => { + self.put_writer(other); + Err(ServeError::Mode) + } + } + } +} diff --git a/moq-transport/src/session/published.rs b/moq-transport/src/session/published.rs new file mode 100644 index 00000000..1d1c0f50 --- /dev/null +++ b/moq-transport/src/session/published.rs @@ -0,0 +1,621 @@ +use std::ops; +use std::sync::{Arc, Mutex}; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; + +use crate::coding::{Encode, Location, ReasonPhrase, TrackNamespace}; +use crate::message::ParameterType; +use crate::mlog; +use crate::serve::{ServeError, TrackReaderMode}; +use crate::watch::State; +use crate::{data, message, serve}; + +use super::{Publisher, SessionError, Writer}; + +#[derive(Debug, Clone)] +pub struct PublishInfo { + pub id: u64, + pub track_namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +impl PublishInfo { + pub fn new_from_publish(msg: &message::Publish) -> Self { + Self { + id: msg.id, + track_namespace: msg.track_namespace.clone(), + track_name: msg.track_name.clone(), + track_alias: msg.track_alias, + } + } +} + +#[derive(Debug)] +struct PublishedState { + ok: bool, + forward: bool, + subscriber_priority: u8, + group_order: message::GroupOrder, + largest_location: Option, + closed: Result<(), ServeError>, +} + +impl PublishedState { + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { + let new_location = Location::new(group_id, object_id); + if let Some(current) = self.largest_location { + if new_location > current { + self.largest_location = Some(new_location); + } + } else { + self.largest_location = Some(new_location); + } + Ok(()) + } +} + +impl Default for PublishedState { + fn default() -> Self { + Self { + ok: false, + forward: true, + subscriber_priority: 128, + group_order: message::GroupOrder::Ascending, + largest_location: None, + closed: Ok(()), + } + } +} + +#[must_use = "sends PUBLISH_DONE on drop"] +pub struct Published { + publisher: Publisher, + pub info: PublishInfo, + state: State, + ok: bool, + mlog: Option>>, +} + +impl Published { + pub(super) fn new( + mut publisher: Publisher, + msg: message::Publish, + mlog: Option>>, + ) -> (Self, PublishedRecv) { + let info = PublishInfo::new_from_publish(&msg); + + publisher.send_message(msg); + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + ok: false, + mlog, + }; + + let recv = PublishedRecv { state: recv }; + + (send, recv) + } + + pub async fn ok(&mut self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + self.ok = true; + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notify) => notify, + None => return Ok(()), + } + } + .await; + } + } + + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } + + pub async fn serve(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_inner(track).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve using a pre-acquired TrackReaderMode. + /// Use this when you need to acquire the mode early (before network round trips) + /// to avoid missing frames in late-join scenarios. + pub async fn serve_mode(mut self, mode: TrackReaderMode) -> Result<(), SessionError> { + let res = self.serve_mode_inner(mode).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + /// Serve immediately without waiting for PUBLISH_OK. + /// Use this for relay scenarios where you want to start forwarding data right away. + /// The subscriber will receive data as soon as they're ready. + pub async fn serve_immediately(mut self, track: serve::TrackReader) -> Result<(), SessionError> { + let res = self.serve_immediately_inner(track).await; + if let Err(err) = &res { + self.close(err.clone().into())?; + } + res + } + + async fn serve_inner(&mut self, track: serve::TrackReader) -> Result<(), SessionError> { + self.ok().await?; + + let forward = { + let state = self.state.lock(); + state.forward + }; + + if !forward { + self.closed().await?; + return Ok(()); + } + + match track.mode().await? { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_mode_inner(&mut self, mode: TrackReaderMode) -> Result<(), SessionError> { + self.ok().await?; + + let forward = { + let state = self.state.lock(); + state.forward + }; + + if !forward { + self.closed().await?; + return Ok(()); + } + + match mode { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_immediately_inner(&mut self, track: serve::TrackReader) -> Result<(), SessionError> { + // Don't wait for PUBLISH_OK - start streaming immediately + // This is useful for relay scenarios where we want minimal latency + + match track.mode().await? { + TrackReaderMode::Stream(_stream) => panic!("deprecated"), + TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, + TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + } + } + + async fn serve_subgroups( + &mut self, + mut subgroups: serve::SubgroupsReader, + ) -> Result<(), SessionError> { + let mut tasks = FuturesUnordered::new(); + let mut done: Option> = None; + + loop { + tokio::select! { + res = subgroups.next(), if done.is_none() => match res { + Ok(Some(subgroup)) => { + // Header type will be determined in serve_subgroup based on extension headers + let track_alias = self.info.track_alias; + let publisher = self.publisher.clone(); + let state = self.state.clone(); + let info = subgroup.info.clone(); + let mlog = self.mlog.clone(); + + tasks.push(async move { + if let Err(err) = Self::serve_subgroup(track_alias, subgroup, publisher, state, mlog).await { + log::warn!("failed to serve subgroup: {:?}, error: {}", info, err); + } + }); + }, + Ok(None) => done = Some(Ok(())), + Err(err) => done = Some(Err(err)), + }, + res = self.closed(), if done.is_none() => done = Some(res), + _ = tasks.next(), if !tasks.is_empty() => {}, + else => return Ok(done.unwrap()?), + } + } + } + + async fn serve_subgroup( + track_alias: u64, + mut subgroup_reader: serve::SubgroupReader, + mut publisher: Publisher, + state: State, + mlog: Option>>, + ) -> Result<(), SessionError> { + log::debug!( + "[PUBLISHED] serve_subgroup: starting - track_alias={}, group_id={}, subgroup_id={:?}, priority={}", + track_alias, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + subgroup_reader.priority + ); + + // Read the first object to determine if we have extension headers + let first_object = match subgroup_reader.next().await? { + Some(obj) => obj, + None => { + log::debug!("[PUBLISHED] serve_subgroup: no objects in subgroup, skipping"); + return Ok(()); + } + }; + + // Use preserved header type if available, otherwise determine from extension headers + let has_extension_headers = !first_object.extension_headers.is_empty(); + let header_type = subgroup_reader.info.header_type.unwrap_or_else(|| { + // Fallback: determine header type based on extension headers + if has_extension_headers { + data::StreamHeaderType::SubgroupZeroIdExtEndOfGroup + } else { + data::StreamHeaderType::SubgroupZeroIdEndOfGroup + } + }); + + // Set subgroup_id based on header type (ZeroId variants don't include it on wire) + let subgroup_id = if header_type.has_subgroup_id() { + Some(subgroup_reader.subgroup_id) + } else { + None + }; + + let header = data::SubgroupHeader { + header_type, + track_alias, + group_id: subgroup_reader.group_id, + subgroup_id, + publisher_priority: Some(subgroup_reader.priority), + }; + + let mut send_stream = publisher.open_uni().await?; + send_stream.set_priority(subgroup_reader.priority as i32); + + let mut writer = Writer::new(send_stream); + + log::debug!( + "[PUBLISHED] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}, has_ext={}", + header.track_alias, + header.group_id, + header.subgroup_id, + header.publisher_priority, + header.header_type, + has_extension_headers + ); + + writer.encode(&header).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_header_created(time, stream_id, &header); + let _ = mlog_guard.add_event(event); + } + } + + // Helper to write an object + async fn write_object( + writer: &mut Writer, + object_reader: &mut serve::SubgroupObjectReader, + has_extension_headers: bool, + object_count: u64, + subgroup_reader: &serve::SubgroupReader, + state: &State, + mlog: &Option>>, + ) -> Result<(), SessionError> { + if has_extension_headers { + let subgroup_object = data::SubgroupObjectExt { + object_id_delta: 0, + extension_headers: object_reader.extension_headers.clone(), + payload_length: object_reader.size, + status: if object_reader.size == 0 { + Some(object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHED] serve_subgroup: sending object #{} (ext) - object_id={}, payload_length={}, status={:?}", + object_count + 1, + object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_object_ext_created( + time, + stream_id, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_reader.object_id, + &subgroup_object, + ); + let _ = mlog_guard.add_event(event); + } + } + } else { + let subgroup_object = data::SubgroupObject { + object_id_delta: 0, + payload_length: object_reader.size, + status: if object_reader.size == 0 { + Some(object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHED] serve_subgroup: sending object #{} - object_id={}, payload_length={}, status={:?}", + object_count + 1, + object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; + + // No mlog for non-ext objects currently + } + + state + .lock_mut() + .ok_or(ServeError::Done)? + .update_largest_location( + subgroup_reader.group_id, + object_reader.object_id, + )?; + + while let Some(chunk) = object_reader.read().await? { + writer.write(&chunk).await?; + } + + Ok(()) + } + + // Write the first object that we already read + let mut object_count = 0; + let mut first_object = first_object; + write_object( + &mut writer, + &mut first_object, + has_extension_headers, + object_count, + &subgroup_reader, + &state, + &mlog, + ) + .await?; + object_count += 1; + + // Continue with remaining objects + while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + write_object( + &mut writer, + &mut subgroup_object_reader, + has_extension_headers, + object_count, + &subgroup_reader, + &state, + &mlog, + ) + .await?; + object_count += 1; + } + + log::info!( + "[PUBLISHED] serve_subgroup: completed subgroup (group_id={}, subgroup_id={:?}, {} objects sent, header_type={:?})", + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_count, + header_type + ); + + Ok(()) + } + + async fn serve_datagrams( + &mut self, + mut datagrams: serve::DatagramsReader, + ) -> Result<(), SessionError> { + log::debug!("[PUBLISHED] serve_datagrams: starting"); + + let mut datagram_count = 0; + while let Some(datagram) = datagrams.read().await? { + let has_extension_headers = !datagram.extension_headers.is_empty(); + let datagram_type = if has_extension_headers { + data::DatagramType::ObjectIdPayloadExt + } else { + data::DatagramType::ObjectIdPayload + }; + + let encoded_datagram = data::Datagram { + datagram_type, + track_alias: self.info.track_alias, + group_id: datagram.group_id, + object_id: Some(datagram.object_id), + publisher_priority: Some(datagram.priority), + extension_headers: if has_extension_headers { + Some(datagram.extension_headers.clone()) + } else { + None + }, + status: None, + payload: Some(datagram.payload), + }; + + let payload_len = encoded_datagram + .payload + .as_ref() + .map(|p| p.len()) + .unwrap_or(0); + let mut buffer = bytes::BytesMut::with_capacity(payload_len + 100); + encoded_datagram.encode(&mut buffer)?; + + log::debug!( + "[PUBLISHED] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={:?}, payload_len={}", + datagram_count + 1, + encoded_datagram.track_alias, + encoded_datagram.group_id, + encoded_datagram.object_id.unwrap(), + encoded_datagram.publisher_priority, + payload_len + ); + + if let Some(ref mlog) = self.mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let _ = mlog_guard.add_event(mlog::object_datagram_created( + time, + stream_id, + &encoded_datagram, + )); + } + } + + self.publisher.send_datagram(buffer.into()).await?; + + self.state + .lock_mut() + .ok_or(ServeError::Done)? + .update_largest_location( + encoded_datagram.group_id, + encoded_datagram.object_id.unwrap(), + )?; + + datagram_count += 1; + } + + log::info!( + "[PUBLISHED] serve_datagrams: completed ({} datagrams sent)", + datagram_count + ); + + Ok(()) + } +} + +impl ops::Deref for Published { + type Target = PublishInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for Published { + fn drop(&mut self) { + let state = self.state.lock(); + let err = state + .closed + .as_ref() + .err() + .cloned() + .unwrap_or(ServeError::Done); + drop(state); + + self.publisher.send_message(message::PublishDone { + id: self.info.id, + status_code: err.code(), + stream_count: 0, // TODO SLG + reason: ReasonPhrase(err.to_string()), + }); + } +} + +pub(super) struct PublishedRecv { + state: State, +} + +impl PublishedRecv { + pub fn recv_ok(&mut self, msg: &message::PublishOk) -> Result<(), ServeError> { + let state = self.state.lock(); + if state.ok { + return Err(ServeError::Duplicate); + } + + if let Some(mut state) = state.into_mut() { + state.ok = true; + + // Extract subscription properties from parameters (draft-16) + if let Some(v) = msg.params.get_intvalue(ParameterType::Forward.into()) { + state.forward = v == 1; + } + if let Some(v) = msg.params.get_intvalue(ParameterType::SubscriberPriority.into()) { + state.subscriber_priority = v as u8; + } + if let Some(v) = msg.params.get_intvalue(ParameterType::GroupOrder.into()) { + state.group_order = match v { + 0x0 => message::GroupOrder::Publisher, + 0x1 => message::GroupOrder::Ascending, + 0x2 => message::GroupOrder::Descending, + _ => message::GroupOrder::Ascending, + }; + } + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 1d0c45b1..6bc4e5c2 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -1,55 +1,49 @@ use std::{ - collections::{hash_map, HashMap}, + collections::{hash_map, HashMap, HashSet}, sync::{atomic, Arc, Mutex}, }; -use futures::{stream::FuturesUnordered, StreamExt}; - use crate::{ - coding::TrackNamespace, - message::{self, Message}, + coding::{KeyValuePairs, TrackNamespace}, + message::{self, GroupOrder, Message, ParameterType}, mlog, - serve::{ServeError, TracksReader}, + serve::{self, ServeError, TracksReader}, }; use crate::watch::Queue; use super::{ - Announce, AnnounceRecv, Session, SessionError, Subscribed, SubscribedRecv, TrackStatusRequested, + PublishNamespace, PublishNamespaceRecv, Published, PublishedRecv, Session, SessionError, + SubscribeNamespaceReceived, SubscribeNamespaceReceivedRecv, Subscribed, SubscribedRecv, + TrackStatusRequested, }; -// TODO remove Clone. #[derive(Clone)] pub struct Publisher { webtransport: web_transport::Session, - /// When the announce method is used, a new entry is added to this HashMap to track outbound announcement - announces: Arc>>, + publish_namespaces: Arc>>, + + filtered_namespaces: Arc>>, - /// When a Subscribe is received and we have a previous announce for the namespace, then a new entry is - /// added to this HashMap to track the inbound subscription subscribeds: Arc>>, - /// When a Subscribe is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound subscription unknown_subscribed: Queue, - /// When a TrackStatus is received and we DO NOT have a previous announce for the namespace, then a new entry is - /// added to this Queue to track the inbound track status request unknown_track_status_requested: Queue, - /// The queue we will write any outbound control messages we want to sent, the session run_send task - /// will process the queue and send the message on the control stream. + subscribe_namespaces_received: Arc>>, + + subscribe_namespace_received_queue: Queue, + + publisheds: Arc>>, + + next_track_alias: Arc, + outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). next_requestid: Arc, - /// Optional mlog writer for logging transport events mlog: Option>>, } @@ -62,16 +56,26 @@ impl Publisher { ) -> Self { Self { webtransport, - announces: Default::default(), + publish_namespaces: Default::default(), + filtered_namespaces: Default::default(), subscribeds: Default::default(), unknown_subscribed: Default::default(), unknown_track_status_requested: Default::default(), + subscribe_namespaces_received: Default::default(), + subscribe_namespace_received_queue: Default::default(), + publisheds: Default::default(), + next_track_alias: Arc::new(atomic::AtomicU64::new(0)), outgoing, next_requestid, mlog, } } + pub fn next_track_alias(&self) -> u64 { + self.next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed) + } + pub async fn accept( session: web_transport::Session, ) -> Result<(Session, Publisher), SessionError> { @@ -86,81 +90,37 @@ impl Publisher { Ok((session, publisher)) } - /// Announce a namespace and serve tracks using the provided [serve::TracksReader]. - /// The caller uses [serve::TracksWriter] for static tracks and [serve::TracksRequest] for dynamic tracks. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - // Check if annouce for this namespace already exists or not, and if not, then create a new Announce - let announce = match self - .announces + pub async fn publish_namespace( + &mut self, + namespace: TrackNamespace, + ) -> Result { + if self + .filtered_namespaces .lock() .unwrap() - .entry(tracks.namespace.clone()) + .contains(&namespace) + { + return Err(ServeError::Cancel.into()); + } + + let publish_ns = match self + .publish_namespaces + .lock() + .unwrap() + .entry(namespace.clone()) { - // Namespace already exists in HashMap (has already been announced) - return Duplicate error hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), - // This is a new announce, send announce message to peer. hash_map::Entry::Vacant(entry) => { - // Get the current next request id to use and increment the value for by 2 for the next request let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); - let (send, recv) = - Announce::new(self.clone(), request_id, tracks.namespace.clone()); + let (send, recv) = PublishNamespace::new(self.clone(), request_id, namespace); entry.insert(recv); send } }; - let mut subscribe_tasks = FuturesUnordered::new(); - let mut status_tasks = FuturesUnordered::new(); - let mut subscribe_done = false; - let mut status_done = false; - - // The code enters an infinite loop and waits for one of several events: - // - A new subscription arrives. - // - A new track status request arrives. - // - One of the spawned subscription-handling tasks completes. - // - One of the spawned status-handling tasks completes. - // Exit the loop when all input streams are done (None), and all tasks have completed - loop { - tokio::select! { - // Get next subscription to this announce - res = announce.subscribed(), if !subscribe_done => { - match res? { - Some(subscribed) => { - let tracks = tracks.clone(); - - subscribe_tasks.push(async move { - let info = subscribed.info.clone(); - if let Err(err) = Self::serve_subscribe(subscribed, tracks).await { - log::warn!("failed serving subscribe: {:?}, error: {}", info, err) - } - }); - }, - None => subscribe_done = true, - } - - }, - res = announce.track_status_requested(), if !status_done => { - match res? { - Some(status) => { - let tracks = tracks.clone(); - - status_tasks.push(async move { - let request_msg = status.request_msg.clone(); - if let Err(err) = Self::serve_track_status(status, tracks).await { - log::warn!("failed serving track status request: {:?}, error: {}", request_msg, err) - } - }); - }, - None => status_done = true, - } - }, - Some(res) = subscribe_tasks.next() => res, - Some(res) = status_tasks.next() => res, - else => return Ok(()) - } - } + Ok(publish_ns) } pub async fn serve_subscribe( @@ -206,16 +166,87 @@ impl Publisher { Ok(()) } - // Returns subscriptions that do not map to an active announce. pub async fn subscribed(&mut self) -> Option { self.unknown_subscribed.pop().await } - // Returns track_status requests that do not map to an active announce. pub async fn track_status_requested(&mut self) -> Option { self.unknown_track_status_requested.pop().await } + pub async fn subscribe_namespace_received(&mut self) -> Option { + self.subscribe_namespace_received_queue.pop().await + } + + pub async fn publish(&mut self, track: serve::TrackReader) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), GroupOrder::Ascending as u64); + params.set_intvalue(ParameterType::Forward.into(), 1); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions: Default::default(), + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + + pub async fn publish_with_options( + &mut self, + track: serve::TrackReader, + group_order: GroupOrder, + forward: bool, + ) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), group_order as u64); + params.set_intvalue(ParameterType::Forward.into(), if forward { 1 } else { 0 }); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions: Default::default(), + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { let res = match msg { message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg), @@ -226,23 +257,13 @@ impl Publisher { Err(SessionError::unimplemented("FETCH_CANCEL")) } message::Subscriber::TrackStatus(msg) => self.recv_track_status(msg), - message::Subscriber::SubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE")) - } - message::Subscriber::UnsubscribeNamespace(_msg) => { - Err(SessionError::unimplemented("UNSUBSCRIBE_NAMESPACE")) - } + message::Subscriber::SubscribeNamespace(msg) => self.recv_subscribe_namespace(msg), message::Subscriber::PublishNamespaceCancel(msg) => { self.recv_publish_namespace_cancel(msg) } - message::Subscriber::PublishNamespaceOk(msg) => self.recv_publish_namespace_ok(msg), - message::Subscriber::PublishNamespaceError(msg) => { - self.recv_publish_namespace_error(msg) - } - message::Subscriber::PublishOk(_msg) => Err(SessionError::unimplemented("PUBLISH_OK")), - message::Subscriber::PublishError(_msg) => { - Err(SessionError::unimplemented("PUBLISH_ERROR")) - } + message::Subscriber::RequestOk(msg) => self.recv_request_ok(msg), + message::Subscriber::PublishOk(msg) => self.recv_publish_ok(msg), + message::Subscriber::RequestError(msg) => self.recv_request_error(msg), }; if let Err(err) = res { @@ -252,42 +273,29 @@ impl Publisher { Ok(()) } - fn recv_publish_namespace_ok( - &mut self, - msg: message::PublishNamespaceOk, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); - let announce = announces.iter_mut().find(|(_k, v)| v.request_id == msg.id); - - if let Some(announce) = announce { - announce.1.recv_ok()?; + fn recv_request_ok(&mut self, msg: message::RequestOk) -> Result<(), SessionError> { + let mut publish_namespaces = self.publish_namespaces.lock().unwrap(); + let entry = publish_namespaces + .iter_mut() + .find(|(_k, v)| v.request_id == msg.id); + + if let Some(entry) = entry { + entry.1.recv_ok()?; } Ok(()) } - fn recv_publish_namespace_error( - &mut self, - msg: message::PublishNamespaceError, - ) -> Result<(), SessionError> { - // We need to find the announce request using the request id, however the self.announces data structure - // is a HashMap indexed by Namespace (which is needed for handling PUBLISH_NAMESPACE_CANCEL). TODO - make more efficient. - // For now iterate through all self.annouces until we find the matching id. - let mut announces = self.announces.lock().unwrap(); + fn recv_request_error(&mut self, msg: message::RequestError) -> Result<(), SessionError> { + let mut publish_namespaces = self.publish_namespaces.lock().unwrap(); - // Find the key first (immutable borrow only) - let key_opt = announces + let key_opt = publish_namespaces .iter() .find(|(_k, v)| v.request_id == msg.id) .map(|(k, _)| k.clone()); - // Remove from HashMap and take ownership if let Some(key) = key_opt { - if let Some((_ns, v)) = announces.remove_entry(&key) { - // Step 3: call recv_error, consuming v + if let Some((_ns, v)) = publish_namespaces.remove_entry(&key) { v.recv_error(ServeError::Closed(msg.error_code))?; } } @@ -299,10 +307,21 @@ impl Publisher { &mut self, msg: message::PublishNamespaceCancel, ) -> Result<(), SessionError> { - // TODO: If a publisher receives new subscriptions for that namespace after receiving an ANNOUNCE_CANCEL, - // it SHOULD close the session as a 'Protocol Violation'. - if let Some(announce) = self.announces.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_error(ServeError::Cancel)?; + if let Some(entry) = self + .publish_namespaces + .lock() + .unwrap() + .remove(&msg.track_namespace) + { + entry.recv_error(ServeError::Cancel)?; + } + + Ok(()) + } + + fn recv_publish_ok(&mut self, msg: message::PublishOk) -> Result<(), SessionError> { + if let Some(published) = self.publisheds.lock().unwrap().get_mut(&msg.id) { + published.recv_ok(&msg)?; } Ok(()) @@ -314,29 +333,18 @@ impl Publisher { let subscribed = { let mut subscribeds = self.subscribeds.lock().unwrap(); - // See if entry exists for this request id already, if so error out let entry = match subscribeds.entry(msg.id) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create new Subscribed entry and add to HashMap let (send, recv) = Subscribed::new(self.clone(), msg, self.mlog.clone()); entry.insert(recv); send }; - // If we have an announce, route the subscribe to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce.recv_subscribe(subscribed).map_err(Into::into); - } - - // Otherwise, put it in the unknown queue. - // TODO Have some way to detect if the application is not reading from the unknown queue, - // then send SubscribeError. if let Err(err) = self.unknown_subscribed.push(subscribed) { - // Default to closing with a not found error I guess. err.close(ServeError::not_found_ctx(format!( "unknown_subscribed queue full for namespace {:?}", namespace @@ -355,26 +363,12 @@ impl Publisher { } fn recv_track_status(&mut self, msg: message::TrackStatus) -> Result<(), SessionError> { - let namespace = msg.track_namespace.clone(); - - // Create TrackStatusRequested to track this request let track_status_requested = TrackStatusRequested::new(self.clone(), msg); - // If we have an announce, route the track_status to it. - if let Some(announce) = self.announces.lock().unwrap().get_mut(&namespace) { - return announce - .recv_track_status_requested(track_status_requested) - .map_err(Into::into); - } - - // Otherwise, put it in the unknown_track_status queue. - // TODO Have some way to detect if the application is not reading from the unknown_track_status queue, - // then send TrackStatusError. if let Err(mut err) = self .unknown_track_status_requested .push(track_status_requested) { - // push only fails if the queue is dropped, send TrackStatusError, Internal error err.respond_error(0, "Internal error")?; } @@ -389,15 +383,48 @@ impl Publisher { Ok(()) } - /// Process a message before sending it, performing any necessary internal actions. + fn recv_subscribe_namespace( + &mut self, + msg: message::SubscribeNamespace, + ) -> Result<(), SessionError> { + let namespace_prefix = msg.track_namespace_prefix.clone(); + + self.filtered_namespaces + .lock() + .unwrap() + .remove(&namespace_prefix); + + let mut entries = self.subscribe_namespaces_received.lock().unwrap(); + + let entry = match entries.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = + SubscribeNamespaceReceived::new(self.clone(), msg.id, namespace_prefix); + + if let Err(send) = self.subscribe_namespace_received_queue.push(send) { + send.reject(0x0, "Internal error")?; + return Ok(()); + } + + entry.insert(recv); + + Ok(()) + } + fn act_on_message_to_send>( &mut self, msg: T, ) -> message::Publisher { let msg = msg.into(); match &msg { - message::Publisher::PublishDone(m) => self.drop_subscribe(m.id), - message::Publisher::SubscribeError(m) => self.drop_subscribe(m.id), + message::Publisher::PublishDone(m) => { + self.drop_subscribe(m.id); + self.drop_published(m.id); + } + message::Publisher::RequestError(m) => self.drop_subscribe(m.id), message::Publisher::PublishNamespaceDone(m) => { self.drop_publish_namespace(&m.track_namespace); } @@ -429,7 +456,11 @@ impl Publisher { } fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announces.lock().unwrap().remove(namespace); + self.publish_namespaces.lock().unwrap().remove(namespace); + } + + fn drop_published(&mut self, id: u64) { + self.publisheds.lock().unwrap().remove(&id); } pub(super) async fn open_uni(&mut self) -> Result { @@ -439,4 +470,16 @@ impl Publisher { pub(super) async fn send_datagram(&mut self, data: bytes::Bytes) -> Result<(), SessionError> { Ok(self.webtransport.send_datagram(data).await?) } + + /// Forward a PUBLISH message to the subscriber (used by relay for SUBSCRIBE_NAMESPACE flow). + /// This sends the message without tracking it for PUBLISH_OK response handling. + pub fn forward_publish(&mut self, msg: message::Publish) { + self.outgoing.push(msg.into()).ok(); + } + + /// Forward a NAMESPACE message to the subscriber (used by relay for SUBSCRIBE_NAMESPACE flow). + /// This announces a namespace that matches the subscriber's SUBSCRIBE_NAMESPACE prefix. + pub fn forward_namespace(&mut self, msg: message::Namespace) { + self.outgoing.push(msg.into()).ok(); + } } diff --git a/moq-transport/src/session/reader.rs b/moq-transport/src/session/reader.rs index 18a6ae16..1dd05530 100644 --- a/moq-transport/src/session/reader.rs +++ b/moq-transport/src/session/reader.rs @@ -68,7 +68,7 @@ impl Reader { // We always read at least once to avoid an infinite loop if some dingus puts remain=0 loop { let before_read = self.buffer.len(); - if !self.stream.read_buf(&mut self.buffer).await? { + if self.stream.read_buf(&mut self.buffer).await?.is_none() { log::warn!( "[READER] decode: stream ended while waiting for data (have={} bytes, need={})", self.buffer.len(), @@ -113,7 +113,7 @@ impl Reader { return Ok(Some(data)); } - let chunk = self.stream.read_chunk(max).await?; + let chunk = self.stream.read(max).await?; if let Some(ref data) = chunk { log::trace!("[READER] read_chunk: read {} bytes from stream", data.len()); } else { @@ -127,6 +127,6 @@ impl Reader { return Ok(false); } - Ok(!self.stream.read_buf(&mut self.buffer).await?) + Ok(self.stream.read_buf(&mut self.buffer).await?.is_none()) } } diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index b536641b..f9bc1f47 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -1,9 +1,8 @@ use std::ops; use crate::{ - coding::{KeyValuePairs, Location, TrackNamespace}, - data, - message::{self, FilterType, GroupOrder}, + coding::{KeyValuePairs, TrackNamespace}, + data, message, serve::{self, ServeError, TrackWriter, TrackWriterMode}, }; @@ -17,22 +16,6 @@ pub struct SubscribeInfo { pub id: u64, pub track_namespace: TrackNamespace, pub track_name: String, - - /// Subscriber Priority - pub subscriber_priority: u8, - pub group_order: GroupOrder, - - /// Forward Flag - pub forward: bool, - - /// Filter type - pub filter_type: FilterType, - - /// The starting location for this subscription. Only present for "AbsoluteStart" and "AbsoluteRange" filter types. - pub start_location: Option, - /// End group id, inclusive, for the subscription, if applicable. Only present for "AbsoluteRange" filter type. - pub end_group_id: Option, - /// Optional parameters pub params: KeyValuePairs, @@ -46,12 +29,6 @@ impl SubscribeInfo { id: msg.id, track_namespace: msg.track_namespace.clone(), track_name: msg.track_name.clone(), - subscriber_priority: msg.subscriber_priority, - group_order: msg.group_order, - forward: msg.forward, - filter_type: msg.filter_type, - start_location: msg.start_location, - end_group_id: msg.end_group_id, params: msg.params.clone(), track_status: false, } @@ -93,13 +70,6 @@ impl Subscribe { id: request_id, track_namespace: track.namespace.clone(), track_name: track.name.clone(), - // TODO add prioritization logic on the publisher side - subscriber_priority: 127, // default to mid value, see: https://github.com/moq-wg/moq-transport/issues/504 - group_order: GroupOrder::Publisher, // defer to publisher send order - forward: true, // default to forwarding objects - filter_type: FilterType::LargestObject, - start_location: None, - end_group_id: None, params: Default::default(), }; let info = SubscribeInfo::new_from_subscribe(&subscribe_message); @@ -205,16 +175,20 @@ impl SubscribeRecv { _ => return Err(ServeError::Mode), }; - let writer = subgroups.create(serve::Subgroup { + let result = subgroups.create(serve::Subgroup { group_id: header.group_id, // When subgroup_id is not present in the header type, it implicitly means subgroup 0 subgroup_id: header.subgroup_id.unwrap_or(0), - priority: header.publisher_priority, - })?; + // When priority is not present (NoPriority header types), default to 0 + priority: header.publisher_priority.unwrap_or(0), + // Preserve the incoming header type for forwarding + header_type: Some(header.header_type), + }); + // Always put writer back, even on error, to avoid losing it self.writer = Some(subgroups.into()); - Ok(writer) + result } pub fn datagram(&mut self, datagram: data::Datagram) -> Result<(), ServeError> { @@ -224,23 +198,39 @@ impl SubscribeRecv { TrackWriterMode::Track(track) => { // convert Track -> Datagrams writer, write, then put Datagrams back let mut datagrams = track.datagrams()?; + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id.unwrap_or(0), - priority: datagram.publisher_priority, + // When priority is not present (NoPriority datagram types), default to 0 + priority: datagram.publisher_priority.unwrap_or(0), payload: datagram.payload.unwrap_or_default(), extension_headers: datagram.extension_headers.unwrap_or_default(), + status, })?; self.writer = Some(TrackWriterMode::Datagrams(datagrams)); Ok(()) } TrackWriterMode::Datagrams(mut datagrams) => { + // Determine status from datagram type or explicit status field + let status = if datagram.datagram_type.is_end_of_group() { + Some(crate::data::ObjectStatus::EndOfGroup) + } else { + datagram.status + }; datagrams.write(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id.unwrap_or(0), - priority: datagram.publisher_priority, + // When priority is not present (NoPriority datagram types), default to 0 + priority: datagram.publisher_priority.unwrap_or(0), payload: datagram.payload.unwrap_or_default(), extension_headers: datagram.extension_headers.unwrap_or_default(), + status, })?; self.writer = Some(TrackWriterMode::Datagrams(datagrams)); Ok(()) diff --git a/moq-transport/src/session/subscribe_namespace.rs b/moq-transport/src/session/subscribe_namespace.rs new file mode 100644 index 00000000..a89d7ecc --- /dev/null +++ b/moq-transport/src/session/subscribe_namespace.rs @@ -0,0 +1,143 @@ +use std::ops; + +use crate::coding::TrackNamespace; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Subscriber; + +#[derive(Debug, Clone)] +pub struct SubscribeNsInfo { + pub request_id: u64, + pub namespace_prefix: TrackNamespace, +} + +struct SubscribeNsState { + ok: bool, + closed: Result<(), ServeError>, +} + +impl Default for SubscribeNsState { + fn default() -> Self { + Self { + ok: false, + closed: Ok(()), + } + } +} + +/// Represents an outbound SUBSCRIBE_NAMESPACE request (subscriber side). +/// When dropped, sends UNSUBSCRIBE_NAMESPACE to the peer. +#[must_use = "sends UNSUBSCRIBE_NAMESPACE on drop"] +pub struct SubscribeNs { + subscriber: Subscriber, + state: State, + + pub info: SubscribeNsInfo, +} + +impl SubscribeNs { + pub(super) fn new( + mut subscriber: Subscriber, + request_id: u64, + namespace_prefix: TrackNamespace, + ) -> (SubscribeNs, SubscribeNsRecv) { + let info = SubscribeNsInfo { + request_id, + namespace_prefix: namespace_prefix.clone(), + }; + + subscriber.send_message(message::SubscribeNamespace::new( + request_id, + namespace_prefix, + 1, + )); + + let (send, recv) = State::default().split(); + + let send = Self { + subscriber, + info, + state: send, + }; + let recv = SubscribeNsRecv { state: recv }; + + (send, recv) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } + + pub async fn ok(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + if state.ok { + return Ok(()); + } + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl Drop for SubscribeNs { + fn drop(&mut self) { + // In draft-16, SUBSCRIBE_NAMESPACE uses its own bidirectional stream. + // Closing the stream implicitly unsubscribes. + } +} + +impl ops::Deref for SubscribeNs { + type Target = SubscribeNsInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +pub(super) struct SubscribeNsRecv { + state: State, +} + +impl SubscribeNsRecv { + pub fn recv_ok(&mut self) -> Result<(), ServeError> { + if let Some(mut state) = self.state.lock_mut() { + if state.ok { + return Err(ServeError::Duplicate); + } + + state.ok = true; + } + + Ok(()) + } + + pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + let mut state = state.into_mut().ok_or(ServeError::Done)?; + state.closed = Err(err); + + Ok(()) + } +} diff --git a/moq-transport/src/session/subscribe_namespace_received.rs b/moq-transport/src/session/subscribe_namespace_received.rs new file mode 100644 index 00000000..60b3a333 --- /dev/null +++ b/moq-transport/src/session/subscribe_namespace_received.rs @@ -0,0 +1,148 @@ +use std::ops; + +use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::watch::State; +use crate::{message, serve::ServeError}; + +use super::Publisher; + +#[derive(Debug, Clone)] +pub struct SubscribeNamespaceReceivedInfo { + pub request_id: u64, + pub namespace_prefix: TrackNamespace, +} + +struct SubscribeNamespaceReceivedState { + closed: Result<(), ServeError>, +} + +impl Default for SubscribeNamespaceReceivedState { + fn default() -> Self { + Self { closed: Ok(()) } + } +} + +#[must_use = "sends SUBSCRIBE_NAMESPACE_ERROR on drop if not accepted"] +pub struct SubscribeNamespaceReceived { + publisher: Publisher, + state: State, + pub info: SubscribeNamespaceReceivedInfo, + ok: bool, +} + +impl SubscribeNamespaceReceived { + pub(super) fn new( + publisher: Publisher, + request_id: u64, + namespace_prefix: TrackNamespace, + ) -> (Self, SubscribeNamespaceReceivedRecv) { + let info = SubscribeNamespaceReceivedInfo { + request_id, + namespace_prefix: namespace_prefix.clone(), + }; + + let (send, recv) = State::default().split(); + + let send = Self { + publisher, + info, + state: send, + ok: false, + }; + + let recv = SubscribeNamespaceReceivedRecv { + state: recv, + namespace_prefix, + }; + + (send, recv) + } + + pub fn ok(&mut self) -> Result<(), ServeError> { + if self.ok { + return Err(ServeError::Duplicate); + } + + self.publisher.send_message(message::RequestOk { + id: self.info.request_id, + params: Default::default(), + }); + + self.ok = true; + + Ok(()) + } + + pub fn reject(mut self, error_code: u64, reason: &str) -> Result<(), ServeError> { + self.publisher.send_message(message::RequestError { + id: self.info.request_id, + error_code, + retry_interval: 0, + reason_phrase: ReasonPhrase(reason.to_string()), + }); + + self.ok = true; + + Ok(()) + } + + pub async fn closed(&self) -> Result<(), ServeError> { + loop { + { + let state = self.state.lock(); + state.closed.clone()?; + + match state.modified() { + Some(notified) => notified, + None => return Ok(()), + } + } + .await; + } + } +} + +impl ops::Deref for SubscribeNamespaceReceived { + type Target = SubscribeNamespaceReceivedInfo; + + fn deref(&self) -> &Self::Target { + &self.info + } +} + +impl Drop for SubscribeNamespaceReceived { + fn drop(&mut self) { + if self.ok { + return; + } + + self.publisher.send_message(message::RequestError { + id: self.info.request_id, + error_code: ServeError::NotFound.code(), + retry_interval: 0, + reason_phrase: ReasonPhrase("SUBSCRIBE_NAMESPACE not handled".to_string()), + }); + } +} + +pub(super) struct SubscribeNamespaceReceivedRecv { + state: State, + namespace_prefix: TrackNamespace, +} + +impl SubscribeNamespaceReceivedRecv { + pub fn namespace_prefix(&self) -> &TrackNamespace { + &self.namespace_prefix + } + + pub fn recv_unsubscribe(&mut self) -> Result<(), ServeError> { + let state = self.state.lock(); + state.closed.clone()?; + + if let Some(mut state) = state.into_mut() { + state.closed = Err(ServeError::Cancel); + } + + Ok(()) + } +} diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index e87961fc..4f166d02 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -20,7 +20,20 @@ struct SubscribedState { closed: Result<(), ServeError>, } +impl Default for SubscribedState { + fn default() -> Self { + Self { + largest_location: None, + closed: Ok(()), + } + } +} + impl SubscribedState { + fn is_closed(&self) -> bool { + self.closed.is_err() + } + fn update_largest_location(&mut self, group_id: u64, object_id: u64) -> Result<(), ServeError> { if let Some(current_largest_location) = self.largest_location { let update_largest_location = Location::new(group_id, object_id); @@ -33,15 +46,6 @@ impl SubscribedState { } } -impl Default for SubscribedState { - fn default() -> Self { - Self { - largest_location: None, - closed: Ok(()), - } - } -} - pub struct Subscribed { /// The sessions Publisher manager, used to send control messages, /// create new QUIC streams, and send datagrams @@ -66,7 +70,7 @@ impl Subscribed { msg: message::Subscribe, mlog: Option>>, ) -> (Self, SubscribedRecv) { - let (send, recv) = State::default().split(); + let (send, recv) = State::new(SubscribedState::default()).split(); let info = SubscribeInfo::new_from_subscribe(&msg); let send = Self { publisher, @@ -102,14 +106,12 @@ impl Subscribed { // Send SubscribeOk using send_message_and_wait to ensure it is sent at least to the QUIC stack before // we start serving the track. If a subscriber gets the stream before SubscribeOk // then they won't recognize the track_alias in the stream header. + let track_alias = self.publisher.next_track_alias(); self.publisher .send_message_and_wait(message::SubscribeOk { id: self.info.id, - track_alias: self.info.id, // use subscription id as track alias - expires: 0, // TODO SLG - group_order: message::GroupOrder::Descending, // TODO: resolve correct value from publisher / subscriber prefs - content_exists: largest_location.is_some(), - largest_location, + track_alias, + track_extensions: Default::default(), params: Default::default(), }) .await; @@ -120,8 +122,12 @@ impl Subscribed { match track.mode().await? { // TODO cancel track/datagrams on closed TrackReaderMode::Stream(_stream) => panic!("deprecated"), - TrackReaderMode::Subgroups(subgroups) => self.serve_subgroups(subgroups).await, - TrackReaderMode::Datagrams(datagrams) => self.serve_datagrams(datagrams).await, + TrackReaderMode::Subgroups(subgroups) => { + self.serve_subgroups(subgroups, track_alias).await + } + TrackReaderMode::Datagrams(datagrams) => { + self.serve_datagrams(datagrams, track_alias).await + } } } @@ -178,9 +184,10 @@ impl Drop for Subscribed { reason: ReasonPhrase(err.to_string()), }); } else { - self.publisher.send_message(message::SubscribeError { + self.publisher.send_message(message::RequestError { id: self.info.id, error_code: err.code(), + retry_interval: 0, reason_phrase: ReasonPhrase(err.to_string()), }); }; @@ -191,6 +198,7 @@ impl Subscribed { async fn serve_subgroups( &mut self, mut subgroups: serve::SubgroupsReader, + track_alias: u64, ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); let mut done: Option> = None; @@ -199,12 +207,19 @@ impl Subscribed { tokio::select! { res = subgroups.next(), if done.is_none() => match res { Ok(Some(subgroup)) => { + // Use preserved header type if available, otherwise default to SubgroupIdExt + let header_type = subgroup.info.header_type.unwrap_or(data::StreamHeaderType::SubgroupIdExt); + let subgroup_id = if header_type.has_subgroup_id() { + Some(subgroup.subgroup_id) + } else { + None + }; let header = data::SubgroupHeader { - header_type: data::StreamHeaderType::SubgroupIdExt, // SubGroupId = Yes, Extensions = Yes, ContainsEndOfGroup = No - track_alias: self.info.id, // use subscription id as track_alias + header_type, + track_alias, group_id: subgroup.group_id, - subgroup_id: Some(subgroup.subgroup_id), - publisher_priority: subgroup.priority, + subgroup_id, + publisher_priority: Some(subgroup.priority), }; let publisher = self.publisher.clone(); @@ -251,7 +266,7 @@ impl Subscribed { let mut writer = Writer::new(send_stream); log::debug!( - "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={}, header_type={:?}", + "[PUBLISHER] serve_subgroup: sending header - track_alias={}, group_id={}, subgroup_id={:?}, priority={:?}, header_type={:?}", header.track_alias, header.group_id, header.subgroup_id, @@ -273,6 +288,16 @@ impl Subscribed { let mut object_count = 0; while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { + if state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_subgroup: subscription cancelled, stopping (group_id={}, subgroup_id={:?}, {} objects sent)", + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + object_count + ); + return Ok(()); + } + let subgroup_object = data::SubgroupObjectExt { object_id_delta: 0, // before delta logic, used to be subgroup_object_reader.object_id, extension_headers: subgroup_object_reader.extension_headers.clone(), // Pass through extension headers @@ -325,6 +350,13 @@ impl Subscribed { let mut chunks_sent = 0; let mut bytes_sent = 0; while let Some(chunk) = subgroup_object_reader.read().await? { + if state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_subgroup: subscription cancelled during payload transfer" + ); + return Ok(()); + } + log::trace!( "[PUBLISHER] serve_subgroup: sending payload chunk #{} for object #{} ({} bytes)", chunks_sent + 1, @@ -358,12 +390,20 @@ impl Subscribed { async fn serve_datagrams( &mut self, mut datagrams: serve::DatagramsReader, + track_alias: u64, ) -> Result<(), SessionError> { log::debug!("[PUBLISHER] serve_datagrams: starting"); let mut datagram_count = 0; while let Some(datagram) = datagrams.read().await? { - // Determine datagram type based on extension headers presence + if self.state.lock().is_closed() { + log::debug!( + "[PUBLISHER] serve_datagrams: subscription cancelled, stopping ({} datagrams sent)", + datagram_count + ); + return Ok(()); + } + let has_extension_headers = !datagram.extension_headers.is_empty(); let datagram_type = if has_extension_headers { data::DatagramType::ObjectIdPayloadExt @@ -373,10 +413,10 @@ impl Subscribed { let encoded_datagram = data::Datagram { datagram_type, - track_alias: self.info.id, // use subscription id as track_alias + track_alias, group_id: datagram.group_id, object_id: Some(datagram.object_id), - publisher_priority: datagram.priority, + publisher_priority: Some(datagram.priority), extension_headers: if has_extension_headers { Some(datagram.extension_headers.clone()) } else { @@ -395,7 +435,7 @@ impl Subscribed { encoded_datagram.encode(&mut buffer)?; log::debug!( - "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={}, payload_len={}, extension_headers={:?}, total_encoded_len={}", + "[PUBLISHER] serve_datagrams: sending datagram #{} - track_alias={}, group_id={}, object_id={}, priority={:?}, payload_len={}, extension_headers={:?}, total_encoded_len={}", datagram_count + 1, encoded_datagram.track_alias, encoded_datagram.group_id, diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 49c59e8d..3ec9db57 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -17,41 +17,40 @@ use crate::{ use crate::watch::Queue; -use super::{Announced, AnnouncedRecv, Reader, Session, SessionError, Subscribe, SubscribeRecv}; +use super::{ + PublishNamespaceReceived, PublishNamespaceReceivedRecv, PublishReceived, PublishReceivedRecv, + Reader, Session, SessionError, Subscribe, SubscribeNs, SubscribeNsRecv, SubscribeRecv, +}; // Default timeout for waiting for subscribe aliases to become available via SUBSCRIBE_OK (1 second) const DEFAULT_ALIAS_WAIT_TIME_MS: u64 = 1000; -// TODO remove Clone. #[derive(Clone)] pub struct Subscriber { - /// The currently active inbound announces, keyed by namespace. - announced: Arc>>, + publish_namespaces_received: Arc>>, + + publish_namespace_received_queue: Queue, - /// Queue of announced namespaces we have received from the Publisher, waiting to be processed. - announced_queue: Queue, + subscribe_namespaces: Arc>>, - /// The currently active outbound subscribes, keyed by request id. subscribes: Arc>>, - /// Map of track alias to subscription id for quick lookup when receiving streams/datagrams. subscribe_alias_map: Arc>>, - /// Notify when subscribe alias map is updated subscribe_alias_notify: Arc, - /// The queue we will write any outbound control messages we want to send, the session run_send task - /// will process the queue and send the message on the control stream. + publishes_received: Arc>>, + + publish_received_queue: Queue, + + publish_alias_map: Arc>>, + + publish_alias_notify: Arc, + outgoing: Queue, - /// When we need a new Request Id for sending a request, we can get it from here. Note: The instance - /// of AtomicU64 is shared with the Subscriber, so the session uses unique request ids for all requests - /// generated. Note: If we initiated the QUIC connection then request id's start at 0 and increment by 2 - /// for each request (even numbers). If we accepted an inbound QUIC connection then request id's start at 1 and - /// increment by 2 for each request (odd numbers). next_requestid: Arc, - /// Optional mlog writer for logging transport events mlog: Option>>, } @@ -62,14 +61,19 @@ impl Subscriber { mlog: Option>>, ) -> Self { Self { - announced: Default::default(), - announced_queue: Default::default(), + publish_namespaces_received: Default::default(), + publish_namespace_received_queue: Default::default(), + subscribe_namespaces: Default::default(), subscribes: Default::default(), subscribe_alias_map: Default::default(), + subscribe_alias_notify: Arc::new(Notify::new()), + publishes_received: Default::default(), + publish_received_queue: Default::default(), + publish_alias_map: Default::default(), + publish_alias_notify: Arc::new(Notify::new()), outgoing, next_requestid, mlog, - subscribe_alias_notify: Arc::new(Notify::new()), } } @@ -85,9 +89,12 @@ impl Subscriber { Ok((session, subscriber)) } - /// Wait for the next announced namespace from the publisher, if any. - pub async fn announced(&mut self) -> Option { - self.announced_queue.pop().await + pub async fn publish_ns_recvd(&mut self) -> Option { + self.publish_namespace_received_queue.pop().await + } + + pub async fn publish_received(&mut self) -> Option { + self.publish_received_queue.pop().await } /// Get the current next request id to use and increment the value for by 2 for the next request @@ -120,45 +127,49 @@ impl Subscriber { send.closed().await } + pub fn subscribe_ns( + &mut self, + namespace_prefix: TrackNamespace, + ) -> Result { + let request_id = self.get_next_request_id(); + + let mut subscribe_namespaces = self.subscribe_namespaces.lock().unwrap(); + let entry = match subscribe_namespaces.entry(request_id) { + hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (send, recv) = SubscribeNs::new(self.clone(), request_id, namespace_prefix); + entry.insert(recv); + + Ok(send) + } + /// Send a message to the publisher via the control stream. pub(super) fn send_message>(&mut self, msg: M) { let msg = msg.into(); // Remove our entry on terminal state. - match &msg { - message::Subscriber::PublishNamespaceCancel(msg) => { - self.drop_publish_namespace(&msg.track_namespace) - } - // TODO SLG - there is no longer a namespace in the error, need to map via request id - message::Subscriber::PublishNamespaceError(_msg) => {} // Not implemented yet - need request id mapping - _ => {} + if let message::Subscriber::PublishNamespaceCancel(msg) = &msg { + self.drop_publish_namespace(&msg.track_namespace) } // TODO report dropped messages? let _ = self.outgoing.push(msg.into()); } - /// Receive a message from the publisher via the control stream. pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { let res = match &msg { message::Publisher::PublishNamespace(msg) => self.recv_publish_namespace(msg), - message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_namespace_done(msg), - message::Publisher::Publish(_msg) => Err(SessionError::unimplemented("PUBLISH")), + message::Publisher::PublishNamespaceDone(msg) => self.recv_publish_ns_done(msg), + message::Publisher::Namespace(msg) => self.recv_namespace(msg), + message::Publisher::Publish(msg) => self.recv_publish(msg), message::Publisher::PublishDone(msg) => self.recv_publish_done(msg), message::Publisher::SubscribeOk(msg) => self.recv_subscribe_ok(msg), - message::Publisher::SubscribeError(msg) => self.recv_subscribe_error(msg), + message::Publisher::RequestError(msg) => self.recv_request_error(msg), message::Publisher::TrackStatusOk(msg) => self.recv_track_status_ok(msg), - message::Publisher::TrackStatusError(_msg) => { - Err(SessionError::unimplemented("TRACK_STATUS_ERROR")) - } message::Publisher::FetchOk(_msg) => Err(SessionError::unimplemented("FETCH_OK")), - message::Publisher::FetchError(_msg) => Err(SessionError::unimplemented("FETCH_ERROR")), - message::Publisher::SubscribeNamespaceOk(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_OK")) - } - message::Publisher::SubscribeNamespaceError(_msg) => { - Err(SessionError::unimplemented("SUBSCRIBE_NAMESPACE_ERROR")) - } + message::Publisher::RequestOk(msg) => self.recv_request_ok(msg), }; if let Err(SessionError::Serve(err)) = res { @@ -169,23 +180,24 @@ impl Subscriber { res } - /// Handle the reception of a PublishNamespace message from the publisher. fn recv_publish_namespace( &mut self, msg: &message::PublishNamespace, ) -> Result<(), SessionError> { - let mut announces = self.announced.lock().unwrap(); + let mut entries = self.publish_namespaces_received.lock().unwrap(); - // Check for duplicate namespace announcement - let entry = match announces.entry(msg.track_namespace.clone()) { + let entry = match entries.entry(msg.track_namespace.clone()) { hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), hash_map::Entry::Vacant(entry) => entry, }; - // Create the announced namespace and insert it into our map of active announces, and the announced queue. - let (announced, recv) = Announced::new(self.clone(), msg.id, msg.track_namespace.clone()); - if let Err(announced) = self.announced_queue.push(announced) { - announced.close(ServeError::Cancel)?; + let (publish_ns_received, recv) = + PublishNamespaceReceived::new(self.clone(), msg.id, msg.track_namespace.clone()); + if let Err(publish_ns_received) = self + .publish_namespace_received_queue + .push(publish_ns_received) + { + publish_ns_received.close(ServeError::Cancel)?; return Ok(()); } entry.insert(recv); @@ -193,14 +205,55 @@ impl Subscriber { Ok(()) } - /// Handle the reception of a PublishNamespaceDone message from the publisher. - fn recv_publish_namespace_done( + fn recv_publish_ns_done( &mut self, msg: &message::PublishNamespaceDone, ) -> Result<(), SessionError> { - if let Some(announce) = self.announced.lock().unwrap().remove(&msg.track_namespace) { - announce.recv_unannounce()?; + if let Some(entry) = self + .publish_namespaces_received + .lock() + .unwrap() + .remove(&msg.track_namespace) + { + entry.recv_done()?; + } + + Ok(()) + } + + /// Handle NAMESPACE message (draft-16) - relay forwards this in response to SUBSCRIBE_NAMESPACE + fn recv_namespace(&mut self, msg: &message::Namespace) -> Result<(), SessionError> { + log::info!( + "received NAMESPACE for {:?} (request_id={})", + msg.track_namespace, + msg.id + ); + // TODO: Implement proper handling - notify the SUBSCRIBE_NAMESPACE handler + // For now, just log and accept + Ok(()) + } + + fn recv_publish(&mut self, msg: &message::Publish) -> Result<(), SessionError> { + let mut entries = self.publishes_received.lock().unwrap(); + + let entry = match entries.entry(msg.id) { + hash_map::Entry::Occupied(_) => return Err(SessionError::Duplicate), + hash_map::Entry::Vacant(entry) => entry, + }; + + let (publish_received, recv) = PublishReceived::new(self.clone(), msg); + + self.publish_alias_map + .lock() + .unwrap() + .insert(msg.track_alias, msg.id); + self.publish_alias_notify.notify_waiters(); + + if let Err(publish_received) = self.publish_received_queue.push(publish_received) { + publish_received.close(ServeError::Cancel)?; + return Ok(()); } + entry.insert(recv); Ok(()) } @@ -240,19 +293,45 @@ impl Subscriber { } } - /// Handle the reception of a SubscribeError message from the publisher. - fn recv_subscribe_error(&mut self, msg: &message::SubscribeError) -> Result<(), SessionError> { + fn recv_request_ok(&mut self, msg: &message::RequestOk) -> Result<(), SessionError> { + if let Some(subscribe_ns) = self.subscribe_namespaces.lock().unwrap().get_mut(&msg.id) { + subscribe_ns.recv_ok()?; + return Ok(()); + } + + log::warn!( + "[SUBSCRIBER] recv_request_ok: request id {} not found", + msg.id + ); + Ok(()) + } + + fn recv_request_error(&mut self, msg: &message::RequestError) -> Result<(), SessionError> { if let Some(subscribe) = self.remove_subscribe(msg.id) { subscribe.error(ServeError::Closed(msg.error_code))?; + return Ok(()); } + if let Some(subscribe_ns) = self.subscribe_namespaces.lock().unwrap().remove(&msg.id) { + subscribe_ns.recv_error(ServeError::Closed(msg.error_code))?; + return Ok(()); + } + + log::warn!( + "[SUBSCRIBER] recv_request_error: request id {} not found", + msg.id + ); Ok(()) } - /// Handle the reception of a PublishDone message from the publisher. fn recv_publish_done(&mut self, msg: &message::PublishDone) -> Result<(), SessionError> { if let Some(subscribe) = self.remove_subscribe(msg.id) { subscribe.error(ServeError::Closed(msg.status_code))?; + return Ok(()); + } + + if let Some(mut publish_recv) = self.remove_publish_received(msg.id) { + publish_recv.recv_done()?; } Ok(()) @@ -266,23 +345,32 @@ impl Subscriber { Ok(()) } - /// Remove an announced namespace from our map of active announces. fn drop_publish_namespace(&mut self, namespace: &TrackNamespace) { - self.announced.lock().unwrap().remove(namespace); + self.publish_namespaces_received + .lock() + .unwrap() + .remove(namespace); + } + + fn remove_publish_received(&mut self, id: u64) -> Option { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().remove(&id) { + if let Some(track_alias) = publish_recv.track_alias() { + self.publish_alias_map.lock().unwrap().remove(&track_alias); + } + Some(publish_recv) + } else { + None + } } - /// Get a subscribe id by track alias, waiting up to the specified timeout if not present. - /// If timeout_ms is None, only check if already present and return None if not. async fn get_subscribe_id_by_alias( &self, track_alias: u64, timeout_ms: Option, ) -> Option { - // If no timeout specified, don't wait let timeout_ms = match timeout_ms { Some(ms) => ms, None => { - // Just check once return self .subscribe_alias_map .lock() @@ -292,14 +380,11 @@ impl Subscriber { } }; - // Wait for it to appear, checking after each notification let timeout_duration = Duration::from_millis(timeout_ms); tokio::time::timeout(timeout_duration, async { loop { - // Register for notification before checking map let notified = self.subscribe_alias_notify.notified(); - // Check Map for alias if let Some(id) = self .subscribe_alias_map .lock() @@ -310,7 +395,45 @@ impl Subscriber { return id; } - // Alias not present yet, wait for notification + notified.await; + } + }) + .await + .ok() + } + + async fn get_publish_id_by_alias( + &self, + track_alias: u64, + timeout_ms: Option, + ) -> Option { + let timeout_ms = match timeout_ms { + Some(ms) => ms, + None => { + return self + .publish_alias_map + .lock() + .unwrap() + .get(&track_alias) + .cloned(); + } + }; + + let timeout_duration = Duration::from_millis(timeout_ms); + tokio::time::timeout(timeout_duration, async { + loop { + let notified = self.publish_alias_notify.notified(); + + if let Some(id) = self + .publish_alias_map + .lock() + .unwrap() + .get(&track_alias) + .cloned() + { + return id; + } + notified.await; } }) @@ -376,7 +499,6 @@ impl Subscriber { res } - /// Continue handling the reception of a new stream from the QUIC session. async fn recv_stream_inner( &mut self, reader: Reader, @@ -389,19 +511,28 @@ impl Subscriber { track_alias ); - // This is super silly, but I couldn't figure out a way to avoid the mutex guard across awaits. enum Writer { - //Fetch(serve::FetchWriter), Subgroup(serve::SubgroupWriter), } + // First check both maps WITHOUT waiting - this is the fast path for subsequent groups + // where the alias mapping is already established + let (subscribe_id_immediate, publish_id_immediate) = { + let subscribe_id = self.get_subscribe_id_by_alias(track_alias, None).await; + let publish_id = self.get_publish_id_by_alias(track_alias, None).await; + (subscribe_id, publish_id) + }; + + log::debug!( + "[SUBSCRIBER] recv_stream_inner: track_alias={}, subscribe_id_immediate={:?}, publish_id_immediate={:?}", + track_alias, subscribe_id_immediate, publish_id_immediate + ); + + // Determine which path to use, waiting only if neither map has the alias yet let writer = { - // Look up the subscribe id for this track alias - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { - // Look up the subscribe by id + if let Some(subscribe_id) = subscribe_id_immediate { + // Found in subscribe map immediately + log::debug!("[SUBSCRIBER] recv_stream_inner: using SUBSCRIBE path (immediate)"); let mut subscribes = self.subscribes.lock().unwrap(); let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { ServeError::not_found_ctx(format!( @@ -410,10 +541,37 @@ impl Subscriber { )) })?; - // Create the appropriate writer based on the stream header type if stream_header.header_type.is_subgroup() { - log::trace!("[SUBSCRIBER] recv_stream_inner: creating subgroup writer"); - Writer::Subgroup(subscribe.subgroup(stream_header.subgroup_header.unwrap())?) + log::trace!( + "[SUBSCRIBER] recv_stream_inner: creating subgroup writer from subscribe" + ); + Writer::Subgroup( + subscribe.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } else if let Some(publish_id) = publish_id_immediate { + // Found in publish map immediately + log::debug!("[SUBSCRIBER] recv_stream_inner: using PUBLISH path (immediate)"); + let mut publishes = self.publishes_received.lock().unwrap(); + let publish_recv = publishes.get_mut(&publish_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "publish_id={} not found for track_alias={}", + publish_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + log::trace!( + "[SUBSCRIBER] recv_stream_inner: creating subgroup writer from publish" + ); + Writer::Subgroup( + publish_recv.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) } else { return Err(SessionError::Serve(ServeError::internal_ctx(format!( "unsupported stream header type={}", @@ -421,16 +579,69 @@ impl Subscriber { )))); } } else { - return Err(SessionError::Serve(ServeError::not_found_ctx(format!( - "subscription track_alias={} not found", + // Not found in either map - wait for either to become available + // This only happens for the first stream before control messages establish the mapping + log::debug!( + "[SUBSCRIBER] recv_stream_inner: track_alias={} NOT FOUND in either map, WAITING for alias mapping", track_alias - )))); + ); + + // Race both lookups with timeout + let subscribe_fut = self.get_subscribe_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + let publish_fut = self.get_publish_id_by_alias(track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + + tokio::select! { + Some(subscribe_id) = subscribe_fut => { + let mut subscribes = self.subscribes.lock().unwrap(); + let subscribe = subscribes.get_mut(&subscribe_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "subscribe_id={} not found for track_alias={}", + subscribe_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + Writer::Subgroup( + subscribe.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } + Some(publish_id) = publish_fut => { + let mut publishes = self.publishes_received.lock().unwrap(); + let publish_recv = publishes.get_mut(&publish_id).ok_or_else(|| { + ServeError::not_found_ctx(format!( + "publish_id={} not found for track_alias={}", + publish_id, track_alias + )) + })?; + + if stream_header.header_type.is_subgroup() { + Writer::Subgroup( + publish_recv.subgroup(stream_header.subgroup_header.clone().unwrap())?, + ) + } else { + return Err(SessionError::Serve(ServeError::internal_ctx(format!( + "unsupported stream header type={}", + stream_header.header_type + )))); + } + } + else => { + return Err(SessionError::Serve(ServeError::not_found_ctx(format!( + "subscription track_alias={} not found", + track_alias + )))); + } + } } }; - // Handle the stream based on the writer type match writer { - //Writer::Fetch(fetch) => Self::recv_fetch(fetch, reader).await?, Writer::Subgroup(subgroup_writer) => { log::trace!("[SUBSCRIBER] recv_stream_inner: receiving subgroup data"); Self::recv_subgroup(stream_header.header_type, subgroup_writer, reader, mlog) @@ -513,6 +724,20 @@ impl Subscriber { } } + // Check for Prior Object ID Gap (type 0x3E = 62) + if object.extension_headers.has(0x3E) { + log::info!( + "[SUBSCRIBER] recv_subgroup: object #{} contains PRIOR OBJECT ID GAP (type 0x3E)", + object_count + 1 + ); + if let Some(gap_ext) = object.extension_headers.get(0x3E) { + log::debug!( + "[SUBSCRIBER] recv_subgroup: prior object id gap details: {:?}", + gap_ext + ); + } + } + let obj_copy = object.clone(); ( object.payload_length, @@ -581,10 +806,12 @@ impl Subscriber { } } - // Pass extension headers through to the serve layer - // TODO SLG - object_id_delta and object status are still being ignored - - let mut object_writer = subgroup_writer.create(remaining_bytes, extension_headers)?; + // Pass extension headers and status through to the serve layer + let mut object_writer = subgroup_writer.create_with_status( + remaining_bytes, + extension_headers, + status.unwrap_or(crate::data::ObjectStatus::NormalObject), + )?; log::trace!( "[SUBSCRIBER] recv_subgroup: reading payload for object #{} ({} bytes)", object_count + 1, @@ -624,11 +851,27 @@ impl Subscriber { object_count += 1; } + // If the stream header type signals end-of-group, write an EndOfGroup marker + // This forwards the "stream end = group end" semantic to downstream subscribers + if stream_header_type.signals_end_of_group() { + log::debug!( + "[SUBSCRIBER] recv_subgroup: writing EndOfGroup marker (header_type={:?} signals EOG)", + stream_header_type + ); + if let Err(e) = subgroup_writer.end_of_group() { + log::warn!( + "[SUBSCRIBER] recv_subgroup: failed to write EndOfGroup marker: {}", + e + ); + } + } + log::info!( - "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received)", + "[SUBSCRIBER] recv_subgroup: completed subgroup (group_id={}, subgroup_id={}, {} objects received, eog={})", subgroup_writer.info.group_id, subgroup_writer.info.subgroup_id, - object_count + object_count, + stream_header_type.signals_end_of_group() ); Ok(()) @@ -682,17 +925,28 @@ impl Subscriber { ); } } + + // Check for Prior Object ID Gap (type 0x3E = 62) + if ext_headers.has(0x3E) { + log::info!( + "[SUBSCRIBER] recv_datagram: datagram contains PRIOR OBJECT ID GAP (type 0x3E)" + ); + if let Some(gap_ext) = ext_headers.get(0x3E) { + log::debug!( + "[SUBSCRIBER] recv_datagram: prior object id gap details: {:?}", + gap_ext + ); + } + } } - // Look up the subscribe id for this track alias if let Some(subscribe_id) = self .get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) .await { - // Look up the subscribe by id if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { log::trace!( - "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", + "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", datagram.track_alias, datagram.group_id, datagram.object_id.unwrap_or(0), @@ -701,9 +955,25 @@ impl Subscriber { datagram.payload.as_ref().map_or(0, |p| p.len())); subscribe.datagram(datagram)?; } + } else if let Some(publish_id) = self + .get_publish_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) + .await + { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) + { + log::trace!( + "[SUBSCRIBER] recv_datagram from publish: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + publish_recv.datagram(datagram)?; + } } else { log::warn!( - "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={}, status={}, payload_length={}", + "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", datagram.track_alias, datagram.group_id, datagram.object_id.unwrap_or(0), diff --git a/moq-transport/src/session/track_status_requested.rs b/moq-transport/src/session/track_status_requested.rs index 4c9cf744..587610be 100644 --- a/moq-transport/src/session/track_status_requested.rs +++ b/moq-transport/src/session/track_status_requested.rs @@ -21,9 +21,10 @@ impl TrackStatusRequested { error_code: u64, error_message: &str, ) -> Result<(), SessionError> { - let status_error = message::TrackStatusError { + let status_error = message::RequestError { id: self.request_msg.id, error_code, + retry_interval: 0, reason_phrase: ReasonPhrase(error_message.to_string()), }; self.publisher.send_message(status_error); From 884bd7444dbd0d5fdd0326b71eaa56b53198da13 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:18 -0700 Subject: [PATCH 03/21] Update serve layer with header type preservation and fixes --- moq-transport/src/serve/datagram.rs | 9 ++ moq-transport/src/serve/error.rs | 6 + moq-transport/src/serve/stream.rs | 5 + moq-transport/src/serve/subgroup.rs | 49 +++++++- moq-transport/src/serve/track.rs | 176 +++++++++++++++++++++++++++- moq-transport/src/serve/tracks.rs | 90 +++++++++++++- 6 files changed, 327 insertions(+), 8 deletions(-) diff --git a/moq-transport/src/serve/datagram.rs b/moq-transport/src/serve/datagram.rs index f7ff2697..4d62e8d2 100644 --- a/moq-transport/src/serve/datagram.rs +++ b/moq-transport/src/serve/datagram.rs @@ -114,6 +114,11 @@ impl DatagramsReader { .as_ref() .map(|datagram| (datagram.group_id, datagram.object_id)) } + + pub fn is_closed(&self) -> bool { + let state = self.state.lock(); + state.closed.is_err() || state.modified().is_none() + } } /// Static information about the datagram. @@ -126,6 +131,9 @@ pub struct Datagram { // Extension headers (for draft-14 compliance, particularly immutable extensions) pub extension_headers: crate::data::ExtensionHeaders, + + // Object status (e.g., EndOfGroup) + pub status: Option, } impl fmt::Debug for Datagram { @@ -136,6 +144,7 @@ impl fmt::Debug for Datagram { .field("priority", &self.priority) .field("payload", &self.payload.len()) .field("extension_headers", &self.extension_headers) + .field("status", &self.status) .finish() } } diff --git a/moq-transport/src/serve/error.rs b/moq-transport/src/serve/error.rs index bb0995b5..57666d3a 100644 --- a/moq-transport/src/serve/error.rs +++ b/moq-transport/src/serve/error.rs @@ -36,6 +36,10 @@ pub enum ServeError { #[error("not implemented: {0} [error:{1}]")] NotImplementedWithId(String, uuid::Uuid), + + /// Relay already has an active SUBSCRIBE path, not interested in PUBLISH + #[error("uninterested")] + Uninterested, } impl ServeError { @@ -60,6 +64,8 @@ impl ServeError { Self::NotImplemented(_) | Self::NotImplementedWithId(_, _) => 0x3, // INTERNAL_ERROR (0x0) - per-request error registries use 0x0 Self::Internal(_) | Self::InternalWithId(_, _) => 0x0, + // UNINTERESTED (0x1) - relay already has data path via SUBSCRIBE + Self::Uninterested => 0x1, } } diff --git a/moq-transport/src/serve/stream.rs b/moq-transport/src/serve/stream.rs index e56c1405..020e2aba 100644 --- a/moq-transport/src/serve/stream.rs +++ b/moq-transport/src/serve/stream.rs @@ -188,6 +188,11 @@ impl StreamReader { ) }) } + + pub fn is_closed(&self) -> bool { + let state = self.state.lock(); + state.closed.is_err() || state.modified().is_none() + } } impl Deref for StreamReader { diff --git a/moq-transport/src/serve/subgroup.rs b/moq-transport/src/serve/subgroup.rs index 2d0fc0c0..e47048b6 100644 --- a/moq-transport/src/serve/subgroup.rs +++ b/moq-transport/src/serve/subgroup.rs @@ -95,6 +95,7 @@ impl SubgroupsWriter { group_id, subgroup_id, priority, + header_type: None, }) } @@ -105,6 +106,7 @@ impl SubgroupsWriter { group_id: subgroup.group_id, subgroup_id: subgroup.subgroup_id, priority: subgroup.priority, + header_type: subgroup.header_type, }; let (writer, reader) = subgroup.produce(); @@ -114,8 +116,17 @@ impl SubgroupsWriter { // TODO: Check this logic again if writer.group_id.cmp(&latest.group_id) == cmp::Ordering::Equal { match writer.subgroup_id.cmp(&latest.subgroup_id) { - cmp::Ordering::Less => return Ok(writer), // dropped immediately, lul - cmp::Ordering::Equal => return Err(ServeError::Duplicate), + cmp::Ordering::Less => return Ok(writer), // dropped immediately + cmp::Ordering::Equal => { + // Duplicate subgroup - silently drop instead of erroring + // This can happen with SubgroupZeroIdEndOfGroup streams + log::warn!( + "duplicate subgroup: group_id={}, subgroup_id={} - dropping", + writer.group_id, + writer.subgroup_id + ); + return Ok(writer); // writer dropped, data lost but relay continues + } cmp::Ordering::Greater => state.latest_subgroup_reader = Some(reader), } } else if writer.group_id.cmp(&latest.group_id) == cmp::Ordering::Greater { @@ -199,6 +210,12 @@ impl SubgroupsReader { .as_ref() .map(|group| (group.group_id, group.latest())) } + + /// Check if the subgroups writer has been closed or dropped. + pub fn is_closed(&self) -> bool { + let state = self.state.lock(); + state.closed.is_err() || state.modified().is_none() + } } impl Deref for SubgroupsReader { @@ -222,6 +239,9 @@ pub struct Subgroup { // The priority of the group within the track. pub priority: u8, + + // The stream header type used for this subgroup (preserved from incoming stream) + pub header_type: Option, } /// Static information about the group @@ -239,6 +259,9 @@ pub struct SubgroupInfo { // The priority of the group within the track. pub priority: u8, + + // The stream header type used for this subgroup (preserved from incoming stream) + pub header_type: Option, } impl SubgroupInfo { @@ -313,11 +336,21 @@ impl SubgroupWriter { &mut self, size: usize, extension_headers: Option, + ) -> Result { + self.create_with_status(size, extension_headers, ObjectStatus::NormalObject) + } + + /// Write an object with a specific status (e.g., EndOfGroup). + pub fn create_with_status( + &mut self, + size: usize, + extension_headers: Option, + status: ObjectStatus, ) -> Result { let (writer, reader) = SubgroupObject { group: self.info.clone(), object_id: self.next_object_id, - status: ObjectStatus::NormalObject, + status, size, extension_headers: extension_headers.unwrap_or_default(), } @@ -331,6 +364,16 @@ impl SubgroupWriter { Ok(writer) } + /// Write an EndOfGroup marker object to signal the end of this subgroup. + /// This should be called when the group is complete. + pub fn end_of_group(&mut self) -> Result<(), ServeError> { + // Create an object with size=0 and status=EndOfGroup + let object_writer = self.create_with_status(0, None, ObjectStatus::EndOfGroup)?; + // Object writer with size=0 will complete immediately when dropped + drop(object_writer); + Ok(()) + } + /// Close the stream with an error. pub fn close(self, err: ServeError) -> Result<(), ServeError> { let state = self.state.lock(); diff --git a/moq-transport/src/serve/track.rs b/moq-transport/src/serve/track.rs index 9dd9e101..01591ca8 100644 --- a/moq-transport/src/serve/track.rs +++ b/moq-transport/src/serve/track.rs @@ -199,10 +199,24 @@ impl TrackReader { /// This is used to detect stale cached TrackReaders that should not be reused. pub fn is_closed(&self) -> bool { let state = self.state.lock(); - // Track is closed if: - // 1. It was explicitly closed with an error, OR - // 2. The writer side has been dropped (modified() returns None) - state.closed.is_err() || state.modified().is_none() + + if state.closed.is_err() { + return true; + } + + // Clone the mode out before dropping the TrackState lock to avoid + // nested lock deadlocks (mode readers hold their own State locks). + if let Some(mode) = state.reader_mode.clone() { + // Mode has been set — the TrackWriter was consumed during the + // Track→Subgroups/Stream/Datagrams transition. Liveness is now + // determined by whether the mode-level writer is still alive. + drop(state); + return mode.is_closed(); + } + + // No mode set yet — check if the writer was abandoned before + // transitioning to a specific mode. + state.modified().is_none() } } @@ -234,6 +248,12 @@ macro_rules! track_readers { $(Self::$name(reader) => reader.latest(),)* } } + + pub fn is_closed(&self) -> bool { + match self { + $(Self::$name(reader) => reader.is_closed(),)* + } + } } } } @@ -266,3 +286,151 @@ macro_rules! track_writers { } track_writers!(Track, Stream, Subgroups, Objects, Datagrams,); + +#[cfg(test)] +mod tests { + use super::*; + use crate::coding::TrackNamespace; + use crate::serve::Subgroup; + + #[test] + fn test_is_closed_false_before_mode_set() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (_writer, reader) = track.produce(); + assert!(!reader.is_closed()); + } + + #[test] + fn test_is_closed_true_when_writer_dropped_without_mode() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + drop(writer); + assert!(reader.is_closed()); + } + + #[test] + fn test_is_closed_true_when_explicitly_closed() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + writer.close(ServeError::Cancel).unwrap(); + assert!(reader.is_closed()); + } + + #[test] + fn test_is_closed_false_after_subgroups_transition_while_writer_alive() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let _subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while SubgroupsWriter is alive" + ); + } + + #[test] + fn test_is_closed_true_after_subgroups_writer_dropped() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); + drop(subgroups_writer); + + assert!( + reader.is_closed(), + "track should be closed after SubgroupsWriter is dropped" + ); + } + + #[test] + fn test_is_closed_true_after_subgroups_writer_explicitly_closed() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let subgroups_writer = writer.subgroups().expect("subgroups transition should succeed"); + subgroups_writer.close(ServeError::Cancel).unwrap(); + + assert!( + reader.is_closed(), + "track should be closed after SubgroupsWriter is explicitly closed" + ); + } + + #[test] + fn test_is_closed_false_after_stream_transition_while_writer_alive() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let _stream_writer = writer.stream(0).expect("stream transition should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while StreamWriter is alive" + ); + } + + #[test] + fn test_is_closed_true_after_stream_writer_dropped() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let stream_writer = writer.stream(0).expect("stream transition should succeed"); + drop(stream_writer); + + assert!( + reader.is_closed(), + "track should be closed after StreamWriter is dropped" + ); + } + + #[test] + fn test_is_closed_false_after_datagrams_transition_while_writer_alive() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let _datagrams_writer = writer.datagrams().expect("datagrams transition should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while DatagramsWriter is alive" + ); + } + + #[test] + fn test_is_closed_true_after_datagrams_writer_dropped() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let datagrams_writer = writer.datagrams().expect("datagrams transition should succeed"); + drop(datagrams_writer); + + assert!( + reader.is_closed(), + "track should be closed after DatagramsWriter is dropped" + ); + } + + #[test] + fn test_is_closed_false_while_subgroups_actively_writing() { + let track = Track::new(TrackNamespace::from_utf8_path("ns"), "t".to_string()); + let (writer, reader) = track.produce(); + + let mut subgroups_writer = + writer.subgroups().expect("subgroups transition should succeed"); + + let _subgroup_writer = subgroups_writer + .create(Subgroup { + group_id: 0, + subgroup_id: 0, + priority: 0, + header_type: None, + }) + .expect("create subgroup should succeed"); + + assert!( + !reader.is_closed(), + "track should NOT be closed while actively writing subgroups" + ); + } +} diff --git a/moq-transport/src/serve/tracks.rs b/moq-transport/src/serve/tracks.rs index 7ce3aef0..cea6e954 100644 --- a/moq-transport/src/serve/tracks.rs +++ b/moq-transport/src/serve/tracks.rs @@ -91,6 +91,21 @@ impl TracksWriter { }; self.state.lock_mut()?.tracks.remove(&full_name) } + + /// Insert an existing track reader into the broadcast. + /// Returns None if all readers have been dropped or if a track with this name already exists. + pub fn insert(&mut self, reader: TrackReader) -> Option<()> { + let full_name = FullTrackName { + namespace: reader.namespace.clone(), + name: reader.name.clone(), + }; + let mut state = self.state.lock_mut()?; + if state.tracks.contains_key(&full_name) { + return None; + } + state.tracks.insert(full_name, reader); + Some(()) + } } impl Deref for TracksWriter { @@ -201,7 +216,6 @@ impl TracksReader { return Some(track_reader.clone()); } // Track is closed/stale, fall through to create a new one - // We'll remove the stale entry and request a fresh track from the publisher } let mut state = state.into_mut()?; @@ -226,6 +240,13 @@ impl TracksReader { Some(track_writer_reader.1) } + + /// Forward an existing track writer to the upstream subscription queue. + /// The writer will be received by [TracksRequest::next()]. + /// Returns None if the queue is closed. + pub fn forward_upstream(&mut self, writer: TrackWriter) -> Option<()> { + self.queue.push(writer).ok() + } } impl Deref for TracksReader { @@ -324,6 +345,73 @@ mod tests { ); } + #[tokio::test] + async fn test_track_not_stale_after_subgroups_transition() { + let namespace = TrackNamespace::from_utf8_path("test/namespace"); + let track_name = "test-track"; + + let (_writer, mut request, mut reader) = Tracks::new(namespace.clone()).produce(); + + let _track_reader_1 = reader + .subscribe(namespace.clone(), track_name) + .expect("first subscribe should succeed"); + + let track_writer = request + .next() + .await + .expect("publisher should receive track request"); + + let _subgroups_writer = track_writer + .subgroups() + .expect("subgroups transition should succeed"); + + let _track_reader_2 = reader + .subscribe(namespace.clone(), track_name) + .expect("second subscribe should succeed"); + + let maybe_second_request = + tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; + + assert!( + maybe_second_request.is_err(), + "publisher should NOT get a second request while SubgroupsWriter is alive" + ); + } + + #[tokio::test] + async fn test_track_stale_after_subgroups_writer_dropped() { + let namespace = TrackNamespace::from_utf8_path("test/namespace"); + let track_name = "test-track"; + + let (_writer, mut request, mut reader) = Tracks::new(namespace.clone()).produce(); + + let _track_reader_1 = reader + .subscribe(namespace.clone(), track_name) + .expect("first subscribe should succeed"); + + let track_writer = request + .next() + .await + .expect("publisher should receive track request"); + + let subgroups_writer = track_writer + .subgroups() + .expect("subgroups transition should succeed"); + drop(subgroups_writer); + + let _track_reader_2 = reader + .subscribe(namespace.clone(), track_name) + .expect("second subscribe should succeed"); + + let maybe_second_request = + tokio::time::timeout(std::time::Duration::from_millis(100), request.next()).await; + + assert!( + maybe_second_request.is_ok(), + "publisher should get a new request after SubgroupsWriter is dropped" + ); + } + /// Test that normal track caching works correctly when tracks are still alive. /// /// Multiple subscribers to the same track should share the same TrackReader From dbd2e9b2a28c521aa04033a5569cb202c8d30986 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:22 -0700 Subject: [PATCH 04/21] Add SUBSCRIBE_NAMESPACE/PUBLISH relay support --- moq-relay-ietf/src/consumer.rs | 188 +++++++++-- moq-relay-ietf/src/lib.rs | 2 + moq-relay-ietf/src/local.rs | 360 ++++++++++++++++++++-- moq-relay-ietf/src/producer.rs | 281 ++++++++++++++++- moq-relay-ietf/src/relay.rs | 26 +- moq-relay-ietf/src/subscriber_registry.rs | 277 +++++++++++++++++ 6 files changed, 1069 insertions(+), 65 deletions(-) create mode 100644 moq-relay-ietf/src/subscriber_registry.rs diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 8d636912..acb16ba5 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -3,11 +3,13 @@ use std::sync::Arc; use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ - serve::Tracks, - session::{Announced, SessionError, Subscriber}, + coding::KeyValuePairs, + message::PublishOk, + serve::{ServeError, Tracks}, + session::{PublishNamespaceReceived, PublishReceived, SessionError, Subscriber}, }; -use crate::{Coordinator, Locals, Producer}; +use crate::{Coordinator, Locals, Producer, SubscriberRegistry}; /// Consumer of tracks from a remote Publisher #[derive(Clone)] @@ -16,6 +18,7 @@ pub struct Consumer { locals: Locals, coordinator: Arc, forward: Option, // Forward all announcements to this subscriber + subscriber_registry: Option, } impl Consumer { @@ -30,28 +33,60 @@ impl Consumer { locals, coordinator, forward, + subscriber_registry: None, } } - /// Run the consumer to serve announce requests. - pub async fn run(mut self) -> Result<(), SessionError> { - let mut tasks = FuturesUnordered::new(); + /// Creates a consumer with a subscriber registry for PUBLISH notifications. + pub fn with_registry( + subscriber: Subscriber, + locals: Locals, + coordinator: Arc, + forward: Option, + subscriber_registry: SubscriberRegistry, + ) -> Self { + Self { + subscriber, + locals, + coordinator, + forward, + subscriber_registry: Some(subscriber_registry), + } + } + + /// Run the consumer to serve announce requests and track-level publish messages. + pub async fn run(self) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); loop { + let mut subscriber_ns = self.subscriber.clone(); + let mut subscriber_publish = self.subscriber.clone(); + tokio::select! { - // Handle a new announce request - Some(announce) = self.subscriber.announced() => { + Some(publish_ns) = subscriber_ns.publish_ns_recvd() => { let this = self.clone(); tasks.push(async move { - let info = announce.clone(); - log::info!("serving announce: {:?}", info); + let info = publish_ns.clone(); + log::info!("serving publish_namespace: {:?}", info); - // Serve the announce request - if let Err(err) = this.serve(announce).await { - log::warn!("failed serving announce: {:?}, error: {}", info, err) + if let Err(err) = this.serve_publish_namespace(publish_ns).await { + log::warn!("failed serving publish_namespace: {:?}, error: {}", info, err) } - }); + }.boxed()); + }, + Some(publish) = subscriber_publish.publish_received() => { + let this = self.clone(); + + tasks.push(async move { + let info = publish.info.clone(); + log::info!("serving publish (track-level): {:?}", info); + + if let Err(err) = this.serve_publish(publish).await { + log::warn!("failed serving publish: {:?}, error: {}", info, err) + } + }.boxed()); }, _ = tasks.next(), if !tasks.is_empty() => {}, else => return Ok(()), @@ -59,12 +94,13 @@ impl Consumer { } } - /// Serve an announce request. - async fn serve(mut self, mut announce: Announced) -> Result<(), anyhow::Error> { + async fn serve_publish_namespace( + mut self, + mut publish_ns: PublishNamespaceReceived, + ) -> Result<(), anyhow::Error> { let mut tasks = FuturesUnordered::new(); - // Produce the tracks for this announce and return the reader - let (_, mut request, reader) = Tracks::new(announce.namespace.clone()).produce(); + let (writer, mut request, reader) = Tracks::new(publish_ns.namespace.clone()).produce(); // NOTE(mpandit): once the track is pulled from origin, internally it will be relayed // from this metal only, because now coordinator will have entry for the namespace. @@ -78,20 +114,40 @@ impl Consumer { .await?; // Register the local tracks, unregister on drop - let _register = self.locals.register(reader.clone()).await?; + let _register = self.locals.register(reader.clone(), writer).await?; - // Accept the announce with an OK response - announce.ok()?; + publish_ns.ok()?; - // Forward the announce, if needed - if let Some(mut forward) = self.forward { + // Notify subscriber registry of the new PUBLISH_NAMESPACE + // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + if let Some(ref registry) = self.subscriber_registry { + let notified = registry.notify_publish_namespace(&publish_ns.namespace); + if notified > 0 { + log::info!( + "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH_NAMESPACE {:?}", + notified, + publish_ns.namespace + ); + } + } + + if let Some(mut forward) = self.forward.clone() { + let reader_clone = reader.clone(); tasks.push( async move { - log::info!("forwarding announce: {:?}", reader.info); - forward - .announce(reader) + log::info!("forwarding publish_namespace: {:?}", reader_clone.info); + let publish_ns = forward + .publish_namespace(reader_clone) .await - .context("failed forwarding announce") + .context("failed forwarding publish_namespace")?; + publish_ns + .ok() + .await + .context("publish_namespace not accepted")?; + publish_ns + .closed() + .await + .context("publish_namespace closed with error") } .boxed(), ); @@ -100,8 +156,7 @@ impl Consumer { // Serve subscribe requests loop { tokio::select! { - // If the announce is closed, return the error - Err(err) = announce.closed() => return Err(err.into()), + Err(err) = publish_ns.closed() => return Err(err.into()), // Wait for the next subscriber and serve the track. Some(track) = request.next() => { @@ -125,4 +180,79 @@ impl Consumer { } } } + + async fn serve_publish(self, publish: PublishReceived) -> Result<(), anyhow::Error> { + let namespace = publish.info.track_namespace.clone(); + let track_name = publish.info.track_name.clone(); + let track_alias = publish.info.track_alias; + + log::info!("received PUBLISH for track: {}/{}", namespace, track_name); + + // Use auto-register variant to support SUBSCRIBE_NAMESPACE flow + // where PUBLISH can arrive without prior PUBLISH_NAMESPACE + let track_info = self + .locals + .get_or_create_track_info_auto_register(&namespace, &track_name); + + let writer = match track_info.publish_arrived() { + Ok(w) => w, + Err(ServeError::Uninterested) => { + log::info!( + "PUBLISH rejected: already subscribed to {}/{}", + namespace, + track_name + ); + publish.reject(ServeError::Uninterested.code(), "Already subscribed")?; + return Err(ServeError::Uninterested.into()); + } + Err(ServeError::Duplicate) => { + log::info!( + "PUBLISH rejected: already publishing {}/{}", + namespace, + track_name + ); + publish.reject(ServeError::Duplicate.code(), "Already publishing")?; + return Err(ServeError::Duplicate.into()); + } + Err(e) => { + publish.reject(e.code(), &e.to_string())?; + return Err(e.into()); + } + }; + + let reader = track_info.get_reader(); + + self.locals + .insert_track(&namespace, reader) + .context("failed to insert track into namespace")?; + + let msg = PublishOk { + id: publish.info.id, + params: KeyValuePairs::default(), + }; + + publish.accept(writer, msg)?; + + log::info!( + "PUBLISH accepted, track {}/{} now in Publishing state", + namespace, + track_name + ); + + // Notify subscriber registry of the new PUBLISH + // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + if let Some(ref registry) = self.subscriber_registry { + let notified = registry.notify_publish(&namespace, &track_name, track_alias); + if notified > 0 { + log::info!( + "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH {}/{}", + notified, + namespace, + track_name + ); + } + } + + Ok(()) + } } diff --git a/moq-relay-ietf/src/lib.rs b/moq-relay-ietf/src/lib.rs index aac39326..11a9456f 100644 --- a/moq-relay-ietf/src/lib.rs +++ b/moq-relay-ietf/src/lib.rs @@ -36,6 +36,7 @@ mod producer; mod relay; mod remote; mod session; +mod subscriber_registry; mod web; pub use api::*; @@ -46,4 +47,5 @@ pub use producer::*; pub use relay::*; pub use remote::*; pub use session::*; +pub use subscriber_registry::*; pub use web::*; diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 406e6650..624e3d1b 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -1,17 +1,169 @@ use std::collections::hash_map; use std::collections::HashMap; - -use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; use moq_transport::{ coding::TrackNamespace, - serve::{ServeError, TracksReader}, + serve::{ServeError, Track, TrackReader, TrackWriter, TracksReader, TracksWriter}, }; -/// Registry of local tracks +#[repr(u8)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum TrackState { + Pending = 0, + Subscribing = 1, + Subscribed = 2, + Publishing = 3, + Closed = 4, +} + +impl TrackState { + fn from_u8(v: u8) -> Self { + match v { + 0 => TrackState::Pending, + 1 => TrackState::Subscribing, + 2 => TrackState::Subscribed, + 3 => TrackState::Publishing, + _ => TrackState::Closed, + } + } +} + +pub struct TrackInfo { + pub namespace: TrackNamespace, + pub name: String, + + state: AtomicU8, + track_reader: OnceLock, + track_writer: Mutex>, + upstream_subscribe_sent: AtomicBool, + upstream_request_id: Mutex>, +} + +impl TrackInfo { + pub fn new(namespace: TrackNamespace, name: String) -> Self { + Self { + namespace, + name, + state: AtomicU8::new(TrackState::Pending as u8), + track_reader: OnceLock::new(), + track_writer: Mutex::new(None), + upstream_subscribe_sent: AtomicBool::new(false), + upstream_request_id: Mutex::new(None), + } + } + + pub fn get_reader(&self) -> TrackReader { + self.ensure_track_created(); + self.track_reader.get().unwrap().clone() + } + + pub fn should_subscribe_upstream(&self) -> bool { + let state = self.state(); + + if state == TrackState::Publishing { + return false; + } + + !self.upstream_subscribe_sent.swap(true, Ordering::SeqCst) + } + + pub fn mark_subscribe_sent(&self, request_id: u64) { + *self.upstream_request_id.lock().unwrap() = Some(request_id); + + let _ = self.state.compare_exchange( + TrackState::Pending as u8, + TrackState::Subscribing as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ); + } + + pub fn subscribe_ok_received(&self) { + let _ = self.state.compare_exchange( + TrackState::Subscribing as u8, + TrackState::Subscribed as u8, + Ordering::SeqCst, + Ordering::SeqCst, + ); + } + + pub fn publish_arrived(&self) -> Result { + self.ensure_track_created(); + + let current_state = self.state(); + + if current_state == TrackState::Subscribed { + return Err(ServeError::Uninterested); + } + + if current_state == TrackState::Publishing { + return Err(ServeError::Duplicate); + } + + self.state + .store(TrackState::Publishing as u8, Ordering::SeqCst); + + self.track_writer + .lock() + .unwrap() + .take() + .ok_or(ServeError::Duplicate) + } + + pub fn state(&self) -> TrackState { + TrackState::from_u8(self.state.load(Ordering::SeqCst)) + } + + pub fn is_publishing(&self) -> bool { + self.state() == TrackState::Publishing + } + + pub fn take_writer_for_upstream(&self) -> Result { + self.ensure_track_created(); + + let current_state = self.state(); + + if current_state == TrackState::Publishing { + return Err(ServeError::Duplicate); + } + + if current_state == TrackState::Subscribing || current_state == TrackState::Subscribed { + return Err(ServeError::Duplicate); + } + + self.state + .store(TrackState::Subscribing as u8, Ordering::SeqCst); + + self.track_writer + .lock() + .unwrap() + .take() + .ok_or(ServeError::Duplicate) + } + + fn ensure_track_created(&self) { + self.track_reader.get_or_init(|| { + let (writer, reader) = Track::new(self.namespace.clone(), self.name.clone()).produce(); + *self.track_writer.lock().unwrap() = Some(writer); + reader + }); + } +} + +struct LocalsEntry { + /// reader and writer hold the readers and writers for a namespace + reader: TracksReader, + writer: TracksWriter, + /// tracks holds the individual tracks for a namespace + tracks: Mutex>>, +} + +/// Locals is a map of TrackNamespace to LocalsEntry #[derive(Clone)] pub struct Locals { - lookup: Arc>>, + lookup: Arc>>, } impl Default for Locals { @@ -20,7 +172,6 @@ impl Default for Locals { } } -/// Local tracks registry. impl Locals { pub fn new() -> Self { Self { @@ -28,13 +179,19 @@ impl Locals { } } - /// Register new local tracks. - pub async fn register(&mut self, tracks: TracksReader) -> anyhow::Result { - let namespace = tracks.namespace.clone(); + pub async fn register( + &mut self, + reader: TracksReader, + writer: TracksWriter, + ) -> anyhow::Result { + let namespace = reader.namespace.clone(); - // Insert the tracks(TracksReader) into the lookup table match self.lookup.lock().unwrap().entry(namespace.clone()) { - hash_map::Entry::Vacant(entry) => entry.insert(tracks), + hash_map::Entry::Vacant(entry) => entry.insert(LocalsEntry { + reader, + writer, + tracks: Mutex::new(HashMap::new()), + }), hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), }; @@ -46,17 +203,122 @@ impl Locals { Ok(registration) } - /// Retrieve local tracks by namespace using hierarchical prefix matching. - /// Returns the TracksReader for the longest matching namespace prefix. pub fn retrieve(&self, namespace: &TrackNamespace) -> Option { let lookup = self.lookup.lock().unwrap(); - // Find the longest matching prefix let mut best_match: Option = None; let mut best_len = 0; - for (registered_ns, tracks) in lookup.iter() { - // Check if registered_ns is a prefix of namespace + for (registered_ns, entry) in lookup.iter() { + if namespace.fields.len() >= registered_ns.fields.len() { + let is_prefix = registered_ns + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b); + + if is_prefix && registered_ns.fields.len() > best_len { + best_match = Some(entry.reader.clone()); + best_len = registered_ns.fields.len(); + } + } + } + + best_match + } + + pub fn get_or_create_track_info( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Option> { + let lookup = self.lookup.lock().unwrap(); + + let entry = Self::find_best_match_entry(&lookup, namespace)?; + + // Use full namespace + track_name as key to avoid collisions + let track_key = format!("{}:{}", namespace, track_name); + + let mut tracks = entry.tracks.lock().unwrap(); + + let track_info = tracks + .entry(track_key) + .or_insert_with(|| Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string()))) + .clone(); + + Some(track_info) + } + + /// Get or create track info, auto-registering the namespace if needed. + /// This supports the SUBSCRIBE_NAMESPACE flow where PUBLISH can arrive + /// without a prior PUBLISH_NAMESPACE. + pub fn get_or_create_track_info_auto_register( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Arc { + let mut lookup = self.lookup.lock().unwrap(); + + // Use full namespace + track_name as key to avoid collisions + // when different namespaces have the same track_name + let track_key = format!("{}:{}", namespace, track_name); + + // First try to find an existing matching namespace entry + if let Some(entry) = Self::find_best_match_entry(&lookup, namespace) { + let mut tracks = entry.tracks.lock().unwrap(); + return tracks + .entry(track_key.clone()) + .or_insert_with(|| { + Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string())) + }) + .clone(); + } + + // No matching namespace found - auto-register for SUBSCRIBE_NAMESPACE flow + log::info!( + "auto-registering namespace {} for PUBLISH (no prior PUBLISH_NAMESPACE)", + namespace + ); + + let (writer, _request, reader) = + moq_transport::serve::Tracks::new(namespace.clone()).produce(); + + let entry = lookup.entry(namespace.clone()).or_insert(LocalsEntry { + reader, + writer, + tracks: Mutex::new(HashMap::new()), + }); + + let mut tracks = entry.tracks.lock().unwrap(); + tracks + .entry(track_key) + .or_insert_with(|| Arc::new(TrackInfo::new(namespace.clone(), track_name.to_string()))) + .clone() + } + + pub fn get_track_info( + &self, + namespace: &TrackNamespace, + track_name: &str, + ) -> Option> { + let lookup = self.lookup.lock().unwrap(); + + let entry = Self::find_best_match_entry(&lookup, namespace)?; + + // Use full namespace + track_name as key to match get_or_create_track_info + let track_key = format!("{}:{}", namespace, track_name); + let tracks = entry.tracks.lock().unwrap(); + tracks.get(&track_key).cloned() + } + + fn find_best_match_entry<'a>( + lookup: &'a HashMap, + namespace: &TrackNamespace, + ) -> Option<&'a LocalsEntry> { + let mut best_match: Option<&LocalsEntry> = None; + let mut best_len = 0; + + for (registered_ns, entry) in lookup.iter() { if namespace.fields.len() >= registered_ns.fields.len() { let is_prefix = registered_ns .fields @@ -65,7 +327,7 @@ impl Locals { .all(|(a, b)| a == b); if is_prefix && registered_ns.fields.len() > best_len { - best_match = Some(tracks.clone()); + best_match = Some(entry); best_len = registered_ns.fields.len(); } } @@ -73,6 +335,69 @@ impl Locals { best_match } + + pub fn insert_track( + &self, + namespace: &TrackNamespace, + track_reader: TrackReader, + ) -> Option<()> { + let mut lookup = self.lookup.lock().unwrap(); + + if let Some(entry) = lookup.get_mut(namespace) { + entry.writer.insert(track_reader) + } else { + None + } + } + + pub fn subscribe_upstream(&self, track_info: Arc) -> Option { + let mut lookup = self.lookup.lock().unwrap(); + + let entry = lookup.get_mut(&track_info.namespace)?; + + let writer = track_info.take_writer_for_upstream().ok()?; + let reader = track_info.get_reader(); + + entry.reader.forward_upstream(writer)?; + + let namespace = track_info.namespace.clone(); + + let entry_mut = lookup + .iter_mut() + .find(|(ns, _)| { + namespace.fields.len() >= ns.fields.len() + && ns + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b) + }) + .map(|(_, e)| e)?; + + entry_mut.writer.insert(reader.clone()); + + Some(reader) + } + + pub fn matching_namespaces(&self, prefix: &TrackNamespace) -> Vec { + let lookup = self.lookup.lock().unwrap(); + + lookup + .keys() + .filter(|ns| { + if ns.fields.len() >= prefix.fields.len() { + prefix + .fields + .iter() + .zip(ns.fields.iter()) + .all(|(a, b)| a == b) + } else { + false + } + }) + .cloned() + .collect() + } } pub struct Registration { @@ -80,7 +405,6 @@ pub struct Registration { namespace: TrackNamespace, } -/// Deregister local tracks on drop. impl Drop for Registration { fn drop(&mut self) { self.locals.lookup.lock().unwrap().remove(&self.namespace); diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 23ea49f3..dd029c7e 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -1,10 +1,15 @@ use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::{ + coding::{KeyValuePairs, TrackNamespace}, + message, serve::{ServeError, TracksReader}, - session::{Publisher, SessionError, Subscribed, TrackStatusRequested}, + session::{ + PublishNamespace, Publisher, SessionError, SubscribeNamespaceReceived, Subscribed, + TrackStatusRequested, + }, }; -use crate::{Locals, RemotesConsumer}; +use crate::{Locals, RemotesConsumer, SubscriberRegistry}; /// Producer of tracks to a remote Subscriber #[derive(Clone)] @@ -12,6 +17,7 @@ pub struct Producer { publisher: Publisher, locals: Locals, remotes: Option, + subscriber_registry: Option, } impl Producer { @@ -20,15 +26,34 @@ impl Producer { publisher, locals, remotes, + subscriber_registry: None, } } - /// Announce new tracks to the remote server. - pub async fn announce(&mut self, tracks: TracksReader) -> Result<(), SessionError> { - self.publisher.announce(tracks).await + /// Creates a producer with a subscriber registry. + pub fn with_registry( + publisher: Publisher, + locals: Locals, + remotes: Option, + subscriber_registry: SubscriberRegistry, + ) -> Self { + Self { + publisher, + locals, + remotes, + subscriber_registry: Some(subscriber_registry), + } + } + + pub async fn publish_namespace( + &mut self, + tracks: TracksReader, + ) -> Result { + self.publisher + .publish_namespace(tracks.namespace.clone()) + .await } - /// Run the producer to serve subscribe requests. pub async fn run(self) -> Result<(), SessionError> { //let mut tasks = FuturesUnordered::new(); let mut tasks: FuturesUnordered> = @@ -37,6 +62,7 @@ impl Producer { loop { let mut publisher_subscribed = self.publisher.clone(); let mut publisher_track_status = self.publisher.clone(); + let mut publisher_subscribe_ns = self.publisher.clone(); tokio::select! { // Handle a new subscribe request @@ -69,28 +95,60 @@ impl Producer { } }.boxed()) }, + Some(subscribe_ns) = publisher_subscribe_ns.subscribe_namespace_received() => { + let this = self.clone(); + + tasks.push(async move { + let info = subscribe_ns.info.clone(); + log::info!("serving subscribe_namespace: {:?}", info); + + if let Err(err) = this.serve_subscribe_namespace(subscribe_ns).await { + log::warn!("failed serving subscribe_namespace: {:?}, error: {}", info, err) + } + }.boxed()) + }, _= tasks.next(), if !tasks.is_empty() => {}, else => return Ok(()), }; } } - /// Serve a subscribe request. async fn serve_subscribe(self, subscribed: Subscribed) -> Result<(), anyhow::Error> { let namespace = subscribed.track_namespace.clone(); let track_name = subscribed.track_name.clone(); - // Check local tracks first, and serve from local if possible - if let Some(mut local) = self.locals.retrieve(&namespace) { - // Pass the full requested namespace, not the announced prefix - if let Some(track) = local.subscribe(namespace.clone(), &track_name) { - log::info!("serving subscribe from local: {:?}", track.info); - return Ok(subscribed.serve(track).await?); + if let Some(track_info) = self + .locals + .get_or_create_track_info(&namespace, &track_name) + { + if track_info.should_subscribe_upstream() { + log::info!( + "subscribe needs upstream request: {}/{}", + namespace, + track_name + ); + + if let Some(reader) = self.locals.subscribe_upstream(track_info.clone()) { + log::info!( + "forwarding subscribe upstream via TrackInfo: {}/{}", + namespace, + track_name + ); + return Ok(subscribed.serve(reader).await?); + } } + + let reader = track_info.get_reader(); + log::info!( + "serving subscribe from local: {}/{} (state: {:?})", + namespace, + track_name, + track_info.state() + ); + return Ok(subscribed.serve(reader).await?); } if let Some(remotes) = self.remotes { - // Check remote tracks second, and serve from remote if possible match remotes.route(&namespace).await { Ok(remote) => { if let Some(remote) = remote { @@ -105,7 +163,7 @@ impl Producer { } } } - // Track not found - close the subscription with not found error + let err = ServeError::not_found_ctx(format!( "track '{}/{}' not found in local or remote tracks", namespace, track_name @@ -114,7 +172,198 @@ impl Producer { Err(err.into()) } - /// Serve a track_status request. + async fn serve_subscribe_namespace( + mut self, + mut subscribe_ns: SubscribeNamespaceReceived, + ) -> Result<(), anyhow::Error> { + let namespace_prefix = subscribe_ns.namespace_prefix.clone(); + + // Register with subscriber registry to receive PUBLISH and PUBLISH_NAMESPACE notifications + let (_subscription_guard, mut publish_rx, mut publish_ns_rx) = + if let Some(ref registry) = self.subscriber_registry { + let (id, rx, rx_ns) = registry.register(namespace_prefix.clone()); + ( + Some(crate::SubscriptionGuard::new(registry.clone(), id)), + Some(rx), + Some(rx_ns), + ) + } else { + (None, None, None) + }; + + // Find existing namespaces that match the prefix + let matching_namespaces: Vec = self + .locals + .matching_namespaces(&namespace_prefix) + .into_iter() + .collect(); + + // Accept the subscription (even if no current matches - publisher may arrive later) + subscribe_ns.ok()?; + + log::info!( + "accepted SUBSCRIBE_NAMESPACE for prefix {:?}, {} existing matches", + namespace_prefix, + matching_namespaces.len() + ); + + // Send PUBLISH_NAMESPACE for existing namespaces + for namespace in matching_namespaces { + log::info!( + "sending PUBLISH_NAMESPACE for {:?} (matched prefix {:?})", + namespace, + namespace_prefix + ); + match self.publisher.publish_namespace(namespace.clone()).await { + Ok(_publish_ns) => { + log::debug!("sent PUBLISH_NAMESPACE for {:?}", namespace); + // Note: publish_ns is kept alive to maintain the announcement + } + Err(e) => { + log::warn!( + "failed to send PUBLISH_NAMESPACE for {:?}: {}", + namespace, + e + ); + } + } + } + + // If we have a publish receiver, listen for new PUBLISH and PUBLISH_NAMESPACE notifications + if publish_rx.is_some() || publish_ns_rx.is_some() { + loop { + tokio::select! { + // Wait for the subscription to close + result = subscribe_ns.closed() => { + result?; + break; + } + // Wait for PUBLISH notifications + notification = async { + if let Some(ref mut rx) = publish_rx { + rx.recv().await + } else { + std::future::pending().await + } + } => { + match notification { + Ok(publish_notif) => { + log::info!( + "received PUBLISH notification for {}/{} on subscription prefix {:?}", + publish_notif.namespace, + publish_notif.track_name, + namespace_prefix + ); + + // Get the TrackReader for this track so we can stream data + if let Some(track_info) = self.locals.get_track_info( + &publish_notif.namespace, + &publish_notif.track_name, + ) { + let track_reader = track_info.get_reader(); + + // Use publisher.publish() which sends PUBLISH with forward=1 + // This allows forwarding objects immediately + let mut publisher = self.publisher.clone(); + let ns = publish_notif.namespace.clone(); + let name = publish_notif.track_name.clone(); + tokio::spawn(async move { + match publisher.publish(track_reader.clone()).await { + Ok(published) => { + log::info!( + "forwarded PUBLISH for {}/{} with forward=1, streaming immediately", + ns, name + ); + // serve_immediately() starts streaming without waiting for PUBLISH_OK + // Since forward=1, subscriber expects data immediately + // If subscriber sends error, serve will end and we cleanup + match published.serve_immediately(track_reader).await { + Ok(()) => { + log::info!("track {}/{} serving completed", ns, name); + } + Err(e) => { + log::warn!( + "track {}/{} serving ended: {}", + ns, name, e + ); + // Cleanup handled by Published drop + } + } + } + Err(e) => { + log::warn!( + "failed to publish track {}/{}: {}", + ns, name, e + ); + } + } + }); + } else { + log::warn!( + "no track info found for {}/{}, cannot forward PUBLISH", + publish_notif.namespace, + publish_notif.track_name + ); + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + log::warn!("subscription lagged by {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + log::debug!("publish notification channel closed"); + break; + } + } + } + // Wait for PUBLISH_NAMESPACE notifications -> forward as NAMESPACE message + notification = async { + if let Some(ref mut rx) = publish_ns_rx { + rx.recv().await + } else { + std::future::pending().await + } + } => { + match notification { + Ok(ns_notif) => { + log::info!( + "received PUBLISH_NAMESPACE notification for {:?} on subscription prefix {:?}", + ns_notif.namespace, + namespace_prefix + ); + // Forward NAMESPACE message to the subscriber (not PUBLISH_NAMESPACE) + // NAMESPACE (0x08) is the draft-16 message for announcing namespaces + // to SUBSCRIBE_NAMESPACE subscribers + let namespace_msg = message::Namespace { + id: subscribe_ns.info.request_id, + track_namespace: ns_notif.namespace.clone(), + params: KeyValuePairs::new(), + }; + self.publisher.forward_namespace(namespace_msg); + log::debug!( + "forwarded NAMESPACE for {:?} (request_id={})", + ns_notif.namespace, + subscribe_ns.info.request_id + ); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + log::warn!("namespace subscription lagged by {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + log::debug!("publish_namespace notification channel closed"); + break; + } + } + } + } + } + } else { + // No registry, just wait for close + subscribe_ns.closed().await?; + } + + Ok(()) + } + async fn serve_track_status( self, mut track_status_requested: TrackStatusRequested, diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index 0daf6edb..f4756ea7 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -8,6 +8,7 @@ use url::Url; use crate::{ Consumer, Coordinator, Locals, Producer, Remotes, RemotesConsumer, RemotesProducer, Session, + SubscriberRegistry, }; // A type alias for boxed future @@ -58,6 +59,7 @@ pub struct Relay { locals: Locals, remotes: Option<(RemotesProducer, RemotesConsumer)>, coordinator: Arc, + subscriber_registry: SubscriberRegistry, } impl Relay { @@ -107,6 +109,9 @@ impl Relay { } .produce(); + // Create subscriber registry for SUBSCRIBE_NAMESPACE tracking + let subscriber_registry = SubscriberRegistry::new(); + Ok(Self { quic_endpoints: endpoints, announce_url: config.announce, @@ -114,6 +119,7 @@ impl Relay { locals, remotes: Some(remotes), coordinator: config.coordinator, + subscriber_registry, }) } @@ -219,6 +225,7 @@ impl Relay { let remotes = remotes.clone(); let forward = forward_producer.clone(); let coordinator = self.coordinator.clone(); + let subscriber_registry = self.subscriber_registry.clone(); // Spawn a new task to handle the connection tasks.push(async move { @@ -235,8 +242,23 @@ impl Relay { let moq_session = session; let session = Session { session: moq_session, - producer: publisher.map(|publisher| Producer::new(publisher, locals.clone(), remotes)), - consumer: subscriber.map(|subscriber| Consumer::new(subscriber, locals, coordinator, forward)), + producer: publisher.map(|publisher| { + Producer::with_registry( + publisher, + locals.clone(), + remotes, + subscriber_registry.clone(), + ) + }), + consumer: subscriber.map(|subscriber| { + Consumer::with_registry( + subscriber, + locals, + coordinator, + forward, + subscriber_registry, + ) + }), }; if let Err(err) = session.run().await { diff --git a/moq-relay-ietf/src/subscriber_registry.rs b/moq-relay-ietf/src/subscriber_registry.rs new file mode 100644 index 00000000..a21119ba --- /dev/null +++ b/moq-relay-ietf/src/subscriber_registry.rs @@ -0,0 +1,277 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use moq_transport::coding::TrackNamespace; +use tokio::sync::broadcast; + +/// Information about an active SUBSCRIBE_NAMESPACE subscription +#[derive(Clone)] +pub struct NamespaceSubscription { + /// The namespace prefix this subscription is for + pub prefix: TrackNamespace, + /// Channel to send PUBLISH notifications to this subscriber + pub publish_tx: broadcast::Sender, + /// Channel to send PUBLISH_NAMESPACE notifications to this subscriber + pub publish_ns_tx: broadcast::Sender, +} + +/// Notification sent when a PUBLISH arrives that matches a subscription +#[derive(Clone, Debug)] +pub struct PublishNotification { + pub namespace: TrackNamespace, + pub track_name: String, + pub track_alias: u64, +} + +/// Notification sent when a PUBLISH_NAMESPACE arrives that matches a subscription +#[derive(Clone, Debug)] +pub struct PublishNamespaceNotification { + pub namespace: TrackNamespace, +} + +/// Registry for tracking active SUBSCRIBE_NAMESPACE subscriptions +/// +/// When a subscriber sends SUBSCRIBE_NAMESPACE, they register here. +/// When a publisher sends PUBLISH, we find matching subscriptions and notify. +#[derive(Clone)] +pub struct SubscriberRegistry { + inner: Arc>, +} + +struct SubscriberRegistryInner { + /// Map from subscription ID to subscription info + subscriptions: HashMap, + /// Next subscription ID + next_id: u64, +} + +impl SubscriberRegistry { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(SubscriberRegistryInner { + subscriptions: HashMap::new(), + next_id: 0, + })), + } + } + + /// Register a SUBSCRIBE_NAMESPACE subscription + /// Returns (subscription_id, receiver for PUBLISH notifications, receiver for PUBLISH_NAMESPACE notifications) + pub fn register( + &self, + prefix: TrackNamespace, + ) -> ( + u64, + broadcast::Receiver, + broadcast::Receiver, + ) { + let mut inner = self.inner.lock().unwrap(); + + let id = inner.next_id; + inner.next_id += 1; + + // Create broadcast channels for PUBLISH and PUBLISH_NAMESPACE notifications + let (publish_tx, publish_rx) = broadcast::channel(64); + let (publish_ns_tx, publish_ns_rx) = broadcast::channel(64); + + let subscription = NamespaceSubscription { + prefix, + publish_tx, + publish_ns_tx, + }; + + inner.subscriptions.insert(id, subscription); + + log::debug!("registered namespace subscription id={}", id); + + (id, publish_rx, publish_ns_rx) + } + + /// Unregister a subscription + pub fn unregister(&self, id: u64) { + let mut inner = self.inner.lock().unwrap(); + if inner.subscriptions.remove(&id).is_some() { + log::debug!("unregistered namespace subscription id={}", id); + } + } + + /// Find all subscriptions that match a given namespace and notify them of a PUBLISH + /// Returns the number of matching subscriptions notified + pub fn notify_publish( + &self, + namespace: &TrackNamespace, + track_name: &str, + track_alias: u64, + ) -> usize { + let inner = self.inner.lock().unwrap(); + + let notification = PublishNotification { + namespace: namespace.clone(), + track_name: track_name.to_string(), + track_alias, + }; + + let mut notified = 0; + + for (id, sub) in inner.subscriptions.iter() { + // Check if the namespace matches the subscription prefix + // The subscription prefix should be a prefix of the namespace + if Self::prefix_matches(&sub.prefix, namespace) { + if let Err(e) = sub.publish_tx.send(notification.clone()) { + log::warn!("failed to notify subscription id={}: {}", id, e); + } else { + log::debug!( + "notified subscription id={} of PUBLISH {}/{}", + id, + namespace, + track_name + ); + notified += 1; + } + } + } + + notified + } + + /// Find all subscriptions that match a given namespace and notify them of a PUBLISH_NAMESPACE + /// Returns the number of matching subscriptions notified + pub fn notify_publish_namespace(&self, namespace: &TrackNamespace) -> usize { + let inner = self.inner.lock().unwrap(); + + let notification = PublishNamespaceNotification { + namespace: namespace.clone(), + }; + + let mut notified = 0; + + for (id, sub) in inner.subscriptions.iter() { + // Check if the namespace matches the subscription prefix + if Self::prefix_matches(&sub.prefix, namespace) { + if let Err(e) = sub.publish_ns_tx.send(notification.clone()) { + log::warn!( + "failed to notify subscription id={} of PUBLISH_NAMESPACE: {}", + id, + e + ); + } else { + log::debug!( + "notified subscription id={} of PUBLISH_NAMESPACE {:?}", + id, + namespace + ); + notified += 1; + } + } + } + + notified + } + + /// Check if prefix is a prefix of namespace + fn prefix_matches(prefix: &TrackNamespace, namespace: &TrackNamespace) -> bool { + if prefix.fields.len() > namespace.fields.len() { + return false; + } + + prefix + .fields + .iter() + .zip(namespace.fields.iter()) + .all(|(a, b)| a == b) + } + + /// Get all subscriptions matching a prefix (for debugging) + pub fn matching_subscriptions(&self, namespace: &TrackNamespace) -> Vec { + let inner = self.inner.lock().unwrap(); + + inner + .subscriptions + .iter() + .filter(|(_, sub)| Self::prefix_matches(&sub.prefix, namespace)) + .map(|(id, _)| *id) + .collect() + } +} + +impl Default for SubscriberRegistry { + fn default() -> Self { + Self::new() + } +} + +/// RAII guard that unregisters on drop +pub struct SubscriptionGuard { + registry: SubscriberRegistry, + id: u64, +} + +impl SubscriptionGuard { + pub fn new(registry: SubscriberRegistry, id: u64) -> Self { + Self { registry, id } + } + + pub fn id(&self) -> u64 { + self.id + } +} + +impl Drop for SubscriptionGuard { + fn drop(&mut self) { + self.registry.unregister(self.id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ns(path: &str) -> TrackNamespace { + TrackNamespace::from_utf8_path(path) + } + + #[test] + fn test_prefix_matching() { + assert!(SubscriberRegistry::prefix_matches(&ns("live"), &ns("live/stream1"))); + assert!(SubscriberRegistry::prefix_matches(&ns("live"), &ns("live"))); + // An empty prefix (zero fields) should match everything + let empty = TrackNamespace::new(); + assert!(SubscriberRegistry::prefix_matches(&empty, &ns("live/stream1"))); + assert!(!SubscriberRegistry::prefix_matches(&ns("live/stream1"), &ns("live"))); + assert!(!SubscriberRegistry::prefix_matches(&ns("other"), &ns("live/stream1"))); + } + + #[test] + fn test_register_unregister() { + let registry = SubscriberRegistry::new(); + + let (id1, _rx1, _rx1_ns) = registry.register(ns("live")); + let (id2, _rx2, _rx2_ns) = registry.register(ns("live/room1")); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 2); + + registry.unregister(id1); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 1); + + registry.unregister(id2); + + assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 0); + } + + #[tokio::test] + async fn test_notify_publish() { + let registry = SubscriberRegistry::new(); + + let (id, mut rx, _rx_ns) = registry.register(ns("live")); + + let notified = registry.notify_publish(&ns("live/stream1"), "video", 100); + assert_eq!(notified, 1); + + let notification = rx.recv().await.unwrap(); + assert_eq!(notification.track_name, "video"); + assert_eq!(notification.track_alias, 100); + + registry.unregister(id); + } +} From 13c4160780fb676fb0906df2d41863972942321b Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:28 -0700 Subject: [PATCH 05/21] Upgrade web-transport crates to v0.10 with subprotocol negotiation --- Cargo.lock | 126 ++++++++++++++++++++++-------------- Cargo.toml | 2 +- moq-native-ietf/Cargo.toml | 2 +- moq-native-ietf/src/quic.rs | 49 ++++++++------ 4 files changed, 109 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8de6eb4c..ec77f290 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -746,7 +746,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.2.6", + "indexmap 2.13.0", "slab", "tokio", "tokio-util", @@ -761,9 +761,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "heck" @@ -972,13 +972,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.16.1", "serde", + "serde_core", ] [[package]] @@ -1041,9 +1042,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.81" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" dependencies = [ "once_cell", "wasm-bindgen", @@ -1188,6 +1189,7 @@ dependencies = [ "chrono", "clap", "env_logger", + "futures", "log", "moq-native-ietf", "moq-transport", @@ -1228,6 +1230,7 @@ dependencies = [ "bytes", "clap", "env_logger", + "futures", "log", "moq-catalog", "moq-native-ietf", @@ -1580,6 +1583,7 @@ version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ + "aws-lc-rs", "bytes", "fastbloom", "getrandom 0.3.3", @@ -2125,7 +2129,7 @@ version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.13.0", "itoa", "memchr", "ryu", @@ -2165,7 +2169,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.2.6", + "indexmap 2.13.0", "schemars 0.9.0", "schemars 1.0.4", "serde", @@ -2187,6 +2191,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sfv" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d471eaefb14f4b30032525bdb124b36e55ba9cb1292080e06f1a236cd10fe87" +dependencies = [ + "base64", + "indexmap 2.13.0", + "ref-cast", +] + [[package]] name = "sha1_smol" version = "1.0.0" @@ -2670,9 +2685,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" dependencies = [ "cfg-if", "once_cell", @@ -2681,20 +2696,6 @@ dependencies = [ "wasm-bindgen-shared", ] -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.104" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" -dependencies = [ - "bumpalo", - "log", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - [[package]] name = "wasm-bindgen-futures" version = "0.4.42" @@ -2709,9 +2710,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2719,31 +2720,43 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" dependencies = [ + "bumpalo", "proc-macro2", "quote", "syn", - "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.104" +version = "0.2.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" dependencies = [ "unicode-ident", ] +[[package]] +name = "web-streams" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48465a648c14f53f6d8319b95bc336a44627f6aa6bd94270463777af8ed65deb" +dependencies = [ + "thiserror 2.0.17", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" dependencies = [ "js-sys", "wasm-bindgen", @@ -2761,56 +2774,73 @@ dependencies = [ [[package]] name = "web-transport" -version = "0.3.0" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5793aee9b4cf993212042c6b1656d877de9ad32b9eb3281d7bc95f4dce3f6591" +checksum = "23c3f78eca5afa10eb7b8ab64b4e5e521a006f0cbd88de09e44d55ef37e8855a" dependencies = [ "bytes", - "thiserror 1.0.61", + "thiserror 2.0.17", + "url", "web-transport-quinn", "web-transport-wasm", ] [[package]] name = "web-transport-proto" -version = "0.2.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0922f754c890ceb9741c00a0f5c730aaa4b52fe8772934a0ad19a03daee0ca" +checksum = "0225d295c8ac00a2e9a498aefeaf3f3c6186da12a251c938189b15b82ea22808" dependencies = [ "bytes", "http", - "thiserror 1.0.61", + "sfv", + "thiserror 2.0.17", + "tokio", "url", ] [[package]] name = "web-transport-quinn" -version = "0.3.0" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d248fb83873166e1fce7e91370deb15bd5213cf4352242e32ccd4abc8aeb2cef" +checksum = "82e77c81fe4cf56c1049e07c6ed9c00862a967010fe9da4f5e02dc7f4d71fdac" dependencies = [ "bytes", "futures", "http", - "log", "quinn", - "quinn-proto", - "thiserror 1.0.61", + "rustls 0.23.31", + "rustls-native-certs 0.8.1", + "thiserror 2.0.17", "tokio", + "tracing", "url", "web-transport-proto", + "web-transport-trait", +] + +[[package]] +name = "web-transport-trait" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb67841c4a481ca3c1412ee4c9f463987401991e1ddc000903df2124f3dc85e9" +dependencies = [ + "bytes", ] [[package]] name = "web-transport-wasm" -version = "0.1.0" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64be28348e18cb1f44e4c8733dc2bd9520d782be840b2b978724dfd1b1bdefa3" +checksum = "6816176def6e8df1636c8fc2c401f37add41ccad1518705e209d9a7ada3d144c" dependencies = [ "bytes", "js-sys", + "thiserror 2.0.17", + "url", "wasm-bindgen", "wasm-bindgen-futures", + "web-streams", "web-sys", ] diff --git a/Cargo.toml b/Cargo.toml index c903feaa..93663988 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ resolver = "2" [workspace.dependencies] -web-transport = "0.3" +web-transport = "0.10" env_logger = "0.11" log = { version = "0.4", features = ["std"] } diff --git a/moq-native-ietf/Cargo.toml b/moq-native-ietf/Cargo.toml index adc5b302..41eeb1c9 100644 --- a/moq-native-ietf/Cargo.toml +++ b/moq-native-ietf/Cargo.toml @@ -14,7 +14,7 @@ categories = ["multimedia", "network-programming", "web-programming"] [dependencies] moq-transport = { path = "../moq-transport", version = "0.12" } web-transport = { workspace = true } -web-transport-quinn = "0.3" +web-transport-quinn = { version = "0.11", default-features = false, features = ["ring"] } rustls = { version = "0.23", features = ["ring"] } rustls-pemfile = "2" diff --git a/moq-native-ietf/src/quic.rs b/moq-native-ietf/src/quic.rs index bed04d83..e5d6ac8d 100644 --- a/moq-native-ietf/src/quic.rs +++ b/moq-native-ietf/src/quic.rs @@ -160,7 +160,7 @@ impl Endpoint { if let Some(mut config) = config.tls.server { config.alpn_protocols = vec![ - web_transport_quinn::ALPN.to_vec(), + web_transport_quinn::ALPN.as_bytes().to_vec(), moq_transport::setup::ALPN.to_vec(), ]; config.key_log = Arc::new(rustls::KeyLogFile::new()); @@ -305,22 +305,24 @@ impl Server { server_name, ); - let session = match alpn.as_bytes() { - web_transport_quinn::ALPN => { - // Wait for the CONNECT request. - let request = web_transport_quinn::accept(conn) - .await - .context("failed to receive WebTransport request")?; - - // Accept the CONNECT request. - request - .ok() - .await - .context("failed to respond to WebTransport request")? - } - // A bit of a hack to pretend like we're a WebTransport session - moq_transport::setup::ALPN => conn.into(), - _ => anyhow::bail!("unsupported ALPN: {}", alpn), + let alpn_bytes = alpn.as_bytes(); + let session = if alpn_bytes == web_transport_quinn::ALPN.as_bytes() { + // Wait for the WebTransport CONNECT request (includes H3 SETTINGS exchange). + let request = web_transport_quinn::Request::accept(conn) + .await + .context("failed to receive WebTransport request")?; + + // Accept the CONNECT request. + request + .ok() + .await + .context("failed to respond to WebTransport request")? + } else if alpn_bytes == moq_transport::setup::ALPN { + // Raw QUIC mode — create a session with no H3 framing. + let request = url::Url::parse("moqt://localhost").unwrap(); + web_transport_quinn::Session::raw(conn, request, web_transport_quinn::proto::ConnectResponse::default()) + } else { + anyhow::bail!("unsupported ALPN: {}", alpn) }; Ok((session.into(), connection_id_hex)) @@ -373,7 +375,7 @@ impl Client { // TODO support connecting to both ALPNs at the same time config.alpn_protocols = vec![match url.scheme() { - "https" => web_transport_quinn::ALPN.to_vec(), + "https" => web_transport_quinn::ALPN.as_bytes().to_vec(), "moqt" => moq_transport::setup::ALPN.to_vec(), _ => anyhow::bail!("url scheme must be 'https' or 'moqt'"), }]; @@ -426,8 +428,15 @@ impl Client { .to_string(); let session = match url.scheme() { - "https" => web_transport_quinn::connect_with(connection, url).await?, - "moqt" => connection.into(), + "https" => { + // Build a ConnectRequest with the MoQT version as the WebTransport subprotocol. + // Per draft-15+, version negotiation uses ALPN (raw QUIC) or + // wt-available-protocols (WebTransport) instead of CLIENT_SETUP versions. + let request = web_transport_quinn::proto::ConnectRequest::new(url.clone()) + .with_protocol(std::str::from_utf8(moq_transport::setup::ALPN).unwrap()); + web_transport_quinn::Session::connect(connection, request).await? + } + "moqt" => web_transport_quinn::Session::raw(connection, url.clone(), web_transport_quinn::proto::ConnectResponse::default()), _ => unreachable!(), }; From 15d2b84d2cf92426666e3fa502cfe0299e47f35e Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:34 -0700 Subject: [PATCH 06/21] Add moq-test-client for interoperability testing --- moq-test-client/Cargo.toml | 2 +- moq-test-client/src/main.rs | 4 +- moq-test-client/src/scenarios.rs | 130 ++++++++++++++++++------------- 3 files changed, 80 insertions(+), 56 deletions(-) diff --git a/moq-test-client/Cargo.toml b/moq-test-client/Cargo.toml index f4a4db17..37840a02 100644 --- a/moq-test-client/Cargo.toml +++ b/moq-test-client/Cargo.toml @@ -18,7 +18,7 @@ path = "src/main.rs" [dependencies] moq-transport = { path = "../moq-transport", version = "0.12" } moq-native-ietf = { path = "../moq-native-ietf", version = "0.7" } -web-transport = "0.3" +web-transport = { workspace = true } url = "2" diff --git a/moq-test-client/src/main.rs b/moq-test-client/src/main.rs index 9438935b..7b925a66 100644 --- a/moq-test-client/src/main.rs +++ b/moq-test-client/src/main.rs @@ -140,9 +140,9 @@ async fn run_test(args: &Args, test_case: TestCase) -> TestResult { let result = match test_case { TestCase::SetupOnly => scenarios::test_setup_only(args).await, - TestCase::AnnounceOnly => scenarios::test_announce_only(args).await, + TestCase::AnnounceOnly => scenarios::test_publish_namespace_only(args).await, TestCase::SubscribeError => scenarios::test_subscribe_error(args).await, - TestCase::AnnounceSubscribe => scenarios::test_announce_subscribe(args).await, + TestCase::AnnounceSubscribe => scenarios::test_publish_namespace_subscribe(args).await, TestCase::SubscribeBeforeAnnounce => scenarios::test_subscribe_before_announce(args).await, TestCase::PublishNamespaceDone => scenarios::test_publish_namespace_done(args).await, }; diff --git a/moq-test-client/src/scenarios.rs b/moq-test-client/src/scenarios.rs index ce6a923c..6752c836 100644 --- a/moq-test-client/src/scenarios.rs +++ b/moq-test-client/src/scenarios.rs @@ -10,7 +10,11 @@ use anyhow::{Context, Result}; use tokio::time::{timeout, Duration}; use moq_native_ietf::quic; -use moq_transport::{coding::TrackNamespace, serve::Tracks, session::Session}; +use moq_transport::{ + coding::TrackNamespace, + serve::Tracks, + session::{Publisher, Session}, +}; use crate::Args; @@ -20,7 +24,7 @@ const TEST_TIMEOUT: Duration = Duration::from_secs(10); /// Namespace used for test operations const TEST_NAMESPACE: &str = "moq-test/interop"; -/// Track name used for test operations +/// Track name used for test operations const TEST_TRACK: &str = "test-track"; /// Helper to connect to a relay and establish a session @@ -72,10 +76,10 @@ pub async fn test_setup_only(args: &Args) -> Result { .context("test timed out")? } -/// T0.2: Announce Only +/// T0.2: Publish namespace Only /// -/// Connect to relay, announce a namespace, receive PUBLISH_NAMESPACE_OK, close. -pub async fn test_announce_only(args: &Args) -> Result { +/// Connect to relay, publish a namespace, receive PUBLISH_NAMESPACE_OK, close. +pub async fn test_publish_namespace_only(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let (session, cid) = connect(args).await.context("failed to connect to relay")?; let mut cids = TestConnectionIds::default(); @@ -86,32 +90,31 @@ pub async fn test_announce_only(args: &Args) -> Result { .context("SETUP exchange failed")?; let namespace = TrackNamespace::from_utf8_path(TEST_NAMESPACE); - let (_, _, reader) = Tracks::new(namespace.clone()).produce(); - log::info!("Announcing namespace: {}", TEST_NAMESPACE); + log::info!("Publishing namespace: {}", TEST_NAMESPACE); - // Run announce with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. - // NOTE: The announce() method blocks waiting for subscriptions after getting OK. + // Run publish namespace with a timeout - we want to verify we get PUBLISH_NAMESPACE_OK. + // NOTE: The publish_namespace() method sends PUBLISH_NAMESPACE and wait for OK or ERROR. // If we get PUBLISH_NAMESPACE_ERROR instead of OK, the method returns Err immediately. - // So timing out here means: either (a) got OK and waiting for subs, or (b) relay never responded. - // We accept this limitation since (b) would indicate a broken relay anyway. - // TODO: For stricter verification, use lower-level Announce::ok() method directly. - let announce_result = tokio::select! { - res = publisher.announce(reader) => res, + // So timing out here means relay never responded and connection may be broken. + let publish_ns = publisher.publish_namespace(namespace).await?; + + let publish_ns_result = tokio::select! { + res = publish_ns.ok() => res, res = session.run() => { res.context("session error")?; anyhow::bail!("session ended before announce completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { // If we got an error from the relay, announce() would have returned already. - // Timing out means we're past the OK and now waiting for subscriptions. - log::info!("Announce succeeded (no error received, waiting for subscriptions timed out)"); - return Ok(cids); + // Timing out means the relay never responded and connection may be broken. + log::info!("Publishing namespace failed (relay did not reply)"); + return Err(anyhow::anyhow!("publish namespace timed out")); } }; - // If we get here, announce completed (which means it errored or namespace was cancelled) - announce_result.context("announce failed")?; + // If we get here, publish namespace completed (which means it errored or namespace was cancelled) + publish_ns_result.context("publish namespace failed")?; Ok(cids) }) @@ -190,11 +193,11 @@ pub async fn test_subscribe_error(args: &Args) -> Result { .context("test timed out")? } -/// T0.4: Announce + Subscribe +/// T0.4: Publish Namespace + Subscribe /// -/// Two clients: publisher announces a namespace, subscriber subscribes to a track. +/// Two clients: publisher publishes a namespace, subscriber subscribes to a track. /// Verifies the relay correctly routes the subscription to the publisher. -pub async fn test_announce_subscribe(args: &Args) -> Result { +pub async fn test_publish_namespace_subscribe(args: &Args) -> Result { timeout(TEST_TIMEOUT, async { let mut cids = TestConnectionIds::default(); @@ -222,7 +225,12 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { // Create the track that subscriber will request let _track_writer = pub_writer.create(TEST_TRACK); - log::info!("Publisher announcing namespace: {}", TEST_NAMESPACE); + log::info!("Publisher publishing namespace: {}", TEST_NAMESPACE); + + let publish_ns = publisher + .publish_namespace(namespace.clone()) + .await + .context("publish_namespace call failed")?; // Subscriber: set up tracks and subscribe let (mut sub_writer, _, _sub_reader) = Tracks::new(namespace.clone()).produce(); @@ -230,37 +238,51 @@ pub async fn test_announce_subscribe(args: &Args) -> Result { .create(TEST_TRACK) .ok_or_else(|| anyhow::anyhow!("failed to create subscriber track"))?; - log::info!( - "Subscriber subscribing to track: {}/{}", - TEST_NAMESPACE, - TEST_TRACK - ); - - // Run everything concurrently. We expect the subscriber to get a response - // (either SUBSCRIBE_OK or error) within the timeout. + // Run everything concurrently. Session::run() consumes self, so + // publish_namespace→subscribe must be sequenced inside a single async + // block running alongside both sessions. + let mut pub_subscriber_handler = publisher.clone(); tokio::select! { - // Publisher announces and waits for subscriptions - res = publisher.announce(pub_reader) => { - res.context("publisher announce failed")?; - log::info!("Publisher announce completed"); - } - // Subscriber subscribes - this is the main thing we're testing - res = subscriber.subscribe(sub_track) => { - match res { + // Publisher publishes namespace, then subscriber subscribes + res = async { + publish_ns.ok().await.context("publish namespace failed")?; + log::info!("Publisher got PUBLISH_NAMESPACE_OK"); + + log::info!("Subscribing to track: {}/{}", TEST_NAMESPACE, TEST_TRACK); + // Subscriber subscribes - this is the main thing we're testing + match subscriber.subscribe(sub_track).await { Ok(()) => log::info!("Subscriber got SUBSCRIBE_OK - relay routed subscription correctly"), Err(e) => log::info!("Subscriber got error: {} - subscription was processed", e), } + Ok::<_, anyhow::Error>(()) + } => { + res?; + } + // Serve incoming subscriptions forwarded by the relay to the publisher + res = async { + while let Some(subscribed) = pub_subscriber_handler.subscribed().await { + let info = subscribed.info.clone(); + log::info!("Publisher serving subscribe: {:?}", info); + if let Err(err) = Publisher::serve_subscribe(subscribed, pub_reader.clone()).await { + log::warn!("Failed serving subscribe: {:?}, error: {}", info, err); + } + } + Ok::<_, anyhow::Error>(()) + } => { + res?; } // Run publisher session res = pub_session.run() => { res.context("publisher session error")?; + anyhow::bail!("publisher session ended unexpectedly"); } // Run subscriber session res = sub_session.run() => { res.context("subscriber session error")?; + anyhow::bail!("subscriber session ended unexpectedly"); } // Timeout: give the relay time to route the subscription - _ = tokio::time::sleep(Duration::from_secs(3)) => { + _ = tokio::time::sleep(Duration::from_secs(5)) => { // If we hit this timeout, the subscription may still be pending. // This isn't necessarily a failure - some relays may hold subscriptions // until the track has data. Log for visibility. @@ -289,27 +311,30 @@ pub async fn test_publish_namespace_done(args: &Args) -> Result res, + let publish_ns = publisher.publish_namespace(namespace).await?; + + let publish_ns_result = tokio::select! { + res = publish_ns.ok() => res, res = session.run() => { res.context("session error")?; anyhow::bail!("session ended before announce completed"); } _ = tokio::time::sleep(Duration::from_secs(2)) => { - // No error received - announce is active and waiting for subscriptions - log::info!("Announce active, now sending PUBLISH_NAMESPACE_DONE"); - // Dropping out of this block will drop the announce, which sends PUBLISH_NAMESPACE_DONE - Ok(()) + // If we got an error from the relay, announce() would have returned already. + // Timing out means the relay never responded and connection may be broken. + log::info!("Publishing namespace failed (relay did not reply)"); + return Err(anyhow::anyhow!("publish namespace timed out")); } }; - result.context("announce failed")?; + publish_ns_result.context("publish namespace failed")?; + + drop(publish_ns); // Small delay to ensure PUBLISH_NAMESPACE_DONE is sent before we close tokio::time::sleep(Duration::from_millis(100)).await; @@ -374,17 +399,16 @@ pub async fn test_subscribe_before_announce(args: &Args) -> Result { + res = publish_ns.ok() => { res.context("publisher announce failed")?; } res = pub_session.run() => { From c83a5c35d6550a81812be5d6eb3d82031342c8e4 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Wed, 8 Apr 2026 21:25:39 -0700 Subject: [PATCH 07/21] Update moq-pub, moq-sub, moq-clock for draft-16 --- moq-clock-ietf/Cargo.toml | 1 + moq-clock-ietf/src/clock.rs | 2 + moq-clock-ietf/src/main.rs | 48 +++- moq-pub/Cargo.toml | 1 + moq-pub/src/main.rs | 52 +++- moq-pub/src/media.rs | 7 +- moq-sub/src/media.rs | 42 +++- moq-transport/src/message/fetch_error.rs | 41 ---- .../src/message/pubilsh_namespace_done.rs | 41 ---- moq-transport/src/message/publish_error.rs | 41 ---- .../src/message/publish_namespace_error.rs | 41 ---- .../src/message/publish_namespace_ok.rs | 37 --- moq-transport/src/message/subscribe_error.rs | 41 ---- .../src/message/subscribe_namespace_error.rs | 41 ---- .../src/message/subscribe_namespace_ok.rs | 37 --- .../src/message/track_status_error.rs | 41 ---- .../src/message/unsubscribe_namespace.rs | 42 ---- moq-transport/src/session/announce.rs | 227 ------------------ moq-transport/src/session/announced.rs | 119 --------- 19 files changed, 139 insertions(+), 763 deletions(-) delete mode 100644 moq-transport/src/message/fetch_error.rs delete mode 100644 moq-transport/src/message/pubilsh_namespace_done.rs delete mode 100644 moq-transport/src/message/publish_error.rs delete mode 100644 moq-transport/src/message/publish_namespace_error.rs delete mode 100644 moq-transport/src/message/publish_namespace_ok.rs delete mode 100644 moq-transport/src/message/subscribe_error.rs delete mode 100644 moq-transport/src/message/subscribe_namespace_error.rs delete mode 100644 moq-transport/src/message/subscribe_namespace_ok.rs delete mode 100644 moq-transport/src/message/track_status_error.rs delete mode 100644 moq-transport/src/message/unsubscribe_namespace.rs delete mode 100644 moq-transport/src/session/announce.rs delete mode 100644 moq-transport/src/session/announced.rs diff --git a/moq-clock-ietf/Cargo.toml b/moq-clock-ietf/Cargo.toml index b717051b..854687a7 100644 --- a/moq-clock-ietf/Cargo.toml +++ b/moq-clock-ietf/Cargo.toml @@ -22,6 +22,7 @@ url = "2" # Async stuff tokio = { version = "1", features = ["full"] } +futures = "0.3" # CLI, logging, error handling clap = { version = "4", features = ["derive"] } diff --git a/moq-clock-ietf/src/clock.rs b/moq-clock-ietf/src/clock.rs index a96863ee..ada2bbd6 100644 --- a/moq-clock-ietf/src/clock.rs +++ b/moq-clock-ietf/src/clock.rs @@ -45,6 +45,7 @@ impl Publisher { group_id: next_group_id as u64, subgroup_id: 0, priority: 0, + header_type: None, }) .context("failed to create minute segment")?; @@ -66,6 +67,7 @@ impl Publisher { priority: 127, payload: time_str.clone().into_bytes().into(), extension_headers: Default::default(), + status: None, }) .context("failed to write datagram")?; diff --git a/moq-clock-ietf/src/main.rs b/moq-clock-ietf/src/main.rs index 7d0f6951..eab274b7 100644 --- a/moq-clock-ietf/src/main.rs +++ b/moq-clock-ietf/src/main.rs @@ -1,6 +1,7 @@ use moq_native_ietf::quic; use anyhow::Context; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; mod cli; mod clock; @@ -10,11 +11,36 @@ use cli::Cli; use moq_transport::{ coding::TrackNamespace, - serve, - session::{Publisher, Subscriber}, + serve::{self, TracksReader}, + session::{Publisher, SessionError, Subscriber}, }; -/// The main entry point for the MoQ Clock IETF example. +async fn serve_subscriptions( + mut publisher: Publisher, + tracks: TracksReader, +) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + loop { + tokio::select! { + Some(subscribed) = publisher.subscribed() => { + let info = subscribed.info.clone(); + let tracks = tracks.clone(); + log::info!("serving subscribe: {:?}", info); + + tasks.push(async move { + if let Err(err) = Publisher::serve_subscribe(subscribed, tracks).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); + } + }.boxed()); + } + _ = tasks.next(), if !tasks.is_empty() => {} + else => return Ok(()), + } + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init(); @@ -59,10 +85,16 @@ async fn main() -> anyhow::Result<()> { let track_writer = tracks_writer.create(&config.track).unwrap(); let clock_publisher = clock::Publisher::new_datagram(track_writer.datagrams()?); + let publish_ns = publisher + .publish_namespace(tracks_reader.namespace.clone()) + .await + .context("failed to register namespace")?; + tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = serve_subscriptions(publisher, tracks_reader) => res.context("failed to serve tracks")?, + res = publish_ns.closed() => res.context("namespace closed")?, } } else { log::info!("publishing clock via streams"); @@ -75,10 +107,16 @@ async fn main() -> anyhow::Result<()> { let track_writer = tracks_writer.create(&config.track).unwrap(); let clock_publisher = clock::Publisher::new(track_writer.subgroups()?); + let publish_ns = publisher + .publish_namespace(tracks_reader.namespace.clone()) + .await + .context("failed to register namespace")?; + tokio::select! { res = session.run() => res.context("session error")?, res = clock_publisher.run() => res.context("clock error")?, - res = publisher.announce(tracks_reader) => res.context("failed to serve tracks")?, + res = serve_subscriptions(publisher, tracks_reader) => res.context("failed to serve tracks")?, + res = publish_ns.closed() => res.context("namespace closed")?, } } } else { diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index 13085ca3..d37ebbaf 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -23,6 +23,7 @@ bytes = "1" # Async stuff tokio = { version = "1", features = ["full"] } +futures = "0.3" # CLI, logging, error handling clap = { version = "4", features = ["derive"] } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index cd9350fc..c259b440 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -4,11 +4,16 @@ use url::Url; use anyhow::Context; use clap::Parser; +use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use tokio::io::AsyncReadExt; use moq_native_ietf::quic; use moq_pub::Media; -use moq_transport::{coding::TrackNamespace, serve, session::Publisher}; +use moq_transport::{ + coding::TrackNamespace, + serve::{self, TracksReader}, + session::{Publisher, SessionError}, +}; #[derive(Parser, Clone)] pub struct Cli { @@ -39,6 +44,32 @@ pub struct Cli { pub tls: moq_native_ietf::tls::Args, } +async fn serve_subscriptions( + mut publisher: Publisher, + tracks: TracksReader, +) -> Result<(), SessionError> { + let mut tasks: FuturesUnordered> = + FuturesUnordered::new(); + + loop { + tokio::select! { + Some(subscribed) = publisher.subscribed() => { + let info = subscribed.info.clone(); + let tracks = tracks.clone(); + log::info!("serving subscribe: {:?}", info); + + tasks.push(async move { + if let Err(err) = Publisher::serve_subscribe(subscribed, tracks).await { + log::warn!("failed serving subscribe: {:?}, error: {}", info, err); + } + }.boxed()); + } + _ = tasks.next(), if !tasks.is_empty() => {} + else => return Ok(()), + } + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { env_logger::init(); @@ -71,16 +102,25 @@ async fn main() -> anyhow::Result<()> { connection_id ); - let (session, mut publisher) = Publisher::connect(session) + let (session, publisher) = Publisher::connect(session) .await .context("failed to create MoQ Transport publisher")?; + let namespace = reader.namespace.clone(); + + let publish_ns = publisher + .clone() + .publish_namespace(namespace) + .await + .context("failed to register namespace")?; + + log::info!("namespace registered, starting media and subscription handling"); + tokio::select! { res = session.run() => res.context("session error")?, - res = run_media(media) => { - res.context("media error")? - }, - res = publisher.announce(reader) => res.context("publisher error")?, + res = run_media(media) => res.context("media error")?, + res = serve_subscriptions(publisher, reader) => res.context("publisher error")?, + res = publish_ns.closed() => res.context("publisher error")?, } Ok(()) diff --git a/moq-pub/src/media.rs b/moq-pub/src/media.rs index b46f5473..f1952781 100644 --- a/moq-pub/src/media.rs +++ b/moq-pub/src/media.rs @@ -384,7 +384,12 @@ impl Track { } pub fn end_group(&mut self) { - self.current = None; + // Send EndOfGroup marker before dropping the writer + if let Some(mut current) = self.current.take() { + if let Err(e) = current.end_of_group() { + log::warn!("failed to send EndOfGroup marker: {}", e); + } + } } } diff --git a/moq-sub/src/media.rs b/moq-sub/src/media.rs index e3fd43c5..5503d9eb 100644 --- a/moq-sub/src/media.rs +++ b/moq-sub/src/media.rs @@ -183,16 +183,54 @@ impl Media { async fn recv_group(mut group: SubgroupReader, out: Arc>) -> anyhow::Result<()> { trace!("group={} start", group.group_id); + + // Pair moof+mdat into a single atomic write to prevent concurrent + // audio/video tasks from interleaving between them on stdout. + let mut pending_moof: Option> = None; + while let Some(object) = group.next().await? { trace!( "group={} fragment={} start", group.group_id, object.object_id ); - let out = out.clone(); let buf = Self::recv_object(object).await?; - out.lock().await.write_all(&buf).await?; + let is_moof = buf.len() >= 8 && &buf[4..8] == b"moof"; + let is_mdat = buf.len() >= 8 && &buf[4..8] == b"mdat"; + + if is_moof { + if let Some(orphan) = pending_moof.take() { + warn!( + "group={}: flushing orphaned moof ({} bytes) without mdat", + group.group_id, + orphan.len() + ); + out.lock().await.write_all(&orphan).await?; + } + pending_moof = Some(buf); + } else if is_mdat { + if let Some(mut moof) = pending_moof.take() { + moof.extend_from_slice(&buf); + out.lock().await.write_all(&moof).await?; + } else { + warn!( + "group={}: mdat without preceding moof ({} bytes)", + group.group_id, + buf.len() + ); + out.lock().await.write_all(&buf).await?; + } + } else { + if let Some(orphan) = pending_moof.take() { + out.lock().await.write_all(&orphan).await?; + } + out.lock().await.write_all(&buf).await?; + } + } + + if let Some(orphan) = pending_moof.take() { + out.lock().await.write_all(&orphan).await?; } Ok(()) diff --git a/moq-transport/src/message/fetch_error.rs b/moq-transport/src/message/fetch_error.rs deleted file mode 100644 index b1acc55b..00000000 --- a/moq-transport/src/message/fetch_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct FetchError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for FetchError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for FetchError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/pubilsh_namespace_done.rs b/moq-transport/src/message/pubilsh_namespace_done.rs deleted file mode 100644 index 4540ab47..00000000 --- a/moq-transport/src/message/pubilsh_namespace_done.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; - -/// Sent by the publisher to terminate a PUBLISH_NAMESPACE. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PublishNamespaceDone { - pub track_namespace: TrackNamespace, -} - -impl Decode for PublishNamespaceDone { - fn decode(r: &mut R) -> Result { - let track_namespace = TrackNamespace::decode(r)?; - - Ok(Self { track_namespace }) - } -} - -impl Encode for PublishNamespaceDone { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace.encode(w)?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = PublishNamespaceDone { - track_namespace: TrackNamespace::from_utf8_path("test/path/to/resource"), - }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishNamespaceDone::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/publish_error.rs b/moq-transport/src/message/publish_error.rs deleted file mode 100644 index f8cc02b9..00000000 --- a/moq-transport/src/message/publish_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct PublishError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace_error.rs b/moq-transport/src/message/publish_namespace_error.rs deleted file mode 100644 index 8a606621..00000000 --- a/moq-transport/src/message/publish_namespace_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an PUBLISH_NAMESPACE. -#[derive(Clone, Debug)] -pub struct PublishNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for PublishNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for PublishNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/publish_namespace_ok.rs b/moq-transport/src/message/publish_namespace_ok.rs deleted file mode 100644 index 9025f03f..00000000 --- a/moq-transport/src/message/publish_namespace_ok.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Sent by the subscriber to accept a PUBLISH_NAMESPACE. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct PublishNamespaceOk { - /// The request ID of the PUBLISH_NAMESPACE this message is replying to. - pub id: u64, -} - -impl Decode for PublishNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for PublishNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = PublishNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = PublishNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/subscribe_error.rs b/moq-transport/src/message/subscribe_error.rs deleted file mode 100644 index 7481a4bf..00000000 --- a/moq-transport/src/message/subscribe_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace_error.rs b/moq-transport/src/message/subscribe_namespace_error.rs deleted file mode 100644 index a5d99d0d..00000000 --- a/moq-transport/src/message/subscribe_namespace_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct SubscribeNamespaceError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for SubscribeNamespaceError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for SubscribeNamespaceError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/subscribe_namespace_ok.rs b/moq-transport/src/message/subscribe_namespace_ok.rs deleted file mode 100644 index 2e2a968d..00000000 --- a/moq-transport/src/message/subscribe_namespace_ok.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError}; - -/// Subscribe Namespace Ok -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SubscribeNamespaceOk { - /// The SubscribeNamespace request ID this message is replying to. - pub id: u64, -} - -impl Decode for SubscribeNamespaceOk { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - Ok(Self { id }) - } -} - -impl Encode for SubscribeNamespaceOk { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = SubscribeNamespaceOk { id: 12345 }; - msg.encode(&mut buf).unwrap(); - let decoded = SubscribeNamespaceOk::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/message/track_status_error.rs b/moq-transport/src/message/track_status_error.rs deleted file mode 100644 index 7b015ea3..00000000 --- a/moq-transport/src/message/track_status_error.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, ReasonPhrase}; - -// TODO SLG - The next draft is going to merge all these error messages to a -// common RequestError message, so we won't do a lot of work on these -// existing messages. We should add an enum for all the various error codes. - -/// Sent by the subscriber to reject an Announce. -#[derive(Clone, Debug)] -pub struct TrackStatusError { - pub id: u64, - - // An error code. - pub error_code: u64, - - // An optional, human-readable reason. - pub reason_phrase: ReasonPhrase, -} - -impl Decode for TrackStatusError { - fn decode(r: &mut R) -> Result { - let id = u64::decode(r)?; - let error_code = u64::decode(r)?; - let reason_phrase = ReasonPhrase::decode(r)?; - - Ok(Self { - id, - error_code, - reason_phrase, - }) - } -} - -impl Encode for TrackStatusError { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w)?; - self.error_code.encode(w)?; - self.reason_phrase.encode(w)?; - - Ok(()) - } -} diff --git a/moq-transport/src/message/unsubscribe_namespace.rs b/moq-transport/src/message/unsubscribe_namespace.rs deleted file mode 100644 index de257378..00000000 --- a/moq-transport/src/message/unsubscribe_namespace.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::coding::{Decode, DecodeError, Encode, EncodeError, TrackNamespace}; - -/// Unsubscribe Namespace -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct UnsubscribeNamespace { - // Echo back the track namespace prefix from subscribe namespace - pub track_namespace_prefix: TrackNamespace, -} - -impl Decode for UnsubscribeNamespace { - fn decode(r: &mut R) -> Result { - let track_namespace_prefix = TrackNamespace::decode(r)?; - Ok(Self { - track_namespace_prefix, - }) - } -} - -impl Encode for UnsubscribeNamespace { - fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.track_namespace_prefix.encode(w)?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bytes::BytesMut; - - #[test] - fn encode_decode() { - let mut buf = BytesMut::new(); - - let msg = UnsubscribeNamespace { - track_namespace_prefix: TrackNamespace::from_utf8_path("test/path/to/resource"), - }; - msg.encode(&mut buf).unwrap(); - let decoded = UnsubscribeNamespace::decode(&mut buf).unwrap(); - assert_eq!(decoded, msg); - } -} diff --git a/moq-transport/src/session/announce.rs b/moq-transport/src/session/announce.rs deleted file mode 100644 index 278c614a..00000000 --- a/moq-transport/src/session/announce.rs +++ /dev/null @@ -1,227 +0,0 @@ -use std::{collections::VecDeque, ops}; - -use crate::coding::TrackNamespace; -use crate::watch::State; -use crate::{message, serve::ServeError}; - -use super::{Publisher, Subscribed, TrackStatusRequested}; - -#[derive(Debug, Clone)] -pub struct AnnounceInfo { - pub request_id: u64, - pub namespace: TrackNamespace, -} - -struct AnnounceState { - subscribers: VecDeque, - track_statuses_requested: VecDeque, - ok: bool, - closed: Result<(), ServeError>, -} - -impl Default for AnnounceState { - fn default() -> Self { - Self { - subscribers: Default::default(), - track_statuses_requested: Default::default(), - ok: false, - closed: Ok(()), - } - } -} - -impl Drop for AnnounceState { - fn drop(&mut self) { - for subscriber in self.subscribers.drain(..) { - subscriber - .close(ServeError::not_found_ctx( - "announce dropped before subscription handled", - )) - .ok(); - } - } -} - -#[must_use = "unannounce on drop"] -pub struct Announce { - publisher: Publisher, - state: State, - - pub info: AnnounceInfo, -} - -impl Announce { - pub(super) fn new( - mut publisher: Publisher, - request_id: u64, - namespace: TrackNamespace, - ) -> (Announce, AnnounceRecv) { - let info = AnnounceInfo { - request_id, - namespace: namespace.clone(), - }; - - publisher.send_message(message::PublishNamespace { - id: request_id, - track_namespace: namespace.clone(), - params: Default::default(), - }); - - let (send, recv) = State::default().split(); - - let send = Self { - publisher, - info, - state: send, - }; - let recv = AnnounceRecv { - state: recv, - request_id, - }; - - (send, recv) - } - - // Run until we get an error - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - state.closed.clone()?; - - match state.modified() { - Some(notified) => notified, - None => return Ok(()), - } - } - .await; - } - } - - /// Wait until a subscriber is received - pub async fn subscribed(&self) -> Result, ServeError> { - loop { - { - let state = self.state.lock(); - if !state.subscribers.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.subscribers.pop_front())); - } - - state.closed.clone()?; - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - } - .await; - } - } - - pub async fn track_status_requested(&self) -> Result, ServeError> { - loop { - { - let state = self.state.lock(); - if !state.track_statuses_requested.is_empty() { - return Ok(state - .into_mut() - .and_then(|mut state| state.track_statuses_requested.pop_front())); - } - - state.closed.clone()?; - match state.modified() { - Some(notified) => notified, - None => return Ok(None), - } - } - .await; - } - } - - // Wait until an OK is received - pub async fn ok(&self) -> Result<(), ServeError> { - loop { - { - let state = self.state.lock(); - if state.ok { - return Ok(()); - } - state.closed.clone()?; - - match state.modified() { - Some(notified) => notified, - None => return Ok(()), - } - } - .await; - } - } -} - -impl Drop for Announce { - fn drop(&mut self) { - if self.state.lock().closed.is_err() { - return; - } - - self.publisher.send_message(message::PublishNamespaceDone { - track_namespace: self.namespace.clone(), - }); - } -} - -impl ops::Deref for Announce { - type Target = AnnounceInfo; - - fn deref(&self) -> &Self::Target { - &self.info - } -} - -pub(super) struct AnnounceRecv { - state: State, - pub request_id: u64, // TODO SLG - Announcements need to be looked up by both request_id and namespace, consider 2 hashmaps in publisher instead of this -} - -impl AnnounceRecv { - pub fn recv_ok(&mut self) -> Result<(), ServeError> { - if let Some(mut state) = self.state.lock_mut() { - if state.ok { - return Err(ServeError::Duplicate); - } - - state.ok = true; - } - - Ok(()) - } - - pub fn recv_error(self, err: ServeError) -> Result<(), ServeError> { - let state = self.state.lock(); - state.closed.clone()?; - - let mut state = state.into_mut().ok_or(ServeError::Done)?; - state.closed = Err(err); - - Ok(()) - } - - pub fn recv_subscribe(&mut self, subscriber: Subscribed) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; - state.subscribers.push_back(subscriber); - - Ok(()) - } - - pub fn recv_track_status_requested( - &mut self, - track_status_requested: TrackStatusRequested, - ) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Done)?; - state - .track_statuses_requested - .push_back(track_status_requested); - Ok(()) - } -} diff --git a/moq-transport/src/session/announced.rs b/moq-transport/src/session/announced.rs deleted file mode 100644 index 5b2e466a..00000000 --- a/moq-transport/src/session/announced.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::ops; - -use crate::coding::{ReasonPhrase, TrackNamespace}; -use crate::watch::State; -use crate::{message, serve::ServeError}; - -use super::{AnnounceInfo, Subscriber}; - -// There's currently no feedback from the peer, so the shared state is empty. -// If Unannounce contained an error code then we'd be talking. -#[derive(Default)] -struct AnnouncedState {} - -pub struct Announced { - session: Subscriber, - state: State, - - pub info: AnnounceInfo, - - ok: bool, - error: Option, -} - -impl Announced { - pub(super) fn new( - session: Subscriber, - request_id: u64, - namespace: TrackNamespace, - ) -> (Announced, AnnouncedRecv) { - let info = AnnounceInfo { - request_id, - namespace, - }; - - let (send, recv) = State::default().split(); - let send = Self { - session, - info, - ok: false, - error: None, - state: send, - }; - let recv = AnnouncedRecv { _state: recv }; - - (send, recv) - } - - // Send an ANNOUNCE_OK - pub fn ok(&mut self) -> Result<(), ServeError> { - if self.ok { - return Err(ServeError::Duplicate); - } - - self.session.send_message(message::PublishNamespaceOk { - id: self.info.request_id, - }); - - self.ok = true; - - Ok(()) - } - - pub async fn closed(&self) -> Result<(), ServeError> { - loop { - // Wow this is dumb and yet pretty cool. - // Basically loop until the state changes and exit when Recv is dropped. - self.state - .lock() - .modified() - .ok_or(ServeError::Cancel)? - .await; - } - } - - pub fn close(mut self, err: ServeError) -> Result<(), ServeError> { - self.error = Some(err); - Ok(()) - } -} - -impl ops::Deref for Announced { - type Target = AnnounceInfo; - - fn deref(&self) -> &AnnounceInfo { - &self.info - } -} - -impl Drop for Announced { - fn drop(&mut self) { - let err = self.error.clone().unwrap_or(ServeError::Done); - - // TODO SLG - ServeError's do not align with draft-13 Announce error codes (section 8.25) - if self.ok { - self.session.send_message(message::PublishNamespaceCancel { - track_namespace: self.namespace.clone(), - error_code: err.code(), - reason_phrase: ReasonPhrase(err.to_string()), - }); - } else { - self.session.send_message(message::PublishNamespaceError { - id: self.info.request_id, - error_code: err.code(), - reason_phrase: ReasonPhrase(err.to_string()), - }); - } - } -} - -pub(super) struct AnnouncedRecv { - _state: State, -} - -impl AnnouncedRecv { - pub fn recv_unannounce(self) -> Result<(), ServeError> { - // Will cause the state to be dropped - Ok(()) - } -} From c8cb92365a279e302286737ab4561aa277e4172a Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 20:13:10 -0700 Subject: [PATCH 08/21] Add REQUEST_UPDATE with forward=1 when subscriber arrives for paused track When a publisher sends PUBLISH with forward=0 (paused), the relay now: 1. Parses and stores the forward state from PUBLISH params 2. Waits for subscribers to arrive 3. Sends REQUEST_UPDATE with forward=1 to tell publisher to start sending This fixes the issue where subscribers would never receive data because the publisher was waiting for forward=1 but the relay never sent it. --- moq-relay-ietf/src/consumer.rs | 69 +++++++++++++++++-- moq-relay-ietf/src/local.rs | 68 ++++++++++++++++++ moq-relay-ietf/src/producer.rs | 16 ++++- moq-transport/src/session/publish_received.rs | 11 +++ moq-transport/src/session/subscriber.rs | 4 +- 5 files changed, 159 insertions(+), 9 deletions(-) diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index acb16ba5..1078ba79 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -181,12 +181,19 @@ impl Consumer { } } - async fn serve_publish(self, publish: PublishReceived) -> Result<(), anyhow::Error> { + async fn serve_publish(mut self, publish: PublishReceived) -> Result<(), anyhow::Error> { let namespace = publish.info.track_namespace.clone(); let track_name = publish.info.track_name.clone(); let track_alias = publish.info.track_alias; + let initial_forward = publish.info.forward; + let publish_request_id = publish.info.id; - log::info!("received PUBLISH for track: {}/{}", namespace, track_name); + log::info!( + "received PUBLISH for track: {}/{} (forward={})", + namespace, + track_name, + initial_forward + ); // Use auto-register variant to support SUBSCRIBE_NAMESPACE flow // where PUBLISH can arrive without prior PUBLISH_NAMESPACE @@ -226,17 +233,25 @@ impl Consumer { .insert_track(&namespace, reader) .context("failed to insert track into namespace")?; + // Store publish info for forward state management + track_info.set_publish_info(publish_request_id, initial_forward); + + // Include forward=1 in PUBLISH_OK to tell publisher to start sending immediately + let mut params = KeyValuePairs::default(); + params.set_intvalue(0x10, 1); // Forward = 1 + let msg = PublishOk { id: publish.info.id, - params: KeyValuePairs::default(), + params, }; publish.accept(writer, msg)?; log::info!( - "PUBLISH accepted, track {}/{} now in Publishing state", + "PUBLISH accepted, track {}/{} now in Publishing state (forward={})", namespace, - track_name + track_name, + initial_forward ); // Notify subscriber registry of the new PUBLISH @@ -253,6 +268,50 @@ impl Consumer { } } + // If forward=0 (paused), wait for subscribers to request forwarding + // When forward state changes to 1, send REQUEST_UPDATE to publisher + if !initial_forward { + let forward_rx = track_info.forward_receiver(); + if let Some(mut rx) = forward_rx { + log::info!( + "track {}/{} is paused (forward=0), waiting for subscriber to request forwarding", + namespace, + track_name + ); + + // Wait for forward state to change to true + loop { + rx.changed().await.ok(); + if *rx.borrow() { + // Forward state changed to true, send REQUEST_UPDATE + log::info!( + "subscriber arrived for paused track {}/{}, sending REQUEST_UPDATE with forward=1", + namespace, + track_name + ); + + let mut params = KeyValuePairs::default(); + params.set_intvalue(0x10, 1); // Forward = 1 + + let request_update = moq_transport::message::SubscribeUpdate { + id: self.subscriber.get_next_request_id(), + existing_request_id: publish_request_id, + params, + }; + + self.subscriber.send_message(request_update); + log::info!( + "sent REQUEST_UPDATE for track {}/{} (existing_request_id={})", + namespace, + track_name, + publish_request_id + ); + break; + } + } + } + } + Ok(()) } } diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 624e3d1b..24cfa286 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -7,6 +7,7 @@ use moq_transport::{ coding::TrackNamespace, serve::{ServeError, Track, TrackReader, TrackWriter, TracksReader, TracksWriter}, }; +use tokio::sync::watch; #[repr(u8)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -39,6 +40,14 @@ pub struct TrackInfo { track_writer: Mutex>, upstream_subscribe_sent: AtomicBool, upstream_request_id: Mutex>, + + /// The PUBLISH request ID (set when publisher sends PUBLISH) + publish_request_id: Mutex>, + /// Forward state: true = forwarding, false = paused + /// Publisher watches this to know when to start/stop sending + forward_state: Mutex>>, + /// Receiver side for forward state changes + forward_receiver: Mutex>>, } impl TrackInfo { @@ -51,6 +60,9 @@ impl TrackInfo { track_writer: Mutex::new(None), upstream_subscribe_sent: AtomicBool::new(false), upstream_request_id: Mutex::new(None), + publish_request_id: Mutex::new(None), + forward_state: Mutex::new(None), + forward_receiver: Mutex::new(None), } } @@ -120,6 +132,62 @@ impl TrackInfo { self.state() == TrackState::Publishing } + /// Set up forward state tracking when PUBLISH is received. + /// Returns the initial forward value that was set. + pub fn set_publish_info(&self, request_id: u64, initial_forward: bool) { + *self.publish_request_id.lock().unwrap() = Some(request_id); + + let (tx, rx) = watch::channel(initial_forward); + *self.forward_state.lock().unwrap() = Some(tx); + *self.forward_receiver.lock().unwrap() = Some(rx); + + log::debug!( + "set_publish_info: track {}/{} request_id={} initial_forward={}", + self.namespace, + self.name, + request_id, + initial_forward + ); + } + + /// Get the PUBLISH request ID + pub fn publish_request_id(&self) -> Option { + *self.publish_request_id.lock().unwrap() + } + + /// Get current forward state + pub fn is_forwarding(&self) -> bool { + self.forward_receiver + .lock() + .unwrap() + .as_ref() + .map(|rx| *rx.borrow()) + .unwrap_or(true) // Default to true if not set (legacy behavior) + } + + /// Request forwarding to start (called when a subscriber arrives). + /// Returns true if the state changed from false to true. + pub fn request_forward(&self) -> bool { + if let Some(tx) = self.forward_state.lock().unwrap().as_ref() { + let current = *tx.borrow(); + if !current { + let _ = tx.send(true); + log::info!( + "request_forward: track {}/{} forward state changed 0 -> 1", + self.namespace, + self.name + ); + return true; + } + } + false + } + + /// Get a receiver for forward state changes (for the publisher to watch) + pub fn forward_receiver(&self) -> Option> { + self.forward_receiver.lock().unwrap().clone() + } + pub fn take_writer_for_upstream(&self) -> Result { self.ensure_track_created(); diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index dd029c7e..115061a9 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -138,12 +138,24 @@ impl Producer { } } + // If the track is in Publishing state and forward=0, request forwarding + // This will trigger the consumer to send REQUEST_UPDATE to the publisher + if track_info.is_publishing() && !track_info.is_forwarding() { + log::info!( + "subscriber arrived for paused track {}/{}, requesting forward", + namespace, + track_name + ); + track_info.request_forward(); + } + let reader = track_info.get_reader(); log::info!( - "serving subscribe from local: {}/{} (state: {:?})", + "serving subscribe from local: {}/{} (state: {:?}, forwarding: {})", namespace, track_name, - track_info.state() + track_info.state(), + track_info.is_forwarding() ); return Ok(subscribed.serve(reader).await?); } diff --git a/moq-transport/src/session/publish_received.rs b/moq-transport/src/session/publish_received.rs index 0a99c0fb..8612d91c 100644 --- a/moq-transport/src/session/publish_received.rs +++ b/moq-transport/src/session/publish_received.rs @@ -13,15 +13,26 @@ pub struct PublishReceivedInfo { pub track_namespace: TrackNamespace, pub track_name: String, pub track_alias: u64, + /// Forward parameter from PUBLISH (0x10): true = forward immediately, false = paused + pub forward: bool, } impl PublishReceivedInfo { pub fn new_from_publish(msg: &message::Publish) -> Self { + // Forward parameter (0x10): default to true if not present + // Value of 0 means paused, 1 (or non-zero) means forward + let forward = msg + .params + .get_intvalue(0x10) // ParameterType::Forward + .map(|v| v != 0) + .unwrap_or(true); + Self { id: msg.id, track_namespace: msg.track_namespace.clone(), track_name: msg.track_name.clone(), track_alias: msg.track_alias, + forward, } } } diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 3ec9db57..11e185e8 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -98,7 +98,7 @@ impl Subscriber { } /// Get the current next request id to use and increment the value for by 2 for the next request - fn get_next_request_id(&self) -> u64 { + pub fn get_next_request_id(&self) -> u64 { self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed) } @@ -146,7 +146,7 @@ impl Subscriber { } /// Send a message to the publisher via the control stream. - pub(super) fn send_message>(&mut self, msg: M) { + pub fn send_message>(&mut self, msg: M) { let msg = msg.into(); // Remove our entry on terminal state. From 12ac6bfc95ba02912431720e33264a06be52424b Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 21:27:09 -0700 Subject: [PATCH 09/21] Fix PublishNamespace handle lifetime and stale track cleanup - Keep forwarded PublishNamespace handle alive in function scope to prevent premature PublishNamespaceDone being sent when task completes - Detect and remove stale TrackInfo entries (Publishing state with no writer) to allow publishers to reconnect without 'already publishing' errors --- moq-relay-ietf/src/consumer.rs | 45 +++++++++++++++++++--------------- moq-relay-ietf/src/local.rs | 20 +++++++++++++++ 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 1078ba79..5d0ec8b6 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -98,7 +98,8 @@ impl Consumer { mut self, mut publish_ns: PublishNamespaceReceived, ) -> Result<(), anyhow::Error> { - let mut tasks = FuturesUnordered::new(); + let mut tasks: FuturesUnordered>> = + FuturesUnordered::new(); let (writer, mut request, reader) = Tracks::new(publish_ns.namespace.clone()).produce(); @@ -131,27 +132,31 @@ impl Consumer { } } - if let Some(mut forward) = self.forward.clone() { + // Forward publish_namespace upstream - keep handle alive in this scope + let _forwarded_publish_ns = if let Some(mut forward) = self.forward.clone() { let reader_clone = reader.clone(); - tasks.push( - async move { - log::info!("forwarding publish_namespace: {:?}", reader_clone.info); - let publish_ns = forward - .publish_namespace(reader_clone) - .await - .context("failed forwarding publish_namespace")?; - publish_ns - .ok() - .await - .context("publish_namespace not accepted")?; - publish_ns - .closed() - .await - .context("publish_namespace closed with error") + log::info!("forwarding publish_namespace: {:?}", reader_clone.info); + match forward.publish_namespace(reader_clone).await { + Ok(publish_ns) => { + if let Err(e) = publish_ns.ok().await { + log::warn!("publish_namespace not accepted: {}", e); + None + } else { + log::info!( + "publish_namespace forwarded and accepted: {:?}", + publish_ns.info.namespace + ); + Some(publish_ns) + } } - .boxed(), - ); - } + Err(e) => { + log::warn!("failed forwarding publish_namespace: {}", e); + None + } + } + } else { + None + }; // Serve subscribe requests loop { diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 24cfa286..0b816810 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -124,6 +124,7 @@ impl TrackInfo { .ok_or(ServeError::Duplicate) } + pub fn state(&self) -> TrackState { TrackState::from_u8(self.state.load(Ordering::SeqCst)) } @@ -334,6 +335,25 @@ impl Locals { // First try to find an existing matching namespace entry if let Some(entry) = Self::find_best_match_entry(&lookup, namespace) { let mut tracks = entry.tracks.lock().unwrap(); + + // Check if there's an existing track in stale Publishing state + if let Some(existing) = tracks.get(&track_key) { + if existing.state() == TrackState::Publishing { + // Check if the writer was already taken (stale state) + // by trying to see if we can get the writer + let has_writer = existing.track_writer.lock().unwrap().is_some(); + if !has_writer { + // Stale state - remove and create fresh TrackInfo + log::info!( + "removing stale TrackInfo for {}/{} (was Publishing, no writer)", + namespace, + track_name + ); + tracks.remove(&track_key); + } + } + } + return tracks .entry(track_key.clone()) .or_insert_with(|| { From 54a3557b904a462ff40a600302e1c8f46bedfdb9 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 21:31:21 -0700 Subject: [PATCH 10/21] Remove stale namespace entry on publisher reconnect When a publisher disconnects and reconnects, the entire namespace entry (not just the TrackInfo) needs to be recreated because the TracksWriter is also closed/stale. --- moq-relay-ietf/src/local.rs | 43 +++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 0b816810..0e4f0d13 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -332,28 +332,35 @@ impl Locals { // when different namespaces have the same track_name let track_key = format!("{}:{}", namespace, track_name); + // Check if there's an existing exact-match namespace entry that's stale + // and needs to be removed (this happens when publisher disconnects and reconnects) + let should_remove_namespace = if let Some(entry) = lookup.get(namespace) { + let tracks = entry.tracks.lock().unwrap(); + if let Some(existing) = tracks.get(&track_key) { + // Track exists and is in Publishing state but has no writer = stale + existing.state() == TrackState::Publishing + && existing.track_writer.lock().unwrap().is_none() + } else { + false + } + } else { + false + }; + + if should_remove_namespace { + log::info!( + "removing stale namespace entry {} (track {}/{} was Publishing with no writer)", + namespace, + namespace, + track_name + ); + lookup.remove(namespace); + } + // First try to find an existing matching namespace entry if let Some(entry) = Self::find_best_match_entry(&lookup, namespace) { let mut tracks = entry.tracks.lock().unwrap(); - // Check if there's an existing track in stale Publishing state - if let Some(existing) = tracks.get(&track_key) { - if existing.state() == TrackState::Publishing { - // Check if the writer was already taken (stale state) - // by trying to see if we can get the writer - let has_writer = existing.track_writer.lock().unwrap().is_some(); - if !has_writer { - // Stale state - remove and create fresh TrackInfo - log::info!( - "removing stale TrackInfo for {}/{} (was Publishing, no writer)", - namespace, - track_name - ); - tracks.remove(&track_key); - } - } - } - return tracks .entry(track_key.clone()) .or_insert_with(|| { From cd0bdcd4fb9438ddb5d254fd23afd8e8d3767c02 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 21:35:29 -0700 Subject: [PATCH 11/21] Keep PublishNamespace handles alive in serve_subscribe_namespace The handles were being dropped at end of each loop iteration, causing PublishNamespaceDone to be sent immediately after PublishNamespace. --- moq-relay-ietf/src/producer.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 115061a9..3ef4e657 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -220,6 +220,8 @@ impl Producer { ); // Send PUBLISH_NAMESPACE for existing namespaces + // Keep handles alive to prevent PublishNamespaceDone from being sent + let mut publish_ns_handles = Vec::new(); for namespace in matching_namespaces { log::info!( "sending PUBLISH_NAMESPACE for {:?} (matched prefix {:?})", @@ -227,9 +229,9 @@ impl Producer { namespace_prefix ); match self.publisher.publish_namespace(namespace.clone()).await { - Ok(_publish_ns) => { + Ok(publish_ns) => { log::debug!("sent PUBLISH_NAMESPACE for {:?}", namespace); - // Note: publish_ns is kept alive to maintain the announcement + publish_ns_handles.push(publish_ns); } Err(e) => { log::warn!( @@ -240,6 +242,7 @@ impl Producer { } } } + let _publish_ns_handles = publish_ns_handles; // If we have a publish receiver, listen for new PUBLISH and PUBLISH_NAMESPACE notifications if publish_rx.is_some() || publish_ns_rx.is_some() { From a29815ebd04755553eec6771bfc4a09aab60b453 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 21:39:55 -0700 Subject: [PATCH 12/21] Fix SUBSCRIBE_NAMESPACE to send NAMESPACE instead of PUBLISH SUBSCRIBE_NAMESPACE is for namespace discovery - the relay should send NAMESPACE messages to notify about available tracks, not send PUBLISH and stream data. The subscriber explicitly SUBSCRIBEs to tracks it wants. --- moq-relay-ietf/src/producer.rs | 67 ++++++++-------------------------- 1 file changed, 16 insertions(+), 51 deletions(-) diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 3ef4e657..97def379 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -253,7 +253,8 @@ impl Producer { result?; break; } - // Wait for PUBLISH notifications + // Wait for PUBLISH notifications -> send NAMESPACE message to notify subscriber + // The subscriber will then explicitly SUBSCRIBE to tracks it wants notification = async { if let Some(ref mut rx) = publish_rx { rx.recv().await @@ -270,56 +271,20 @@ impl Producer { namespace_prefix ); - // Get the TrackReader for this track so we can stream data - if let Some(track_info) = self.locals.get_track_info( - &publish_notif.namespace, - &publish_notif.track_name, - ) { - let track_reader = track_info.get_reader(); - - // Use publisher.publish() which sends PUBLISH with forward=1 - // This allows forwarding objects immediately - let mut publisher = self.publisher.clone(); - let ns = publish_notif.namespace.clone(); - let name = publish_notif.track_name.clone(); - tokio::spawn(async move { - match publisher.publish(track_reader.clone()).await { - Ok(published) => { - log::info!( - "forwarded PUBLISH for {}/{} with forward=1, streaming immediately", - ns, name - ); - // serve_immediately() starts streaming without waiting for PUBLISH_OK - // Since forward=1, subscriber expects data immediately - // If subscriber sends error, serve will end and we cleanup - match published.serve_immediately(track_reader).await { - Ok(()) => { - log::info!("track {}/{} serving completed", ns, name); - } - Err(e) => { - log::warn!( - "track {}/{} serving ended: {}", - ns, name, e - ); - // Cleanup handled by Published drop - } - } - } - Err(e) => { - log::warn!( - "failed to publish track {}/{}: {}", - ns, name, e - ); - } - } - }); - } else { - log::warn!( - "no track info found for {}/{}, cannot forward PUBLISH", - publish_notif.namespace, - publish_notif.track_name - ); - } + // Send NAMESPACE message to notify subscriber about the new track's namespace + // Subscriber will explicitly SUBSCRIBE to tracks it wants + let namespace_msg = message::Namespace { + id: subscribe_ns.info.request_id, + track_namespace: publish_notif.namespace.clone(), + params: KeyValuePairs::new(), + }; + self.publisher.forward_namespace(namespace_msg); + log::debug!( + "sent NAMESPACE for {:?} (track {}) on request_id={}", + publish_notif.namespace, + publish_notif.track_name, + subscribe_ns.info.request_id + ); } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { log::warn!("subscription lagged by {} messages", n); From 43b566539f5b241c21620cf6e3deeeffb354a9a4 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 21:41:42 -0700 Subject: [PATCH 13/21] Fix SUBSCRIBE_NAMESPACE to wait for PUBLISH_OK before streaming Flow: relay sends PUBLISH -> client sends PUBLISH_OK -> relay streams data. Changed from serve_immediately() to serve() which waits for PUBLISH_OK. --- moq-relay-ietf/src/producer.rs | 64 +++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 97def379..930c2beb 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -253,8 +253,8 @@ impl Producer { result?; break; } - // Wait for PUBLISH notifications -> send NAMESPACE message to notify subscriber - // The subscriber will then explicitly SUBSCRIBE to tracks it wants + // Wait for PUBLISH notifications -> forward PUBLISH to subscriber + // Subscriber sends PUBLISH_OK, then relay starts streaming data notification = async { if let Some(ref mut rx) = publish_rx { rx.recv().await @@ -271,20 +271,52 @@ impl Producer { namespace_prefix ); - // Send NAMESPACE message to notify subscriber about the new track's namespace - // Subscriber will explicitly SUBSCRIBE to tracks it wants - let namespace_msg = message::Namespace { - id: subscribe_ns.info.request_id, - track_namespace: publish_notif.namespace.clone(), - params: KeyValuePairs::new(), - }; - self.publisher.forward_namespace(namespace_msg); - log::debug!( - "sent NAMESPACE for {:?} (track {}) on request_id={}", - publish_notif.namespace, - publish_notif.track_name, - subscribe_ns.info.request_id - ); + // Get the TrackReader for this track so we can stream data + if let Some(track_info) = self.locals.get_track_info( + &publish_notif.namespace, + &publish_notif.track_name, + ) { + let track_reader = track_info.get_reader(); + + // Send PUBLISH and wait for PUBLISH_OK before streaming + let mut publisher = self.publisher.clone(); + let ns = publish_notif.namespace.clone(); + let name = publish_notif.track_name.clone(); + tokio::spawn(async move { + match publisher.publish(track_reader.clone()).await { + Ok(published) => { + log::info!( + "sent PUBLISH for {}/{}, waiting for PUBLISH_OK", + ns, name + ); + // serve() waits for PUBLISH_OK before streaming + match published.serve(track_reader).await { + Ok(()) => { + log::info!("track {}/{} serving completed", ns, name); + } + Err(e) => { + log::warn!( + "track {}/{} serving ended: {}", + ns, name, e + ); + } + } + } + Err(e) => { + log::warn!( + "failed to send PUBLISH for {}/{}: {}", + ns, name, e + ); + } + } + }); + } else { + log::warn!( + "no track info found for {}/{}, cannot forward PUBLISH", + publish_notif.namespace, + publish_notif.track_name + ); + } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { log::warn!("subscription lagged by {} messages", n); From 4dcaa7add9b98386b236b404ff215ff3b4d8667f Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 21:52:03 -0700 Subject: [PATCH 14/21] Add self-exclusion to SUBSCRIBE_NAMESPACE flow When a client sends SUBSCRIBE_NAMESPACE and also PUBLISH on the same session, the relay should not forward the client's own PUBLISH back to them. Added session_id tracking to Consumer and Producer, used when registering subscriptions and notifying of PUBLISH events. Notifications skip subscriptions from the same session that originated the PUBLISH. --- .idea/workspace.xml | 237 ++++++++++++++++++++++ moq-relay-ietf/src/consumer.rs | 10 +- moq-relay-ietf/src/producer.rs | 7 +- moq-relay-ietf/src/relay.rs | 10 + moq-relay-ietf/src/subscriber_registry.rs | 62 +++++- 5 files changed, 316 insertions(+), 10 deletions(-) create mode 100644 .idea/workspace.xml diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 00000000..359cff88 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,237 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + { + "associatedIndex": 6 +} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1772062719090 + + + + + + \ No newline at end of file diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index 5d0ec8b6..feb3655e 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -19,6 +19,7 @@ pub struct Consumer { coordinator: Arc, forward: Option, // Forward all announcements to this subscriber subscriber_registry: Option, + session_id: u64, } impl Consumer { @@ -34,6 +35,7 @@ impl Consumer { coordinator, forward, subscriber_registry: None, + session_id: 0, } } @@ -44,6 +46,7 @@ impl Consumer { coordinator: Arc, forward: Option, subscriber_registry: SubscriberRegistry, + session_id: u64, ) -> Self { Self { subscriber, @@ -51,6 +54,7 @@ impl Consumer { coordinator, forward, subscriber_registry: Some(subscriber_registry), + session_id, } } @@ -121,8 +125,9 @@ impl Consumer { // Notify subscriber registry of the new PUBLISH_NAMESPACE // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + // Uses session_id for self-exclusion if let Some(ref registry) = self.subscriber_registry { - let notified = registry.notify_publish_namespace(&publish_ns.namespace); + let notified = registry.notify_publish_namespace(&publish_ns.namespace, self.session_id); if notified > 0 { log::info!( "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH_NAMESPACE {:?}", @@ -261,8 +266,9 @@ impl Consumer { // Notify subscriber registry of the new PUBLISH // This will trigger forwarding to matching SUBSCRIBE_NAMESPACE subscriptions + // Uses session_id for self-exclusion (don't notify the same session that sent the PUBLISH) if let Some(ref registry) = self.subscriber_registry { - let notified = registry.notify_publish(&namespace, &track_name, track_alias); + let notified = registry.notify_publish(&namespace, &track_name, track_alias, self.session_id); if notified > 0 { log::info!( "notified {} SUBSCRIBE_NAMESPACE subscriptions of PUBLISH {}/{}", diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 930c2beb..8ebd48af 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -18,6 +18,7 @@ pub struct Producer { locals: Locals, remotes: Option, subscriber_registry: Option, + session_id: u64, } impl Producer { @@ -27,6 +28,7 @@ impl Producer { locals, remotes, subscriber_registry: None, + session_id: 0, } } @@ -36,12 +38,14 @@ impl Producer { locals: Locals, remotes: Option, subscriber_registry: SubscriberRegistry, + session_id: u64, ) -> Self { Self { publisher, locals, remotes, subscriber_registry: Some(subscriber_registry), + session_id, } } @@ -191,9 +195,10 @@ impl Producer { let namespace_prefix = subscribe_ns.namespace_prefix.clone(); // Register with subscriber registry to receive PUBLISH and PUBLISH_NAMESPACE notifications + // Uses session_id so we can exclude PUBLISH messages from the same session (self-exclusion) let (_subscription_guard, mut publish_rx, mut publish_ns_rx) = if let Some(ref registry) = self.subscriber_registry { - let (id, rx, rx_ns) = registry.register(namespace_prefix.clone()); + let (id, rx, rx_ns) = registry.register(namespace_prefix.clone(), self.session_id); ( Some(crate::SubscriptionGuard::new(registry.clone(), id)), Some(rx), diff --git a/moq-relay-ietf/src/relay.rs b/moq-relay-ietf/src/relay.rs index f4756ea7..8657e1ab 100644 --- a/moq-relay-ietf/src/relay.rs +++ b/moq-relay-ietf/src/relay.rs @@ -239,6 +239,14 @@ impl Relay { }; // Create our MoQ relay session + // Use connection_id hash as session_id for self-exclusion in pub/sub + use std::hash::{Hash, Hasher}; + let session_id = { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + connection_id.hash(&mut hasher); + hasher.finish() + }; + let moq_session = session; let session = Session { session: moq_session, @@ -248,6 +256,7 @@ impl Relay { locals.clone(), remotes, subscriber_registry.clone(), + session_id, ) }), consumer: subscriber.map(|subscriber| { @@ -257,6 +266,7 @@ impl Relay { coordinator, forward, subscriber_registry, + session_id, ) }), }; diff --git a/moq-relay-ietf/src/subscriber_registry.rs b/moq-relay-ietf/src/subscriber_registry.rs index a21119ba..52c8aec5 100644 --- a/moq-relay-ietf/src/subscriber_registry.rs +++ b/moq-relay-ietf/src/subscriber_registry.rs @@ -9,6 +9,8 @@ use tokio::sync::broadcast; pub struct NamespaceSubscription { /// The namespace prefix this subscription is for pub prefix: TrackNamespace, + /// Session ID of the subscriber (for self-exclusion) + pub session_id: u64, /// Channel to send PUBLISH notifications to this subscriber pub publish_tx: broadcast::Sender, /// Channel to send PUBLISH_NAMESPACE notifications to this subscriber @@ -60,6 +62,7 @@ impl SubscriberRegistry { pub fn register( &self, prefix: TrackNamespace, + session_id: u64, ) -> ( u64, broadcast::Receiver, @@ -76,13 +79,18 @@ impl SubscriberRegistry { let subscription = NamespaceSubscription { prefix, + session_id, publish_tx, publish_ns_tx, }; inner.subscriptions.insert(id, subscription); - log::debug!("registered namespace subscription id={}", id); + log::debug!( + "registered namespace subscription id={} session_id={}", + id, + session_id + ); (id, publish_rx, publish_ns_rx) } @@ -96,12 +104,14 @@ impl SubscriberRegistry { } /// Find all subscriptions that match a given namespace and notify them of a PUBLISH + /// Excludes the session that originated the PUBLISH (self-exclusion) /// Returns the number of matching subscriptions notified pub fn notify_publish( &self, namespace: &TrackNamespace, track_name: &str, track_alias: u64, + origin_session_id: u64, ) -> usize { let inner = self.inner.lock().unwrap(); @@ -114,6 +124,16 @@ impl SubscriberRegistry { let mut notified = 0; for (id, sub) in inner.subscriptions.iter() { + // Skip if this subscription belongs to the same session that sent the PUBLISH + if sub.session_id == origin_session_id { + log::debug!( + "skipping subscription id={} (same session {})", + id, + origin_session_id + ); + continue; + } + // Check if the namespace matches the subscription prefix // The subscription prefix should be a prefix of the namespace if Self::prefix_matches(&sub.prefix, namespace) { @@ -135,8 +155,9 @@ impl SubscriberRegistry { } /// Find all subscriptions that match a given namespace and notify them of a PUBLISH_NAMESPACE + /// Excludes the session that originated the PUBLISH_NAMESPACE (self-exclusion) /// Returns the number of matching subscriptions notified - pub fn notify_publish_namespace(&self, namespace: &TrackNamespace) -> usize { + pub fn notify_publish_namespace(&self, namespace: &TrackNamespace, origin_session_id: u64) -> usize { let inner = self.inner.lock().unwrap(); let notification = PublishNamespaceNotification { @@ -146,6 +167,16 @@ impl SubscriberRegistry { let mut notified = 0; for (id, sub) in inner.subscriptions.iter() { + // Skip if this subscription belongs to the same session that sent the PUBLISH_NAMESPACE + if sub.session_id == origin_session_id { + log::debug!( + "skipping subscription id={} for PUBLISH_NAMESPACE (same session {})", + id, + origin_session_id + ); + continue; + } + // Check if the namespace matches the subscription prefix if Self::prefix_matches(&sub.prefix, namespace) { if let Err(e) = sub.publish_ns_tx.send(notification.clone()) { @@ -245,8 +276,8 @@ mod tests { fn test_register_unregister() { let registry = SubscriberRegistry::new(); - let (id1, _rx1, _rx1_ns) = registry.register(ns("live")); - let (id2, _rx2, _rx2_ns) = registry.register(ns("live/room1")); + let (id1, _rx1, _rx1_ns) = registry.register(ns("live"), 100); + let (id2, _rx2, _rx2_ns) = registry.register(ns("live/room1"), 101); assert_eq!(registry.matching_subscriptions(&ns("live/room1/track")).len(), 2); @@ -263,15 +294,32 @@ mod tests { async fn test_notify_publish() { let registry = SubscriberRegistry::new(); - let (id, mut rx, _rx_ns) = registry.register(ns("live")); + // Register with session_id=100 + let (id, mut rx, _rx_ns) = registry.register(ns("live"), 100); - let notified = registry.notify_publish(&ns("live/stream1"), "video", 100); + // Notify from session 200 (different) - should be delivered + let notified = registry.notify_publish(&ns("live/stream1"), "video", 42, 200); assert_eq!(notified, 1); let notification = rx.recv().await.unwrap(); assert_eq!(notification.track_name, "video"); - assert_eq!(notification.track_alias, 100); + assert_eq!(notification.track_alias, 42); registry.unregister(id); } + + #[tokio::test] + async fn test_self_exclusion() { + let registry = SubscriberRegistry::new(); + + // Register with session_id=100 + let (_id, mut rx, _rx_ns) = registry.register(ns("live"), 100); + + // Notify from the same session (100) - should NOT be delivered + let notified = registry.notify_publish(&ns("live/stream1"), "video", 42, 100); + assert_eq!(notified, 0); + + // Verify nothing was received (use try_recv to avoid blocking) + assert!(rx.try_recv().is_err()); + } } From fbefe1d74126f0df7db099e968ded404be6b55ce Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 22:13:35 -0700 Subject: [PATCH 15/21] Send PUBLISH for existing tracks on SUBSCRIBE_NAMESPACE When handling SUBSCRIBE_NAMESPACE, in addition to PUBLISH_NAMESPACE for existing namespaces, also send PUBLISH for existing tracks that are already in Publishing state. This triggers the client's onMatch callback for track discovery (client expects PUBLISH not just ANNOUNCE/PUBLISH_NAMESPACE). --- moq-relay-ietf/src/local.rs | 29 +++++++++++++++++++++++++ moq-relay-ietf/src/producer.rs | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 0e4f0d13..82fcfe05 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -493,6 +493,35 @@ impl Locals { .cloned() .collect() } + + /// Get all tracks in namespaces matching a prefix that are in Publishing state. + /// Returns (namespace, track_name, TrackInfo) tuples. + pub fn matching_tracks(&self, prefix: &TrackNamespace) -> Vec<(TrackNamespace, String, Arc)> { + let lookup = self.lookup.lock().unwrap(); + + let mut result = Vec::new(); + + for (ns, entry) in lookup.iter() { + // Check if namespace matches prefix + if ns.fields.len() >= prefix.fields.len() + && prefix + .fields + .iter() + .zip(ns.fields.iter()) + .all(|(a, b)| a == b) + { + // Get all tracks in this namespace that are publishing + let tracks = entry.tracks.lock().unwrap(); + for (key, track_info) in tracks.iter() { + if track_info.is_publishing() { + result.push((ns.clone(), track_info.name.clone(), track_info.clone())); + } + } + } + } + + result + } } pub struct Registration { diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 8ebd48af..20336e4c 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -249,6 +249,45 @@ impl Producer { } let _publish_ns_handles = publish_ns_handles; + // Also send PUBLISH for existing tracks in matching namespaces + // This triggers the client's onMatch callback for track discovery + let matching_tracks = self.locals.matching_tracks(&namespace_prefix); + for (ns, track_name, track_info) in matching_tracks { + log::info!( + "sending PUBLISH for existing track {}/{} (matched prefix {:?})", + ns, + track_name, + namespace_prefix + ); + + let track_reader = track_info.get_reader(); + let mut publisher = self.publisher.clone(); + + tokio::spawn(async move { + match publisher.publish(track_reader.clone()).await { + Ok(published) => { + log::info!( + "sent PUBLISH for existing track {}/{}, waiting for PUBLISH_OK", + ns, + track_name + ); + // serve() waits for PUBLISH_OK before streaming + match published.serve(track_reader).await { + Ok(()) => { + log::info!("existing track {}/{} serving completed", ns, track_name); + } + Err(e) => { + log::warn!("existing track {}/{} serving ended: {}", ns, track_name, e); + } + } + } + Err(e) => { + log::warn!("failed to send PUBLISH for existing track {}/{}: {}", ns, track_name, e); + } + } + }); + } + // If we have a publish receiver, listen for new PUBLISH and PUBLISH_NAMESPACE notifications if publish_rx.is_some() || publish_ns_rx.is_some() { loop { From eddc7bc2d6d3c56f35b50449e1c9ad1c9a7fa9ef Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 22:30:46 -0700 Subject: [PATCH 16/21] Only send PUBLISH for tracks, not PUBLISH_NAMESPACE Client onMatch callback only fires on PUBLISH messages, not PUBLISH_NAMESPACE. Remove PUBLISH_NAMESPACE sending and only forward PUBLISH for existing tracks in matching namespaces. --- moq-relay-ietf/src/producer.rs | 68 +++++++--------------------------- 1 file changed, 14 insertions(+), 54 deletions(-) diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 20336e4c..24c84156 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -208,50 +208,24 @@ impl Producer { (None, None, None) }; - // Find existing namespaces that match the prefix - let matching_namespaces: Vec = self - .locals - .matching_namespaces(&namespace_prefix) - .into_iter() - .collect(); - // Accept the subscription (even if no current matches - publisher may arrive later) subscribe_ns.ok()?; log::info!( - "accepted SUBSCRIBE_NAMESPACE for prefix {:?}, {} existing matches", - namespace_prefix, - matching_namespaces.len() + "accepted SUBSCRIBE_NAMESPACE for prefix {:?}", + namespace_prefix ); - // Send PUBLISH_NAMESPACE for existing namespaces - // Keep handles alive to prevent PublishNamespaceDone from being sent - let mut publish_ns_handles = Vec::new(); - for namespace in matching_namespaces { - log::info!( - "sending PUBLISH_NAMESPACE for {:?} (matched prefix {:?})", - namespace, - namespace_prefix - ); - match self.publisher.publish_namespace(namespace.clone()).await { - Ok(publish_ns) => { - log::debug!("sent PUBLISH_NAMESPACE for {:?}", namespace); - publish_ns_handles.push(publish_ns); - } - Err(e) => { - log::warn!( - "failed to send PUBLISH_NAMESPACE for {:?}: {}", - namespace, - e - ); - } - } - } - let _publish_ns_handles = publish_ns_handles; - - // Also send PUBLISH for existing tracks in matching namespaces + // Send PUBLISH for existing tracks in matching namespaces // This triggers the client's onMatch callback for track discovery + // Note: We skip PUBLISH_NAMESPACE and send PUBLISH directly - client expects PUBLISH for tracks let matching_tracks = self.locals.matching_tracks(&namespace_prefix); + log::info!( + "found {} existing tracks matching prefix {:?}", + matching_tracks.len(), + namespace_prefix + ); + for (ns, track_name, track_info) in matching_tracks { log::info!( "sending PUBLISH for existing track {}/{} (matched prefix {:?})", @@ -371,7 +345,8 @@ impl Producer { } } } - // Wait for PUBLISH_NAMESPACE notifications -> forward as NAMESPACE message + // PUBLISH_NAMESPACE notifications - we don't forward these as NAMESPACE messages + // Client expects PUBLISH for individual tracks, not namespace announcements notification = async { if let Some(ref mut rx) = publish_ns_rx { rx.recv().await @@ -381,24 +356,9 @@ impl Producer { } => { match notification { Ok(ns_notif) => { - log::info!( - "received PUBLISH_NAMESPACE notification for {:?} on subscription prefix {:?}", - ns_notif.namespace, - namespace_prefix - ); - // Forward NAMESPACE message to the subscriber (not PUBLISH_NAMESPACE) - // NAMESPACE (0x08) is the draft-16 message for announcing namespaces - // to SUBSCRIBE_NAMESPACE subscribers - let namespace_msg = message::Namespace { - id: subscribe_ns.info.request_id, - track_namespace: ns_notif.namespace.clone(), - params: KeyValuePairs::new(), - }; - self.publisher.forward_namespace(namespace_msg); log::debug!( - "forwarded NAMESPACE for {:?} (request_id={})", - ns_notif.namespace, - subscribe_ns.info.request_id + "ignoring PUBLISH_NAMESPACE notification for {:?} (client expects PUBLISH for tracks)", + ns_notif.namespace ); } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { From 7f955152cfa175ff0807eb56ab896e11b8f33852 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Mon, 20 Apr 2026 23:39:50 -0700 Subject: [PATCH 17/21] Forward track_extensions in PUBLISH messages Store track_extensions from incoming PUBLISH in TrackInfo and forward them when relaying PUBLISH to subscribers. This preserves end-to-end extension headers that clients require. --- moq-relay-ietf/src/consumer.rs | 9 ++++- moq-relay-ietf/src/local.rs | 14 +++++++ moq-relay-ietf/src/producer.rs | 15 ++++++-- moq-transport/src/session/publish_received.rs | 4 ++ moq-transport/src/session/publisher.rs | 37 +++++++++++++++++++ 5 files changed, 73 insertions(+), 6 deletions(-) diff --git a/moq-relay-ietf/src/consumer.rs b/moq-relay-ietf/src/consumer.rs index feb3655e..d0641740 100644 --- a/moq-relay-ietf/src/consumer.rs +++ b/moq-relay-ietf/src/consumer.rs @@ -197,12 +197,14 @@ impl Consumer { let track_alias = publish.info.track_alias; let initial_forward = publish.info.forward; let publish_request_id = publish.info.id; + let track_extensions = publish.info.track_extensions.clone(); log::info!( - "received PUBLISH for track: {}/{} (forward={})", + "received PUBLISH for track: {}/{} (forward={}, extensions={:?})", namespace, track_name, - initial_forward + initial_forward, + track_extensions ); // Use auto-register variant to support SUBSCRIBE_NAMESPACE flow @@ -246,6 +248,9 @@ impl Consumer { // Store publish info for forward state management track_info.set_publish_info(publish_request_id, initial_forward); + // Store track extensions for forwarding to subscribers + track_info.set_track_extensions(track_extensions); + // Include forward=1 in PUBLISH_OK to tell publisher to start sending immediately let mut params = KeyValuePairs::default(); params.set_intvalue(0x10, 1); // Forward = 1 diff --git a/moq-relay-ietf/src/local.rs b/moq-relay-ietf/src/local.rs index 82fcfe05..e56211b3 100644 --- a/moq-relay-ietf/src/local.rs +++ b/moq-relay-ietf/src/local.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, Mutex, OnceLock}; use moq_transport::{ coding::TrackNamespace, + data::ExtensionHeaders, serve::{ServeError, Track, TrackReader, TrackWriter, TracksReader, TracksWriter}, }; use tokio::sync::watch; @@ -48,6 +49,8 @@ pub struct TrackInfo { forward_state: Mutex>>, /// Receiver side for forward state changes forward_receiver: Mutex>>, + /// Track extensions from the original PUBLISH message + track_extensions: Mutex>, } impl TrackInfo { @@ -63,6 +66,7 @@ impl TrackInfo { publish_request_id: Mutex::new(None), forward_state: Mutex::new(None), forward_receiver: Mutex::new(None), + track_extensions: Mutex::new(None), } } @@ -189,6 +193,16 @@ impl TrackInfo { self.forward_receiver.lock().unwrap().clone() } + /// Set track extensions from the original PUBLISH message + pub fn set_track_extensions(&self, extensions: ExtensionHeaders) { + *self.track_extensions.lock().unwrap() = Some(extensions); + } + + /// Get track extensions (cloned) + pub fn track_extensions(&self) -> Option { + self.track_extensions.lock().unwrap().clone() + } + pub fn take_writer_for_upstream(&self) -> Result { self.ensure_track_created(); diff --git a/moq-relay-ietf/src/producer.rs b/moq-relay-ietf/src/producer.rs index 24c84156..8149df46 100644 --- a/moq-relay-ietf/src/producer.rs +++ b/moq-relay-ietf/src/producer.rs @@ -227,18 +227,20 @@ impl Producer { ); for (ns, track_name, track_info) in matching_tracks { + let track_extensions = track_info.track_extensions().unwrap_or_default(); log::info!( - "sending PUBLISH for existing track {}/{} (matched prefix {:?})", + "sending PUBLISH for existing track {}/{} (matched prefix {:?}, extensions={:?})", ns, track_name, - namespace_prefix + namespace_prefix, + track_extensions ); let track_reader = track_info.get_reader(); let mut publisher = self.publisher.clone(); tokio::spawn(async move { - match publisher.publish(track_reader.clone()).await { + match publisher.publish_with_extensions(track_reader.clone(), track_extensions).await { Ok(published) => { log::info!( "sent PUBLISH for existing track {}/{}, waiting for PUBLISH_OK", @@ -295,13 +297,18 @@ impl Producer { &publish_notif.track_name, ) { let track_reader = track_info.get_reader(); + let track_extensions = track_info.track_extensions().unwrap_or_default(); // Send PUBLISH and wait for PUBLISH_OK before streaming let mut publisher = self.publisher.clone(); let ns = publish_notif.namespace.clone(); let name = publish_notif.track_name.clone(); + log::info!( + "forwarding PUBLISH for {}/{} with extensions {:?}", + ns, name, track_extensions + ); tokio::spawn(async move { - match publisher.publish(track_reader.clone()).await { + match publisher.publish_with_extensions(track_reader.clone(), track_extensions).await { Ok(published) => { log::info!( "sent PUBLISH for {}/{}, waiting for PUBLISH_OK", diff --git a/moq-transport/src/session/publish_received.rs b/moq-transport/src/session/publish_received.rs index 8612d91c..e0703aba 100644 --- a/moq-transport/src/session/publish_received.rs +++ b/moq-transport/src/session/publish_received.rs @@ -1,6 +1,7 @@ use std::ops; use crate::coding::{ReasonPhrase, TrackNamespace}; +use crate::data::ExtensionHeaders; use crate::serve::ServeError; use crate::watch::State; use crate::{data, message, serve}; @@ -15,6 +16,8 @@ pub struct PublishReceivedInfo { pub track_alias: u64, /// Forward parameter from PUBLISH (0x10): true = forward immediately, false = paused pub forward: bool, + /// Track extensions from the original PUBLISH message + pub track_extensions: ExtensionHeaders, } impl PublishReceivedInfo { @@ -33,6 +36,7 @@ impl PublishReceivedInfo { track_name: msg.track_name.clone(), track_alias: msg.track_alias, forward, + track_extensions: msg.track_extensions.clone(), } } } diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 6bc4e5c2..7303c979 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -247,6 +247,43 @@ impl Publisher { Ok(send) } + /// Publish a track with specific track extensions (for relay forwarding) + pub async fn publish_with_extensions( + &mut self, + track: serve::TrackReader, + track_extensions: crate::data::ExtensionHeaders, + ) -> Result { + let request_id = self.next_requestid.fetch_add(2, atomic::Ordering::Relaxed); + let track_alias = self + .next_track_alias + .fetch_add(1, atomic::Ordering::Relaxed); + + let mut params = KeyValuePairs::new(); + params.set_intvalue(ParameterType::GroupOrder.into(), GroupOrder::Ascending as u64); + params.set_intvalue(ParameterType::Forward.into(), 1); + if let Some(loc) = track.largest_location() { + let mut buf = bytes::BytesMut::new(); + use crate::coding::Encode; + loc.encode(&mut buf).ok(); + params.set_bytesvalue(ParameterType::LargestObject.into(), buf.to_vec()); + } + + let msg = message::Publish { + id: request_id, + track_namespace: track.namespace.clone(), + track_name: track.name.clone(), + track_alias, + params, + track_extensions, + }; + + let (send, recv) = Published::new(self.clone(), msg, self.mlog.clone()); + + self.publisheds.lock().unwrap().insert(request_id, recv); + + Ok(send) + } + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { let res = match msg { message::Subscriber::Subscribe(msg) => self.recv_subscribe(msg), From 4e336758ee3c1980031b34c4511fb031d61bed66 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Tue, 21 Apr 2026 00:41:23 -0700 Subject: [PATCH 18/21] Fix stream header type mismatch when forwarding objects without extensions When objects have empty extension headers but the preserved header type has extensions enabled (e.g., SubgroupIdExt), convert to the non-Ext variant before writing the stream header. This prevents the subscriber from expecting extension_length fields that aren't present. Added StreamHeaderType::without_extensions() to convert Ext types to their non-Ext equivalents. --- moq-transport/src/data/header.rs | 21 +++++++++++++++++++++ moq-transport/src/session/published.rs | 12 ++++++++++++ 2 files changed, 33 insertions(+) diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/data/header.rs index 2f83f3c7..3991e759 100644 --- a/moq-transport/src/data/header.rs +++ b/moq-transport/src/data/header.rs @@ -105,6 +105,27 @@ impl StreamHeaderType { | StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority ) } + + /// Returns the equivalent header type without extensions. + /// Used when forwarding streams where objects have empty extension headers. + pub fn without_extensions(&self) -> Self { + match *self { + StreamHeaderType::SubgroupZeroIdExt => StreamHeaderType::SubgroupZeroId, + StreamHeaderType::SubgroupFirstObjectIdExt => StreamHeaderType::SubgroupFirstObjectId, + StreamHeaderType::SubgroupIdExt => StreamHeaderType::SubgroupId, + StreamHeaderType::SubgroupZeroIdExtEndOfGroup => StreamHeaderType::SubgroupZeroIdEndOfGroup, + StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroup => StreamHeaderType::SubgroupFirstObjectIdEndOfGroup, + StreamHeaderType::SubgroupIdExtEndOfGroup => StreamHeaderType::SubgroupIdEndOfGroup, + StreamHeaderType::SubgroupZeroIdExtNoPriority => StreamHeaderType::SubgroupZeroIdNoPriority, + StreamHeaderType::SubgroupFirstObjectIdExtNoPriority => StreamHeaderType::SubgroupFirstObjectIdNoPriority, + StreamHeaderType::SubgroupIdExtNoPriority => StreamHeaderType::SubgroupIdNoPriority, + StreamHeaderType::SubgroupZeroIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupZeroIdEndOfGroupNoPriority, + StreamHeaderType::SubgroupFirstObjectIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupFirstObjectIdEndOfGroupNoPriority, + StreamHeaderType::SubgroupIdExtEndOfGroupNoPriority => StreamHeaderType::SubgroupIdEndOfGroupNoPriority, + // Already non-Ext or Fetch + other => other, + } + } } impl Encode for StreamHeaderType { diff --git a/moq-transport/src/session/published.rs b/moq-transport/src/session/published.rs index 1d1c0f50..bf004def 100644 --- a/moq-transport/src/session/published.rs +++ b/moq-transport/src/session/published.rs @@ -297,6 +297,18 @@ impl Published { } }); + // If we're not writing extension headers but the preserved header type has extensions, + // convert to the non-Ext variant to avoid mismatch between header and object encoding + let header_type = if !has_extension_headers && header_type.has_extension_headers() { + log::debug!( + "[PUBLISHED] serve_subgroup: converting header_type {:?} to non-Ext variant (objects have no extensions)", + header_type + ); + header_type.without_extensions() + } else { + header_type + }; + // Set subgroup_id based on header type (ZeroId variants don't include it on wire) let subgroup_id = if header_type.has_subgroup_id() { Some(subgroup_reader.subgroup_id) From 0112f915976a05fbe3d52101c80ce07373aaa715 Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Tue, 21 Apr 2026 00:54:46 -0700 Subject: [PATCH 19/21] Fix datagram forwarding: use broadcast channel for proper queueing The datagram serve layer was only keeping the latest datagram, causing most datagrams to be dropped when they arrive faster than the reader can consume them (e.g., 50/sec audio). Changed to use tokio broadcast channel (1024 buffer) so datagrams are queued and forwarded in order. Broadcast allows cloning the reader. Logs warning if reader lags behind. Also fixed is_closed() which was incorrectly returning true when the buffer was empty (rx.len() == 0), potentially causing premature exit. --- moq-transport/src/serve/datagram.rs | 108 +++++++++++----------------- 1 file changed, 43 insertions(+), 65 deletions(-) diff --git a/moq-transport/src/serve/datagram.rs b/moq-transport/src/serve/datagram.rs index 4d62e8d2..c168c429 100644 --- a/moq-transport/src/serve/datagram.rs +++ b/moq-transport/src/serve/datagram.rs @@ -1,123 +1,101 @@ use std::{fmt, sync::Arc}; -use crate::watch::State; +use tokio::sync::broadcast; use super::{ServeError, Track}; +const DATAGRAM_CHANNEL_SIZE: usize = 1024; + pub struct Datagrams { pub track: Arc, } impl Datagrams { pub fn produce(self) -> (DatagramsWriter, DatagramsReader) { - let (writer, reader) = State::default().split(); + let (tx, rx) = broadcast::channel(DATAGRAM_CHANNEL_SIZE); - let writer = DatagramsWriter::new(writer, self.track.clone()); - let reader = DatagramsReader::new(reader, self.track); + let writer = DatagramsWriter::new(tx, self.track.clone()); + let reader = DatagramsReader::new(rx, self.track); (writer, reader) } } -struct DatagramsState { - // The latest datagram - latest: Option, - - // Increased each time datagram changes. - epoch: u64, - - // Set when the writer or all readers are dropped. - closed: Result<(), ServeError>, -} - -impl Default for DatagramsState { - fn default() -> Self { - Self { - latest: None, - epoch: 0, - closed: Ok(()), - } - } -} - pub struct DatagramsWriter { - state: State, + tx: broadcast::Sender, pub track: Arc, } impl DatagramsWriter { - fn new(state: State, track: Arc) -> Self { - Self { state, track } + fn new(tx: broadcast::Sender, track: Arc) -> Self { + Self { tx, track } } pub fn write(&mut self, datagram: Datagram) -> Result<(), ServeError> { - let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?; - - state.latest = Some(datagram); - state.epoch += 1; - + // Ignore send errors (no receivers) - datagrams are fire-and-forget + let _ = self.tx.send(datagram); Ok(()) } - pub fn close(self, err: ServeError) -> Result<(), ServeError> { - let state = self.state.lock(); - state.closed.clone()?; - - let mut state = state.into_mut().ok_or(ServeError::Cancel)?; - state.closed = Err(err); - + pub fn close(self, _err: ServeError) -> Result<(), ServeError> { + // Channel closes when tx is dropped Ok(()) } } -#[derive(Clone)] pub struct DatagramsReader { - state: State, + rx: broadcast::Receiver, pub track: Arc, + latest: Option<(u64, u64)>, +} - epoch: u64, +impl Clone for DatagramsReader { + fn clone(&self) -> Self { + Self { + rx: self.rx.resubscribe(), + track: self.track.clone(), + latest: self.latest, + } + } } impl DatagramsReader { - fn new(state: State, track: Arc) -> Self { + fn new(rx: broadcast::Receiver, track: Arc) -> Self { Self { - state, + rx, track, - epoch: 0, + latest: None, } } pub async fn read(&mut self) -> Result, ServeError> { loop { - { - let state = self.state.lock(); - if self.epoch < state.epoch { - self.epoch = state.epoch; - return Ok(state.latest.clone()); + match self.rx.recv().await { + Ok(datagram) => { + self.latest = Some((datagram.group_id, datagram.object_id)); + return Ok(Some(datagram)); } - - state.closed.clone()?; - match state.modified() { - Some(notify) => notify, - None => return Ok(None), // No more updates will come + Err(broadcast::error::RecvError::Lagged(n)) => { + log::warn!("[DATAGRAMS] reader lagged by {} datagrams", n); + // Continue reading - we'll get the next available datagram + } + Err(broadcast::error::RecvError::Closed) => { + return Ok(None); // Channel closed } } - .await; } } - // Returns the largest group/sequence pub fn latest(&self) -> Option<(u64, u64)> { - let state = self.state.lock(); - state - .latest - .as_ref() - .map(|datagram| (datagram.group_id, datagram.object_id)) + self.latest } pub fn is_closed(&self) -> bool { - let state = self.state.lock(); - state.closed.is_err() || state.modified().is_none() + // Broadcast receiver doesn't have a direct is_closed check. + // We return false (not closed) since we can't reliably detect sender drop + // without actually trying to receive. The read() method will return None + // when the channel is truly closed. + false } } From 1148fa19bf13d1548d439a8d67a38c317b64157b Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Tue, 21 Apr 2026 01:24:15 -0700 Subject: [PATCH 20/21] Fix datagram forwarding rate from 1/sec to 50/sec Use fast-path immediate lookup for datagram alias resolution instead of sequential 1-second timeout lookups. Only fall back to waiting on the first datagram before the alias mapping is established. --- moq-transport/src/serve/datagram.rs | 21 +++++--- moq-transport/src/session/subscriber.rs | 69 +++++++++++++++++++------ 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/moq-transport/src/serve/datagram.rs b/moq-transport/src/serve/datagram.rs index c168c429..1eb07e73 100644 --- a/moq-transport/src/serve/datagram.rs +++ b/moq-transport/src/serve/datagram.rs @@ -4,7 +4,7 @@ use tokio::sync::broadcast; use super::{ServeError, Track}; -const DATAGRAM_CHANNEL_SIZE: usize = 1024; +const DATAGRAM_CHANNEL_SIZE: usize = 4096; pub struct Datagrams { pub track: Arc, @@ -14,8 +14,10 @@ impl Datagrams { pub fn produce(self) -> (DatagramsWriter, DatagramsReader) { let (tx, rx) = broadcast::channel(DATAGRAM_CHANNEL_SIZE); + // Keep a reference to the sender in the reader so clones get fresh receivers + let tx_for_reader = tx.clone(); let writer = DatagramsWriter::new(tx, self.track.clone()); - let reader = DatagramsReader::new(rx, self.track); + let reader = DatagramsReader::new(rx, tx_for_reader, self.track); (writer, reader) } @@ -45,14 +47,18 @@ impl DatagramsWriter { pub struct DatagramsReader { rx: broadcast::Receiver, + tx: broadcast::Sender, pub track: Arc, latest: Option<(u64, u64)>, } impl Clone for DatagramsReader { fn clone(&self) -> Self { + // Subscribe to get a NEW receiver that will get all FUTURE datagrams + // This is correct for relay: each subscriber gets datagrams from now on Self { - rx: self.rx.resubscribe(), + rx: self.tx.subscribe(), + tx: self.tx.clone(), track: self.track.clone(), latest: self.latest, } @@ -60,9 +66,10 @@ impl Clone for DatagramsReader { } impl DatagramsReader { - fn new(rx: broadcast::Receiver, track: Arc) -> Self { + fn new(rx: broadcast::Receiver, tx: broadcast::Sender, track: Arc) -> Self { Self { rx, + tx, track, latest: None, } @@ -91,10 +98,8 @@ impl DatagramsReader { } pub fn is_closed(&self) -> bool { - // Broadcast receiver doesn't have a direct is_closed check. - // We return false (not closed) since we can't reliably detect sender drop - // without actually trying to receive. The read() method will return None - // when the channel is truly closed. + // Check if sender is gone (receiver_count would be 0 or send would fail) + // But we can't easily check this, so return false (conservative) false } } diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 11e185e8..7dd4b77c 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -940,10 +940,15 @@ impl Subscriber { } } - if let Some(subscribe_id) = self - .get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { + // Fast path: check both maps immediately WITHOUT waiting + // This allows datagrams to flow at full rate once alias mapping is established + let (subscribe_id_immediate, publish_id_immediate) = { + let subscribe_id = self.get_subscribe_id_by_alias(datagram.track_alias, None).await; + let publish_id = self.get_publish_id_by_alias(datagram.track_alias, None).await; + (subscribe_id, publish_id) + }; + + if let Some(subscribe_id) = subscribe_id_immediate { if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { log::trace!( "[SUBSCRIBER] recv_datagram: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", @@ -955,10 +960,7 @@ impl Subscriber { datagram.payload.as_ref().map_or(0, |p| p.len())); subscribe.datagram(datagram)?; } - } else if let Some(publish_id) = self - .get_publish_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)) - .await - { + } else if let Some(publish_id) = publish_id_immediate { if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) { log::trace!( @@ -972,14 +974,49 @@ impl Subscriber { publish_recv.datagram(datagram)?; } } else { - log::warn!( - "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", - datagram.track_alias, - datagram.group_id, - datagram.object_id.unwrap_or(0), - datagram.publisher_priority, - datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), - datagram.payload.as_ref().map_or(0, |p| p.len())); + // Slow path: alias not found immediately, wait with timeout (only for first datagram) + let subscribe_fut = self.get_subscribe_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + let publish_fut = self.get_publish_id_by_alias(datagram.track_alias, Some(DEFAULT_ALIAS_WAIT_TIME_MS)); + + tokio::select! { + Some(subscribe_id) = subscribe_fut => { + if let Some(subscribe) = self.subscribes.lock().unwrap().get_mut(&subscribe_id) { + log::trace!( + "[SUBSCRIBER] recv_datagram (waited): track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + subscribe.datagram(datagram)?; + } + } + Some(publish_id) = publish_fut => { + if let Some(publish_recv) = self.publishes_received.lock().unwrap().get_mut(&publish_id) + { + log::trace!( + "[SUBSCRIBER] recv_datagram from publish (waited): track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + publish_recv.datagram(datagram)?; + } + } + else => { + log::warn!( + "[SUBSCRIBER] recv_datagram: discarded due to unknown track_alias: track_alias={}, group_id={}, object_id={}, publisher_priority={:?}, status={}, payload_length={}", + datagram.track_alias, + datagram.group_id, + datagram.object_id.unwrap_or(0), + datagram.publisher_priority, + datagram.status.as_ref().map_or("None".to_string(), |s| format!("{:?}", s)), + datagram.payload.as_ref().map_or(0, |p| p.len())); + } + } } Ok(()) From 5c0606d7fb60c8ca2b13448d3e1820f70d7c47fd Mon Sep 17 00:00:00 2001 From: Suhas Nandakumar Date: Tue, 21 Apr 2026 01:46:50 -0700 Subject: [PATCH 21/21] Fix object encoding to match header type in SUBSCRIBE flow subscribed.rs: Use conditional encoding based on header_type.has_extension_headers() - SubgroupObjectExt for Ext header types (includes extension_length field) - SubgroupObject for non-Ext header types (no extension_length field) subscriber.rs: Use fast-path immediate lookup for datagram alias resolution - Fixes datagram forwarding rate from 1/sec to 50/sec --- moq-transport/src/session/subscribed.rs | 93 +++++++++++++++---------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index 4f166d02..fc17c129 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -286,6 +286,7 @@ impl Subscribed { } } + let has_extension_headers = header.header_type.has_extension_headers(); let mut object_count = 0; while let Some(mut subgroup_object_reader) = subgroup_reader.next().await? { if state.lock().is_closed() { @@ -298,45 +299,65 @@ impl Subscribed { return Ok(()); } - let subgroup_object = data::SubgroupObjectExt { - object_id_delta: 0, // before delta logic, used to be subgroup_object_reader.object_id, - extension_headers: subgroup_object_reader.extension_headers.clone(), // Pass through extension headers - payload_length: subgroup_object_reader.size, - status: if subgroup_object_reader.size == 0 { - // Only set status if payload length is zero - Some(subgroup_object_reader.status) - } else { - None - }, - }; - - log::debug!( - "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, object_id_delta={}, payload_length={}, status={:?}, extension_headers={:?}", - object_count + 1, - subgroup_object_reader.object_id, - subgroup_object.object_id_delta, - subgroup_object.payload_length, - subgroup_object.status, - subgroup_object.extension_headers - ); + // Encode object based on header type - must match what receiver expects + if has_extension_headers { + let subgroup_object = data::SubgroupObjectExt { + object_id_delta: 0, + extension_headers: subgroup_object_reader.extension_headers.clone(), + payload_length: subgroup_object_reader.size, + status: if subgroup_object_reader.size == 0 { + Some(subgroup_object_reader.status) + } else { + None + }, + }; - writer.encode(&subgroup_object).await?; + log::debug!( + "[PUBLISHER] serve_subgroup: sending object #{} (ext) - object_id={}, payload_length={}, status={:?}, extension_headers={:?}", + object_count + 1, + subgroup_object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status, + subgroup_object.extension_headers + ); - // Log subgroup object created/sent - if let Some(ref mlog) = mlog { - if let Ok(mut mlog_guard) = mlog.lock() { - let time = mlog_guard.elapsed_ms(); - let stream_id = 0; // TODO: Placeholder, need actual QUIC stream ID - let event = mlog::subgroup_object_ext_created( - time, - stream_id, - subgroup_reader.group_id, - subgroup_reader.subgroup_id, - subgroup_object_reader.object_id, - &subgroup_object, - ); - let _ = mlog_guard.add_event(event); + writer.encode(&subgroup_object).await?; + + if let Some(ref mlog) = mlog { + if let Ok(mut mlog_guard) = mlog.lock() { + let time = mlog_guard.elapsed_ms(); + let stream_id = 0; + let event = mlog::subgroup_object_ext_created( + time, + stream_id, + subgroup_reader.group_id, + subgroup_reader.subgroup_id, + subgroup_object_reader.object_id, + &subgroup_object, + ); + let _ = mlog_guard.add_event(event); + } } + } else { + let subgroup_object = data::SubgroupObject { + object_id_delta: 0, + payload_length: subgroup_object_reader.size, + status: if subgroup_object_reader.size == 0 { + Some(subgroup_object_reader.status) + } else { + None + }, + }; + + log::debug!( + "[PUBLISHER] serve_subgroup: sending object #{} - object_id={}, payload_length={}, status={:?}", + object_count + 1, + subgroup_object_reader.object_id, + subgroup_object.payload_length, + subgroup_object.status + ); + + writer.encode(&subgroup_object).await?; } state